wgpu_core/command/
compute.rs

1use parking_lot::Mutex;
2use thiserror::Error;
3use wgt::{
4    error::{ErrorType, WebGpuError},
5    BufferAddress, DynamicOffset,
6};
7
8use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
9use core::{convert::Infallible, fmt, str};
10
11use crate::{
12    api_log,
13    binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
14    command::{
15        bind::{Binder, BinderError},
16        compute_command::ArcComputeCommand,
17        encoder::EncodingState,
18        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
19        pass::{self, flush_bindings_helper},
20        pass_base, pass_try,
21        query::{
22            end_pipeline_statistics_query, record_pass_timestamp_writes,
23            validate_and_begin_pipeline_statistics_query,
24        },
25        ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
26        CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
27        PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
28        TimestampWritesError, TransitionResourcesError,
29    },
30    device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
31    global::Global,
32    hal_label, id, impl_resource_type,
33    init_tracker::MemoryInitKind,
34    pipeline::ComputePipeline,
35    resource::{
36        self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
37        MissingBufferUsageError, ParentDevice, RawResourceAccess, TextureView, Trackable,
38    },
39    track::{ResourceUsageCompatibilityError, TextureViewBindGroupState, Tracker},
40    Label,
41};
42
43pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
44
45/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
46/// its validity are two distinct conditions, i.e., the full matrix of
47/// (open, ended) x (valid, invalid) is possible.
48///
49/// The presence or absence of the `parent` `Option` indicates the pass's state.
50/// The presence or absence of an error in `base.error` indicates the pass's
51/// validity.
52pub struct ComputePass {
53    /// All pass data & records is stored here.
54    base: ComputeBasePass,
55
56    /// Parent command encoder that this pass records commands into.
57    ///
58    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
59    /// `None`, then the pass is in the "ended" state.
60    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
61    parent: Option<Arc<CommandEncoder>>,
62
63    timestamp_writes: Option<ArcPassTimestampWrites>,
64
65    // Resource binding dedupe state.
66    current_bind_groups: BindGroupStateChange,
67    current_pipeline: StateChange<id::ComputePipelineId>,
68}
69
70impl_resource_type!(ComputePass);
71
72impl crate::storage::StorageItem for ComputePass {
73    type Marker = id::markers::ComputePassEncoder;
74}
75
76impl ComputePass {
77    /// If the parent command encoder is invalid, the returned pass will be invalid.
78    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
79        let ArcComputePassDescriptor {
80            label,
81            timestamp_writes,
82        } = desc;
83
84        Self {
85            base: BasePass::new(&label),
86            parent: Some(parent),
87            timestamp_writes,
88
89            current_bind_groups: BindGroupStateChange::new(),
90            current_pipeline: StateChange::new(),
91        }
92    }
93
94    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
95        Self {
96            base: BasePass::new_invalid(label, err),
97            parent: Some(parent),
98            timestamp_writes: None,
99            current_bind_groups: BindGroupStateChange::new(),
100            current_pipeline: StateChange::new(),
101        }
102    }
103
104    #[inline]
105    pub fn label(&self) -> Option<&str> {
106        self.base.label.as_deref()
107    }
108}
109
110impl fmt::Debug for ComputePass {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self.parent {
113            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
114            None => write!(f, "ComputePass {{ parent: None }}"),
115        }
116    }
117}
118
119#[derive(Clone, Debug, Default)]
120#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
121pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
122    pub label: Label<'a>,
123    /// Defines where and when timestamp values will be written for this pass.
124    pub timestamp_writes: Option<PTW>,
125}
126
127/// cbindgen:ignore
128type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
129
130#[derive(Clone, Debug, Error)]
131#[non_exhaustive]
132pub enum DispatchError {
133    #[error("Compute pipeline must be set")]
134    MissingPipeline(pass::MissingPipeline),
135    #[error(transparent)]
136    IncompatibleBindGroup(#[from] Box<BinderError>),
137    #[error(
138        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
139    )]
140    InvalidGroupSize { current: [u32; 3], limit: u32 },
141    #[error(transparent)]
142    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
143    #[error("Not all immediate data required by the pipeline has been set via set_immediates (missing byte ranges: {missing})")]
144    MissingImmediateData {
145        missing: naga::valid::ImmediateSlots,
146    },
147}
148
149impl WebGpuError for DispatchError {
150    fn webgpu_error_type(&self) -> ErrorType {
151        ErrorType::Validation
152    }
153}
154
155/// Error encountered when performing a compute pass.
156#[derive(Clone, Debug, Error)]
157pub enum ComputePassErrorInner {
158    #[error(transparent)]
159    Device(#[from] DeviceError),
160    #[error(transparent)]
161    EncoderState(#[from] EncoderStateError),
162    #[error("Parent encoder is invalid")]
163    InvalidParentEncoder,
164    #[error(transparent)]
165    DebugGroupError(#[from] DebugGroupError),
166    #[error(transparent)]
167    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
168    #[error(transparent)]
169    DestroyedResource(#[from] DestroyedResourceError),
170    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
171    UnalignedIndirectBufferOffset(BufferAddress),
172    #[error("Indirect buffer of {args_size} bytes starting at offset {offset} would overrun buffer of size {buffer_size}")]
173    IndirectBufferOverrun {
174        args_size: u64,
175        offset: u64,
176        buffer_size: u64,
177    },
178    #[error(transparent)]
179    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
180    #[error(transparent)]
181    MissingBufferUsage(#[from] MissingBufferUsageError),
182    #[error(transparent)]
183    Dispatch(#[from] DispatchError),
184    #[error(transparent)]
185    Bind(#[from] BindError),
186    #[error(transparent)]
187    ImmediateData(#[from] ImmediateUploadError),
188    #[error(transparent)]
189    QueryUse(#[from] QueryUseError),
190    #[error(transparent)]
191    TransitionResources(#[from] TransitionResourcesError),
192    #[error(transparent)]
193    MissingFeatures(#[from] MissingFeatures),
194    #[error(transparent)]
195    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
196    #[error("The compute pass has already been ended and no further commands can be recorded")]
197    PassEnded,
198    #[error(transparent)]
199    InvalidResource(#[from] InvalidResourceError),
200    #[error(transparent)]
201    TimestampWrites(#[from] TimestampWritesError),
202    // This one is unreachable, but required for generic pass support
203    #[error(transparent)]
204    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
205}
206
207/// Error encountered when performing a compute pass, stored for later reporting
208/// when encoding ends.
209#[derive(Clone, Debug, Error)]
210#[error("{scope}")]
211pub struct ComputePassError {
212    pub scope: PassErrorScope,
213    #[source]
214    pub(super) inner: ComputePassErrorInner,
215}
216
217impl From<pass::MissingPipeline> for ComputePassErrorInner {
218    fn from(value: pass::MissingPipeline) -> Self {
219        Self::Dispatch(DispatchError::MissingPipeline(value))
220    }
221}
222
223impl<E> MapPassErr<ComputePassError> for E
224where
225    E: Into<ComputePassErrorInner>,
226{
227    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
228        ComputePassError {
229            scope,
230            inner: self.into(),
231        }
232    }
233}
234
235impl WebGpuError for ComputePassError {
236    fn webgpu_error_type(&self) -> ErrorType {
237        let Self { scope: _, inner } = self;
238        match inner {
239            ComputePassErrorInner::Device(e) => e.webgpu_error_type(),
240            ComputePassErrorInner::EncoderState(e) => e.webgpu_error_type(),
241            ComputePassErrorInner::DebugGroupError(e) => e.webgpu_error_type(),
242            ComputePassErrorInner::DestroyedResource(e) => e.webgpu_error_type(),
243            ComputePassErrorInner::ResourceUsageCompatibility(e) => e.webgpu_error_type(),
244            ComputePassErrorInner::MissingBufferUsage(e) => e.webgpu_error_type(),
245            ComputePassErrorInner::Dispatch(e) => e.webgpu_error_type(),
246            ComputePassErrorInner::Bind(e) => e.webgpu_error_type(),
247            ComputePassErrorInner::ImmediateData(e) => e.webgpu_error_type(),
248            ComputePassErrorInner::QueryUse(e) => e.webgpu_error_type(),
249            ComputePassErrorInner::TransitionResources(e) => e.webgpu_error_type(),
250            ComputePassErrorInner::MissingFeatures(e) => e.webgpu_error_type(),
251            ComputePassErrorInner::MissingDownlevelFlags(e) => e.webgpu_error_type(),
252            ComputePassErrorInner::InvalidResource(e) => e.webgpu_error_type(),
253            ComputePassErrorInner::TimestampWrites(e) => e.webgpu_error_type(),
254            ComputePassErrorInner::InvalidValuesOffset(e) => e.webgpu_error_type(),
255
256            ComputePassErrorInner::InvalidParentEncoder
257            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
258            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
259            | ComputePassErrorInner::IndirectBufferOverrun { .. }
260            | ComputePassErrorInner::PassEnded => ErrorType::Validation,
261        }
262    }
263}
264
265struct State<'scope, 'snatch_guard, 'cmd_enc> {
266    pipeline: Option<Arc<ComputePipeline>>,
267
268    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
269
270    active_query: Option<(Arc<resource::QuerySet>, u32)>,
271
272    immediates: Vec<u32>,
273
274    /// A bitmask, tracking which 4-byte slots have been written via `set_immediates`.
275    /// Checked against the pipeline's required slots before each dispatch.
276    immediate_slots_set: naga::valid::ImmediateSlots,
277
278    intermediate_trackers: Tracker,
279}
280
281impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
282    fn is_ready(&self) -> Result<(), DispatchError> {
283        if let Some(pipeline) = self.pipeline.as_ref() {
284            self.pass.binder.check_compatibility(pipeline.as_ref())?;
285            self.pass.binder.check_late_buffer_bindings()?;
286            if !self
287                .immediate_slots_set
288                .contains(pipeline.immediate_slots_required)
289            {
290                return Err(DispatchError::MissingImmediateData {
291                    missing: pipeline
292                        .immediate_slots_required
293                        .difference(self.immediate_slots_set),
294                });
295            }
296            Ok(())
297        } else {
298            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
299        }
300    }
301
302    /// Flush binding state in preparation for a dispatch.
303    ///
304    /// # Differences between render and compute passes
305    ///
306    /// There are differences between the `flush_bindings` implementations for
307    /// render and compute passes, because render passes have a single usage
308    /// scope for the entire pass, and compute passes have a separate usage
309    /// scope for each dispatch.
310    ///
311    /// For compute passes, bind groups are merged into a fresh usage scope
312    /// here, not into the pass usage scope within calls to `set_bind_group`. As
313    /// specified by WebGPU, for compute passes, we merge only the bind groups
314    /// that are actually used by the pipeline, unlike render passes, which
315    /// merge every bind group that is ever set, even if it is not ultimately
316    /// used by the pipeline.
317    ///
318    /// For compute passes, we call `drain_barriers` here, because barriers may
319    /// be needed before each dispatch if a previous dispatch had a conflicting
320    /// usage. For render passes, barriers are emitted once at the start of the
321    /// render pass.
322    ///
323    /// # Indirect buffer handling
324    ///
325    /// The `indirect_buffer` argument should be passed for any indirect
326    /// dispatch (with or without validation). It will be checked for
327    /// conflicting usages according to WebGPU rules. For the purpose of
328    /// these rules, the fact that we have actually processed the buffer in
329    /// the validation pass is an implementation detail.
330    ///
331    /// The `track_indirect_buffer` argument should be set when doing indirect
332    /// dispatch *without* validation. In this case, the indirect buffer will
333    /// be added to the tracker in order to generate any necessary transitions
334    /// for that usage.
335    ///
336    /// When doing indirect dispatch *with* validation, the indirect buffer is
337    /// processed by the validation pass and is not used by the actual dispatch.
338    /// The indirect validation code handles transitions for the validation
339    /// pass.
340    fn flush_bindings(
341        &mut self,
342        indirect_buffer: Option<&Arc<Buffer>>,
343        track_indirect_buffer: bool,
344    ) -> Result<(), ComputePassErrorInner> {
345        for bind_group in self.pass.binder.list_active() {
346            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
347        }
348
349        // Add the indirect buffer. Because usage scopes are per-dispatch, this
350        // is the only place where INDIRECT usage could be added, and it is safe
351        // for us to remove it below.
352        if let Some(buffer) = indirect_buffer {
353            self.pass
354                .scope
355                .buffers
356                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
357        }
358
359        // For compute, usage scopes are associated with each dispatch and not
360        // with the pass as a whole. However, because the cost of creating and
361        // dropping `UsageScope`s is significant (even with the pool), we
362        // add and then remove usage from a single usage scope.
363
364        for bind_group in self.pass.binder.list_active() {
365            self.intermediate_trackers
366                .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
367        }
368
369        if track_indirect_buffer {
370            self.intermediate_trackers
371                .buffers
372                .set_and_remove_from_usage_scope_sparse(
373                    &mut self.pass.scope.buffers,
374                    indirect_buffer.map(|buf| buf.tracker_index()),
375                );
376        } else if let Some(buffer) = indirect_buffer {
377            self.pass
378                .scope
379                .buffers
380                .remove_usage(buffer, wgt::BufferUses::INDIRECT);
381        }
382
383        flush_bindings_helper(&mut self.pass)?;
384
385        CommandEncoder::drain_barriers(
386            self.pass.base.raw_encoder,
387            &mut self.intermediate_trackers,
388            self.pass.base.snatch_guard,
389        );
390        Ok(())
391    }
392}
393
394/// Compute pass version of [`command::transition_resources`](crate::command::transition_resources).
395/// See also `State::flush_bindings` for details on the implementation.
396fn transition_resources(
397    state: &mut State,
398    buffer_transitions: Vec<wgt::BufferTransition<Arc<Buffer>>>,
399    texture_transitions: Vec<wgt::TextureTransition<Arc<TextureView>>>,
400) -> Result<(), TransitionResourcesError> {
401    let indices = &state.pass.base.device.tracker_indices;
402    state.pass.scope.buffers.set_size(indices.buffers.size());
403    state.pass.scope.textures.set_size(indices.textures.size());
404
405    let mut buffer_ids = Vec::with_capacity(buffer_transitions.len());
406    let mut textures = TextureViewBindGroupState::new();
407
408    // Process buffer transitions
409    for buffer_transition in buffer_transitions {
410        buffer_transition
411            .buffer
412            .same_device(state.pass.base.device)?;
413
414        state
415            .pass
416            .scope
417            .buffers
418            .merge_single(&buffer_transition.buffer, buffer_transition.state)?;
419        buffer_ids.push(buffer_transition.buffer.tracker_index());
420    }
421
422    state
423        .intermediate_trackers
424        .buffers
425        .set_and_remove_from_usage_scope_sparse(&mut state.pass.scope.buffers, buffer_ids);
426
427    // Process texture transitions
428    for texture_transition in texture_transitions {
429        texture_transition
430            .texture
431            .same_device(state.pass.base.device)?;
432
433        unsafe {
434            state.pass.scope.textures.merge_single(
435                &texture_transition.texture.parent,
436                texture_transition.selector,
437                texture_transition.state,
438            )
439        }?;
440
441        textures.insert_single(texture_transition.texture, texture_transition.state);
442    }
443
444    state
445        .intermediate_trackers
446        .textures
447        .set_and_remove_from_usage_scope_sparse(&mut state.pass.scope.textures, &textures);
448
449    // Record any needed barriers based on tracker data
450    CommandEncoder::drain_barriers(
451        state.pass.base.raw_encoder,
452        &mut state.intermediate_trackers,
453        state.pass.base.snatch_guard,
454    );
455    Ok(())
456}
457
458impl CommandEncoder {
459    fn begin_compute_pass(
460        self: &Arc<Self>,
461        desc: &ComputePassDescriptor<'_, PassTimestampWrites<Arc<resource::QuerySet>>>,
462    ) -> (ComputePass, Option<CommandEncoderError>) {
463        use EncoderStateError as SErr;
464
465        let scope = PassErrorScope::Pass;
466
467        let label = desc.label.as_deref().map(Cow::Borrowed);
468
469        let mut cmd_buf_data = self.data.lock();
470
471        match cmd_buf_data.lock_encoder() {
472            Ok(()) => {
473                drop(cmd_buf_data);
474                if let Err(err) = self.device.check_is_valid() {
475                    return (
476                        ComputePass::new_invalid(Arc::clone(self), &label, err.map_pass_err(scope)),
477                        None,
478                    );
479                }
480
481                match desc
482                    .timestamp_writes
483                    .as_ref()
484                    .map(|tw| {
485                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
486                            &self.device,
487                            tw,
488                        )
489                    })
490                    .transpose()
491                {
492                    Ok(timestamp_writes) => {
493                        let arc_desc = ArcComputePassDescriptor {
494                            label,
495                            timestamp_writes,
496                        };
497                        (ComputePass::new(Arc::clone(self), arc_desc), None)
498                    }
499                    Err(err) => (
500                        ComputePass::new_invalid(Arc::clone(self), &label, err.map_pass_err(scope)),
501                        None,
502                    ),
503                }
504            }
505            Err(err @ SErr::Locked) => {
506                // Attempting to open a new pass while the encoder is locked
507                // invalidates the encoder, but does not generate a validation
508                // error.
509                cmd_buf_data.invalidate(err.clone());
510                drop(cmd_buf_data);
511                (
512                    ComputePass::new_invalid(Arc::clone(self), &label, err.map_pass_err(scope)),
513                    None,
514                )
515            }
516            Err(err @ (SErr::Ended | SErr::Submitted)) => {
517                // Attempting to open a new pass after the encode has ended
518                // generates an immediate validation error.
519                drop(cmd_buf_data);
520                (
521                    ComputePass::new_invalid(
522                        Arc::clone(self),
523                        &label,
524                        err.clone().map_pass_err(scope),
525                    ),
526                    Some(err.into()),
527                )
528            }
529            Err(err @ SErr::Invalid) => {
530                // Passes can be opened even on an invalid encoder. Such passes
531                // are even valid, but since there's no visible side-effect of
532                // the pass being valid and there's no point in storing recorded
533                // commands that will ultimately be discarded, we open an
534                // invalid pass to save that work.
535                drop(cmd_buf_data);
536                (
537                    ComputePass::new_invalid(Arc::clone(self), &label, err.map_pass_err(scope)),
538                    None,
539                )
540            }
541            Err(SErr::Unlocked) => {
542                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
543            }
544        }
545    }
546}
547
548// Running the compute pass.
549
550impl Global {
551    /// Creates a compute pass.
552    ///
553    /// If creation fails, an invalid pass is returned. Attempting to record
554    /// commands into an invalid pass is permitted, but a validation error will
555    /// ultimately be generated when the parent encoder is finished, and it is
556    /// not possible to run any commands from the invalid pass.
557    ///
558    /// If successful, puts the encoder into the [`Locked`] state.
559    ///
560    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
561    pub fn command_encoder_begin_compute_pass(
562        &self,
563        encoder_id: id::CommandEncoderId,
564        desc: &ComputePassDescriptor<'_>,
565    ) -> (ComputePass, Option<CommandEncoderError>) {
566        let hub = &self.hub;
567
568        let cmd_enc = hub.command_encoders.get(encoder_id);
569
570        let desc = ComputePassDescriptor {
571            label: desc.label.as_deref().map(Cow::Borrowed),
572            timestamp_writes: desc
573                .timestamp_writes
574                .as_ref()
575                .map(|tw| PassTimestampWrites {
576                    query_set: hub.query_sets.get(tw.query_set),
577                    beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
578                    end_of_pass_write_index: tw.end_of_pass_write_index,
579                }),
580        };
581
582        cmd_enc.begin_compute_pass(&desc)
583    }
584
585    pub fn command_encoder_begin_compute_pass_with_id(
586        &self,
587        encoder_id: id::CommandEncoderId,
588        desc: &ComputePassDescriptor<'_>,
589        id_in: Option<id::ComputePassEncoderId>,
590    ) -> (id::ComputePassEncoderId, Option<CommandEncoderError>) {
591        let fid = self.hub.compute_passes.prepare(id_in);
592
593        let (pass, err) = self.command_encoder_begin_compute_pass(encoder_id, desc);
594
595        // no lock rank here because only one thread should be using compute pass
596        // and it's only used by id variants of compute pass methods on global
597        // so no deadlock (or concurrent lock) should happen in practise
598        let id = fid.assign(Arc::new(Mutex::new(pass)));
599
600        (id, err)
601    }
602
603    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
604        profiling::scope!(
605            "CommandEncoder::run_compute_pass {}",
606            pass.base.label.as_deref().unwrap_or("")
607        );
608
609        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
610        let mut cmd_buf_data = cmd_enc.data.lock();
611
612        cmd_buf_data.unlock_encoder()?;
613
614        let base = pass.base.take();
615
616        if let Err(ComputePassError {
617            inner:
618                ComputePassErrorInner::EncoderState(
619                    err @ (EncoderStateError::Locked | EncoderStateError::Ended),
620                ),
621            scope: _,
622        }) = base
623        {
624            // Most encoding errors are detected and raised within `finish()`.
625            //
626            // However, we raise a validation error here if the pass was opened
627            // within another pass, or on a finished encoder. The latter is
628            // particularly important, because in that case reporting errors via
629            // `CommandEncoder::finish` is not possible.
630            return Err(err.clone());
631        }
632
633        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
634            Ok(ArcCommand::RunComputePass {
635                pass: base?,
636                timestamp_writes: pass.timestamp_writes.take(),
637            })
638        })
639    }
640
641    pub fn compute_pass_end_with_id(
642        &self,
643        pass_id: id::ComputePassEncoderId,
644    ) -> Result<(), EncoderStateError> {
645        let pass = self.hub.compute_passes.get(pass_id);
646        let mut pass = pass
647            .try_lock()
648            .expect("ComputePasses should not be accessed concurrently");
649        self.compute_pass_end(&mut pass)
650    }
651
652    pub fn compute_pass_drop(&self, pass_id: id::ComputePassEncoderId) {
653        self.hub.compute_passes.remove(pass_id);
654    }
655}
656
657pub(super) fn encode_compute_pass(
658    parent_state: &mut EncodingState<InnerCommandEncoder>,
659    mut base: BasePass<ArcComputeCommand, Infallible>,
660    mut timestamp_writes: Option<ArcPassTimestampWrites>,
661) -> Result<(), ComputePassError> {
662    let pass_scope = PassErrorScope::Pass;
663
664    let device = parent_state.device;
665
666    // We automatically keep extending command buffers over time, and because
667    // we want to insert a command buffer _before_ what we're about to record,
668    // we need to make sure to close the previous one.
669    parent_state
670        .raw_encoder
671        .close_if_open()
672        .map_pass_err(pass_scope)?;
673    let raw_encoder = parent_state
674        .raw_encoder
675        .open_pass(base.label.as_deref())
676        .map_pass_err(pass_scope)?;
677
678    let mut debug_scope_depth = 0;
679
680    let mut state = State {
681        pipeline: None,
682
683        pass: pass::PassState {
684            base: EncodingState {
685                device,
686                raw_encoder,
687                tracker: parent_state.tracker,
688                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
689                texture_memory_actions: parent_state.texture_memory_actions,
690                as_actions: parent_state.as_actions,
691                temp_resources: parent_state.temp_resources,
692                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
693                snatch_guard: parent_state.snatch_guard,
694                debug_scope_depth: &mut debug_scope_depth,
695                query_set_writes: parent_state.query_set_writes,
696                deferred_query_set_resolves: parent_state.deferred_query_set_resolves,
697            },
698            binder: Binder::new(),
699            temp_offsets: Vec::new(),
700            dynamic_offset_count: 0,
701            pending_discard_init_fixups: SurfacesInDiscardState::new(),
702            scope: device.new_usage_scope(),
703            string_offset: 0,
704        },
705        active_query: None,
706
707        immediates: Vec::new(),
708
709        immediate_slots_set: Default::default(),
710
711        intermediate_trackers: Tracker::new(
712            device.ordered_buffer_usages,
713            device.ordered_texture_usages,
714        ),
715    };
716
717    let indices = &device.tracker_indices;
718    state
719        .pass
720        .base
721        .tracker
722        .buffers
723        .set_size(indices.buffers.size());
724    state
725        .pass
726        .base
727        .tracker
728        .textures
729        .set_size(indices.textures.size());
730
731    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
732        if let Some(tw) = timestamp_writes.take() {
733            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
734
735            record_pass_timestamp_writes(&tw, state.pass.base.query_set_writes);
736
737            let query_set = state
738                .pass
739                .base
740                .tracker
741                .query_sets
742                .insert_single(tw.query_set);
743
744            // Unlike in render passes we can't delay resetting the query sets since
745            // there is no auxiliary pass.
746            let range = if let (Some(index_a), Some(index_b)) =
747                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
748            {
749                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
750            } else {
751                tw.beginning_of_pass_write_index
752                    .or(tw.end_of_pass_write_index)
753                    .map(|i| i..i + 1)
754            };
755            let raw_query_set = query_set
756                .try_raw(parent_state.snatch_guard)
757                .map_pass_err(pass_scope)?;
758            // Range should always be Some, both values being None should lead to a validation error.
759            // But no point in erroring over that nuance here!
760            if let Some(range) = range {
761                unsafe {
762                    state
763                        .pass
764                        .base
765                        .raw_encoder
766                        .reset_queries(raw_query_set, range);
767                }
768            }
769
770            Some(hal::PassTimestampWrites {
771                query_set: raw_query_set,
772                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
773                end_of_pass_write_index: tw.end_of_pass_write_index,
774            })
775        } else {
776            None
777        };
778
779    let hal_desc = hal::ComputePassDescriptor {
780        label: hal_label(base.label.as_deref(), device.instance_flags),
781        timestamp_writes,
782    };
783
784    unsafe {
785        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
786    }
787
788    for command in base.commands.drain(..) {
789        match command {
790            ArcComputeCommand::SetBindGroup {
791                index,
792                num_dynamic_offsets,
793                bind_group,
794            } => {
795                let scope = PassErrorScope::SetBindGroup;
796                pass::set_bind_group::<ComputePassErrorInner>(
797                    &mut state.pass,
798                    device,
799                    &base.dynamic_offsets,
800                    index,
801                    num_dynamic_offsets,
802                    bind_group,
803                    false,
804                )
805                .map_pass_err(scope)?;
806            }
807            ArcComputeCommand::SetPipeline(pipeline) => {
808                let scope = PassErrorScope::SetPipelineCompute;
809                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
810            }
811            ArcComputeCommand::SetImmediate {
812                offset,
813                size_bytes,
814                values_offset,
815            } => {
816                let scope = PassErrorScope::SetImmediate;
817                pass::set_immediates::<ComputePassErrorInner, _>(
818                    &mut state.pass,
819                    &base.immediates_data,
820                    offset,
821                    size_bytes,
822                    Some(values_offset),
823                    |data_slice| {
824                        let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
825                        let size_in_elements =
826                            (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
827                        state.immediates[offset_in_elements..][..size_in_elements]
828                            .copy_from_slice(data_slice);
829                    },
830                )
831                .map_pass_err(scope)?;
832                state.immediate_slots_set |=
833                    naga::valid::ImmediateSlots::from_range(offset, size_bytes);
834            }
835            ArcComputeCommand::DispatchWorkgroups(groups) => {
836                let scope = PassErrorScope::Dispatch { indirect: false };
837                dispatch_workgroups(&mut state, groups).map_pass_err(scope)?;
838            }
839            ArcComputeCommand::DispatchWorkgroupsIndirect { buffer, offset } => {
840                let scope = PassErrorScope::Dispatch { indirect: true };
841                dispatch_workgroups_indirect(&mut state, device, buffer, offset)
842                    .map_pass_err(scope)?;
843            }
844            ArcComputeCommand::PushDebugGroup { color: _, len } => {
845                pass::push_debug_group(&mut state.pass, &base.string_data, len);
846            }
847            ArcComputeCommand::PopDebugGroup => {
848                let scope = PassErrorScope::PopDebugGroup;
849                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
850                    .map_pass_err(scope)?;
851            }
852            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
853                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
854            }
855            ArcComputeCommand::WriteTimestamp {
856                query_set,
857                query_index,
858            } => {
859                let scope = PassErrorScope::WriteTimestamp;
860                pass::write_timestamp::<ComputePassErrorInner>(
861                    &mut state.pass,
862                    device,
863                    None, // compute passes do not attempt to coalesce query resets
864                    query_set,
865                    query_index,
866                )
867                .map_pass_err(scope)?;
868            }
869            ArcComputeCommand::BeginPipelineStatisticsQuery {
870                query_set,
871                query_index,
872            } => {
873                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
874                validate_and_begin_pipeline_statistics_query(
875                    query_set,
876                    state.pass.base.raw_encoder,
877                    &mut state.pass.base.tracker.query_sets,
878                    device,
879                    query_index,
880                    None,
881                    &mut state.active_query,
882                    state.pass.base.snatch_guard,
883                )
884                .map_pass_err(scope)?;
885            }
886            ArcComputeCommand::EndPipelineStatisticsQuery => {
887                let scope = PassErrorScope::EndPipelineStatisticsQuery;
888                end_pipeline_statistics_query(
889                    state.pass.base.raw_encoder,
890                    &mut state.active_query,
891                    state.pass.base.snatch_guard,
892                    state.pass.base.query_set_writes,
893                )
894                .map_pass_err(scope)?;
895            }
896            ArcComputeCommand::TransitionResources {
897                buffer_transitions,
898                texture_transitions,
899            } => {
900                let scope = PassErrorScope::TransitionResources;
901                transition_resources(&mut state, buffer_transitions, texture_transitions)
902                    .map_pass_err(scope)?;
903            }
904        }
905    }
906
907    if *state.pass.base.debug_scope_depth > 0 {
908        Err(
909            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
910                .map_pass_err(pass_scope),
911        )?;
912    }
913
914    unsafe {
915        state.pass.base.raw_encoder.end_compute_pass();
916    }
917
918    let State {
919        pass: pass::PassState {
920            pending_discard_init_fixups,
921            ..
922        },
923        intermediate_trackers,
924        ..
925    } = state;
926
927    // Stop the current command encoder.
928    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
929
930    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
931    //
932    // Use that buffer to insert barriers and clear discarded images.
933    let transit = parent_state
934        .raw_encoder
935        .open_pass(hal_label(
936            Some("(wgpu internal) Pre Pass"),
937            device.instance_flags,
938        ))
939        .map_pass_err(pass_scope)?;
940    fixup_discarded_surfaces(
941        pending_discard_init_fixups.into_iter(),
942        transit,
943        &mut parent_state.tracker.textures,
944        device,
945        parent_state.snatch_guard,
946    );
947    CommandEncoder::insert_barriers_from_tracker(
948        transit,
949        parent_state.tracker,
950        &intermediate_trackers,
951        parent_state.snatch_guard,
952    );
953    // Close the command encoder, and swap it with the previous.
954    parent_state
955        .raw_encoder
956        .close_and_swap()
957        .map_pass_err(pass_scope)?;
958
959    Ok(())
960}
961
962fn set_pipeline(
963    state: &mut State,
964    device: &Arc<Device>,
965    pipeline: Arc<ComputePipeline>,
966) -> Result<(), ComputePassErrorInner> {
967    pipeline.same_device(device)?;
968
969    state.pipeline = Some(pipeline.clone());
970
971    let pipeline = state
972        .pass
973        .base
974        .tracker
975        .compute_pipelines
976        .insert_single(pipeline)
977        .clone();
978
979    unsafe {
980        state
981            .pass
982            .base
983            .raw_encoder
984            .set_compute_pipeline(pipeline.raw()?);
985    }
986
987    // Rebind resources
988    let pipeline_layout = pipeline.layout()?;
989    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
990        &mut state.pass,
991        pipeline_layout,
992        &pipeline.late_sized_buffer_groups,
993        || {
994            // This only needs to be here for compute pipelines because they use immediates for
995            // validating indirect draws.
996            state.immediates.clear();
997            // Note that can only be one range for each stage. See the `MoreThanOneImmediateRangePerStage` error.
998            if pipeline_layout.immediate_size != 0 {
999                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
1000                let len = pipeline_layout.immediate_size as usize
1001                    / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
1002                state.immediates.extend(core::iter::repeat_n(0, len));
1003            }
1004        },
1005    )
1006}
1007
1008fn dispatch_workgroups(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
1009    api_log!("ComputePass::dispatch {groups:?}");
1010
1011    state.is_ready()?;
1012
1013    state.flush_bindings(None, false)?;
1014
1015    let groups_size_limit = state
1016        .pass
1017        .base
1018        .device
1019        .limits
1020        .max_compute_workgroups_per_dimension;
1021
1022    if groups.iter().copied().any(|g| g > groups_size_limit) {
1023        return Err(ComputePassErrorInner::Dispatch(
1024            DispatchError::InvalidGroupSize {
1025                current: groups,
1026                limit: groups_size_limit,
1027            },
1028        ));
1029    }
1030
1031    unsafe {
1032        state.pass.base.raw_encoder.dispatch_workgroups(groups);
1033    }
1034    Ok(())
1035}
1036
1037fn dispatch_workgroups_indirect(
1038    state: &mut State,
1039    device: &Arc<Device>,
1040    buffer: Arc<Buffer>,
1041    offset: u64,
1042) -> Result<(), ComputePassErrorInner> {
1043    api_log!("ComputePass::dispatch_indirect");
1044
1045    buffer.same_device(device)?;
1046
1047    state.is_ready()?;
1048
1049    state
1050        .pass
1051        .base
1052        .device
1053        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
1054
1055    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
1056
1057    if !offset.is_multiple_of(4) {
1058        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
1059    }
1060
1061    let args_size = size_of::<wgt::DispatchIndirectArgs>() as u64;
1062    if buffer.size < args_size || buffer.size - args_size < offset {
1063        return Err(ComputePassErrorInner::IndirectBufferOverrun {
1064            args_size,
1065            offset,
1066            buffer_size: buffer.size,
1067        });
1068    }
1069
1070    buffer.check_destroyed(state.pass.base.snatch_guard)?;
1071
1072    let stride = 3 * 4; // 3 integers, x/y/z group size
1073    state.pass.base.buffer_memory_init_actions.extend(
1074        buffer.initialization_status.read().create_action(
1075            &buffer,
1076            offset..(offset + stride),
1077            MemoryInitKind::NeedsInitializedMemory,
1078        ),
1079    );
1080
1081    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
1082        let params = indirect_validation.dispatch.params(
1083            &state.pass.base.device.limits,
1084            offset,
1085            buffer.size,
1086        );
1087
1088        unsafe {
1089            state
1090                .pass
1091                .base
1092                .raw_encoder
1093                .set_compute_pipeline(params.pipeline);
1094        }
1095
1096        unsafe {
1097            state.pass.base.raw_encoder.set_immediates(
1098                params.pipeline_layout,
1099                0,
1100                &[params.offset_remainder as u32 / 4],
1101            );
1102        }
1103
1104        unsafe {
1105            state.pass.base.raw_encoder.set_bind_group(
1106                params.pipeline_layout,
1107                0,
1108                params.dst_bind_group,
1109                &[],
1110            );
1111        }
1112        unsafe {
1113            state.pass.base.raw_encoder.set_bind_group(
1114                params.pipeline_layout,
1115                1,
1116                buffer
1117                    .indirect_validation_bind_groups
1118                    .get(state.pass.base.snatch_guard)
1119                    .unwrap()
1120                    .dispatch
1121                    .as_ref(),
1122                &[params.aligned_offset as u32],
1123            );
1124        }
1125
1126        let src_transition = state
1127            .intermediate_trackers
1128            .buffers
1129            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
1130        let src_barrier = src_transition
1131            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
1132        unsafe {
1133            state
1134                .pass
1135                .base
1136                .raw_encoder
1137                .transition_buffers(src_barrier.as_slice());
1138        }
1139
1140        unsafe {
1141            state
1142                .pass
1143                .base
1144                .raw_encoder
1145                .transition_buffers(&[hal::BufferBarrier {
1146                    buffer: params.dst_buffer,
1147                    usage: hal::StateTransition {
1148                        from: wgt::BufferUses::INDIRECT,
1149                        to: wgt::BufferUses::STORAGE_READ_WRITE,
1150                    },
1151                }]);
1152        }
1153
1154        unsafe {
1155            state.pass.base.raw_encoder.dispatch_workgroups([1, 1, 1]);
1156        }
1157
1158        // reset state
1159        {
1160            let pipeline = state.pipeline.as_ref().unwrap();
1161
1162            unsafe {
1163                state
1164                    .pass
1165                    .base
1166                    .raw_encoder
1167                    .set_compute_pipeline(pipeline.raw()?);
1168            }
1169
1170            let pipeline_layout = pipeline.layout()?;
1171
1172            if !state.immediates.is_empty() {
1173                unsafe {
1174                    state.pass.base.raw_encoder.set_immediates(
1175                        pipeline_layout.raw()?,
1176                        0,
1177                        &state.immediates,
1178                    );
1179                }
1180            }
1181
1182            for (i, group, dynamic_offsets) in state.pass.binder.list_valid() {
1183                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1184                unsafe {
1185                    state.pass.base.raw_encoder.set_bind_group(
1186                        pipeline_layout.raw()?,
1187                        i as u32,
1188                        raw_bg,
1189                        dynamic_offsets,
1190                    );
1191                }
1192            }
1193        }
1194
1195        unsafe {
1196            state
1197                .pass
1198                .base
1199                .raw_encoder
1200                .transition_buffers(&[hal::BufferBarrier {
1201                    buffer: params.dst_buffer,
1202                    usage: hal::StateTransition {
1203                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1204                        to: wgt::BufferUses::INDIRECT,
1205                    },
1206                }]);
1207        }
1208
1209        state.flush_bindings(Some(&buffer), false)?;
1210        unsafe {
1211            state
1212                .pass
1213                .base
1214                .raw_encoder
1215                .dispatch_workgroups_indirect(params.dst_buffer, 0);
1216        }
1217    } else {
1218        state.flush_bindings(Some(&buffer), true)?;
1219
1220        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1221        unsafe {
1222            state
1223                .pass
1224                .base
1225                .raw_encoder
1226                .dispatch_workgroups_indirect(buf_raw, offset);
1227        }
1228    }
1229
1230    Ok(())
1231}
1232
1233// Recording a compute pass.
1234//
1235// The only error that should be returned from these methods is
1236// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1237// validation error is raised.
1238//
1239// All other errors should be stored in the pass for later reporting when
1240// `CommandEncoder.finish()` is called.
1241//
1242// The `pass_try!` macro should be used to handle errors appropriately. Note
1243// that the `pass_try!` and `pass_base!` macros may return early from the
1244// function that invokes them, like the `?` operator.
1245impl Global {
1246    pub fn compute_pass_set_bind_group(
1247        &self,
1248        pass: &mut ComputePass,
1249        index: u32,
1250        bind_group_id: Option<id::BindGroupId>,
1251        offsets: &[DynamicOffset],
1252    ) -> Result<(), PassStateError> {
1253        let scope = PassErrorScope::SetBindGroup;
1254
1255        // This statement will return an error if the pass is ended. It's
1256        // important the error check comes before the early-out for
1257        // `set_and_check_redundant`.
1258        let base = pass_base!(pass, scope);
1259
1260        if pass.current_bind_groups.set_and_check_redundant(
1261            bind_group_id,
1262            index,
1263            &mut base.dynamic_offsets,
1264            offsets,
1265        ) {
1266            return Ok(());
1267        }
1268
1269        let mut bind_group = None;
1270        if let Some(bind_group_id) = bind_group_id {
1271            let hub = &self.hub;
1272            bind_group = Some(pass_try!(
1273                base,
1274                scope,
1275                hub.bind_groups.get(bind_group_id).get(),
1276            ));
1277        }
1278
1279        base.commands.push(ArcComputeCommand::SetBindGroup {
1280            index,
1281            num_dynamic_offsets: offsets.len(),
1282            bind_group,
1283        });
1284
1285        Ok(())
1286    }
1287
1288    pub fn compute_pass_set_bind_group_with_id(
1289        &self,
1290        pass_id: id::ComputePassEncoderId,
1291        index: u32,
1292        bind_group_id: Option<id::BindGroupId>,
1293        offsets: &[DynamicOffset],
1294    ) -> Result<(), PassStateError> {
1295        let pass = self.hub.compute_passes.get(pass_id);
1296        let mut pass = pass
1297            .try_lock()
1298            .expect("ComputePasses should not be accessed concurrently");
1299        self.compute_pass_set_bind_group(&mut pass, index, bind_group_id, offsets)
1300    }
1301
1302    pub fn compute_pass_set_pipeline(
1303        &self,
1304        pass: &mut ComputePass,
1305        pipeline_id: id::ComputePipelineId,
1306    ) -> Result<(), PassStateError> {
1307        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1308
1309        let scope = PassErrorScope::SetPipelineCompute;
1310
1311        // This statement will return an error if the pass is ended.
1312        // Its important the error check comes before the early-out for `redundant`.
1313        let base = pass_base!(pass, scope);
1314
1315        if redundant {
1316            return Ok(());
1317        }
1318
1319        let hub = &self.hub;
1320        let compute_pipeline = hub.compute_pipelines.get(pipeline_id);
1321        pass_try!(base, scope, compute_pipeline.check_valid());
1322
1323        base.commands
1324            .push(ArcComputeCommand::SetPipeline(compute_pipeline));
1325
1326        Ok(())
1327    }
1328
1329    pub fn compute_pass_set_pipeline_with_id(
1330        &self,
1331        pass_id: id::ComputePassEncoderId,
1332        pipeline_id: id::ComputePipelineId,
1333    ) -> Result<(), PassStateError> {
1334        let pass = self.hub.compute_passes.get(pass_id);
1335        let mut pass = pass
1336            .try_lock()
1337            .expect("ComputePasses should not be accessed concurrently");
1338        self.compute_pass_set_pipeline(&mut pass, pipeline_id)
1339    }
1340
1341    pub fn compute_pass_set_immediates(
1342        &self,
1343        pass: &mut ComputePass,
1344        offset: u32,
1345        data: &[u8],
1346    ) -> Result<(), PassStateError> {
1347        let scope = PassErrorScope::SetImmediate;
1348        let base = pass_base!(pass, scope);
1349
1350        let size_bytes = pass_try!(
1351            base,
1352            scope,
1353            u32::try_from(data.len()).map_err(|_| ImmediateUploadError::ImmediateOutOfMemory)
1354        );
1355        pass_try!(
1356            base,
1357            scope,
1358            pass::validate_immediates_alignment(offset, size_bytes)
1359        );
1360
1361        let values_offset = base.immediates_data.len().try_into().unwrap();
1362
1363        base.immediates_data.extend(
1364            data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1365                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1366        );
1367
1368        base.commands.push(ArcComputeCommand::SetImmediate {
1369            offset,
1370            size_bytes: data.len() as u32,
1371            values_offset,
1372        });
1373
1374        Ok(())
1375    }
1376
1377    pub fn compute_pass_set_immediates_with_id(
1378        &self,
1379        pass_id: id::ComputePassEncoderId,
1380        offset: u32,
1381        data: &[u8],
1382    ) -> Result<(), PassStateError> {
1383        let pass = self.hub.compute_passes.get(pass_id);
1384        let mut pass = pass
1385            .try_lock()
1386            .expect("ComputePasses should not be accessed concurrently");
1387        self.compute_pass_set_immediates(&mut pass, offset, data)
1388    }
1389
1390    pub fn compute_pass_dispatch_workgroups(
1391        &self,
1392        pass: &mut ComputePass,
1393        groups_x: u32,
1394        groups_y: u32,
1395        groups_z: u32,
1396    ) -> Result<(), PassStateError> {
1397        let scope = PassErrorScope::Dispatch { indirect: false };
1398
1399        pass_base!(pass, scope)
1400            .commands
1401            .push(ArcComputeCommand::DispatchWorkgroups([
1402                groups_x, groups_y, groups_z,
1403            ]));
1404
1405        Ok(())
1406    }
1407
1408    pub fn compute_pass_dispatch_workgroups_with_id(
1409        &self,
1410        pass_id: id::ComputePassEncoderId,
1411        groups_x: u32,
1412        groups_y: u32,
1413        groups_z: u32,
1414    ) -> Result<(), PassStateError> {
1415        let pass = self.hub.compute_passes.get(pass_id);
1416        let mut pass = pass
1417            .try_lock()
1418            .expect("ComputePasses should not be accessed concurrently");
1419        self.compute_pass_dispatch_workgroups(&mut pass, groups_x, groups_y, groups_z)
1420    }
1421
1422    pub fn compute_pass_dispatch_workgroups_indirect(
1423        &self,
1424        pass: &mut ComputePass,
1425        buffer_id: id::BufferId,
1426        offset: BufferAddress,
1427    ) -> Result<(), PassStateError> {
1428        let hub = &self.hub;
1429        let scope = PassErrorScope::Dispatch { indirect: true };
1430        let base = pass_base!(pass, scope);
1431
1432        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1433
1434        base.commands
1435            .push(ArcComputeCommand::DispatchWorkgroupsIndirect { buffer, offset });
1436
1437        Ok(())
1438    }
1439
1440    pub fn compute_pass_dispatch_workgroups_indirect_with_id(
1441        &self,
1442        pass_id: id::ComputePassEncoderId,
1443        buffer_id: id::BufferId,
1444        offset: BufferAddress,
1445    ) -> Result<(), PassStateError> {
1446        let pass = self.hub.compute_passes.get(pass_id);
1447        let mut pass = pass
1448            .try_lock()
1449            .expect("ComputePasses should not be accessed concurrently");
1450        self.compute_pass_dispatch_workgroups_indirect(&mut pass, buffer_id, offset)
1451    }
1452
1453    pub fn compute_pass_push_debug_group(
1454        &self,
1455        pass: &mut ComputePass,
1456        label: &str,
1457        color: u32,
1458    ) -> Result<(), PassStateError> {
1459        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1460
1461        let bytes = label.as_bytes();
1462        base.string_data.extend_from_slice(bytes);
1463
1464        base.commands.push(ArcComputeCommand::PushDebugGroup {
1465            color,
1466            len: bytes.len(),
1467        });
1468
1469        Ok(())
1470    }
1471
1472    pub fn compute_pass_push_debug_group_with_id(
1473        &self,
1474        pass_id: id::ComputePassEncoderId,
1475        label: &str,
1476        color: u32,
1477    ) -> Result<(), PassStateError> {
1478        let pass = self.hub.compute_passes.get(pass_id);
1479        let mut pass = pass
1480            .try_lock()
1481            .expect("ComputePasses should not be accessed concurrently");
1482        self.compute_pass_push_debug_group(&mut pass, label, color)
1483    }
1484
1485    pub fn compute_pass_pop_debug_group(
1486        &self,
1487        pass: &mut ComputePass,
1488    ) -> Result<(), PassStateError> {
1489        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1490
1491        base.commands.push(ArcComputeCommand::PopDebugGroup);
1492
1493        Ok(())
1494    }
1495
1496    pub fn compute_pass_pop_debug_group_with_id(
1497        &self,
1498        pass_id: id::ComputePassEncoderId,
1499    ) -> Result<(), PassStateError> {
1500        let pass = self.hub.compute_passes.get(pass_id);
1501        let mut pass = pass
1502            .try_lock()
1503            .expect("ComputePasses should not be accessed concurrently");
1504        self.compute_pass_pop_debug_group(&mut pass)
1505    }
1506
1507    pub fn compute_pass_insert_debug_marker(
1508        &self,
1509        pass: &mut ComputePass,
1510        label: &str,
1511        color: u32,
1512    ) -> Result<(), PassStateError> {
1513        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1514
1515        let bytes = label.as_bytes();
1516        base.string_data.extend_from_slice(bytes);
1517
1518        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1519            color,
1520            len: bytes.len(),
1521        });
1522
1523        Ok(())
1524    }
1525
1526    pub fn compute_pass_insert_debug_marker_with_id(
1527        &self,
1528        pass_id: id::ComputePassEncoderId,
1529        label: &str,
1530        color: u32,
1531    ) -> Result<(), PassStateError> {
1532        let pass = self.hub.compute_passes.get(pass_id);
1533        let mut pass = pass
1534            .try_lock()
1535            .expect("ComputePasses should not be accessed concurrently");
1536        self.compute_pass_insert_debug_marker(&mut pass, label, color)
1537    }
1538
1539    pub fn compute_pass_write_timestamp(
1540        &self,
1541        pass: &mut ComputePass,
1542        query_set_id: id::QuerySetId,
1543        query_index: u32,
1544    ) -> Result<(), PassStateError> {
1545        let scope = PassErrorScope::WriteTimestamp;
1546        let base = pass_base!(pass, scope);
1547
1548        let hub = &self.hub;
1549        let query_set = hub.query_sets.get(query_set_id);
1550        pass_try!(base, scope, query_set.check_is_valid());
1551
1552        base.commands.push(ArcComputeCommand::WriteTimestamp {
1553            query_set,
1554            query_index,
1555        });
1556
1557        Ok(())
1558    }
1559
1560    pub fn compute_pass_write_timestamp_with_id(
1561        &self,
1562        pass_id: id::ComputePassEncoderId,
1563        query_set_id: id::QuerySetId,
1564        query_index: u32,
1565    ) -> Result<(), PassStateError> {
1566        let pass = self.hub.compute_passes.get(pass_id);
1567        let mut pass = pass
1568            .try_lock()
1569            .expect("ComputePasses should not be accessed concurrently");
1570        self.compute_pass_write_timestamp(&mut pass, query_set_id, query_index)
1571    }
1572
1573    pub fn compute_pass_begin_pipeline_statistics_query(
1574        &self,
1575        pass: &mut ComputePass,
1576        query_set_id: id::QuerySetId,
1577        query_index: u32,
1578    ) -> Result<(), PassStateError> {
1579        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1580        let base = pass_base!(pass, scope);
1581
1582        let hub = &self.hub;
1583        let query_set = hub.query_sets.get(query_set_id);
1584        pass_try!(base, scope, query_set.check_is_valid());
1585
1586        base.commands
1587            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1588                query_set,
1589                query_index,
1590            });
1591
1592        Ok(())
1593    }
1594
1595    pub fn compute_pass_begin_pipeline_statistics_query_with_id(
1596        &self,
1597        pass_id: id::ComputePassEncoderId,
1598        query_set_id: id::QuerySetId,
1599        query_index: u32,
1600    ) -> Result<(), PassStateError> {
1601        let pass = self.hub.compute_passes.get(pass_id);
1602        let mut pass = pass
1603            .try_lock()
1604            .expect("ComputePasses should not be accessed concurrently");
1605        self.compute_pass_begin_pipeline_statistics_query(&mut pass, query_set_id, query_index)
1606    }
1607
1608    pub fn compute_pass_end_pipeline_statistics_query(
1609        &self,
1610        pass: &mut ComputePass,
1611    ) -> Result<(), PassStateError> {
1612        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1613            .commands
1614            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1615
1616        Ok(())
1617    }
1618
1619    pub fn compute_pass_end_pipeline_statistics_query_with_id(
1620        &self,
1621        pass_id: id::ComputePassEncoderId,
1622    ) -> Result<(), PassStateError> {
1623        let pass = self.hub.compute_passes.get(pass_id);
1624        let mut pass = pass
1625            .try_lock()
1626            .expect("ComputePasses should not be accessed concurrently");
1627        self.compute_pass_end_pipeline_statistics_query(&mut pass)
1628    }
1629
1630    pub fn compute_pass_transition_resources(
1631        &self,
1632        pass: &mut ComputePass,
1633        buffer_transitions: impl Iterator<Item = wgt::BufferTransition<id::BufferId>>,
1634        texture_transitions: impl Iterator<Item = wgt::TextureTransition<id::TextureViewId>>,
1635    ) -> Result<(), PassStateError> {
1636        let scope = PassErrorScope::TransitionResources;
1637        let base = pass_base!(pass, scope);
1638
1639        let hub = &self.hub;
1640
1641        let buffer_transitions = pass_try!(
1642            base,
1643            scope,
1644            buffer_transitions
1645                .map(|buffer_transition| -> Result<_, InvalidResourceError> {
1646                    Ok(wgt::BufferTransition {
1647                        buffer: hub.buffers.get(buffer_transition.buffer).get()?,
1648                        state: buffer_transition.state,
1649                    })
1650                })
1651                .collect::<Result<Vec<_>, _>>()
1652        );
1653
1654        let texture_transitions = pass_try!(
1655            base,
1656            scope,
1657            texture_transitions
1658                .map(|texture_transition| -> Result<_, InvalidResourceError> {
1659                    let texture_view = hub.texture_views.get(texture_transition.texture);
1660                    texture_view.check_valid()?;
1661                    Ok(wgt::TextureTransition {
1662                        texture: texture_view,
1663                        selector: texture_transition.selector,
1664                        state: texture_transition.state,
1665                    })
1666                })
1667                .collect::<Result<Vec<_>, _>>()
1668        );
1669
1670        base.commands.push(ArcComputeCommand::TransitionResources {
1671            buffer_transitions,
1672            texture_transitions,
1673        });
1674
1675        Ok(())
1676    }
1677}