wgpu_core/indirect_validation/
dispatch.rs

1use super::CreateIndirectValidationPipelineError;
2use crate::{
3    device::DeviceError,
4    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
5};
6use alloc::{boxed::Box, format, string::ToString as _};
7use core::num::NonZeroU64;
8
9/// This machinery requires the following limits:
10///
11/// - max_bind_groups: 2,
12/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
13/// - max_storage_buffers_per_shader_stage: 2,
14/// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment,
15/// - max_push_constant_size: 4,
16/// - max_compute_invocations_per_workgroup 1
17///
18/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
19/// required for this module's functionality to work.
20#[derive(Debug)]
21pub(crate) struct Dispatch {
22    module: Box<dyn hal::DynShaderModule>,
23    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
24    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
26    pipeline: Box<dyn hal::DynComputePipeline>,
27    dst_buffer: Box<dyn hal::DynBuffer>,
28    dst_bind_group: Box<dyn hal::DynBindGroup>,
29}
30
31pub struct Params<'a> {
32    pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
33    pub pipeline: &'a dyn hal::DynComputePipeline,
34    pub dst_buffer: &'a dyn hal::DynBuffer,
35    pub dst_bind_group: &'a dyn hal::DynBindGroup,
36    pub aligned_offset: u64,
37    pub offset_remainder: u64,
38}
39
40impl Dispatch {
41    pub(super) fn new(
42        device: &dyn hal::DynDevice,
43        limits: &wgt::Limits,
44    ) -> Result<Self, CreateIndirectValidationPipelineError> {
45        let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
46
47        let src = format!(
48            "
49            @group(0) @binding(0)
50            var<storage, read_write> dst: array<u32, 6>;
51            @group(1) @binding(0)
52            var<storage, read> src: array<u32>;
53            struct OffsetPc {{
54                inner: u32,
55            }}
56            var<push_constant> offset: OffsetPc;
57
58            @compute @workgroup_size(1)
59            fn main() {{
60                let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
61                let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
62                if (
63                    src.x > max_compute_workgroups_per_dimension ||
64                    src.y > max_compute_workgroups_per_dimension ||
65                    src.z > max_compute_workgroups_per_dimension
66                ) {{
67                    dst = array(0u, 0u, 0u, 0u, 0u, 0u);
68                }} else {{
69                    dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
70                }}
71            }}
72        "
73        );
74
75        // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
76        const SRC_BUFFER_SIZE: NonZeroU64 =
77            unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
78
79        // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe.
80        const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
81            NonZeroU64::new_unchecked(
82                SRC_BUFFER_SIZE.get() * 2, // From above: `dst: array<u32, 6>`
83            )
84        };
85
86        #[cfg(feature = "wgsl")]
87        let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
88            CreateShaderModuleError::Parsing(naga::error::ShaderError {
89                source: src.clone(),
90                label: None,
91                inner: Box::new(inner),
92            })
93        })?;
94        #[cfg(not(feature = "wgsl"))]
95        #[allow(clippy::diverging_sub_expression)]
96        let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
97
98        let info = crate::device::create_validator(
99            wgt::Features::PUSH_CONSTANTS,
100            wgt::DownlevelFlags::empty(),
101            naga::valid::ValidationFlags::all(),
102        )
103        .validate(&module)
104        .map_err(|inner| {
105            CreateShaderModuleError::Validation(naga::error::ShaderError {
106                source: src,
107                label: None,
108                inner: Box::new(inner),
109            })
110        })?;
111        let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
112            module: alloc::borrow::Cow::Owned(module),
113            info,
114            debug_source: None,
115        });
116        let hal_desc = hal::ShaderModuleDescriptor {
117            label: None,
118            runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119        };
120        let module =
121            unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122                match error {
123                    hal::ShaderError::Device(error) => {
124                        CreateShaderModuleError::Device(DeviceError::from_hal(error))
125                    }
126                    hal::ShaderError::Compilation(ref msg) => {
127                        log::error!("Shader error: {msg}");
128                        CreateShaderModuleError::Generation
129                    }
130                }
131            })?;
132
133        let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134            label: None,
135            flags: hal::BindGroupLayoutFlags::empty(),
136            entries: &[wgt::BindGroupLayoutEntry {
137                binding: 0,
138                visibility: wgt::ShaderStages::COMPUTE,
139                ty: wgt::BindingType::Buffer {
140                    ty: wgt::BufferBindingType::Storage { read_only: false },
141                    has_dynamic_offset: false,
142                    min_binding_size: Some(DST_BUFFER_SIZE),
143                },
144                count: None,
145            }],
146        };
147        let dst_bind_group_layout = unsafe {
148            device
149                .create_bind_group_layout(&dst_bind_group_layout_desc)
150                .map_err(DeviceError::from_hal)?
151        };
152
153        let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
154            label: None,
155            flags: hal::BindGroupLayoutFlags::empty(),
156            entries: &[wgt::BindGroupLayoutEntry {
157                binding: 0,
158                visibility: wgt::ShaderStages::COMPUTE,
159                ty: wgt::BindingType::Buffer {
160                    ty: wgt::BufferBindingType::Storage { read_only: true },
161                    has_dynamic_offset: true,
162                    min_binding_size: Some(SRC_BUFFER_SIZE),
163                },
164                count: None,
165            }],
166        };
167        let src_bind_group_layout = unsafe {
168            device
169                .create_bind_group_layout(&src_bind_group_layout_desc)
170                .map_err(DeviceError::from_hal)?
171        };
172
173        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
174            label: None,
175            flags: hal::PipelineLayoutFlags::empty(),
176            bind_group_layouts: &[
177                dst_bind_group_layout.as_ref(),
178                src_bind_group_layout.as_ref(),
179            ],
180            push_constant_ranges: &[wgt::PushConstantRange {
181                stages: wgt::ShaderStages::COMPUTE,
182                range: 0..4,
183            }],
184        };
185        let pipeline_layout = unsafe {
186            device
187                .create_pipeline_layout(&pipeline_layout_desc)
188                .map_err(DeviceError::from_hal)?
189        };
190
191        let pipeline_desc = hal::ComputePipelineDescriptor {
192            label: None,
193            layout: pipeline_layout.as_ref(),
194            stage: hal::ProgrammableStage {
195                module: module.as_ref(),
196                entry_point: "main",
197                constants: &Default::default(),
198                zero_initialize_workgroup_memory: false,
199            },
200            cache: None,
201        };
202        let pipeline =
203            unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
204                hal::PipelineError::Device(error) => {
205                    CreateComputePipelineError::Device(DeviceError::from_hal(error))
206                }
207                hal::PipelineError::Linkage(_stages, msg) => {
208                    CreateComputePipelineError::Internal(msg)
209                }
210                hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
211                    crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
212                ),
213                hal::PipelineError::PipelineConstants(_, error) => {
214                    CreateComputePipelineError::PipelineConstants(error)
215                }
216            })?;
217
218        let dst_buffer_desc = hal::BufferDescriptor {
219            label: None,
220            size: DST_BUFFER_SIZE.get(),
221            usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
222            memory_flags: hal::MemoryFlags::empty(),
223        };
224        let dst_buffer =
225            unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
226
227        let dst_bind_group_desc = hal::BindGroupDescriptor {
228            label: None,
229            layout: dst_bind_group_layout.as_ref(),
230            entries: &[hal::BindGroupEntry {
231                binding: 0,
232                resource_index: 0,
233                count: 1,
234            }],
235            // SAFETY: We just created the buffer with this size.
236            buffers: &[hal::BufferBinding::new_unchecked(
237                dst_buffer.as_ref(),
238                0,
239                Some(DST_BUFFER_SIZE),
240            )],
241            samplers: &[],
242            textures: &[],
243            acceleration_structures: &[],
244            external_textures: &[],
245        };
246        let dst_bind_group = unsafe {
247            device
248                .create_bind_group(&dst_bind_group_desc)
249                .map_err(DeviceError::from_hal)
250        }?;
251
252        Ok(Self {
253            module,
254            dst_bind_group_layout,
255            src_bind_group_layout,
256            pipeline_layout,
257            pipeline,
258            dst_buffer,
259            dst_bind_group,
260        })
261    }
262
263    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
264    pub(super) fn create_src_bind_group(
265        &self,
266        device: &dyn hal::DynDevice,
267        limits: &wgt::Limits,
268        buffer_size: u64,
269        buffer: &dyn hal::DynBuffer,
270    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
271        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
272        let Some(binding_size) = NonZeroU64::new(binding_size) else {
273            return Ok(None);
274        };
275        let hal_desc = hal::BindGroupDescriptor {
276            label: None,
277            layout: self.src_bind_group_layout.as_ref(),
278            entries: &[hal::BindGroupEntry {
279                binding: 0,
280                resource_index: 0,
281                count: 1,
282            }],
283            // SAFETY: We calculated the binding size to fit within the buffer.
284            buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
285            samplers: &[],
286            textures: &[],
287            acceleration_structures: &[],
288            external_textures: &[],
289        };
290        unsafe {
291            device
292                .create_bind_group(&hal_desc)
293                .map(Some)
294                .map_err(DeviceError::from_hal)
295        }
296    }
297
298    pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
299        // The offset we receive is only required to be aligned to 4 bytes.
300        //
301        // Binding offsets and dynamic offsets are required to be aligned to
302        // min_storage_buffer_offset_alignment (256 bytes by default).
303        //
304        // So, we work around this limitation by calculating an aligned offset
305        // and pass the remainder through a push constant.
306        //
307        // We could bind the whole buffer and only have to pass the offset
308        // through a push constant but we might run into the
309        // max_storage_buffer_binding_size limit.
310        //
311        // See the inner docs of `calculate_src_buffer_binding_size` to
312        // see how we get the appropriate `binding_size`.
313        let alignment = limits.min_storage_buffer_offset_alignment as u64;
314        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
315        let aligned_offset = offset - offset % alignment;
316        // This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`.
317        let max_aligned_offset = buffer_size - binding_size;
318        let aligned_offset = aligned_offset.min(max_aligned_offset);
319        let offset_remainder = offset - aligned_offset;
320
321        Params {
322            pipeline_layout: self.pipeline_layout.as_ref(),
323            pipeline: self.pipeline.as_ref(),
324            dst_buffer: self.dst_buffer.as_ref(),
325            dst_bind_group: self.dst_bind_group.as_ref(),
326            aligned_offset,
327            offset_remainder,
328        }
329    }
330
331    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
332        let Dispatch {
333            module,
334            dst_bind_group_layout,
335            src_bind_group_layout,
336            pipeline_layout,
337            pipeline,
338            dst_buffer,
339            dst_bind_group,
340        } = self;
341
342        unsafe {
343            device.destroy_bind_group(dst_bind_group);
344            device.destroy_buffer(dst_buffer);
345            device.destroy_compute_pipeline(pipeline);
346            device.destroy_pipeline_layout(pipeline_layout);
347            device.destroy_bind_group_layout(src_bind_group_layout);
348            device.destroy_bind_group_layout(dst_bind_group_layout);
349            device.destroy_shader_module(module);
350        }
351    }
352}
353
354fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
355    let alignment = limits.min_storage_buffer_offset_alignment as u64;
356
357    // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking
358    // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`.
359
360    // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`.
361
362    // Let `chunks = floor(buffer_size / alignment)`.
363    // Let `chunk` be the interval `[0, chunks]`.
364    // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4].
365    // Let `binding` be the interval `[offset, offset + 12]`.
366    // Let `aligned_offset = alignment * chunk`.
367    // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`.
368    // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`.
369    // Let `min_aligned_binding_size = alignment + 8`.
370
371    // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer
372    // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must
373    // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and
374    // `binding_size >= min_aligned_binding_size`.
375
376    // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4].
377    // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks].
378    // => `binding_size = buffer_size - last_aligned_offset`
379    // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)`
380    // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u`
381    // => `binding_size = sr + alignment * u`
382    // => `min_aligned_binding_size <= sr + alignment * u`
383    // => `alignment + 8 <= sr + alignment * u`
384    // => `u` must be at least 2
385    // => `binding_size = sr + alignment * 2`
386
387    let binding_size = 2 * alignment + (buffer_size % alignment);
388    binding_size.min(buffer_size)
389}