wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11    api_log,
12    binding_model::BindError,
13    command::pass::flush_bindings_helper,
14    resource::{RawResourceAccess, Trackable},
15};
16use crate::{
17    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
18    command::{
19        bind::{Binder, BinderError},
20        compute_command::ArcComputeCommand,
21        end_pipeline_statistics_query,
22        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25        PassTimestampWrites, QueryUseError, StateChange,
26    },
27    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28    global::Global,
29    hal_label, id,
30    init_tracker::MemoryInitKind,
31    pipeline::ComputePipeline,
32    resource::{
33        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34    },
35    track::{ResourceUsageCompatibilityError, Tracker},
36    Label,
37};
38use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
39use crate::{
40    command::{
41        encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
42        EncoderStateError, PassStateError, TimestampWritesError,
43    },
44    device::Device,
45};
46
47pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
48
49/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
50/// its validity are two distinct conditions, i.e., the full matrix of
51/// (open, ended) x (valid, invalid) is possible.
52///
53/// The presence or absence of the `parent` `Option` indicates the pass's state.
54/// The presence or absence of an error in `base.error` indicates the pass's
55/// validity.
56pub struct ComputePass {
57    /// All pass data & records is stored here.
58    base: ComputeBasePass,
59
60    /// Parent command encoder that this pass records commands into.
61    ///
62    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
63    /// `None`, then the pass is in the "ended" state.
64    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
65    parent: Option<Arc<CommandEncoder>>,
66
67    timestamp_writes: Option<ArcPassTimestampWrites>,
68
69    // Resource binding dedupe state.
70    current_bind_groups: BindGroupStateChange,
71    current_pipeline: StateChange<id::ComputePipelineId>,
72}
73
74impl ComputePass {
75    /// If the parent command encoder is invalid, the returned pass will be invalid.
76    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
77        let ArcComputePassDescriptor {
78            label,
79            timestamp_writes,
80        } = desc;
81
82        Self {
83            base: BasePass::new(&label),
84            parent: Some(parent),
85            timestamp_writes,
86
87            current_bind_groups: BindGroupStateChange::new(),
88            current_pipeline: StateChange::new(),
89        }
90    }
91
92    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
93        Self {
94            base: BasePass::new_invalid(label, err),
95            parent: Some(parent),
96            timestamp_writes: None,
97            current_bind_groups: BindGroupStateChange::new(),
98            current_pipeline: StateChange::new(),
99        }
100    }
101
102    #[inline]
103    pub fn label(&self) -> Option<&str> {
104        self.base.label.as_deref()
105    }
106}
107
108impl fmt::Debug for ComputePass {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self.parent {
111            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
112            None => write!(f, "ComputePass {{ parent: None }}"),
113        }
114    }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
119    pub label: Label<'a>,
120    /// Defines where and when timestamp values will be written for this pass.
121    pub timestamp_writes: Option<PTW>,
122}
123
124/// cbindgen:ignore
125type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
126
127#[derive(Clone, Debug, Error)]
128#[non_exhaustive]
129pub enum DispatchError {
130    #[error("Compute pipeline must be set")]
131    MissingPipeline(pass::MissingPipeline),
132    #[error(transparent)]
133    IncompatibleBindGroup(#[from] Box<BinderError>),
134    #[error(
135        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
136    )]
137    InvalidGroupSize { current: [u32; 3], limit: u32 },
138    #[error(transparent)]
139    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
140}
141
142impl WebGpuError for DispatchError {
143    fn webgpu_error_type(&self) -> ErrorType {
144        ErrorType::Validation
145    }
146}
147
148/// Error encountered when performing a compute pass.
149#[derive(Clone, Debug, Error)]
150pub enum ComputePassErrorInner {
151    #[error(transparent)]
152    Device(#[from] DeviceError),
153    #[error(transparent)]
154    EncoderState(#[from] EncoderStateError),
155    #[error("Parent encoder is invalid")]
156    InvalidParentEncoder,
157    #[error(transparent)]
158    DebugGroupError(#[from] DebugGroupError),
159    #[error(transparent)]
160    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
161    #[error(transparent)]
162    DestroyedResource(#[from] DestroyedResourceError),
163    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
164    UnalignedIndirectBufferOffset(BufferAddress),
165    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
166    IndirectBufferOverrun {
167        offset: u64,
168        end_offset: u64,
169        buffer_size: u64,
170    },
171    #[error(transparent)]
172    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
173    #[error(transparent)]
174    MissingBufferUsage(#[from] MissingBufferUsageError),
175    #[error(transparent)]
176    Dispatch(#[from] DispatchError),
177    #[error(transparent)]
178    Bind(#[from] BindError),
179    #[error(transparent)]
180    PushConstants(#[from] PushConstantUploadError),
181    #[error("Push constant offset must be aligned to 4 bytes")]
182    PushConstantOffsetAlignment,
183    #[error("Push constant size must be aligned to 4 bytes")]
184    PushConstantSizeAlignment,
185    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
186    PushConstantOutOfMemory,
187    #[error(transparent)]
188    QueryUse(#[from] QueryUseError),
189    #[error(transparent)]
190    MissingFeatures(#[from] MissingFeatures),
191    #[error(transparent)]
192    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
193    #[error("The compute pass has already been ended and no further commands can be recorded")]
194    PassEnded,
195    #[error(transparent)]
196    InvalidResource(#[from] InvalidResourceError),
197    #[error(transparent)]
198    TimestampWrites(#[from] TimestampWritesError),
199    // This one is unreachable, but required for generic pass support
200    #[error(transparent)]
201    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
202}
203
204/// Error encountered when performing a compute pass, stored for later reporting
205/// when encoding ends.
206#[derive(Clone, Debug, Error)]
207#[error("{scope}")]
208pub struct ComputePassError {
209    pub scope: PassErrorScope,
210    #[source]
211    pub(super) inner: ComputePassErrorInner,
212}
213
214impl From<pass::MissingPipeline> for ComputePassErrorInner {
215    fn from(value: pass::MissingPipeline) -> Self {
216        Self::Dispatch(DispatchError::MissingPipeline(value))
217    }
218}
219
220impl<E> MapPassErr<ComputePassError> for E
221where
222    E: Into<ComputePassErrorInner>,
223{
224    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
225        ComputePassError {
226            scope,
227            inner: self.into(),
228        }
229    }
230}
231
232impl WebGpuError for ComputePassError {
233    fn webgpu_error_type(&self) -> ErrorType {
234        let Self { scope: _, inner } = self;
235        let e: &dyn WebGpuError = match inner {
236            ComputePassErrorInner::Device(e) => e,
237            ComputePassErrorInner::EncoderState(e) => e,
238            ComputePassErrorInner::DebugGroupError(e) => e,
239            ComputePassErrorInner::DestroyedResource(e) => e,
240            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
241            ComputePassErrorInner::MissingBufferUsage(e) => e,
242            ComputePassErrorInner::Dispatch(e) => e,
243            ComputePassErrorInner::Bind(e) => e,
244            ComputePassErrorInner::PushConstants(e) => e,
245            ComputePassErrorInner::QueryUse(e) => e,
246            ComputePassErrorInner::MissingFeatures(e) => e,
247            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
248            ComputePassErrorInner::InvalidResource(e) => e,
249            ComputePassErrorInner::TimestampWrites(e) => e,
250            ComputePassErrorInner::InvalidValuesOffset(e) => e,
251
252            ComputePassErrorInner::InvalidParentEncoder
253            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
254            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
255            | ComputePassErrorInner::IndirectBufferOverrun { .. }
256            | ComputePassErrorInner::PushConstantOffsetAlignment
257            | ComputePassErrorInner::PushConstantSizeAlignment
258            | ComputePassErrorInner::PushConstantOutOfMemory
259            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
260        };
261        e.webgpu_error_type()
262    }
263}
264
265struct State<'scope, 'snatch_guard, 'cmd_enc> {
266    pipeline: Option<Arc<ComputePipeline>>,
267
268    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
269
270    active_query: Option<(Arc<resource::QuerySet>, u32)>,
271
272    push_constants: Vec<u32>,
273
274    intermediate_trackers: Tracker,
275}
276
277impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
278    fn is_ready(&self) -> Result<(), DispatchError> {
279        if let Some(pipeline) = self.pipeline.as_ref() {
280            self.pass.binder.check_compatibility(pipeline.as_ref())?;
281            self.pass.binder.check_late_buffer_bindings()?;
282            Ok(())
283        } else {
284            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
285        }
286    }
287
288    /// Flush binding state in preparation for a dispatch.
289    ///
290    /// # Differences between render and compute passes
291    ///
292    /// There are differences between the `flush_bindings` implementations for
293    /// render and compute passes, because render passes have a single usage
294    /// scope for the entire pass, and compute passes have a separate usage
295    /// scope for each dispatch.
296    ///
297    /// For compute passes, bind groups are merged into a fresh usage scope
298    /// here, not into the pass usage scope within calls to `set_bind_group`. As
299    /// specified by WebGPU, for compute passes, we merge only the bind groups
300    /// that are actually used by the pipeline, unlike render passes, which
301    /// merge every bind group that is ever set, even if it is not ultimately
302    /// used by the pipeline.
303    ///
304    /// For compute passes, we call `drain_barriers` here, because barriers may
305    /// be needed before each dispatch if a previous dispatch had a conflicting
306    /// usage. For render passes, barriers are emitted once at the start of the
307    /// render pass.
308    ///
309    /// # Indirect buffer handling
310    ///
311    /// The `indirect_buffer` argument should be passed for any indirect
312    /// dispatch (with or without validation). It will be checked for
313    /// conflicting usages according to WebGPU rules. For the purpose of
314    /// these rules, the fact that we have actually processed the buffer in
315    /// the validation pass is an implementation detail.
316    ///
317    /// The `track_indirect_buffer` argument should be set when doing indirect
318    /// dispatch *without* validation. In this case, the indirect buffer will
319    /// be added to the tracker in order to generate any necessary transitions
320    /// for that usage.
321    ///
322    /// When doing indirect dispatch *with* validation, the indirect buffer is
323    /// processed by the validation pass and is not used by the actual dispatch.
324    /// The indirect validation code handles transitions for the validation
325    /// pass.
326    fn flush_bindings(
327        &mut self,
328        indirect_buffer: Option<&Arc<Buffer>>,
329        track_indirect_buffer: bool,
330    ) -> Result<(), ComputePassErrorInner> {
331        for bind_group in self.pass.binder.list_active() {
332            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
333        }
334
335        // Add the indirect buffer. Because usage scopes are per-dispatch, this
336        // is the only place where INDIRECT usage could be added, and it is safe
337        // for us to remove it below.
338        if let Some(buffer) = indirect_buffer {
339            self.pass
340                .scope
341                .buffers
342                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
343        }
344
345        // For compute, usage scopes are associated with each dispatch and not
346        // with the pass as a whole. However, because the cost of creating and
347        // dropping `UsageScope`s is significant (even with the pool), we
348        // add and then remove usage from a single usage scope.
349
350        for bind_group in self.pass.binder.list_active() {
351            self.intermediate_trackers
352                .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
353        }
354
355        if track_indirect_buffer {
356            self.intermediate_trackers
357                .buffers
358                .set_and_remove_from_usage_scope_sparse(
359                    &mut self.pass.scope.buffers,
360                    indirect_buffer.map(|buf| buf.tracker_index()),
361                );
362        } else if let Some(buffer) = indirect_buffer {
363            self.pass
364                .scope
365                .buffers
366                .remove_usage(buffer, wgt::BufferUses::INDIRECT);
367        }
368
369        flush_bindings_helper(&mut self.pass)?;
370
371        CommandEncoder::drain_barriers(
372            self.pass.base.raw_encoder,
373            &mut self.intermediate_trackers,
374            self.pass.base.snatch_guard,
375        );
376        Ok(())
377    }
378}
379
380// Running the compute pass.
381
382impl Global {
383    /// Creates a compute pass.
384    ///
385    /// If creation fails, an invalid pass is returned. Attempting to record
386    /// commands into an invalid pass is permitted, but a validation error will
387    /// ultimately be generated when the parent encoder is finished, and it is
388    /// not possible to run any commands from the invalid pass.
389    ///
390    /// If successful, puts the encoder into the [`Locked`] state.
391    ///
392    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
393    pub fn command_encoder_begin_compute_pass(
394        &self,
395        encoder_id: id::CommandEncoderId,
396        desc: &ComputePassDescriptor<'_>,
397    ) -> (ComputePass, Option<CommandEncoderError>) {
398        use EncoderStateError as SErr;
399
400        let scope = PassErrorScope::Pass;
401        let hub = &self.hub;
402
403        let label = desc.label.as_deref().map(Cow::Borrowed);
404
405        let cmd_enc = hub.command_encoders.get(encoder_id);
406        let mut cmd_buf_data = cmd_enc.data.lock();
407
408        match cmd_buf_data.lock_encoder() {
409            Ok(()) => {
410                drop(cmd_buf_data);
411                if let Err(err) = cmd_enc.device.check_is_valid() {
412                    return (
413                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414                        None,
415                    );
416                }
417
418                match desc
419                    .timestamp_writes
420                    .as_ref()
421                    .map(|tw| {
422                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
423                            &cmd_enc.device,
424                            &hub.query_sets.read(),
425                            tw,
426                        )
427                    })
428                    .transpose()
429                {
430                    Ok(timestamp_writes) => {
431                        let arc_desc = ArcComputePassDescriptor {
432                            label,
433                            timestamp_writes,
434                        };
435                        (ComputePass::new(cmd_enc, arc_desc), None)
436                    }
437                    Err(err) => (
438                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
439                        None,
440                    ),
441                }
442            }
443            Err(err @ SErr::Locked) => {
444                // Attempting to open a new pass while the encoder is locked
445                // invalidates the encoder, but does not generate a validation
446                // error.
447                cmd_buf_data.invalidate(err.clone());
448                drop(cmd_buf_data);
449                (
450                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
451                    None,
452                )
453            }
454            Err(err @ (SErr::Ended | SErr::Submitted)) => {
455                // Attempting to open a new pass after the encode has ended
456                // generates an immediate validation error.
457                drop(cmd_buf_data);
458                (
459                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
460                    Some(err.into()),
461                )
462            }
463            Err(err @ SErr::Invalid) => {
464                // Passes can be opened even on an invalid encoder. Such passes
465                // are even valid, but since there's no visible side-effect of
466                // the pass being valid and there's no point in storing recorded
467                // commands that will ultimately be discarded, we open an
468                // invalid pass to save that work.
469                drop(cmd_buf_data);
470                (
471                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
472                    None,
473                )
474            }
475            Err(SErr::Unlocked) => {
476                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
477            }
478        }
479    }
480
481    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
482        profiling::scope!(
483            "CommandEncoder::run_compute_pass {}",
484            pass.base.label.as_deref().unwrap_or("")
485        );
486
487        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
488        let mut cmd_buf_data = cmd_enc.data.lock();
489
490        cmd_buf_data.unlock_encoder()?;
491
492        let base = pass.base.take();
493
494        if let Err(ComputePassError {
495            inner:
496                ComputePassErrorInner::EncoderState(
497                    err @ (EncoderStateError::Locked | EncoderStateError::Ended),
498                ),
499            scope: _,
500        }) = base
501        {
502            // Most encoding errors are detected and raised within `finish()`.
503            //
504            // However, we raise a validation error here if the pass was opened
505            // within another pass, or on a finished encoder. The latter is
506            // particularly important, because in that case reporting errors via
507            // `CommandEncoder::finish` is not possible.
508            return Err(err.clone());
509        }
510
511        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
512            Ok(ArcCommand::RunComputePass {
513                pass: base?,
514                timestamp_writes: pass.timestamp_writes.take(),
515            })
516        })
517    }
518}
519
520pub(super) fn encode_compute_pass(
521    parent_state: &mut EncodingState<InnerCommandEncoder>,
522    mut base: BasePass<ArcComputeCommand, Infallible>,
523    mut timestamp_writes: Option<ArcPassTimestampWrites>,
524) -> Result<(), ComputePassError> {
525    let pass_scope = PassErrorScope::Pass;
526
527    let device = parent_state.device;
528
529    // We automatically keep extending command buffers over time, and because
530    // we want to insert a command buffer _before_ what we're about to record,
531    // we need to make sure to close the previous one.
532    parent_state
533        .raw_encoder
534        .close_if_open()
535        .map_pass_err(pass_scope)?;
536    let raw_encoder = parent_state
537        .raw_encoder
538        .open_pass(base.label.as_deref())
539        .map_pass_err(pass_scope)?;
540
541    let mut debug_scope_depth = 0;
542
543    let mut state = State {
544        pipeline: None,
545
546        pass: pass::PassState {
547            base: EncodingState {
548                device,
549                raw_encoder,
550                tracker: parent_state.tracker,
551                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
552                texture_memory_actions: parent_state.texture_memory_actions,
553                as_actions: parent_state.as_actions,
554                temp_resources: parent_state.temp_resources,
555                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
556                snatch_guard: parent_state.snatch_guard,
557                debug_scope_depth: &mut debug_scope_depth,
558            },
559            binder: Binder::new(),
560            temp_offsets: Vec::new(),
561            dynamic_offset_count: 0,
562            pending_discard_init_fixups: SurfacesInDiscardState::new(),
563            scope: device.new_usage_scope(),
564            string_offset: 0,
565        },
566        active_query: None,
567
568        push_constants: Vec::new(),
569
570        intermediate_trackers: Tracker::new(),
571    };
572
573    let indices = &device.tracker_indices;
574    state
575        .pass
576        .base
577        .tracker
578        .buffers
579        .set_size(indices.buffers.size());
580    state
581        .pass
582        .base
583        .tracker
584        .textures
585        .set_size(indices.textures.size());
586
587    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
588        if let Some(tw) = timestamp_writes.take() {
589            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
590
591            let query_set = state
592                .pass
593                .base
594                .tracker
595                .query_sets
596                .insert_single(tw.query_set);
597
598            // Unlike in render passes we can't delay resetting the query sets since
599            // there is no auxiliary pass.
600            let range = if let (Some(index_a), Some(index_b)) =
601                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
602            {
603                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
604            } else {
605                tw.beginning_of_pass_write_index
606                    .or(tw.end_of_pass_write_index)
607                    .map(|i| i..i + 1)
608            };
609            // Range should always be Some, both values being None should lead to a validation error.
610            // But no point in erroring over that nuance here!
611            if let Some(range) = range {
612                unsafe {
613                    state
614                        .pass
615                        .base
616                        .raw_encoder
617                        .reset_queries(query_set.raw(), range);
618                }
619            }
620
621            Some(hal::PassTimestampWrites {
622                query_set: query_set.raw(),
623                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
624                end_of_pass_write_index: tw.end_of_pass_write_index,
625            })
626        } else {
627            None
628        };
629
630    let hal_desc = hal::ComputePassDescriptor {
631        label: hal_label(base.label.as_deref(), device.instance_flags),
632        timestamp_writes,
633    };
634
635    unsafe {
636        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
637    }
638
639    for command in base.commands.drain(..) {
640        match command {
641            ArcComputeCommand::SetBindGroup {
642                index,
643                num_dynamic_offsets,
644                bind_group,
645            } => {
646                let scope = PassErrorScope::SetBindGroup;
647                pass::set_bind_group::<ComputePassErrorInner>(
648                    &mut state.pass,
649                    device,
650                    &base.dynamic_offsets,
651                    index,
652                    num_dynamic_offsets,
653                    bind_group,
654                    false,
655                )
656                .map_pass_err(scope)?;
657            }
658            ArcComputeCommand::SetPipeline(pipeline) => {
659                let scope = PassErrorScope::SetPipelineCompute;
660                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
661            }
662            ArcComputeCommand::SetPushConstant {
663                offset,
664                size_bytes,
665                values_offset,
666            } => {
667                let scope = PassErrorScope::SetPushConstant;
668                pass::set_push_constant::<ComputePassErrorInner, _>(
669                    &mut state.pass,
670                    &base.push_constant_data,
671                    wgt::ShaderStages::COMPUTE,
672                    offset,
673                    size_bytes,
674                    Some(values_offset),
675                    |data_slice| {
676                        let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
677                        let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
678                        state.push_constants[offset_in_elements..][..size_in_elements]
679                            .copy_from_slice(data_slice);
680                    },
681                )
682                .map_pass_err(scope)?;
683            }
684            ArcComputeCommand::Dispatch(groups) => {
685                let scope = PassErrorScope::Dispatch { indirect: false };
686                dispatch(&mut state, groups).map_pass_err(scope)?;
687            }
688            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
689                let scope = PassErrorScope::Dispatch { indirect: true };
690                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
691            }
692            ArcComputeCommand::PushDebugGroup { color: _, len } => {
693                pass::push_debug_group(&mut state.pass, &base.string_data, len);
694            }
695            ArcComputeCommand::PopDebugGroup => {
696                let scope = PassErrorScope::PopDebugGroup;
697                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
698                    .map_pass_err(scope)?;
699            }
700            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
701                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
702            }
703            ArcComputeCommand::WriteTimestamp {
704                query_set,
705                query_index,
706            } => {
707                let scope = PassErrorScope::WriteTimestamp;
708                pass::write_timestamp::<ComputePassErrorInner>(
709                    &mut state.pass,
710                    device,
711                    None, // compute passes do not attempt to coalesce query resets
712                    query_set,
713                    query_index,
714                )
715                .map_pass_err(scope)?;
716            }
717            ArcComputeCommand::BeginPipelineStatisticsQuery {
718                query_set,
719                query_index,
720            } => {
721                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
722                validate_and_begin_pipeline_statistics_query(
723                    query_set,
724                    state.pass.base.raw_encoder,
725                    &mut state.pass.base.tracker.query_sets,
726                    device,
727                    query_index,
728                    None,
729                    &mut state.active_query,
730                )
731                .map_pass_err(scope)?;
732            }
733            ArcComputeCommand::EndPipelineStatisticsQuery => {
734                let scope = PassErrorScope::EndPipelineStatisticsQuery;
735                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
736                    .map_pass_err(scope)?;
737            }
738        }
739    }
740
741    if *state.pass.base.debug_scope_depth > 0 {
742        Err(
743            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
744                .map_pass_err(pass_scope),
745        )?;
746    }
747
748    unsafe {
749        state.pass.base.raw_encoder.end_compute_pass();
750    }
751
752    let State {
753        pass: pass::PassState {
754            pending_discard_init_fixups,
755            ..
756        },
757        intermediate_trackers,
758        ..
759    } = state;
760
761    // Stop the current command encoder.
762    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
763
764    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
765    //
766    // Use that buffer to insert barriers and clear discarded images.
767    let transit = parent_state
768        .raw_encoder
769        .open_pass(hal_label(
770            Some("(wgpu internal) Pre Pass"),
771            device.instance_flags,
772        ))
773        .map_pass_err(pass_scope)?;
774    fixup_discarded_surfaces(
775        pending_discard_init_fixups.into_iter(),
776        transit,
777        &mut parent_state.tracker.textures,
778        device,
779        parent_state.snatch_guard,
780    );
781    CommandEncoder::insert_barriers_from_tracker(
782        transit,
783        parent_state.tracker,
784        &intermediate_trackers,
785        parent_state.snatch_guard,
786    );
787    // Close the command encoder, and swap it with the previous.
788    parent_state
789        .raw_encoder
790        .close_and_swap()
791        .map_pass_err(pass_scope)?;
792
793    Ok(())
794}
795
796fn set_pipeline(
797    state: &mut State,
798    device: &Arc<Device>,
799    pipeline: Arc<ComputePipeline>,
800) -> Result<(), ComputePassErrorInner> {
801    pipeline.same_device(device)?;
802
803    state.pipeline = Some(pipeline.clone());
804
805    let pipeline = state
806        .pass
807        .base
808        .tracker
809        .compute_pipelines
810        .insert_single(pipeline)
811        .clone();
812
813    unsafe {
814        state
815            .pass
816            .base
817            .raw_encoder
818            .set_compute_pipeline(pipeline.raw());
819    }
820
821    // Rebind resources
822    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
823        &mut state.pass,
824        &pipeline.layout,
825        &pipeline.late_sized_buffer_groups,
826        || {
827            // This only needs to be here for compute pipelines because they use push constants for
828            // validating indirect draws.
829            state.push_constants.clear();
830            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
831            if let Some(push_constant_range) =
832                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
833                    pcr.stages
834                        .contains(wgt::ShaderStages::COMPUTE)
835                        .then_some(pcr.range.clone())
836                })
837            {
838                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
839                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
840                state.push_constants.extend(core::iter::repeat_n(0, len));
841            }
842        },
843    )
844}
845
846fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
847    api_log!("ComputePass::dispatch {groups:?}");
848
849    state.is_ready()?;
850
851    state.flush_bindings(None, false)?;
852
853    let groups_size_limit = state
854        .pass
855        .base
856        .device
857        .limits
858        .max_compute_workgroups_per_dimension;
859
860    if groups[0] > groups_size_limit
861        || groups[1] > groups_size_limit
862        || groups[2] > groups_size_limit
863    {
864        return Err(ComputePassErrorInner::Dispatch(
865            DispatchError::InvalidGroupSize {
866                current: groups,
867                limit: groups_size_limit,
868            },
869        ));
870    }
871
872    unsafe {
873        state.pass.base.raw_encoder.dispatch(groups);
874    }
875    Ok(())
876}
877
878fn dispatch_indirect(
879    state: &mut State,
880    device: &Arc<Device>,
881    buffer: Arc<Buffer>,
882    offset: u64,
883) -> Result<(), ComputePassErrorInner> {
884    api_log!("ComputePass::dispatch_indirect");
885
886    buffer.same_device(device)?;
887
888    state.is_ready()?;
889
890    state
891        .pass
892        .base
893        .device
894        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
895
896    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
897    buffer.check_destroyed(state.pass.base.snatch_guard)?;
898
899    if offset % 4 != 0 {
900        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
901    }
902
903    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
904    if end_offset > buffer.size {
905        return Err(ComputePassErrorInner::IndirectBufferOverrun {
906            offset,
907            end_offset,
908            buffer_size: buffer.size,
909        });
910    }
911
912    let stride = 3 * 4; // 3 integers, x/y/z group size
913    state.pass.base.buffer_memory_init_actions.extend(
914        buffer.initialization_status.read().create_action(
915            &buffer,
916            offset..(offset + stride),
917            MemoryInitKind::NeedsInitializedMemory,
918        ),
919    );
920
921    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
922        let params = indirect_validation.dispatch.params(
923            &state.pass.base.device.limits,
924            offset,
925            buffer.size,
926        );
927
928        unsafe {
929            state
930                .pass
931                .base
932                .raw_encoder
933                .set_compute_pipeline(params.pipeline);
934        }
935
936        unsafe {
937            state.pass.base.raw_encoder.set_push_constants(
938                params.pipeline_layout,
939                wgt::ShaderStages::COMPUTE,
940                0,
941                &[params.offset_remainder as u32 / 4],
942            );
943        }
944
945        unsafe {
946            state.pass.base.raw_encoder.set_bind_group(
947                params.pipeline_layout,
948                0,
949                Some(params.dst_bind_group),
950                &[],
951            );
952        }
953        unsafe {
954            state.pass.base.raw_encoder.set_bind_group(
955                params.pipeline_layout,
956                1,
957                Some(
958                    buffer
959                        .indirect_validation_bind_groups
960                        .get(state.pass.base.snatch_guard)
961                        .unwrap()
962                        .dispatch
963                        .as_ref(),
964                ),
965                &[params.aligned_offset as u32],
966            );
967        }
968
969        let src_transition = state
970            .intermediate_trackers
971            .buffers
972            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
973        let src_barrier = src_transition
974            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
975        unsafe {
976            state
977                .pass
978                .base
979                .raw_encoder
980                .transition_buffers(src_barrier.as_slice());
981        }
982
983        unsafe {
984            state
985                .pass
986                .base
987                .raw_encoder
988                .transition_buffers(&[hal::BufferBarrier {
989                    buffer: params.dst_buffer,
990                    usage: hal::StateTransition {
991                        from: wgt::BufferUses::INDIRECT,
992                        to: wgt::BufferUses::STORAGE_READ_WRITE,
993                    },
994                }]);
995        }
996
997        unsafe {
998            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
999        }
1000
1001        // reset state
1002        {
1003            let pipeline = state.pipeline.as_ref().unwrap();
1004
1005            unsafe {
1006                state
1007                    .pass
1008                    .base
1009                    .raw_encoder
1010                    .set_compute_pipeline(pipeline.raw());
1011            }
1012
1013            if !state.push_constants.is_empty() {
1014                unsafe {
1015                    state.pass.base.raw_encoder.set_push_constants(
1016                        pipeline.layout.raw(),
1017                        wgt::ShaderStages::COMPUTE,
1018                        0,
1019                        &state.push_constants,
1020                    );
1021                }
1022            }
1023
1024            for (i, e) in state.pass.binder.list_valid() {
1025                let group = e.group.as_ref().unwrap();
1026                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1027                unsafe {
1028                    state.pass.base.raw_encoder.set_bind_group(
1029                        pipeline.layout.raw(),
1030                        i as u32,
1031                        Some(raw_bg),
1032                        &e.dynamic_offsets,
1033                    );
1034                }
1035            }
1036        }
1037
1038        unsafe {
1039            state
1040                .pass
1041                .base
1042                .raw_encoder
1043                .transition_buffers(&[hal::BufferBarrier {
1044                    buffer: params.dst_buffer,
1045                    usage: hal::StateTransition {
1046                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1047                        to: wgt::BufferUses::INDIRECT,
1048                    },
1049                }]);
1050        }
1051
1052        state.flush_bindings(Some(&buffer), false)?;
1053        unsafe {
1054            state
1055                .pass
1056                .base
1057                .raw_encoder
1058                .dispatch_indirect(params.dst_buffer, 0);
1059        }
1060    } else {
1061        state.flush_bindings(Some(&buffer), true)?;
1062
1063        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1064        unsafe {
1065            state
1066                .pass
1067                .base
1068                .raw_encoder
1069                .dispatch_indirect(buf_raw, offset);
1070        }
1071    }
1072
1073    Ok(())
1074}
1075
1076// Recording a compute pass.
1077//
1078// The only error that should be returned from these methods is
1079// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1080// validation error is raised.
1081//
1082// All other errors should be stored in the pass for later reporting when
1083// `CommandEncoder.finish()` is called.
1084//
1085// The `pass_try!` macro should be used to handle errors appropriately. Note
1086// that the `pass_try!` and `pass_base!` macros may return early from the
1087// function that invokes them, like the `?` operator.
1088impl Global {
1089    pub fn compute_pass_set_bind_group(
1090        &self,
1091        pass: &mut ComputePass,
1092        index: u32,
1093        bind_group_id: Option<id::BindGroupId>,
1094        offsets: &[DynamicOffset],
1095    ) -> Result<(), PassStateError> {
1096        let scope = PassErrorScope::SetBindGroup;
1097
1098        // This statement will return an error if the pass is ended. It's
1099        // important the error check comes before the early-out for
1100        // `set_and_check_redundant`.
1101        let base = pass_base!(pass, scope);
1102
1103        if pass.current_bind_groups.set_and_check_redundant(
1104            bind_group_id,
1105            index,
1106            &mut base.dynamic_offsets,
1107            offsets,
1108        ) {
1109            return Ok(());
1110        }
1111
1112        let mut bind_group = None;
1113        if bind_group_id.is_some() {
1114            let bind_group_id = bind_group_id.unwrap();
1115
1116            let hub = &self.hub;
1117            bind_group = Some(pass_try!(
1118                base,
1119                scope,
1120                hub.bind_groups.get(bind_group_id).get(),
1121            ));
1122        }
1123
1124        base.commands.push(ArcComputeCommand::SetBindGroup {
1125            index,
1126            num_dynamic_offsets: offsets.len(),
1127            bind_group,
1128        });
1129
1130        Ok(())
1131    }
1132
1133    pub fn compute_pass_set_pipeline(
1134        &self,
1135        pass: &mut ComputePass,
1136        pipeline_id: id::ComputePipelineId,
1137    ) -> Result<(), PassStateError> {
1138        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1139
1140        let scope = PassErrorScope::SetPipelineCompute;
1141
1142        // This statement will return an error if the pass is ended.
1143        // Its important the error check comes before the early-out for `redundant`.
1144        let base = pass_base!(pass, scope);
1145
1146        if redundant {
1147            return Ok(());
1148        }
1149
1150        let hub = &self.hub;
1151        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1152
1153        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1154
1155        Ok(())
1156    }
1157
1158    pub fn compute_pass_set_push_constants(
1159        &self,
1160        pass: &mut ComputePass,
1161        offset: u32,
1162        data: &[u8],
1163    ) -> Result<(), PassStateError> {
1164        let scope = PassErrorScope::SetPushConstant;
1165        let base = pass_base!(pass, scope);
1166
1167        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1168            pass_try!(
1169                base,
1170                scope,
1171                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1172            );
1173        }
1174
1175        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1176            pass_try!(
1177                base,
1178                scope,
1179                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1180            )
1181        }
1182        let value_offset = pass_try!(
1183            base,
1184            scope,
1185            base.push_constant_data
1186                .len()
1187                .try_into()
1188                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1189        );
1190
1191        base.push_constant_data.extend(
1192            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1193                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1194        );
1195
1196        base.commands.push(ArcComputeCommand::SetPushConstant {
1197            offset,
1198            size_bytes: data.len() as u32,
1199            values_offset: value_offset,
1200        });
1201
1202        Ok(())
1203    }
1204
1205    pub fn compute_pass_dispatch_workgroups(
1206        &self,
1207        pass: &mut ComputePass,
1208        groups_x: u32,
1209        groups_y: u32,
1210        groups_z: u32,
1211    ) -> Result<(), PassStateError> {
1212        let scope = PassErrorScope::Dispatch { indirect: false };
1213
1214        pass_base!(pass, scope)
1215            .commands
1216            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1217
1218        Ok(())
1219    }
1220
1221    pub fn compute_pass_dispatch_workgroups_indirect(
1222        &self,
1223        pass: &mut ComputePass,
1224        buffer_id: id::BufferId,
1225        offset: BufferAddress,
1226    ) -> Result<(), PassStateError> {
1227        let hub = &self.hub;
1228        let scope = PassErrorScope::Dispatch { indirect: true };
1229        let base = pass_base!(pass, scope);
1230
1231        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1232
1233        base.commands
1234            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1235
1236        Ok(())
1237    }
1238
1239    pub fn compute_pass_push_debug_group(
1240        &self,
1241        pass: &mut ComputePass,
1242        label: &str,
1243        color: u32,
1244    ) -> Result<(), PassStateError> {
1245        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1246
1247        let bytes = label.as_bytes();
1248        base.string_data.extend_from_slice(bytes);
1249
1250        base.commands.push(ArcComputeCommand::PushDebugGroup {
1251            color,
1252            len: bytes.len(),
1253        });
1254
1255        Ok(())
1256    }
1257
1258    pub fn compute_pass_pop_debug_group(
1259        &self,
1260        pass: &mut ComputePass,
1261    ) -> Result<(), PassStateError> {
1262        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1263
1264        base.commands.push(ArcComputeCommand::PopDebugGroup);
1265
1266        Ok(())
1267    }
1268
1269    pub fn compute_pass_insert_debug_marker(
1270        &self,
1271        pass: &mut ComputePass,
1272        label: &str,
1273        color: u32,
1274    ) -> Result<(), PassStateError> {
1275        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1276
1277        let bytes = label.as_bytes();
1278        base.string_data.extend_from_slice(bytes);
1279
1280        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1281            color,
1282            len: bytes.len(),
1283        });
1284
1285        Ok(())
1286    }
1287
1288    pub fn compute_pass_write_timestamp(
1289        &self,
1290        pass: &mut ComputePass,
1291        query_set_id: id::QuerySetId,
1292        query_index: u32,
1293    ) -> Result<(), PassStateError> {
1294        let scope = PassErrorScope::WriteTimestamp;
1295        let base = pass_base!(pass, scope);
1296
1297        let hub = &self.hub;
1298        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300        base.commands.push(ArcComputeCommand::WriteTimestamp {
1301            query_set,
1302            query_index,
1303        });
1304
1305        Ok(())
1306    }
1307
1308    pub fn compute_pass_begin_pipeline_statistics_query(
1309        &self,
1310        pass: &mut ComputePass,
1311        query_set_id: id::QuerySetId,
1312        query_index: u32,
1313    ) -> Result<(), PassStateError> {
1314        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1315        let base = pass_base!(pass, scope);
1316
1317        let hub = &self.hub;
1318        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1319
1320        base.commands
1321            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1322                query_set,
1323                query_index,
1324            });
1325
1326        Ok(())
1327    }
1328
1329    pub fn compute_pass_end_pipeline_statistics_query(
1330        &self,
1331        pass: &mut ComputePass,
1332    ) -> Result<(), PassStateError> {
1333        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1334            .commands
1335            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1336
1337        Ok(())
1338    }
1339}