wgpu_examples/hello_synchronization/
mod.rs
1const ARR_SIZE: usize = 128;
2
3struct ExecuteResults {
4 patient_workgroup_results: Vec<u32>,
5 hasty_workgroup_results: Vec<u32>,
6}
7
8async fn run() {
9 let instance = wgpu::Instance::default();
10 let adapter = instance
11 .request_adapter(&wgpu::RequestAdapterOptions::default())
12 .await
13 .unwrap();
14 let (device, queue) = adapter
15 .request_device(&wgpu::DeviceDescriptor {
16 label: None,
17 required_features: wgpu::Features::empty(),
18 required_limits: wgpu::Limits::downlevel_defaults(),
19 memory_hints: wgpu::MemoryHints::Performance,
20 trace: wgpu::Trace::Off,
21 })
22 .await
23 .unwrap();
24
25 let ExecuteResults {
26 patient_workgroup_results,
27 hasty_workgroup_results,
28 } = execute(&device, &queue, ARR_SIZE).await;
29
30 log::info!("Patient results: {patient_workgroup_results:?}");
32 if !patient_workgroup_results.iter().any(|e| *e != 16) {
33 log::info!("patient_main was patient.");
34 } else {
35 log::error!("patient_main was not patient!");
36 }
37 log::info!("Hasty results: {hasty_workgroup_results:?}");
38 if hasty_workgroup_results.iter().any(|e| *e != 16) {
39 log::info!("hasty_main was not patient.");
40 } else {
41 log::info!("hasty_main got lucky.");
42 }
43}
44
45async fn execute(
46 device: &wgpu::Device,
47 queue: &wgpu::Queue,
48 result_vec_size: usize,
49) -> ExecuteResults {
50 let mut local_patient_workgroup_results = vec![0u32; result_vec_size];
51 let mut local_hasty_workgroup_results = local_patient_workgroup_results.clone();
52
53 let shaders_module = device.create_shader_module(wgpu::include_wgsl!("shaders.wgsl"));
54
55 let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
56 label: None,
57 size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
58 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
59 mapped_at_creation: false,
60 });
61 let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
62 label: None,
63 size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
64 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
65 mapped_at_creation: false,
66 });
67
68 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
69 label: None,
70 entries: &[wgpu::BindGroupLayoutEntry {
71 binding: 0,
72 visibility: wgpu::ShaderStages::COMPUTE,
73 ty: wgpu::BindingType::Buffer {
74 ty: wgpu::BufferBindingType::Storage { read_only: false },
75 has_dynamic_offset: false,
76 min_binding_size: None,
77 },
78 count: None,
79 }],
80 });
81 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
82 label: None,
83 layout: &bind_group_layout,
84 entries: &[wgpu::BindGroupEntry {
85 binding: 0,
86 resource: storage_buffer.as_entire_binding(),
87 }],
88 });
89
90 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
91 label: None,
92 bind_group_layouts: &[&bind_group_layout],
93 push_constant_ranges: &[],
94 });
95 let patient_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
96 label: None,
97 layout: Some(&pipeline_layout),
98 module: &shaders_module,
99 entry_point: Some("patient_main"),
100 compilation_options: Default::default(),
101 cache: None,
102 });
103 let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
104 label: None,
105 layout: Some(&pipeline_layout),
106 module: &shaders_module,
107 entry_point: Some("hasty_main"),
108 compilation_options: Default::default(),
109 cache: None,
110 });
111
112 let mut command_encoder =
115 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
116 {
117 let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
118 label: None,
119 timestamp_writes: None,
120 });
121 compute_pass.set_pipeline(&patient_pipeline);
122 compute_pass.set_bind_group(0, &bind_group, &[]);
123 compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
124 }
125 queue.submit(Some(command_encoder.finish()));
126
127 get_data(
128 local_patient_workgroup_results.as_mut_slice(),
129 &storage_buffer,
130 &output_staging_buffer,
131 device,
132 queue,
133 )
134 .await;
135
136 let mut command_encoder =
137 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
138 {
139 let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
140 label: None,
141 timestamp_writes: None,
142 });
143 compute_pass.set_pipeline(&hasty_pipeline);
144 compute_pass.set_bind_group(0, &bind_group, &[]);
145 compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
146 }
147 queue.submit(Some(command_encoder.finish()));
148
149 get_data(
150 local_hasty_workgroup_results.as_mut_slice(),
151 &storage_buffer,
152 &output_staging_buffer,
153 device,
154 queue,
155 )
156 .await;
157
158 ExecuteResults {
159 patient_workgroup_results: local_patient_workgroup_results,
160 hasty_workgroup_results: local_hasty_workgroup_results,
161 }
162}
163
164async fn get_data<T: bytemuck::Pod>(
165 output: &mut [T],
166 storage_buffer: &wgpu::Buffer,
167 staging_buffer: &wgpu::Buffer,
168 device: &wgpu::Device,
169 queue: &wgpu::Queue,
170) {
171 let mut command_encoder =
172 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
173 command_encoder.copy_buffer_to_buffer(
174 storage_buffer,
175 0,
176 staging_buffer,
177 0,
178 size_of_val(output) as u64,
179 );
180 queue.submit(Some(command_encoder.finish()));
181 let buffer_slice = staging_buffer.slice(..);
182 let (sender, receiver) = flume::bounded(1);
183 buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
184 device.poll(wgpu::PollType::wait()).unwrap();
185 receiver.recv_async().await.unwrap().unwrap();
186 output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
187 staging_buffer.unmap();
188}
189
190pub fn main() {
191 #[cfg(not(target_arch = "wasm32"))]
192 {
193 env_logger::builder()
194 .filter_level(log::LevelFilter::Info)
195 .format_timestamp_nanos()
196 .init();
197 pollster::block_on(run());
198 }
199 #[cfg(target_arch = "wasm32")]
200 {
201 std::panic::set_hook(Box::new(console_error_panic_hook::hook));
202 console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
203
204 crate::utils::add_web_nothing_to_see_msg();
205
206 wasm_bindgen_futures::spawn_local(run());
207 }
208}
209
210#[cfg(test)]
211pub mod tests;