1use bytemuck::{Pod, Zeroable};
19use half::f16;
20
21const M: u32 = 64; const N: u32 = 64; const K: u32 = 64; #[repr(C)]
27#[derive(Clone, Copy, Pod, Zeroable)]
28struct Dimensions {
29 m: u32,
30 n: u32,
31 k: u32,
32 stride: u32,
33}
34
35async fn run() {
36 let instance = wgpu::Instance::default();
38 let adapter = instance
39 .request_adapter(&wgpu::RequestAdapterOptions {
40 power_preference: wgpu::PowerPreference::HighPerformance,
41 ..Default::default()
42 })
43 .await
44 .expect("Failed to find an appropriate adapter");
45
46 log::info!("Using adapter: {:?}", adapter.get_info());
47
48 let coop_props = adapter.cooperative_matrix_properties();
50 if coop_props.is_empty() {
51 log::error!(
52 "Cooperative matrix is not supported on this adapter.\n\
53 This feature requires:\n\
54 - Metal: Apple7+ (A14/M1) with MSL 2.3+\n\
55 - Vulkan: VK_KHR_cooperative_matrix extension"
56 );
57 return;
58 }
59
60 log::info!("Supported cooperative matrix configurations:");
62 for (i, prop) in coop_props.iter().enumerate() {
63 log::info!(
64 " [{}] {:?}x{:?}x{:?} - AB: {:?}, CR: {:?}{}",
65 i,
66 prop.m_size,
67 prop.n_size,
68 prop.k_size,
69 prop.ab_type,
70 prop.cr_type,
71 if prop.saturating_accumulation {
72 " (saturating)"
73 } else {
74 ""
75 }
76 );
77 }
78
79 let selected_config = coop_props
82 .iter()
83 .find(|prop| {
84 prop.m_size == 16
85 && prop.n_size == 16
86 && prop.k_size == 16
87 && prop.ab_type == wgpu::CooperativeScalarType::F16
88 && prop.cr_type == wgpu::CooperativeScalarType::F16
89 })
90 .or_else(|| {
91 coop_props.iter().find(|prop| {
92 prop.m_size == 8
93 && prop.n_size == 8
94 && prop.k_size == 8
95 && prop.ab_type == wgpu::CooperativeScalarType::F32
96 && prop.cr_type == wgpu::CooperativeScalarType::F32
97 })
98 });
99
100 let config = match selected_config {
101 Some(c) => {
102 log::info!(
103 "Selected configuration: {:?}x{:?}x{:?} AB={:?} CR={:?}",
104 c.m_size,
105 c.n_size,
106 c.k_size,
107 c.ab_type,
108 c.cr_type
109 );
110 c
111 }
112 None => {
113 log::error!(
114 "No suitable cooperative matrix configuration found.\n\
115 This example supports 16x16 f16 (AMD) or 8x8 f32 (Apple Metal).\n\
116 Available configurations are listed above."
117 );
118 return;
119 }
120 };
121
122 let tile_size = config.m_size;
123 let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
124
125 log::info!(
126 "Using {}x{} tiles with {} precision",
127 tile_size,
128 tile_size,
129 if use_f16 { "f16" } else { "f32" }
130 );
131
132 let adapter_features = adapter.features();
134 if !adapter_features.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
135 log::error!("EXPERIMENTAL_COOPERATIVE_MATRIX feature not available");
136 return;
137 }
138
139 if use_f16 && !adapter_features.contains(wgpu::Features::SHADER_F16) {
141 log::error!("SHADER_F16 feature not available, but required for f16 cooperative matrices");
142 return;
143 }
144
145 let mut required_features = wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX;
147 if use_f16 {
148 required_features |= wgpu::Features::SHADER_F16;
149 }
150
151 let (device, queue) = unsafe {
153 adapter
154 .request_device(&wgpu::DeviceDescriptor {
155 label: Some("Cooperative Matrix Device"),
156 required_features,
157 required_limits: wgpu::Limits::downlevel_defaults(),
158 experimental_features: wgpu::ExperimentalFeatures::enabled(),
159 memory_hints: wgpu::MemoryHints::Performance,
160 trace: wgpu::Trace::Off,
161 })
162 .await
163 .expect("Failed to create device")
164 };
165
166 let results = execute(&device, &queue, config).await;
167
168 log::info!(
169 "Matrix multiplication {M}x{K}x{N} completed using {} precision!",
170 if use_f16 { "f16" } else { "f32" }
171 );
172 log::info!("Max error vs CPU reference: {:.6}", results.max_error);
173
174 if results.max_error < results.tolerance {
175 log::info!(
176 "✓ Results match CPU reference within tolerance ({})",
177 results.tolerance
178 );
179 } else {
180 log::warn!(
181 "✗ Results differ from CPU reference (tolerance: {})",
182 results.tolerance
183 );
184 }
185
186 log::info!("Sample of result matrix C (top-left 4x4):");
188 for i in 0..4 {
189 let row: Vec<String> = (0..4)
190 .map(|j| format!("{:6.2}", results.matrix[i * N as usize + j]))
191 .collect();
192 log::info!(" [{}]", row.join(", "));
193 }
194}
195
196struct ExecuteResults {
197 max_error: f32,
198 tolerance: f32,
199 matrix: Vec<f32>,
200}
201
202async fn execute(
203 device: &wgpu::Device,
204 queue: &wgpu::Queue,
205 config: &wgpu::CooperativeMatrixProperties,
206) -> ExecuteResults {
207 let use_f16 = config.ab_type == wgpu::CooperativeScalarType::F16;
208
209 let shader_source = if use_f16 {
211 include_str!("shader_f16_16x16.wgsl")
212 } else {
213 include_str!("shader.wgsl")
214 };
215
216 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
218 label: Some("Cooperative Matrix Shader"),
219 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
220 });
221
222 let matrix_a_f32: Vec<f32> = (0..M * K).map(|i| (i % 7) as f32 * 0.1).collect();
226 let matrix_b_f32: Vec<f32> = (0..K * N).map(|i| (i % 11) as f32 * 0.1).collect();
227 let matrix_c_f32: Vec<f32> = vec![0.0; (M * N) as usize];
228
229 let element_size = if use_f16 { 2usize } else { 4usize };
231 let num_elements_a = (M * K) as usize;
232 let num_elements_b = (K * N) as usize;
233 let num_elements_c = (M * N) as usize;
234
235 let buffer_a = device.create_buffer(&wgpu::BufferDescriptor {
237 label: Some("Matrix A"),
238 size: (num_elements_a * element_size) as u64,
239 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
240 mapped_at_creation: false,
241 });
242
243 let buffer_b = device.create_buffer(&wgpu::BufferDescriptor {
244 label: Some("Matrix B"),
245 size: (num_elements_b * element_size) as u64,
246 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
247 mapped_at_creation: false,
248 });
249
250 let buffer_c = device.create_buffer(&wgpu::BufferDescriptor {
251 label: Some("Matrix C"),
252 size: (num_elements_c * element_size) as u64,
253 usage: wgpu::BufferUsages::STORAGE
254 | wgpu::BufferUsages::COPY_DST
255 | wgpu::BufferUsages::COPY_SRC,
256 mapped_at_creation: false,
257 });
258
259 let dimensions = Dimensions {
260 m: M,
261 n: N,
262 k: K,
263 stride: N,
264 };
265 let buffer_dims = device.create_buffer(&wgpu::BufferDescriptor {
266 label: Some("Dimensions"),
267 size: std::mem::size_of::<Dimensions>() as u64,
268 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
269 mapped_at_creation: false,
270 });
271
272 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
273 label: Some("Staging Buffer"),
274 size: (num_elements_c * element_size) as u64,
275 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
276 mapped_at_creation: false,
277 });
278
279 if use_f16 {
281 let matrix_a_f16: Vec<f16> = matrix_a_f32.iter().map(|&x| f16::from_f32(x)).collect();
282 let matrix_b_f16: Vec<f16> = matrix_b_f32.iter().map(|&x| f16::from_f32(x)).collect();
283 let matrix_c_f16: Vec<f16> = matrix_c_f32.iter().map(|&x| f16::from_f32(x)).collect();
284 queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f16));
285 queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f16));
286 queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f16));
287 } else {
288 queue.write_buffer(&buffer_a, 0, bytemuck::cast_slice(&matrix_a_f32));
289 queue.write_buffer(&buffer_b, 0, bytemuck::cast_slice(&matrix_b_f32));
290 queue.write_buffer(&buffer_c, 0, bytemuck::cast_slice(&matrix_c_f32));
291 }
292 queue.write_buffer(&buffer_dims, 0, bytemuck::bytes_of(&dimensions));
293
294 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
296 label: Some("Cooperative Matrix Bind Group Layout"),
297 entries: &[
298 wgpu::BindGroupLayoutEntry {
299 binding: 0,
300 visibility: wgpu::ShaderStages::COMPUTE,
301 ty: wgpu::BindingType::Buffer {
302 ty: wgpu::BufferBindingType::Storage { read_only: true },
303 has_dynamic_offset: false,
304 min_binding_size: None,
305 },
306 count: None,
307 },
308 wgpu::BindGroupLayoutEntry {
309 binding: 1,
310 visibility: wgpu::ShaderStages::COMPUTE,
311 ty: wgpu::BindingType::Buffer {
312 ty: wgpu::BufferBindingType::Storage { read_only: true },
313 has_dynamic_offset: false,
314 min_binding_size: None,
315 },
316 count: None,
317 },
318 wgpu::BindGroupLayoutEntry {
319 binding: 2,
320 visibility: wgpu::ShaderStages::COMPUTE,
321 ty: wgpu::BindingType::Buffer {
322 ty: wgpu::BufferBindingType::Storage { read_only: false },
323 has_dynamic_offset: false,
324 min_binding_size: None,
325 },
326 count: None,
327 },
328 wgpu::BindGroupLayoutEntry {
329 binding: 3,
330 visibility: wgpu::ShaderStages::COMPUTE,
331 ty: wgpu::BindingType::Buffer {
332 ty: wgpu::BufferBindingType::Uniform,
333 has_dynamic_offset: false,
334 min_binding_size: None,
335 },
336 count: None,
337 },
338 ],
339 });
340
341 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
342 label: Some("Cooperative Matrix Bind Group"),
343 layout: &bind_group_layout,
344 entries: &[
345 wgpu::BindGroupEntry {
346 binding: 0,
347 resource: buffer_a.as_entire_binding(),
348 },
349 wgpu::BindGroupEntry {
350 binding: 1,
351 resource: buffer_b.as_entire_binding(),
352 },
353 wgpu::BindGroupEntry {
354 binding: 2,
355 resource: buffer_c.as_entire_binding(),
356 },
357 wgpu::BindGroupEntry {
358 binding: 3,
359 resource: buffer_dims.as_entire_binding(),
360 },
361 ],
362 });
363
364 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
366 label: Some("Cooperative Matrix Pipeline Layout"),
367 bind_group_layouts: &[Some(&bind_group_layout)],
368 immediate_size: 0,
369 });
370
371 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
372 label: Some("Cooperative Matrix Pipeline"),
373 layout: Some(&pipeline_layout),
374 module: &shader,
375 entry_point: Some("main"),
376 compilation_options: Default::default(),
377 cache: None,
378 });
379
380 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
382 label: Some("Cooperative Matrix Encoder"),
383 });
384
385 {
386 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
387 label: Some("Cooperative Matrix Pass"),
388 timestamp_writes: None,
389 });
390 compute_pass.set_pipeline(&pipeline);
391 compute_pass.set_bind_group(0, &bind_group, &[]);
392 compute_pass.dispatch_workgroups(M / config.m_size, N / config.m_size, 1);
394 }
395
396 encoder.copy_buffer_to_buffer(&buffer_c, 0, &staging_buffer, 0, staging_buffer.size());
398
399 queue.submit(Some(encoder.finish()));
400
401 let buffer_slice = staging_buffer.slice(..);
403 let (sender, receiver) = flume::bounded(1);
404 buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
405 device
406 .poll(wgpu::PollType::wait_indefinitely())
407 .expect("Poll failed");
408 receiver
409 .recv_async()
410 .await
411 .expect("Channel receive failed")
412 .expect("Buffer mapping failed");
413
414 let data = buffer_slice.get_mapped_range();
415
416 let result: Vec<f32> = if use_f16 {
418 let result_f16: &[f16] = bytemuck::cast_slice(&data);
419 result_f16.iter().map(|x| x.to_f32()).collect()
420 } else {
421 bytemuck::cast_slice::<_, f32>(&data).to_vec()
422 };
423
424 let mut reference = vec![0.0f32; (M * N) as usize];
426 for i in 0..M {
427 for j in 0..N {
428 let mut sum = 0.0f32;
429 for k in 0..K {
430 sum += matrix_a_f32[(i * K + k) as usize] * matrix_b_f32[(k * N + j) as usize];
431 }
432 reference[(i * N + j) as usize] = sum;
433 }
434 }
435
436 let tolerance = if use_f16 { 0.1 } else { 0.01 };
438 let mut max_error = 0.0f32;
439 for i in 0..(M * N) as usize {
440 let error = (result[i] - reference[i]).abs();
441 max_error = max_error.max(error);
442 }
443
444 ExecuteResults {
445 max_error,
446 tolerance,
447 matrix: result,
448 }
449}
450
451pub fn main() {
452 #[cfg(not(target_arch = "wasm32"))]
453 {
454 env_logger::builder()
455 .filter_level(log::LevelFilter::Info)
456 .format_timestamp_nanos()
457 .init();
458 pollster::block_on(run());
459 }
460 #[cfg(target_arch = "wasm32")]
461 {
462 std::panic::set_hook(Box::new(console_error_panic_hook::hook));
463 console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
464 crate::utils::add_web_nothing_to_see_msg();
465 wasm_bindgen_futures::spawn_local(run());
466 }
467}
468
469#[cfg(test)]
470pub mod tests;