wgpu_examples/hello_synchronization/
mod.rs1const ARR_SIZE: usize = 128;
2
3struct ExecuteResults {
4 patient_workgroup_results: Vec<u32>,
5 hasty_workgroup_results: Vec<u32>,
6}
7
8async fn run() {
9 let instance = wgpu::Instance::default();
10 let adapter = instance
11 .request_adapter(&wgpu::RequestAdapterOptions::default())
12 .await
13 .unwrap();
14 let (device, queue) = adapter
15 .request_device(&wgpu::DeviceDescriptor {
16 label: None,
17 required_features: wgpu::Features::empty(),
18 required_limits: wgpu::Limits::downlevel_defaults(),
19 experimental_features: wgpu::ExperimentalFeatures::disabled(),
20 memory_hints: wgpu::MemoryHints::Performance,
21 trace: wgpu::Trace::Off,
22 })
23 .await
24 .unwrap();
25
26 let ExecuteResults {
27 patient_workgroup_results,
28 hasty_workgroup_results,
29 } = execute(&device, &queue, ARR_SIZE).await;
30
31 log::info!("Patient results: {patient_workgroup_results:?}");
33 if !patient_workgroup_results.iter().any(|e| *e != 16) {
34 log::info!("patient_main was patient.");
35 } else {
36 log::error!("patient_main was not patient!");
37 }
38 log::info!("Hasty results: {hasty_workgroup_results:?}");
39 if hasty_workgroup_results.iter().any(|e| *e != 16) {
40 log::info!("hasty_main was not patient.");
41 } else {
42 log::info!("hasty_main got lucky.");
43 }
44}
45
46async fn execute(
47 device: &wgpu::Device,
48 queue: &wgpu::Queue,
49 result_vec_size: usize,
50) -> ExecuteResults {
51 let mut local_patient_workgroup_results = vec![0u32; result_vec_size];
52 let mut local_hasty_workgroup_results = local_patient_workgroup_results.clone();
53
54 let shaders_module = device.create_shader_module(wgpu::include_wgsl!("shaders.wgsl"));
55
56 let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
57 label: None,
58 size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
59 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
60 mapped_at_creation: false,
61 });
62 let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
63 label: None,
64 size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
65 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
66 mapped_at_creation: false,
67 });
68
69 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
70 label: None,
71 entries: &[wgpu::BindGroupLayoutEntry {
72 binding: 0,
73 visibility: wgpu::ShaderStages::COMPUTE,
74 ty: wgpu::BindingType::Buffer {
75 ty: wgpu::BufferBindingType::Storage { read_only: false },
76 has_dynamic_offset: false,
77 min_binding_size: None,
78 },
79 count: None,
80 }],
81 });
82 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
83 label: None,
84 layout: &bind_group_layout,
85 entries: &[wgpu::BindGroupEntry {
86 binding: 0,
87 resource: storage_buffer.as_entire_binding(),
88 }],
89 });
90
91 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
92 label: None,
93 bind_group_layouts: &[&bind_group_layout],
94 push_constant_ranges: &[],
95 });
96 let patient_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
97 label: None,
98 layout: Some(&pipeline_layout),
99 module: &shaders_module,
100 entry_point: Some("patient_main"),
101 compilation_options: Default::default(),
102 cache: None,
103 });
104 let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
105 label: None,
106 layout: Some(&pipeline_layout),
107 module: &shaders_module,
108 entry_point: Some("hasty_main"),
109 compilation_options: Default::default(),
110 cache: None,
111 });
112
113 let mut command_encoder =
116 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
117 {
118 let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
119 label: None,
120 timestamp_writes: None,
121 });
122 compute_pass.set_pipeline(&patient_pipeline);
123 compute_pass.set_bind_group(0, &bind_group, &[]);
124 compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
125 }
126 queue.submit(Some(command_encoder.finish()));
127
128 get_data(
129 local_patient_workgroup_results.as_mut_slice(),
130 &storage_buffer,
131 &output_staging_buffer,
132 device,
133 queue,
134 )
135 .await;
136
137 let mut command_encoder =
138 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
139 {
140 let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
141 label: None,
142 timestamp_writes: None,
143 });
144 compute_pass.set_pipeline(&hasty_pipeline);
145 compute_pass.set_bind_group(0, &bind_group, &[]);
146 compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
147 }
148 queue.submit(Some(command_encoder.finish()));
149
150 get_data(
151 local_hasty_workgroup_results.as_mut_slice(),
152 &storage_buffer,
153 &output_staging_buffer,
154 device,
155 queue,
156 )
157 .await;
158
159 ExecuteResults {
160 patient_workgroup_results: local_patient_workgroup_results,
161 hasty_workgroup_results: local_hasty_workgroup_results,
162 }
163}
164
165async fn get_data<T: bytemuck::Pod>(
166 output: &mut [T],
167 storage_buffer: &wgpu::Buffer,
168 staging_buffer: &wgpu::Buffer,
169 device: &wgpu::Device,
170 queue: &wgpu::Queue,
171) {
172 let mut command_encoder =
173 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
174 command_encoder.copy_buffer_to_buffer(
175 storage_buffer,
176 0,
177 staging_buffer,
178 0,
179 size_of_val(output) as u64,
180 );
181 queue.submit(Some(command_encoder.finish()));
182 let buffer_slice = staging_buffer.slice(..);
183 let (sender, receiver) = flume::bounded(1);
184 buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
185 device.poll(wgpu::PollType::wait()).unwrap();
186 receiver.recv_async().await.unwrap().unwrap();
187 output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
188 staging_buffer.unmap();
189}
190
191pub fn main() {
192 #[cfg(not(target_arch = "wasm32"))]
193 {
194 env_logger::builder()
195 .filter_level(log::LevelFilter::Info)
196 .format_timestamp_nanos()
197 .init();
198 pollster::block_on(run());
199 }
200 #[cfg(target_arch = "wasm32")]
201 {
202 std::panic::set_hook(Box::new(console_error_panic_hook::hook));
203 console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
204
205 crate::utils::add_web_nothing_to_see_msg();
206
207 wasm_bindgen_futures::spawn_local(run());
208 }
209}
210
211#[cfg(test)]
212pub mod tests;