wgpu_examples/cooperative_matrix/
mod.rs

1//! Cooperative Matrix Multiplication Example
2//!
3//! This example demonstrates how to use cooperative matrix operations
4//! (also known as tensor cores on NVIDIA GPUs or simdgroup matrix
5//! operations on Apple GPUs) to perform efficient matrix multiplication.
6//!
7//! Cooperative matrices allow a workgroup to collectively load, store,
8//! and perform matrix operations on small tiles of data, enabling
9//! hardware-accelerated matrix math.
10//!
11//! Note: This feature requires hardware support and is currently
12//! experimental. Use `adapter.cooperative_matrix_properties()` to query
13//! supported configurations:
14//! - Metal (Apple): 8x8 f32, 8x8 f16, mixed precision (f16 inputs, f32 accumulator)
15//! - Vulkan (AMD): Typically 16x16 f16
16//! - Vulkan (NVIDIA): Varies by GPU generation
17
18use bytemuck::{Pod, Zeroable};
19use half::f16;
20
21/// Matrix dimensions for our example (must be divisible by tile size)
22const M: u32 = 64; // Rows of A and C
23const N: u32 = 64; // Cols of B and C
24const K: u32 = 64; // Cols of A, Rows of B
25
26#[repr(C)]
27#[derive(Clone, Copy, Pod, Zeroable)]
28struct Dimensions {
29    m: u32,
30    n: u32,
31    k: u32,
32    stride: u32,
33}
34
35async fn run() {
36    // Initialize wgpu
37    let instance = wgpu::Instance::default();
38    let adapter = instance
39        .request_adapter(&wgpu::RequestAdapterOptions {
40            power_preference: wgpu::PowerPreference::HighPerformance,
41            ..Default::default()
42        })
43        .await
44        .expect("Failed to find an appropriate adapter");
45
46    log::info!("Using adapter: {:?}", adapter.get_info());
47
48    // Query supported cooperative matrix configurations
49    let coop_props = adapter.cooperative_matrix_properties();
50    if coop_props.is_empty() {
51        log::error!(
52            "Cooperative matrix is not supported on this adapter.\n\
53            This feature requires:\n\
54            - Metal: Apple7+ (A14/M1) with MSL 2.3+\n\
55            - Vulkan: VK_KHR_cooperative_matrix extension"
56        );
57        return;
58    }
59
60    // Display supported configurations
61    log::info!("Supported cooperative matrix configurations:");
62    for (i, prop) in coop_props.iter().enumerate() {
63        log::info!(
64            "  [{}] {:?}x{:?}x{:?} - AB: {:?}, CR: {:?}{}",
65            i,
66            prop.m_size,
67            prop.n_size,
68            prop.k_size,
69            prop.ab_type,
70            prop.cr_type,
71            if prop.saturating_accumulation {
72                " (saturating)"
73            } else {
74                ""
75            }
76        );
77    }
78
79    // Find a suitable configuration - prefer f32, but accept f16
80    // Try 16x16 first (AMD), then 8x8 (Apple Metal)
81    let selected_config = coop_props
82        .iter()
83        .find(|prop| {
84            prop.m_size == 16
85                && prop.n_size == 16
86                && prop.k_size == 16
87                && prop.ab_type == wgpu::CooperativeScalarType::F16
88                && prop.cr_type == wgpu::CooperativeScalarType::F16
89        })
90        .or_else(|| {
91            coop_props.iter().find(|prop| {
92                prop.m_size == 8
93                    && prop.n_size == 8
94                    && prop.k_size == 8
95                    && prop.ab_type == wgpu::CooperativeScalarType::F32
96                    && prop.cr_type == wgpu::CooperativeScalarType::F32
97            })
98        });
99
100    let config = match selected_config {
101        Some(c) => {
102            log::info!(
103                "Selected configuration: {:?}x{:?}x{:?} AB={:?} CR={:?}",
104                c.m_size,
105                c.n_size,
106                c.k_size,
107                c.ab_type,
108                c.cr_type
109            );
110            c
111        }
112        None => {
113            log::error!(
114                "No suitable cooperative matrix configuration found.\n\
115                This example supports 16x16 f16 (AMD) or 8x8 f32 (Apple Metal).\n\
116                Available configurations are listed above."
117            );
118            return;
119        }
120    };
121
122    let tile_size = config.m_size;
123    let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
124
125    log::info!(
126        "Using {}x{} tiles with {} precision",
127        tile_size,
128        tile_size,
129        if use_f16 { "f16" } else { "f32" }
130    );
131
132    // Check if cooperative matrix is supported
133    let adapter_features = adapter.features();
134    if !adapter_features.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
135        log::error!("EXPERIMENTAL_COOPERATIVE_MATRIX feature not available");
136        return;
137    }
138
139    // Check if f16 is needed and available
140    if use_f16 && !adapter_features.contains(wgpu::Features::SHADER_F16) {
141        log::error!("SHADER_F16 feature not available, but required for f16 cooperative matrices");
142        return;
143    }
144
145    // Build required features
146    let mut required_features = wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
147    if use_f16 {
148        required_features |= wgpu::Features::SHADER_F16;
149    }
150
151    // Request device with experimental features enabled
152    let (device, queue) = unsafe {
153        adapter
154            .request_device(&wgpu::DeviceDescriptor {
155                label: Some("Cooperative Matrix Device"),
156                required_features,
157                required_limits: wgpu::Limits::downlevel_defaults(),
158                experimental_features: wgpu::ExperimentalFeatures::enabled(),
159                memory_hints: wgpu::MemoryHints::Performance,
160                trace: wgpu::Trace::Off,
161            })
162            .await
163            .expect("Failed to create device")
164    };
165
166    let results = execute(&device, &queue, config).await;
167
168    log::info!(
169        "Matrix multiplication {M}x{K}x{N} completed using {} precision!",
170        if use_f16 { "f16" } else { "f32" }
171    );
172    log::info!("Max error vs CPU reference: {:.6}", results.max_error);
173
174    if results.max_error < results.tolerance {
175        log::info!(
176            "✓ Results match CPU reference within tolerance ({})",
177            results.tolerance
178        );
179    } else {
180        log::warn!(
181            "✗ Results differ from CPU reference (tolerance: {})",
182            results.tolerance
183        );
184    }
185
186    // Print a small sample of the result
187    log::info!("Sample of result matrix C (top-left 4x4):");
188    for i in 0..4 {
189        let row: Vec<String> = (0..4)
190            .map(|j| format!("{:6.2}", results.matrix[i * N as usize + j]))
191            .collect();
192        log::info!("  [{}]", row.join(", "));
193    }
194}
195
196struct ExecuteResults {
197    max_error: f32,
198    tolerance: f32,
199    matrix: Vec<f32>,
200}
201
202async fn execute(
203    device: &wgpu::Device,
204    queue: &wgpu::Queue,
205    config: &wgpu::CooperativeMatrixProperties,
206) -> ExecuteResults {
207    let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
208
209    // Select the appropriate shader based on configuration
210    let shader_source = if use_f16 {
211        include_str!("shader_f16_16x16.wgsl")
212    } else {
213        include_str!("shader.wgsl")
214    };
215
216    // Create the shader module using the standard validated path
217    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
218        label: Some("Cooperative Matrix Shader"),
219        source: wgpu::ShaderSource::Wgsl(shader_source.into()),
220    });
221
222    // Initialize matrices
223    // A is MxK, B is KxN, C is MxN (result)
224    // Use f32 for computation, convert to f16 if needed for GPU
225    let matrix_a_f32: Vec<f32> = (0..M * K).map(|i| (i % 7) as f32 * 0.1).collect();
226    let matrix_b_f32: Vec<f32> = (0..K * N).map(|i| (i % 11) as f32 * 0.1).collect();
227    let matrix_c_f32: Vec<f32> = vec![0.0; (M * N) as usize];
228
229    // Element size depends on precision
230    let element_size = if use_f16 { 2usize } else { 4usize };
231    let num_elements_a = (M * K) as usize;
232    let num_elements_b = (K * N) as usize;
233    let num_elements_c = (M * N) as usize;
234
235    // Create buffers
236    let buffer_a = device.create_buffer(&wgpu::BufferDescriptor {
237        label: Some("Matrix A"),
238        size: (num_elements_a * element_size) as u64,
239        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
240        mapped_at_creation: false,
241    });
242
243    let buffer_b = device.create_buffer(&wgpu::BufferDescriptor {
244        label: Some("Matrix B"),
245        size: (num_elements_b * element_size) as u64,
246        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
247        mapped_at_creation: false,
248    });
249
250    let buffer_c = device.create_buffer(&wgpu::BufferDescriptor {
251        label: Some("Matrix C"),
252        size: (num_elements_c * element_size) as u64,
253        usage: wgpu::BufferUsages::STORAGE
254            | wgpu::BufferUsages::COPY_DST
255            | wgpu::BufferUsages::COPY_SRC,
256        mapped_at_creation: false,
257    });
258
259    let dimensions = Dimensions {
260        m: M,
261        n: N,
262        k: K,
263        stride: N,
264    };
265    let buffer_dims = device.create_buffer(&wgpu::BufferDescriptor {
266        label: Some("Dimensions"),
267        size: std::mem::size_of::<Dimensions>() as u64,
268        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
269        mapped_at_creation: false,
270    });
271
272    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
273        label: Some("Staging Buffer"),
274        size: (num_elements_c * element_size) as u64,
275        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
276        mapped_at_creation: false,
277    });
278
279    // Upload data (convert to f16 if needed)
280    if use_f16 {
281        let matrix_a_f16: Vec<f16> = matrix_a_f32.iter().map(|&x| f16::from_f32(x)).collect();
282        let matrix_b_f16: Vec<f16> = matrix_b_f32.iter().map(|&x| f16::from_f32(x)).collect();
283        let matrix_c_f16: Vec<f16> = matrix_c_f32.iter().map(|&x| f16::from_f32(x)).collect();
284        queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f16));
285        queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f16));
286        queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f16));
287    } else {
288        queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f32));
289        queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f32));
290        queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f32));
291    }
292    queue.write_buffer(&buffer_dims, 0, bytemuck::bytes_of(&dimensions));
293
294    // Create bind group layout and bind group
295    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
296        label: Some("Cooperative Matrix Bind Group Layout"),
297        entries: &[
298            wgpu::BindGroupLayoutEntry {
299                binding: 0,
300                visibility: wgpu::ShaderStages::COMPUTE,
301                ty: wgpu::BindingType::Buffer {
302                    ty: wgpu::BufferBindingType::Storage { read_only: true },
303                    has_dynamic_offset: false,
304                    min_binding_size: None,
305                },
306                count: None,
307            },
308            wgpu::BindGroupLayoutEntry {
309                binding: 1,
310                visibility: wgpu::ShaderStages::COMPUTE,
311                ty: wgpu::BindingType::Buffer {
312                    ty: wgpu::BufferBindingType::Storage { read_only: true },
313                    has_dynamic_offset: false,
314                    min_binding_size: None,
315                },
316                count: None,
317            },
318            wgpu::BindGroupLayoutEntry {
319                binding: 2,
320                visibility: wgpu::ShaderStages::COMPUTE,
321                ty: wgpu::BindingType::Buffer {
322                    ty: wgpu::BufferBindingType::Storage { read_only: false },
323                    has_dynamic_offset: false,
324                    min_binding_size: None,
325                },
326                count: None,
327            },
328            wgpu::BindGroupLayoutEntry {
329                binding: 3,
330                visibility: wgpu::ShaderStages::COMPUTE,
331                ty: wgpu::BindingType::Buffer {
332                    ty: wgpu::BufferBindingType::Uniform,
333                    has_dynamic_offset: false,
334                    min_binding_size: None,
335                },
336                count: None,
337            },
338        ],
339    });
340
341    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
342        label: Some("Cooperative Matrix Bind Group"),
343        layout: &bind_group_layout,
344        entries: &[
345            wgpu::BindGroupEntry {
346                binding: 0,
347                resource: buffer_a.as_entire_binding(),
348            },
349            wgpu::BindGroupEntry {
350                binding: 1,
351                resource: buffer_b.as_entire_binding(),
352            },
353            wgpu::BindGroupEntry {
354                binding: 2,
355                resource: buffer_c.as_entire_binding(),
356            },
357            wgpu::BindGroupEntry {
358                binding: 3,
359                resource: buffer_dims.as_entire_binding(),
360            },
361        ],
362    });
363
364    // Create compute pipeline
365    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
366        label: Some("Cooperative Matrix Pipeline Layout"),
367        bind_group_layouts: &[Some(&bind_group_layout)],
368        immediate_size: 0,
369    });
370
371    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
372        label: Some("Cooperative Matrix Pipeline"),
373        layout: Some(&pipeline_layout),
374        module: &shader,
375        entry_point: Some("main"),
376        compilation_options: Default::default(),
377        cache: None,
378    });
379
380    // Dispatch compute
381    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
382        label: Some("Cooperative Matrix Encoder"),
383    });
384
385    {
386        let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
387            label: Some("Cooperative Matrix Pass"),
388            timestamp_writes: None,
389        });
390        compute_pass.set_pipeline(&pipeline);
391        compute_pass.set_bind_group(0, &bind_group, &[]);
392        // Dispatch one workgroup per tile of the output
393        compute_pass.dispatch_workgroups(M / config.m_size, N / config.m_size, 1);
394    }
395
396    // Copy result to staging buffer
397    encoder.copy_buffer_to_buffer(&buffer_c, 0, &staging_buffer, 0, staging_buffer.size());
398
399    queue.submit(Some(encoder.finish()));
400
401    // Read back results
402    let buffer_slice = staging_buffer.slice(..);
403    let (sender, receiver) = flume::bounded(1);
404    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
405    device
406        .poll(wgpu::PollType::wait_indefinitely())
407        .expect("Poll failed");
408    receiver
409        .recv_async()
410        .await
411        .expect("Channel receive failed")
412        .expect("Buffer mapping failed");
413
414    let data = buffer_slice.get_mapped_range();
415
416    // Convert result back to f32 for comparison
417    let result: Vec<f32> = if use_f16 {
418        let result_f16: &[f16] = bytemuck::cast_slice(&data);
419        result_f16.iter().map(|x| x.to_f32()).collect()
420    } else {
421        bytemuck::cast_slice::<_, f32>(&data).to_vec()
422    };
423
424    // Compute reference result on CPU for verification
425    let mut reference = vec![0.0f32; (M * N) as usize];
426    for i in 0..M {
427        for j in 0..N {
428            let mut sum = 0.0f32;
429            for k in 0..K {
430                sum += matrix_a_f32[(i * K + k) as usize] * matrix_b_f32[(k * N + j) as usize];
431            }
432            reference[(i * N + j) as usize] = sum;
433        }
434    }
435
436    // Verify results (use larger tolerance for f16)
437    let tolerance = if use_f16 { 0.1 } else { 0.01 };
438    let mut max_error = 0.0f32;
439    for i in 0..(M * N) as usize {
440        let error = (result[i] - reference[i]).abs();
441        max_error = max_error.max(error);
442    }
443
444    ExecuteResults {
445        max_error,
446        tolerance,
447        matrix: result,
448    }
449}
450
451pub fn main() {
452    #[cfg(not(target_arch = "wasm32"))]
453    {
454        env_logger::builder()
455            .filter_level(log::LevelFilter::Info)
456            .format_timestamp_nanos()
457            .init();
458        pollster::block_on(run());
459    }
460    #[cfg(target_arch = "wasm32")]
461    {
462        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
463        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
464        crate::utils::add_web_nothing_to_see_msg();
465        wasm_bindgen_futures::spawn_local(run());
466    }
467}
468
469#[cfg(test)]
470pub mod tests;