wgpu_core/indirect_validation/
dispatch.rs

1use super::CreateIndirectValidationPipelineError;
2use crate::{
3    device::DeviceError,
4    hal_label,
5    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
6};
7use alloc::{boxed::Box, format, string::ToString as _};
8use core::num::NonZeroU64;
9
10/// This machinery requires the following limits:
11///
12/// - max_bind_groups: 2,
13/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
14/// - max_storage_buffers_per_shader_stage: 2,
15/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment,
16/// - max_immediate_size: 4,
17/// - max_compute_invocations_per_workgroup 1
18///
19/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
20/// required for this module's functionality to work.
21#[derive(Debug)]
22pub(crate) struct Dispatch {
23    module: Box<dyn hal::DynShaderModule>,
24    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
26    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
27    pipeline: Box<dyn hal::DynComputePipeline>,
28    dst_buffer: Box<dyn hal::DynBuffer>,
29    dst_bind_group: Box<dyn hal::DynBindGroup>,
30}
31
32pub struct Params<'a> {
33    pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
34    pub pipeline: &'a dyn hal::DynComputePipeline,
35    pub dst_buffer: &'a dyn hal::DynBuffer,
36    pub dst_bind_group: &'a dyn hal::DynBindGroup,
37    pub aligned_offset: u64,
38    pub offset_remainder: u64,
39}
40
41impl Dispatch {
42    pub(super) fn new(
43        device: &dyn hal::DynDevice,
44        instance_flags: wgt::InstanceFlags,
45        limits: &wgt::Limits,
46    ) -> Result<Self, CreateIndirectValidationPipelineError> {
47        let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
48
49        let src = format!(
50            "
51            @group(0) @binding(0)
52            var<storage, read_write> dst: array<u32, 6>;
53            @group(1) @binding(0)
54            var<storage, read> src: array<u32>;
55            struct OffsetPc {{
56                inner: u32,
57            }}
58            var<immediate> offset: OffsetPc;
59
60            @compute @workgroup_size(1)
61            fn main() {{
62                let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
63                let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
64                if (
65                    src.x > max_compute_workgroups_per_dimension ||
66                    src.y > max_compute_workgroups_per_dimension ||
67                    src.z > max_compute_workgroups_per_dimension
68                ) {{
69                    dst = array(0u, 0u, 0u, 0u, 0u, 0u);
70                }} else {{
71                    dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
72                }}
73            }}
74        "
75        );
76
77        // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
78        const SRC_BUFFER_SIZE: NonZeroU64 = NonZeroU64::new(size_of::<u32>() as u64 * 3).unwrap();
79
80        // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
81        const DST_BUFFER_SIZE: NonZeroU64 = NonZeroU64::new(SRC_BUFFER_SIZE.get() * 2).unwrap();
82
83        #[cfg(feature = "wgsl")]
84        let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
85            CreateShaderModuleError::Parsing(naga::error::ShaderError {
86                source: src.clone(),
87                label: None,
88                inner: Box::new(inner),
89            })
90        })?;
91        #[cfg(not(feature = "wgsl"))]
92        #[allow(clippy::diverging_sub_expression)]
93        let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
94
95        let info = crate::device::create_validator(
96            wgt::Features::IMMEDIATES,
97            wgt::DownlevelFlags::empty(),
98            naga::valid::ValidationFlags::all(),
99        )
100        .validate(&module)
101        .map_err(|inner| {
102            CreateShaderModuleError::Validation(naga::error::ShaderError {
103                source: src,
104                label: None,
105                inner: Box::new(inner),
106            })
107        })?;
108        let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
109            module: alloc::borrow::Cow::Owned(module),
110            info,
111            debug_source: None,
112        });
113        let hal_desc = hal::ShaderModuleDescriptor {
114            label: hal_label(
115                Some("(wgpu internal) Indirect dispatch validation shader module"),
116                instance_flags,
117            ),
118            runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119        };
120        let module =
121            unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122                match error {
123                    hal::ShaderError::Device(error) => {
124                        CreateShaderModuleError::Device(DeviceError::from_hal(error))
125                    }
126                    hal::ShaderError::Compilation(ref msg) => {
127                        log::error!("Shader error: {msg}");
128                        CreateShaderModuleError::Generation
129                    }
130                }
131            })?;
132
133        let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134            label: hal_label(
135                Some("(wgpu internal) Indirect dispatch validation destination bind group layout"),
136                instance_flags,
137            ),
138            flags: hal::BindGroupLayoutFlags::empty(),
139            entries: &[wgt::BindGroupLayoutEntry {
140                binding: 0,
141                visibility: wgt::ShaderStages::COMPUTE,
142                ty: wgt::BindingType::Buffer {
143                    ty: wgt::BufferBindingType::Storage { read_only: false },
144                    has_dynamic_offset: false,
145                    min_binding_size: Some(DST_BUFFER_SIZE),
146                },
147                count: None,
148            }],
149        };
150        let dst_bind_group_layout = unsafe {
151            device
152                .create_bind_group_layout(&dst_bind_group_layout_desc)
153                .map_err(DeviceError::from_hal)?
154        };
155
156        let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
157            label: hal_label(
158                Some("(wgpu internal) Indirect dispatch validation source bind group layout"),
159                instance_flags,
160            ),
161            flags: hal::BindGroupLayoutFlags::empty(),
162            entries: &[wgt::BindGroupLayoutEntry {
163                binding: 0,
164                visibility: wgt::ShaderStages::COMPUTE,
165                ty: wgt::BindingType::Buffer {
166                    ty: wgt::BufferBindingType::Storage { read_only: true },
167                    has_dynamic_offset: true,
168                    min_binding_size: Some(SRC_BUFFER_SIZE),
169                },
170                count: None,
171            }],
172        };
173        let src_bind_group_layout = unsafe {
174            device
175                .create_bind_group_layout(&src_bind_group_layout_desc)
176                .map_err(DeviceError::from_hal)?
177        };
178
179        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
180            label: hal_label(
181                Some("(wgpu internal) Indirect dispatch validation pipeline layout"),
182                instance_flags,
183            ),
184            flags: hal::PipelineLayoutFlags::empty(),
185            bind_group_layouts: &[
186                Some(dst_bind_group_layout.as_ref()),
187                Some(src_bind_group_layout.as_ref()),
188            ],
189            immediate_size: 4,
190        };
191        let pipeline_layout = unsafe {
192            device
193                .create_pipeline_layout(&pipeline_layout_desc)
194                .map_err(DeviceError::from_hal)?
195        };
196
197        let pipeline_desc = hal::ComputePipelineDescriptor {
198            label: hal_label(
199                Some("(wgpu internal) Indirect dispatch validation pipeline"),
200                instance_flags,
201            ),
202            layout: pipeline_layout.as_ref(),
203            stage: hal::ProgrammableStage {
204                module: module.as_ref(),
205                entry_point: "main",
206                constants: &Default::default(),
207                zero_initialize_workgroup_memory: false,
208            },
209            cache: None,
210        };
211        let pipeline =
212            unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
213                hal::PipelineError::Device(error) => {
214                    CreateComputePipelineError::Device(DeviceError::from_hal(error))
215                }
216                hal::PipelineError::Linkage(_stages, msg) => {
217                    CreateComputePipelineError::Internal(msg)
218                }
219                hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
220                    crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
221                ),
222                hal::PipelineError::PipelineConstants(_, error) => {
223                    CreateComputePipelineError::PipelineConstants(error)
224                }
225            })?;
226
227        let dst_buffer_desc = hal::BufferDescriptor {
228            label: hal_label(
229                Some("(wgpu internal) Indirect dispatch validation destination buffer"),
230                instance_flags,
231            ),
232            size: DST_BUFFER_SIZE.get(),
233            usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
234            memory_flags: hal::MemoryFlags::empty(),
235        };
236        let dst_buffer =
237            unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
238
239        let dst_bind_group_desc = hal::BindGroupDescriptor {
240            label: hal_label(
241                Some("(wgpu internal) Indirect dispatch validation destination bind group"),
242                instance_flags,
243            ),
244            layout: dst_bind_group_layout.as_ref(),
245            entries: &[hal::BindGroupEntry {
246                binding: 0,
247                resource_index: 0,
248                count: 1,
249            }],
250            // SAFETY: We just created the buffer with this size.
251            buffers: &[hal::BufferBinding::new_unchecked(
252                dst_buffer.as_ref(),
253                0,
254                Some(DST_BUFFER_SIZE),
255            )],
256            samplers: &[],
257            textures: &[],
258            acceleration_structures: &[],
259            external_textures: &[],
260        };
261        let dst_bind_group = unsafe {
262            device
263                .create_bind_group(&dst_bind_group_desc)
264                .map_err(DeviceError::from_hal)
265        }?;
266
267        Ok(Self {
268            module,
269            dst_bind_group_layout,
270            src_bind_group_layout,
271            pipeline_layout,
272            pipeline,
273            dst_buffer,
274            dst_bind_group,
275        })
276    }
277
278    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
279    pub(super) fn create_src_bind_group(
280        &self,
281        device: &dyn hal::DynDevice,
282        limits: &wgt::Limits,
283        buffer_size: u64,
284        buffer: &dyn hal::DynBuffer,
285        instance_flags: wgt::InstanceFlags,
286    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
287        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
288        let Some(binding_size) = NonZeroU64::new(binding_size) else {
289            return Ok(None);
290        };
291        let hal_desc = hal::BindGroupDescriptor {
292            label: hal_label(
293                Some("(wgpu internal) Indirect dispatch validation source bind group"),
294                instance_flags,
295            ),
296            layout: self.src_bind_group_layout.as_ref(),
297            entries: &[hal::BindGroupEntry {
298                binding: 0,
299                resource_index: 0,
300                count: 1,
301            }],
302            // SAFETY: We calculated the binding size to fit within the buffer.
303            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
304            samplers: &[],
305            textures: &[],
306            acceleration_structures: &[],
307            external_textures: &[],
308        };
309        unsafe {
310            device
311                .create_bind_group(&hal_desc)
312                .map(Some)
313                .map_err(DeviceError::from_hal)
314        }
315    }
316
317    pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
318        // The offset we receive is only required to be aligned to 4 bytes.
319        //
320        // Binding offsets and dynamic offsets are required to be aligned to
321        // min_storage_buffer_offset_alignment (256 bytes by default).
322        //
323        // So, we work around this limitation by calculating an aligned offset
324        // and pass the remainder through a immediate data.
325        //
326        // We could bind the whole buffer and only have to pass the offset
327        // through a immediate data but we might run into the
328        // max_storage_buffer_binding_size limit.
329        //
330        // See the inner docs of `calculate_src_buffer_binding_size` to
331        // see how we get the appropriate `binding_size`.
332        let alignment = limits.min_storage_buffer_offset_alignment as u64;
333        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
334        let aligned_offset = offset - offset % alignment;
335        // This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`.
336        let max_aligned_offset = buffer_size - binding_size;
337        let aligned_offset = aligned_offset.min(max_aligned_offset);
338        let offset_remainder = offset - aligned_offset;
339
340        Params {
341            pipeline_layout: self.pipeline_layout.as_ref(),
342            pipeline: self.pipeline.as_ref(),
343            dst_buffer: self.dst_buffer.as_ref(),
344            dst_bind_group: self.dst_bind_group.as_ref(),
345            aligned_offset,
346            offset_remainder,
347        }
348    }
349
350    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
351        let Dispatch {
352            module,
353            dst_bind_group_layout,
354            src_bind_group_layout,
355            pipeline_layout,
356            pipeline,
357            dst_buffer,
358            dst_bind_group,
359        } = self;
360
361        unsafe {
362            device.destroy_bind_group(dst_bind_group);
363            device.destroy_buffer(dst_buffer);
364            device.destroy_compute_pipeline(pipeline);
365            device.destroy_pipeline_layout(pipeline_layout);
366            device.destroy_bind_group_layout(src_bind_group_layout);
367            device.destroy_bind_group_layout(dst_bind_group_layout);
368            device.destroy_shader_module(module);
369        }
370    }
371}
372
373fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
374    let alignment = limits.min_storage_buffer_offset_alignment as u64;
375
376    // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking
377    // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`.
378
379    // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`.
380
381    // Let `chunks = floor(buffer_size / alignment)`.
382    // Let `chunk` be the interval `[0, chunks]`.
383    // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4].
384    // Let `binding` be the interval `[offset, offset + 12]`.
385    // Let `aligned_offset = alignment * chunk`.
386    // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`.
387    // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`.
388    // Let `min_aligned_binding_size = alignment + 8`.
389
390    // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer
391    // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must
392    // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and
393    // `binding_size >= min_aligned_binding_size`.
394
395    // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4].
396    // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks].
397    // => `binding_size = buffer_size - last_aligned_offset`
398    // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)`
399    // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u`
400    // => `binding_size = sr + alignment * u`
401    // => `min_aligned_binding_size <= sr + alignment * u`
402    // => `alignment + 8 <= sr + alignment * u`
403    // => `u` must be at least 2
404    // => `binding_size = sr + alignment * 2`
405
406    let binding_size = 2 * alignment + (buffer_size % alignment);
407    binding_size.min(buffer_size)
408}