wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::encoder::EncodingState;
6use crate::command::memory_init::SurfacesInDiscardState;
7use crate::command::{DebugGroupError, QueryResetMap, QueryUseError};
8use crate::device::{Device, DeviceError, MissingFeatures};
9use crate::pipeline::LateSizedBufferGroup;
10use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
11use crate::track::{ResourceUsageCompatibilityError, UsageScope};
12use crate::{api_log, binding_model};
13use alloc::sync::Arc;
14use alloc::vec::Vec;
15use core::str;
16use thiserror::Error;
17use wgt::error::{ErrorType, WebGpuError};
18use wgt::DynamicOffset;
19
20#[derive(Clone, Debug, Error)]
21#[error(
22    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
23)]
24pub struct BindGroupIndexOutOfRange {
25    pub index: u32,
26    pub max: u32,
27}
28
29#[derive(Clone, Debug, Error)]
30#[error("Pipeline must be set")]
31pub struct MissingPipeline;
32
33#[derive(Clone, Debug, Error)]
34#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
35pub struct InvalidValuesOffset;
36
37impl WebGpuError for InvalidValuesOffset {
38    fn webgpu_error_type(&self) -> ErrorType {
39        ErrorType::Validation
40    }
41}
42
43pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc> {
44    pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc>,
45
46    /// Immediate texture inits required because of prior discards. Need to
47    /// be inserted before texture reads.
48    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
49
50    pub(crate) scope: UsageScope<'scope>,
51
52    pub(crate) binder: Binder,
53
54    pub(crate) temp_offsets: Vec<u32>,
55
56    pub(crate) dynamic_offset_count: usize,
57
58    pub(crate) string_offset: usize,
59}
60
61pub(crate) fn set_bind_group<E>(
62    state: &mut PassState,
63    device: &Arc<Device>,
64    dynamic_offsets: &[DynamicOffset],
65    index: u32,
66    num_dynamic_offsets: usize,
67    bind_group: Option<Arc<BindGroup>>,
68    merge_bind_groups: bool,
69) -> Result<(), E>
70where
71    E: From<DeviceError>
72        + From<BindGroupIndexOutOfRange>
73        + From<ResourceUsageCompatibilityError>
74        + From<DestroyedResourceError>
75        + From<BindError>,
76{
77    if bind_group.is_none() {
78        api_log!("Pass::set_bind_group {index} None");
79    } else {
80        api_log!(
81            "Pass::set_bind_group {index} {}",
82            bind_group.as_ref().unwrap().error_ident()
83        );
84    }
85
86    let max_bind_groups = state.base.device.limits.max_bind_groups;
87    if index >= max_bind_groups {
88        return Err(BindGroupIndexOutOfRange {
89            index,
90            max: max_bind_groups,
91        }
92        .into());
93    }
94
95    state.temp_offsets.clear();
96    state.temp_offsets.extend_from_slice(
97        &dynamic_offsets
98            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
99    );
100    state.dynamic_offset_count += num_dynamic_offsets;
101
102    let Some(bind_group) = bind_group else {
103        // TODO: Handle bind_group None.
104        return Ok(());
105    };
106
107    // Add the bind group to the tracker. This is done for both compute and
108    // render passes, and is used to fail submission of the command buffer if
109    // any resource in any of the bind groups has been destroyed, whether or
110    // not the bind group is actually used by the pipeline.
111    let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
112
113    bind_group.same_device(device)?;
114
115    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
116
117    if merge_bind_groups {
118        // Merge the bind group's resources into the tracker. We only do this
119        // for render passes. For compute passes it is done per dispatch in
120        // [`flush_bindings`].
121        unsafe {
122            state.scope.merge_bind_group(&bind_group.used)?;
123        }
124    }
125    //Note: stateless trackers are not merged: the lifetime reference
126    // is held to the bind group itself.
127
128    state
129        .binder
130        .assign_group(index as usize, bind_group, &state.temp_offsets);
131
132    Ok(())
133}
134
135/// Implementation of `flush_bindings` for both compute and render passes.
136///
137/// See the compute pass version of `State::flush_bindings` for an explanation
138/// of some differences in handling the two types of passes.
139pub(super) fn flush_bindings_helper(state: &mut PassState) -> Result<(), DestroyedResourceError> {
140    let range = state.binder.take_rebind_range();
141    if range.is_empty() {
142        return Ok(());
143    }
144
145    let entries = state.binder.entries(range);
146
147    for (_, entry) in entries.clone() {
148        let bind_group = entry.group.as_ref().unwrap();
149
150        state.base.buffer_memory_init_actions.extend(
151            bind_group.used_buffer_ranges.iter().filter_map(|action| {
152                action
153                    .buffer
154                    .initialization_status
155                    .read()
156                    .check_action(action)
157            }),
158        );
159        for action in bind_group.used_texture_ranges.iter() {
160            state.pending_discard_init_fixups.extend(
161                state
162                    .base
163                    .texture_memory_actions
164                    .register_init_action(action),
165            );
166        }
167
168        let used_resource = bind_group
169            .used
170            .acceleration_structures
171            .into_iter()
172            .map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
173
174        state.base.as_actions.extend(used_resource);
175    }
176
177    if let Some(pipeline_layout) = state.binder.pipeline_layout.as_ref() {
178        for (i, e) in entries {
179            if let Some(group) = e.group.as_ref() {
180                let raw_bg = group.try_raw(state.base.snatch_guard)?;
181                unsafe {
182                    state.base.raw_encoder.set_bind_group(
183                        pipeline_layout.raw(),
184                        i as u32,
185                        Some(raw_bg),
186                        &e.dynamic_offsets,
187                    );
188                }
189            }
190        }
191    }
192
193    Ok(())
194}
195
196pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
197    state: &mut PassState,
198    pipeline_layout: &Arc<binding_model::PipelineLayout>,
199    late_sized_buffer_groups: &[LateSizedBufferGroup],
200    f: F,
201) -> Result<(), E>
202where
203    E: From<DestroyedResourceError>,
204{
205    if state.binder.pipeline_layout.is_none()
206        || !state
207            .binder
208            .pipeline_layout
209            .as_ref()
210            .unwrap()
211            .is_equal(pipeline_layout)
212    {
213        state
214            .binder
215            .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
216
217        f();
218
219        let non_overlapping =
220            super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
221
222        // Clear push constant ranges
223        for range in non_overlapping {
224            let offset = range.range.start;
225            let size_bytes = range.range.end - offset;
226            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
227                state.base.raw_encoder.set_push_constants(
228                    pipeline_layout.raw(),
229                    range.stages,
230                    clear_offset,
231                    clear_data,
232                );
233            });
234        }
235    }
236    Ok(())
237}
238
239pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
240    state: &mut PassState,
241    push_constant_data: &[u32],
242    stages: wgt::ShaderStages,
243    offset: u32,
244    size_bytes: u32,
245    values_offset: Option<u32>,
246    f: F,
247) -> Result<(), E>
248where
249    E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
250{
251    api_log!("Pass::set_push_constants");
252
253    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
254
255    let end_offset_bytes = offset + size_bytes;
256    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
257    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
258
259    let pipeline_layout = state
260        .binder
261        .pipeline_layout
262        .as_ref()
263        .ok_or(MissingPipeline)?;
264
265    pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
266
267    f(data_slice);
268
269    unsafe {
270        state
271            .base
272            .raw_encoder
273            .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
274    }
275    Ok(())
276}
277
278pub(crate) fn write_timestamp<E>(
279    state: &mut PassState,
280    device: &Arc<Device>,
281    pending_query_resets: Option<&mut QueryResetMap>,
282    query_set: Arc<QuerySet>,
283    query_index: u32,
284) -> Result<(), E>
285where
286    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
287{
288    api_log!(
289        "Pass::write_timestamps {query_index} {}",
290        query_set.error_ident()
291    );
292
293    query_set.same_device(device)?;
294
295    state
296        .base
297        .device
298        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
299
300    let query_set = state.base.tracker.query_sets.insert_single(query_set);
301
302    query_set.validate_and_write_timestamp(
303        state.base.raw_encoder,
304        query_index,
305        pending_query_resets,
306    )?;
307    Ok(())
308}
309
310pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
311    *state.base.debug_scope_depth += 1;
312    if !state
313        .base
314        .device
315        .instance_flags
316        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
317    {
318        let label =
319            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
320
321        api_log!("Pass::push_debug_group {label:?}");
322        unsafe {
323            state.base.raw_encoder.begin_debug_marker(label);
324        }
325    }
326    state.string_offset += len;
327}
328
329pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
330where
331    E: From<DebugGroupError>,
332{
333    api_log!("Pass::pop_debug_group");
334
335    if *state.base.debug_scope_depth == 0 {
336        return Err(DebugGroupError::InvalidPop.into());
337    }
338    *state.base.debug_scope_depth -= 1;
339    if !state
340        .base
341        .device
342        .instance_flags
343        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
344    {
345        unsafe {
346            state.base.raw_encoder.end_debug_marker();
347        }
348    }
349    Ok(())
350}
351
352pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
353    if !state
354        .base
355        .device
356        .instance_flags
357        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
358    {
359        let label =
360            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
361        api_log!("Pass::insert_debug_marker {label:?}");
362        unsafe {
363            state.base.raw_encoder.insert_debug_marker(label);
364        }
365    }
366    state.string_offset += len;
367}