1use crate::binding_model::{BindError, BindGroup, ImmediateUploadError};
4use crate::command::encoder::EncodingState;
5use crate::command::{
6 bind::Binder, memory_init::SurfacesInDiscardState, query::QueryResetMap, DebugGroupError,
7 QueryUseError,
8};
9use crate::device::{Device, DeviceError, MissingFeatures};
10use crate::pipeline::LateSizedBufferGroup;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::track::{ResourceUsageCompatibilityError, UsageScope};
13use crate::{api_log, binding_model};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::str;
17use thiserror::Error;
18use wgt::error::{ErrorType, WebGpuError};
19use wgt::DynamicOffset;
20
21#[derive(Clone, Debug, Error)]
22#[error(
23 "Bind group index {index} is greater than the device's configured `max_bind_groups` limit {max}"
24)]
25pub struct BindGroupIndexOutOfRange {
26 pub index: u32,
27 pub max: u32,
28}
29
30#[derive(Clone, Debug, Error)]
31#[error("Pipeline must be set")]
32pub struct MissingPipeline;
33
34#[derive(Clone, Debug, Error)]
35#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
36pub struct InvalidValuesOffset;
37
38impl WebGpuError for InvalidValuesOffset {
39 fn webgpu_error_type(&self) -> ErrorType {
40 ErrorType::Validation
41 }
42}
43
44pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
45 pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
46
47 pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
50
51 pub(crate) scope: UsageScope<'scope>,
52
53 pub(crate) binder: Binder,
54
55 pub(crate) temp_offsets: Vec<u32>,
56
57 pub(crate) dynamic_offset_count: usize,
58
59 pub(crate) string_offset: usize,
60}
61
62pub(crate) fn set_bind_group<E>(
63 state: &mut PassState,
64 device: &Arc<Device>,
65 dynamic_offsets: &[DynamicOffset],
66 index: u32,
67 num_dynamic_offsets: usize,
68 bind_group: Option<Arc<BindGroup>>,
69 merge_bind_groups: bool,
70) -> Result<(), E>
71where
72 E: From<DeviceError>
73 + From<BindGroupIndexOutOfRange>
74 + From<ResourceUsageCompatibilityError>
75 + From<DestroyedResourceError>
76 + From<BindError>,
77{
78 if let Some(ref bind_group) = bind_group {
79 api_log!("Pass::set_bind_group {index} {}", bind_group.error_ident());
80 } else {
81 api_log!("Pass::set_bind_group {index} None");
82 }
83
84 let max_bind_groups = state.base.device.limits.max_bind_groups;
85 if index >= max_bind_groups {
86 return Err(BindGroupIndexOutOfRange {
87 index,
88 max: max_bind_groups,
89 }
90 .into());
91 }
92
93 state.temp_offsets.clear();
94 state.temp_offsets.extend_from_slice(
95 &dynamic_offsets
96 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
97 );
98 state.dynamic_offset_count += num_dynamic_offsets;
99
100 if let Some(bind_group) = bind_group {
101 let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
106
107 bind_group.same_device(device)?;
108
109 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
110
111 if merge_bind_groups {
112 unsafe {
116 state.scope.merge_bind_group(&bind_group.used)?;
117 }
118 }
119 state
123 .binder
124 .assign_group(index as usize, bind_group, &state.temp_offsets);
125 } else {
126 if !state.temp_offsets.is_empty() {
127 return Err(BindError::DynamicOffsetCountNotZero {
128 group: index,
129 actual: state.temp_offsets.len(),
130 }
131 .into());
132 }
133
134 state.binder.clear_group(index as usize);
135 };
136
137 Ok(())
138}
139
140pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), DestroyedResourceError> {
145 let start = state.binder.take_rebind_start_index();
146 let entries = state.binder.list_valid_with_start(start);
147 let pipeline_layout = state.binder.pipeline_layout.as_ref().unwrap();
148
149 for (i, bind_group, dynamic_offsets) in entries {
150 state.base.buffer_memory_init_actions.extend(
151 bind_group.used_buffer_ranges.iter().filter_map(|action| {
152 action
153 .buffer
154 .initialization_status
155 .read()
156 .check_action(action)
157 }),
158 );
159 for action in bind_group.used_texture_ranges.iter() {
160 state.pending_discard_init_fixups.extend(
161 state
162 .base
163 .texture_memory_actions
164 .register_init_action(action),
165 );
166 }
167
168 let used_resource = bind_group
169 .used
170 .acceleration_structures
171 .into_iter()
172 .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
173
174 state.base.as_actions.extend(used_resource);
175
176 let raw_bg = bind_group.try_raw(state.base.snatch_guard)?;
177 unsafe {
178 state.base.raw_encoder.set_bind_group(
179 pipeline_layout.raw(),
180 i as u32,
181 raw_bg,
182 dynamic_offsets,
183 );
184 }
185 }
186
187 Ok(())
188}
189
190pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
191 state: &mut PassState,
192 pipeline_layout: &Arc<binding_model::PipelineLayout>,
193 late_sized_buffer_groups: &[LateSizedBufferGroup],
194 f: F,
195) -> Result<(), E>
196where
197 E: From<DestroyedResourceError>,
198{
199 if state
200 .binder
201 .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups)
202 {
203 f();
204
205 super::immediates_clear(
206 0,
207 pipeline_layout.immediate_size,
208 |clear_offset, clear_data| unsafe {
209 state.base.raw_encoder.set_immediates(
210 pipeline_layout.raw(),
211 clear_offset,
212 clear_data,
213 );
214 },
215 );
216 }
217 Ok(())
218}
219
220pub(crate) fn set_immediates<E, F: FnOnce(&[u32])>(
221 state: &mut PassState,
222 immediates_data: &[u32],
223 offset: u32,
224 size_bytes: u32,
225 values_offset: Option<u32>,
226 f: F,
227) -> Result<(), E>
228where
229 E: From<ImmediateUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
230{
231 api_log!("Pass::set_immediates");
232
233 let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
234
235 let end_offset_bytes = offset + size_bytes;
236 let values_end_offset = (values_offset + size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
237 let data_slice = &immediates_data[(values_offset as usize)..values_end_offset];
238
239 let pipeline_layout = state
240 .binder
241 .pipeline_layout
242 .as_ref()
243 .ok_or(MissingPipeline)?;
244
245 pipeline_layout.validate_immediates_ranges(offset, end_offset_bytes)?;
246
247 f(data_slice);
248
249 unsafe {
250 state
251 .base
252 .raw_encoder
253 .set_immediates(pipeline_layout.raw(), offset, data_slice)
254 }
255 Ok(())
256}
257
258pub(crate) fn write_timestamp<E>(
259 state: &mut PassState,
260 device: &Arc<Device>,
261 pending_query_resets: Option<&mut QueryResetMap>,
262 query_set: Arc<QuerySet>,
263 query_index: u32,
264) -> Result<(), E>
265where
266 E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
267{
268 api_log!(
269 "Pass::write_timestamps {query_index} {}",
270 query_set.error_ident()
271 );
272
273 query_set.same_device(device)?;
274
275 state
276 .base
277 .device
278 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
279
280 let query_set = state.base.tracker.query_sets.insert_single(query_set);
281
282 query_set.validate_and_write_timestamp(
283 state.base.raw_encoder,
284 query_index,
285 pending_query_resets,
286 )?;
287 Ok(())
288}
289
290pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
291 *state.base.debug_scope_depth += 1;
292 if !state
293 .base
294 .device
295 .instance_flags
296 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
297 {
298 let label =
299 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
300
301 api_log!("Pass::push_debug_group {label:?}");
302 unsafe {
303 state.base.raw_encoder.begin_debug_marker(label);
304 }
305 }
306 state.string_offset += len;
307}
308
309pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
310where
311 E: From<DebugGroupError>,
312{
313 api_log!("Pass::pop_debug_group");
314
315 if *state.base.debug_scope_depth == 0 {
316 return Err(DebugGroupError::InvalidPop.into());
317 }
318 *state.base.debug_scope_depth -= 1;
319 if !state
320 .base
321 .device
322 .instance_flags
323 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
324 {
325 unsafe {
326 state.base.raw_encoder.end_debug_marker();
327 }
328 }
329 Ok(())
330}
331
332pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
333 if !state
334 .base
335 .device
336 .instance_flags
337 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
338 {
339 let label =
340 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
341 api_log!("Pass::insert_debug_marker {label:?}");
342 unsafe {
343 state.base.raw_encoder.insert_debug_marker(label);
344 }
345 }
346 state.string_offset += len;
347}