wgpu_core/indirect_validation/
dispatch.rs1use super::CreateIndirectValidationPipelineError;
2use crate::{
3 device::DeviceError,
4 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
5};
6use alloc::{boxed::Box, format, string::ToString as _};
7use core::num::NonZeroU64;
8
9#[derive(Debug)]
21pub(crate) struct Dispatch {
22 module: Box<dyn hal::DynShaderModule>,
23 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
24 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
26 pipeline: Box<dyn hal::DynComputePipeline>,
27 dst_buffer: Box<dyn hal::DynBuffer>,
28 dst_bind_group: Box<dyn hal::DynBindGroup>,
29}
30
31pub struct Params<'a> {
32 pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
33 pub pipeline: &'a dyn hal::DynComputePipeline,
34 pub dst_buffer: &'a dyn hal::DynBuffer,
35 pub dst_bind_group: &'a dyn hal::DynBindGroup,
36 pub aligned_offset: u64,
37 pub offset_remainder: u64,
38}
39
40impl Dispatch {
41 pub(super) fn new(
42 device: &dyn hal::DynDevice,
43 limits: &wgt::Limits,
44 ) -> Result<Self, CreateIndirectValidationPipelineError> {
45 let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
46
47 let src = format!(
48 "
49 @group(0) @binding(0)
50 var<storage, read_write> dst: array<u32, 6>;
51 @group(1) @binding(0)
52 var<storage, read> src: array<u32>;
53 struct OffsetPc {{
54 inner: u32,
55 }}
56 var<immediate> offset: OffsetPc;
57
58 @compute @workgroup_size(1)
59 fn main() {{
60 let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
61 let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
62 if (
63 src.x > max_compute_workgroups_per_dimension ||
64 src.y > max_compute_workgroups_per_dimension ||
65 src.z > max_compute_workgroups_per_dimension
66 ) {{
67 dst = array(0u, 0u, 0u, 0u, 0u, 0u);
68 }} else {{
69 dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
70 }}
71 }}
72 "
73 );
74
75 const SRC_BUFFER_SIZE: NonZeroU64 =
77 unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
78
79 const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
81 NonZeroU64::new_unchecked(
82 SRC_BUFFER_SIZE.get() * 2, )
84 };
85
86 #[cfg(feature = "wgsl")]
87 let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
88 CreateShaderModuleError::Parsing(naga::error::ShaderError {
89 source: src.clone(),
90 label: None,
91 inner: Box::new(inner),
92 })
93 })?;
94 #[cfg(not(feature = "wgsl"))]
95 #[allow(clippy::diverging_sub_expression)]
96 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
97
98 let info = crate::device::create_validator(
99 wgt::Features::IMMEDIATES,
100 wgt::DownlevelFlags::empty(),
101 naga::valid::ValidationFlags::all(),
102 )
103 .validate(&module)
104 .map_err(|inner| {
105 CreateShaderModuleError::Validation(naga::error::ShaderError {
106 source: src,
107 label: None,
108 inner: Box::new(inner),
109 })
110 })?;
111 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
112 module: alloc::borrow::Cow::Owned(module),
113 info,
114 debug_source: None,
115 });
116 let hal_desc = hal::ShaderModuleDescriptor {
117 label: None,
118 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119 };
120 let module =
121 unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122 match error {
123 hal::ShaderError::Device(error) => {
124 CreateShaderModuleError::Device(DeviceError::from_hal(error))
125 }
126 hal::ShaderError::Compilation(ref msg) => {
127 log::error!("Shader error: {msg}");
128 CreateShaderModuleError::Generation
129 }
130 }
131 })?;
132
133 let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134 label: None,
135 flags: hal::BindGroupLayoutFlags::empty(),
136 entries: &[wgt::BindGroupLayoutEntry {
137 binding: 0,
138 visibility: wgt::ShaderStages::COMPUTE,
139 ty: wgt::BindingType::Buffer {
140 ty: wgt::BufferBindingType::Storage { read_only: false },
141 has_dynamic_offset: false,
142 min_binding_size: Some(DST_BUFFER_SIZE),
143 },
144 count: None,
145 }],
146 };
147 let dst_bind_group_layout = unsafe {
148 device
149 .create_bind_group_layout(&dst_bind_group_layout_desc)
150 .map_err(DeviceError::from_hal)?
151 };
152
153 let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
154 label: None,
155 flags: hal::BindGroupLayoutFlags::empty(),
156 entries: &[wgt::BindGroupLayoutEntry {
157 binding: 0,
158 visibility: wgt::ShaderStages::COMPUTE,
159 ty: wgt::BindingType::Buffer {
160 ty: wgt::BufferBindingType::Storage { read_only: true },
161 has_dynamic_offset: true,
162 min_binding_size: Some(SRC_BUFFER_SIZE),
163 },
164 count: None,
165 }],
166 };
167 let src_bind_group_layout = unsafe {
168 device
169 .create_bind_group_layout(&src_bind_group_layout_desc)
170 .map_err(DeviceError::from_hal)?
171 };
172
173 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
174 label: None,
175 flags: hal::PipelineLayoutFlags::empty(),
176 bind_group_layouts: &[
177 dst_bind_group_layout.as_ref(),
178 src_bind_group_layout.as_ref(),
179 ],
180 immediate_size: 4,
181 };
182 let pipeline_layout = unsafe {
183 device
184 .create_pipeline_layout(&pipeline_layout_desc)
185 .map_err(DeviceError::from_hal)?
186 };
187
188 let pipeline_desc = hal::ComputePipelineDescriptor {
189 label: None,
190 layout: pipeline_layout.as_ref(),
191 stage: hal::ProgrammableStage {
192 module: module.as_ref(),
193 entry_point: "main",
194 constants: &Default::default(),
195 zero_initialize_workgroup_memory: false,
196 },
197 cache: None,
198 };
199 let pipeline =
200 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
201 hal::PipelineError::Device(error) => {
202 CreateComputePipelineError::Device(DeviceError::from_hal(error))
203 }
204 hal::PipelineError::Linkage(_stages, msg) => {
205 CreateComputePipelineError::Internal(msg)
206 }
207 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
208 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
209 ),
210 hal::PipelineError::PipelineConstants(_, error) => {
211 CreateComputePipelineError::PipelineConstants(error)
212 }
213 })?;
214
215 let dst_buffer_desc = hal::BufferDescriptor {
216 label: None,
217 size: DST_BUFFER_SIZE.get(),
218 usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
219 memory_flags: hal::MemoryFlags::empty(),
220 };
221 let dst_buffer =
222 unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
223
224 let dst_bind_group_desc = hal::BindGroupDescriptor {
225 label: None,
226 layout: dst_bind_group_layout.as_ref(),
227 entries: &[hal::BindGroupEntry {
228 binding: 0,
229 resource_index: 0,
230 count: 1,
231 }],
232 buffers: &[hal::BufferBinding::new_unchecked(
234 dst_buffer.as_ref(),
235 0,
236 Some(DST_BUFFER_SIZE),
237 )],
238 samplers: &[],
239 textures: &[],
240 acceleration_structures: &[],
241 external_textures: &[],
242 };
243 let dst_bind_group = unsafe {
244 device
245 .create_bind_group(&dst_bind_group_desc)
246 .map_err(DeviceError::from_hal)
247 }?;
248
249 Ok(Self {
250 module,
251 dst_bind_group_layout,
252 src_bind_group_layout,
253 pipeline_layout,
254 pipeline,
255 dst_buffer,
256 dst_bind_group,
257 })
258 }
259
260 pub(super) fn create_src_bind_group(
262 &self,
263 device: &dyn hal::DynDevice,
264 limits: &wgt::Limits,
265 buffer_size: u64,
266 buffer: &dyn hal::DynBuffer,
267 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
268 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
269 let Some(binding_size) = NonZeroU64::new(binding_size) else {
270 return Ok(None);
271 };
272 let hal_desc = hal::BindGroupDescriptor {
273 label: None,
274 layout: self.src_bind_group_layout.as_ref(),
275 entries: &[hal::BindGroupEntry {
276 binding: 0,
277 resource_index: 0,
278 count: 1,
279 }],
280 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, binding_size)],
282 samplers: &[],
283 textures: &[],
284 acceleration_structures: &[],
285 external_textures: &[],
286 };
287 unsafe {
288 device
289 .create_bind_group(&hal_desc)
290 .map(Some)
291 .map_err(DeviceError::from_hal)
292 }
293 }
294
295 pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
296 let alignment = limits.min_storage_buffer_offset_alignment as u64;
311 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
312 let aligned_offset = offset - offset % alignment;
313 let max_aligned_offset = buffer_size - binding_size;
315 let aligned_offset = aligned_offset.min(max_aligned_offset);
316 let offset_remainder = offset - aligned_offset;
317
318 Params {
319 pipeline_layout: self.pipeline_layout.as_ref(),
320 pipeline: self.pipeline.as_ref(),
321 dst_buffer: self.dst_buffer.as_ref(),
322 dst_bind_group: self.dst_bind_group.as_ref(),
323 aligned_offset,
324 offset_remainder,
325 }
326 }
327
328 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
329 let Dispatch {
330 module,
331 dst_bind_group_layout,
332 src_bind_group_layout,
333 pipeline_layout,
334 pipeline,
335 dst_buffer,
336 dst_bind_group,
337 } = self;
338
339 unsafe {
340 device.destroy_bind_group(dst_bind_group);
341 device.destroy_buffer(dst_buffer);
342 device.destroy_compute_pipeline(pipeline);
343 device.destroy_pipeline_layout(pipeline_layout);
344 device.destroy_bind_group_layout(src_bind_group_layout);
345 device.destroy_bind_group_layout(dst_bind_group_layout);
346 device.destroy_shader_module(module);
347 }
348 }
349}
350
351fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
352 let alignment = limits.min_storage_buffer_offset_alignment as u64;
353
354 let binding_size = 2 * alignment + (buffer_size % alignment);
385 binding_size.min(buffer_size)
386}