1use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::encoder::EncodingState;
6use crate::command::memory_init::SurfacesInDiscardState;
7use crate::command::{CommandEncoder, DebugGroupError, QueryResetMap, QueryUseError};
8use crate::device::{DeviceError, MissingFeatures};
9use crate::pipeline::LateSizedBufferGroup;
10use crate::ray_tracing::AsAction;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::track::{ResourceUsageCompatibilityError, UsageScope};
13use crate::{api_log, binding_model};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::str;
17use thiserror::Error;
18use wgt::error::{ErrorType, WebGpuError};
19use wgt::DynamicOffset;
20
21#[derive(Clone, Debug, Error)]
22#[error(
23 "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
24)]
25pub struct BindGroupIndexOutOfRange {
26 pub index: u32,
27 pub max: u32,
28}
29
30#[derive(Clone, Debug, Error)]
31#[error("Pipeline must be set")]
32pub struct MissingPipeline;
33
34#[derive(Clone, Debug, Error)]
35#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
36pub struct InvalidValuesOffset;
37
38impl WebGpuError for InvalidValuesOffset {
39 fn webgpu_error_type(&self) -> ErrorType {
40 ErrorType::Validation
41 }
42}
43
44pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
45 pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc, 'raw_encoder>,
46
47 pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
50
51 pub(crate) scope: UsageScope<'scope>,
52
53 pub(crate) binder: Binder,
54
55 pub(crate) temp_offsets: Vec<u32>,
56
57 pub(crate) dynamic_offset_count: usize,
58
59 pub(crate) string_offset: usize,
60}
61
62pub(crate) fn set_bind_group<E>(
63 state: &mut PassState,
64 cmd_enc: &CommandEncoder,
65 dynamic_offsets: &[DynamicOffset],
66 index: u32,
67 num_dynamic_offsets: usize,
68 bind_group: Option<Arc<BindGroup>>,
69 merge_bind_groups: bool,
70) -> Result<(), E>
71where
72 E: From<DeviceError>
73 + From<BindGroupIndexOutOfRange>
74 + From<ResourceUsageCompatibilityError>
75 + From<DestroyedResourceError>
76 + From<BindError>,
77{
78 if bind_group.is_none() {
79 api_log!("Pass::set_bind_group {index} None");
80 } else {
81 api_log!(
82 "Pass::set_bind_group {index} {}",
83 bind_group.as_ref().unwrap().error_ident()
84 );
85 }
86
87 let max_bind_groups = state.base.device.limits.max_bind_groups;
88 if index >= max_bind_groups {
89 return Err(BindGroupIndexOutOfRange {
90 index,
91 max: max_bind_groups,
92 }
93 .into());
94 }
95
96 state.temp_offsets.clear();
97 state.temp_offsets.extend_from_slice(
98 &dynamic_offsets
99 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
100 );
101 state.dynamic_offset_count += num_dynamic_offsets;
102
103 if bind_group.is_none() {
104 return Ok(());
106 }
107
108 let bind_group = bind_group.unwrap();
109 let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
110
111 bind_group.same_device_as(cmd_enc)?;
112
113 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
114
115 if merge_bind_groups {
116 unsafe {
118 state.scope.merge_bind_group(&bind_group.used)?;
119 }
120 }
121 state
125 .base
126 .buffer_memory_init_actions
127 .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
128 action
129 .buffer
130 .initialization_status
131 .read()
132 .check_action(action)
133 }));
134 for action in bind_group.used_texture_ranges.iter() {
135 state.pending_discard_init_fixups.extend(
136 state
137 .base
138 .texture_memory_actions
139 .register_init_action(action),
140 );
141 }
142
143 let used_resource = bind_group
144 .used
145 .acceleration_structures
146 .into_iter()
147 .map(|tlas| AsAction::UseTlas(tlas.clone()));
148
149 state.base.as_actions.extend(used_resource);
150
151 let pipeline_layout = state.binder.pipeline_layout.clone();
152 let entries = state
153 .binder
154 .assign_group(index as usize, bind_group, &state.temp_offsets);
155 if !entries.is_empty() && pipeline_layout.is_some() {
156 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
157 for (i, e) in entries.iter().enumerate() {
158 if let Some(group) = e.group.as_ref() {
159 let raw_bg = group.try_raw(state.base.snatch_guard)?;
160 unsafe {
161 state.base.raw_encoder.set_bind_group(
162 pipeline_layout,
163 index + i as u32,
164 Some(raw_bg),
165 &e.dynamic_offsets,
166 );
167 }
168 }
169 }
170 }
171 Ok(())
172}
173
174pub(crate) fn rebind_resources<E, F: FnOnce()>(
176 state: &mut PassState,
177 pipeline_layout: &Arc<binding_model::PipelineLayout>,
178 late_sized_buffer_groups: &[LateSizedBufferGroup],
179 f: F,
180) -> Result<(), E>
181where
182 E: From<DestroyedResourceError>,
183{
184 if state.binder.pipeline_layout.is_none()
185 || !state
186 .binder
187 .pipeline_layout
188 .as_ref()
189 .unwrap()
190 .is_equal(pipeline_layout)
191 {
192 let (start_index, entries) = state
193 .binder
194 .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
195 if !entries.is_empty() {
196 for (i, e) in entries.iter().enumerate() {
197 if let Some(group) = e.group.as_ref() {
198 let raw_bg = group.try_raw(state.base.snatch_guard)?;
199 unsafe {
200 state.base.raw_encoder.set_bind_group(
201 pipeline_layout.raw(),
202 start_index as u32 + i as u32,
203 Some(raw_bg),
204 &e.dynamic_offsets,
205 );
206 }
207 }
208 }
209 }
210
211 f();
212
213 let non_overlapping =
214 super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
215
216 for range in non_overlapping {
218 let offset = range.range.start;
219 let size_bytes = range.range.end - offset;
220 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
221 state.base.raw_encoder.set_push_constants(
222 pipeline_layout.raw(),
223 range.stages,
224 clear_offset,
225 clear_data,
226 );
227 });
228 }
229 }
230 Ok(())
231}
232
233pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
234 state: &mut PassState,
235 push_constant_data: &[u32],
236 stages: wgt::ShaderStages,
237 offset: u32,
238 size_bytes: u32,
239 values_offset: Option<u32>,
240 f: F,
241) -> Result<(), E>
242where
243 E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
244{
245 api_log!("Pass::set_push_constants");
246
247 let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
248
249 let end_offset_bytes = offset + size_bytes;
250 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
251 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
252
253 let pipeline_layout = state
254 .binder
255 .pipeline_layout
256 .as_ref()
257 .ok_or(MissingPipeline)?;
258
259 pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
260
261 f(data_slice);
262
263 unsafe {
264 state
265 .base
266 .raw_encoder
267 .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
268 }
269 Ok(())
270}
271
272pub(crate) fn write_timestamp<E>(
273 state: &mut PassState,
274 cmd_enc: &CommandEncoder,
275 pending_query_resets: Option<&mut QueryResetMap>,
276 query_set: Arc<QuerySet>,
277 query_index: u32,
278) -> Result<(), E>
279where
280 E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
281{
282 api_log!(
283 "Pass::write_timestamps {query_index} {}",
284 query_set.error_ident()
285 );
286
287 query_set.same_device_as(cmd_enc)?;
288
289 state
290 .base
291 .device
292 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
293
294 let query_set = state.base.tracker.query_sets.insert_single(query_set);
295
296 query_set.validate_and_write_timestamp(
297 state.base.raw_encoder,
298 query_index,
299 pending_query_resets,
300 )?;
301 Ok(())
302}
303
304pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
305 *state.base.debug_scope_depth += 1;
306 if !state
307 .base
308 .device
309 .instance_flags
310 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
311 {
312 let label =
313 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
314
315 api_log!("Pass::push_debug_group {label:?}");
316 unsafe {
317 state.base.raw_encoder.begin_debug_marker(label);
318 }
319 }
320 state.string_offset += len;
321}
322
323pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
324where
325 E: From<DebugGroupError>,
326{
327 api_log!("Pass::pop_debug_group");
328
329 if *state.base.debug_scope_depth == 0 {
330 return Err(DebugGroupError::InvalidPop.into());
331 }
332 *state.base.debug_scope_depth -= 1;
333 if !state
334 .base
335 .device
336 .instance_flags
337 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
338 {
339 unsafe {
340 state.base.raw_encoder.end_debug_marker();
341 }
342 }
343 Ok(())
344}
345
346pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
347 if !state
348 .base
349 .device
350 .instance_flags
351 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
352 {
353 let label =
354 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
355 api_log!("Pass::insert_debug_marker {label:?}");
356 unsafe {
357 state.base.raw_encoder.insert_debug_marker(label);
358 }
359 }
360 state.string_offset += len;
361}