1use super::CreateIndirectValidationPipelineError;
2use crate::{
3 device::DeviceError,
4 hal_label,
5 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
6};
7use alloc::{boxed::Box, format, string::ToString as _};
8use core::num::NonZeroU64;
9
10#[derive(Debug)]
22pub(crate) struct Dispatch {
23 module: Box<dyn hal::DynShaderModule>,
24 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
26 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
27 pipeline: Box<dyn hal::DynComputePipeline>,
28 dst_buffer: Box<dyn hal::DynBuffer>,
29 dst_bind_group: Box<dyn hal::DynBindGroup>,
30}
31
32pub struct Params<'a> {
33 pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
34 pub pipeline: &'a dyn hal::DynComputePipeline,
35 pub dst_buffer: &'a dyn hal::DynBuffer,
36 pub dst_bind_group: &'a dyn hal::DynBindGroup,
37 pub aligned_offset: u64,
38 pub offset_remainder: u64,
39}
40
41impl Dispatch {
42 pub(super) fn new(
43 device: &dyn hal::DynDevice,
44 instance_flags: wgt::InstanceFlags,
45 limits: &wgt::Limits,
46 ) -> Result<Self, CreateIndirectValidationPipelineError> {
47 let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
48
49 let src = format!(
50 "
51 @group(0) @binding(0)
52 var<storage, read_write> dst: array<u32, 6>;
53 @group(1) @binding(0)
54 var<storage, read> src: array<u32>;
55 struct OffsetPc {{
56 inner: u32,
57 }}
58 var<immediate> offset: OffsetPc;
59
60 @compute @workgroup_size(1)
61 fn main() {{
62 let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
63 let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
64 if (
65 src.x > max_compute_workgroups_per_dimension ||
66 src.y > max_compute_workgroups_per_dimension ||
67 src.z > max_compute_workgroups_per_dimension
68 ) {{
69 dst = array(0u, 0u, 0u, 0u, 0u, 0u);
70 }} else {{
71 dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
72 }}
73 }}
74 "
75 );
76
77 const SRC_BUFFER_SIZE: NonZeroU64 = NonZeroU64::new(size_of::<u32>() as u64 * 3).unwrap();
79
80 const DST_BUFFER_SIZE: NonZeroU64 = NonZeroU64::new(SRC_BUFFER_SIZE.get() * 2).unwrap();
82
83 #[cfg(feature = "wgsl")]
84 let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
85 CreateShaderModuleError::Parsing(naga::error::ShaderError {
86 source: src.clone(),
87 label: None,
88 inner: Box::new(inner),
89 })
90 })?;
91 #[cfg(not(feature = "wgsl"))]
92 #[allow(clippy::diverging_sub_expression)]
93 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
94
95 let info = crate::device::create_validator(
96 wgt::Features::IMMEDIATES,
97 wgt::DownlevelFlags::empty(),
98 naga::valid::ValidationFlags::all(),
99 )
100 .validate(&module)
101 .map_err(|inner| {
102 CreateShaderModuleError::Validation(naga::error::ShaderError {
103 source: src,
104 label: None,
105 inner: Box::new(inner),
106 })
107 })?;
108 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
109 module: alloc::borrow::Cow::Owned(module),
110 info,
111 debug_source: None,
112 });
113 let hal_desc = hal::ShaderModuleDescriptor {
114 label: hal_label(
115 Some("(wgpu internal) Indirect dispatch validation shader module"),
116 instance_flags,
117 ),
118 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119 };
120 let module =
121 unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122 match error {
123 hal::ShaderError::Device(error) => {
124 CreateShaderModuleError::Device(DeviceError::from_hal(error))
125 }
126 hal::ShaderError::Compilation(ref msg) => {
127 log::error!("Shader error: {msg}");
128 CreateShaderModuleError::Generation
129 }
130 }
131 })?;
132
133 let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134 label: hal_label(
135 Some("(wgpu internal) Indirect dispatch validation destination bind group layout"),
136 instance_flags,
137 ),
138 flags: hal::BindGroupLayoutFlags::empty(),
139 entries: &[wgt::BindGroupLayoutEntry {
140 binding: 0,
141 visibility: wgt::ShaderStages::COMPUTE,
142 ty: wgt::BindingType::Buffer {
143 ty: wgt::BufferBindingType::Storage { read_only: false },
144 has_dynamic_offset: false,
145 min_binding_size: Some(DST_BUFFER_SIZE),
146 },
147 count: None,
148 }],
149 };
150 let dst_bind_group_layout = unsafe {
151 device
152 .create_bind_group_layout(&dst_bind_group_layout_desc)
153 .map_err(DeviceError::from_hal)?
154 };
155
156 let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
157 label: hal_label(
158 Some("(wgpu internal) Indirect dispatch validation source bind group layout"),
159 instance_flags,
160 ),
161 flags: hal::BindGroupLayoutFlags::empty(),
162 entries: &[wgt::BindGroupLayoutEntry {
163 binding: 0,
164 visibility: wgt::ShaderStages::COMPUTE,
165 ty: wgt::BindingType::Buffer {
166 ty: wgt::BufferBindingType::Storage { read_only: true },
167 has_dynamic_offset: true,
168 min_binding_size: Some(SRC_BUFFER_SIZE),
169 },
170 count: None,
171 }],
172 };
173 let src_bind_group_layout = unsafe {
174 device
175 .create_bind_group_layout(&src_bind_group_layout_desc)
176 .map_err(DeviceError::from_hal)?
177 };
178
179 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
180 label: hal_label(
181 Some("(wgpu internal) Indirect dispatch validation pipeline layout"),
182 instance_flags,
183 ),
184 flags: hal::PipelineLayoutFlags::empty(),
185 bind_group_layouts: &[
186 Some(dst_bind_group_layout.as_ref()),
187 Some(src_bind_group_layout.as_ref()),
188 ],
189 immediate_size: 4,
190 };
191 let pipeline_layout = unsafe {
192 device
193 .create_pipeline_layout(&pipeline_layout_desc)
194 .map_err(DeviceError::from_hal)?
195 };
196
197 let pipeline_desc = hal::ComputePipelineDescriptor {
198 label: hal_label(
199 Some("(wgpu internal) Indirect dispatch validation pipeline"),
200 instance_flags,
201 ),
202 layout: pipeline_layout.as_ref(),
203 stage: hal::ProgrammableStage {
204 module: module.as_ref(),
205 entry_point: "main",
206 constants: &Default::default(),
207 zero_initialize_workgroup_memory: false,
208 },
209 cache: None,
210 };
211 let pipeline =
212 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
213 hal::PipelineError::Device(error) => {
214 CreateComputePipelineError::Device(DeviceError::from_hal(error))
215 }
216 hal::PipelineError::Linkage(_stages, msg) => {
217 CreateComputePipelineError::Internal(msg)
218 }
219 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
220 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
221 ),
222 hal::PipelineError::PipelineConstants(_, error) => {
223 CreateComputePipelineError::PipelineConstants(error)
224 }
225 })?;
226
227 let dst_buffer_desc = hal::BufferDescriptor {
228 label: hal_label(
229 Some("(wgpu internal) Indirect dispatch validation destination buffer"),
230 instance_flags,
231 ),
232 size: DST_BUFFER_SIZE.get(),
233 usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
234 memory_flags: hal::MemoryFlags::empty(),
235 };
236 let dst_buffer =
237 unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
238
239 let dst_bind_group_desc = hal::BindGroupDescriptor {
240 label: hal_label(
241 Some("(wgpu internal) Indirect dispatch validation destination bind group"),
242 instance_flags,
243 ),
244 layout: dst_bind_group_layout.as_ref(),
245 entries: &[hal::BindGroupEntry {
246 binding: 0,
247 resource_index: 0,
248 count: 1,
249 }],
250 buffers: &[hal::BufferBinding::new_unchecked(
252 dst_buffer.as_ref(),
253 0,
254 Some(DST_BUFFER_SIZE),
255 )],
256 samplers: &[],
257 textures: &[],
258 acceleration_structures: &[],
259 external_textures: &[],
260 };
261 let dst_bind_group = unsafe {
262 device
263 .create_bind_group(&dst_bind_group_desc)
264 .map_err(DeviceError::from_hal)
265 }?;
266
267 Ok(Self {
268 module,
269 dst_bind_group_layout,
270 src_bind_group_layout,
271 pipeline_layout,
272 pipeline,
273 dst_buffer,
274 dst_bind_group,
275 })
276 }
277
278 pub(super) fn create_src_bind_group(
280 &self,
281 device: &dyn hal::DynDevice,
282 limits: &wgt::Limits,
283 buffer_size: u64,
284 buffer: &dyn hal::DynBuffer,
285 instance_flags: wgt::InstanceFlags,
286 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
287 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
288 let Some(binding_size) = NonZeroU64::new(binding_size) else {
289 return Ok(None);
290 };
291 let hal_desc = hal::BindGroupDescriptor {
292 label: hal_label(
293 Some("(wgpu internal) Indirect dispatch validation source bind group"),
294 instance_flags,
295 ),
296 layout: self.src_bind_group_layout.as_ref(),
297 entries: &[hal::BindGroupEntry {
298 binding: 0,
299 resource_index: 0,
300 count: 1,
301 }],
302 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
304 samplers: &[],
305 textures: &[],
306 acceleration_structures: &[],
307 external_textures: &[],
308 };
309 unsafe {
310 device
311 .create_bind_group(&hal_desc)
312 .map(Some)
313 .map_err(DeviceError::from_hal)
314 }
315 }
316
317 pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
318 let alignment = limits.min_storage_buffer_offset_alignment as u64;
333 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
334 let aligned_offset = offset - offset % alignment;
335 let max_aligned_offset = buffer_size - binding_size;
337 let aligned_offset = aligned_offset.min(max_aligned_offset);
338 let offset_remainder = offset - aligned_offset;
339
340 Params {
341 pipeline_layout: self.pipeline_layout.as_ref(),
342 pipeline: self.pipeline.as_ref(),
343 dst_buffer: self.dst_buffer.as_ref(),
344 dst_bind_group: self.dst_bind_group.as_ref(),
345 aligned_offset,
346 offset_remainder,
347 }
348 }
349
350 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
351 let Dispatch {
352 module,
353 dst_bind_group_layout,
354 src_bind_group_layout,
355 pipeline_layout,
356 pipeline,
357 dst_buffer,
358 dst_bind_group,
359 } = self;
360
361 unsafe {
362 device.destroy_bind_group(dst_bind_group);
363 device.destroy_buffer(dst_buffer);
364 device.destroy_compute_pipeline(pipeline);
365 device.destroy_pipeline_layout(pipeline_layout);
366 device.destroy_bind_group_layout(src_bind_group_layout);
367 device.destroy_bind_group_layout(dst_bind_group_layout);
368 device.destroy_shader_module(module);
369 }
370 }
371}
372
373fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
374 let alignment = limits.min_storage_buffer_offset_alignment as u64;
375
376 let binding_size = 2 * alignment + (buffer_size % alignment);
407 binding_size.min(buffer_size)
408}