wgpu_examples/big_compute_buffers/
mod.rs1use std::num::{NonZeroU32, NonZeroU64};
8use wgpu::{util::DeviceExt, Features};
9
10const MAX_BUFFER_SIZE: u64 = 1 << 27; const MAX_DISPATCH_SIZE: u32 = (1 << 16) - 1;
13
14pub async fn execute_gpu(numbers: &[f32]) -> Vec<f32> {
15 let instance = wgpu::Instance::default();
16
17 let adapter = instance
18 .request_adapter(&wgpu::RequestAdapterOptions::default())
19 .await
20 .unwrap();
21
22 let (device, queue) = adapter
23 .request_device(&wgpu::DeviceDescriptor {
24 label: None,
25 required_features: Features::BUFFER_BINDING_ARRAY
28 | Features::STORAGE_RESOURCE_BINDING_ARRAY
29 | Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
30 memory_hints: wgpu::MemoryHints::Performance,
31 required_limits: wgpu::Limits {
32 max_buffer_size: MAX_BUFFER_SIZE,
33 max_binding_array_elements_per_shader_stage: 8,
34 ..Default::default()
35 },
36 ..Default::default()
37 })
38 .await
39 .unwrap();
40
41 execute_gpu_inner(&device, &queue, numbers).await
42}
43
44pub async fn execute_gpu_inner(
45 device: &wgpu::Device,
46 queue: &wgpu::Queue,
47 numbers: &[f32],
48) -> Vec<f32> {
49 let (staging_buffers, storage_buffers, bind_group, compute_pipeline) = setup(device, numbers);
50
51 let mut encoder =
52 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
53 {
54 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
55 label: Some("compute pass descriptor"),
56 timestamp_writes: None,
57 });
58 cpass.set_pipeline(&compute_pipeline);
59 cpass.set_bind_group(0, Some(&bind_group), &[]);
60
61 cpass.dispatch_workgroups(MAX_DISPATCH_SIZE.min(numbers.len() as u32), 1, 1);
62 }
63
64 for (storage_buffer, staging_buffer) in storage_buffers.iter().zip(staging_buffers.iter()) {
65 let stg_size = staging_buffer.size();
66
67 encoder.copy_buffer_to_buffer(
68 storage_buffer, 0,
70 staging_buffer, 0,
72 stg_size,
73 );
74 }
75
76 queue.submit(Some(encoder.finish()));
77
78 for staging_buffer in &staging_buffers {
79 let slice = staging_buffer.slice(..);
80 slice.map_async(wgpu::MapMode::Read, |_| {});
81 }
82
83 device.poll(wgpu::PollType::wait_indefinitely()).unwrap();
84
85 let mut data = Vec::new();
86 for staging_buffer in &staging_buffers {
87 let slice = staging_buffer.slice(..);
88 let mapped = slice.get_mapped_range();
89 data.extend_from_slice(bytemuck::cast_slice(&mapped));
90 drop(mapped);
91 staging_buffer.unmap();
92 }
93
94 data
95}
96
97fn setup(
98 device: &wgpu::Device,
99 numbers: &[f32],
100) -> (
101 Vec<wgpu::Buffer>,
102 Vec<wgpu::Buffer>,
103 wgpu::BindGroup,
104 wgpu::ComputePipeline,
105) {
106 let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
107
108 let staging_buffers = create_staging_buffers(device, numbers);
109 let storage_buffers = create_storage_buffers(device, numbers);
110
111 let (bind_group_layout, bind_group) = setup_binds(&storage_buffers, device);
112
113 let compute_pipeline = setup_pipeline(device, bind_group_layout, cs_module);
114 (
115 staging_buffers,
116 storage_buffers,
117 bind_group,
118 compute_pipeline,
119 )
120}
121
122fn setup_pipeline(
123 device: &wgpu::Device,
124 bind_group_layout: wgpu::BindGroupLayout,
125 cs_module: wgpu::ShaderModule,
126) -> wgpu::ComputePipeline {
127 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
128 label: Some("Compute Pipeline Layout"),
129 bind_group_layouts: &[Some(&bind_group_layout)],
130 immediate_size: 0,
131 });
132
133 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
134 label: Some("Compute Pipeline"),
135 layout: Some(&pipeline_layout),
136 module: &cs_module,
137 entry_point: Some("main"),
138 compilation_options: Default::default(),
139 cache: None,
140 })
141}
142
143fn setup_binds(
144 storage_buffers: &[wgpu::Buffer],
145 device: &wgpu::Device,
146) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
147 let buffers: Vec<_> = storage_buffers
148 .iter()
149 .map(|b| b.as_entire_buffer_binding())
150 .collect();
151
152 let entry = wgpu::BindGroupEntry {
153 binding: 0,
154 resource: wgpu::BindingResource::BufferArray(&buffers),
155 };
156
157 let bgl_entry = wgpu::BindGroupLayoutEntry {
158 binding: 0,
159 visibility: wgpu::ShaderStages::COMPUTE,
160 ty: wgpu::BindingType::Buffer {
161 ty: wgpu::BufferBindingType::Storage { read_only: false },
162 has_dynamic_offset: false,
163 min_binding_size: Some(NonZeroU64::new(4).unwrap()),
164 },
165 count: Some(NonZeroU32::new(buffers.len() as u32).unwrap()),
166 };
167
168 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
169 label: Some("Custom Storage Bind Group Layout"),
170 entries: &[bgl_entry],
171 });
172
173 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
174 label: Some("Combined Storage Bind Group"),
175 layout: &bind_group_layout,
176 entries: &[entry],
177 });
178
179 (bind_group_layout, bind_group)
180}
181
182fn calculate_chunks(numbers: &[f32], max_buffer_size: u64) -> Vec<&[f32]> {
183 let max_elements_per_chunk = max_buffer_size as usize / std::mem::size_of::<f32>();
184 numbers.chunks(max_elements_per_chunk).collect()
185}
186
187fn create_storage_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
188 let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
189
190 chunks
191 .iter()
192 .enumerate()
193 .map(|(e, seg)| {
194 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
195 label: Some(&format!("Storage Buffer-{e}")),
196 contents: bytemuck::cast_slice(seg),
197 usage: wgpu::BufferUsages::STORAGE
198 | wgpu::BufferUsages::COPY_DST
199 | wgpu::BufferUsages::COPY_SRC,
200 })
201 })
202 .collect()
203}
204
205fn create_staging_buffers(device: &wgpu::Device, numbers: &[f32]) -> Vec<wgpu::Buffer> {
206 let chunks = calculate_chunks(numbers, MAX_BUFFER_SIZE);
207
208 (0..chunks.len())
209 .map(|e| {
210 let size = std::mem::size_of_val(chunks[e]) as u64;
211
212 device.create_buffer(&wgpu::BufferDescriptor {
213 label: Some(&format!("staging buffer-{e}")),
214 size,
215 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
216 mapped_at_creation: false,
217 })
218 })
219 .collect()
220}
221
222#[cfg_attr(target_arch = "wasm32", allow(clippy::allow_attributes, dead_code))]
223async fn run() {
224 let numbers = {
225 const BYTES_PER_GB: usize = 1024 * 1024 * 1024;
226 let elements = (BYTES_PER_GB as f32 / 4.0) as usize;
228 vec![0.0; elements]
229 };
230 assert!(numbers.iter().all(|n| *n == 0.0));
231 log::info!("All 0.0s");
232 let t1 = std::time::Instant::now();
233 let results = execute_gpu(&numbers).await;
234 log::info!("GPU RUNTIME: {}ms", t1.elapsed().as_millis());
235 assert_eq!(numbers.len(), results.len());
236 assert!(results.iter().all(|n| *n == 1.0));
237 log::info!("All 1.0s");
238}
239
240pub fn main() {
241 #[cfg(not(target_arch = "wasm32"))]
242 {
243 env_logger::init();
244 pollster::block_on(run());
245 }
246}
247
248#[cfg(test)]
249#[cfg(not(target_arch = "wasm32"))]
250pub mod tests;