wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::encoder::EncodingState;
6use crate::command::memory_init::SurfacesInDiscardState;
7use crate::command::{CommandEncoder, DebugGroupError, QueryResetMap, QueryUseError};
8use crate::device::{DeviceError, MissingFeatures};
9use crate::pipeline::LateSizedBufferGroup;
10use crate::ray_tracing::AsAction;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::track::{ResourceUsageCompatibilityError, UsageScope};
13use crate::{api_log, binding_model};
14use alloc::sync::Arc;
15use alloc::vec::Vec;
16use core::str;
17use thiserror::Error;
18use wgt::error::{ErrorType, WebGpuError};
19use wgt::DynamicOffset;
20
21#[derive(Clone, Debug, Error)]
22#[error(
23    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
24)]
25pub struct BindGroupIndexOutOfRange {
26    pub index: u32,
27    pub max: u32,
28}
29
30#[derive(Clone, Debug, Error)]
31#[error("Pipeline must be set")]
32pub struct MissingPipeline;
33
34#[derive(Clone, Debug, Error)]
35#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
36pub struct InvalidValuesOffset;
37
38impl WebGpuError for InvalidValuesOffset {
39    fn webgpu_error_type(&self) -> ErrorType {
40        ErrorType::Validation
41    }
42}
43
44pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
45    pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc, 'raw_encoder>,
46
47    /// Immediate texture inits required because of prior discards. Need to
48    /// be inserted before texture reads.
49    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
50
51    pub(crate) scope: UsageScope<'scope>,
52
53    pub(crate) binder: Binder,
54
55    pub(crate) temp_offsets: Vec<u32>,
56
57    pub(crate) dynamic_offset_count: usize,
58
59    pub(crate) string_offset: usize,
60}
61
62pub(crate) fn set_bind_group<E>(
63    state: &mut PassState,
64    cmd_enc: &CommandEncoder,
65    dynamic_offsets: &[DynamicOffset],
66    index: u32,
67    num_dynamic_offsets: usize,
68    bind_group: Option<Arc<BindGroup>>,
69    merge_bind_groups: bool,
70) -> Result<(), E>
71where
72    E: From<DeviceError>
73        + From<BindGroupIndexOutOfRange>
74        + From<ResourceUsageCompatibilityError>
75        + From<DestroyedResourceError>
76        + From<BindError>,
77{
78    if bind_group.is_none() {
79        api_log!("Pass::set_bind_group {index} None");
80    } else {
81        api_log!(
82            "Pass::set_bind_group {index} {}",
83            bind_group.as_ref().unwrap().error_ident()
84        );
85    }
86
87    let max_bind_groups = state.base.device.limits.max_bind_groups;
88    if index >= max_bind_groups {
89        return Err(BindGroupIndexOutOfRange {
90            index,
91            max: max_bind_groups,
92        }
93        .into());
94    }
95
96    state.temp_offsets.clear();
97    state.temp_offsets.extend_from_slice(
98        &dynamic_offsets
99            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
100    );
101    state.dynamic_offset_count += num_dynamic_offsets;
102
103    if bind_group.is_none() {
104        // TODO: Handle bind_group None.
105        return Ok(());
106    }
107
108    let bind_group = bind_group.unwrap();
109    let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
110
111    bind_group.same_device_as(cmd_enc)?;
112
113    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
114
115    if merge_bind_groups {
116        // merge the resource tracker in
117        unsafe {
118            state.scope.merge_bind_group(&bind_group.used)?;
119        }
120    }
121    //Note: stateless trackers are not merged: the lifetime reference
122    // is held to the bind group itself.
123
124    state
125        .base
126        .buffer_memory_init_actions
127        .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
128            action
129                .buffer
130                .initialization_status
131                .read()
132                .check_action(action)
133        }));
134    for action in bind_group.used_texture_ranges.iter() {
135        state.pending_discard_init_fixups.extend(
136            state
137                .base
138                .texture_memory_actions
139                .register_init_action(action),
140        );
141    }
142
143    let used_resource = bind_group
144        .used
145        .acceleration_structures
146        .into_iter()
147        .map(|tlas| AsAction::UseTlas(tlas.clone()));
148
149    state.base.as_actions.extend(used_resource);
150
151    let pipeline_layout = state.binder.pipeline_layout.clone();
152    let entries = state
153        .binder
154        .assign_group(index as usize, bind_group, &state.temp_offsets);
155    if !entries.is_empty() && pipeline_layout.is_some() {
156        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
157        for (i, e) in entries.iter().enumerate() {
158            if let Some(group) = e.group.as_ref() {
159                let raw_bg = group.try_raw(state.base.snatch_guard)?;
160                unsafe {
161                    state.base.raw_encoder.set_bind_group(
162                        pipeline_layout,
163                        index + i as u32,
164                        Some(raw_bg),
165                        &e.dynamic_offsets,
166                    );
167                }
168            }
169        }
170    }
171    Ok(())
172}
173
174/// After a pipeline has been changed, resources must be rebound
175pub(crate) fn rebind_resources<E, F: FnOnce()>(
176    state: &mut PassState,
177    pipeline_layout: &Arc<binding_model::PipelineLayout>,
178    late_sized_buffer_groups: &[LateSizedBufferGroup],
179    f: F,
180) -> Result<(), E>
181where
182    E: From<DestroyedResourceError>,
183{
184    if state.binder.pipeline_layout.is_none()
185        || !state
186            .binder
187            .pipeline_layout
188            .as_ref()
189            .unwrap()
190            .is_equal(pipeline_layout)
191    {
192        let (start_index, entries) = state
193            .binder
194            .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
195        if !entries.is_empty() {
196            for (i, e) in entries.iter().enumerate() {
197                if let Some(group) = e.group.as_ref() {
198                    let raw_bg = group.try_raw(state.base.snatch_guard)?;
199                    unsafe {
200                        state.base.raw_encoder.set_bind_group(
201                            pipeline_layout.raw(),
202                            start_index as u32 + i as u32,
203                            Some(raw_bg),
204                            &e.dynamic_offsets,
205                        );
206                    }
207                }
208            }
209        }
210
211        f();
212
213        let non_overlapping =
214            super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
215
216        // Clear push constant ranges
217        for range in non_overlapping {
218            let offset = range.range.start;
219            let size_bytes = range.range.end - offset;
220            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
221                state.base.raw_encoder.set_push_constants(
222                    pipeline_layout.raw(),
223                    range.stages,
224                    clear_offset,
225                    clear_data,
226                );
227            });
228        }
229    }
230    Ok(())
231}
232
233pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
234    state: &mut PassState,
235    push_constant_data: &[u32],
236    stages: wgt::ShaderStages,
237    offset: u32,
238    size_bytes: u32,
239    values_offset: Option<u32>,
240    f: F,
241) -> Result<(), E>
242where
243    E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
244{
245    api_log!("Pass::set_push_constants");
246
247    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
248
249    let end_offset_bytes = offset + size_bytes;
250    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
251    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
252
253    let pipeline_layout = state
254        .binder
255        .pipeline_layout
256        .as_ref()
257        .ok_or(MissingPipeline)?;
258
259    pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
260
261    f(data_slice);
262
263    unsafe {
264        state
265            .base
266            .raw_encoder
267            .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
268    }
269    Ok(())
270}
271
272pub(crate) fn write_timestamp<E>(
273    state: &mut PassState,
274    cmd_enc: &CommandEncoder,
275    pending_query_resets: Option<&mut QueryResetMap>,
276    query_set: Arc<QuerySet>,
277    query_index: u32,
278) -> Result<(), E>
279where
280    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
281{
282    api_log!(
283        "Pass::write_timestamps {query_index} {}",
284        query_set.error_ident()
285    );
286
287    query_set.same_device_as(cmd_enc)?;
288
289    state
290        .base
291        .device
292        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
293
294    let query_set = state.base.tracker.query_sets.insert_single(query_set);
295
296    query_set.validate_and_write_timestamp(
297        state.base.raw_encoder,
298        query_index,
299        pending_query_resets,
300    )?;
301    Ok(())
302}
303
304pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
305    *state.base.debug_scope_depth += 1;
306    if !state
307        .base
308        .device
309        .instance_flags
310        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
311    {
312        let label =
313            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
314
315        api_log!("Pass::push_debug_group {label:?}");
316        unsafe {
317            state.base.raw_encoder.begin_debug_marker(label);
318        }
319    }
320    state.string_offset += len;
321}
322
323pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
324where
325    E: From<DebugGroupError>,
326{
327    api_log!("Pass::pop_debug_group");
328
329    if *state.base.debug_scope_depth == 0 {
330        return Err(DebugGroupError::InvalidPop.into());
331    }
332    *state.base.debug_scope_depth -= 1;
333    if !state
334        .base
335        .device
336        .instance_flags
337        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
338    {
339        unsafe {
340            state.base.raw_encoder.end_debug_marker();
341        }
342    }
343    Ok(())
344}
345
346pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
347    if !state
348        .base
349        .device
350        .instance_flags
351        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
352    {
353        let label =
354            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
355        api_log!("Pass::insert_debug_marker {label:?}");
356        unsafe {
357            state.base.raw_encoder.insert_debug_marker(label);
358        }
359    }
360    state.string_offset += len;
361}