wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, ImmediateUploadError};
4use crate::command::encoder::EncodingState;
5use crate::command::{
6    bind::Binder, memory_init::SurfacesInDiscardState, query::QueryResetMap, DebugGroupError,
7    QueryUseError,
8};
9use crate::device::{Device, DeviceError, MissingFeatures};
10use crate::pipeline::LateSizedBufferGroup;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::track::{ResourceUsageCompatibilityError, UsageScope};
13use crate::{api_log, binding_model};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::str;
17use thiserror::Error;
18use wgt::error::{ErrorType, WebGpuError};
19use wgt::DynamicOffset;
20
21#[derive(Clone, Debug, Error)]
22#[error(
23    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
24)]
25pub struct BindGroupIndexOutOfRange {
26    pub index: u32,
27    pub max: u32,
28}
29
30#[derive(Clone, Debug, Error)]
31#[error("Pipeline must be set")]
32pub struct MissingPipeline;
33
34#[derive(Clone, Debug, Error)]
35#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
36pub struct InvalidValuesOffset;
37
38impl WebGpuError for InvalidValuesOffset {
39    fn webgpu_error_type(&self) -> ErrorType {
40        ErrorType::Validation
41    }
42}
43
44pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
45    pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
46
47    /// Immediate texture inits required because of prior discards. Need to
48    /// be inserted before texture reads.
49    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
50
51    pub(crate) scope: UsageScope<'scope>,
52
53    pub(crate) binder: Binder,
54
55    pub(crate) temp_offsets: Vec<u32>,
56
57    pub(crate) dynamic_offset_count: usize,
58
59    pub(crate) string_offset: usize,
60}
61
62pub(crate) fn set_bind_group<E>(
63    state: &mut PassState,
64    device: &Arc<Device>,
65    dynamic_offsets: &[DynamicOffset],
66    index: u32,
67    num_dynamic_offsets: usize,
68    bind_group: Option<Arc<BindGroup>>,
69    merge_bind_groups: bool,
70) -> Result<(), E>
71where
72    E: From<DeviceError>
73        + From<BindGroupIndexOutOfRange>
74        + From<ResourceUsageCompatibilityError>
75        + From<DestroyedResourceError>
76        + From<BindError>,
77{
78    if let Some(ref bind_group) = bind_group {
79        api_log!("Pass::set_bind_group {index} {}", bind_group.error_ident());
80    } else {
81        api_log!("Pass::set_bind_group {index} None");
82    }
83
84    let max_bind_groups = state.base.device.limits.max_bind_groups;
85    if index >= max_bind_groups {
86        return Err(BindGroupIndexOutOfRange {
87            index,
88            max: max_bind_groups,
89        }
90        .into());
91    }
92
93    state.temp_offsets.clear();
94    state.temp_offsets.extend_from_slice(
95        &dynamic_offsets
96            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
97    );
98    state.dynamic_offset_count += num_dynamic_offsets;
99
100    let Some(bind_group) = bind_group else {
101        // TODO: Handle bind_group None.
102        return Ok(());
103    };
104
105    // Add the bind group to the tracker. This is done for both compute and
106    // render passes, and is used to fail submission of the command buffer if
107    // any resource in any of the bind groups has been destroyed, whether or
108    // not the bind group is actually used by the pipeline.
109    let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
110
111    bind_group.same_device(device)?;
112
113    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
114
115    if merge_bind_groups {
116        // Merge the bind group's resources into the tracker. We only do this
117        // for render passes. For compute passes it is done per dispatch in
118        // [`flush_bindings`].
119        unsafe {
120            state.scope.merge_bind_group(&bind_group.used)?;
121        }
122    }
123    //Note: stateless trackers are not merged: the lifetime reference
124    // is held to the bind group itself.
125
126    state
127        .binder
128        .assign_group(index as usize, bind_group, &state.temp_offsets);
129
130    Ok(())
131}
132
133/// Implementation of `flush_bindings` for both compute and render passes.
134///
135/// See the compute pass version of `State::flush_bindings` for an explanation
136/// of some differences in handling the two types of passes.
137pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), DestroyedResourceError> {
138    let range = state.binder.take_rebind_range();
139    if range.is_empty() {
140        return Ok(());
141    }
142
143    let entries = state.binder.entries(range);
144
145    for (_, entry) in entries.clone() {
146        let bind_group = entry.group.as_ref().unwrap();
147
148        state.base.buffer_memory_init_actions.extend(
149            bind_group.used_buffer_ranges.iter().filter_map(|action| {
150                action
151                    .buffer
152                    .initialization_status
153                    .read()
154                    .check_action(action)
155            }),
156        );
157        for action in bind_group.used_texture_ranges.iter() {
158            state.pending_discard_init_fixups.extend(
159                state
160                    .base
161                    .texture_memory_actions
162                    .register_init_action(action),
163            );
164        }
165
166        let used_resource = bind_group
167            .used
168            .acceleration_structures
169            .into_iter()
170            .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
171
172        state.base.as_actions.extend(used_resource);
173    }
174
175    if let Some(pipeline_layout) = state.binder.pipeline_layout.as_ref() {
176        for (i, e) in entries {
177            if let Some(group) = e.group.as_ref() {
178                let raw_bg = group.try_raw(state.base.snatch_guard)?;
179                unsafe {
180                    state.base.raw_encoder.set_bind_group(
181                        pipeline_layout.raw(),
182                        i as u32,
183                        Some(raw_bg),
184                        &e.dynamic_offsets,
185                    );
186                }
187            }
188        }
189    }
190
191    Ok(())
192}
193
194pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
195    state: &mut PassState,
196    pipeline_layout: &Arc<binding_model::PipelineLayout>,
197    late_sized_buffer_groups: &[LateSizedBufferGroup],
198    f: F,
199) -> Result<(), E>
200where
201    E: From<DestroyedResourceError>,
202{
203    if state
204        .binder
205        .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups)
206    {
207        f();
208
209        super::immediates_clear(
210            0,
211            pipeline_layout.immediate_size,
212            |clear_offset, clear_data| unsafe {
213                state.base.raw_encoder.set_immediates(
214                    pipeline_layout.raw(),
215                    clear_offset,
216                    clear_data,
217                );
218            },
219        );
220    }
221    Ok(())
222}
223
224pub(crate) fn set_immediates<E, F: FnOnce(&[u32])>(
225    state: &mut PassState,
226    immediates_data: &[u32],
227    offset: u32,
228    size_bytes: u32,
229    values_offset: Option<u32>,
230    f: F,
231) -> Result<(), E>
232where
233    E: From<ImmediateUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
234{
235    api_log!("Pass::set_immediates");
236
237    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
238
239    let end_offset_bytes = offset + size_bytes;
240    let values_end_offset = (values_offset + size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
241    let data_slice = &immediates_data[(values_offset as usize)..values_end_offset];
242
243    let pipeline_layout = state
244        .binder
245        .pipeline_layout
246        .as_ref()
247        .ok_or(MissingPipeline)?;
248
249    pipeline_layout.validate_immediates_ranges(offset, end_offset_bytes)?;
250
251    f(data_slice);
252
253    unsafe {
254        state
255            .base
256            .raw_encoder
257            .set_immediates(pipeline_layout.raw(), offset, data_slice)
258    }
259    Ok(())
260}
261
262pub(crate) fn write_timestamp<E>(
263    state: &mut PassState,
264    device: &Arc<Device>,
265    pending_query_resets: Option<&mut QueryResetMap>,
266    query_set: Arc<QuerySet>,
267    query_index: u32,
268) -> Result<(), E>
269where
270    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
271{
272    api_log!(
273        "Pass::write_timestamps {query_index} {}",
274        query_set.error_ident()
275    );
276
277    query_set.same_device(device)?;
278
279    state
280        .base
281        .device
282        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
283
284    let query_set = state.base.tracker.query_sets.insert_single(query_set);
285
286    query_set.validate_and_write_timestamp(
287        state.base.raw_encoder,
288        query_index,
289        pending_query_resets,
290    )?;
291    Ok(())
292}
293
294pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
295    *state.base.debug_scope_depth += 1;
296    if !state
297        .base
298        .device
299        .instance_flags
300        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
301    {
302        let label =
303            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
304
305        api_log!("Pass::push_debug_group {label:?}");
306        unsafe {
307            state.base.raw_encoder.begin_debug_marker(label);
308        }
309    }
310    state.string_offset += len;
311}
312
313pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
314where
315    E: From<DebugGroupError>,
316{
317    api_log!("Pass::pop_debug_group");
318
319    if *state.base.debug_scope_depth == 0 {
320        return Err(DebugGroupError::InvalidPop.into());
321    }
322    *state.base.debug_scope_depth -= 1;
323    if !state
324        .base
325        .device
326        .instance_flags
327        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
328    {
329        unsafe {
330            state.base.raw_encoder.end_debug_marker();
331        }
332    }
333    Ok(())
334}
335
336pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
337    if !state
338        .base
339        .device
340        .instance_flags
341        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
342    {
343        let label =
344            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
345        api_log!("Pass::insert_debug_marker {label:?}");
346        unsafe {
347            state.base.raw_encoder.insert_debug_marker(label);
348        }
349    }
350    state.string_offset += len;
351}