wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11    api_log,
12    binding_model::BindError,
13    command::pass::flush_bindings_helper,
14    resource::{RawResourceAccess, Trackable},
15};
16use crate::{
17    binding_model::{ImmediateUploadError, LateMinBufferBindingSizeMismatch},
18    command::{
19        bind::{Binder, BinderError},
20        compute_command::ArcComputeCommand,
21        end_pipeline_statistics_query,
22        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25        PassTimestampWrites, QueryUseError, StateChange,
26    },
27    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28    global::Global,
29    hal_label, id,
30    init_tracker::MemoryInitKind,
31    pipeline::ComputePipeline,
32    resource::{
33        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34    },
35    track::{ResourceUsageCompatibilityError, Tracker},
36    Label,
37};
38use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
39use crate::{
40    command::{
41        encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
42        EncoderStateError, PassStateError, TimestampWritesError,
43    },
44    device::Device,
45};
46
47pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
48
49/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
50/// its validity are two distinct conditions, i.e., the full matrix of
51/// (open, ended) x (valid, invalid) is possible.
52///
53/// The presence or absence of the `parent` `Option` indicates the pass's state.
54/// The presence or absence of an error in `base.error` indicates the pass's
55/// validity.
56pub struct ComputePass {
57    /// All pass data & records is stored here.
58    base: ComputeBasePass,
59
60    /// Parent command encoder that this pass records commands into.
61    ///
62    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
63    /// `None`, then the pass is in the "ended" state.
64    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
65    parent: Option<Arc<CommandEncoder>>,
66
67    timestamp_writes: Option<ArcPassTimestampWrites>,
68
69    // Resource binding dedupe state.
70    current_bind_groups: BindGroupStateChange,
71    current_pipeline: StateChange<id::ComputePipelineId>,
72}
73
74impl ComputePass {
75    /// If the parent command encoder is invalid, the returned pass will be invalid.
76    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
77        let ArcComputePassDescriptor {
78            label,
79            timestamp_writes,
80        } = desc;
81
82        Self {
83            base: BasePass::new(&label),
84            parent: Some(parent),
85            timestamp_writes,
86
87            current_bind_groups: BindGroupStateChange::new(),
88            current_pipeline: StateChange::new(),
89        }
90    }
91
92    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
93        Self {
94            base: BasePass::new_invalid(label, err),
95            parent: Some(parent),
96            timestamp_writes: None,
97            current_bind_groups: BindGroupStateChange::new(),
98            current_pipeline: StateChange::new(),
99        }
100    }
101
102    #[inline]
103    pub fn label(&self) -> Option<&str> {
104        self.base.label.as_deref()
105    }
106}
107
108impl fmt::Debug for ComputePass {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self.parent {
111            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
112            None => write!(f, "ComputePass {{ parent: None }}"),
113        }
114    }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
119    pub label: Label<'a>,
120    /// Defines where and when timestamp values will be written for this pass.
121    pub timestamp_writes: Option<PTW>,
122}
123
124/// cbindgen:ignore
125type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
126
127#[derive(Clone, Debug, Error)]
128#[non_exhaustive]
129pub enum DispatchError {
130    #[error("Compute pipeline must be set")]
131    MissingPipeline(pass::MissingPipeline),
132    #[error(transparent)]
133    IncompatibleBindGroup(#[from] Box<BinderError>),
134    #[error(
135        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
136    )]
137    InvalidGroupSize { current: [u32; 3], limit: u32 },
138    #[error(transparent)]
139    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
140}
141
142impl WebGpuError for DispatchError {
143    fn webgpu_error_type(&self) -> ErrorType {
144        ErrorType::Validation
145    }
146}
147
148/// Error encountered when performing a compute pass.
149#[derive(Clone, Debug, Error)]
150pub enum ComputePassErrorInner {
151    #[error(transparent)]
152    Device(#[from] DeviceError),
153    #[error(transparent)]
154    EncoderState(#[from] EncoderStateError),
155    #[error("Parent encoder is invalid")]
156    InvalidParentEncoder,
157    #[error(transparent)]
158    DebugGroupError(#[from] DebugGroupError),
159    #[error(transparent)]
160    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
161    #[error(transparent)]
162    DestroyedResource(#[from] DestroyedResourceError),
163    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
164    UnalignedIndirectBufferOffset(BufferAddress),
165    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
166    IndirectBufferOverrun {
167        offset: u64,
168        end_offset: u64,
169        buffer_size: u64,
170    },
171    #[error(transparent)]
172    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
173    #[error(transparent)]
174    MissingBufferUsage(#[from] MissingBufferUsageError),
175    #[error(transparent)]
176    Dispatch(#[from] DispatchError),
177    #[error(transparent)]
178    Bind(#[from] BindError),
179    #[error(transparent)]
180    ImmediateData(#[from] ImmediateUploadError),
181    #[error("Immediate data offset must be aligned to 4 bytes")]
182    ImmediateOffsetAlignment,
183    #[error("Immediate data size must be aligned to 4 bytes")]
184    ImmediateDataizeAlignment,
185    #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
186    ImmediateOutOfMemory,
187    #[error(transparent)]
188    QueryUse(#[from] QueryUseError),
189    #[error(transparent)]
190    MissingFeatures(#[from] MissingFeatures),
191    #[error(transparent)]
192    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
193    #[error("The compute pass has already been ended and no further commands can be recorded")]
194    PassEnded,
195    #[error(transparent)]
196    InvalidResource(#[from] InvalidResourceError),
197    #[error(transparent)]
198    TimestampWrites(#[from] TimestampWritesError),
199    // This one is unreachable, but required for generic pass support
200    #[error(transparent)]
201    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
202}
203
204/// Error encountered when performing a compute pass, stored for later reporting
205/// when encoding ends.
206#[derive(Clone, Debug, Error)]
207#[error("{scope}")]
208pub struct ComputePassError {
209    pub scope: PassErrorScope,
210    #[source]
211    pub(super) inner: ComputePassErrorInner,
212}
213
214impl From<pass::MissingPipeline> for ComputePassErrorInner {
215    fn from(value: pass::MissingPipeline) -> Self {
216        Self::Dispatch(DispatchError::MissingPipeline(value))
217    }
218}
219
220impl<E> MapPassErr<ComputePassError> for E
221where
222    E: Into<ComputePassErrorInner>,
223{
224    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
225        ComputePassError {
226            scope,
227            inner: self.into(),
228        }
229    }
230}
231
232impl WebGpuError for ComputePassError {
233    fn webgpu_error_type(&self) -> ErrorType {
234        let Self { scope: _, inner } = self;
235        let e: &dyn WebGpuError = match inner {
236            ComputePassErrorInner::Device(e) => e,
237            ComputePassErrorInner::EncoderState(e) => e,
238            ComputePassErrorInner::DebugGroupError(e) => e,
239            ComputePassErrorInner::DestroyedResource(e) => e,
240            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
241            ComputePassErrorInner::MissingBufferUsage(e) => e,
242            ComputePassErrorInner::Dispatch(e) => e,
243            ComputePassErrorInner::Bind(e) => e,
244            ComputePassErrorInner::ImmediateData(e) => e,
245            ComputePassErrorInner::QueryUse(e) => e,
246            ComputePassErrorInner::MissingFeatures(e) => e,
247            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
248            ComputePassErrorInner::InvalidResource(e) => e,
249            ComputePassErrorInner::TimestampWrites(e) => e,
250            ComputePassErrorInner::InvalidValuesOffset(e) => e,
251
252            ComputePassErrorInner::InvalidParentEncoder
253            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
254            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
255            | ComputePassErrorInner::IndirectBufferOverrun { .. }
256            | ComputePassErrorInner::ImmediateOffsetAlignment
257            | ComputePassErrorInner::ImmediateDataizeAlignment
258            | ComputePassErrorInner::ImmediateOutOfMemory
259            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
260        };
261        e.webgpu_error_type()
262    }
263}
264
265struct State<'scope, 'snatch_guard, 'cmd_enc> {
266    pipeline: Option<Arc<ComputePipeline>>,
267
268    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
269
270    active_query: Option<(Arc<resource::QuerySet>, u32)>,
271
272    immediates: Vec<u32>,
273
274    intermediate_trackers: Tracker,
275}
276
277impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
278    fn is_ready(&self) -> Result<(), DispatchError> {
279        if let Some(pipeline) = self.pipeline.as_ref() {
280            self.pass.binder.check_compatibility(pipeline.as_ref())?;
281            self.pass.binder.check_late_buffer_bindings()?;
282            Ok(())
283        } else {
284            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
285        }
286    }
287
288    /// Flush binding state in preparation for a dispatch.
289    ///
290    /// # Differences between render and compute passes
291    ///
292    /// There are differences between the `flush_bindings` implementations for
293    /// render and compute passes, because render passes have a single usage
294    /// scope for the entire pass, and compute passes have a separate usage
295    /// scope for each dispatch.
296    ///
297    /// For compute passes, bind groups are merged into a fresh usage scope
298    /// here, not into the pass usage scope within calls to `set_bind_group`. As
299    /// specified by WebGPU, for compute passes, we merge only the bind groups
300    /// that are actually used by the pipeline, unlike render passes, which
301    /// merge every bind group that is ever set, even if it is not ultimately
302    /// used by the pipeline.
303    ///
304    /// For compute passes, we call `drain_barriers` here, because barriers may
305    /// be needed before each dispatch if a previous dispatch had a conflicting
306    /// usage. For render passes, barriers are emitted once at the start of the
307    /// render pass.
308    ///
309    /// # Indirect buffer handling
310    ///
311    /// The `indirect_buffer` argument should be passed for any indirect
312    /// dispatch (with or without validation). It will be checked for
313    /// conflicting usages according to WebGPU rules. For the purpose of
314    /// these rules, the fact that we have actually processed the buffer in
315    /// the validation pass is an implementation detail.
316    ///
317    /// The `track_indirect_buffer` argument should be set when doing indirect
318    /// dispatch *without* validation. In this case, the indirect buffer will
319    /// be added to the tracker in order to generate any necessary transitions
320    /// for that usage.
321    ///
322    /// When doing indirect dispatch *with* validation, the indirect buffer is
323    /// processed by the validation pass and is not used by the actual dispatch.
324    /// The indirect validation code handles transitions for the validation
325    /// pass.
326    fn flush_bindings(
327        &mut self,
328        indirect_buffer: Option<&Arc<Buffer>>,
329        track_indirect_buffer: bool,
330    ) -> Result<(), ComputePassErrorInner> {
331        for bind_group in self.pass.binder.list_active() {
332            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
333        }
334
335        // Add the indirect buffer. Because usage scopes are per-dispatch, this
336        // is the only place where INDIRECT usage could be added, and it is safe
337        // for us to remove it below.
338        if let Some(buffer) = indirect_buffer {
339            self.pass
340                .scope
341                .buffers
342                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
343        }
344
345        // For compute, usage scopes are associated with each dispatch and not
346        // with the pass as a whole. However, because the cost of creating and
347        // dropping `UsageScope`s is significant (even with the pool), we
348        // add and then remove usage from a single usage scope.
349
350        for bind_group in self.pass.binder.list_active() {
351            self.intermediate_trackers
352                .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
353        }
354
355        if track_indirect_buffer {
356            self.intermediate_trackers
357                .buffers
358                .set_and_remove_from_usage_scope_sparse(
359                    &mut self.pass.scope.buffers,
360                    indirect_buffer.map(|buf| buf.tracker_index()),
361                );
362        } else if let Some(buffer) = indirect_buffer {
363            self.pass
364                .scope
365                .buffers
366                .remove_usage(buffer, wgt::BufferUses::INDIRECT);
367        }
368
369        flush_bindings_helper(&mut self.pass)?;
370
371        CommandEncoder::drain_barriers(
372            self.pass.base.raw_encoder,
373            &mut self.intermediate_trackers,
374            self.pass.base.snatch_guard,
375        );
376        Ok(())
377    }
378}
379
380// Running the compute pass.
381
382impl Global {
383    /// Creates a compute pass.
384    ///
385    /// If creation fails, an invalid pass is returned. Attempting to record
386    /// commands into an invalid pass is permitted, but a validation error will
387    /// ultimately be generated when the parent encoder is finished, and it is
388    /// not possible to run any commands from the invalid pass.
389    ///
390    /// If successful, puts the encoder into the [`Locked`] state.
391    ///
392    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
393    pub fn command_encoder_begin_compute_pass(
394        &self,
395        encoder_id: id::CommandEncoderId,
396        desc: &ComputePassDescriptor<'_>,
397    ) -> (ComputePass, Option<CommandEncoderError>) {
398        use EncoderStateError as SErr;
399
400        let scope = PassErrorScope::Pass;
401        let hub = &self.hub;
402
403        let label = desc.label.as_deref().map(Cow::Borrowed);
404
405        let cmd_enc = hub.command_encoders.get(encoder_id);
406        let mut cmd_buf_data = cmd_enc.data.lock();
407
408        match cmd_buf_data.lock_encoder() {
409            Ok(()) => {
410                drop(cmd_buf_data);
411                if let Err(err) = cmd_enc.device.check_is_valid() {
412                    return (
413                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414                        None,
415                    );
416                }
417
418                match desc
419                    .timestamp_writes
420                    .as_ref()
421                    .map(|tw| {
422                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
423                            &cmd_enc.device,
424                            &hub.query_sets.read(),
425                            tw,
426                        )
427                    })
428                    .transpose()
429                {
430                    Ok(timestamp_writes) => {
431                        let arc_desc = ArcComputePassDescriptor {
432                            label,
433                            timestamp_writes,
434                        };
435                        (ComputePass::new(cmd_enc, arc_desc), None)
436                    }
437                    Err(err) => (
438                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
439                        None,
440                    ),
441                }
442            }
443            Err(err @ SErr::Locked) => {
444                // Attempting to open a new pass while the encoder is locked
445                // invalidates the encoder, but does not generate a validation
446                // error.
447                cmd_buf_data.invalidate(err.clone());
448                drop(cmd_buf_data);
449                (
450                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
451                    None,
452                )
453            }
454            Err(err @ (SErr::Ended | SErr::Submitted)) => {
455                // Attempting to open a new pass after the encode has ended
456                // generates an immediate validation error.
457                drop(cmd_buf_data);
458                (
459                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
460                    Some(err.into()),
461                )
462            }
463            Err(err @ SErr::Invalid) => {
464                // Passes can be opened even on an invalid encoder. Such passes
465                // are even valid, but since there's no visible side-effect of
466                // the pass being valid and there's no point in storing recorded
467                // commands that will ultimately be discarded, we open an
468                // invalid pass to save that work.
469                drop(cmd_buf_data);
470                (
471                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
472                    None,
473                )
474            }
475            Err(SErr::Unlocked) => {
476                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
477            }
478        }
479    }
480
481    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
482        profiling::scope!(
483            "CommandEncoder::run_compute_pass {}",
484            pass.base.label.as_deref().unwrap_or("")
485        );
486
487        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
488        let mut cmd_buf_data = cmd_enc.data.lock();
489
490        cmd_buf_data.unlock_encoder()?;
491
492        let base = pass.base.take();
493
494        if let Err(ComputePassError {
495            inner:
496                ComputePassErrorInner::EncoderState(
497                    err @ (EncoderStateError::Locked | EncoderStateError::Ended),
498                ),
499            scope: _,
500        }) = base
501        {
502            // Most encoding errors are detected and raised within `finish()`.
503            //
504            // However, we raise a validation error here if the pass was opened
505            // within another pass, or on a finished encoder. The latter is
506            // particularly important, because in that case reporting errors via
507            // `CommandEncoder::finish` is not possible.
508            return Err(err.clone());
509        }
510
511        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
512            Ok(ArcCommand::RunComputePass {
513                pass: base?,
514                timestamp_writes: pass.timestamp_writes.take(),
515            })
516        })
517    }
518}
519
520pub(super) fn encode_compute_pass(
521    parent_state: &mut EncodingState<InnerCommandEncoder>,
522    mut base: BasePass<ArcComputeCommand, Infallible>,
523    mut timestamp_writes: Option<ArcPassTimestampWrites>,
524) -> Result<(), ComputePassError> {
525    let pass_scope = PassErrorScope::Pass;
526
527    let device = parent_state.device;
528
529    // We automatically keep extending command buffers over time, and because
530    // we want to insert a command buffer _before_ what we're about to record,
531    // we need to make sure to close the previous one.
532    parent_state
533        .raw_encoder
534        .close_if_open()
535        .map_pass_err(pass_scope)?;
536    let raw_encoder = parent_state
537        .raw_encoder
538        .open_pass(base.label.as_deref())
539        .map_pass_err(pass_scope)?;
540
541    let mut debug_scope_depth = 0;
542
543    let mut state = State {
544        pipeline: None,
545
546        pass: pass::PassState {
547            base: EncodingState {
548                device,
549                raw_encoder,
550                tracker: parent_state.tracker,
551                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
552                texture_memory_actions: parent_state.texture_memory_actions,
553                as_actions: parent_state.as_actions,
554                temp_resources: parent_state.temp_resources,
555                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
556                snatch_guard: parent_state.snatch_guard,
557                debug_scope_depth: &mut debug_scope_depth,
558            },
559            binder: Binder::new(),
560            temp_offsets: Vec::new(),
561            dynamic_offset_count: 0,
562            pending_discard_init_fixups: SurfacesInDiscardState::new(),
563            scope: device.new_usage_scope(),
564            string_offset: 0,
565        },
566        active_query: None,
567
568        immediates: Vec::new(),
569
570        intermediate_trackers: Tracker::new(),
571    };
572
573    let indices = &device.tracker_indices;
574    state
575        .pass
576        .base
577        .tracker
578        .buffers
579        .set_size(indices.buffers.size());
580    state
581        .pass
582        .base
583        .tracker
584        .textures
585        .set_size(indices.textures.size());
586
587    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
588        if let Some(tw) = timestamp_writes.take() {
589            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
590
591            let query_set = state
592                .pass
593                .base
594                .tracker
595                .query_sets
596                .insert_single(tw.query_set);
597
598            // Unlike in render passes we can't delay resetting the query sets since
599            // there is no auxiliary pass.
600            let range = if let (Some(index_a), Some(index_b)) =
601                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
602            {
603                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
604            } else {
605                tw.beginning_of_pass_write_index
606                    .or(tw.end_of_pass_write_index)
607                    .map(|i| i..i + 1)
608            };
609            // Range should always be Some, both values being None should lead to a validation error.
610            // But no point in erroring over that nuance here!
611            if let Some(range) = range {
612                unsafe {
613                    state
614                        .pass
615                        .base
616                        .raw_encoder
617                        .reset_queries(query_set.raw(), range);
618                }
619            }
620
621            Some(hal::PassTimestampWrites {
622                query_set: query_set.raw(),
623                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
624                end_of_pass_write_index: tw.end_of_pass_write_index,
625            })
626        } else {
627            None
628        };
629
630    let hal_desc = hal::ComputePassDescriptor {
631        label: hal_label(base.label.as_deref(), device.instance_flags),
632        timestamp_writes,
633    };
634
635    unsafe {
636        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
637    }
638
639    for command in base.commands.drain(..) {
640        match command {
641            ArcComputeCommand::SetBindGroup {
642                index,
643                num_dynamic_offsets,
644                bind_group,
645            } => {
646                let scope = PassErrorScope::SetBindGroup;
647                pass::set_bind_group::<ComputePassErrorInner>(
648                    &mut state.pass,
649                    device,
650                    &base.dynamic_offsets,
651                    index,
652                    num_dynamic_offsets,
653                    bind_group,
654                    false,
655                )
656                .map_pass_err(scope)?;
657            }
658            ArcComputeCommand::SetPipeline(pipeline) => {
659                let scope = PassErrorScope::SetPipelineCompute;
660                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
661            }
662            ArcComputeCommand::SetImmediate {
663                offset,
664                size_bytes,
665                values_offset,
666            } => {
667                let scope = PassErrorScope::SetImmediate;
668                pass::set_immediates::<ComputePassErrorInner, _>(
669                    &mut state.pass,
670                    &base.immediates_data,
671                    offset,
672                    size_bytes,
673                    Some(values_offset),
674                    |data_slice| {
675                        let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
676                        let size_in_elements =
677                            (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
678                        state.immediates[offset_in_elements..][..size_in_elements]
679                            .copy_from_slice(data_slice);
680                    },
681                )
682                .map_pass_err(scope)?;
683            }
684            ArcComputeCommand::Dispatch(groups) => {
685                let scope = PassErrorScope::Dispatch { indirect: false };
686                dispatch(&mut state, groups).map_pass_err(scope)?;
687            }
688            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
689                let scope = PassErrorScope::Dispatch { indirect: true };
690                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
691            }
692            ArcComputeCommand::PushDebugGroup { color: _, len } => {
693                pass::push_debug_group(&mut state.pass, &base.string_data, len);
694            }
695            ArcComputeCommand::PopDebugGroup => {
696                let scope = PassErrorScope::PopDebugGroup;
697                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
698                    .map_pass_err(scope)?;
699            }
700            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
701                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
702            }
703            ArcComputeCommand::WriteTimestamp {
704                query_set,
705                query_index,
706            } => {
707                let scope = PassErrorScope::WriteTimestamp;
708                pass::write_timestamp::<ComputePassErrorInner>(
709                    &mut state.pass,
710                    device,
711                    None, // compute passes do not attempt to coalesce query resets
712                    query_set,
713                    query_index,
714                )
715                .map_pass_err(scope)?;
716            }
717            ArcComputeCommand::BeginPipelineStatisticsQuery {
718                query_set,
719                query_index,
720            } => {
721                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
722                validate_and_begin_pipeline_statistics_query(
723                    query_set,
724                    state.pass.base.raw_encoder,
725                    &mut state.pass.base.tracker.query_sets,
726                    device,
727                    query_index,
728                    None,
729                    &mut state.active_query,
730                )
731                .map_pass_err(scope)?;
732            }
733            ArcComputeCommand::EndPipelineStatisticsQuery => {
734                let scope = PassErrorScope::EndPipelineStatisticsQuery;
735                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
736                    .map_pass_err(scope)?;
737            }
738        }
739    }
740
741    if *state.pass.base.debug_scope_depth > 0 {
742        Err(
743            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
744                .map_pass_err(pass_scope),
745        )?;
746    }
747
748    unsafe {
749        state.pass.base.raw_encoder.end_compute_pass();
750    }
751
752    let State {
753        pass: pass::PassState {
754            pending_discard_init_fixups,
755            ..
756        },
757        intermediate_trackers,
758        ..
759    } = state;
760
761    // Stop the current command encoder.
762    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
763
764    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
765    //
766    // Use that buffer to insert barriers and clear discarded images.
767    let transit = parent_state
768        .raw_encoder
769        .open_pass(hal_label(
770            Some("(wgpu internal) Pre Pass"),
771            device.instance_flags,
772        ))
773        .map_pass_err(pass_scope)?;
774    fixup_discarded_surfaces(
775        pending_discard_init_fixups.into_iter(),
776        transit,
777        &mut parent_state.tracker.textures,
778        device,
779        parent_state.snatch_guard,
780    );
781    CommandEncoder::insert_barriers_from_tracker(
782        transit,
783        parent_state.tracker,
784        &intermediate_trackers,
785        parent_state.snatch_guard,
786    );
787    // Close the command encoder, and swap it with the previous.
788    parent_state
789        .raw_encoder
790        .close_and_swap()
791        .map_pass_err(pass_scope)?;
792
793    Ok(())
794}
795
796fn set_pipeline(
797    state: &mut State,
798    device: &Arc<Device>,
799    pipeline: Arc<ComputePipeline>,
800) -> Result<(), ComputePassErrorInner> {
801    pipeline.same_device(device)?;
802
803    state.pipeline = Some(pipeline.clone());
804
805    let pipeline = state
806        .pass
807        .base
808        .tracker
809        .compute_pipelines
810        .insert_single(pipeline)
811        .clone();
812
813    unsafe {
814        state
815            .pass
816            .base
817            .raw_encoder
818            .set_compute_pipeline(pipeline.raw());
819    }
820
821    // Rebind resources
822    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
823        &mut state.pass,
824        &pipeline.layout,
825        &pipeline.late_sized_buffer_groups,
826        || {
827            // This only needs to be here for compute pipelines because they use immediates for
828            // validating indirect draws.
829            state.immediates.clear();
830            // Note that can only be one range for each stage. See the `MoreThanOneImmediateRangePerStage` error.
831            if pipeline.layout.immediate_size != 0 {
832                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
833                let len = pipeline.layout.immediate_size as usize
834                    / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
835                state.immediates.extend(core::iter::repeat_n(0, len));
836            }
837        },
838    )
839}
840
841fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
842    api_log!("ComputePass::dispatch {groups:?}");
843
844    state.is_ready()?;
845
846    state.flush_bindings(None, false)?;
847
848    let groups_size_limit = state
849        .pass
850        .base
851        .device
852        .limits
853        .max_compute_workgroups_per_dimension;
854
855    if groups[0] > groups_size_limit
856        || groups[1] > groups_size_limit
857        || groups[2] > groups_size_limit
858    {
859        return Err(ComputePassErrorInner::Dispatch(
860            DispatchError::InvalidGroupSize {
861                current: groups,
862                limit: groups_size_limit,
863            },
864        ));
865    }
866
867    unsafe {
868        state.pass.base.raw_encoder.dispatch(groups);
869    }
870    Ok(())
871}
872
873fn dispatch_indirect(
874    state: &mut State,
875    device: &Arc<Device>,
876    buffer: Arc<Buffer>,
877    offset: u64,
878) -> Result<(), ComputePassErrorInner> {
879    api_log!("ComputePass::dispatch_indirect");
880
881    buffer.same_device(device)?;
882
883    state.is_ready()?;
884
885    state
886        .pass
887        .base
888        .device
889        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
890
891    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
892
893    if offset % 4 != 0 {
894        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
895    }
896
897    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
898    if end_offset > buffer.size {
899        return Err(ComputePassErrorInner::IndirectBufferOverrun {
900            offset,
901            end_offset,
902            buffer_size: buffer.size,
903        });
904    }
905
906    buffer.check_destroyed(state.pass.base.snatch_guard)?;
907
908    let stride = 3 * 4; // 3 integers, x/y/z group size
909    state.pass.base.buffer_memory_init_actions.extend(
910        buffer.initialization_status.read().create_action(
911            &buffer,
912            offset..(offset + stride),
913            MemoryInitKind::NeedsInitializedMemory,
914        ),
915    );
916
917    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
918        let params = indirect_validation.dispatch.params(
919            &state.pass.base.device.limits,
920            offset,
921            buffer.size,
922        );
923
924        unsafe {
925            state
926                .pass
927                .base
928                .raw_encoder
929                .set_compute_pipeline(params.pipeline);
930        }
931
932        unsafe {
933            state.pass.base.raw_encoder.set_immediates(
934                params.pipeline_layout,
935                0,
936                &[params.offset_remainder as u32 / 4],
937            );
938        }
939
940        unsafe {
941            state.pass.base.raw_encoder.set_bind_group(
942                params.pipeline_layout,
943                0,
944                Some(params.dst_bind_group),
945                &[],
946            );
947        }
948        unsafe {
949            state.pass.base.raw_encoder.set_bind_group(
950                params.pipeline_layout,
951                1,
952                Some(
953                    buffer
954                        .indirect_validation_bind_groups
955                        .get(state.pass.base.snatch_guard)
956                        .unwrap()
957                        .dispatch
958                        .as_ref(),
959                ),
960                &[params.aligned_offset as u32],
961            );
962        }
963
964        let src_transition = state
965            .intermediate_trackers
966            .buffers
967            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
968        let src_barrier = src_transition
969            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
970        unsafe {
971            state
972                .pass
973                .base
974                .raw_encoder
975                .transition_buffers(src_barrier.as_slice());
976        }
977
978        unsafe {
979            state
980                .pass
981                .base
982                .raw_encoder
983                .transition_buffers(&[hal::BufferBarrier {
984                    buffer: params.dst_buffer,
985                    usage: hal::StateTransition {
986                        from: wgt::BufferUses::INDIRECT,
987                        to: wgt::BufferUses::STORAGE_READ_WRITE,
988                    },
989                }]);
990        }
991
992        unsafe {
993            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
994        }
995
996        // reset state
997        {
998            let pipeline = state.pipeline.as_ref().unwrap();
999
1000            unsafe {
1001                state
1002                    .pass
1003                    .base
1004                    .raw_encoder
1005                    .set_compute_pipeline(pipeline.raw());
1006            }
1007
1008            if !state.immediates.is_empty() {
1009                unsafe {
1010                    state.pass.base.raw_encoder.set_immediates(
1011                        pipeline.layout.raw(),
1012                        0,
1013                        &state.immediates,
1014                    );
1015                }
1016            }
1017
1018            for (i, e) in state.pass.binder.list_valid() {
1019                let group = e.group.as_ref().unwrap();
1020                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1021                unsafe {
1022                    state.pass.base.raw_encoder.set_bind_group(
1023                        pipeline.layout.raw(),
1024                        i as u32,
1025                        Some(raw_bg),
1026                        &e.dynamic_offsets,
1027                    );
1028                }
1029            }
1030        }
1031
1032        unsafe {
1033            state
1034                .pass
1035                .base
1036                .raw_encoder
1037                .transition_buffers(&[hal::BufferBarrier {
1038                    buffer: params.dst_buffer,
1039                    usage: hal::StateTransition {
1040                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1041                        to: wgt::BufferUses::INDIRECT,
1042                    },
1043                }]);
1044        }
1045
1046        state.flush_bindings(Some(&buffer), false)?;
1047        unsafe {
1048            state
1049                .pass
1050                .base
1051                .raw_encoder
1052                .dispatch_indirect(params.dst_buffer, 0);
1053        }
1054    } else {
1055        state.flush_bindings(Some(&buffer), true)?;
1056
1057        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1058        unsafe {
1059            state
1060                .pass
1061                .base
1062                .raw_encoder
1063                .dispatch_indirect(buf_raw, offset);
1064        }
1065    }
1066
1067    Ok(())
1068}
1069
1070// Recording a compute pass.
1071//
1072// The only error that should be returned from these methods is
1073// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1074// validation error is raised.
1075//
1076// All other errors should be stored in the pass for later reporting when
1077// `CommandEncoder.finish()` is called.
1078//
1079// The `pass_try!` macro should be used to handle errors appropriately. Note
1080// that the `pass_try!` and `pass_base!` macros may return early from the
1081// function that invokes them, like the `?` operator.
1082impl Global {
1083    pub fn compute_pass_set_bind_group(
1084        &self,
1085        pass: &mut ComputePass,
1086        index: u32,
1087        bind_group_id: Option<id::BindGroupId>,
1088        offsets: &[DynamicOffset],
1089    ) -> Result<(), PassStateError> {
1090        let scope = PassErrorScope::SetBindGroup;
1091
1092        // This statement will return an error if the pass is ended. It's
1093        // important the error check comes before the early-out for
1094        // `set_and_check_redundant`.
1095        let base = pass_base!(pass, scope);
1096
1097        if pass.current_bind_groups.set_and_check_redundant(
1098            bind_group_id,
1099            index,
1100            &mut base.dynamic_offsets,
1101            offsets,
1102        ) {
1103            return Ok(());
1104        }
1105
1106        let mut bind_group = None;
1107        if let Some(bind_group_id) = bind_group_id {
1108            let hub = &self.hub;
1109            bind_group = Some(pass_try!(
1110                base,
1111                scope,
1112                hub.bind_groups.get(bind_group_id).get(),
1113            ));
1114        }
1115
1116        base.commands.push(ArcComputeCommand::SetBindGroup {
1117            index,
1118            num_dynamic_offsets: offsets.len(),
1119            bind_group,
1120        });
1121
1122        Ok(())
1123    }
1124
1125    pub fn compute_pass_set_pipeline(
1126        &self,
1127        pass: &mut ComputePass,
1128        pipeline_id: id::ComputePipelineId,
1129    ) -> Result<(), PassStateError> {
1130        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1131
1132        let scope = PassErrorScope::SetPipelineCompute;
1133
1134        // This statement will return an error if the pass is ended.
1135        // Its important the error check comes before the early-out for `redundant`.
1136        let base = pass_base!(pass, scope);
1137
1138        if redundant {
1139            return Ok(());
1140        }
1141
1142        let hub = &self.hub;
1143        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1144
1145        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1146
1147        Ok(())
1148    }
1149
1150    pub fn compute_pass_set_immediates(
1151        &self,
1152        pass: &mut ComputePass,
1153        offset: u32,
1154        data: &[u8],
1155    ) -> Result<(), PassStateError> {
1156        let scope = PassErrorScope::SetImmediate;
1157        let base = pass_base!(pass, scope);
1158
1159        if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1160            pass_try!(
1161                base,
1162                scope,
1163                Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1164            );
1165        }
1166
1167        if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1168            pass_try!(
1169                base,
1170                scope,
1171                Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1172            )
1173        }
1174        let value_offset = pass_try!(
1175            base,
1176            scope,
1177            base.immediates_data
1178                .len()
1179                .try_into()
1180                .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1181        );
1182
1183        base.immediates_data.extend(
1184            data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1185                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1186        );
1187
1188        base.commands.push(ArcComputeCommand::SetImmediate {
1189            offset,
1190            size_bytes: data.len() as u32,
1191            values_offset: value_offset,
1192        });
1193
1194        Ok(())
1195    }
1196
1197    pub fn compute_pass_dispatch_workgroups(
1198        &self,
1199        pass: &mut ComputePass,
1200        groups_x: u32,
1201        groups_y: u32,
1202        groups_z: u32,
1203    ) -> Result<(), PassStateError> {
1204        let scope = PassErrorScope::Dispatch { indirect: false };
1205
1206        pass_base!(pass, scope)
1207            .commands
1208            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1209
1210        Ok(())
1211    }
1212
1213    pub fn compute_pass_dispatch_workgroups_indirect(
1214        &self,
1215        pass: &mut ComputePass,
1216        buffer_id: id::BufferId,
1217        offset: BufferAddress,
1218    ) -> Result<(), PassStateError> {
1219        let hub = &self.hub;
1220        let scope = PassErrorScope::Dispatch { indirect: true };
1221        let base = pass_base!(pass, scope);
1222
1223        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1224
1225        base.commands
1226            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1227
1228        Ok(())
1229    }
1230
1231    pub fn compute_pass_push_debug_group(
1232        &self,
1233        pass: &mut ComputePass,
1234        label: &str,
1235        color: u32,
1236    ) -> Result<(), PassStateError> {
1237        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1238
1239        let bytes = label.as_bytes();
1240        base.string_data.extend_from_slice(bytes);
1241
1242        base.commands.push(ArcComputeCommand::PushDebugGroup {
1243            color,
1244            len: bytes.len(),
1245        });
1246
1247        Ok(())
1248    }
1249
1250    pub fn compute_pass_pop_debug_group(
1251        &self,
1252        pass: &mut ComputePass,
1253    ) -> Result<(), PassStateError> {
1254        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1255
1256        base.commands.push(ArcComputeCommand::PopDebugGroup);
1257
1258        Ok(())
1259    }
1260
1261    pub fn compute_pass_insert_debug_marker(
1262        &self,
1263        pass: &mut ComputePass,
1264        label: &str,
1265        color: u32,
1266    ) -> Result<(), PassStateError> {
1267        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1268
1269        let bytes = label.as_bytes();
1270        base.string_data.extend_from_slice(bytes);
1271
1272        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1273            color,
1274            len: bytes.len(),
1275        });
1276
1277        Ok(())
1278    }
1279
1280    pub fn compute_pass_write_timestamp(
1281        &self,
1282        pass: &mut ComputePass,
1283        query_set_id: id::QuerySetId,
1284        query_index: u32,
1285    ) -> Result<(), PassStateError> {
1286        let scope = PassErrorScope::WriteTimestamp;
1287        let base = pass_base!(pass, scope);
1288
1289        let hub = &self.hub;
1290        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1291
1292        base.commands.push(ArcComputeCommand::WriteTimestamp {
1293            query_set,
1294            query_index,
1295        });
1296
1297        Ok(())
1298    }
1299
1300    pub fn compute_pass_begin_pipeline_statistics_query(
1301        &self,
1302        pass: &mut ComputePass,
1303        query_set_id: id::QuerySetId,
1304        query_index: u32,
1305    ) -> Result<(), PassStateError> {
1306        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1307        let base = pass_base!(pass, scope);
1308
1309        let hub = &self.hub;
1310        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1311
1312        base.commands
1313            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1314                query_set,
1315                query_index,
1316            });
1317
1318        Ok(())
1319    }
1320
1321    pub fn compute_pass_end_pipeline_statistics_query(
1322        &self,
1323        pass: &mut ComputePass,
1324    ) -> Result<(), PassStateError> {
1325        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1326            .commands
1327            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1328
1329        Ok(())
1330    }
1331}