wgpu_examples/hello_synchronization/
mod.rs

1const ARR_SIZE: usize = 128;
2
3struct ExecuteResults {
4    patient_workgroup_results: Vec<u32>,
5    hasty_workgroup_results: Vec<u32>,
6}
7
8async fn run() {
9    let instance = wgpu::Instance::default();
10    let adapter = instance
11        .request_adapter(&wgpu::RequestAdapterOptions::default())
12        .await
13        .unwrap();
14    let (device, queue) = adapter
15        .request_device(&wgpu::DeviceDescriptor {
16            label: None,
17            required_features: wgpu::Features::empty(),
18            required_limits: wgpu::Limits::downlevel_defaults(),
19            experimental_features: wgpu::ExperimentalFeatures::disabled(),
20            memory_hints: wgpu::MemoryHints::Performance,
21            trace: wgpu::Trace::Off,
22        })
23        .await
24        .unwrap();
25
26    let ExecuteResults {
27        patient_workgroup_results,
28        hasty_workgroup_results,
29    } = execute(&device, &queue, ARR_SIZE).await;
30
31    // Print data
32    log::info!("Patient results: {patient_workgroup_results:?}");
33    if !patient_workgroup_results.iter().any(|e| *e != 16) {
34        log::info!("patient_main was patient.");
35    } else {
36        log::error!("patient_main was not patient!");
37    }
38    log::info!("Hasty results: {hasty_workgroup_results:?}");
39    if hasty_workgroup_results.iter().any(|e| *e != 16) {
40        log::info!("hasty_main was not patient.");
41    } else {
42        log::info!("hasty_main got lucky.");
43    }
44}
45
46async fn execute(
47    device: &wgpu::Device,
48    queue: &wgpu::Queue,
49    result_vec_size: usize,
50) -> ExecuteResults {
51    let mut local_patient_workgroup_results = vec![0u32; result_vec_size];
52    let mut local_hasty_workgroup_results = local_patient_workgroup_results.clone();
53
54    let shaders_module = device.create_shader_module(wgpu::include_wgsl!("shaders.wgsl"));
55
56    let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
57        label: None,
58        size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
59        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
60        mapped_at_creation: false,
61    });
62    let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
63        label: None,
64        size: size_of_val(local_patient_workgroup_results.as_slice()) as u64,
65        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
66        mapped_at_creation: false,
67    });
68
69    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
70        label: None,
71        entries: &[wgpu::BindGroupLayoutEntry {
72            binding: 0,
73            visibility: wgpu::ShaderStages::COMPUTE,
74            ty: wgpu::BindingType::Buffer {
75                ty: wgpu::BufferBindingType::Storage { read_only: false },
76                has_dynamic_offset: false,
77                min_binding_size: None,
78            },
79            count: None,
80        }],
81    });
82    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
83        label: None,
84        layout: &bind_group_layout,
85        entries: &[wgpu::BindGroupEntry {
86            binding: 0,
87            resource: storage_buffer.as_entire_binding(),
88        }],
89    });
90
91    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
92        label: None,
93        bind_group_layouts: &[&bind_group_layout],
94        push_constant_ranges: &[],
95    });
96    let patient_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
97        label: None,
98        layout: Some(&pipeline_layout),
99        module: &shaders_module,
100        entry_point: Some("patient_main"),
101        compilation_options: Default::default(),
102        cache: None,
103    });
104    let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
105        label: None,
106        layout: Some(&pipeline_layout),
107        module: &shaders_module,
108        entry_point: Some("hasty_main"),
109        compilation_options: Default::default(),
110        cache: None,
111    });
112
113    //----------------------------------------------------------
114
115    let mut command_encoder =
116        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
117    {
118        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
119            label: None,
120            timestamp_writes: None,
121        });
122        compute_pass.set_pipeline(&patient_pipeline);
123        compute_pass.set_bind_group(0, &bind_group, &[]);
124        compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
125    }
126    queue.submit(Some(command_encoder.finish()));
127
128    get_data(
129        local_patient_workgroup_results.as_mut_slice(),
130        &storage_buffer,
131        &output_staging_buffer,
132        device,
133        queue,
134    )
135    .await;
136
137    let mut command_encoder =
138        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
139    {
140        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
141            label: None,
142            timestamp_writes: None,
143        });
144        compute_pass.set_pipeline(&hasty_pipeline);
145        compute_pass.set_bind_group(0, &bind_group, &[]);
146        compute_pass.dispatch_workgroups(local_patient_workgroup_results.len() as u32, 1, 1);
147    }
148    queue.submit(Some(command_encoder.finish()));
149
150    get_data(
151        local_hasty_workgroup_results.as_mut_slice(),
152        &storage_buffer,
153        &output_staging_buffer,
154        device,
155        queue,
156    )
157    .await;
158
159    ExecuteResults {
160        patient_workgroup_results: local_patient_workgroup_results,
161        hasty_workgroup_results: local_hasty_workgroup_results,
162    }
163}
164
165async fn get_data<T: bytemuck::Pod>(
166    output: &mut [T],
167    storage_buffer: &wgpu::Buffer,
168    staging_buffer: &wgpu::Buffer,
169    device: &wgpu::Device,
170    queue: &wgpu::Queue,
171) {
172    let mut command_encoder =
173        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
174    command_encoder.copy_buffer_to_buffer(
175        storage_buffer,
176        0,
177        staging_buffer,
178        0,
179        size_of_val(output) as u64,
180    );
181    queue.submit(Some(command_encoder.finish()));
182    let buffer_slice = staging_buffer.slice(..);
183    let (sender, receiver) = flume::bounded(1);
184    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
185    device.poll(wgpu::PollType::wait()).unwrap();
186    receiver.recv_async().await.unwrap().unwrap();
187    output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
188    staging_buffer.unmap();
189}
190
191pub fn main() {
192    #[cfg(not(target_arch = "wasm32"))]
193    {
194        env_logger::builder()
195            .filter_level(log::LevelFilter::Info)
196            .format_timestamp_nanos()
197            .init();
198        pollster::block_on(run());
199    }
200    #[cfg(target_arch = "wasm32")]
201    {
202        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
203        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
204
205        crate::utils::add_web_nothing_to_see_msg();
206
207        wasm_bindgen_futures::spawn_local(run());
208    }
209}
210
211#[cfg(test)]
212pub mod tests;