wgpu_examples/big_compute_buffers/
mod.rs
1use std::num::NonZeroU32;
8use wgpu::{util::DeviceExt, Features};
9
10const MAX_BUFFER_SIZE: u64 = 1 << 27; const MAX_DISPATCH_SIZE: u32 = (1 << 16) - 1;
13
14pub async fn execute_gpu(numbers: &[f32]) -> Vec<f32> {
15 let instance = wgpu::Instance::default();
16
17 let adapter = instance
18 .request_adapter(&wgpu::RequestAdapterOptions::default())
19 .await
20 .unwrap();
21
22 let (device, queue) = adapter
23 .request_device(&wgpu::DeviceDescriptor {
24 label: None,
25 required_features: Features::BUFFER_BINDING_ARRAY
28 | Features::STORAGE_RESOURCE_BINDING_ARRAY
29 | Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
30 memory_hints: wgpu::MemoryHints::Performance,
31 required_limits: wgpu::Limits {
32 max_buffer_size: MAX_BUFFER_SIZE,
33 max_binding_array_elements_per_shader_stage: 8,
34 ..Default::default()
35 },
36 ..Default::default()
37 })
38 .await
39 .unwrap();
40
41 execute_gpu_inner(&device, &queue, numbers).await
42}
43
44pub async fn execute_gpu_inner(
45 device: &wgpu::Device,
46 queue: &wgpu::Queue,
47 numbers: &[f32],
48) -> Vec<f32> {
49 let (staging_buffers, storage_buffers, bind_group, compute_pipeline) = setup(device, numbers);
50
51 let mut encoder =
52 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
53 {
54 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
55 label: Some("compute pass descriptor"),
56 timestamp_writes: None,
57 });
58 cpass.set_pipeline(&compute_pipeline);
59 cpass.set_bind_group(0, Some(&bind_group), &[]);
60
61 cpass.dispatch_workgroups(MAX_DISPATCH_SIZE.min(numbers.len() as u32), 1, 1);
62 }
63
64 for (storage_buffer, staging_buffer) in storage_buffers.iter().zip(staging_buffers.iter()) {
65 let stg_size = staging_buffer.size();
66
67 encoder.copy_buffer_to_buffer(
68 storage_buffer, 0,
70 staging_buffer, 0,
72 stg_size,
73 );
74 }
75
76 queue.submit(Some(encoder.finish()));
77
78 for staging_buffer in &staging_buffers {
79 let slice = staging_buffer.slice(..);
80 slice.map_async(wgpu::MapMode::Read, |_| {});
81 }
82
83 device.poll(wgpu::PollType::Wait).unwrap();
84
85 let mut data = Vec::new();
86 for staging_buffer in &staging_buffers {
87 let slice = staging_buffer.slice(..);
88 let mapped = slice.get_mapped_range();
89 data.extend_from_slice(bytemuck::cast_slice(&mapped));
90 drop(mapped);
91 staging_buffer.unmap();
92 }
93
94 data
95}
96
97fn setup(
98 device: &wgpu::Device,
99 numbers: &[f32],
100) -> (
101 Vec<wgpu::Buffer>,
102 Vec<wgpu::Buffer>,
103 wgpu::BindGroup,
104 wgpu::ComputePipeline,
105) {
106 let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
107
108 let staging_buffers = create_staging_buffers(device, numbers);
109 let storage_buffers = create_storage_buffers(device, numbers);
110
111 let (bind_group_layout, bind_group) = setup_binds(&storage_buffers, device);
112
113 let compute_pipeline = setup_pipeline(device, bind_group_layout, cs_module);
114 (
115 staging_buffers,
116 storage_buffers,
117 bind_group,
118 compute_pipeline,
119 )
120}
121
122fn setup_pipeline(
123 device: &wgpu::Device,
124 bind_group_layout: wgpu::BindGroupLayout,
125 cs_module: wgpu::ShaderModule,
126) -> wgpu::ComputePipeline {
127 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
128 label: Some("Compute Pipeline Layout"),
129 bind_group_layouts: &[&bind_group_layout],
130 push_constant_ranges: &[],
131 });
132
133 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
134 label: Some("Compute Pipeline"),
135 layout: Some(&pipeline_layout),
136 module: &cs_module,
137 entry_point: Some("main"),
138 compilation_options: Default::default(),
139 cache: None,
140 })
141}
142
143fn setup_binds(
144 storage_buffers: &[wgpu::Buffer],
145 device: &wgpu::Device,
146) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
147 let bind_group_entries: Vec<wgpu::BindGroupEntry> = storage_buffers
148 .iter()
149 .enumerate()
150 .map(|(bind_idx, buffer)| wgpu::BindGroupEntry {
151 binding: bind_idx as u32,
152 resource: buffer.as_entire_binding(),
153 })
154 .collect();
155
156 let bind_group_layout_entries: Vec<wgpu::BindGroupLayoutEntry> = (0..storage_buffers.len())
157 .map(|bind_idx| wgpu::BindGroupLayoutEntry {
158 binding: bind_idx as u32,
159 visibility: wgpu::ShaderStages::COMPUTE,
160 ty: wgpu::BindingType::Buffer {
161 ty: wgpu::BufferBindingType::Storage { read_only: false },
162 has_dynamic_offset: false,
163 min_binding_size: None,
164 },
165 count: Some(NonZeroU32::new(1).unwrap()),
166 })
167 .collect();
168
169 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
170 label: Some("Custom Storage Bind Group Layout"),
171 entries: &bind_group_layout_entries,
172 });
173
174 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
175 label: Some("Combined Storage Bind Group"),
176 layout: &bind_group_layout,
177 entries: &bind_group_entries,
178 });
179
180 (bind_group_layout, bind_group)
181}
182
183fn calculate_chunks(numbers: &[f32], max_buffer_size: u64) -> Vec<&[f32]> {
184 let max_elements_per_chunk = max_buffer_size as usize / std::mem::size_of::<f32>();
185 numbers.chunks(max_elements_per_chunk).collect()
186}
187
188fn create_storage_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
189 let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
190
191 chunks
192 .iter()
193 .enumerate()
194 .map(|(e, seg)| {
195 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
196 label: Some(&format!("Storage Buffer-{e}")),
197 contents: bytemuck::cast_slice(seg),
198 usage: wgpu::BufferUsages::STORAGE
199 | wgpu::BufferUsages::COPY_DST
200 | wgpu::BufferUsages::COPY_SRC,
201 })
202 })
203 .collect()
204}
205
206fn create_staging_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
207 let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
208
209 (0..chunks.len())
210 .map(|e| {
211 let size = std::mem::size_of_val(chunks[e]) as u64;
212
213 device.create_buffer(&wgpu::BufferDescriptor {
214 label: Some(&format!("staging buffer-{e}")),
215 size,
216 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
217 mapped_at_creation: false,
218 })
219 })
220 .collect()
221}
222
223#[cfg_attr(target_arch = "wasm32", allow(clippy::allow_attributes, dead_code))]
224async fn run() {
225 let numbers = {
226 const BYTES_PER_GB: usize = 1024 * 1024 * 1024;
227 let elements = (BYTES_PER_GB as f32 / 4.0) as usize;
229 vec![0.0; elements]
230 };
231 assert!(numbers.iter().all(|n| *n == 0.0));
232 log::info!("All 0.0s");
233 let t1 = std::time::Instant::now();
234 let results = execute_gpu(&numbers).await;
235 log::info!("GPU RUNTIME: {}ms", t1.elapsed().as_millis());
236 assert_eq!(numbers.len(), results.len());
237 assert!(results.iter().all(|n| *n == 1.0));
238 log::info!("All 1.0s");
239}
240
241pub fn main() {
242 #[cfg(not(target_arch = "wasm32"))]
243 {
244 env_logger::init();
245 pollster::block_on(run());
246 }
247}
248
249#[cfg(test)]
250#[cfg(not(target_arch = "wasm32"))]
251pub mod tests;