1use bytemuck::{Pod, Zeroable};
19use half::f16;
20
21const M: u32 = 64; const N: u32 = 64; const K: u32 = 64; #[repr(C)]
27#[derive(Clone, Copy, Pod, Zeroable)]
28struct Dimensions {
29 m: u32,
30 n: u32,
31 k: u32,
32 stride: u32,
33}
34
35async fn run() {
36 let instance = wgpu::Instance::default();
38 let adapter = instance
39 .request_adapter(&wgpu::RequestAdapterOptions {
40 power_preference: wgpu::PowerPreference::HighPerformance,
41 ..Default::default()
42 })
43 .await
44 .expect("Failed to find an appropriate adapter");
45
46 log::info!("Using adapter: {:?}", adapter.get_info());
47
48 let coop_props = adapter.cooperative_matrix_properties();
50 if coop_props.is_empty() {
51 log::error!(
52 "Cooperative matrix is not supported on this adapter.\n\
53 This feature requires:\n\
54 - Metal: Apple7+ (A14/M1) with MSL 2.3+\n\
55 - Vulkan: VK_KHR_cooperative_matrix extension"
56 );
57 return;
58 }
59
60 log::info!("Supported cooperative matrix configurations:");
62 for (i, prop) in coop_props.iter().enumerate() {
63 log::info!(
64 " [{}] {:?}x{:?}x{:?} - AB: {:?}, CR: {:?}{}",
65 i,
66 prop.m_size,
67 prop.n_size,
68 prop.k_size,
69 prop.ab_type,
70 prop.cr_type,
71 if prop.saturating_accumulation {
72 " (saturating)"
73 } else {
74 ""
75 }
76 );
77 }
78
79 let selected_config = coop_props
82 .iter()
83 .find(|prop| {
84 prop.m_size == 16
85 && prop.n_size == 16
86 && prop.k_size == 16
87 && prop.ab_type == wgpu::CooperativeScalarType::F16
88 && prop.cr_type == wgpu::CooperativeScalarType::F16
89 })
90 .or_else(|| {
91 coop_props.iter().find(|prop| {
92 prop.m_size == 8
93 && prop.n_size == 8
94 && prop.k_size == 8
95 && prop.ab_type == wgpu::CooperativeScalarType::F32
96 && prop.cr_type == wgpu::CooperativeScalarType::F32
97 })
98 });
99
100 let config = match selected_config {
101 Some(c) => {
102 log::info!(
103 "Selected configuration: {:?}x{:?}x{:?} AB={:?} CR={:?}",
104 c.m_size,
105 c.n_size,
106 c.k_size,
107 c.ab_type,
108 c.cr_type
109 );
110 c
111 }
112 None => {
113 log::error!(
114 "No suitable cooperative matrix configuration found.\n\
115 This example supports 16x16 f16 (AMD) or 8x8 f32 (Apple Metal).\n\
116 Available configurations are listed above."
117 );
118 return;
119 }
120 };
121
122 let tile_size = config.m_size;
123 let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
124
125 log::info!(
126 "Using {}x{} tiles with {} precision",
127 tile_size,
128 tile_size,
129 if use_f16 { "f16" } else { "f32" }
130 );
131
132 let adapter_features = adapter.features();
134 if !adapter_features.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
135 log::error!("EXPERIMENTAL_COOPERATIVE_MATRIX feature not available");
136 return;
137 }
138
139 if use_f16 && !adapter_features.contains(wgpu::Features::SHADER_F16) {
141 log::error!("SHADER_F16 feature not available, but required for f16 cooperative matrices");
142 return;
143 }
144
145 let mut required_features = wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
147 if use_f16 {
148 required_features |= wgpu::Features::SHADER_F16;
149 }
150
151 let (device, queue) = unsafe {
153 adapter
154 .request_device(&wgpu::DeviceDescriptor {
155 label: Some("Cooperative Matrix Device"),
156 required_features,
157 required_limits: wgpu::Limits::downlevel_defaults(),
158 experimental_features: wgpu::ExperimentalFeatures::enabled(),
159 memory_hints: wgpu::MemoryHints::Performance,
160 trace: wgpu::Trace::Off,
161 })
162 .await
163 .expect("Failed to create device")
164 };
165
166 let results = execute(&device, &queue, config).await;
167
168 log::info!(
169 "Matrix multiplication {M}x{K}x{N} completed using {} precision!",
170 if use_f16 { "f16" } else { "f32" }
171 );
172 log::info!("Max error vs CPU reference: {:.6}", results.max_error);
173
174 if results.max_error < results.tolerance {
175 log::info!(
176 "✓ Results match CPU reference within tolerance ({})",
177 results.tolerance
178 );
179 } else {
180 log::warn!(
181 "✗ Results differ from CPU reference (tolerance: {})",
182 results.tolerance
183 );
184 }
185
186 log::info!("Sample of result matrix C (top-left 4x4):");
188 for i in 0..4 {
189 let row: Vec<String> = (0..4)
190 .map(|j| format!("{:6.2}", results.matrix[i * N as usize + j]))
191 .collect();
192 log::info!(" [{}]", row.join(", "));
193 }
194}
195
196struct ExecuteResults {
197 max_error: f32,
198 tolerance: f32,
199 matrix: Vec<f32>,
200}
201
202async fn execute(
203 device: &wgpu::Device,
204 queue: &wgpu::Queue,
205 config: &wgpu::CooperativeMatrixProperties,
206) -> ExecuteResults {
207 let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
208
209 let shader_source = if use_f16 {
211 include_str!("shader_f16_16x16.wgsl")
212 } else {
213 include_str!("shader.wgsl")
214 };
215
216 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
218 label: Some("Cooperative Matrix Shader"),
219 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
220 });
221
222 let matrix_a_f32: Vec<f32> = (0..M * K)
233 .map(|idx| {
234 let (i, j) = (idx / K, idx % K);
235 ((i * 3 + j * 5) % 11) as f32 * 0.1
236 })
237 .collect();
238 let matrix_b_f32: Vec<f32> = (0..K * N)
239 .map(|idx| {
240 let (i, j) = (idx / N, idx % N);
241 ((i * 7 + j * 11) % 13) as f32 * 0.1
242 })
243 .collect();
244 let matrix_c_f32: Vec<f32> = vec![0.0; (M * N) as usize];
245
246 let element_size = if use_f16 { 2usize } else { 4usize };
248 let num_elements_a = (M * K) as usize;
249 let num_elements_b = (K * N) as usize;
250 let num_elements_c = (M * N) as usize;
251
252 let buffer_a = device.create_buffer(&wgpu::BufferDescriptor {
254 label: Some("Matrix A"),
255 size: (num_elements_a * element_size) as u64,
256 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
257 mapped_at_creation: false,
258 });
259
260 let buffer_b = device.create_buffer(&wgpu::BufferDescriptor {
261 label: Some("Matrix B"),
262 size: (num_elements_b * element_size) as u64,
263 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
264 mapped_at_creation: false,
265 });
266
267 let buffer_c = device.create_buffer(&wgpu::BufferDescriptor {
268 label: Some("Matrix C"),
269 size: (num_elements_c * element_size) as u64,
270 usage: wgpu::BufferUsages::STORAGE
271 | wgpu::BufferUsages::COPY_DST
272 | wgpu::BufferUsages::COPY_SRC,
273 mapped_at_creation: false,
274 });
275
276 let dimensions = Dimensions {
277 m: M,
278 n: N,
279 k: K,
280 stride: N,
281 };
282 let buffer_dims = device.create_buffer(&wgpu::BufferDescriptor {
283 label: Some("Dimensions"),
284 size: std::mem::size_of::<Dimensions>() as u64,
285 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
286 mapped_at_creation: false,
287 });
288
289 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
290 label: Some("Staging Buffer"),
291 size: (num_elements_c * element_size) as u64,
292 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
293 mapped_at_creation: false,
294 });
295
296 if use_f16 {
298 let matrix_a_f16: Vec<f16> = matrix_a_f32.iter().map(|&x| f16::from_f32(x)).collect();
299 let matrix_b_f16: Vec<f16> = matrix_b_f32.iter().map(|&x| f16::from_f32(x)).collect();
300 let matrix_c_f16: Vec<f16> = matrix_c_f32.iter().map(|&x| f16::from_f32(x)).collect();
301 queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f16));
302 queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f16));
303 queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f16));
304 } else {
305 queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f32));
306 queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f32));
307 queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f32));
308 }
309 queue.write_buffer(&buffer_dims, 0, bytemuck::bytes_of(&dimensions));
310
311 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
313 label: Some("Cooperative Matrix Bind Group Layout"),
314 entries: &[
315 wgpu::BindGroupLayoutEntry {
316 binding: 0,
317 visibility: wgpu::ShaderStages::COMPUTE,
318 ty: wgpu::BindingType::Buffer {
319 ty: wgpu::BufferBindingType::Storage { read_only: true },
320 has_dynamic_offset: false,
321 min_binding_size: None,
322 },
323 count: None,
324 },
325 wgpu::BindGroupLayoutEntry {
326 binding: 1,
327 visibility: wgpu::ShaderStages::COMPUTE,
328 ty: wgpu::BindingType::Buffer {
329 ty: wgpu::BufferBindingType::Storage { read_only: true },
330 has_dynamic_offset: false,
331 min_binding_size: None,
332 },
333 count: None,
334 },
335 wgpu::BindGroupLayoutEntry {
336 binding: 2,
337 visibility: wgpu::ShaderStages::COMPUTE,
338 ty: wgpu::BindingType::Buffer {
339 ty: wgpu::BufferBindingType::Storage { read_only: false },
340 has_dynamic_offset: false,
341 min_binding_size: None,
342 },
343 count: None,
344 },
345 wgpu::BindGroupLayoutEntry {
346 binding: 3,
347 visibility: wgpu::ShaderStages::COMPUTE,
348 ty: wgpu::BindingType::Buffer {
349 ty: wgpu::BufferBindingType::Uniform,
350 has_dynamic_offset: false,
351 min_binding_size: None,
352 },
353 count: None,
354 },
355 ],
356 });
357
358 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
359 label: Some("Cooperative Matrix Bind Group"),
360 layout: &bind_group_layout,
361 entries: &[
362 wgpu::BindGroupEntry {
363 binding: 0,
364 resource: buffer_a.as_entire_binding(),
365 },
366 wgpu::BindGroupEntry {
367 binding: 1,
368 resource: buffer_b.as_entire_binding(),
369 },
370 wgpu::BindGroupEntry {
371 binding: 2,
372 resource: buffer_c.as_entire_binding(),
373 },
374 wgpu::BindGroupEntry {
375 binding: 3,
376 resource: buffer_dims.as_entire_binding(),
377 },
378 ],
379 });
380
381 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
383 label: Some("Cooperative Matrix Pipeline Layout"),
384 bind_group_layouts: &[Some(&bind_group_layout)],
385 immediate_size: 0,
386 });
387
388 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
389 label: Some("Cooperative Matrix Pipeline"),
390 layout: Some(&pipeline_layout),
391 module: &shader,
392 entry_point: Some("main"),
393 compilation_options: Default::default(),
394 cache: None,
395 });
396
397 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
399 label: Some("Cooperative Matrix Encoder"),
400 });
401
402 {
403 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
404 label: Some("Cooperative Matrix Pass"),
405 timestamp_writes: None,
406 });
407 compute_pass.set_pipeline(&pipeline);
408 compute_pass.set_bind_group(0, &bind_group, &[]);
409 compute_pass.dispatch_workgroups(M / config.m_size, N / config.m_size, 1);
411 }
412
413 encoder.copy_buffer_to_buffer(&buffer_c, 0, &staging_buffer, 0, staging_buffer.size());
415
416 queue.submit(Some(encoder.finish()));
417
418 let buffer_slice = staging_buffer.slice(..);
420 let (sender, receiver) = flume::bounded(1);
421 buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
422 device
423 .poll(wgpu::PollType::wait_indefinitely())
424 .expect("Poll failed");
425 receiver
426 .recv_async()
427 .await
428 .expect("Channel receive failed")
429 .expect("Buffer mapping failed");
430
431 let data = buffer_slice.get_mapped_range().unwrap();
432
433 let result: Vec<f32> = if use_f16 {
435 let result_f16: Vec<f16> = bytemuck::allocation::pod_collect_to_vec(&data);
436 result_f16.iter().map(|x| x.to_f32()).collect()
437 } else {
438 bytemuck::allocation::pod_collect_to_vec(&data)
439 };
440
441 let mut reference = vec![0.0f32; (M * N) as usize];
443 for i in 0..M {
444 for j in 0..N {
445 let mut sum = 0.0f32;
446 for k in 0..K {
447 sum += matrix_a_f32[(i * K + k) as usize] * matrix_b_f32[(k * N + j) as usize];
448 }
449 reference[(i * N + j) as usize] = sum;
450 }
451 }
452
453 let tolerance = if use_f16 { 0.1 } else { 0.01 };
455 let mut max_error = 0.0f32;
456 for i in 0..(M * N) as usize {
457 let error = (result[i] - reference[i]).abs();
458 max_error = max_error.max(error);
459 }
460
461 ExecuteResults {
462 max_error,
463 tolerance,
464 matrix: result,
465 }
466}
467
468pub fn main() {
469 #[cfg(not(target_arch = "wasm32"))]
470 {
471 env_logger::builder()
472 .filter_level(log::LevelFilter::Info)
473 .format_timestamp_nanos()
474 .init();
475 pollster::block_on(run());
476 }
477 #[cfg(target_arch = "wasm32")]
478 {
479 std::panic::set_hook(Box::new(console_error_panic_hook::hook));
480 console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
481 crate::utils::add_web_nothing_to_see_msg();
482 wasm_bindgen_futures::spawn_local(run());
483 }
484}
485
486#[cfg(test)]
487pub mod tests;