wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{
11    pass, CommandEncoder, DebugGroupError, EncoderStateError, PassStateError, TimestampWritesError,
12};
13use crate::resource::DestroyedResourceError;
14use crate::{binding_model::BindError, resource::RawResourceAccess};
15use crate::{
16    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
17    command::{
18        bind::{Binder, BinderError},
19        compute_command::ArcComputeCommand,
20        end_pipeline_statistics_query,
21        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
22        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
23        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
24        PassTimestampWrites, QueryUseError, StateChange,
25    },
26    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
27    global::Global,
28    hal_label, id,
29    init_tracker::MemoryInitKind,
30    pipeline::ComputePipeline,
31    resource::{
32        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
33    },
34    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
35    Label,
36};
37
38pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
39
40/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
41/// its validity are two distinct conditions, i.e., the full matrix of
42/// (open, ended) x (valid, invalid) is possible.
43///
44/// The presence or absence of the `parent` `Option` indicates the pass's state.
45/// The presence or absence of an error in `base.error` indicates the pass's
46/// validity.
47pub struct ComputePass {
48    /// All pass data & records is stored here.
49    base: ComputeBasePass,
50
51    /// Parent command encoder that this pass records commands into.
52    ///
53    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
54    /// `None`, then the pass is in the "ended" state.
55    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
56    parent: Option<Arc<CommandEncoder>>,
57
58    timestamp_writes: Option<ArcPassTimestampWrites>,
59
60    // Resource binding dedupe state.
61    current_bind_groups: BindGroupStateChange,
62    current_pipeline: StateChange<id::ComputePipelineId>,
63}
64
65impl ComputePass {
66    /// If the parent command encoder is invalid, the returned pass will be invalid.
67    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
68        let ArcComputePassDescriptor {
69            label,
70            timestamp_writes,
71        } = desc;
72
73        Self {
74            base: BasePass::new(&label),
75            parent: Some(parent),
76            timestamp_writes,
77
78            current_bind_groups: BindGroupStateChange::new(),
79            current_pipeline: StateChange::new(),
80        }
81    }
82
83    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
84        Self {
85            base: BasePass::new_invalid(label, err),
86            parent: Some(parent),
87            timestamp_writes: None,
88            current_bind_groups: BindGroupStateChange::new(),
89            current_pipeline: StateChange::new(),
90        }
91    }
92
93    #[inline]
94    pub fn label(&self) -> Option<&str> {
95        self.base.label.as_deref()
96    }
97}
98
99impl fmt::Debug for ComputePass {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self.parent {
102            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
103            None => write!(f, "ComputePass {{ parent: None }}"),
104        }
105    }
106}
107
108#[derive(Clone, Debug, Default)]
109pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
110    pub label: Label<'a>,
111    /// Defines where and when timestamp values will be written for this pass.
112    pub timestamp_writes: Option<PTW>,
113}
114
115/// cbindgen:ignore
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
117
118#[derive(Clone, Debug, Error)]
119#[non_exhaustive]
120pub enum DispatchError {
121    #[error("Compute pipeline must be set")]
122    MissingPipeline(pass::MissingPipeline),
123    #[error(transparent)]
124    IncompatibleBindGroup(#[from] Box<BinderError>),
125    #[error(
126        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
127    )]
128    InvalidGroupSize { current: [u32; 3], limit: u32 },
129    #[error(transparent)]
130    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
131}
132
133impl WebGpuError for DispatchError {
134    fn webgpu_error_type(&self) -> ErrorType {
135        ErrorType::Validation
136    }
137}
138
139/// Error encountered when performing a compute pass.
140#[derive(Clone, Debug, Error)]
141pub enum ComputePassErrorInner {
142    #[error(transparent)]
143    Device(#[from] DeviceError),
144    #[error(transparent)]
145    EncoderState(#[from] EncoderStateError),
146    #[error("Parent encoder is invalid")]
147    InvalidParentEncoder,
148    #[error(transparent)]
149    DebugGroupError(#[from] DebugGroupError),
150    #[error(transparent)]
151    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
152    #[error(transparent)]
153    DestroyedResource(#[from] DestroyedResourceError),
154    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
155    UnalignedIndirectBufferOffset(BufferAddress),
156    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
157    IndirectBufferOverrun {
158        offset: u64,
159        end_offset: u64,
160        buffer_size: u64,
161    },
162    #[error(transparent)]
163    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
164    #[error(transparent)]
165    MissingBufferUsage(#[from] MissingBufferUsageError),
166    #[error(transparent)]
167    Dispatch(#[from] DispatchError),
168    #[error(transparent)]
169    Bind(#[from] BindError),
170    #[error(transparent)]
171    PushConstants(#[from] PushConstantUploadError),
172    #[error("Push constant offset must be aligned to 4 bytes")]
173    PushConstantOffsetAlignment,
174    #[error("Push constant size must be aligned to 4 bytes")]
175    PushConstantSizeAlignment,
176    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
177    PushConstantOutOfMemory,
178    #[error(transparent)]
179    QueryUse(#[from] QueryUseError),
180    #[error(transparent)]
181    MissingFeatures(#[from] MissingFeatures),
182    #[error(transparent)]
183    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
184    #[error("The compute pass has already been ended and no further commands can be recorded")]
185    PassEnded,
186    #[error(transparent)]
187    InvalidResource(#[from] InvalidResourceError),
188    #[error(transparent)]
189    TimestampWrites(#[from] TimestampWritesError),
190    // This one is unreachable, but required for generic pass support
191    #[error(transparent)]
192    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
193}
194
195/// Error encountered when performing a compute pass, stored for later reporting
196/// when encoding ends.
197#[derive(Clone, Debug, Error)]
198#[error("{scope}")]
199pub struct ComputePassError {
200    pub scope: PassErrorScope,
201    #[source]
202    pub(super) inner: ComputePassErrorInner,
203}
204
205impl From<pass::MissingPipeline> for ComputePassErrorInner {
206    fn from(value: pass::MissingPipeline) -> Self {
207        Self::Dispatch(DispatchError::MissingPipeline(value))
208    }
209}
210
211impl<E> MapPassErr<ComputePassError> for E
212where
213    E: Into<ComputePassErrorInner>,
214{
215    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
216        ComputePassError {
217            scope,
218            inner: self.into(),
219        }
220    }
221}
222
223impl WebGpuError for ComputePassError {
224    fn webgpu_error_type(&self) -> ErrorType {
225        let Self { scope: _, inner } = self;
226        let e: &dyn WebGpuError = match inner {
227            ComputePassErrorInner::Device(e) => e,
228            ComputePassErrorInner::EncoderState(e) => e,
229            ComputePassErrorInner::DebugGroupError(e) => e,
230            ComputePassErrorInner::DestroyedResource(e) => e,
231            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
232            ComputePassErrorInner::MissingBufferUsage(e) => e,
233            ComputePassErrorInner::Dispatch(e) => e,
234            ComputePassErrorInner::Bind(e) => e,
235            ComputePassErrorInner::PushConstants(e) => e,
236            ComputePassErrorInner::QueryUse(e) => e,
237            ComputePassErrorInner::MissingFeatures(e) => e,
238            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
239            ComputePassErrorInner::InvalidResource(e) => e,
240            ComputePassErrorInner::TimestampWrites(e) => e,
241            ComputePassErrorInner::InvalidValuesOffset(e) => e,
242
243            ComputePassErrorInner::InvalidParentEncoder
244            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
245            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
246            | ComputePassErrorInner::IndirectBufferOverrun { .. }
247            | ComputePassErrorInner::PushConstantOffsetAlignment
248            | ComputePassErrorInner::PushConstantSizeAlignment
249            | ComputePassErrorInner::PushConstantOutOfMemory
250            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
251        };
252        e.webgpu_error_type()
253    }
254}
255
256struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
257    pipeline: Option<Arc<ComputePipeline>>,
258
259    general: pass::BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
260
261    active_query: Option<(Arc<resource::QuerySet>, u32)>,
262
263    push_constants: Vec<u32>,
264
265    intermediate_trackers: Tracker,
266}
267
268impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
269    State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
270{
271    fn is_ready(&self) -> Result<(), DispatchError> {
272        if let Some(pipeline) = self.pipeline.as_ref() {
273            self.general.binder.check_compatibility(pipeline.as_ref())?;
274            self.general.binder.check_late_buffer_bindings()?;
275            Ok(())
276        } else {
277            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
278        }
279    }
280
281    // `extra_buffer` is there to represent the indirect buffer that is also
282    // part of the usage scope.
283    fn flush_states(
284        &mut self,
285        indirect_buffer: Option<TrackerIndex>,
286    ) -> Result<(), ResourceUsageCompatibilityError> {
287        for bind_group in self.general.binder.list_active() {
288            unsafe { self.general.scope.merge_bind_group(&bind_group.used)? };
289            // Note: stateless trackers are not merged: the lifetime reference
290            // is held to the bind group itself.
291        }
292
293        for bind_group in self.general.binder.list_active() {
294            unsafe {
295                self.intermediate_trackers
296                    .set_and_remove_from_usage_scope_sparse(
297                        &mut self.general.scope,
298                        &bind_group.used,
299                    )
300            }
301        }
302
303        // Add the state of the indirect buffer if it hasn't been hit before.
304        unsafe {
305            self.intermediate_trackers
306                .buffers
307                .set_and_remove_from_usage_scope_sparse(
308                    &mut self.general.scope.buffers,
309                    indirect_buffer,
310                );
311        }
312
313        CommandEncoder::drain_barriers(
314            self.general.raw_encoder,
315            &mut self.intermediate_trackers,
316            self.general.snatch_guard,
317        );
318        Ok(())
319    }
320}
321
322// Running the compute pass.
323
324impl Global {
325    /// Creates a compute pass.
326    ///
327    /// If creation fails, an invalid pass is returned. Attempting to record
328    /// commands into an invalid pass is permitted, but a validation error will
329    /// ultimately be generated when the parent encoder is finished, and it is
330    /// not possible to run any commands from the invalid pass.
331    ///
332    /// If successful, puts the encoder into the [`Locked`] state.
333    ///
334    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
335    pub fn command_encoder_begin_compute_pass(
336        &self,
337        encoder_id: id::CommandEncoderId,
338        desc: &ComputePassDescriptor<'_>,
339    ) -> (ComputePass, Option<CommandEncoderError>) {
340        use EncoderStateError as SErr;
341
342        let scope = PassErrorScope::Pass;
343        let hub = &self.hub;
344
345        let label = desc.label.as_deref().map(Cow::Borrowed);
346
347        let cmd_enc = hub.command_encoders.get(encoder_id);
348        let mut cmd_buf_data = cmd_enc.data.lock();
349
350        match cmd_buf_data.lock_encoder() {
351            Ok(()) => {
352                drop(cmd_buf_data);
353                if let Err(err) = cmd_enc.device.check_is_valid() {
354                    return (
355                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
356                        None,
357                    );
358                }
359
360                match desc
361                    .timestamp_writes
362                    .as_ref()
363                    .map(|tw| {
364                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
365                            &cmd_enc.device,
366                            &hub.query_sets.read(),
367                            tw,
368                        )
369                    })
370                    .transpose()
371                {
372                    Ok(timestamp_writes) => {
373                        let arc_desc = ArcComputePassDescriptor {
374                            label,
375                            timestamp_writes,
376                        };
377                        (ComputePass::new(cmd_enc, arc_desc), None)
378                    }
379                    Err(err) => (
380                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
381                        None,
382                    ),
383                }
384            }
385            Err(err @ SErr::Locked) => {
386                // Attempting to open a new pass while the encoder is locked
387                // invalidates the encoder, but does not generate a validation
388                // error.
389                cmd_buf_data.invalidate(err.clone());
390                drop(cmd_buf_data);
391                (
392                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
393                    None,
394                )
395            }
396            Err(err @ (SErr::Ended | SErr::Submitted)) => {
397                // Attempting to open a new pass after the encode has ended
398                // generates an immediate validation error.
399                drop(cmd_buf_data);
400                (
401                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
402                    Some(err.into()),
403                )
404            }
405            Err(err @ SErr::Invalid) => {
406                // Passes can be opened even on an invalid encoder. Such passes
407                // are even valid, but since there's no visible side-effect of
408                // the pass being valid and there's no point in storing recorded
409                // commands that will ultimately be discarded, we open an
410                // invalid pass to save that work.
411                drop(cmd_buf_data);
412                (
413                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414                    None,
415                )
416            }
417            Err(SErr::Unlocked) => {
418                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
419            }
420        }
421    }
422
423    /// Note that this differs from [`Self::compute_pass_end`], it will
424    /// create a new pass, replay the commands and end the pass.
425    ///
426    /// # Panics
427    /// On any error.
428    #[doc(hidden)]
429    #[cfg(any(feature = "serde", feature = "replay"))]
430    pub fn compute_pass_end_with_unresolved_commands(
431        &self,
432        encoder_id: id::CommandEncoderId,
433        base: BasePass<super::ComputeCommand, core::convert::Infallible>,
434        timestamp_writes: Option<&PassTimestampWrites>,
435    ) {
436        #[cfg(feature = "trace")]
437        {
438            let cmd_enc = self.hub.command_encoders.get(encoder_id);
439            let mut cmd_buf_data = cmd_enc.data.lock();
440            let cmd_buf_data = cmd_buf_data.get_inner();
441
442            if let Some(ref mut list) = cmd_buf_data.commands {
443                list.push(crate::device::trace::Command::RunComputePass {
444                    base: BasePass {
445                        label: base.label.clone(),
446                        error: None,
447                        commands: base.commands.clone(),
448                        dynamic_offsets: base.dynamic_offsets.clone(),
449                        string_data: base.string_data.clone(),
450                        push_constant_data: base.push_constant_data.clone(),
451                    },
452                    timestamp_writes: timestamp_writes.cloned(),
453                });
454            }
455        }
456
457        let BasePass {
458            label,
459            error: _,
460            commands,
461            dynamic_offsets,
462            string_data,
463            push_constant_data,
464        } = base;
465
466        let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
467            encoder_id,
468            &ComputePassDescriptor {
469                label: label.as_deref().map(Cow::Borrowed),
470                timestamp_writes: timestamp_writes.cloned(),
471            },
472        );
473        if let Some(err) = encoder_error {
474            panic!("{:?}", err);
475        };
476
477        compute_pass.base = BasePass {
478            label,
479            error: None,
480            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
481                .unwrap(),
482            dynamic_offsets,
483            string_data,
484            push_constant_data,
485        };
486
487        self.compute_pass_end(&mut compute_pass).unwrap();
488    }
489
490    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
491        let pass_scope = PassErrorScope::Pass;
492        profiling::scope!(
493            "CommandEncoder::run_compute_pass {}",
494            pass.base.label.as_deref().unwrap_or("")
495        );
496
497        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
498        let mut cmd_buf_data = cmd_enc.data.lock();
499
500        if let Some(err) = pass.base.error.take() {
501            if matches!(
502                err,
503                ComputePassError {
504                    inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
505                    scope: _,
506                }
507            ) {
508                // If the encoder was already finished at time of pass creation,
509                // then it was not put in the locked state, so we need to
510                // generate a validation error here due to the encoder not being
511                // locked. The encoder already has a copy of the error.
512                return Err(EncoderStateError::Ended);
513            } else {
514                // If the pass is invalid, invalidate the parent encoder and return.
515                // Since we do not track the state of an invalid encoder, it is not
516                // necessary to unlock it.
517                cmd_buf_data.invalidate(err);
518                return Ok(());
519            }
520        }
521
522        cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
523            let device = &cmd_enc.device;
524            device.check_is_valid().map_pass_err(pass_scope)?;
525
526            let base = &mut pass.base;
527
528            let encoder = &mut cmd_buf_data.encoder;
529
530            // We automatically keep extending command buffers over time, and because
531            // we want to insert a command buffer _before_ what we're about to record,
532            // we need to make sure to close the previous one.
533            encoder.close_if_open().map_pass_err(pass_scope)?;
534            let raw_encoder = encoder
535                .open_pass(base.label.as_deref())
536                .map_pass_err(pass_scope)?;
537
538            let snatch_guard = device.snatchable_lock.read();
539
540            let mut state = State {
541                pipeline: None,
542
543                general: pass::BaseState {
544                    device,
545                    raw_encoder,
546                    tracker: &mut cmd_buf_data.trackers,
547                    buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
548                    texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
549                    as_actions: &mut cmd_buf_data.as_actions,
550                    binder: Binder::new(),
551                    temp_offsets: Vec::new(),
552                    dynamic_offset_count: 0,
553
554                    pending_discard_init_fixups: SurfacesInDiscardState::new(),
555
556                    snatch_guard: &snatch_guard,
557                    scope: device.new_usage_scope(),
558
559                    debug_scope_depth: 0,
560                    string_offset: 0,
561                },
562                active_query: None,
563
564                push_constants: Vec::new(),
565
566                intermediate_trackers: Tracker::new(),
567            };
568
569            let indices = &state.general.device.tracker_indices;
570            state
571                .general
572                .tracker
573                .buffers
574                .set_size(indices.buffers.size());
575            state
576                .general
577                .tracker
578                .textures
579                .set_size(indices.textures.size());
580
581            let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
582                if let Some(tw) = pass.timestamp_writes.take() {
583                    tw.query_set
584                        .same_device_as(cmd_enc.as_ref())
585                        .map_pass_err(pass_scope)?;
586
587                    let query_set = state.general.tracker.query_sets.insert_single(tw.query_set);
588
589                    // Unlike in render passes we can't delay resetting the query sets since
590                    // there is no auxiliary pass.
591                    let range = if let (Some(index_a), Some(index_b)) =
592                        (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
593                    {
594                        Some(index_a.min(index_b)..index_a.max(index_b) + 1)
595                    } else {
596                        tw.beginning_of_pass_write_index
597                            .or(tw.end_of_pass_write_index)
598                            .map(|i| i..i + 1)
599                    };
600                    // Range should always be Some, both values being None should lead to a validation error.
601                    // But no point in erroring over that nuance here!
602                    if let Some(range) = range {
603                        unsafe {
604                            state
605                                .general
606                                .raw_encoder
607                                .reset_queries(query_set.raw(), range);
608                        }
609                    }
610
611                    Some(hal::PassTimestampWrites {
612                        query_set: query_set.raw(),
613                        beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
614                        end_of_pass_write_index: tw.end_of_pass_write_index,
615                    })
616                } else {
617                    None
618                };
619
620            let hal_desc = hal::ComputePassDescriptor {
621                label: hal_label(base.label.as_deref(), device.instance_flags),
622                timestamp_writes,
623            };
624
625            unsafe {
626                state.general.raw_encoder.begin_compute_pass(&hal_desc);
627            }
628
629            for command in base.commands.drain(..) {
630                match command {
631                    ArcComputeCommand::SetBindGroup {
632                        index,
633                        num_dynamic_offsets,
634                        bind_group,
635                    } => {
636                        let scope = PassErrorScope::SetBindGroup;
637                        pass::set_bind_group::<ComputePassErrorInner>(
638                            &mut state.general,
639                            cmd_enc.as_ref(),
640                            &base.dynamic_offsets,
641                            index,
642                            num_dynamic_offsets,
643                            bind_group,
644                            false,
645                        )
646                        .map_pass_err(scope)?;
647                    }
648                    ArcComputeCommand::SetPipeline(pipeline) => {
649                        let scope = PassErrorScope::SetPipelineCompute;
650                        set_pipeline(&mut state, cmd_enc.as_ref(), pipeline).map_pass_err(scope)?;
651                    }
652                    ArcComputeCommand::SetPushConstant {
653                        offset,
654                        size_bytes,
655                        values_offset,
656                    } => {
657                        let scope = PassErrorScope::SetPushConstant;
658                        pass::set_push_constant::<ComputePassErrorInner, _>(
659                            &mut state.general,
660                            &base.push_constant_data,
661                            wgt::ShaderStages::COMPUTE,
662                            offset,
663                            size_bytes,
664                            Some(values_offset),
665                            |data_slice| {
666                                let offset_in_elements =
667                                    (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
668                                let size_in_elements =
669                                    (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
670                                state.push_constants[offset_in_elements..][..size_in_elements]
671                                    .copy_from_slice(data_slice);
672                            },
673                        )
674                        .map_pass_err(scope)?;
675                    }
676                    ArcComputeCommand::Dispatch(groups) => {
677                        let scope = PassErrorScope::Dispatch { indirect: false };
678                        dispatch(&mut state, groups).map_pass_err(scope)?;
679                    }
680                    ArcComputeCommand::DispatchIndirect { buffer, offset } => {
681                        let scope = PassErrorScope::Dispatch { indirect: true };
682                        dispatch_indirect(&mut state, cmd_enc.as_ref(), buffer, offset)
683                            .map_pass_err(scope)?;
684                    }
685                    ArcComputeCommand::PushDebugGroup { color: _, len } => {
686                        pass::push_debug_group(&mut state.general, &base.string_data, len);
687                    }
688                    ArcComputeCommand::PopDebugGroup => {
689                        let scope = PassErrorScope::PopDebugGroup;
690                        pass::pop_debug_group::<ComputePassErrorInner>(&mut state.general)
691                            .map_pass_err(scope)?;
692                    }
693                    ArcComputeCommand::InsertDebugMarker { color: _, len } => {
694                        pass::insert_debug_marker(&mut state.general, &base.string_data, len);
695                    }
696                    ArcComputeCommand::WriteTimestamp {
697                        query_set,
698                        query_index,
699                    } => {
700                        let scope = PassErrorScope::WriteTimestamp;
701                        pass::write_timestamp::<ComputePassErrorInner>(
702                            &mut state.general,
703                            cmd_enc.as_ref(),
704                            None,
705                            query_set,
706                            query_index,
707                        )
708                        .map_pass_err(scope)?;
709                    }
710                    ArcComputeCommand::BeginPipelineStatisticsQuery {
711                        query_set,
712                        query_index,
713                    } => {
714                        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
715                        validate_and_begin_pipeline_statistics_query(
716                            query_set,
717                            state.general.raw_encoder,
718                            &mut state.general.tracker.query_sets,
719                            cmd_enc.as_ref(),
720                            query_index,
721                            None,
722                            &mut state.active_query,
723                        )
724                        .map_pass_err(scope)?;
725                    }
726                    ArcComputeCommand::EndPipelineStatisticsQuery => {
727                        let scope = PassErrorScope::EndPipelineStatisticsQuery;
728                        end_pipeline_statistics_query(
729                            state.general.raw_encoder,
730                            &mut state.active_query,
731                        )
732                        .map_pass_err(scope)?;
733                    }
734                }
735            }
736
737            if state.general.debug_scope_depth > 0 {
738                Err(
739                    ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
740                        .map_pass_err(pass_scope),
741                )?;
742            }
743
744            unsafe {
745                state.general.raw_encoder.end_compute_pass();
746            }
747
748            let State {
749                general:
750                    pass::BaseState {
751                        tracker,
752                        pending_discard_init_fixups,
753                        ..
754                    },
755                intermediate_trackers,
756                ..
757            } = state;
758
759            // Stop the current command encoder.
760            encoder.close().map_pass_err(pass_scope)?;
761
762            // Create a new command encoder, which we will insert _before_ the body of the compute pass.
763            //
764            // Use that buffer to insert barriers and clear discarded images.
765            let transit = encoder
766                .open_pass(hal_label(
767                    Some("(wgpu internal) Pre Pass"),
768                    self.instance.flags,
769                ))
770                .map_pass_err(pass_scope)?;
771            fixup_discarded_surfaces(
772                pending_discard_init_fixups.into_iter(),
773                transit,
774                &mut tracker.textures,
775                device,
776                &snatch_guard,
777            );
778            CommandEncoder::insert_barriers_from_tracker(
779                transit,
780                tracker,
781                &intermediate_trackers,
782                &snatch_guard,
783            );
784            // Close the command encoder, and swap it with the previous.
785            encoder.close_and_swap().map_pass_err(pass_scope)?;
786
787            Ok(())
788        })
789    }
790}
791
792fn set_pipeline(
793    state: &mut State,
794    cmd_enc: &CommandEncoder,
795    pipeline: Arc<ComputePipeline>,
796) -> Result<(), ComputePassErrorInner> {
797    pipeline.same_device_as(cmd_enc)?;
798
799    state.pipeline = Some(pipeline.clone());
800
801    let pipeline = state
802        .general
803        .tracker
804        .compute_pipelines
805        .insert_single(pipeline)
806        .clone();
807
808    unsafe {
809        state
810            .general
811            .raw_encoder
812            .set_compute_pipeline(pipeline.raw());
813    }
814
815    // Rebind resources
816    pass::rebind_resources::<ComputePassErrorInner, _>(
817        &mut state.general,
818        &pipeline.layout,
819        &pipeline.late_sized_buffer_groups,
820        || {
821            // This only needs to be here for compute pipelines because they use push constants for
822            // validating indirect draws.
823            state.push_constants.clear();
824            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
825            if let Some(push_constant_range) =
826                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
827                    pcr.stages
828                        .contains(wgt::ShaderStages::COMPUTE)
829                        .then_some(pcr.range.clone())
830                })
831            {
832                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
833                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
834                state.push_constants.extend(core::iter::repeat_n(0, len));
835            }
836        },
837    )
838}
839
840fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
841    state.is_ready()?;
842
843    state.flush_states(None)?;
844
845    let groups_size_limit = state
846        .general
847        .device
848        .limits
849        .max_compute_workgroups_per_dimension;
850
851    if groups[0] > groups_size_limit
852        || groups[1] > groups_size_limit
853        || groups[2] > groups_size_limit
854    {
855        return Err(ComputePassErrorInner::Dispatch(
856            DispatchError::InvalidGroupSize {
857                current: groups,
858                limit: groups_size_limit,
859            },
860        ));
861    }
862
863    unsafe {
864        state.general.raw_encoder.dispatch(groups);
865    }
866    Ok(())
867}
868
869fn dispatch_indirect(
870    state: &mut State,
871    cmd_enc: &CommandEncoder,
872    buffer: Arc<Buffer>,
873    offset: u64,
874) -> Result<(), ComputePassErrorInner> {
875    buffer.same_device_as(cmd_enc)?;
876
877    state.is_ready()?;
878
879    state
880        .general
881        .device
882        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
883
884    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
885    buffer.check_destroyed(state.general.snatch_guard)?;
886
887    if offset % 4 != 0 {
888        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
889    }
890
891    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
892    if end_offset > buffer.size {
893        return Err(ComputePassErrorInner::IndirectBufferOverrun {
894            offset,
895            end_offset,
896            buffer_size: buffer.size,
897        });
898    }
899
900    let stride = 3 * 4; // 3 integers, x/y/z group size
901    state.general.buffer_memory_init_actions.extend(
902        buffer.initialization_status.read().create_action(
903            &buffer,
904            offset..(offset + stride),
905            MemoryInitKind::NeedsInitializedMemory,
906        ),
907    );
908
909    if let Some(ref indirect_validation) = state.general.device.indirect_validation {
910        let params =
911            indirect_validation
912                .dispatch
913                .params(&state.general.device.limits, offset, buffer.size);
914
915        unsafe {
916            state
917                .general
918                .raw_encoder
919                .set_compute_pipeline(params.pipeline);
920        }
921
922        unsafe {
923            state.general.raw_encoder.set_push_constants(
924                params.pipeline_layout,
925                wgt::ShaderStages::COMPUTE,
926                0,
927                &[params.offset_remainder as u32 / 4],
928            );
929        }
930
931        unsafe {
932            state.general.raw_encoder.set_bind_group(
933                params.pipeline_layout,
934                0,
935                Some(params.dst_bind_group),
936                &[],
937            );
938        }
939        unsafe {
940            state.general.raw_encoder.set_bind_group(
941                params.pipeline_layout,
942                1,
943                Some(
944                    buffer
945                        .indirect_validation_bind_groups
946                        .get(state.general.snatch_guard)
947                        .unwrap()
948                        .dispatch
949                        .as_ref(),
950                ),
951                &[params.aligned_offset as u32],
952            );
953        }
954
955        let src_transition = state
956            .intermediate_trackers
957            .buffers
958            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
959        let src_barrier = src_transition
960            .map(|transition| transition.into_hal(&buffer, state.general.snatch_guard));
961        unsafe {
962            state
963                .general
964                .raw_encoder
965                .transition_buffers(src_barrier.as_slice());
966        }
967
968        unsafe {
969            state
970                .general
971                .raw_encoder
972                .transition_buffers(&[hal::BufferBarrier {
973                    buffer: params.dst_buffer,
974                    usage: hal::StateTransition {
975                        from: wgt::BufferUses::INDIRECT,
976                        to: wgt::BufferUses::STORAGE_READ_WRITE,
977                    },
978                }]);
979        }
980
981        unsafe {
982            state.general.raw_encoder.dispatch([1, 1, 1]);
983        }
984
985        // reset state
986        {
987            let pipeline = state.pipeline.as_ref().unwrap();
988
989            unsafe {
990                state
991                    .general
992                    .raw_encoder
993                    .set_compute_pipeline(pipeline.raw());
994            }
995
996            if !state.push_constants.is_empty() {
997                unsafe {
998                    state.general.raw_encoder.set_push_constants(
999                        pipeline.layout.raw(),
1000                        wgt::ShaderStages::COMPUTE,
1001                        0,
1002                        &state.push_constants,
1003                    );
1004                }
1005            }
1006
1007            for (i, e) in state.general.binder.list_valid() {
1008                let group = e.group.as_ref().unwrap();
1009                let raw_bg = group.try_raw(state.general.snatch_guard)?;
1010                unsafe {
1011                    state.general.raw_encoder.set_bind_group(
1012                        pipeline.layout.raw(),
1013                        i as u32,
1014                        Some(raw_bg),
1015                        &e.dynamic_offsets,
1016                    );
1017                }
1018            }
1019        }
1020
1021        unsafe {
1022            state
1023                .general
1024                .raw_encoder
1025                .transition_buffers(&[hal::BufferBarrier {
1026                    buffer: params.dst_buffer,
1027                    usage: hal::StateTransition {
1028                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1029                        to: wgt::BufferUses::INDIRECT,
1030                    },
1031                }]);
1032        }
1033
1034        state.flush_states(None)?;
1035        unsafe {
1036            state
1037                .general
1038                .raw_encoder
1039                .dispatch_indirect(params.dst_buffer, 0);
1040        }
1041    } else {
1042        state
1043            .general
1044            .scope
1045            .buffers
1046            .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1047
1048        use crate::resource::Trackable;
1049        state.flush_states(Some(buffer.tracker_index()))?;
1050
1051        let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
1052        unsafe {
1053            state.general.raw_encoder.dispatch_indirect(buf_raw, offset);
1054        }
1055    }
1056
1057    Ok(())
1058}
1059
1060// Recording a compute pass.
1061//
1062// The only error that should be returned from these methods is
1063// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1064// validation error is raised.
1065//
1066// All other errors should be stored in the pass for later reporting when
1067// `CommandEncoder.finish()` is called.
1068//
1069// The `pass_try!` macro should be used to handle errors appropriately. Note
1070// that the `pass_try!` and `pass_base!` macros may return early from the
1071// function that invokes them, like the `?` operator.
1072impl Global {
1073    pub fn compute_pass_set_bind_group(
1074        &self,
1075        pass: &mut ComputePass,
1076        index: u32,
1077        bind_group_id: Option<id::BindGroupId>,
1078        offsets: &[DynamicOffset],
1079    ) -> Result<(), PassStateError> {
1080        let scope = PassErrorScope::SetBindGroup;
1081
1082        // This statement will return an error if the pass is ended. It's
1083        // important the error check comes before the early-out for
1084        // `set_and_check_redundant`.
1085        let base = pass_base!(pass, scope);
1086
1087        if pass.current_bind_groups.set_and_check_redundant(
1088            bind_group_id,
1089            index,
1090            &mut base.dynamic_offsets,
1091            offsets,
1092        ) {
1093            return Ok(());
1094        }
1095
1096        let mut bind_group = None;
1097        if bind_group_id.is_some() {
1098            let bind_group_id = bind_group_id.unwrap();
1099
1100            let hub = &self.hub;
1101            bind_group = Some(pass_try!(
1102                base,
1103                scope,
1104                hub.bind_groups.get(bind_group_id).get(),
1105            ));
1106        }
1107
1108        base.commands.push(ArcComputeCommand::SetBindGroup {
1109            index,
1110            num_dynamic_offsets: offsets.len(),
1111            bind_group,
1112        });
1113
1114        Ok(())
1115    }
1116
1117    pub fn compute_pass_set_pipeline(
1118        &self,
1119        pass: &mut ComputePass,
1120        pipeline_id: id::ComputePipelineId,
1121    ) -> Result<(), PassStateError> {
1122        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1123
1124        let scope = PassErrorScope::SetPipelineCompute;
1125
1126        // This statement will return an error if the pass is ended.
1127        // Its important the error check comes before the early-out for `redundant`.
1128        let base = pass_base!(pass, scope);
1129
1130        if redundant {
1131            return Ok(());
1132        }
1133
1134        let hub = &self.hub;
1135        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1136
1137        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1138
1139        Ok(())
1140    }
1141
1142    pub fn compute_pass_set_push_constants(
1143        &self,
1144        pass: &mut ComputePass,
1145        offset: u32,
1146        data: &[u8],
1147    ) -> Result<(), PassStateError> {
1148        let scope = PassErrorScope::SetPushConstant;
1149        let base = pass_base!(pass, scope);
1150
1151        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1152            pass_try!(
1153                base,
1154                scope,
1155                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1156            );
1157        }
1158
1159        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1160            pass_try!(
1161                base,
1162                scope,
1163                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1164            )
1165        }
1166        let value_offset = pass_try!(
1167            base,
1168            scope,
1169            base.push_constant_data
1170                .len()
1171                .try_into()
1172                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1173        );
1174
1175        base.push_constant_data.extend(
1176            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1177                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1178        );
1179
1180        base.commands.push(ArcComputeCommand::SetPushConstant {
1181            offset,
1182            size_bytes: data.len() as u32,
1183            values_offset: value_offset,
1184        });
1185
1186        Ok(())
1187    }
1188
1189    pub fn compute_pass_dispatch_workgroups(
1190        &self,
1191        pass: &mut ComputePass,
1192        groups_x: u32,
1193        groups_y: u32,
1194        groups_z: u32,
1195    ) -> Result<(), PassStateError> {
1196        let scope = PassErrorScope::Dispatch { indirect: false };
1197
1198        pass_base!(pass, scope)
1199            .commands
1200            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1201
1202        Ok(())
1203    }
1204
1205    pub fn compute_pass_dispatch_workgroups_indirect(
1206        &self,
1207        pass: &mut ComputePass,
1208        buffer_id: id::BufferId,
1209        offset: BufferAddress,
1210    ) -> Result<(), PassStateError> {
1211        let hub = &self.hub;
1212        let scope = PassErrorScope::Dispatch { indirect: true };
1213        let base = pass_base!(pass, scope);
1214
1215        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1216
1217        base.commands
1218            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1219
1220        Ok(())
1221    }
1222
1223    pub fn compute_pass_push_debug_group(
1224        &self,
1225        pass: &mut ComputePass,
1226        label: &str,
1227        color: u32,
1228    ) -> Result<(), PassStateError> {
1229        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1230
1231        let bytes = label.as_bytes();
1232        base.string_data.extend_from_slice(bytes);
1233
1234        base.commands.push(ArcComputeCommand::PushDebugGroup {
1235            color,
1236            len: bytes.len(),
1237        });
1238
1239        Ok(())
1240    }
1241
1242    pub fn compute_pass_pop_debug_group(
1243        &self,
1244        pass: &mut ComputePass,
1245    ) -> Result<(), PassStateError> {
1246        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1247
1248        base.commands.push(ArcComputeCommand::PopDebugGroup);
1249
1250        Ok(())
1251    }
1252
1253    pub fn compute_pass_insert_debug_marker(
1254        &self,
1255        pass: &mut ComputePass,
1256        label: &str,
1257        color: u32,
1258    ) -> Result<(), PassStateError> {
1259        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1260
1261        let bytes = label.as_bytes();
1262        base.string_data.extend_from_slice(bytes);
1263
1264        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1265            color,
1266            len: bytes.len(),
1267        });
1268
1269        Ok(())
1270    }
1271
1272    pub fn compute_pass_write_timestamp(
1273        &self,
1274        pass: &mut ComputePass,
1275        query_set_id: id::QuerySetId,
1276        query_index: u32,
1277    ) -> Result<(), PassStateError> {
1278        let scope = PassErrorScope::WriteTimestamp;
1279        let base = pass_base!(pass, scope);
1280
1281        let hub = &self.hub;
1282        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1283
1284        base.commands.push(ArcComputeCommand::WriteTimestamp {
1285            query_set,
1286            query_index,
1287        });
1288
1289        Ok(())
1290    }
1291
1292    pub fn compute_pass_begin_pipeline_statistics_query(
1293        &self,
1294        pass: &mut ComputePass,
1295        query_set_id: id::QuerySetId,
1296        query_index: u32,
1297    ) -> Result<(), PassStateError> {
1298        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1299        let base = pass_base!(pass, scope);
1300
1301        let hub = &self.hub;
1302        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1303
1304        base.commands
1305            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1306                query_set,
1307                query_index,
1308            });
1309
1310        Ok(())
1311    }
1312
1313    pub fn compute_pass_end_pipeline_statistics_query(
1314        &self,
1315        pass: &mut ComputePass,
1316    ) -> Result<(), PassStateError> {
1317        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1318            .commands
1319            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1320
1321        Ok(())
1322    }
1323}