wgpu_examples/hello_synchronization/
mod.rs

1const ARR_SIZE: usize = 128;
2
3struct ExecuteResults {
4    patient_workgroup_results: Vec<u32>,
5    hasty_workgroup_results: Vec<u32>,
6}
7
8async fn run() {
9    let instance = wgpu::Instance::default();
10    let adapter = instance
11        .request_adapter(&wgpu::RequestAdapterOptions::default())
12        .await
13        .unwrap();
14    let (device, queue) = adapter
15        .request_device(&wgpu::DeviceDescriptor {
16            label: None,
17            required_features: wgpu::Features::empty(),
18            required_limits: wgpu::Limits::downlevel_defaults(),
19            memory_hints: wgpu::MemoryHints::Performance,
20            trace: wgpu::Trace::Off,
21        })
22        .await
23        .unwrap();
24
25    let ExecuteResults {
26        patient_workgroup_results,
27        hasty_workgroup_results,
28    } = execute(&device, &queue, ARR_SIZE).await;
29
30    // Print data
31    log::info!("Patient results: {patient_workgroup_results:?}");
32    if !patient_workgroup_results.iter().any(|e| *e != 16) {
33        log::info!("patient_main was patient.");
34    } else {
35        log::error!("patient_main was not patient!");
36    }
37    log::info!("Hasty results: {hasty_workgroup_results:?}");
38    if hasty_workgroup_results.iter().any(|e| *e != 16) {
39        log::info!("hasty_main was not patient.");
40    } else {
41        log::info!("hasty_main got lucky.");
42    }
43}
44
45async fn execute(
46    device: &wgpu::Device,
47    queue: &wgpu::Queue,
48    result_vec_size: usize,
49) -> ExecuteResults {
50    let mut local_patient_workgroup_results = vec![0u32; result_vec_size];
51    let mut local_hasty_workgroup_results = local_patient_workgroup_results.clone();
52
53    let shaders_module = device.create_shader_module(wgpu::include_wgsl!("shaders.wgsl"));
54
55    let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
56        label: None,
57        size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
58        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
59        mapped_at_creation: false,
60    });
61    let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
62        label: None,
63        size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
64        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
65        mapped_at_creation: false,
66    });
67
68    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
69        label: None,
70        entries: &[wgpu::BindGroupLayoutEntry {
71            binding: 0,
72            visibility: wgpu::ShaderStages::COMPUTE,
73            ty: wgpu::BindingType::Buffer {
74                ty: wgpu::BufferBindingType::Storage { read_only: false },
75                has_dynamic_offset: false,
76                min_binding_size: None,
77            },
78            count: None,
79        }],
80    });
81    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
82        label: None,
83        layout: &bind_group_layout,
84        entries: &[wgpu::BindGroupEntry {
85            binding: 0,
86            resource: storage_buffer.as_entire_binding(),
87        }],
88    });
89
90    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
91        label: None,
92        bind_group_layouts: &[&bind_group_layout],
93        push_constant_ranges: &[],
94    });
95    let patient_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
96        label: None,
97        layout: Some(&pipeline_layout),
98        module: &shaders_module,
99        entry_point: Some("patient_main"),
100        compilation_options: Default::default(),
101        cache: None,
102    });
103    let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
104        label: None,
105        layout: Some(&pipeline_layout),
106        module: &shaders_module,
107        entry_point: Some("hasty_main"),
108        compilation_options: Default::default(),
109        cache: None,
110    });
111
112    //----------------------------------------------------------
113
114    let mut command_encoder =
115        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
116    {
117        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
118            label: None,
119            timestamp_writes: None,
120        });
121        compute_pass.set_pipeline(&patient_pipeline);
122        compute_pass.set_bind_group(0, &bind_group, &[]);
123        compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
124    }
125    queue.submit(Some(command_encoder.finish()));
126
127    get_data(
128        local_patient_workgroup_results.as_mut_slice(),
129        &storage_buffer,
130        &output_staging_buffer,
131        device,
132        queue,
133    )
134    .await;
135
136    let mut command_encoder =
137        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
138    {
139        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
140            label: None,
141            timestamp_writes: None,
142        });
143        compute_pass.set_pipeline(&hasty_pipeline);
144        compute_pass.set_bind_group(0, &bind_group, &[]);
145        compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
146    }
147    queue.submit(Some(command_encoder.finish()));
148
149    get_data(
150        local_hasty_workgroup_results.as_mut_slice(),
151        &storage_buffer,
152        &output_staging_buffer,
153        device,
154        queue,
155    )
156    .await;
157
158    ExecuteResults {
159        patient_workgroup_results: local_patient_workgroup_results,
160        hasty_workgroup_results: local_hasty_workgroup_results,
161    }
162}
163
164async fn get_data<T: bytemuck::Pod>(
165    output: &mut [T],
166    storage_buffer: &wgpu::Buffer,
167    staging_buffer: &wgpu::Buffer,
168    device: &wgpu::Device,
169    queue: &wgpu::Queue,
170) {
171    let mut command_encoder =
172        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
173    command_encoder.copy_buffer_to_buffer(
174        storage_buffer,
175        0,
176        staging_buffer,
177        0,
178        size_of_val(output) as u64,
179    );
180    queue.submit(Some(command_encoder.finish()));
181    let buffer_slice = staging_buffer.slice(..);
182    let (sender, receiver) = flume::bounded(1);
183    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
184    device.poll(wgpu::PollType::wait()).unwrap();
185    receiver.recv_async().await.unwrap().unwrap();
186    output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
187    staging_buffer.unmap();
188}
189
190pub fn main() {
191    #[cfg(not(target_arch = "wasm32"))]
192    {
193        env_logger::builder()
194            .filter_level(log::LevelFilter::Info)
195            .format_timestamp_nanos()
196            .init();
197        pollster::block_on(run());
198    }
199    #[cfg(target_arch = "wasm32")]
200    {
201        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
202        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
203
204        crate::utils::add_web_nothing_to_see_msg();
205
206        wasm_bindgen_futures::spawn_local(run());
207    }
208}
209
210#[cfg(test)]
211pub mod tests;