wgpu_examples/big_compute_buffers/
mod.rs

1//! This example shows you a potential course for when your 'data' is too large
2//! for a single Buffer.
3//!
4//! A lot of things aren't explained here via comments. See hello-compute and
5//! repeated-compute for code that is more thoroughly commented.
6
7use std::num::NonZeroU32;
8use wgpu::{util::DeviceExt, Features};
9
10// These are set by the minimum required defaults for webgpu.
11const MAX_BUFFER_SIZE: u64 = 1 << 27; // 134_217_728 // 134MB
12const MAX_DISPATCH_SIZE: u32 = (1 << 16) - 1;
13
14pub async fn execute_gpu(numbers: &[f32]) -> Vec<f32> {
15    let instance = wgpu::Instance::default();
16
17    let adapter = instance
18        .request_adapter(&wgpu::RequestAdapterOptions::default())
19        .await
20        .unwrap();
21
22    let (device, queue) = adapter
23        .request_device(&wgpu::DeviceDescriptor {
24            label: None,
25            // These features are required to use `binding_array` in your wgsl.
26            // Without them your shader may fail to compile.
27            required_features: Features::BUFFER_BINDING_ARRAY
28                | Features::STORAGE_RESOURCE_BINDING_ARRAY
29                | Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
30            memory_hints: wgpu::MemoryHints::Performance,
31            required_limits: wgpu::Limits {
32                max_buffer_size: MAX_BUFFER_SIZE,
33                max_binding_array_elements_per_shader_stage: 8,
34                ..Default::default()
35            },
36            ..Default::default()
37        })
38        .await
39        .unwrap();
40
41    execute_gpu_inner(&device, &queue, numbers).await
42}
43
44pub async fn execute_gpu_inner(
45    device: &wgpu::Device,
46    queue: &wgpu::Queue,
47    numbers: &[f32],
48) -> Vec<f32> {
49    let (staging_buffers, storage_buffers, bind_group, compute_pipeline) = setup(device, numbers);
50
51    let mut encoder =
52        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
53    {
54        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
55            label: Some("compute pass descriptor"),
56            timestamp_writes: None,
57        });
58        cpass.set_pipeline(&compute_pipeline);
59        cpass.set_bind_group(0, Some(&bind_group), &[]);
60
61        cpass.dispatch_workgroups(MAX_DISPATCH_SIZE.min(numbers.len() as u32), 1, 1);
62    }
63
64    for (storage_buffer, staging_buffer) in storage_buffers.iter().zip(staging_buffers.iter()) {
65        let stg_size = staging_buffer.size();
66
67        encoder.copy_buffer_to_buffer(
68            storage_buffer, // Source buffer
69            0,
70            staging_buffer, // Destination buffer
71            0,
72            stg_size,
73        );
74    }
75
76    queue.submit(Some(encoder.finish()));
77
78    for staging_buffer in &staging_buffers {
79        let slice = staging_buffer.slice(..);
80        slice.map_async(wgpu::MapMode::Read, |_| {});
81    }
82
83    device.poll(wgpu::PollType::Wait).unwrap();
84
85    let mut data = Vec::new();
86    for staging_buffer in &staging_buffers {
87        let slice = staging_buffer.slice(..);
88        let mapped = slice.get_mapped_range();
89        data.extend_from_slice(bytemuck::cast_slice(&mapped));
90        drop(mapped);
91        staging_buffer.unmap();
92    }
93
94    data
95}
96
97fn setup(
98    device: &wgpu::Device,
99    numbers: &[f32],
100) -> (
101    Vec<wgpu::Buffer>,
102    Vec<wgpu::Buffer>,
103    wgpu::BindGroup,
104    wgpu::ComputePipeline,
105) {
106    let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
107
108    let staging_buffers = create_staging_buffers(device, numbers);
109    let storage_buffers = create_storage_buffers(device, numbers);
110
111    let (bind_group_layout, bind_group) = setup_binds(&storage_buffers, device);
112
113    let compute_pipeline = setup_pipeline(device, bind_group_layout, cs_module);
114    (
115        staging_buffers,
116        storage_buffers,
117        bind_group,
118        compute_pipeline,
119    )
120}
121
122fn setup_pipeline(
123    device: &wgpu::Device,
124    bind_group_layout: wgpu::BindGroupLayout,
125    cs_module: wgpu::ShaderModule,
126) -> wgpu::ComputePipeline {
127    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
128        label: Some("Compute Pipeline Layout"),
129        bind_group_layouts: &[&bind_group_layout],
130        push_constant_ranges: &[],
131    });
132
133    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
134        label: Some("Compute Pipeline"),
135        layout: Some(&pipeline_layout),
136        module: &cs_module,
137        entry_point: Some("main"),
138        compilation_options: Default::default(),
139        cache: None,
140    })
141}
142
143fn setup_binds(
144    storage_buffers: &[wgpu::Buffer],
145    device: &wgpu::Device,
146) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
147    let bind_group_entries: Vec<wgpu::BindGroupEntry> = storage_buffers
148        .iter()
149        .enumerate()
150        .map(|(bind_idx, buffer)| wgpu::BindGroupEntry {
151            binding: bind_idx as u32,
152            resource: buffer.as_entire_binding(),
153        })
154        .collect();
155
156    let bind_group_layout_entries: Vec<wgpu::BindGroupLayoutEntry> = (0..storage_buffers.len())
157        .map(|bind_idx| wgpu::BindGroupLayoutEntry {
158            binding: bind_idx as u32,
159            visibility: wgpu::ShaderStages::COMPUTE,
160            ty: wgpu::BindingType::Buffer {
161                ty: wgpu::BufferBindingType::Storage { read_only: false },
162                has_dynamic_offset: false,
163                min_binding_size: None,
164            },
165            count: Some(NonZeroU32::new(1).unwrap()),
166        })
167        .collect();
168
169    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
170        label: Some("Custom Storage Bind Group Layout"),
171        entries: &bind_group_layout_entries,
172    });
173
174    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
175        label: Some("Combined Storage Bind Group"),
176        layout: &bind_group_layout,
177        entries: &bind_group_entries,
178    });
179
180    (bind_group_layout, bind_group)
181}
182
183fn calculate_chunks(numbers: &[f32], max_buffer_size: u64) -> Vec<&[f32]> {
184    let max_elements_per_chunk = max_buffer_size as usize / std::mem::size_of::<f32>();
185    numbers.chunks(max_elements_per_chunk).collect()
186}
187
188fn create_storage_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
189    let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
190
191    chunks
192        .iter()
193        .enumerate()
194        .map(|(e, seg)| {
195            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
196                label: Some(&format!("Storage Buffer-{e}")),
197                contents: bytemuck::cast_slice(seg),
198                usage: wgpu::BufferUsages::STORAGE
199                    | wgpu::BufferUsages::COPY_DST
200                    | wgpu::BufferUsages::COPY_SRC,
201            })
202        })
203        .collect()
204}
205
206fn create_staging_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
207    let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
208
209    (0..chunks.len())
210        .map(|e| {
211            let size = std::mem::size_of_val(chunks[e]) as u64;
212
213            device.create_buffer(&wgpu::BufferDescriptor {
214                label: Some(&format!("staging buffer-{e}")),
215                size,
216                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
217                mapped_at_creation: false,
218            })
219        })
220        .collect()
221}
222
223#[cfg_attr(target_arch = "wasm32", allow(clippy::allow_attributes, dead_code))]
224async fn run() {
225    let numbers = {
226        const BYTES_PER_GB: usize = 1024 * 1024 * 1024;
227        // 4 bytes per f32
228        let elements = (BYTES_PER_GB as f32 / 4.0) as usize;
229        vec![0.0; elements]
230    };
231    assert!(numbers.iter().all(|n| *n == 0.0));
232    log::info!("All 0.0s");
233    let t1 = std::time::Instant::now();
234    let results = execute_gpu(&numbers).await;
235    log::info!("GPU RUNTIME: {}ms", t1.elapsed().as_millis());
236    assert_eq!(numbers.len(), results.len());
237    assert!(results.iter().all(|n| *n == 1.0));
238    log::info!("All 1.0s");
239}
240
241pub fn main() {
242    #[cfg(not(target_arch = "wasm32"))]
243    {
244        env_logger::init();
245        pollster::block_on(run());
246    }
247}
248
249#[cfg(test)]
250#[cfg(not(target_arch = "wasm32"))]
251pub mod tests;