wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11    api_log, binding_model::BindError, command::pass::flush_bindings_helper,
12    resource::RawResourceAccess,
13};
14use crate::{
15    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
16    command::{
17        bind::{Binder, BinderError},
18        compute_command::ArcComputeCommand,
19        end_pipeline_statistics_query,
20        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
21        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
22        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
23        PassTimestampWrites, QueryUseError, StateChange,
24    },
25    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
26    global::Global,
27    hal_label, id,
28    init_tracker::MemoryInitKind,
29    pipeline::ComputePipeline,
30    resource::{
31        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
32    },
33    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
34    Label,
35};
36use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
37use crate::{
38    command::{
39        encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
40        EncoderStateError, PassStateError, TimestampWritesError,
41    },
42    device::Device,
43};
44
45pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
46
47/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
48/// its validity are two distinct conditions, i.e., the full matrix of
49/// (open, ended) x (valid, invalid) is possible.
50///
51/// The presence or absence of the `parent` `Option` indicates the pass's state.
52/// The presence or absence of an error in `base.error` indicates the pass's
53/// validity.
54pub struct ComputePass {
55    /// All pass data & records is stored here.
56    base: ComputeBasePass,
57
58    /// Parent command encoder that this pass records commands into.
59    ///
60    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
61    /// `None`, then the pass is in the "ended" state.
62    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
63    parent: Option<Arc<CommandEncoder>>,
64
65    timestamp_writes: Option<ArcPassTimestampWrites>,
66
67    // Resource binding dedupe state.
68    current_bind_groups: BindGroupStateChange,
69    current_pipeline: StateChange<id::ComputePipelineId>,
70}
71
72impl ComputePass {
73    /// If the parent command encoder is invalid, the returned pass will be invalid.
74    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
75        let ArcComputePassDescriptor {
76            label,
77            timestamp_writes,
78        } = desc;
79
80        Self {
81            base: BasePass::new(&label),
82            parent: Some(parent),
83            timestamp_writes,
84
85            current_bind_groups: BindGroupStateChange::new(),
86            current_pipeline: StateChange::new(),
87        }
88    }
89
90    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
91        Self {
92            base: BasePass::new_invalid(label, err),
93            parent: Some(parent),
94            timestamp_writes: None,
95            current_bind_groups: BindGroupStateChange::new(),
96            current_pipeline: StateChange::new(),
97        }
98    }
99
100    #[inline]
101    pub fn label(&self) -> Option<&str> {
102        self.base.label.as_deref()
103    }
104}
105
106impl fmt::Debug for ComputePass {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        match self.parent {
109            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
110            None => write!(f, "ComputePass {{ parent: None }}"),
111        }
112    }
113}
114
115#[derive(Clone, Debug, Default)]
116pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
117    pub label: Label<'a>,
118    /// Defines where and when timestamp values will be written for this pass.
119    pub timestamp_writes: Option<PTW>,
120}
121
122/// cbindgen:ignore
123type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
124
125#[derive(Clone, Debug, Error)]
126#[non_exhaustive]
127pub enum DispatchError {
128    #[error("Compute pipeline must be set")]
129    MissingPipeline(pass::MissingPipeline),
130    #[error(transparent)]
131    IncompatibleBindGroup(#[from] Box<BinderError>),
132    #[error(
133        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
134    )]
135    InvalidGroupSize { current: [u32; 3], limit: u32 },
136    #[error(transparent)]
137    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
138}
139
140impl WebGpuError for DispatchError {
141    fn webgpu_error_type(&self) -> ErrorType {
142        ErrorType::Validation
143    }
144}
145
146/// Error encountered when performing a compute pass.
147#[derive(Clone, Debug, Error)]
148pub enum ComputePassErrorInner {
149    #[error(transparent)]
150    Device(#[from] DeviceError),
151    #[error(transparent)]
152    EncoderState(#[from] EncoderStateError),
153    #[error("Parent encoder is invalid")]
154    InvalidParentEncoder,
155    #[error(transparent)]
156    DebugGroupError(#[from] DebugGroupError),
157    #[error(transparent)]
158    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
159    #[error(transparent)]
160    DestroyedResource(#[from] DestroyedResourceError),
161    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
162    UnalignedIndirectBufferOffset(BufferAddress),
163    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
164    IndirectBufferOverrun {
165        offset: u64,
166        end_offset: u64,
167        buffer_size: u64,
168    },
169    #[error(transparent)]
170    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
171    #[error(transparent)]
172    MissingBufferUsage(#[from] MissingBufferUsageError),
173    #[error(transparent)]
174    Dispatch(#[from] DispatchError),
175    #[error(transparent)]
176    Bind(#[from] BindError),
177    #[error(transparent)]
178    PushConstants(#[from] PushConstantUploadError),
179    #[error("Push constant offset must be aligned to 4 bytes")]
180    PushConstantOffsetAlignment,
181    #[error("Push constant size must be aligned to 4 bytes")]
182    PushConstantSizeAlignment,
183    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
184    PushConstantOutOfMemory,
185    #[error(transparent)]
186    QueryUse(#[from] QueryUseError),
187    #[error(transparent)]
188    MissingFeatures(#[from] MissingFeatures),
189    #[error(transparent)]
190    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
191    #[error("The compute pass has already been ended and no further commands can be recorded")]
192    PassEnded,
193    #[error(transparent)]
194    InvalidResource(#[from] InvalidResourceError),
195    #[error(transparent)]
196    TimestampWrites(#[from] TimestampWritesError),
197    // This one is unreachable, but required for generic pass support
198    #[error(transparent)]
199    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
200}
201
202/// Error encountered when performing a compute pass, stored for later reporting
203/// when encoding ends.
204#[derive(Clone, Debug, Error)]
205#[error("{scope}")]
206pub struct ComputePassError {
207    pub scope: PassErrorScope,
208    #[source]
209    pub(super) inner: ComputePassErrorInner,
210}
211
212impl From<pass::MissingPipeline> for ComputePassErrorInner {
213    fn from(value: pass::MissingPipeline) -> Self {
214        Self::Dispatch(DispatchError::MissingPipeline(value))
215    }
216}
217
218impl<E> MapPassErr<ComputePassError> for E
219where
220    E: Into<ComputePassErrorInner>,
221{
222    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
223        ComputePassError {
224            scope,
225            inner: self.into(),
226        }
227    }
228}
229
230impl WebGpuError for ComputePassError {
231    fn webgpu_error_type(&self) -> ErrorType {
232        let Self { scope: _, inner } = self;
233        let e: &dyn WebGpuError = match inner {
234            ComputePassErrorInner::Device(e) => e,
235            ComputePassErrorInner::EncoderState(e) => e,
236            ComputePassErrorInner::DebugGroupError(e) => e,
237            ComputePassErrorInner::DestroyedResource(e) => e,
238            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
239            ComputePassErrorInner::MissingBufferUsage(e) => e,
240            ComputePassErrorInner::Dispatch(e) => e,
241            ComputePassErrorInner::Bind(e) => e,
242            ComputePassErrorInner::PushConstants(e) => e,
243            ComputePassErrorInner::QueryUse(e) => e,
244            ComputePassErrorInner::MissingFeatures(e) => e,
245            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
246            ComputePassErrorInner::InvalidResource(e) => e,
247            ComputePassErrorInner::TimestampWrites(e) => e,
248            ComputePassErrorInner::InvalidValuesOffset(e) => e,
249
250            ComputePassErrorInner::InvalidParentEncoder
251            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
252            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
253            | ComputePassErrorInner::IndirectBufferOverrun { .. }
254            | ComputePassErrorInner::PushConstantOffsetAlignment
255            | ComputePassErrorInner::PushConstantSizeAlignment
256            | ComputePassErrorInner::PushConstantOutOfMemory
257            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
258        };
259        e.webgpu_error_type()
260    }
261}
262
263struct State<'scope, 'snatch_guard, 'cmd_enc> {
264    pipeline: Option<Arc<ComputePipeline>>,
265
266    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
267
268    active_query: Option<(Arc<resource::QuerySet>, u32)>,
269
270    push_constants: Vec<u32>,
271
272    intermediate_trackers: Tracker,
273}
274
275impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
276    fn is_ready(&self) -> Result<(), DispatchError> {
277        if let Some(pipeline) = self.pipeline.as_ref() {
278            self.pass.binder.check_compatibility(pipeline.as_ref())?;
279            self.pass.binder.check_late_buffer_bindings()?;
280            Ok(())
281        } else {
282            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
283        }
284    }
285
286    /// Flush binding state in preparation for a dispatch.
287    ///
288    /// # Differences between render and compute passes
289    ///
290    /// There are differences between the `flush_bindings` implementations for
291    /// render and compute passes, because render passes have a single usage
292    /// scope for the entire pass, and compute passes have a separate usage
293    /// scope for each dispatch.
294    ///
295    /// For compute passes, bind groups are merged into a fresh usage scope
296    /// here, not into the pass usage scope within calls to `set_bind_group`. As
297    /// specified by WebGPU, for compute passes, we merge only the bind groups
298    /// that are actually used by the pipeline, unlike render passes, which
299    /// merge every bind group that is ever set, even if it is not ultimately
300    /// used by the pipeline.
301    ///
302    /// For compute passes, we call `drain_barriers` here, because barriers may
303    /// be needed before each dispatch if a previous dispatch had a conflicting
304    /// usage. For render passes, barriers are emitted once at the start of the
305    /// render pass.
306    ///
307    /// # Indirect buffer handling
308    ///
309    /// For indirect dispatches without validation, pass both `indirect_buffer`
310    /// and `indirect_buffer_index_if_not_validating`. The indirect buffer will
311    /// be added to the usage scope and the tracker.
312    ///
313    /// For indirect dispatches with validation, pass only `indirect_buffer`.
314    /// The indirect buffer will be added to the usage scope to detect usage
315    /// conflicts. The indirect buffer does not need to be added to the tracker;
316    /// the indirect validation code handles transitions manually.
317    fn flush_bindings(
318        &mut self,
319        indirect_buffer: Option<&Arc<Buffer>>,
320        indirect_buffer_index_if_not_validating: Option<TrackerIndex>,
321    ) -> Result<(), ComputePassErrorInner> {
322        let mut scope = self.pass.base.device.new_usage_scope();
323
324        for bind_group in self.pass.binder.list_active() {
325            unsafe { scope.merge_bind_group(&bind_group.used)? };
326        }
327
328        // When indirect validation is turned on, our actual use of the buffer
329        // is `STORAGE_READ_ONLY`, but for usage scope validation, we still want
330        // to treat it as indirect so we can detect the conflicts prescribed by
331        // WebGPU. The usage scope we construct here never leaves this function
332        // (and is not used to populate a tracker), so it's fine to do this.
333        if let Some(buffer) = indirect_buffer {
334            scope
335                .buffers
336                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
337        }
338
339        // Add the state of the indirect buffer, if needed (see above).
340        self.intermediate_trackers
341            .buffers
342            .set_multiple(&mut scope.buffers, indirect_buffer_index_if_not_validating);
343
344        flush_bindings_helper(&mut self.pass, |bind_group| {
345            self.intermediate_trackers
346                .set_from_bind_group(&mut scope, &bind_group.used)
347        })?;
348
349        CommandEncoder::drain_barriers(
350            self.pass.base.raw_encoder,
351            &mut self.intermediate_trackers,
352            self.pass.base.snatch_guard,
353        );
354        Ok(())
355    }
356}
357
358// Running the compute pass.
359
360impl Global {
361    /// Creates a compute pass.
362    ///
363    /// If creation fails, an invalid pass is returned. Attempting to record
364    /// commands into an invalid pass is permitted, but a validation error will
365    /// ultimately be generated when the parent encoder is finished, and it is
366    /// not possible to run any commands from the invalid pass.
367    ///
368    /// If successful, puts the encoder into the [`Locked`] state.
369    ///
370    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
371    pub fn command_encoder_begin_compute_pass(
372        &self,
373        encoder_id: id::CommandEncoderId,
374        desc: &ComputePassDescriptor<'_>,
375    ) -> (ComputePass, Option<CommandEncoderError>) {
376        use EncoderStateError as SErr;
377
378        let scope = PassErrorScope::Pass;
379        let hub = &self.hub;
380
381        let label = desc.label.as_deref().map(Cow::Borrowed);
382
383        let cmd_enc = hub.command_encoders.get(encoder_id);
384        let mut cmd_buf_data = cmd_enc.data.lock();
385
386        match cmd_buf_data.lock_encoder() {
387            Ok(()) => {
388                drop(cmd_buf_data);
389                if let Err(err) = cmd_enc.device.check_is_valid() {
390                    return (
391                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
392                        None,
393                    );
394                }
395
396                match desc
397                    .timestamp_writes
398                    .as_ref()
399                    .map(|tw| {
400                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
401                            &cmd_enc.device,
402                            &hub.query_sets.read(),
403                            tw,
404                        )
405                    })
406                    .transpose()
407                {
408                    Ok(timestamp_writes) => {
409                        let arc_desc = ArcComputePassDescriptor {
410                            label,
411                            timestamp_writes,
412                        };
413                        (ComputePass::new(cmd_enc, arc_desc), None)
414                    }
415                    Err(err) => (
416                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
417                        None,
418                    ),
419                }
420            }
421            Err(err @ SErr::Locked) => {
422                // Attempting to open a new pass while the encoder is locked
423                // invalidates the encoder, but does not generate a validation
424                // error.
425                cmd_buf_data.invalidate(err.clone());
426                drop(cmd_buf_data);
427                (
428                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
429                    None,
430                )
431            }
432            Err(err @ (SErr::Ended | SErr::Submitted)) => {
433                // Attempting to open a new pass after the encode has ended
434                // generates an immediate validation error.
435                drop(cmd_buf_data);
436                (
437                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
438                    Some(err.into()),
439                )
440            }
441            Err(err @ SErr::Invalid) => {
442                // Passes can be opened even on an invalid encoder. Such passes
443                // are even valid, but since there's no visible side-effect of
444                // the pass being valid and there's no point in storing recorded
445                // commands that will ultimately be discarded, we open an
446                // invalid pass to save that work.
447                drop(cmd_buf_data);
448                (
449                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
450                    None,
451                )
452            }
453            Err(SErr::Unlocked) => {
454                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
455            }
456        }
457    }
458
459    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
460        profiling::scope!(
461            "CommandEncoder::run_compute_pass {}",
462            pass.base.label.as_deref().unwrap_or("")
463        );
464
465        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
466        let mut cmd_buf_data = cmd_enc.data.lock();
467
468        cmd_buf_data.unlock_encoder()?;
469
470        let base = pass.base.take();
471
472        if matches!(
473            base,
474            Err(ComputePassError {
475                inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
476                scope: _,
477            })
478        ) {
479            // If the encoder was already finished at time of pass creation,
480            // then it was not put in the locked state, so we need to
481            // generate a validation error here and now due to the encoder not
482            // being locked. The encoder already holds an error from when the
483            // pass was opened, or earlier.
484            //
485            // All other errors are propagated to the encoder within `push_with`,
486            // and will be reported later.
487            return Err(EncoderStateError::Ended);
488        }
489
490        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
491            Ok(ArcCommand::RunComputePass {
492                pass: base?,
493                timestamp_writes: pass.timestamp_writes.take(),
494            })
495        })
496    }
497}
498
499pub(super) fn encode_compute_pass(
500    parent_state: &mut EncodingState<InnerCommandEncoder>,
501    mut base: BasePass<ArcComputeCommand, Infallible>,
502    mut timestamp_writes: Option<ArcPassTimestampWrites>,
503) -> Result<(), ComputePassError> {
504    let pass_scope = PassErrorScope::Pass;
505
506    let device = parent_state.device;
507
508    // We automatically keep extending command buffers over time, and because
509    // we want to insert a command buffer _before_ what we're about to record,
510    // we need to make sure to close the previous one.
511    parent_state
512        .raw_encoder
513        .close_if_open()
514        .map_pass_err(pass_scope)?;
515    let raw_encoder = parent_state
516        .raw_encoder
517        .open_pass(base.label.as_deref())
518        .map_pass_err(pass_scope)?;
519
520    let mut debug_scope_depth = 0;
521
522    let mut state = State {
523        pipeline: None,
524
525        pass: pass::PassState {
526            base: EncodingState {
527                device,
528                raw_encoder,
529                tracker: parent_state.tracker,
530                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
531                texture_memory_actions: parent_state.texture_memory_actions,
532                as_actions: parent_state.as_actions,
533                temp_resources: parent_state.temp_resources,
534                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
535                snatch_guard: parent_state.snatch_guard,
536                debug_scope_depth: &mut debug_scope_depth,
537            },
538            binder: Binder::new(),
539            temp_offsets: Vec::new(),
540            dynamic_offset_count: 0,
541            pending_discard_init_fixups: SurfacesInDiscardState::new(),
542            scope: device.new_usage_scope(),
543            string_offset: 0,
544        },
545        active_query: None,
546
547        push_constants: Vec::new(),
548
549        intermediate_trackers: Tracker::new(),
550    };
551
552    let indices = &device.tracker_indices;
553    state
554        .pass
555        .base
556        .tracker
557        .buffers
558        .set_size(indices.buffers.size());
559    state
560        .pass
561        .base
562        .tracker
563        .textures
564        .set_size(indices.textures.size());
565
566    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
567        if let Some(tw) = timestamp_writes.take() {
568            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
569
570            let query_set = state
571                .pass
572                .base
573                .tracker
574                .query_sets
575                .insert_single(tw.query_set);
576
577            // Unlike in render passes we can't delay resetting the query sets since
578            // there is no auxiliary pass.
579            let range = if let (Some(index_a), Some(index_b)) =
580                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
581            {
582                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
583            } else {
584                tw.beginning_of_pass_write_index
585                    .or(tw.end_of_pass_write_index)
586                    .map(|i| i..i + 1)
587            };
588            // Range should always be Some, both values being None should lead to a validation error.
589            // But no point in erroring over that nuance here!
590            if let Some(range) = range {
591                unsafe {
592                    state
593                        .pass
594                        .base
595                        .raw_encoder
596                        .reset_queries(query_set.raw(), range);
597                }
598            }
599
600            Some(hal::PassTimestampWrites {
601                query_set: query_set.raw(),
602                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
603                end_of_pass_write_index: tw.end_of_pass_write_index,
604            })
605        } else {
606            None
607        };
608
609    let hal_desc = hal::ComputePassDescriptor {
610        label: hal_label(base.label.as_deref(), device.instance_flags),
611        timestamp_writes,
612    };
613
614    unsafe {
615        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
616    }
617
618    for command in base.commands.drain(..) {
619        match command {
620            ArcComputeCommand::SetBindGroup {
621                index,
622                num_dynamic_offsets,
623                bind_group,
624            } => {
625                let scope = PassErrorScope::SetBindGroup;
626                pass::set_bind_group::<ComputePassErrorInner>(
627                    &mut state.pass,
628                    device,
629                    &base.dynamic_offsets,
630                    index,
631                    num_dynamic_offsets,
632                    bind_group,
633                    false,
634                )
635                .map_pass_err(scope)?;
636            }
637            ArcComputeCommand::SetPipeline(pipeline) => {
638                let scope = PassErrorScope::SetPipelineCompute;
639                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
640            }
641            ArcComputeCommand::SetPushConstant {
642                offset,
643                size_bytes,
644                values_offset,
645            } => {
646                let scope = PassErrorScope::SetPushConstant;
647                pass::set_push_constant::<ComputePassErrorInner, _>(
648                    &mut state.pass,
649                    &base.push_constant_data,
650                    wgt::ShaderStages::COMPUTE,
651                    offset,
652                    size_bytes,
653                    Some(values_offset),
654                    |data_slice| {
655                        let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
656                        let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
657                        state.push_constants[offset_in_elements..][..size_in_elements]
658                            .copy_from_slice(data_slice);
659                    },
660                )
661                .map_pass_err(scope)?;
662            }
663            ArcComputeCommand::Dispatch(groups) => {
664                let scope = PassErrorScope::Dispatch { indirect: false };
665                dispatch(&mut state, groups).map_pass_err(scope)?;
666            }
667            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
668                let scope = PassErrorScope::Dispatch { indirect: true };
669                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
670            }
671            ArcComputeCommand::PushDebugGroup { color: _, len } => {
672                pass::push_debug_group(&mut state.pass, &base.string_data, len);
673            }
674            ArcComputeCommand::PopDebugGroup => {
675                let scope = PassErrorScope::PopDebugGroup;
676                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
677                    .map_pass_err(scope)?;
678            }
679            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
680                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
681            }
682            ArcComputeCommand::WriteTimestamp {
683                query_set,
684                query_index,
685            } => {
686                let scope = PassErrorScope::WriteTimestamp;
687                pass::write_timestamp::<ComputePassErrorInner>(
688                    &mut state.pass,
689                    device,
690                    None, // compute passes do not attempt to coalesce query resets
691                    query_set,
692                    query_index,
693                )
694                .map_pass_err(scope)?;
695            }
696            ArcComputeCommand::BeginPipelineStatisticsQuery {
697                query_set,
698                query_index,
699            } => {
700                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
701                validate_and_begin_pipeline_statistics_query(
702                    query_set,
703                    state.pass.base.raw_encoder,
704                    &mut state.pass.base.tracker.query_sets,
705                    device,
706                    query_index,
707                    None,
708                    &mut state.active_query,
709                )
710                .map_pass_err(scope)?;
711            }
712            ArcComputeCommand::EndPipelineStatisticsQuery => {
713                let scope = PassErrorScope::EndPipelineStatisticsQuery;
714                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
715                    .map_pass_err(scope)?;
716            }
717        }
718    }
719
720    if *state.pass.base.debug_scope_depth > 0 {
721        Err(
722            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
723                .map_pass_err(pass_scope),
724        )?;
725    }
726
727    unsafe {
728        state.pass.base.raw_encoder.end_compute_pass();
729    }
730
731    let State {
732        pass: pass::PassState {
733            pending_discard_init_fixups,
734            ..
735        },
736        intermediate_trackers,
737        ..
738    } = state;
739
740    // Stop the current command encoder.
741    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
742
743    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
744    //
745    // Use that buffer to insert barriers and clear discarded images.
746    let transit = parent_state
747        .raw_encoder
748        .open_pass(hal_label(
749            Some("(wgpu internal) Pre Pass"),
750            device.instance_flags,
751        ))
752        .map_pass_err(pass_scope)?;
753    fixup_discarded_surfaces(
754        pending_discard_init_fixups.into_iter(),
755        transit,
756        &mut parent_state.tracker.textures,
757        device,
758        parent_state.snatch_guard,
759    );
760    CommandEncoder::insert_barriers_from_tracker(
761        transit,
762        parent_state.tracker,
763        &intermediate_trackers,
764        parent_state.snatch_guard,
765    );
766    // Close the command encoder, and swap it with the previous.
767    parent_state
768        .raw_encoder
769        .close_and_swap()
770        .map_pass_err(pass_scope)?;
771
772    Ok(())
773}
774
775fn set_pipeline(
776    state: &mut State,
777    device: &Arc<Device>,
778    pipeline: Arc<ComputePipeline>,
779) -> Result<(), ComputePassErrorInner> {
780    pipeline.same_device(device)?;
781
782    state.pipeline = Some(pipeline.clone());
783
784    let pipeline = state
785        .pass
786        .base
787        .tracker
788        .compute_pipelines
789        .insert_single(pipeline)
790        .clone();
791
792    unsafe {
793        state
794            .pass
795            .base
796            .raw_encoder
797            .set_compute_pipeline(pipeline.raw());
798    }
799
800    // Rebind resources
801    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
802        &mut state.pass,
803        &pipeline.layout,
804        &pipeline.late_sized_buffer_groups,
805        || {
806            // This only needs to be here for compute pipelines because they use push constants for
807            // validating indirect draws.
808            state.push_constants.clear();
809            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
810            if let Some(push_constant_range) =
811                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
812                    pcr.stages
813                        .contains(wgt::ShaderStages::COMPUTE)
814                        .then_some(pcr.range.clone())
815                })
816            {
817                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
818                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
819                state.push_constants.extend(core::iter::repeat_n(0, len));
820            }
821        },
822    )
823}
824
825fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
826    api_log!("ComputePass::dispatch {groups:?}");
827
828    state.is_ready()?;
829
830    state.flush_bindings(None, None)?;
831
832    let groups_size_limit = state
833        .pass
834        .base
835        .device
836        .limits
837        .max_compute_workgroups_per_dimension;
838
839    if groups[0] > groups_size_limit
840        || groups[1] > groups_size_limit
841        || groups[2] > groups_size_limit
842    {
843        return Err(ComputePassErrorInner::Dispatch(
844            DispatchError::InvalidGroupSize {
845                current: groups,
846                limit: groups_size_limit,
847            },
848        ));
849    }
850
851    unsafe {
852        state.pass.base.raw_encoder.dispatch(groups);
853    }
854    Ok(())
855}
856
857fn dispatch_indirect(
858    state: &mut State,
859    device: &Arc<Device>,
860    buffer: Arc<Buffer>,
861    offset: u64,
862) -> Result<(), ComputePassErrorInner> {
863    api_log!("ComputePass::dispatch_indirect");
864
865    buffer.same_device(device)?;
866
867    state.is_ready()?;
868
869    state
870        .pass
871        .base
872        .device
873        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
874
875    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
876    buffer.check_destroyed(state.pass.base.snatch_guard)?;
877
878    if offset % 4 != 0 {
879        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
880    }
881
882    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
883    if end_offset > buffer.size {
884        return Err(ComputePassErrorInner::IndirectBufferOverrun {
885            offset,
886            end_offset,
887            buffer_size: buffer.size,
888        });
889    }
890
891    let stride = 3 * 4; // 3 integers, x/y/z group size
892    state.pass.base.buffer_memory_init_actions.extend(
893        buffer.initialization_status.read().create_action(
894            &buffer,
895            offset..(offset + stride),
896            MemoryInitKind::NeedsInitializedMemory,
897        ),
898    );
899
900    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
901        let params = indirect_validation.dispatch.params(
902            &state.pass.base.device.limits,
903            offset,
904            buffer.size,
905        );
906
907        unsafe {
908            state
909                .pass
910                .base
911                .raw_encoder
912                .set_compute_pipeline(params.pipeline);
913        }
914
915        unsafe {
916            state.pass.base.raw_encoder.set_push_constants(
917                params.pipeline_layout,
918                wgt::ShaderStages::COMPUTE,
919                0,
920                &[params.offset_remainder as u32 / 4],
921            );
922        }
923
924        unsafe {
925            state.pass.base.raw_encoder.set_bind_group(
926                params.pipeline_layout,
927                0,
928                Some(params.dst_bind_group),
929                &[],
930            );
931        }
932        unsafe {
933            state.pass.base.raw_encoder.set_bind_group(
934                params.pipeline_layout,
935                1,
936                Some(
937                    buffer
938                        .indirect_validation_bind_groups
939                        .get(state.pass.base.snatch_guard)
940                        .unwrap()
941                        .dispatch
942                        .as_ref(),
943                ),
944                &[params.aligned_offset as u32],
945            );
946        }
947
948        let src_transition = state
949            .intermediate_trackers
950            .buffers
951            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
952        let src_barrier = src_transition
953            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
954        unsafe {
955            state
956                .pass
957                .base
958                .raw_encoder
959                .transition_buffers(src_barrier.as_slice());
960        }
961
962        unsafe {
963            state
964                .pass
965                .base
966                .raw_encoder
967                .transition_buffers(&[hal::BufferBarrier {
968                    buffer: params.dst_buffer,
969                    usage: hal::StateTransition {
970                        from: wgt::BufferUses::INDIRECT,
971                        to: wgt::BufferUses::STORAGE_READ_WRITE,
972                    },
973                }]);
974        }
975
976        unsafe {
977            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
978        }
979
980        // reset state
981        {
982            let pipeline = state.pipeline.as_ref().unwrap();
983
984            unsafe {
985                state
986                    .pass
987                    .base
988                    .raw_encoder
989                    .set_compute_pipeline(pipeline.raw());
990            }
991
992            if !state.push_constants.is_empty() {
993                unsafe {
994                    state.pass.base.raw_encoder.set_push_constants(
995                        pipeline.layout.raw(),
996                        wgt::ShaderStages::COMPUTE,
997                        0,
998                        &state.push_constants,
999                    );
1000                }
1001            }
1002
1003            for (i, e) in state.pass.binder.list_valid() {
1004                let group = e.group.as_ref().unwrap();
1005                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1006                unsafe {
1007                    state.pass.base.raw_encoder.set_bind_group(
1008                        pipeline.layout.raw(),
1009                        i as u32,
1010                        Some(raw_bg),
1011                        &e.dynamic_offsets,
1012                    );
1013                }
1014            }
1015        }
1016
1017        unsafe {
1018            state
1019                .pass
1020                .base
1021                .raw_encoder
1022                .transition_buffers(&[hal::BufferBarrier {
1023                    buffer: params.dst_buffer,
1024                    usage: hal::StateTransition {
1025                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1026                        to: wgt::BufferUses::INDIRECT,
1027                    },
1028                }]);
1029        }
1030
1031        state.flush_bindings(Some(&buffer), None)?;
1032        unsafe {
1033            state
1034                .pass
1035                .base
1036                .raw_encoder
1037                .dispatch_indirect(params.dst_buffer, 0);
1038        }
1039    } else {
1040        use crate::resource::Trackable;
1041        state.flush_bindings(Some(&buffer), Some(buffer.tracker_index()))?;
1042
1043        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1044        unsafe {
1045            state
1046                .pass
1047                .base
1048                .raw_encoder
1049                .dispatch_indirect(buf_raw, offset);
1050        }
1051    }
1052
1053    Ok(())
1054}
1055
1056// Recording a compute pass.
1057//
1058// The only error that should be returned from these methods is
1059// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1060// validation error is raised.
1061//
1062// All other errors should be stored in the pass for later reporting when
1063// `CommandEncoder.finish()` is called.
1064//
1065// The `pass_try!` macro should be used to handle errors appropriately. Note
1066// that the `pass_try!` and `pass_base!` macros may return early from the
1067// function that invokes them, like the `?` operator.
1068impl Global {
1069    pub fn compute_pass_set_bind_group(
1070        &self,
1071        pass: &mut ComputePass,
1072        index: u32,
1073        bind_group_id: Option<id::BindGroupId>,
1074        offsets: &[DynamicOffset],
1075    ) -> Result<(), PassStateError> {
1076        let scope = PassErrorScope::SetBindGroup;
1077
1078        // This statement will return an error if the pass is ended. It's
1079        // important the error check comes before the early-out for
1080        // `set_and_check_redundant`.
1081        let base = pass_base!(pass, scope);
1082
1083        if pass.current_bind_groups.set_and_check_redundant(
1084            bind_group_id,
1085            index,
1086            &mut base.dynamic_offsets,
1087            offsets,
1088        ) {
1089            return Ok(());
1090        }
1091
1092        let mut bind_group = None;
1093        if bind_group_id.is_some() {
1094            let bind_group_id = bind_group_id.unwrap();
1095
1096            let hub = &self.hub;
1097            bind_group = Some(pass_try!(
1098                base,
1099                scope,
1100                hub.bind_groups.get(bind_group_id).get(),
1101            ));
1102        }
1103
1104        base.commands.push(ArcComputeCommand::SetBindGroup {
1105            index,
1106            num_dynamic_offsets: offsets.len(),
1107            bind_group,
1108        });
1109
1110        Ok(())
1111    }
1112
1113    pub fn compute_pass_set_pipeline(
1114        &self,
1115        pass: &mut ComputePass,
1116        pipeline_id: id::ComputePipelineId,
1117    ) -> Result<(), PassStateError> {
1118        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1119
1120        let scope = PassErrorScope::SetPipelineCompute;
1121
1122        // This statement will return an error if the pass is ended.
1123        // Its important the error check comes before the early-out for `redundant`.
1124        let base = pass_base!(pass, scope);
1125
1126        if redundant {
1127            return Ok(());
1128        }
1129
1130        let hub = &self.hub;
1131        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1132
1133        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1134
1135        Ok(())
1136    }
1137
1138    pub fn compute_pass_set_push_constants(
1139        &self,
1140        pass: &mut ComputePass,
1141        offset: u32,
1142        data: &[u8],
1143    ) -> Result<(), PassStateError> {
1144        let scope = PassErrorScope::SetPushConstant;
1145        let base = pass_base!(pass, scope);
1146
1147        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1148            pass_try!(
1149                base,
1150                scope,
1151                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1152            );
1153        }
1154
1155        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1156            pass_try!(
1157                base,
1158                scope,
1159                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1160            )
1161        }
1162        let value_offset = pass_try!(
1163            base,
1164            scope,
1165            base.push_constant_data
1166                .len()
1167                .try_into()
1168                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1169        );
1170
1171        base.push_constant_data.extend(
1172            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1173                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1174        );
1175
1176        base.commands.push(ArcComputeCommand::SetPushConstant {
1177            offset,
1178            size_bytes: data.len() as u32,
1179            values_offset: value_offset,
1180        });
1181
1182        Ok(())
1183    }
1184
1185    pub fn compute_pass_dispatch_workgroups(
1186        &self,
1187        pass: &mut ComputePass,
1188        groups_x: u32,
1189        groups_y: u32,
1190        groups_z: u32,
1191    ) -> Result<(), PassStateError> {
1192        let scope = PassErrorScope::Dispatch { indirect: false };
1193
1194        pass_base!(pass, scope)
1195            .commands
1196            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1197
1198        Ok(())
1199    }
1200
1201    pub fn compute_pass_dispatch_workgroups_indirect(
1202        &self,
1203        pass: &mut ComputePass,
1204        buffer_id: id::BufferId,
1205        offset: BufferAddress,
1206    ) -> Result<(), PassStateError> {
1207        let hub = &self.hub;
1208        let scope = PassErrorScope::Dispatch { indirect: true };
1209        let base = pass_base!(pass, scope);
1210
1211        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1212
1213        base.commands
1214            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1215
1216        Ok(())
1217    }
1218
1219    pub fn compute_pass_push_debug_group(
1220        &self,
1221        pass: &mut ComputePass,
1222        label: &str,
1223        color: u32,
1224    ) -> Result<(), PassStateError> {
1225        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1226
1227        let bytes = label.as_bytes();
1228        base.string_data.extend_from_slice(bytes);
1229
1230        base.commands.push(ArcComputeCommand::PushDebugGroup {
1231            color,
1232            len: bytes.len(),
1233        });
1234
1235        Ok(())
1236    }
1237
1238    pub fn compute_pass_pop_debug_group(
1239        &self,
1240        pass: &mut ComputePass,
1241    ) -> Result<(), PassStateError> {
1242        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1243
1244        base.commands.push(ArcComputeCommand::PopDebugGroup);
1245
1246        Ok(())
1247    }
1248
1249    pub fn compute_pass_insert_debug_marker(
1250        &self,
1251        pass: &mut ComputePass,
1252        label: &str,
1253        color: u32,
1254    ) -> Result<(), PassStateError> {
1255        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1256
1257        let bytes = label.as_bytes();
1258        base.string_data.extend_from_slice(bytes);
1259
1260        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1261            color,
1262            len: bytes.len(),
1263        });
1264
1265        Ok(())
1266    }
1267
1268    pub fn compute_pass_write_timestamp(
1269        &self,
1270        pass: &mut ComputePass,
1271        query_set_id: id::QuerySetId,
1272        query_index: u32,
1273    ) -> Result<(), PassStateError> {
1274        let scope = PassErrorScope::WriteTimestamp;
1275        let base = pass_base!(pass, scope);
1276
1277        let hub = &self.hub;
1278        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1279
1280        base.commands.push(ArcComputeCommand::WriteTimestamp {
1281            query_set,
1282            query_index,
1283        });
1284
1285        Ok(())
1286    }
1287
1288    pub fn compute_pass_begin_pipeline_statistics_query(
1289        &self,
1290        pass: &mut ComputePass,
1291        query_set_id: id::QuerySetId,
1292        query_index: u32,
1293    ) -> Result<(), PassStateError> {
1294        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1295        let base = pass_base!(pass, scope);
1296
1297        let hub = &self.hub;
1298        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300        base.commands
1301            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1302                query_set,
1303                query_index,
1304            });
1305
1306        Ok(())
1307    }
1308
1309    pub fn compute_pass_end_pipeline_statistics_query(
1310        &self,
1311        pass: &mut ComputePass,
1312    ) -> Result<(), PassStateError> {
1313        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1314            .commands
1315            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1316
1317        Ok(())
1318    }
1319}