wgpu_examples/hello_workgroups/
mod.rs

1//! This example assumes that you've seen hello-compute and or repeated-compute
2//! and thus have a general understanding of what's going on here.
3//!
4//! There's an explainer on what this example does exactly and what workgroups
5//! are and the meaning of `@workgroup(size_x, size_y, size_z)` in the
6//! README. Also see commenting in shader.wgsl as well.
7//!
8//! Only parts specific to this example will be commented.
9
10use wgpu::util::DeviceExt;
11
12async fn run() {
13    let mut local_a = [0i32; 100];
14    for (i, e) in local_a.iter_mut().enumerate() {
15        *e = i as i32;
16    }
17    log::info!("Input a: {local_a:?}");
18    let mut local_b = [0i32; 100];
19    for (i, e) in local_b.iter_mut().enumerate() {
20        *e = i as i32 * 2;
21    }
22    log::info!("Input b: {local_b:?}");
23
24    let instance = wgpu::Instance::default();
25    let adapter = instance
26        .request_adapter(&wgpu::RequestAdapterOptions::default())
27        .await
28        .unwrap();
29    let (device, queue) = adapter
30        .request_device(&wgpu::DeviceDescriptor {
31            label: None,
32            required_features: wgpu::Features::empty(),
33            required_limits: wgpu::Limits::downlevel_defaults(),
34            experimental_features: wgpu::ExperimentalFeatures::disabled(),
35            memory_hints: wgpu::MemoryHints::MemoryUsage,
36            trace: wgpu::Trace::Off,
37        })
38        .await
39        .unwrap();
40
41    let shader = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
42
43    let storage_buffer_a = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
44        label: None,
45        contents: bytemuck::cast_slice(&local_a[..]),
46        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
47    });
48    let storage_buffer_b = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
49        label: None,
50        contents: bytemuck::cast_slice(&local_b[..]),
51        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
52    });
53    let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
54        label: None,
55        size: size_of_val(&local_a) as u64,
56        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
57        mapped_at_creation: false,
58    });
59
60    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
61        label: None,
62        entries: &[
63            wgpu::BindGroupLayoutEntry {
64                binding: 0,
65                visibility: wgpu::ShaderStages::COMPUTE,
66                ty: wgpu::BindingType::Buffer {
67                    ty: wgpu::BufferBindingType::Storage { read_only: false },
68                    has_dynamic_offset: false,
69                    min_binding_size: None,
70                },
71                count: None,
72            },
73            wgpu::BindGroupLayoutEntry {
74                binding: 1,
75                visibility: wgpu::ShaderStages::COMPUTE,
76                ty: wgpu::BindingType::Buffer {
77                    ty: wgpu::BufferBindingType::Storage { read_only: false },
78                    has_dynamic_offset: false,
79                    min_binding_size: None,
80                },
81                count: None,
82            },
83        ],
84    });
85    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
86        label: None,
87        layout: &bind_group_layout,
88        entries: &[
89            wgpu::BindGroupEntry {
90                binding: 0,
91                resource: storage_buffer_a.as_entire_binding(),
92            },
93            wgpu::BindGroupEntry {
94                binding: 1,
95                resource: storage_buffer_b.as_entire_binding(),
96            },
97        ],
98    });
99
100    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
101        label: None,
102        bind_group_layouts: &[&bind_group_layout],
103        push_constant_ranges: &[],
104    });
105    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
106        label: None,
107        layout: Some(&pipeline_layout),
108        module: &shader,
109        entry_point: Some("main"),
110        compilation_options: Default::default(),
111        cache: None,
112    });
113
114    //----------------------------------------------------------
115
116    let mut command_encoder =
117        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
118    {
119        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
120            label: None,
121            timestamp_writes: None,
122        });
123        compute_pass.set_pipeline(&pipeline);
124        compute_pass.set_bind_group(0, &bind_group, &[]);
125        /* Note that since each workgroup will cover both arrays, we only need to
126        cover the length of one array. */
127        compute_pass.dispatch_workgroups(local_a.len() as u32, 1, 1);
128    }
129    queue.submit(Some(command_encoder.finish()));
130
131    //----------------------------------------------------------
132
133    get_data(
134        &mut local_a[..],
135        &storage_buffer_a,
136        &output_staging_buffer,
137        &device,
138        &queue,
139    )
140    .await;
141    get_data(
142        &mut local_b[..],
143        &storage_buffer_b,
144        &output_staging_buffer,
145        &device,
146        &queue,
147    )
148    .await;
149
150    log::info!("Output in A: {local_a:?}");
151    log::info!("Output in B: {local_b:?}");
152}
153
154async fn get_data<T: bytemuck::Pod>(
155    output: &mut [T],
156    storage_buffer: &wgpu::Buffer,
157    staging_buffer: &wgpu::Buffer,
158    device: &wgpu::Device,
159    queue: &wgpu::Queue,
160) {
161    let mut command_encoder =
162        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
163    command_encoder.copy_buffer_to_buffer(
164        storage_buffer,
165        0,
166        staging_buffer,
167        0,
168        size_of_val(output) as u64,
169    );
170    queue.submit(Some(command_encoder.finish()));
171    let buffer_slice = staging_buffer.slice(..);
172    let (sender, receiver) = flume::bounded(1);
173    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
174    device.poll(wgpu::PollType::wait()).unwrap();
175    receiver.recv_async().await.unwrap().unwrap();
176    output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
177    staging_buffer.unmap();
178}
179
180pub fn main() {
181    #[cfg(not(target_arch = "wasm32"))]
182    {
183        env_logger::builder()
184            .filter_level(log::LevelFilter::Info)
185            .format_timestamp_nanos()
186            .init();
187        pollster::block_on(run());
188    }
189    #[cfg(target_arch = "wasm32")]
190    {
191        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
192        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
193
194        crate::utils::add_web_nothing_to_see_msg();
195
196        wasm_bindgen_futures::spawn_local(run());
197    }
198}