wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::encoder::EncodingState;
6use crate::command::memory_init::SurfacesInDiscardState;
7use crate::command::{DebugGroupError, QueryResetMap, QueryUseError};
8use crate::device::{Device, DeviceError, MissingFeatures};
9use crate::pipeline::LateSizedBufferGroup;
10use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
11use crate::track::{ResourceUsageCompatibilityError, UsageScope};
12use crate::{api_log, binding_model};
13use alloc::sync::Arc;
14use alloc::vec::Vec;
15use core::str;
16use thiserror::Error;
17use wgt::error::{ErrorType, WebGpuError};
18use wgt::DynamicOffset;
19
20#[derive(Clone, Debug, Error)]
21#[error(
22    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
23)]
24pub struct BindGroupIndexOutOfRange {
25    pub index: u32,
26    pub max: u32,
27}
28
29#[derive(Clone, Debug, Error)]
30#[error("Pipeline must be set")]
31pub struct MissingPipeline;
32
33#[derive(Clone, Debug, Error)]
34#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
35pub struct InvalidValuesOffset;
36
37impl WebGpuError for InvalidValuesOffset {
38    fn webgpu_error_type(&self) -> ErrorType {
39        ErrorType::Validation
40    }
41}
42
43pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
44    pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
45
46    /// Immediate texture inits required because of prior discards. Need to
47    /// be inserted before texture reads.
48    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
49
50    pub(crate) scope: UsageScope<'scope>,
51
52    pub(crate) binder: Binder,
53
54    pub(crate) temp_offsets: Vec<u32>,
55
56    pub(crate) dynamic_offset_count: usize,
57
58    pub(crate) string_offset: usize,
59}
60
61pub(crate) fn set_bind_group<E>(
62    state: &mut PassState,
63    device: &Arc<Device>,
64    dynamic_offsets: &[DynamicOffset],
65    index: u32,
66    num_dynamic_offsets: usize,
67    bind_group: Option<Arc<BindGroup>>,
68    merge_bind_groups: bool,
69) -> Result<(), E>
70where
71    E: From<DeviceError>
72        + From<BindGroupIndexOutOfRange>
73        + From<ResourceUsageCompatibilityError>
74        + From<DestroyedResourceError>
75        + From<BindError>,
76{
77    if bind_group.is_none() {
78        api_log!("Pass::set_bind_group {index} None");
79    } else {
80        api_log!(
81            "Pass::set_bind_group {index} {}",
82            bind_group.as_ref().unwrap().error_ident()
83        );
84    }
85
86    let max_bind_groups = state.base.device.limits.max_bind_groups;
87    if index >= max_bind_groups {
88        return Err(BindGroupIndexOutOfRange {
89            index,
90            max: max_bind_groups,
91        }
92        .into());
93    }
94
95    state.temp_offsets.clear();
96    state.temp_offsets.extend_from_slice(
97        &dynamic_offsets
98            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
99    );
100    state.dynamic_offset_count += num_dynamic_offsets;
101
102    let Some(bind_group) = bind_group else {
103        // TODO: Handle bind_group None.
104        return Ok(());
105    };
106
107    // Add the bind group to the tracker. This is done for both compute and
108    // render passes, and is used to fail submission of the command buffer if
109    // any resource in any of the bind groups has been destroyed, whether or
110    // not the bind group is actually used by the pipeline.
111    let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
112
113    bind_group.same_device(device)?;
114
115    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
116
117    if merge_bind_groups {
118        // Merge the bind group's resources into the tracker. We only do this
119        // for render passes. For compute passes it is done per dispatch in
120        // [`flush_bindings`].
121        unsafe {
122            state.scope.merge_bind_group(&bind_group.used)?;
123        }
124    }
125    //Note: stateless trackers are not merged: the lifetime reference
126    // is held to the bind group itself.
127
128    state
129        .binder
130        .assign_group(index as usize, bind_group, &state.temp_offsets);
131
132    Ok(())
133}
134
135/// Helper for `flush_bindings` implementing the portions that are the same for
136/// compute and render passes.
137pub(super) fn flush_bindings_helper<F>(
138    state: &mut PassState,
139    mut f: F,
140) -> Result<(), DestroyedResourceError>
141where
142    F: FnMut(&Arc<BindGroup>),
143{
144    for bind_group in state.binder.list_active() {
145        f(bind_group);
146
147        state.base.buffer_memory_init_actions.extend(
148            bind_group.used_buffer_ranges.iter().filter_map(|action| {
149                action
150                    .buffer
151                    .initialization_status
152                    .read()
153                    .check_action(action)
154            }),
155        );
156        for action in bind_group.used_texture_ranges.iter() {
157            state.pending_discard_init_fixups.extend(
158                state
159                    .base
160                    .texture_memory_actions
161                    .register_init_action(action),
162            );
163        }
164
165        let used_resource = bind_group
166            .used
167            .acceleration_structures
168            .into_iter()
169            .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
170
171        state.base.as_actions.extend(used_resource);
172    }
173
174    let range = state.binder.take_rebind_range();
175    let entries = state.binder.entries(range);
176    match state.binder.pipeline_layout.as_ref() {
177        Some(pipeline_layout) if entries.len() != 0 => {
178            for (i, e) in entries {
179                if let Some(group) = e.group.as_ref() {
180                    let raw_bg = group.try_raw(state.base.snatch_guard)?;
181                    unsafe {
182                        state.base.raw_encoder.set_bind_group(
183                            pipeline_layout.raw(),
184                            i as u32,
185                            Some(raw_bg),
186                            &e.dynamic_offsets,
187                        );
188                    }
189                }
190            }
191        }
192        _ => {}
193    }
194
195    Ok(())
196}
197
198pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
199    state: &mut PassState,
200    pipeline_layout: &Arc<binding_model::PipelineLayout>,
201    late_sized_buffer_groups: &[LateSizedBufferGroup],
202    f: F,
203) -> Result<(), E>
204where
205    E: From<DestroyedResourceError>,
206{
207    if state.binder.pipeline_layout.is_none()
208        || !state
209            .binder
210            .pipeline_layout
211            .as_ref()
212            .unwrap()
213            .is_equal(pipeline_layout)
214    {
215        state
216            .binder
217            .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
218
219        f();
220
221        let non_overlapping =
222            super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
223
224        // Clear push constant ranges
225        for range in non_overlapping {
226            let offset = range.range.start;
227            let size_bytes = range.range.end - offset;
228            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
229                state.base.raw_encoder.set_push_constants(
230                    pipeline_layout.raw(),
231                    range.stages,
232                    clear_offset,
233                    clear_data,
234                );
235            });
236        }
237    }
238    Ok(())
239}
240
241pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
242    state: &mut PassState,
243    push_constant_data: &[u32],
244    stages: wgt::ShaderStages,
245    offset: u32,
246    size_bytes: u32,
247    values_offset: Option<u32>,
248    f: F,
249) -> Result<(), E>
250where
251    E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
252{
253    api_log!("Pass::set_push_constants");
254
255    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
256
257    let end_offset_bytes = offset + size_bytes;
258    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
259    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
260
261    let pipeline_layout = state
262        .binder
263        .pipeline_layout
264        .as_ref()
265        .ok_or(MissingPipeline)?;
266
267    pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
268
269    f(data_slice);
270
271    unsafe {
272        state
273            .base
274            .raw_encoder
275            .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
276    }
277    Ok(())
278}
279
280pub(crate) fn write_timestamp<E>(
281    state: &mut PassState,
282    device: &Arc<Device>,
283    pending_query_resets: Option<&mut QueryResetMap>,
284    query_set: Arc<QuerySet>,
285    query_index: u32,
286) -> Result<(), E>
287where
288    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
289{
290    api_log!(
291        "Pass::write_timestamps {query_index} {}",
292        query_set.error_ident()
293    );
294
295    query_set.same_device(device)?;
296
297    state
298        .base
299        .device
300        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
301
302    let query_set = state.base.tracker.query_sets.insert_single(query_set);
303
304    query_set.validate_and_write_timestamp(
305        state.base.raw_encoder,
306        query_index,
307        pending_query_resets,
308    )?;
309    Ok(())
310}
311
312pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
313    *state.base.debug_scope_depth += 1;
314    if !state
315        .base
316        .device
317        .instance_flags
318        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
319    {
320        let label =
321            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
322
323        api_log!("Pass::push_debug_group {label:?}");
324        unsafe {
325            state.base.raw_encoder.begin_debug_marker(label);
326        }
327    }
328    state.string_offset += len;
329}
330
331pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
332where
333    E: From<DebugGroupError>,
334{
335    api_log!("Pass::pop_debug_group");
336
337    if *state.base.debug_scope_depth == 0 {
338        return Err(DebugGroupError::InvalidPop.into());
339    }
340    *state.base.debug_scope_depth -= 1;
341    if !state
342        .base
343        .device
344        .instance_flags
345        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
346    {
347        unsafe {
348            state.base.raw_encoder.end_debug_marker();
349        }
350    }
351    Ok(())
352}
353
354pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
355    if !state
356        .base
357        .device
358        .instance_flags
359        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
360    {
361        let label =
362            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
363        api_log!("Pass::insert_debug_marker {label:?}");
364        unsafe {
365            state.base.raw_encoder.insert_debug_marker(label);
366        }
367    }
368    state.string_offset += len;
369}