1use crate::binding_model::{BindError, BindGroup, ImmediateUploadError};
4use crate::command::encoder::EncodingState;
5use crate::command::{
6 bind::Binder, memory_init::SurfacesInDiscardState, query::QueryResetMap, DebugGroupError,
7 QueryUseError,
8};
9use crate::device::{Device, DeviceError, MissingFeatures};
10use crate::pipeline::LateSizedBufferGroup;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::track::{ResourceUsageCompatibilityError, UsageScope};
13use crate::{api_log, binding_model};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::str;
17use thiserror::Error;
18use wgt::error::{ErrorType, WebGpuError};
19use wgt::DynamicOffset;
20
21#[derive(Clone, Debug, Error)]
22#[error(
23 "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
24)]
25pub struct BindGroupIndexOutOfRange {
26 pub index: u32,
27 pub max: u32,
28}
29
30#[derive(Clone, Debug, Error)]
31#[error("Pipeline must be set")]
32pub struct MissingPipeline;
33
34#[derive(Clone, Debug, Error)]
35#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
36pub struct InvalidValuesOffset;
37
38impl WebGpuError for InvalidValuesOffset {
39 fn webgpu_error_type(&self) -> ErrorType {
40 ErrorType::Validation
41 }
42}
43
44pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
45 pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
46
47 pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
50
51 pub(crate) scope: UsageScope<'scope>,
52
53 pub(crate) binder: Binder,
54
55 pub(crate) temp_offsets: Vec<u32>,
56
57 pub(crate) dynamic_offset_count: usize,
58
59 pub(crate) string_offset: usize,
60}
61
62pub(crate) fn set_bind_group<E>(
63 state: &mut PassState,
64 device: &Arc<Device>,
65 dynamic_offsets: &[DynamicOffset],
66 index: u32,
67 num_dynamic_offsets: usize,
68 bind_group: Option<Arc<BindGroup>>,
69 merge_bind_groups: bool,
70) -> Result<(), E>
71where
72 E: From<DeviceError>
73 + From<BindGroupIndexOutOfRange>
74 + From<ResourceUsageCompatibilityError>
75 + From<DestroyedResourceError>
76 + From<BindError>,
77{
78 if let Some(ref bind_group) = bind_group {
79 api_log!("Pass::set_bind_group {index} {}", bind_group.error_ident());
80 } else {
81 api_log!("Pass::set_bind_group {index} None");
82 }
83
84 let max_bind_groups = state.base.device.limits.max_bind_groups;
85 if index >= max_bind_groups {
86 return Err(BindGroupIndexOutOfRange {
87 index,
88 max: max_bind_groups,
89 }
90 .into());
91 }
92
93 state.temp_offsets.clear();
94 state.temp_offsets.extend_from_slice(
95 &dynamic_offsets
96 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
97 );
98 state.dynamic_offset_count += num_dynamic_offsets;
99
100 let Some(bind_group) = bind_group else {
101 return Ok(());
103 };
104
105 let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
110
111 bind_group.same_device(device)?;
112
113 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
114
115 if merge_bind_groups {
116 unsafe {
120 state.scope.merge_bind_group(&bind_group.used)?;
121 }
122 }
123 state
127 .binder
128 .assign_group(index as usize, bind_group, &state.temp_offsets);
129
130 Ok(())
131}
132
133pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), DestroyedResourceError> {
138 let range = state.binder.take_rebind_range();
139 if range.is_empty() {
140 return Ok(());
141 }
142
143 let entries = state.binder.entries(range);
144
145 for (_, entry) in entries.clone() {
146 let bind_group = entry.group.as_ref().unwrap();
147
148 state.base.buffer_memory_init_actions.extend(
149 bind_group.used_buffer_ranges.iter().filter_map(|action| {
150 action
151 .buffer
152 .initialization_status
153 .read()
154 .check_action(action)
155 }),
156 );
157 for action in bind_group.used_texture_ranges.iter() {
158 state.pending_discard_init_fixups.extend(
159 state
160 .base
161 .texture_memory_actions
162 .register_init_action(action),
163 );
164 }
165
166 let used_resource = bind_group
167 .used
168 .acceleration_structures
169 .into_iter()
170 .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
171
172 state.base.as_actions.extend(used_resource);
173 }
174
175 if let Some(pipeline_layout) = state.binder.pipeline_layout.as_ref() {
176 for (i, e) in entries {
177 if let Some(group) = e.group.as_ref() {
178 let raw_bg = group.try_raw(state.base.snatch_guard)?;
179 unsafe {
180 state.base.raw_encoder.set_bind_group(
181 pipeline_layout.raw(),
182 i as u32,
183 Some(raw_bg),
184 &e.dynamic_offsets,
185 );
186 }
187 }
188 }
189 }
190
191 Ok(())
192}
193
194pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
195 state: &mut PassState,
196 pipeline_layout: &Arc<binding_model::PipelineLayout>,
197 late_sized_buffer_groups: &[LateSizedBufferGroup],
198 f: F,
199) -> Result<(), E>
200where
201 E: From<DestroyedResourceError>,
202{
203 if state
204 .binder
205 .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups)
206 {
207 f();
208
209 super::immediates_clear(
210 0,
211 pipeline_layout.immediate_size,
212 |clear_offset, clear_data| unsafe {
213 state.base.raw_encoder.set_immediates(
214 pipeline_layout.raw(),
215 clear_offset,
216 clear_data,
217 );
218 },
219 );
220 }
221 Ok(())
222}
223
224pub(crate) fn set_immediates<E, F: FnOnce(&[u32])>(
225 state: &mut PassState,
226 immediates_data: &[u32],
227 offset: u32,
228 size_bytes: u32,
229 values_offset: Option<u32>,
230 f: F,
231) -> Result<(), E>
232where
233 E: From<ImmediateUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
234{
235 api_log!("Pass::set_immediates");
236
237 let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
238
239 let end_offset_bytes = offset + size_bytes;
240 let values_end_offset = (values_offset + size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
241 let data_slice = &immediates_data[(values_offset as usize)..values_end_offset];
242
243 let pipeline_layout = state
244 .binder
245 .pipeline_layout
246 .as_ref()
247 .ok_or(MissingPipeline)?;
248
249 pipeline_layout.validate_immediates_ranges(offset, end_offset_bytes)?;
250
251 f(data_slice);
252
253 unsafe {
254 state
255 .base
256 .raw_encoder
257 .set_immediates(pipeline_layout.raw(), offset, data_slice)
258 }
259 Ok(())
260}
261
262pub(crate) fn write_timestamp<E>(
263 state: &mut PassState,
264 device: &Arc<Device>,
265 pending_query_resets: Option<&mut QueryResetMap>,
266 query_set: Arc<QuerySet>,
267 query_index: u32,
268) -> Result<(), E>
269where
270 E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
271{
272 api_log!(
273 "Pass::write_timestamps {query_index} {}",
274 query_set.error_ident()
275 );
276
277 query_set.same_device(device)?;
278
279 state
280 .base
281 .device
282 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
283
284 let query_set = state.base.tracker.query_sets.insert_single(query_set);
285
286 query_set.validate_and_write_timestamp(
287 state.base.raw_encoder,
288 query_index,
289 pending_query_resets,
290 )?;
291 Ok(())
292}
293
294pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
295 *state.base.debug_scope_depth += 1;
296 if !state
297 .base
298 .device
299 .instance_flags
300 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
301 {
302 let label =
303 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
304
305 api_log!("Pass::push_debug_group {label:?}");
306 unsafe {
307 state.base.raw_encoder.begin_debug_marker(label);
308 }
309 }
310 state.string_offset += len;
311}
312
313pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
314where
315 E: From<DebugGroupError>,
316{
317 api_log!("Pass::pop_debug_group");
318
319 if *state.base.debug_scope_depth == 0 {
320 return Err(DebugGroupError::InvalidPop.into());
321 }
322 *state.base.debug_scope_depth -= 1;
323 if !state
324 .base
325 .device
326 .instance_flags
327 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
328 {
329 unsafe {
330 state.base.raw_encoder.end_debug_marker();
331 }
332 }
333 Ok(())
334}
335
336pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
337 if !state
338 .base
339 .device
340 .instance_flags
341 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
342 {
343 let label =
344 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
345 api_log!("Pass::insert_debug_marker {label:?}");
346 unsafe {
347 state.base.raw_encoder.insert_debug_marker(label);
348 }
349 }
350 state.string_offset += len;
351}