wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11    api_log,
12    binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13    command::{
14        bind::{Binder, BinderError},
15        compute_command::ArcComputeCommand,
16        encoder::EncodingState,
17        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18        pass::{self, flush_bindings_helper},
19        pass_base, pass_try,
20        query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21        ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22        CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23        PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24        TimestampWritesError,
25    },
26    device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27    global::Global,
28    hal_label, id,
29    init_tracker::MemoryInitKind,
30    pipeline::ComputePipeline,
31    resource::{
32        self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33        MissingBufferUsageError, ParentDevice, RawResourceAccess, Trackable,
34    },
35    track::{ResourceUsageCompatibilityError, Tracker},
36    Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
42/// its validity are two distinct conditions, i.e., the full matrix of
43/// (open, ended) x (valid, invalid) is possible.
44///
45/// The presence or absence of the `parent` `Option` indicates the pass's state.
46/// The presence or absence of an error in `base.error` indicates the pass's
47/// validity.
48pub struct ComputePass {
49    /// All pass data & records is stored here.
50    base: ComputeBasePass,
51
52    /// Parent command encoder that this pass records commands into.
53    ///
54    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
55    /// `None`, then the pass is in the "ended" state.
56    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
57    parent: Option<Arc<CommandEncoder>>,
58
59    timestamp_writes: Option<ArcPassTimestampWrites>,
60
61    // Resource binding dedupe state.
62    current_bind_groups: BindGroupStateChange,
63    current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67    /// If the parent command encoder is invalid, the returned pass will be invalid.
68    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69        let ArcComputePassDescriptor {
70            label,
71            timestamp_writes,
72        } = desc;
73
74        Self {
75            base: BasePass::new(&label),
76            parent: Some(parent),
77            timestamp_writes,
78
79            current_bind_groups: BindGroupStateChange::new(),
80            current_pipeline: StateChange::new(),
81        }
82    }
83
84    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85        Self {
86            base: BasePass::new_invalid(label, err),
87            parent: Some(parent),
88            timestamp_writes: None,
89            current_bind_groups: BindGroupStateChange::new(),
90            current_pipeline: StateChange::new(),
91        }
92    }
93
94    #[inline]
95    pub fn label(&self) -> Option<&str> {
96        self.base.label.as_deref()
97    }
98}
99
100impl fmt::Debug for ComputePass {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self.parent {
103            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104            None => write!(f, "ComputePass {{ parent: None }}"),
105        }
106    }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111    pub label: Label<'a>,
112    /// Defines where and when timestamp values will be written for this pass.
113    pub timestamp_writes: Option<PTW>,
114}
115
116/// cbindgen:ignore
117type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122    #[error("Compute pipeline must be set")]
123    MissingPipeline(pass::MissingPipeline),
124    #[error(transparent)]
125    IncompatibleBindGroup(#[from] Box<BinderError>),
126    #[error(
127        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128    )]
129    InvalidGroupSize { current: [u32; 3], limit: u32 },
130    #[error(transparent)]
131    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135    fn webgpu_error_type(&self) -> ErrorType {
136        ErrorType::Validation
137    }
138}
139
140/// Error encountered when performing a compute pass.
141#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143    #[error(transparent)]
144    Device(#[from] DeviceError),
145    #[error(transparent)]
146    EncoderState(#[from] EncoderStateError),
147    #[error("Parent encoder is invalid")]
148    InvalidParentEncoder,
149    #[error(transparent)]
150    DebugGroupError(#[from] DebugGroupError),
151    #[error(transparent)]
152    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153    #[error(transparent)]
154    DestroyedResource(#[from] DestroyedResourceError),
155    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156    UnalignedIndirectBufferOffset(BufferAddress),
157    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158    IndirectBufferOverrun {
159        offset: u64,
160        end_offset: u64,
161        buffer_size: u64,
162    },
163    #[error(transparent)]
164    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165    #[error(transparent)]
166    MissingBufferUsage(#[from] MissingBufferUsageError),
167    #[error(transparent)]
168    Dispatch(#[from] DispatchError),
169    #[error(transparent)]
170    Bind(#[from] BindError),
171    #[error(transparent)]
172    ImmediateData(#[from] ImmediateUploadError),
173    #[error("Immediate data offset must be aligned to 4 bytes")]
174    ImmediateOffsetAlignment,
175    #[error("Immediate data size must be aligned to 4 bytes")]
176    ImmediateDataizeAlignment,
177    #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
178    ImmediateOutOfMemory,
179    #[error(transparent)]
180    QueryUse(#[from] QueryUseError),
181    #[error(transparent)]
182    MissingFeatures(#[from] MissingFeatures),
183    #[error(transparent)]
184    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185    #[error("The compute pass has already been ended and no further commands can be recorded")]
186    PassEnded,
187    #[error(transparent)]
188    InvalidResource(#[from] InvalidResourceError),
189    #[error(transparent)]
190    TimestampWrites(#[from] TimestampWritesError),
191    // This one is unreachable, but required for generic pass support
192    #[error(transparent)]
193    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196/// Error encountered when performing a compute pass, stored for later reporting
197/// when encoding ends.
198#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201    pub scope: PassErrorScope,
202    #[source]
203    pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207    fn from(value: pass::MissingPipeline) -> Self {
208        Self::Dispatch(DispatchError::MissingPipeline(value))
209    }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214    E: Into<ComputePassErrorInner>,
215{
216    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217        ComputePassError {
218            scope,
219            inner: self.into(),
220        }
221    }
222}
223
224impl WebGpuError for ComputePassError {
225    fn webgpu_error_type(&self) -> ErrorType {
226        let Self { scope: _, inner } = self;
227        let e: &dyn WebGpuError = match inner {
228            ComputePassErrorInner::Device(e) => e,
229            ComputePassErrorInner::EncoderState(e) => e,
230            ComputePassErrorInner::DebugGroupError(e) => e,
231            ComputePassErrorInner::DestroyedResource(e) => e,
232            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
233            ComputePassErrorInner::MissingBufferUsage(e) => e,
234            ComputePassErrorInner::Dispatch(e) => e,
235            ComputePassErrorInner::Bind(e) => e,
236            ComputePassErrorInner::ImmediateData(e) => e,
237            ComputePassErrorInner::QueryUse(e) => e,
238            ComputePassErrorInner::MissingFeatures(e) => e,
239            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
240            ComputePassErrorInner::InvalidResource(e) => e,
241            ComputePassErrorInner::TimestampWrites(e) => e,
242            ComputePassErrorInner::InvalidValuesOffset(e) => e,
243
244            ComputePassErrorInner::InvalidParentEncoder
245            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247            | ComputePassErrorInner::IndirectBufferOverrun { .. }
248            | ComputePassErrorInner::ImmediateOffsetAlignment
249            | ComputePassErrorInner::ImmediateDataizeAlignment
250            | ComputePassErrorInner::ImmediateOutOfMemory
251            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
252        };
253        e.webgpu_error_type()
254    }
255}
256
257struct State<'scope, 'snatch_guard, 'cmd_enc> {
258    pipeline: Option<Arc<ComputePipeline>>,
259
260    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
261
262    active_query: Option<(Arc<resource::QuerySet>, u32)>,
263
264    immediates: Vec<u32>,
265
266    intermediate_trackers: Tracker,
267}
268
269impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
270    fn is_ready(&self) -> Result<(), DispatchError> {
271        if let Some(pipeline) = self.pipeline.as_ref() {
272            self.pass.binder.check_compatibility(pipeline.as_ref())?;
273            self.pass.binder.check_late_buffer_bindings()?;
274            Ok(())
275        } else {
276            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
277        }
278    }
279
280    /// Flush binding state in preparation for a dispatch.
281    ///
282    /// # Differences between render and compute passes
283    ///
284    /// There are differences between the `flush_bindings` implementations for
285    /// render and compute passes, because render passes have a single usage
286    /// scope for the entire pass, and compute passes have a separate usage
287    /// scope for each dispatch.
288    ///
289    /// For compute passes, bind groups are merged into a fresh usage scope
290    /// here, not into the pass usage scope within calls to `set_bind_group`. As
291    /// specified by WebGPU, for compute passes, we merge only the bind groups
292    /// that are actually used by the pipeline, unlike render passes, which
293    /// merge every bind group that is ever set, even if it is not ultimately
294    /// used by the pipeline.
295    ///
296    /// For compute passes, we call `drain_barriers` here, because barriers may
297    /// be needed before each dispatch if a previous dispatch had a conflicting
298    /// usage. For render passes, barriers are emitted once at the start of the
299    /// render pass.
300    ///
301    /// # Indirect buffer handling
302    ///
303    /// The `indirect_buffer` argument should be passed for any indirect
304    /// dispatch (with or without validation). It will be checked for
305    /// conflicting usages according to WebGPU rules. For the purpose of
306    /// these rules, the fact that we have actually processed the buffer in
307    /// the validation pass is an implementation detail.
308    ///
309    /// The `track_indirect_buffer` argument should be set when doing indirect
310    /// dispatch *without* validation. In this case, the indirect buffer will
311    /// be added to the tracker in order to generate any necessary transitions
312    /// for that usage.
313    ///
314    /// When doing indirect dispatch *with* validation, the indirect buffer is
315    /// processed by the validation pass and is not used by the actual dispatch.
316    /// The indirect validation code handles transitions for the validation
317    /// pass.
318    fn flush_bindings(
319        &mut self,
320        indirect_buffer: Option<&Arc<Buffer>>,
321        track_indirect_buffer: bool,
322    ) -> Result<(), ComputePassErrorInner> {
323        for bind_group in self.pass.binder.list_active() {
324            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
325        }
326
327        // Add the indirect buffer. Because usage scopes are per-dispatch, this
328        // is the only place where INDIRECT usage could be added, and it is safe
329        // for us to remove it below.
330        if let Some(buffer) = indirect_buffer {
331            self.pass
332                .scope
333                .buffers
334                .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
335        }
336
337        // For compute, usage scopes are associated with each dispatch and not
338        // with the pass as a whole. However, because the cost of creating and
339        // dropping `UsageScope`s is significant (even with the pool), we
340        // add and then remove usage from a single usage scope.
341
342        for bind_group in self.pass.binder.list_active() {
343            self.intermediate_trackers
344                .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
345        }
346
347        if track_indirect_buffer {
348            self.intermediate_trackers
349                .buffers
350                .set_and_remove_from_usage_scope_sparse(
351                    &mut self.pass.scope.buffers,
352                    indirect_buffer.map(|buf| buf.tracker_index()),
353                );
354        } else if let Some(buffer) = indirect_buffer {
355            self.pass
356                .scope
357                .buffers
358                .remove_usage(buffer, wgt::BufferUses::INDIRECT);
359        }
360
361        flush_bindings_helper(&mut self.pass)?;
362
363        CommandEncoder::drain_barriers(
364            self.pass.base.raw_encoder,
365            &mut self.intermediate_trackers,
366            self.pass.base.snatch_guard,
367        );
368        Ok(())
369    }
370}
371
372// Running the compute pass.
373
374impl Global {
375    /// Creates a compute pass.
376    ///
377    /// If creation fails, an invalid pass is returned. Attempting to record
378    /// commands into an invalid pass is permitted, but a validation error will
379    /// ultimately be generated when the parent encoder is finished, and it is
380    /// not possible to run any commands from the invalid pass.
381    ///
382    /// If successful, puts the encoder into the [`Locked`] state.
383    ///
384    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
385    pub fn command_encoder_begin_compute_pass(
386        &self,
387        encoder_id: id::CommandEncoderId,
388        desc: &ComputePassDescriptor<'_>,
389    ) -> (ComputePass, Option<CommandEncoderError>) {
390        use EncoderStateError as SErr;
391
392        let scope = PassErrorScope::Pass;
393        let hub = &self.hub;
394
395        let label = desc.label.as_deref().map(Cow::Borrowed);
396
397        let cmd_enc = hub.command_encoders.get(encoder_id);
398        let mut cmd_buf_data = cmd_enc.data.lock();
399
400        match cmd_buf_data.lock_encoder() {
401            Ok(()) => {
402                drop(cmd_buf_data);
403                if let Err(err) = cmd_enc.device.check_is_valid() {
404                    return (
405                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
406                        None,
407                    );
408                }
409
410                match desc
411                    .timestamp_writes
412                    .as_ref()
413                    .map(|tw| {
414                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
415                            &cmd_enc.device,
416                            &hub.query_sets.read(),
417                            tw,
418                        )
419                    })
420                    .transpose()
421                {
422                    Ok(timestamp_writes) => {
423                        let arc_desc = ArcComputePassDescriptor {
424                            label,
425                            timestamp_writes,
426                        };
427                        (ComputePass::new(cmd_enc, arc_desc), None)
428                    }
429                    Err(err) => (
430                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
431                        None,
432                    ),
433                }
434            }
435            Err(err @ SErr::Locked) => {
436                // Attempting to open a new pass while the encoder is locked
437                // invalidates the encoder, but does not generate a validation
438                // error.
439                cmd_buf_data.invalidate(err.clone());
440                drop(cmd_buf_data);
441                (
442                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
443                    None,
444                )
445            }
446            Err(err @ (SErr::Ended | SErr::Submitted)) => {
447                // Attempting to open a new pass after the encode has ended
448                // generates an immediate validation error.
449                drop(cmd_buf_data);
450                (
451                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
452                    Some(err.into()),
453                )
454            }
455            Err(err @ SErr::Invalid) => {
456                // Passes can be opened even on an invalid encoder. Such passes
457                // are even valid, but since there's no visible side-effect of
458                // the pass being valid and there's no point in storing recorded
459                // commands that will ultimately be discarded, we open an
460                // invalid pass to save that work.
461                drop(cmd_buf_data);
462                (
463                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
464                    None,
465                )
466            }
467            Err(SErr::Unlocked) => {
468                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
469            }
470        }
471    }
472
473    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
474        profiling::scope!(
475            "CommandEncoder::run_compute_pass {}",
476            pass.base.label.as_deref().unwrap_or("")
477        );
478
479        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
480        let mut cmd_buf_data = cmd_enc.data.lock();
481
482        cmd_buf_data.unlock_encoder()?;
483
484        let base = pass.base.take();
485
486        if let Err(ComputePassError {
487            inner:
488                ComputePassErrorInner::EncoderState(
489                    err @ (EncoderStateError::Locked | EncoderStateError::Ended),
490                ),
491            scope: _,
492        }) = base
493        {
494            // Most encoding errors are detected and raised within `finish()`.
495            //
496            // However, we raise a validation error here if the pass was opened
497            // within another pass, or on a finished encoder. The latter is
498            // particularly important, because in that case reporting errors via
499            // `CommandEncoder::finish` is not possible.
500            return Err(err.clone());
501        }
502
503        cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
504            Ok(ArcCommand::RunComputePass {
505                pass: base?,
506                timestamp_writes: pass.timestamp_writes.take(),
507            })
508        })
509    }
510}
511
512pub(super) fn encode_compute_pass(
513    parent_state: &mut EncodingState<InnerCommandEncoder>,
514    mut base: BasePass<ArcComputeCommand, Infallible>,
515    mut timestamp_writes: Option<ArcPassTimestampWrites>,
516) -> Result<(), ComputePassError> {
517    let pass_scope = PassErrorScope::Pass;
518
519    let device = parent_state.device;
520
521    // We automatically keep extending command buffers over time, and because
522    // we want to insert a command buffer _before_ what we're about to record,
523    // we need to make sure to close the previous one.
524    parent_state
525        .raw_encoder
526        .close_if_open()
527        .map_pass_err(pass_scope)?;
528    let raw_encoder = parent_state
529        .raw_encoder
530        .open_pass(base.label.as_deref())
531        .map_pass_err(pass_scope)?;
532
533    let mut debug_scope_depth = 0;
534
535    let mut state = State {
536        pipeline: None,
537
538        pass: pass::PassState {
539            base: EncodingState {
540                device,
541                raw_encoder,
542                tracker: parent_state.tracker,
543                buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
544                texture_memory_actions: parent_state.texture_memory_actions,
545                as_actions: parent_state.as_actions,
546                temp_resources: parent_state.temp_resources,
547                indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
548                snatch_guard: parent_state.snatch_guard,
549                debug_scope_depth: &mut debug_scope_depth,
550            },
551            binder: Binder::new(),
552            temp_offsets: Vec::new(),
553            dynamic_offset_count: 0,
554            pending_discard_init_fixups: SurfacesInDiscardState::new(),
555            scope: device.new_usage_scope(),
556            string_offset: 0,
557        },
558        active_query: None,
559
560        immediates: Vec::new(),
561
562        intermediate_trackers: Tracker::new(),
563    };
564
565    let indices = &device.tracker_indices;
566    state
567        .pass
568        .base
569        .tracker
570        .buffers
571        .set_size(indices.buffers.size());
572    state
573        .pass
574        .base
575        .tracker
576        .textures
577        .set_size(indices.textures.size());
578
579    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
580        if let Some(tw) = timestamp_writes.take() {
581            tw.query_set.same_device(device).map_pass_err(pass_scope)?;
582
583            let query_set = state
584                .pass
585                .base
586                .tracker
587                .query_sets
588                .insert_single(tw.query_set);
589
590            // Unlike in render passes we can't delay resetting the query sets since
591            // there is no auxiliary pass.
592            let range = if let (Some(index_a), Some(index_b)) =
593                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
594            {
595                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
596            } else {
597                tw.beginning_of_pass_write_index
598                    .or(tw.end_of_pass_write_index)
599                    .map(|i| i..i + 1)
600            };
601            // Range should always be Some, both values being None should lead to a validation error.
602            // But no point in erroring over that nuance here!
603            if let Some(range) = range {
604                unsafe {
605                    state
606                        .pass
607                        .base
608                        .raw_encoder
609                        .reset_queries(query_set.raw(), range);
610                }
611            }
612
613            Some(hal::PassTimestampWrites {
614                query_set: query_set.raw(),
615                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
616                end_of_pass_write_index: tw.end_of_pass_write_index,
617            })
618        } else {
619            None
620        };
621
622    let hal_desc = hal::ComputePassDescriptor {
623        label: hal_label(base.label.as_deref(), device.instance_flags),
624        timestamp_writes,
625    };
626
627    unsafe {
628        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
629    }
630
631    for command in base.commands.drain(..) {
632        match command {
633            ArcComputeCommand::SetBindGroup {
634                index,
635                num_dynamic_offsets,
636                bind_group,
637            } => {
638                let scope = PassErrorScope::SetBindGroup;
639                pass::set_bind_group::<ComputePassErrorInner>(
640                    &mut state.pass,
641                    device,
642                    &base.dynamic_offsets,
643                    index,
644                    num_dynamic_offsets,
645                    bind_group,
646                    false,
647                )
648                .map_pass_err(scope)?;
649            }
650            ArcComputeCommand::SetPipeline(pipeline) => {
651                let scope = PassErrorScope::SetPipelineCompute;
652                set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
653            }
654            ArcComputeCommand::SetImmediate {
655                offset,
656                size_bytes,
657                values_offset,
658            } => {
659                let scope = PassErrorScope::SetImmediate;
660                pass::set_immediates::<ComputePassErrorInner, _>(
661                    &mut state.pass,
662                    &base.immediates_data,
663                    offset,
664                    size_bytes,
665                    Some(values_offset),
666                    |data_slice| {
667                        let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
668                        let size_in_elements =
669                            (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
670                        state.immediates[offset_in_elements..][..size_in_elements]
671                            .copy_from_slice(data_slice);
672                    },
673                )
674                .map_pass_err(scope)?;
675            }
676            ArcComputeCommand::Dispatch(groups) => {
677                let scope = PassErrorScope::Dispatch { indirect: false };
678                dispatch(&mut state, groups).map_pass_err(scope)?;
679            }
680            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
681                let scope = PassErrorScope::Dispatch { indirect: true };
682                dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
683            }
684            ArcComputeCommand::PushDebugGroup { color: _, len } => {
685                pass::push_debug_group(&mut state.pass, &base.string_data, len);
686            }
687            ArcComputeCommand::PopDebugGroup => {
688                let scope = PassErrorScope::PopDebugGroup;
689                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
690                    .map_pass_err(scope)?;
691            }
692            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
693                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
694            }
695            ArcComputeCommand::WriteTimestamp {
696                query_set,
697                query_index,
698            } => {
699                let scope = PassErrorScope::WriteTimestamp;
700                pass::write_timestamp::<ComputePassErrorInner>(
701                    &mut state.pass,
702                    device,
703                    None, // compute passes do not attempt to coalesce query resets
704                    query_set,
705                    query_index,
706                )
707                .map_pass_err(scope)?;
708            }
709            ArcComputeCommand::BeginPipelineStatisticsQuery {
710                query_set,
711                query_index,
712            } => {
713                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
714                validate_and_begin_pipeline_statistics_query(
715                    query_set,
716                    state.pass.base.raw_encoder,
717                    &mut state.pass.base.tracker.query_sets,
718                    device,
719                    query_index,
720                    None,
721                    &mut state.active_query,
722                )
723                .map_pass_err(scope)?;
724            }
725            ArcComputeCommand::EndPipelineStatisticsQuery => {
726                let scope = PassErrorScope::EndPipelineStatisticsQuery;
727                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
728                    .map_pass_err(scope)?;
729            }
730        }
731    }
732
733    if *state.pass.base.debug_scope_depth > 0 {
734        Err(
735            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
736                .map_pass_err(pass_scope),
737        )?;
738    }
739
740    unsafe {
741        state.pass.base.raw_encoder.end_compute_pass();
742    }
743
744    let State {
745        pass: pass::PassState {
746            pending_discard_init_fixups,
747            ..
748        },
749        intermediate_trackers,
750        ..
751    } = state;
752
753    // Stop the current command encoder.
754    parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
755
756    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
757    //
758    // Use that buffer to insert barriers and clear discarded images.
759    let transit = parent_state
760        .raw_encoder
761        .open_pass(hal_label(
762            Some("(wgpu internal) Pre Pass"),
763            device.instance_flags,
764        ))
765        .map_pass_err(pass_scope)?;
766    fixup_discarded_surfaces(
767        pending_discard_init_fixups.into_iter(),
768        transit,
769        &mut parent_state.tracker.textures,
770        device,
771        parent_state.snatch_guard,
772    );
773    CommandEncoder::insert_barriers_from_tracker(
774        transit,
775        parent_state.tracker,
776        &intermediate_trackers,
777        parent_state.snatch_guard,
778    );
779    // Close the command encoder, and swap it with the previous.
780    parent_state
781        .raw_encoder
782        .close_and_swap()
783        .map_pass_err(pass_scope)?;
784
785    Ok(())
786}
787
788fn set_pipeline(
789    state: &mut State,
790    device: &Arc<Device>,
791    pipeline: Arc<ComputePipeline>,
792) -> Result<(), ComputePassErrorInner> {
793    pipeline.same_device(device)?;
794
795    state.pipeline = Some(pipeline.clone());
796
797    let pipeline = state
798        .pass
799        .base
800        .tracker
801        .compute_pipelines
802        .insert_single(pipeline)
803        .clone();
804
805    unsafe {
806        state
807            .pass
808            .base
809            .raw_encoder
810            .set_compute_pipeline(pipeline.raw());
811    }
812
813    // Rebind resources
814    pass::change_pipeline_layout::<ComputePassErrorInner, _>(
815        &mut state.pass,
816        &pipeline.layout,
817        &pipeline.late_sized_buffer_groups,
818        || {
819            // This only needs to be here for compute pipelines because they use immediates for
820            // validating indirect draws.
821            state.immediates.clear();
822            // Note that can only be one range for each stage. See the `MoreThanOneImmediateRangePerStage` error.
823            if pipeline.layout.immediate_size != 0 {
824                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
825                let len = pipeline.layout.immediate_size as usize
826                    / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
827                state.immediates.extend(core::iter::repeat_n(0, len));
828            }
829        },
830    )
831}
832
833fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
834    api_log!("ComputePass::dispatch {groups:?}");
835
836    state.is_ready()?;
837
838    state.flush_bindings(None, false)?;
839
840    let groups_size_limit = state
841        .pass
842        .base
843        .device
844        .limits
845        .max_compute_workgroups_per_dimension;
846
847    if groups[0] > groups_size_limit
848        || groups[1] > groups_size_limit
849        || groups[2] > groups_size_limit
850    {
851        return Err(ComputePassErrorInner::Dispatch(
852            DispatchError::InvalidGroupSize {
853                current: groups,
854                limit: groups_size_limit,
855            },
856        ));
857    }
858
859    unsafe {
860        state.pass.base.raw_encoder.dispatch(groups);
861    }
862    Ok(())
863}
864
865fn dispatch_indirect(
866    state: &mut State,
867    device: &Arc<Device>,
868    buffer: Arc<Buffer>,
869    offset: u64,
870) -> Result<(), ComputePassErrorInner> {
871    api_log!("ComputePass::dispatch_indirect");
872
873    buffer.same_device(device)?;
874
875    state.is_ready()?;
876
877    state
878        .pass
879        .base
880        .device
881        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
882
883    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
884
885    if offset % 4 != 0 {
886        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
887    }
888
889    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
890    if end_offset > buffer.size {
891        return Err(ComputePassErrorInner::IndirectBufferOverrun {
892            offset,
893            end_offset,
894            buffer_size: buffer.size,
895        });
896    }
897
898    buffer.check_destroyed(state.pass.base.snatch_guard)?;
899
900    let stride = 3 * 4; // 3 integers, x/y/z group size
901    state.pass.base.buffer_memory_init_actions.extend(
902        buffer.initialization_status.read().create_action(
903            &buffer,
904            offset..(offset + stride),
905            MemoryInitKind::NeedsInitializedMemory,
906        ),
907    );
908
909    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
910        let params = indirect_validation.dispatch.params(
911            &state.pass.base.device.limits,
912            offset,
913            buffer.size,
914        );
915
916        unsafe {
917            state
918                .pass
919                .base
920                .raw_encoder
921                .set_compute_pipeline(params.pipeline);
922        }
923
924        unsafe {
925            state.pass.base.raw_encoder.set_immediates(
926                params.pipeline_layout,
927                0,
928                &[params.offset_remainder as u32 / 4],
929            );
930        }
931
932        unsafe {
933            state.pass.base.raw_encoder.set_bind_group(
934                params.pipeline_layout,
935                0,
936                Some(params.dst_bind_group),
937                &[],
938            );
939        }
940        unsafe {
941            state.pass.base.raw_encoder.set_bind_group(
942                params.pipeline_layout,
943                1,
944                Some(
945                    buffer
946                        .indirect_validation_bind_groups
947                        .get(state.pass.base.snatch_guard)
948                        .unwrap()
949                        .dispatch
950                        .as_ref(),
951                ),
952                &[params.aligned_offset as u32],
953            );
954        }
955
956        let src_transition = state
957            .intermediate_trackers
958            .buffers
959            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
960        let src_barrier = src_transition
961            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
962        unsafe {
963            state
964                .pass
965                .base
966                .raw_encoder
967                .transition_buffers(src_barrier.as_slice());
968        }
969
970        unsafe {
971            state
972                .pass
973                .base
974                .raw_encoder
975                .transition_buffers(&[hal::BufferBarrier {
976                    buffer: params.dst_buffer,
977                    usage: hal::StateTransition {
978                        from: wgt::BufferUses::INDIRECT,
979                        to: wgt::BufferUses::STORAGE_READ_WRITE,
980                    },
981                }]);
982        }
983
984        unsafe {
985            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
986        }
987
988        // reset state
989        {
990            let pipeline = state.pipeline.as_ref().unwrap();
991
992            unsafe {
993                state
994                    .pass
995                    .base
996                    .raw_encoder
997                    .set_compute_pipeline(pipeline.raw());
998            }
999
1000            if !state.immediates.is_empty() {
1001                unsafe {
1002                    state.pass.base.raw_encoder.set_immediates(
1003                        pipeline.layout.raw(),
1004                        0,
1005                        &state.immediates,
1006                    );
1007                }
1008            }
1009
1010            for (i, e) in state.pass.binder.list_valid() {
1011                let group = e.group.as_ref().unwrap();
1012                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1013                unsafe {
1014                    state.pass.base.raw_encoder.set_bind_group(
1015                        pipeline.layout.raw(),
1016                        i as u32,
1017                        Some(raw_bg),
1018                        &e.dynamic_offsets,
1019                    );
1020                }
1021            }
1022        }
1023
1024        unsafe {
1025            state
1026                .pass
1027                .base
1028                .raw_encoder
1029                .transition_buffers(&[hal::BufferBarrier {
1030                    buffer: params.dst_buffer,
1031                    usage: hal::StateTransition {
1032                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1033                        to: wgt::BufferUses::INDIRECT,
1034                    },
1035                }]);
1036        }
1037
1038        state.flush_bindings(Some(&buffer), false)?;
1039        unsafe {
1040            state
1041                .pass
1042                .base
1043                .raw_encoder
1044                .dispatch_indirect(params.dst_buffer, 0);
1045        }
1046    } else {
1047        state.flush_bindings(Some(&buffer), true)?;
1048
1049        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1050        unsafe {
1051            state
1052                .pass
1053                .base
1054                .raw_encoder
1055                .dispatch_indirect(buf_raw, offset);
1056        }
1057    }
1058
1059    Ok(())
1060}
1061
1062// Recording a compute pass.
1063//
1064// The only error that should be returned from these methods is
1065// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1066// validation error is raised.
1067//
1068// All other errors should be stored in the pass for later reporting when
1069// `CommandEncoder.finish()` is called.
1070//
1071// The `pass_try!` macro should be used to handle errors appropriately. Note
1072// that the `pass_try!` and `pass_base!` macros may return early from the
1073// function that invokes them, like the `?` operator.
1074impl Global {
1075    pub fn compute_pass_set_bind_group(
1076        &self,
1077        pass: &mut ComputePass,
1078        index: u32,
1079        bind_group_id: Option<id::BindGroupId>,
1080        offsets: &[DynamicOffset],
1081    ) -> Result<(), PassStateError> {
1082        let scope = PassErrorScope::SetBindGroup;
1083
1084        // This statement will return an error if the pass is ended. It's
1085        // important the error check comes before the early-out for
1086        // `set_and_check_redundant`.
1087        let base = pass_base!(pass, scope);
1088
1089        if pass.current_bind_groups.set_and_check_redundant(
1090            bind_group_id,
1091            index,
1092            &mut base.dynamic_offsets,
1093            offsets,
1094        ) {
1095            return Ok(());
1096        }
1097
1098        let mut bind_group = None;
1099        if let Some(bind_group_id) = bind_group_id {
1100            let hub = &self.hub;
1101            bind_group = Some(pass_try!(
1102                base,
1103                scope,
1104                hub.bind_groups.get(bind_group_id).get(),
1105            ));
1106        }
1107
1108        base.commands.push(ArcComputeCommand::SetBindGroup {
1109            index,
1110            num_dynamic_offsets: offsets.len(),
1111            bind_group,
1112        });
1113
1114        Ok(())
1115    }
1116
1117    pub fn compute_pass_set_pipeline(
1118        &self,
1119        pass: &mut ComputePass,
1120        pipeline_id: id::ComputePipelineId,
1121    ) -> Result<(), PassStateError> {
1122        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1123
1124        let scope = PassErrorScope::SetPipelineCompute;
1125
1126        // This statement will return an error if the pass is ended.
1127        // Its important the error check comes before the early-out for `redundant`.
1128        let base = pass_base!(pass, scope);
1129
1130        if redundant {
1131            return Ok(());
1132        }
1133
1134        let hub = &self.hub;
1135        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1136
1137        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1138
1139        Ok(())
1140    }
1141
1142    pub fn compute_pass_set_immediates(
1143        &self,
1144        pass: &mut ComputePass,
1145        offset: u32,
1146        data: &[u8],
1147    ) -> Result<(), PassStateError> {
1148        let scope = PassErrorScope::SetImmediate;
1149        let base = pass_base!(pass, scope);
1150
1151        if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1152            pass_try!(
1153                base,
1154                scope,
1155                Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1156            );
1157        }
1158
1159        if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1160            pass_try!(
1161                base,
1162                scope,
1163                Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1164            )
1165        }
1166        let value_offset = pass_try!(
1167            base,
1168            scope,
1169            base.immediates_data
1170                .len()
1171                .try_into()
1172                .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1173        );
1174
1175        base.immediates_data.extend(
1176            data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1177                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1178        );
1179
1180        base.commands.push(ArcComputeCommand::SetImmediate {
1181            offset,
1182            size_bytes: data.len() as u32,
1183            values_offset: value_offset,
1184        });
1185
1186        Ok(())
1187    }
1188
1189    pub fn compute_pass_dispatch_workgroups(
1190        &self,
1191        pass: &mut ComputePass,
1192        groups_x: u32,
1193        groups_y: u32,
1194        groups_z: u32,
1195    ) -> Result<(), PassStateError> {
1196        let scope = PassErrorScope::Dispatch { indirect: false };
1197
1198        pass_base!(pass, scope)
1199            .commands
1200            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1201
1202        Ok(())
1203    }
1204
1205    pub fn compute_pass_dispatch_workgroups_indirect(
1206        &self,
1207        pass: &mut ComputePass,
1208        buffer_id: id::BufferId,
1209        offset: BufferAddress,
1210    ) -> Result<(), PassStateError> {
1211        let hub = &self.hub;
1212        let scope = PassErrorScope::Dispatch { indirect: true };
1213        let base = pass_base!(pass, scope);
1214
1215        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1216
1217        base.commands
1218            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1219
1220        Ok(())
1221    }
1222
1223    pub fn compute_pass_push_debug_group(
1224        &self,
1225        pass: &mut ComputePass,
1226        label: &str,
1227        color: u32,
1228    ) -> Result<(), PassStateError> {
1229        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1230
1231        let bytes = label.as_bytes();
1232        base.string_data.extend_from_slice(bytes);
1233
1234        base.commands.push(ArcComputeCommand::PushDebugGroup {
1235            color,
1236            len: bytes.len(),
1237        });
1238
1239        Ok(())
1240    }
1241
1242    pub fn compute_pass_pop_debug_group(
1243        &self,
1244        pass: &mut ComputePass,
1245    ) -> Result<(), PassStateError> {
1246        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1247
1248        base.commands.push(ArcComputeCommand::PopDebugGroup);
1249
1250        Ok(())
1251    }
1252
1253    pub fn compute_pass_insert_debug_marker(
1254        &self,
1255        pass: &mut ComputePass,
1256        label: &str,
1257        color: u32,
1258    ) -> Result<(), PassStateError> {
1259        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1260
1261        let bytes = label.as_bytes();
1262        base.string_data.extend_from_slice(bytes);
1263
1264        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1265            color,
1266            len: bytes.len(),
1267        });
1268
1269        Ok(())
1270    }
1271
1272    pub fn compute_pass_write_timestamp(
1273        &self,
1274        pass: &mut ComputePass,
1275        query_set_id: id::QuerySetId,
1276        query_index: u32,
1277    ) -> Result<(), PassStateError> {
1278        let scope = PassErrorScope::WriteTimestamp;
1279        let base = pass_base!(pass, scope);
1280
1281        let hub = &self.hub;
1282        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1283
1284        base.commands.push(ArcComputeCommand::WriteTimestamp {
1285            query_set,
1286            query_index,
1287        });
1288
1289        Ok(())
1290    }
1291
1292    pub fn compute_pass_begin_pipeline_statistics_query(
1293        &self,
1294        pass: &mut ComputePass,
1295        query_set_id: id::QuerySetId,
1296        query_index: u32,
1297    ) -> Result<(), PassStateError> {
1298        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1299        let base = pass_base!(pass, scope);
1300
1301        let hub = &self.hub;
1302        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1303
1304        base.commands
1305            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1306                query_set,
1307                query_index,
1308            });
1309
1310        Ok(())
1311    }
1312
1313    pub fn compute_pass_end_pipeline_statistics_query(
1314        &self,
1315        pass: &mut ComputePass,
1316    ) -> Result<(), PassStateError> {
1317        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1318            .commands
1319            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1320
1321        Ok(())
1322    }
1323}