1use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::encoder::EncodingState;
6use crate::command::memory_init::SurfacesInDiscardState;
7use crate::command::{DebugGroupError, QueryResetMap, QueryUseError};
8use crate::device::{Device, DeviceError, MissingFeatures};
9use crate::pipeline::LateSizedBufferGroup;
10use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
11use crate::track::{ResourceUsageCompatibilityError, UsageScope};
12use crate::{api_log, binding_model};
13use alloc::sync::Arc;
14use alloc::vec::Vec;
15use core::str;
16use thiserror::Error;
17use wgt::error::{ErrorType, WebGpuError};
18use wgt::DynamicOffset;
19
20#[derive(Clone, Debug, Error)]
21#[error(
22 "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
23)]
24pub struct BindGroupIndexOutOfRange {
25 pub index: u32,
26 pub max: u32,
27}
28
29#[derive(Clone, Debug, Error)]
30#[error("Pipeline must be set")]
31pub struct MissingPipeline;
32
33#[derive(Clone, Debug, Error)]
34#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
35pub struct InvalidValuesOffset;
36
37impl WebGpuError for InvalidValuesOffset {
38 fn webgpu_error_type(&self) -> ErrorType {
39 ErrorType::Validation
40 }
41}
42
43pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
44 pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
45
46 pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
49
50 pub(crate) scope: UsageScope<'scope>,
51
52 pub(crate) binder: Binder,
53
54 pub(crate) temp_offsets: Vec<u32>,
55
56 pub(crate) dynamic_offset_count: usize,
57
58 pub(crate) string_offset: usize,
59}
60
61pub(crate) fn set_bind_group<E>(
62 state: &mut PassState,
63 device: &Arc<Device>,
64 dynamic_offsets: &[DynamicOffset],
65 index: u32,
66 num_dynamic_offsets: usize,
67 bind_group: Option<Arc<BindGroup>>,
68 merge_bind_groups: bool,
69) -> Result<(), E>
70where
71 E: From<DeviceError>
72 + From<BindGroupIndexOutOfRange>
73 + From<ResourceUsageCompatibilityError>
74 + From<DestroyedResourceError>
75 + From<BindError>,
76{
77 if bind_group.is_none() {
78 api_log!("Pass::set_bind_group {index} None");
79 } else {
80 api_log!(
81 "Pass::set_bind_group {index} {}",
82 bind_group.as_ref().unwrap().error_ident()
83 );
84 }
85
86 let max_bind_groups = state.base.device.limits.max_bind_groups;
87 if index >= max_bind_groups {
88 return Err(BindGroupIndexOutOfRange {
89 index,
90 max: max_bind_groups,
91 }
92 .into());
93 }
94
95 state.temp_offsets.clear();
96 state.temp_offsets.extend_from_slice(
97 &dynamic_offsets
98 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
99 );
100 state.dynamic_offset_count += num_dynamic_offsets;
101
102 let Some(bind_group) = bind_group else {
103 return Ok(());
105 };
106
107 let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
112
113 bind_group.same_device(device)?;
114
115 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
116
117 if merge_bind_groups {
118 unsafe {
122 state.scope.merge_bind_group(&bind_group.used)?;
123 }
124 }
125 state
129 .binder
130 .assign_group(index as usize, bind_group, &state.temp_offsets);
131
132 Ok(())
133}
134
135pub(super) fn flush_bindings_helper<F>(
138 state: &mut PassState,
139 mut f: F,
140) -> Result<(), DestroyedResourceError>
141where
142 F: FnMut(&Arc<BindGroup>),
143{
144 for bind_group in state.binder.list_active() {
145 f(bind_group);
146
147 state.base.buffer_memory_init_actions.extend(
148 bind_group.used_buffer_ranges.iter().filter_map(|action| {
149 action
150 .buffer
151 .initialization_status
152 .read()
153 .check_action(action)
154 }),
155 );
156 for action in bind_group.used_texture_ranges.iter() {
157 state.pending_discard_init_fixups.extend(
158 state
159 .base
160 .texture_memory_actions
161 .register_init_action(action),
162 );
163 }
164
165 let used_resource = bind_group
166 .used
167 .acceleration_structures
168 .into_iter()
169 .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
170
171 state.base.as_actions.extend(used_resource);
172 }
173
174 let range = state.binder.take_rebind_range();
175 let entries = state.binder.entries(range);
176 match state.binder.pipeline_layout.as_ref() {
177 Some(pipeline_layout) if entries.len() != 0 => {
178 for (i, e) in entries {
179 if let Some(group) = e.group.as_ref() {
180 let raw_bg = group.try_raw(state.base.snatch_guard)?;
181 unsafe {
182 state.base.raw_encoder.set_bind_group(
183 pipeline_layout.raw(),
184 i as u32,
185 Some(raw_bg),
186 &e.dynamic_offsets,
187 );
188 }
189 }
190 }
191 }
192 _ => {}
193 }
194
195 Ok(())
196}
197
198pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
199 state: &mut PassState,
200 pipeline_layout: &Arc<binding_model::PipelineLayout>,
201 late_sized_buffer_groups: &[LateSizedBufferGroup],
202 f: F,
203) -> Result<(), E>
204where
205 E: From<DestroyedResourceError>,
206{
207 if state.binder.pipeline_layout.is_none()
208 || !state
209 .binder
210 .pipeline_layout
211 .as_ref()
212 .unwrap()
213 .is_equal(pipeline_layout)
214 {
215 state
216 .binder
217 .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
218
219 f();
220
221 let non_overlapping =
222 super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
223
224 for range in non_overlapping {
226 let offset = range.range.start;
227 let size_bytes = range.range.end - offset;
228 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
229 state.base.raw_encoder.set_push_constants(
230 pipeline_layout.raw(),
231 range.stages,
232 clear_offset,
233 clear_data,
234 );
235 });
236 }
237 }
238 Ok(())
239}
240
241pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
242 state: &mut PassState,
243 push_constant_data: &[u32],
244 stages: wgt::ShaderStages,
245 offset: u32,
246 size_bytes: u32,
247 values_offset: Option<u32>,
248 f: F,
249) -> Result<(), E>
250where
251 E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
252{
253 api_log!("Pass::set_push_constants");
254
255 let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
256
257 let end_offset_bytes = offset + size_bytes;
258 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
259 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
260
261 let pipeline_layout = state
262 .binder
263 .pipeline_layout
264 .as_ref()
265 .ok_or(MissingPipeline)?;
266
267 pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
268
269 f(data_slice);
270
271 unsafe {
272 state
273 .base
274 .raw_encoder
275 .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
276 }
277 Ok(())
278}
279
280pub(crate) fn write_timestamp<E>(
281 state: &mut PassState,
282 device: &Arc<Device>,
283 pending_query_resets: Option<&mut QueryResetMap>,
284 query_set: Arc<QuerySet>,
285 query_index: u32,
286) -> Result<(), E>
287where
288 E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
289{
290 api_log!(
291 "Pass::write_timestamps {query_index} {}",
292 query_set.error_ident()
293 );
294
295 query_set.same_device(device)?;
296
297 state
298 .base
299 .device
300 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
301
302 let query_set = state.base.tracker.query_sets.insert_single(query_set);
303
304 query_set.validate_and_write_timestamp(
305 state.base.raw_encoder,
306 query_index,
307 pending_query_resets,
308 )?;
309 Ok(())
310}
311
312pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
313 *state.base.debug_scope_depth += 1;
314 if !state
315 .base
316 .device
317 .instance_flags
318 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
319 {
320 let label =
321 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
322
323 api_log!("Pass::push_debug_group {label:?}");
324 unsafe {
325 state.base.raw_encoder.begin_debug_marker(label);
326 }
327 }
328 state.string_offset += len;
329}
330
331pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
332where
333 E: From<DebugGroupError>,
334{
335 api_log!("Pass::pop_debug_group");
336
337 if *state.base.debug_scope_depth == 0 {
338 return Err(DebugGroupError::InvalidPop.into());
339 }
340 *state.base.debug_scope_depth -= 1;
341 if !state
342 .base
343 .device
344 .instance_flags
345 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
346 {
347 unsafe {
348 state.base.raw_encoder.end_debug_marker();
349 }
350 }
351 Ok(())
352}
353
354pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
355 if !state
356 .base
357 .device
358 .instance_flags
359 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
360 {
361 let label =
362 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
363 api_log!("Pass::insert_debug_marker {label:?}");
364 unsafe {
365 state.base.raw_encoder.insert_debug_marker(label);
366 }
367 }
368 state.string_offset += len;
369}