wgpu_examples/cooperative_matrix/
mod.rs

1//! Cooperative Matrix Multiplication Example
2//!
3//! This example demonstrates how to use cooperative matrix operations
4//! (also known as tensor cores on NVIDIA GPUs or simdgroup matrix
5//! operations on Apple GPUs) to perform efficient matrix multiplication.
6//!
7//! Cooperative matrices allow a workgroup to collectively load, store,
8//! and perform matrix operations on small tiles of data, enabling
9//! hardware-accelerated matrix math.
10//!
11//! Note: This feature requires hardware support and is currently
12//! experimental. Use `adapter.cooperative_matrix_properties()` to query
13//! supported configurations:
14//! - Metal (Apple): 8x8 f32, 8x8 f16, mixed precision (f16 inputs, f32 accumulator)
15//! - Vulkan (AMD): Typically 16x16 f16
16//! - Vulkan (NVIDIA): Varies by GPU generation
17
18use bytemuck::{Pod, Zeroable};
19use half::f16;
20
21/// Matrix dimensions for our example (must be divisible by tile size)
22const M: u32 = 64; // Rows of A and C
23const N: u32 = 64; // Cols of B and C
24const K: u32 = 64; // Cols of A, Rows of B
25
26#[repr(C)]
27#[derive(Clone, Copy, Pod, Zeroable)]
28struct Dimensions {
29    m: u32,
30    n: u32,
31    k: u32,
32    stride: u32,
33}
34
35async fn run() {
36    // Initialize wgpu
37    let instance = wgpu::Instance::default();
38    let adapter = instance
39        .request_adapter(&wgpu::RequestAdapterOptions {
40            power_preference: wgpu::PowerPreference::HighPerformance,
41            ..Default::default()
42        })
43        .await
44        .expect("Failed to find an appropriate adapter");
45
46    log::info!("Using adapter: {:?}", adapter.get_info());
47
48    // Query supported cooperative matrix configurations
49    let coop_props = adapter.cooperative_matrix_properties();
50    if coop_props.is_empty() {
51        log::error!(
52            "Cooperative matrix is not supported on this adapter.\n\
53            This feature requires:\n\
54            - Metal: Apple7+ (A14/M1) with MSL 2.3+\n\
55            - Vulkan: VK_KHR_cooperative_matrix extension"
56        );
57        return;
58    }
59
60    // Display supported configurations
61    log::info!("Supported cooperative matrix configurations:");
62    for (i, prop) in coop_props.iter().enumerate() {
63        log::info!(
64            "  [{}] {:?}x{:?}x{:?} - AB: {:?}, CR: {:?}{}",
65            i,
66            prop.m_size,
67            prop.n_size,
68            prop.k_size,
69            prop.ab_type,
70            prop.cr_type,
71            if prop.saturating_accumulation {
72                " (saturating)"
73            } else {
74                ""
75            }
76        );
77    }
78
79    // Find a suitable configuration - prefer f32, but accept f16
80    // Try 16x16 first (AMD), then 8x8 (Apple Metal)
81    let selected_config = coop_props
82        .iter()
83        .find(|prop| {
84            prop.m_size == 16
85                && prop.n_size == 16
86                && prop.k_size == 16
87                && prop.ab_type == wgpu::CooperativeScalarType::F16
88                && prop.cr_type == wgpu::CooperativeScalarType::F16
89        })
90        .or_else(|| {
91            coop_props.iter().find(|prop| {
92                prop.m_size == 8
93                    && prop.n_size == 8
94                    && prop.k_size == 8
95                    && prop.ab_type == wgpu::CooperativeScalarType::F32
96                    && prop.cr_type == wgpu::CooperativeScalarType::F32
97            })
98        });
99
100    let config = match selected_config {
101        Some(c) => {
102            log::info!(
103                "Selected configuration: {:?}x{:?}x{:?} AB={:?} CR={:?}",
104                c.m_size,
105                c.n_size,
106                c.k_size,
107                c.ab_type,
108                c.cr_type
109            );
110            c
111        }
112        None => {
113            log::error!(
114                "No suitable cooperative matrix configuration found.\n\
115                This example supports 16x16 f16 (AMD) or 8x8 f32 (Apple Metal).\n\
116                Available configurations are listed above."
117            );
118            return;
119        }
120    };
121
122    let tile_size = config.m_size;
123    let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
124
125    log::info!(
126        "Using {}x{} tiles with {} precision",
127        tile_size,
128        tile_size,
129        if use_f16 { "f16" } else { "f32" }
130    );
131
132    // Check if cooperative matrix is supported
133    let adapter_features = adapter.features();
134    if !adapter_features.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
135        log::error!("EXPERIMENTAL_COOPERATIVE_MATRIX feature not available");
136        return;
137    }
138
139    // Check if f16 is needed and available
140    if use_f16 && !adapter_features.contains(wgpu::Features::SHADER_F16) {
141        log::error!("SHADER_F16 feature not available, but required for f16 cooperative matrices");
142        return;
143    }
144
145    // Build required features
146    let mut required_features = wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
147    if use_f16 {
148        required_features |= wgpu::Features::SHADER_F16;
149    }
150
151    // Request device with experimental features enabled
152    let (device, queue) = unsafe {
153        adapter
154            .request_device(&wgpu::DeviceDescriptor {
155                label: Some("Cooperative Matrix Device"),
156                required_features,
157                required_limits: wgpu::Limits::downlevel_defaults(),
158                experimental_features: wgpu::ExperimentalFeatures::enabled(),
159                memory_hints: wgpu::MemoryHints::Performance,
160                trace: wgpu::Trace::Off,
161            })
162            .await
163            .expect("Failed to create device")
164    };
165
166    let results = execute(&device, &queue, config).await;
167
168    log::info!(
169        "Matrix multiplication {M}x{K}x{N} completed using {} precision!",
170        if use_f16 { "f16" } else { "f32" }
171    );
172    log::info!("Max error vs CPU reference: {:.6}", results.max_error);
173
174    if results.max_error < results.tolerance {
175        log::info!(
176            "✓ Results match CPU reference within tolerance ({})",
177            results.tolerance
178        );
179    } else {
180        log::warn!(
181            "✗ Results differ from CPU reference (tolerance: {})",
182            results.tolerance
183        );
184    }
185
186    // Print a small sample of the result
187    log::info!("Sample of result matrix C (top-left 4x4):");
188    for i in 0..4 {
189        let row: Vec<String> = (0..4)
190            .map(|j| format!("{:6.2}", results.matrix[i * N as usize + j]))
191            .collect();
192        log::info!("  [{}]", row.join(", "));
193    }
194}
195
196struct ExecuteResults {
197    max_error: f32,
198    tolerance: f32,
199    matrix: Vec<f32>,
200}
201
202async fn execute(
203    device: &wgpu::Device,
204    queue: &wgpu::Queue,
205    config: &wgpu::CooperativeMatrixProperties,
206) -> ExecuteResults {
207    let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
208
209    // Select the appropriate shader based on configuration
210    let shader_source = if use_f16 {
211        include_str!("shader_f16_16x16.wgsl")
212    } else {
213        include_str!("shader.wgsl")
214    };
215
216    // Create the shader module using the standard validated path
217    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
218        label: Some("Cooperative Matrix Shader"),
219        source: wgpu::ShaderSource::Wgsl(shader_source.into()),
220    });
221
222    // Initialize matrices
223    // A is MxK, B is KxN, C is MxN (result)
224    // Use f32 for computation, convert to f16 if needed for GPU.
225    //
226    // The init weights `i * col_stride + j * row_stride` are chosen so
227    // neither A nor B is symmetric in (i, j): if the row/col index
228    // weighting reduced to the same residue class modulo the divisor,
229    // the matrix would become symmetric and the test would no longer
230    // distinguish row-major from column-major loads. The primes here
231    // (`3, 5` for A; `7, 11` for B) ensure asymmetry for any M/N/K.
232    let matrix_a_f32: Vec<f32> = (0..M * K)
233        .map(|idx| {
234            let (i, j) = (idx / K, idx % K);
235            ((i * 3 + j * 5) % 11) as f32 * 0.1
236        })
237        .collect();
238    let matrix_b_f32: Vec<f32> = (0..K * N)
239        .map(|idx| {
240            let (i, j) = (idx / N, idx % N);
241            ((i * 7 + j * 11) % 13) as f32 * 0.1
242        })
243        .collect();
244    let matrix_c_f32: Vec<f32> = vec![0.0; (M * N) as usize];
245
246    // Element size depends on precision
247    let element_size = if use_f16 { 2usize } else { 4usize };
248    let num_elements_a = (M * K) as usize;
249    let num_elements_b = (K * N) as usize;
250    let num_elements_c = (M * N) as usize;
251
252    // Create buffers
253    let buffer_a = device.create_buffer(&wgpu::BufferDescriptor {
254        label: Some("Matrix A"),
255        size: (num_elements_a * element_size) as u64,
256        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
257        mapped_at_creation: false,
258    });
259
260    let buffer_b = device.create_buffer(&wgpu::BufferDescriptor {
261        label: Some("Matrix B"),
262        size: (num_elements_b * element_size) as u64,
263        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
264        mapped_at_creation: false,
265    });
266
267    let buffer_c = device.create_buffer(&wgpu::BufferDescriptor {
268        label: Some("Matrix C"),
269        size: (num_elements_c * element_size) as u64,
270        usage: wgpu::BufferUsages::STORAGE
271            | wgpu::BufferUsages::COPY_DST
272            | wgpu::BufferUsages::COPY_SRC,
273        mapped_at_creation: false,
274    });
275
276    let dimensions = Dimensions {
277        m: M,
278        n: N,
279        k: K,
280        stride: N,
281    };
282    let buffer_dims = device.create_buffer(&wgpu::BufferDescriptor {
283        label: Some("Dimensions"),
284        size: std::mem::size_of::<Dimensions>() as u64,
285        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
286        mapped_at_creation: false,
287    });
288
289    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
290        label: Some("Staging Buffer"),
291        size: (num_elements_c * element_size) as u64,
292        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
293        mapped_at_creation: false,
294    });
295
296    // Upload data (convert to f16 if needed)
297    if use_f16 {
298        let matrix_a_f16: Vec<f16> = matrix_a_f32.iter().map(|&x| f16::from_f32(x)).collect();
299        let matrix_b_f16: Vec<f16> = matrix_b_f32.iter().map(|&x| f16::from_f32(x)).collect();
300        let matrix_c_f16: Vec<f16> = matrix_c_f32.iter().map(|&x| f16::from_f32(x)).collect();
301        queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f16));
302        queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f16));
303        queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f16));
304    } else {
305        queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f32));
306        queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f32));
307        queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f32));
308    }
309    queue.write_buffer(&buffer_dims, 0, bytemuck::bytes_of(&dimensions));
310
311    // Create bind group layout and bind group
312    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
313        label: Some("Cooperative Matrix Bind Group Layout"),
314        entries: &[
315            wgpu::BindGroupLayoutEntry {
316                binding: 0,
317                visibility: wgpu::ShaderStages::COMPUTE,
318                ty: wgpu::BindingType::Buffer {
319                    ty: wgpu::BufferBindingType::Storage { read_only: true },
320                    has_dynamic_offset: false,
321                    min_binding_size: None,
322                },
323                count: None,
324            },
325            wgpu::BindGroupLayoutEntry {
326                binding: 1,
327                visibility: wgpu::ShaderStages::COMPUTE,
328                ty: wgpu::BindingType::Buffer {
329                    ty: wgpu::BufferBindingType::Storage { read_only: true },
330                    has_dynamic_offset: false,
331                    min_binding_size: None,
332                },
333                count: None,
334            },
335            wgpu::BindGroupLayoutEntry {
336                binding: 2,
337                visibility: wgpu::ShaderStages::COMPUTE,
338                ty: wgpu::BindingType::Buffer {
339                    ty: wgpu::BufferBindingType::Storage { read_only: false },
340                    has_dynamic_offset: false,
341                    min_binding_size: None,
342                },
343                count: None,
344            },
345            wgpu::BindGroupLayoutEntry {
346                binding: 3,
347                visibility: wgpu::ShaderStages::COMPUTE,
348                ty: wgpu::BindingType::Buffer {
349                    ty: wgpu::BufferBindingType::Uniform,
350                    has_dynamic_offset: false,
351                    min_binding_size: None,
352                },
353                count: None,
354            },
355        ],
356    });
357
358    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
359        label: Some("Cooperative Matrix Bind Group"),
360        layout: &bind_group_layout,
361        entries: &[
362            wgpu::BindGroupEntry {
363                binding: 0,
364                resource: buffer_a.as_entire_binding(),
365            },
366            wgpu::BindGroupEntry {
367                binding: 1,
368                resource: buffer_b.as_entire_binding(),
369            },
370            wgpu::BindGroupEntry {
371                binding: 2,
372                resource: buffer_c.as_entire_binding(),
373            },
374            wgpu::BindGroupEntry {
375                binding: 3,
376                resource: buffer_dims.as_entire_binding(),
377            },
378        ],
379    });
380
381    // Create compute pipeline
382    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
383        label: Some("Cooperative Matrix Pipeline Layout"),
384        bind_group_layouts: &[Some(&bind_group_layout)],
385        immediate_size: 0,
386    });
387
388    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
389        label: Some("Cooperative Matrix Pipeline"),
390        layout: Some(&pipeline_layout),
391        module: &shader,
392        entry_point: Some("main"),
393        compilation_options: Default::default(),
394        cache: None,
395    });
396
397    // Dispatch compute
398    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
399        label: Some("Cooperative Matrix Encoder"),
400    });
401
402    {
403        let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
404            label: Some("Cooperative Matrix Pass"),
405            timestamp_writes: None,
406        });
407        compute_pass.set_pipeline(&pipeline);
408        compute_pass.set_bind_group(0, &bind_group, &[]);
409        // Dispatch one workgroup per tile of the output
410        compute_pass.dispatch_workgroups(M / config.m_size, N / config.m_size, 1);
411    }
412
413    // Copy result to staging buffer
414    encoder.copy_buffer_to_buffer(&buffer_c, 0, &staging_buffer, 0, staging_buffer.size());
415
416    queue.submit(Some(encoder.finish()));
417
418    // Read back results
419    let buffer_slice = staging_buffer.slice(..);
420    let (sender, receiver) = flume::bounded(1);
421    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
422    device
423        .poll(wgpu::PollType::wait_indefinitely())
424        .expect("Poll failed");
425    receiver
426        .recv_async()
427        .await
428        .expect("Channel receive failed")
429        .expect("Buffer mapping failed");
430
431    let data = buffer_slice.get_mapped_range().unwrap();
432
433    // Convert result back to f32 for comparison
434    let result: Vec<f32> = if use_f16 {
435        let result_f16: Vec<f16> = bytemuck::allocation::pod_collect_to_vec(&data);
436        result_f16.iter().map(|x| x.to_f32()).collect()
437    } else {
438        bytemuck::allocation::pod_collect_to_vec(&data)
439    };
440
441    // Compute reference result on CPU for verification
442    let mut reference = vec![0.0f32; (M * N) as usize];
443    for i in 0..M {
444        for j in 0..N {
445            let mut sum = 0.0f32;
446            for k in 0..K {
447                sum += matrix_a_f32[(i * K + k) as usize] * matrix_b_f32[(k * N + j) as usize];
448            }
449            reference[(i * N + j) as usize] = sum;
450        }
451    }
452
453    // Verify results (use larger tolerance for f16)
454    let tolerance = if use_f16 { 0.1 } else { 0.01 };
455    let mut max_error = 0.0f32;
456    for i in 0..(M * N) as usize {
457        let error = (result[i] - reference[i]).abs();
458        max_error = max_error.max(error);
459    }
460
461    ExecuteResults {
462        max_error,
463        tolerance,
464        matrix: result,
465    }
466}
467
468pub fn main() {
469    #[cfg(not(target_arch = "wasm32"))]
470    {
471        env_logger::builder()
472            .filter_level(log::LevelFilter::Info)
473            .format_timestamp_nanos()
474            .init();
475        pollster::block_on(run());
476    }
477    #[cfg(target_arch = "wasm32")]
478    {
479        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
480        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
481        crate::utils::add_web_nothing_to_see_msg();
482        wasm_bindgen_futures::spawn_local(run());
483    }
484}
485
486#[cfg(test)]
487pub mod tests;