wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::memory_init::{CommandBufferTextureMemoryActions, SurfacesInDiscardState};
6use crate::command::{CommandEncoder, DebugGroupError, QueryResetMap, QueryUseError};
7use crate::device::{Device, DeviceError, MissingFeatures};
8use crate::init_tracker::BufferInitTrackerAction;
9use crate::pipeline::LateSizedBufferGroup;
10use crate::ray_tracing::AsAction;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::snatch::SnatchGuard;
13use crate::track::{ResourceUsageCompatibilityError, Tracker, UsageScope};
14use crate::{api_log, binding_model};
15use alloc::sync::Arc;
16use alloc::vec::Vec;
17use core::str;
18use thiserror::Error;
19use wgt::error::{ErrorType, WebGpuError};
20use wgt::DynamicOffset;
21
22#[derive(Clone, Debug, Error)]
23#[error(
24    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
25)]
26pub struct BindGroupIndexOutOfRange {
27    pub index: u32,
28    pub max: u32,
29}
30
31#[derive(Clone, Debug, Error)]
32#[error("Pipeline must be set")]
33pub struct MissingPipeline;
34
35#[derive(Clone, Debug, Error)]
36#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
37pub struct InvalidValuesOffset;
38
39impl WebGpuError for InvalidValuesOffset {
40    fn webgpu_error_type(&self) -> ErrorType {
41        ErrorType::Validation
42    }
43}
44
45pub(crate) struct BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
46    pub(crate) device: &'cmd_enc Arc<Device>,
47
48    pub(crate) raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
49
50    pub(crate) tracker: &'cmd_enc mut Tracker,
51    pub(crate) buffer_memory_init_actions: &'cmd_enc mut Vec<BufferInitTrackerAction>,
52    pub(crate) texture_memory_actions: &'cmd_enc mut CommandBufferTextureMemoryActions,
53    pub(crate) as_actions: &'cmd_enc mut Vec<AsAction>,
54
55    /// Immediate texture inits required because of prior discards. Need to
56    /// be inserted before texture reads.
57    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
58
59    pub(crate) scope: UsageScope<'scope>,
60
61    pub(crate) binder: Binder,
62
63    pub(crate) temp_offsets: Vec<u32>,
64
65    pub(crate) dynamic_offset_count: usize,
66
67    pub(crate) snatch_guard: &'snatch_guard SnatchGuard<'snatch_guard>,
68
69    pub(crate) debug_scope_depth: u32,
70    pub(crate) string_offset: usize,
71}
72
73pub(crate) fn set_bind_group<E>(
74    state: &mut BaseState,
75    cmd_enc: &CommandEncoder,
76    dynamic_offsets: &[DynamicOffset],
77    index: u32,
78    num_dynamic_offsets: usize,
79    bind_group: Option<Arc<BindGroup>>,
80    merge_bind_groups: bool,
81) -> Result<(), E>
82where
83    E: From<DeviceError>
84        + From<BindGroupIndexOutOfRange>
85        + From<ResourceUsageCompatibilityError>
86        + From<DestroyedResourceError>
87        + From<BindError>,
88{
89    if bind_group.is_none() {
90        api_log!("Pass::set_bind_group {index} None");
91    } else {
92        api_log!(
93            "Pass::set_bind_group {index} {}",
94            bind_group.as_ref().unwrap().error_ident()
95        );
96    }
97
98    let max_bind_groups = state.device.limits.max_bind_groups;
99    if index >= max_bind_groups {
100        return Err(BindGroupIndexOutOfRange {
101            index,
102            max: max_bind_groups,
103        }
104        .into());
105    }
106
107    state.temp_offsets.clear();
108    state.temp_offsets.extend_from_slice(
109        &dynamic_offsets
110            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
111    );
112    state.dynamic_offset_count += num_dynamic_offsets;
113
114    if bind_group.is_none() {
115        // TODO: Handle bind_group None.
116        return Ok(());
117    }
118
119    let bind_group = bind_group.unwrap();
120    let bind_group = state.tracker.bind_groups.insert_single(bind_group);
121
122    bind_group.same_device_as(cmd_enc)?;
123
124    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
125
126    if merge_bind_groups {
127        // merge the resource tracker in
128        unsafe {
129            state.scope.merge_bind_group(&bind_group.used)?;
130        }
131    }
132    //Note: stateless trackers are not merged: the lifetime reference
133    // is held to the bind group itself.
134
135    state
136        .buffer_memory_init_actions
137        .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
138            action
139                .buffer
140                .initialization_status
141                .read()
142                .check_action(action)
143        }));
144    for action in bind_group.used_texture_ranges.iter() {
145        state
146            .pending_discard_init_fixups
147            .extend(state.texture_memory_actions.register_init_action(action));
148    }
149
150    let used_resource = bind_group
151        .used
152        .acceleration_structures
153        .into_iter()
154        .map(|tlas| AsAction::UseTlas(tlas.clone()));
155
156    state.as_actions.extend(used_resource);
157
158    let pipeline_layout = state.binder.pipeline_layout.clone();
159    let entries = state
160        .binder
161        .assign_group(index as usize, bind_group, &state.temp_offsets);
162    if !entries.is_empty() && pipeline_layout.is_some() {
163        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
164        for (i, e) in entries.iter().enumerate() {
165            if let Some(group) = e.group.as_ref() {
166                let raw_bg = group.try_raw(state.snatch_guard)?;
167                unsafe {
168                    state.raw_encoder.set_bind_group(
169                        pipeline_layout,
170                        index + i as u32,
171                        Some(raw_bg),
172                        &e.dynamic_offsets,
173                    );
174                }
175            }
176        }
177    }
178    Ok(())
179}
180
181/// After a pipeline has been changed, resources must be rebound
182pub(crate) fn rebind_resources<E, F: FnOnce()>(
183    state: &mut BaseState,
184    pipeline_layout: &Arc<binding_model::PipelineLayout>,
185    late_sized_buffer_groups: &[LateSizedBufferGroup],
186    f: F,
187) -> Result<(), E>
188where
189    E: From<DestroyedResourceError>,
190{
191    if state.binder.pipeline_layout.is_none()
192        || !state
193            .binder
194            .pipeline_layout
195            .as_ref()
196            .unwrap()
197            .is_equal(pipeline_layout)
198    {
199        let (start_index, entries) = state
200            .binder
201            .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
202        if !entries.is_empty() {
203            for (i, e) in entries.iter().enumerate() {
204                if let Some(group) = e.group.as_ref() {
205                    let raw_bg = group.try_raw(state.snatch_guard)?;
206                    unsafe {
207                        state.raw_encoder.set_bind_group(
208                            pipeline_layout.raw(),
209                            start_index as u32 + i as u32,
210                            Some(raw_bg),
211                            &e.dynamic_offsets,
212                        );
213                    }
214                }
215            }
216        }
217
218        f();
219
220        let non_overlapping =
221            super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
222
223        // Clear push constant ranges
224        for range in non_overlapping {
225            let offset = range.range.start;
226            let size_bytes = range.range.end - offset;
227            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
228                state.raw_encoder.set_push_constants(
229                    pipeline_layout.raw(),
230                    range.stages,
231                    clear_offset,
232                    clear_data,
233                );
234            });
235        }
236    }
237    Ok(())
238}
239
240pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
241    state: &mut BaseState,
242    push_constant_data: &[u32],
243    stages: wgt::ShaderStages,
244    offset: u32,
245    size_bytes: u32,
246    values_offset: Option<u32>,
247    f: F,
248) -> Result<(), E>
249where
250    E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
251{
252    api_log!("Pass::set_push_constants");
253
254    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
255
256    let end_offset_bytes = offset + size_bytes;
257    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
258    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
259
260    let pipeline_layout = state
261        .binder
262        .pipeline_layout
263        .as_ref()
264        .ok_or(MissingPipeline)?;
265
266    pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
267
268    f(data_slice);
269
270    unsafe {
271        state
272            .raw_encoder
273            .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
274    }
275    Ok(())
276}
277
278pub(crate) fn write_timestamp<E>(
279    state: &mut BaseState,
280    cmd_enc: &CommandEncoder,
281    pending_query_resets: Option<&mut QueryResetMap>,
282    query_set: Arc<QuerySet>,
283    query_index: u32,
284) -> Result<(), E>
285where
286    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
287{
288    api_log!(
289        "Pass::write_timestamps {query_index} {}",
290        query_set.error_ident()
291    );
292
293    query_set.same_device_as(cmd_enc)?;
294
295    state
296        .device
297        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
298
299    let query_set = state.tracker.query_sets.insert_single(query_set);
300
301    query_set.validate_and_write_timestamp(state.raw_encoder, query_index, pending_query_resets)?;
302    Ok(())
303}
304
305pub(crate) fn push_debug_group(state: &mut BaseState, string_data: &[u8], len: usize) {
306    state.debug_scope_depth += 1;
307    if !state
308        .device
309        .instance_flags
310        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
311    {
312        let label =
313            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
314
315        api_log!("Pass::push_debug_group {label:?}");
316        unsafe {
317            state.raw_encoder.begin_debug_marker(label);
318        }
319    }
320    state.string_offset += len;
321}
322
323pub(crate) fn pop_debug_group<E>(state: &mut BaseState) -> Result<(), E>
324where
325    E: From<DebugGroupError>,
326{
327    api_log!("Pass::pop_debug_group");
328
329    if state.debug_scope_depth == 0 {
330        return Err(DebugGroupError::InvalidPop.into());
331    }
332    state.debug_scope_depth -= 1;
333    if !state
334        .device
335        .instance_flags
336        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
337    {
338        unsafe {
339            state.raw_encoder.end_debug_marker();
340        }
341    }
342    Ok(())
343}
344
345pub(crate) fn insert_debug_marker(state: &mut BaseState, string_data: &[u8], len: usize) {
346    if !state
347        .device
348        .instance_flags
349        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
350    {
351        let label =
352            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
353        api_log!("Pass::insert_debug_marker {label:?}");
354        unsafe {
355            state.raw_encoder.insert_debug_marker(label);
356        }
357    }
358    state.string_offset += len;
359}