wgpu_examples/hello_workgroups/
mod.rs

1//! This example assumes that you've seen hello-compute and or repeated-compute
2//! and thus have a general understanding of what's going on here.
3//!
4//! There's an explainer on what this example does exactly and what workgroups
5//! are and the meaning of `@workgroup(size_x, size_y, size_z)` in the
6//! README. Also see commenting in shader.wgsl as well.
7//!
8//! Only parts specific to this example will be commented.
9
10use wgpu::util::DeviceExt;
11
12async fn run() {
13    let mut local_a = [0i32; 100];
14    for (i, e) in local_a.iter_mut().enumerate() {
15        *e = i as i32;
16    }
17    log::info!("Input a: {local_a:?}");
18    let mut local_b = [0i32; 100];
19    for (i, e) in local_b.iter_mut().enumerate() {
20        *e = i as i32 * 2;
21    }
22    log::info!("Input b: {local_b:?}");
23
24    let instance = wgpu::Instance::default();
25    let adapter = instance
26        .request_adapter(&wgpu::RequestAdapterOptions::default())
27        .await
28        .unwrap();
29    let (device, queue) = adapter
30        .request_device(&wgpu::DeviceDescriptor {
31            label: None,
32            required_features: wgpu::Features::empty(),
33            required_limits: wgpu::Limits::downlevel_defaults(),
34            memory_hints: wgpu::MemoryHints::MemoryUsage,
35            trace: wgpu::Trace::Off,
36        })
37        .await
38        .unwrap();
39
40    let shader = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
41
42    let storage_buffer_a = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
43        label: None,
44        contents: bytemuck::cast_slice(&local_a[..]),
45        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
46    });
47    let storage_buffer_b = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
48        label: None,
49        contents: bytemuck::cast_slice(&local_b[..]),
50        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
51    });
52    let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
53        label: None,
54        size: size_of_val(&local_a) as u64,
55        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
56        mapped_at_creation: false,
57    });
58
59    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
60        label: None,
61        entries: &[
62            wgpu::BindGroupLayoutEntry {
63                binding: 0,
64                visibility: wgpu::ShaderStages::COMPUTE,
65                ty: wgpu::BindingType::Buffer {
66                    ty: wgpu::BufferBindingType::Storage { read_only: false },
67                    has_dynamic_offset: false,
68                    min_binding_size: None,
69                },
70                count: None,
71            },
72            wgpu::BindGroupLayoutEntry {
73                binding: 1,
74                visibility: wgpu::ShaderStages::COMPUTE,
75                ty: wgpu::BindingType::Buffer {
76                    ty: wgpu::BufferBindingType::Storage { read_only: false },
77                    has_dynamic_offset: false,
78                    min_binding_size: None,
79                },
80                count: None,
81            },
82        ],
83    });
84    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
85        label: None,
86        layout: &bind_group_layout,
87        entries: &[
88            wgpu::BindGroupEntry {
89                binding: 0,
90                resource: storage_buffer_a.as_entire_binding(),
91            },
92            wgpu::BindGroupEntry {
93                binding: 1,
94                resource: storage_buffer_b.as_entire_binding(),
95            },
96        ],
97    });
98
99    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
100        label: None,
101        bind_group_layouts: &[&bind_group_layout],
102        push_constant_ranges: &[],
103    });
104    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
105        label: None,
106        layout: Some(&pipeline_layout),
107        module: &shader,
108        entry_point: Some("main"),
109        compilation_options: Default::default(),
110        cache: None,
111    });
112
113    //----------------------------------------------------------
114
115    let mut command_encoder =
116        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
117    {
118        let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
119            label: None,
120            timestamp_writes: None,
121        });
122        compute_pass.set_pipeline(&pipeline);
123        compute_pass.set_bind_group(0, &bind_group, &[]);
124        /* Note that since each workgroup will cover both arrays, we only need to
125        cover the length of one array. */
126        compute_pass.dispatch_workgroups(local_a.len() as u32, 1, 1);
127    }
128    queue.submit(Some(command_encoder.finish()));
129
130    //----------------------------------------------------------
131
132    get_data(
133        &mut local_a[..],
134        &storage_buffer_a,
135        &output_staging_buffer,
136        &device,
137        &queue,
138    )
139    .await;
140    get_data(
141        &mut local_b[..],
142        &storage_buffer_b,
143        &output_staging_buffer,
144        &device,
145        &queue,
146    )
147    .await;
148
149    log::info!("Output in A: {local_a:?}");
150    log::info!("Output in B: {local_b:?}");
151}
152
153async fn get_data<T: bytemuck::Pod>(
154    output: &mut [T],
155    storage_buffer: &wgpu::Buffer,
156    staging_buffer: &wgpu::Buffer,
157    device: &wgpu::Device,
158    queue: &wgpu::Queue,
159) {
160    let mut command_encoder =
161        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
162    command_encoder.copy_buffer_to_buffer(
163        storage_buffer,
164        0,
165        staging_buffer,
166        0,
167        size_of_val(output) as u64,
168    );
169    queue.submit(Some(command_encoder.finish()));
170    let buffer_slice = staging_buffer.slice(..);
171    let (sender, receiver) = flume::bounded(1);
172    buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
173    device.poll(wgpu::PollType::wait()).unwrap();
174    receiver.recv_async().await.unwrap().unwrap();
175    output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
176    staging_buffer.unmap();
177}
178
179pub fn main() {
180    #[cfg(not(target_arch = "wasm32"))]
181    {
182        env_logger::builder()
183            .filter_level(log::LevelFilter::Info)
184            .format_timestamp_nanos()
185            .init();
186        pollster::block_on(run());
187    }
188    #[cfg(target_arch = "wasm32")]
189    {
190        std::panic::set_hook(Box::new(console_error_panic_hook::hook));
191        console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
192
193        crate::utils::add_web_nothing_to_see_msg();
194
195        wasm_bindgen_futures::spawn_local(run());
196    }
197}