1use super::CreateIndirectValidationPipelineError;
2use crate::{
3 device::DeviceError,
4 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
5};
6use alloc::{boxed::Box, format, string::ToString as _};
7use core::num::NonZeroU64;
8
9#[derive(Debug)]
21pub(crate) struct Dispatch {
22 module: Box<dyn hal::DynShaderModule>,
23 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
24 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
26 pipeline: Box<dyn hal::DynComputePipeline>,
27 dst_buffer: Box<dyn hal::DynBuffer>,
28 dst_bind_group: Box<dyn hal::DynBindGroup>,
29}
30
31pub struct Params<'a> {
32 pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
33 pub pipeline: &'a dyn hal::DynComputePipeline,
34 pub dst_buffer: &'a dyn hal::DynBuffer,
35 pub dst_bind_group: &'a dyn hal::DynBindGroup,
36 pub aligned_offset: u64,
37 pub offset_remainder: u64,
38}
39
40impl Dispatch {
41 pub(super) fn new(
42 device: &dyn hal::DynDevice,
43 limits: &wgt::Limits,
44 ) -> Result<Self, CreateIndirectValidationPipelineError> {
45 let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
46
47 let src = format!(
48 "
49 @group(0) @binding(0)
50 var<storage, read_write> dst: array<u32, 6>;
51 @group(1) @binding(0)
52 var<storage, read> src: array<u32>;
53 struct OffsetPc {{
54 inner: u32,
55 }}
56 var<push_constant> offset: OffsetPc;
57
58 @compute @workgroup_size(1)
59 fn main() {{
60 let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
61 let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
62 if (
63 src.x > max_compute_workgroups_per_dimension ||
64 src.y > max_compute_workgroups_per_dimension ||
65 src.z > max_compute_workgroups_per_dimension
66 ) {{
67 dst = array(0u, 0u, 0u, 0u, 0u, 0u);
68 }} else {{
69 dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
70 }}
71 }}
72 "
73 );
74
75 const SRC_BUFFER_SIZE: NonZeroU64 =
77 unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
78
79 const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
81 NonZeroU64::new_unchecked(
82 SRC_BUFFER_SIZE.get() * 2, )
84 };
85
86 #[cfg(feature = "wgsl")]
87 let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
88 CreateShaderModuleError::Parsing(naga::error::ShaderError {
89 source: src.clone(),
90 label: None,
91 inner: Box::new(inner),
92 })
93 })?;
94 #[cfg(not(feature = "wgsl"))]
95 #[allow(clippy::diverging_sub_expression)]
96 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
97
98 let info = crate::device::create_validator(
99 wgt::Features::PUSH_CONSTANTS,
100 wgt::DownlevelFlags::empty(),
101 naga::valid::ValidationFlags::all(),
102 )
103 .validate(&module)
104 .map_err(|inner| {
105 CreateShaderModuleError::Validation(naga::error::ShaderError {
106 source: src,
107 label: None,
108 inner: Box::new(inner),
109 })
110 })?;
111 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
112 module: alloc::borrow::Cow::Owned(module),
113 info,
114 debug_source: None,
115 });
116 let hal_desc = hal::ShaderModuleDescriptor {
117 label: None,
118 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119 };
120 let module =
121 unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122 match error {
123 hal::ShaderError::Device(error) => {
124 CreateShaderModuleError::Device(DeviceError::from_hal(error))
125 }
126 hal::ShaderError::Compilation(ref msg) => {
127 log::error!("Shader error: {msg}");
128 CreateShaderModuleError::Generation
129 }
130 }
131 })?;
132
133 let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134 label: None,
135 flags: hal::BindGroupLayoutFlags::empty(),
136 entries: &[wgt::BindGroupLayoutEntry {
137 binding: 0,
138 visibility: wgt::ShaderStages::COMPUTE,
139 ty: wgt::BindingType::Buffer {
140 ty: wgt::BufferBindingType::Storage { read_only: false },
141 has_dynamic_offset: false,
142 min_binding_size: Some(DST_BUFFER_SIZE),
143 },
144 count: None,
145 }],
146 };
147 let dst_bind_group_layout = unsafe {
148 device
149 .create_bind_group_layout(&dst_bind_group_layout_desc)
150 .map_err(DeviceError::from_hal)?
151 };
152
153 let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
154 label: None,
155 flags: hal::BindGroupLayoutFlags::empty(),
156 entries: &[wgt::BindGroupLayoutEntry {
157 binding: 0,
158 visibility: wgt::ShaderStages::COMPUTE,
159 ty: wgt::BindingType::Buffer {
160 ty: wgt::BufferBindingType::Storage { read_only: true },
161 has_dynamic_offset: true,
162 min_binding_size: Some(SRC_BUFFER_SIZE),
163 },
164 count: None,
165 }],
166 };
167 let src_bind_group_layout = unsafe {
168 device
169 .create_bind_group_layout(&src_bind_group_layout_desc)
170 .map_err(DeviceError::from_hal)?
171 };
172
173 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
174 label: None,
175 flags: hal::PipelineLayoutFlags::empty(),
176 bind_group_layouts: &[
177 dst_bind_group_layout.as_ref(),
178 src_bind_group_layout.as_ref(),
179 ],
180 push_constant_ranges: &[wgt::PushConstantRange {
181 stages: wgt::ShaderStages::COMPUTE,
182 range: 0..4,
183 }],
184 };
185 let pipeline_layout = unsafe {
186 device
187 .create_pipeline_layout(&pipeline_layout_desc)
188 .map_err(DeviceError::from_hal)?
189 };
190
191 let pipeline_desc = hal::ComputePipelineDescriptor {
192 label: None,
193 layout: pipeline_layout.as_ref(),
194 stage: hal::ProgrammableStage {
195 module: module.as_ref(),
196 entry_point: "main",
197 constants: &Default::default(),
198 zero_initialize_workgroup_memory: false,
199 },
200 cache: None,
201 };
202 let pipeline =
203 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
204 hal::PipelineError::Device(error) => {
205 CreateComputePipelineError::Device(DeviceError::from_hal(error))
206 }
207 hal::PipelineError::Linkage(_stages, msg) => {
208 CreateComputePipelineError::Internal(msg)
209 }
210 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
211 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
212 ),
213 hal::PipelineError::PipelineConstants(_, error) => {
214 CreateComputePipelineError::PipelineConstants(error)
215 }
216 })?;
217
218 let dst_buffer_desc = hal::BufferDescriptor {
219 label: None,
220 size: DST_BUFFER_SIZE.get(),
221 usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
222 memory_flags: hal::MemoryFlags::empty(),
223 };
224 let dst_buffer =
225 unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
226
227 let dst_bind_group_desc = hal::BindGroupDescriptor {
228 label: None,
229 layout: dst_bind_group_layout.as_ref(),
230 entries: &[hal::BindGroupEntry {
231 binding: 0,
232 resource_index: 0,
233 count: 1,
234 }],
235 buffers: &[hal::BufferBinding::new_unchecked(
237 dst_buffer.as_ref(),
238 0,
239 Some(DST_BUFFER_SIZE),
240 )],
241 samplers: &[],
242 textures: &[],
243 acceleration_structures: &[],
244 external_textures: &[],
245 };
246 let dst_bind_group = unsafe {
247 device
248 .create_bind_group(&dst_bind_group_desc)
249 .map_err(DeviceError::from_hal)
250 }?;
251
252 Ok(Self {
253 module,
254 dst_bind_group_layout,
255 src_bind_group_layout,
256 pipeline_layout,
257 pipeline,
258 dst_buffer,
259 dst_bind_group,
260 })
261 }
262
263 pub(super) fn create_src_bind_group(
265 &self,
266 device: &dyn hal::DynDevice,
267 limits: &wgt::Limits,
268 buffer_size: u64,
269 buffer: &dyn hal::DynBuffer,
270 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
271 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
272 let Some(binding_size) = NonZeroU64::new(binding_size) else {
273 return Ok(None);
274 };
275 let hal_desc = hal::BindGroupDescriptor {
276 label: None,
277 layout: self.src_bind_group_layout.as_ref(),
278 entries: &[hal::BindGroupEntry {
279 binding: 0,
280 resource_index: 0,
281 count: 1,
282 }],
283 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
285 samplers: &[],
286 textures: &[],
287 acceleration_structures: &[],
288 external_textures: &[],
289 };
290 unsafe {
291 device
292 .create_bind_group(&hal_desc)
293 .map(Some)
294 .map_err(DeviceError::from_hal)
295 }
296 }
297
298 pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
299 let alignment = limits.min_storage_buffer_offset_alignment as u64;
314 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
315 let aligned_offset = offset - offset % alignment;
316 let max_aligned_offset = buffer_size - binding_size;
318 let aligned_offset = aligned_offset.min(max_aligned_offset);
319 let offset_remainder = offset - aligned_offset;
320
321 Params {
322 pipeline_layout: self.pipeline_layout.as_ref(),
323 pipeline: self.pipeline.as_ref(),
324 dst_buffer: self.dst_buffer.as_ref(),
325 dst_bind_group: self.dst_bind_group.as_ref(),
326 aligned_offset,
327 offset_remainder,
328 }
329 }
330
331 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
332 let Dispatch {
333 module,
334 dst_bind_group_layout,
335 src_bind_group_layout,
336 pipeline_layout,
337 pipeline,
338 dst_buffer,
339 dst_bind_group,
340 } = self;
341
342 unsafe {
343 device.destroy_bind_group(dst_bind_group);
344 device.destroy_buffer(dst_buffer);
345 device.destroy_compute_pipeline(pipeline);
346 device.destroy_pipeline_layout(pipeline_layout);
347 device.destroy_bind_group_layout(src_bind_group_layout);
348 device.destroy_bind_group_layout(dst_bind_group_layout);
349 device.destroy_shader_module(module);
350 }
351 }
352}
353
354fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
355 let alignment = limits.min_storage_buffer_offset_alignment as u64;
356
357 let binding_size = 2 * alignment + (buffer_size % alignment);
388 binding_size.min(buffer_size)
389}