wgpu_examples/big_compute_buffers/
mod.rs

1//! This example shows you a potential course for when your 'data' is too large
2//! for a single Buffer.
3//!
4//! A lot of things aren't explained here via comments. See hello-compute and
5//! repeated-compute for code that is more thoroughly commented.
6
7use std::num::{NonZeroU32, NonZeroU64};
8use wgpu::{util::DeviceExt, Features};
9
10// These are set by the minimum required defaults for webgpu.
11const MAX_BUFFER_SIZE: u64 = 1 << 27; // 134_217_728 // 134MB
12const MAX_DISPATCH_SIZE: u32 = (1 << 16) - 1;
13
14pub async fn execute_gpu(numbers: &[f32]) -> Vec<f32> {
15    let instance = wgpu::Instance::default();
16
17    let adapter = instance
18        .request_adapter(&wgpu::RequestAdapterOptions::default())
19        .await
20        .unwrap();
21
22    let (device, queue) = adapter
23        .request_device(&wgpu::DeviceDescriptor {
24            label: None,
25            // These features are required to use `binding_array` in your wgsl.
26            // Without them your shader may fail to compile.
27            required_features: Features::BUFFER_BINDING_ARRAY
28                | Features::STORAGE_RESOURCE_BINDING_ARRAY
29                | Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
30            memory_hints: wgpu::MemoryHints::Performance,
31            required_limits: wgpu::Limits {
32                max_buffer_size: MAX_BUFFER_SIZE,
33                max_binding_array_elements_per_shader_stage: 8,
34                ..Default::default()
35            },
36            ..Default::default()
37        })
38        .await
39        .unwrap();
40
41    execute_gpu_inner(&device, &queue, numbers).await
42}
43
44pub async fn execute_gpu_inner(
45    device: &wgpu::Device,
46    queue: &wgpu::Queue,
47    numbers: &[f32],
48) -> Vec<f32> {
49    let (staging_buffers, storage_buffers, bind_group, compute_pipeline) = setup(device, numbers);
50
51    let mut encoder =
52        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
53    {
54        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
55            label: Some("compute pass descriptor"),
56            timestamp_writes: None,
57        });
58        cpass.set_pipeline(&compute_pipeline);
59        cpass.set_bind_group(0, Some(&bind_group), &[]);
60
61        cpass.dispatch_workgroups(MAX_DISPATCH_SIZE.min(numbers.len() as u32), 1, 1);
62    }
63
64    for (storage_buffer, staging_buffer) in storage_buffers.iter().zip(staging_buffers.iter()) {
65        let stg_size = staging_buffer.size();
66
67        encoder.copy_buffer_to_buffer(
68            storage_buffer, // Source buffer
69            0,
70            staging_buffer, // Destination buffer
71            0,
72            stg_size,
73        );
74    }
75
76    queue.submit(Some(encoder.finish()));
77
78    for staging_buffer in &staging_buffers {
79        let slice = staging_buffer.slice(..);
80        slice.map_async(wgpu::MapMode::Read, |_| {});
81    }
82
83    device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
84
85    let mut data = Vec::new();
86    for staging_buffer in &staging_buffers {
87        let slice = staging_buffer.slice(..);
88        let mapped = slice.get_mapped_range();
89        data.extend_from_slice(bytemuck::cast_slice(&mapped));
90        drop(mapped);
91        staging_buffer.unmap();
92    }
93
94    data
95}
96
97fn setup(
98    device: &wgpu::Device,
99    numbers: &[f32],
100) -> (
101    Vec<wgpu::Buffer>,
102    Vec<wgpu::Buffer>,
103    wgpu::BindGroup,
104    wgpu::ComputePipeline,
105) {
106    let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
107
108    let staging_buffers = create_staging_buffers(device, numbers);
109    let storage_buffers = create_storage_buffers(device, numbers);
110
111    let (bind_group_layout, bind_group) = setup_binds(&storage_buffers, device);
112
113    let compute_pipeline = setup_pipeline(device, bind_group_layout, cs_module);
114    (
115        staging_buffers,
116        storage_buffers,
117        bind_group,
118        compute_pipeline,
119    )
120}
121
122fn setup_pipeline(
123    device: &wgpu::Device,
124    bind_group_layout: wgpu::BindGroupLayout,
125    cs_module: wgpu::ShaderModule,
126) -> wgpu::ComputePipeline {
127    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
128        label: Some("Compute Pipeline Layout"),
129        bind_group_layouts: &[Some(&bind_group_layout)],
130        immediate_size: 0,
131    });
132
133    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
134        label: Some("Compute Pipeline"),
135        layout: Some(&pipeline_layout),
136        module: &cs_module,
137        entry_point: Some("main"),
138        compilation_options: Default::default(),
139        cache: None,
140    })
141}
142
143fn setup_binds(
144    storage_buffers: &[wgpu::Buffer],
145    device: &wgpu::Device,
146) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
147    let buffers: Vec<_> = storage_buffers
148        .iter()
149        .map(|b| b.as_entire_buffer_binding())
150        .collect();
151
152    let entry = wgpu::BindGroupEntry {
153        binding: 0,
154        resource: wgpu::BindingResource::BufferArray(&buffers),
155    };
156
157    let bgl_entry = wgpu::BindGroupLayoutEntry {
158        binding: 0,
159        visibility: wgpu::ShaderStages::COMPUTE,
160        ty: wgpu::BindingType::Buffer {
161            ty: wgpu::BufferBindingType::Storage { read_only: false },
162            has_dynamic_offset: false,
163            min_binding_size: Some(NonZeroU64::new(4).unwrap()),
164        },
165        count: Some(NonZeroU32::new(buffers.len() as u32).unwrap()),
166    };
167
168    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
169        label: Some("Custom Storage Bind Group Layout"),
170        entries: &[bgl_entry],
171    });
172
173    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
174        label: Some("Combined Storage Bind Group"),
175        layout: &bind_group_layout,
176        entries: &[entry],
177    });
178
179    (bind_group_layout, bind_group)
180}
181
182fn calculate_chunks(numbers: &[f32], max_buffer_size: u64) -> Vec<&[f32]> {
183    let max_elements_per_chunk = max_buffer_size as usize / std::mem::size_of::<f32>();
184    numbers.chunks(max_elements_per_chunk).collect()
185}
186
187fn create_storage_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
188    let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
189
190    chunks
191        .iter()
192        .enumerate()
193        .map(|(e, seg)| {
194            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
195                label: Some(&format!("Storage Buffer-{e}")),
196                contents: bytemuck::cast_slice(seg),
197                usage: wgpu::BufferUsages::STORAGE
198                    | wgpu::BufferUsages::COPY_DST
199                    | wgpu::BufferUsages::COPY_SRC,
200            })
201        })
202        .collect()
203}
204
205fn create_staging_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
206    let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
207
208    (0..chunks.len())
209        .map(|e| {
210            let size = std::mem::size_of_val(chunks[e]) as u64;
211
212            device.create_buffer(&wgpu::BufferDescriptor {
213                label: Some(&format!("staging buffer-{e}")),
214                size,
215                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
216                mapped_at_creation: false,
217            })
218        })
219        .collect()
220}
221
222#[cfg_attr(target_arch = "wasm32", allow(clippy::allow_attributes, dead_code))]
223async fn run() {
224    let numbers = {
225        const BYTES_PER_GB: usize = 1024 * 1024 * 1024;
226        // 4 bytes per f32
227        let elements = (BYTES_PER_GB as f32 / 4.0) as usize;
228        vec![0.0; elements]
229    };
230    assert!(numbers.iter().all(|n| *n == 0.0));
231    log::info!("All 0.0s");
232    let t1 = std::time::Instant::now();
233    let results = execute_gpu(&numbers).await;
234    log::info!("GPU RUNTIME: {}ms", t1.elapsed().as_millis());
235    assert_eq!(numbers.len(), results.len());
236    assert!(results.iter().all(|n| *n == 1.0));
237    log::info!("All 1.0s");
238}
239
240pub fn main() {
241    #[cfg(not(target_arch = "wasm32"))]
242    {
243        env_logger::init();
244        pollster::block_on(run());
245    }
246}
247
248#[cfg(test)]
249#[cfg(not(target_arch = "wasm32"))]
250pub mod tests;