wgpu_core/command/
compute.rs

1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{
11    encoder::EncodingState, pass, CommandBufferMutable, CommandEncoder, DebugGroupError,
12    EncoderStateError, PassStateError, TimestampWritesError,
13};
14use crate::resource::DestroyedResourceError;
15use crate::{binding_model::BindError, resource::RawResourceAccess};
16use crate::{
17    binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
18    command::{
19        bind::{Binder, BinderError},
20        compute_command::ArcComputeCommand,
21        end_pipeline_statistics_query,
22        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23        pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24        BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25        PassTimestampWrites, QueryUseError, StateChange,
26    },
27    device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28    global::Global,
29    hal_label, id,
30    init_tracker::MemoryInitKind,
31    pipeline::ComputePipeline,
32    resource::{
33        self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34    },
35    track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
36    Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
42/// its validity are two distinct conditions, i.e., the full matrix of
43/// (open, ended) x (valid, invalid) is possible.
44///
45/// The presence or absence of the `parent` `Option` indicates the pass's state.
46/// The presence or absence of an error in `base.error` indicates the pass's
47/// validity.
48pub struct ComputePass {
49    /// All pass data & records is stored here.
50    base: ComputeBasePass,
51
52    /// Parent command encoder that this pass records commands into.
53    ///
54    /// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
55    /// `None`, then the pass is in the "ended" state.
56    /// See <https://www.w3.org/TR/webgpu/#encoder-state>
57    parent: Option<Arc<CommandEncoder>>,
58
59    timestamp_writes: Option<ArcPassTimestampWrites>,
60
61    // Resource binding dedupe state.
62    current_bind_groups: BindGroupStateChange,
63    current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67    /// If the parent command encoder is invalid, the returned pass will be invalid.
68    fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69        let ArcComputePassDescriptor {
70            label,
71            timestamp_writes,
72        } = desc;
73
74        Self {
75            base: BasePass::new(&label),
76            parent: Some(parent),
77            timestamp_writes,
78
79            current_bind_groups: BindGroupStateChange::new(),
80            current_pipeline: StateChange::new(),
81        }
82    }
83
84    fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85        Self {
86            base: BasePass::new_invalid(label, err),
87            parent: Some(parent),
88            timestamp_writes: None,
89            current_bind_groups: BindGroupStateChange::new(),
90            current_pipeline: StateChange::new(),
91        }
92    }
93
94    #[inline]
95    pub fn label(&self) -> Option<&str> {
96        self.base.label.as_deref()
97    }
98}
99
100impl fmt::Debug for ComputePass {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self.parent {
103            Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104            None => write!(f, "ComputePass {{ parent: None }}"),
105        }
106    }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111    pub label: Label<'a>,
112    /// Defines where and when timestamp values will be written for this pass.
113    pub timestamp_writes: Option<PTW>,
114}
115
116/// cbindgen:ignore
117type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122    #[error("Compute pipeline must be set")]
123    MissingPipeline(pass::MissingPipeline),
124    #[error(transparent)]
125    IncompatibleBindGroup(#[from] Box<BinderError>),
126    #[error(
127        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128    )]
129    InvalidGroupSize { current: [u32; 3], limit: u32 },
130    #[error(transparent)]
131    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135    fn webgpu_error_type(&self) -> ErrorType {
136        ErrorType::Validation
137    }
138}
139
140/// Error encountered when performing a compute pass.
141#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143    #[error(transparent)]
144    Device(#[from] DeviceError),
145    #[error(transparent)]
146    EncoderState(#[from] EncoderStateError),
147    #[error("Parent encoder is invalid")]
148    InvalidParentEncoder,
149    #[error(transparent)]
150    DebugGroupError(#[from] DebugGroupError),
151    #[error(transparent)]
152    BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153    #[error(transparent)]
154    DestroyedResource(#[from] DestroyedResourceError),
155    #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156    UnalignedIndirectBufferOffset(BufferAddress),
157    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158    IndirectBufferOverrun {
159        offset: u64,
160        end_offset: u64,
161        buffer_size: u64,
162    },
163    #[error(transparent)]
164    ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165    #[error(transparent)]
166    MissingBufferUsage(#[from] MissingBufferUsageError),
167    #[error(transparent)]
168    Dispatch(#[from] DispatchError),
169    #[error(transparent)]
170    Bind(#[from] BindError),
171    #[error(transparent)]
172    PushConstants(#[from] PushConstantUploadError),
173    #[error("Push constant offset must be aligned to 4 bytes")]
174    PushConstantOffsetAlignment,
175    #[error("Push constant size must be aligned to 4 bytes")]
176    PushConstantSizeAlignment,
177    #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
178    PushConstantOutOfMemory,
179    #[error(transparent)]
180    QueryUse(#[from] QueryUseError),
181    #[error(transparent)]
182    MissingFeatures(#[from] MissingFeatures),
183    #[error(transparent)]
184    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185    #[error("The compute pass has already been ended and no further commands can be recorded")]
186    PassEnded,
187    #[error(transparent)]
188    InvalidResource(#[from] InvalidResourceError),
189    #[error(transparent)]
190    TimestampWrites(#[from] TimestampWritesError),
191    // This one is unreachable, but required for generic pass support
192    #[error(transparent)]
193    InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196/// Error encountered when performing a compute pass, stored for later reporting
197/// when encoding ends.
198#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201    pub scope: PassErrorScope,
202    #[source]
203    pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207    fn from(value: pass::MissingPipeline) -> Self {
208        Self::Dispatch(DispatchError::MissingPipeline(value))
209    }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214    E: Into<ComputePassErrorInner>,
215{
216    fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217        ComputePassError {
218            scope,
219            inner: self.into(),
220        }
221    }
222}
223
224impl WebGpuError for ComputePassError {
225    fn webgpu_error_type(&self) -> ErrorType {
226        let Self { scope: _, inner } = self;
227        let e: &dyn WebGpuError = match inner {
228            ComputePassErrorInner::Device(e) => e,
229            ComputePassErrorInner::EncoderState(e) => e,
230            ComputePassErrorInner::DebugGroupError(e) => e,
231            ComputePassErrorInner::DestroyedResource(e) => e,
232            ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
233            ComputePassErrorInner::MissingBufferUsage(e) => e,
234            ComputePassErrorInner::Dispatch(e) => e,
235            ComputePassErrorInner::Bind(e) => e,
236            ComputePassErrorInner::PushConstants(e) => e,
237            ComputePassErrorInner::QueryUse(e) => e,
238            ComputePassErrorInner::MissingFeatures(e) => e,
239            ComputePassErrorInner::MissingDownlevelFlags(e) => e,
240            ComputePassErrorInner::InvalidResource(e) => e,
241            ComputePassErrorInner::TimestampWrites(e) => e,
242            ComputePassErrorInner::InvalidValuesOffset(e) => e,
243
244            ComputePassErrorInner::InvalidParentEncoder
245            | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246            | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247            | ComputePassErrorInner::IndirectBufferOverrun { .. }
248            | ComputePassErrorInner::PushConstantOffsetAlignment
249            | ComputePassErrorInner::PushConstantSizeAlignment
250            | ComputePassErrorInner::PushConstantOutOfMemory
251            | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
252        };
253        e.webgpu_error_type()
254    }
255}
256
257struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
258    pipeline: Option<Arc<ComputePipeline>>,
259
260    pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
261
262    active_query: Option<(Arc<resource::QuerySet>, u32)>,
263
264    push_constants: Vec<u32>,
265
266    intermediate_trackers: Tracker,
267}
268
269impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
270    State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
271{
272    fn is_ready(&self) -> Result<(), DispatchError> {
273        if let Some(pipeline) = self.pipeline.as_ref() {
274            self.pass.binder.check_compatibility(pipeline.as_ref())?;
275            self.pass.binder.check_late_buffer_bindings()?;
276            Ok(())
277        } else {
278            Err(DispatchError::MissingPipeline(pass::MissingPipeline))
279        }
280    }
281
282    // `extra_buffer` is there to represent the indirect buffer that is also
283    // part of the usage scope.
284    fn flush_states(
285        &mut self,
286        indirect_buffer: Option<TrackerIndex>,
287    ) -> Result<(), ResourceUsageCompatibilityError> {
288        for bind_group in self.pass.binder.list_active() {
289            unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
290            // Note: stateless trackers are not merged: the lifetime reference
291            // is held to the bind group itself.
292        }
293
294        for bind_group in self.pass.binder.list_active() {
295            unsafe {
296                self.intermediate_trackers
297                    .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used)
298            }
299        }
300
301        // Add the state of the indirect buffer if it hasn't been hit before.
302        unsafe {
303            self.intermediate_trackers
304                .buffers
305                .set_and_remove_from_usage_scope_sparse(
306                    &mut self.pass.scope.buffers,
307                    indirect_buffer,
308                );
309        }
310
311        CommandEncoder::drain_barriers(
312            self.pass.base.raw_encoder,
313            &mut self.intermediate_trackers,
314            self.pass.base.snatch_guard,
315        );
316        Ok(())
317    }
318}
319
320// Running the compute pass.
321
322impl Global {
323    /// Creates a compute pass.
324    ///
325    /// If creation fails, an invalid pass is returned. Attempting to record
326    /// commands into an invalid pass is permitted, but a validation error will
327    /// ultimately be generated when the parent encoder is finished, and it is
328    /// not possible to run any commands from the invalid pass.
329    ///
330    /// If successful, puts the encoder into the [`Locked`] state.
331    ///
332    /// [`Locked`]: crate::command::CommandEncoderStatus::Locked
333    pub fn command_encoder_begin_compute_pass(
334        &self,
335        encoder_id: id::CommandEncoderId,
336        desc: &ComputePassDescriptor<'_>,
337    ) -> (ComputePass, Option<CommandEncoderError>) {
338        use EncoderStateError as SErr;
339
340        let scope = PassErrorScope::Pass;
341        let hub = &self.hub;
342
343        let label = desc.label.as_deref().map(Cow::Borrowed);
344
345        let cmd_enc = hub.command_encoders.get(encoder_id);
346        let mut cmd_buf_data = cmd_enc.data.lock();
347
348        match cmd_buf_data.lock_encoder() {
349            Ok(()) => {
350                drop(cmd_buf_data);
351                if let Err(err) = cmd_enc.device.check_is_valid() {
352                    return (
353                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
354                        None,
355                    );
356                }
357
358                match desc
359                    .timestamp_writes
360                    .as_ref()
361                    .map(|tw| {
362                        Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
363                            &cmd_enc.device,
364                            &hub.query_sets.read(),
365                            tw,
366                        )
367                    })
368                    .transpose()
369                {
370                    Ok(timestamp_writes) => {
371                        let arc_desc = ArcComputePassDescriptor {
372                            label,
373                            timestamp_writes,
374                        };
375                        (ComputePass::new(cmd_enc, arc_desc), None)
376                    }
377                    Err(err) => (
378                        ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
379                        None,
380                    ),
381                }
382            }
383            Err(err @ SErr::Locked) => {
384                // Attempting to open a new pass while the encoder is locked
385                // invalidates the encoder, but does not generate a validation
386                // error.
387                cmd_buf_data.invalidate(err.clone());
388                drop(cmd_buf_data);
389                (
390                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
391                    None,
392                )
393            }
394            Err(err @ (SErr::Ended | SErr::Submitted)) => {
395                // Attempting to open a new pass after the encode has ended
396                // generates an immediate validation error.
397                drop(cmd_buf_data);
398                (
399                    ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
400                    Some(err.into()),
401                )
402            }
403            Err(err @ SErr::Invalid) => {
404                // Passes can be opened even on an invalid encoder. Such passes
405                // are even valid, but since there's no visible side-effect of
406                // the pass being valid and there's no point in storing recorded
407                // commands that will ultimately be discarded, we open an
408                // invalid pass to save that work.
409                drop(cmd_buf_data);
410                (
411                    ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
412                    None,
413                )
414            }
415            Err(SErr::Unlocked) => {
416                unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
417            }
418        }
419    }
420
421    /// Note that this differs from [`Self::compute_pass_end`], it will
422    /// create a new pass, replay the commands and end the pass.
423    ///
424    /// # Panics
425    /// On any error.
426    #[doc(hidden)]
427    #[cfg(any(feature = "serde", feature = "replay"))]
428    pub fn compute_pass_end_with_unresolved_commands(
429        &self,
430        encoder_id: id::CommandEncoderId,
431        base: BasePass<super::ComputeCommand, core::convert::Infallible>,
432        timestamp_writes: Option<&PassTimestampWrites>,
433    ) {
434        #[cfg(feature = "trace")]
435        {
436            let cmd_enc = self.hub.command_encoders.get(encoder_id);
437            let mut cmd_buf_data = cmd_enc.data.lock();
438            let cmd_buf_data = cmd_buf_data.get_inner();
439
440            if let Some(ref mut list) = cmd_buf_data.trace_commands {
441                list.push(crate::command::Command::RunComputePass {
442                    base: BasePass {
443                        label: base.label.clone(),
444                        error: None,
445                        commands: base.commands.clone(),
446                        dynamic_offsets: base.dynamic_offsets.clone(),
447                        string_data: base.string_data.clone(),
448                        push_constant_data: base.push_constant_data.clone(),
449                    },
450                    timestamp_writes: timestamp_writes.cloned(),
451                });
452            }
453        }
454
455        let BasePass {
456            label,
457            error: _,
458            commands,
459            dynamic_offsets,
460            string_data,
461            push_constant_data,
462        } = base;
463
464        let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
465            encoder_id,
466            &ComputePassDescriptor {
467                label: label.as_deref().map(Cow::Borrowed),
468                timestamp_writes: timestamp_writes.cloned(),
469            },
470        );
471        if let Some(err) = encoder_error {
472            panic!("{:?}", err);
473        };
474
475        compute_pass.base = BasePass {
476            label,
477            error: None,
478            commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
479                .unwrap(),
480            dynamic_offsets,
481            string_data,
482            push_constant_data,
483        };
484
485        self.compute_pass_end(&mut compute_pass).unwrap();
486    }
487
488    pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
489        profiling::scope!(
490            "CommandEncoder::run_compute_pass {}",
491            pass.base.label.as_deref().unwrap_or("")
492        );
493
494        let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
495        let mut cmd_buf_data = cmd_enc.data.lock();
496
497        if let Some(err) = pass.base.error.take() {
498            if matches!(
499                err,
500                ComputePassError {
501                    inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
502                    scope: _,
503                }
504            ) {
505                // If the encoder was already finished at time of pass creation,
506                // then it was not put in the locked state, so we need to
507                // generate a validation error here due to the encoder not being
508                // locked. The encoder already has a copy of the error.
509                return Err(EncoderStateError::Ended);
510            } else {
511                // If the pass is invalid, invalidate the parent encoder and return.
512                // Since we do not track the state of an invalid encoder, it is not
513                // necessary to unlock it.
514                cmd_buf_data.invalidate(err);
515                return Ok(());
516            }
517        }
518
519        cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
520            encode_compute_pass(cmd_buf_data, &cmd_enc, pass)
521        })
522    }
523}
524
525fn encode_compute_pass(
526    cmd_buf_data: &mut CommandBufferMutable,
527    cmd_enc: &Arc<CommandEncoder>,
528    pass: &mut ComputePass,
529) -> Result<(), ComputePassError> {
530    let pass_scope = PassErrorScope::Pass;
531
532    let device = &cmd_enc.device;
533    device.check_is_valid().map_pass_err(pass_scope)?;
534
535    let base = &mut pass.base;
536
537    let encoder = &mut cmd_buf_data.encoder;
538
539    // We automatically keep extending command buffers over time, and because
540    // we want to insert a command buffer _before_ what we're about to record,
541    // we need to make sure to close the previous one.
542    encoder.close_if_open().map_pass_err(pass_scope)?;
543    let raw_encoder = encoder
544        .open_pass(base.label.as_deref())
545        .map_pass_err(pass_scope)?;
546
547    let snatch_guard = device.snatchable_lock.read();
548    let mut debug_scope_depth = 0;
549
550    let mut state = State {
551        pipeline: None,
552
553        pass: pass::PassState {
554            base: EncodingState {
555                device,
556                raw_encoder,
557                tracker: &mut cmd_buf_data.trackers,
558                buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
559                texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
560                as_actions: &mut cmd_buf_data.as_actions,
561                indirect_draw_validation_resources: &mut cmd_buf_data
562                    .indirect_draw_validation_resources,
563                snatch_guard: &snatch_guard,
564                debug_scope_depth: &mut debug_scope_depth,
565            },
566            binder: Binder::new(),
567            temp_offsets: Vec::new(),
568            dynamic_offset_count: 0,
569
570            pending_discard_init_fixups: SurfacesInDiscardState::new(),
571
572            scope: device.new_usage_scope(),
573
574            string_offset: 0,
575        },
576        active_query: None,
577
578        push_constants: Vec::new(),
579
580        intermediate_trackers: Tracker::new(),
581    };
582
583    let indices = &state.pass.base.device.tracker_indices;
584    state
585        .pass
586        .base
587        .tracker
588        .buffers
589        .set_size(indices.buffers.size());
590    state
591        .pass
592        .base
593        .tracker
594        .textures
595        .set_size(indices.textures.size());
596
597    let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
598        if let Some(tw) = pass.timestamp_writes.take() {
599            tw.query_set
600                .same_device_as(cmd_enc.as_ref())
601                .map_pass_err(pass_scope)?;
602
603            let query_set = state
604                .pass
605                .base
606                .tracker
607                .query_sets
608                .insert_single(tw.query_set);
609
610            // Unlike in render passes we can't delay resetting the query sets since
611            // there is no auxiliary pass.
612            let range = if let (Some(index_a), Some(index_b)) =
613                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
614            {
615                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
616            } else {
617                tw.beginning_of_pass_write_index
618                    .or(tw.end_of_pass_write_index)
619                    .map(|i| i..i + 1)
620            };
621            // Range should always be Some, both values being None should lead to a validation error.
622            // But no point in erroring over that nuance here!
623            if let Some(range) = range {
624                unsafe {
625                    state
626                        .pass
627                        .base
628                        .raw_encoder
629                        .reset_queries(query_set.raw(), range);
630                }
631            }
632
633            Some(hal::PassTimestampWrites {
634                query_set: query_set.raw(),
635                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
636                end_of_pass_write_index: tw.end_of_pass_write_index,
637            })
638        } else {
639            None
640        };
641
642    let hal_desc = hal::ComputePassDescriptor {
643        label: hal_label(base.label.as_deref(), device.instance_flags),
644        timestamp_writes,
645    };
646
647    unsafe {
648        state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
649    }
650
651    for command in base.commands.drain(..) {
652        match command {
653            ArcComputeCommand::SetBindGroup {
654                index,
655                num_dynamic_offsets,
656                bind_group,
657            } => {
658                let scope = PassErrorScope::SetBindGroup;
659                pass::set_bind_group::<ComputePassErrorInner>(
660                    &mut state.pass,
661                    cmd_enc.as_ref(),
662                    &base.dynamic_offsets,
663                    index,
664                    num_dynamic_offsets,
665                    bind_group,
666                    false,
667                )
668                .map_pass_err(scope)?;
669            }
670            ArcComputeCommand::SetPipeline(pipeline) => {
671                let scope = PassErrorScope::SetPipelineCompute;
672                set_pipeline(&mut state, cmd_enc.as_ref(), pipeline).map_pass_err(scope)?;
673            }
674            ArcComputeCommand::SetPushConstant {
675                offset,
676                size_bytes,
677                values_offset,
678            } => {
679                let scope = PassErrorScope::SetPushConstant;
680                pass::set_push_constant::<ComputePassErrorInner, _>(
681                    &mut state.pass,
682                    &base.push_constant_data,
683                    wgt::ShaderStages::COMPUTE,
684                    offset,
685                    size_bytes,
686                    Some(values_offset),
687                    |data_slice| {
688                        let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
689                        let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
690                        state.push_constants[offset_in_elements..][..size_in_elements]
691                            .copy_from_slice(data_slice);
692                    },
693                )
694                .map_pass_err(scope)?;
695            }
696            ArcComputeCommand::Dispatch(groups) => {
697                let scope = PassErrorScope::Dispatch { indirect: false };
698                dispatch(&mut state, groups).map_pass_err(scope)?;
699            }
700            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
701                let scope = PassErrorScope::Dispatch { indirect: true };
702                dispatch_indirect(&mut state, cmd_enc.as_ref(), buffer, offset)
703                    .map_pass_err(scope)?;
704            }
705            ArcComputeCommand::PushDebugGroup { color: _, len } => {
706                pass::push_debug_group(&mut state.pass, &base.string_data, len);
707            }
708            ArcComputeCommand::PopDebugGroup => {
709                let scope = PassErrorScope::PopDebugGroup;
710                pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
711                    .map_pass_err(scope)?;
712            }
713            ArcComputeCommand::InsertDebugMarker { color: _, len } => {
714                pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
715            }
716            ArcComputeCommand::WriteTimestamp {
717                query_set,
718                query_index,
719            } => {
720                let scope = PassErrorScope::WriteTimestamp;
721                pass::write_timestamp::<ComputePassErrorInner>(
722                    &mut state.pass,
723                    cmd_enc.as_ref(),
724                    None,
725                    query_set,
726                    query_index,
727                )
728                .map_pass_err(scope)?;
729            }
730            ArcComputeCommand::BeginPipelineStatisticsQuery {
731                query_set,
732                query_index,
733            } => {
734                let scope = PassErrorScope::BeginPipelineStatisticsQuery;
735                validate_and_begin_pipeline_statistics_query(
736                    query_set,
737                    state.pass.base.raw_encoder,
738                    &mut state.pass.base.tracker.query_sets,
739                    cmd_enc.as_ref(),
740                    query_index,
741                    None,
742                    &mut state.active_query,
743                )
744                .map_pass_err(scope)?;
745            }
746            ArcComputeCommand::EndPipelineStatisticsQuery => {
747                let scope = PassErrorScope::EndPipelineStatisticsQuery;
748                end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
749                    .map_pass_err(scope)?;
750            }
751        }
752    }
753
754    if *state.pass.base.debug_scope_depth > 0 {
755        Err(
756            ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
757                .map_pass_err(pass_scope),
758        )?;
759    }
760
761    unsafe {
762        state.pass.base.raw_encoder.end_compute_pass();
763    }
764
765    let State {
766        pass:
767            pass::PassState {
768                base: EncodingState { tracker, .. },
769                pending_discard_init_fixups,
770                ..
771            },
772        intermediate_trackers,
773        ..
774    } = state;
775
776    // Stop the current command encoder.
777    encoder.close().map_pass_err(pass_scope)?;
778
779    // Create a new command encoder, which we will insert _before_ the body of the compute pass.
780    //
781    // Use that buffer to insert barriers and clear discarded images.
782    let transit = encoder
783        .open_pass(hal_label(
784            Some("(wgpu internal) Pre Pass"),
785            device.instance_flags,
786        ))
787        .map_pass_err(pass_scope)?;
788    fixup_discarded_surfaces(
789        pending_discard_init_fixups.into_iter(),
790        transit,
791        &mut tracker.textures,
792        device,
793        &snatch_guard,
794    );
795    CommandEncoder::insert_barriers_from_tracker(
796        transit,
797        tracker,
798        &intermediate_trackers,
799        &snatch_guard,
800    );
801    // Close the command encoder, and swap it with the previous.
802    encoder.close_and_swap().map_pass_err(pass_scope)?;
803
804    Ok(())
805}
806
807fn set_pipeline(
808    state: &mut State,
809    cmd_enc: &CommandEncoder,
810    pipeline: Arc<ComputePipeline>,
811) -> Result<(), ComputePassErrorInner> {
812    pipeline.same_device_as(cmd_enc)?;
813
814    state.pipeline = Some(pipeline.clone());
815
816    let pipeline = state
817        .pass
818        .base
819        .tracker
820        .compute_pipelines
821        .insert_single(pipeline)
822        .clone();
823
824    unsafe {
825        state
826            .pass
827            .base
828            .raw_encoder
829            .set_compute_pipeline(pipeline.raw());
830    }
831
832    // Rebind resources
833    pass::rebind_resources::<ComputePassErrorInner, _>(
834        &mut state.pass,
835        &pipeline.layout,
836        &pipeline.late_sized_buffer_groups,
837        || {
838            // This only needs to be here for compute pipelines because they use push constants for
839            // validating indirect draws.
840            state.push_constants.clear();
841            // Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
842            if let Some(push_constant_range) =
843                pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
844                    pcr.stages
845                        .contains(wgt::ShaderStages::COMPUTE)
846                        .then_some(pcr.range.clone())
847                })
848            {
849                // Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
850                let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
851                state.push_constants.extend(core::iter::repeat_n(0, len));
852            }
853        },
854    )
855}
856
857fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
858    state.is_ready()?;
859
860    state.flush_states(None)?;
861
862    let groups_size_limit = state
863        .pass
864        .base
865        .device
866        .limits
867        .max_compute_workgroups_per_dimension;
868
869    if groups[0] > groups_size_limit
870        || groups[1] > groups_size_limit
871        || groups[2] > groups_size_limit
872    {
873        return Err(ComputePassErrorInner::Dispatch(
874            DispatchError::InvalidGroupSize {
875                current: groups,
876                limit: groups_size_limit,
877            },
878        ));
879    }
880
881    unsafe {
882        state.pass.base.raw_encoder.dispatch(groups);
883    }
884    Ok(())
885}
886
887fn dispatch_indirect(
888    state: &mut State,
889    cmd_enc: &CommandEncoder,
890    buffer: Arc<Buffer>,
891    offset: u64,
892) -> Result<(), ComputePassErrorInner> {
893    buffer.same_device_as(cmd_enc)?;
894
895    state.is_ready()?;
896
897    state
898        .pass
899        .base
900        .device
901        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
902
903    buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
904    buffer.check_destroyed(state.pass.base.snatch_guard)?;
905
906    if offset % 4 != 0 {
907        return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
908    }
909
910    let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
911    if end_offset > buffer.size {
912        return Err(ComputePassErrorInner::IndirectBufferOverrun {
913            offset,
914            end_offset,
915            buffer_size: buffer.size,
916        });
917    }
918
919    let stride = 3 * 4; // 3 integers, x/y/z group size
920    state.pass.base.buffer_memory_init_actions.extend(
921        buffer.initialization_status.read().create_action(
922            &buffer,
923            offset..(offset + stride),
924            MemoryInitKind::NeedsInitializedMemory,
925        ),
926    );
927
928    if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
929        let params = indirect_validation.dispatch.params(
930            &state.pass.base.device.limits,
931            offset,
932            buffer.size,
933        );
934
935        unsafe {
936            state
937                .pass
938                .base
939                .raw_encoder
940                .set_compute_pipeline(params.pipeline);
941        }
942
943        unsafe {
944            state.pass.base.raw_encoder.set_push_constants(
945                params.pipeline_layout,
946                wgt::ShaderStages::COMPUTE,
947                0,
948                &[params.offset_remainder as u32 / 4],
949            );
950        }
951
952        unsafe {
953            state.pass.base.raw_encoder.set_bind_group(
954                params.pipeline_layout,
955                0,
956                Some(params.dst_bind_group),
957                &[],
958            );
959        }
960        unsafe {
961            state.pass.base.raw_encoder.set_bind_group(
962                params.pipeline_layout,
963                1,
964                Some(
965                    buffer
966                        .indirect_validation_bind_groups
967                        .get(state.pass.base.snatch_guard)
968                        .unwrap()
969                        .dispatch
970                        .as_ref(),
971                ),
972                &[params.aligned_offset as u32],
973            );
974        }
975
976        let src_transition = state
977            .intermediate_trackers
978            .buffers
979            .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
980        let src_barrier = src_transition
981            .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
982        unsafe {
983            state
984                .pass
985                .base
986                .raw_encoder
987                .transition_buffers(src_barrier.as_slice());
988        }
989
990        unsafe {
991            state
992                .pass
993                .base
994                .raw_encoder
995                .transition_buffers(&[hal::BufferBarrier {
996                    buffer: params.dst_buffer,
997                    usage: hal::StateTransition {
998                        from: wgt::BufferUses::INDIRECT,
999                        to: wgt::BufferUses::STORAGE_READ_WRITE,
1000                    },
1001                }]);
1002        }
1003
1004        unsafe {
1005            state.pass.base.raw_encoder.dispatch([1, 1, 1]);
1006        }
1007
1008        // reset state
1009        {
1010            let pipeline = state.pipeline.as_ref().unwrap();
1011
1012            unsafe {
1013                state
1014                    .pass
1015                    .base
1016                    .raw_encoder
1017                    .set_compute_pipeline(pipeline.raw());
1018            }
1019
1020            if !state.push_constants.is_empty() {
1021                unsafe {
1022                    state.pass.base.raw_encoder.set_push_constants(
1023                        pipeline.layout.raw(),
1024                        wgt::ShaderStages::COMPUTE,
1025                        0,
1026                        &state.push_constants,
1027                    );
1028                }
1029            }
1030
1031            for (i, e) in state.pass.binder.list_valid() {
1032                let group = e.group.as_ref().unwrap();
1033                let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1034                unsafe {
1035                    state.pass.base.raw_encoder.set_bind_group(
1036                        pipeline.layout.raw(),
1037                        i as u32,
1038                        Some(raw_bg),
1039                        &e.dynamic_offsets,
1040                    );
1041                }
1042            }
1043        }
1044
1045        unsafe {
1046            state
1047                .pass
1048                .base
1049                .raw_encoder
1050                .transition_buffers(&[hal::BufferBarrier {
1051                    buffer: params.dst_buffer,
1052                    usage: hal::StateTransition {
1053                        from: wgt::BufferUses::STORAGE_READ_WRITE,
1054                        to: wgt::BufferUses::INDIRECT,
1055                    },
1056                }]);
1057        }
1058
1059        state.flush_states(None)?;
1060        unsafe {
1061            state
1062                .pass
1063                .base
1064                .raw_encoder
1065                .dispatch_indirect(params.dst_buffer, 0);
1066        }
1067    } else {
1068        state
1069            .pass
1070            .scope
1071            .buffers
1072            .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1073
1074        use crate::resource::Trackable;
1075        state.flush_states(Some(buffer.tracker_index()))?;
1076
1077        let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1078        unsafe {
1079            state
1080                .pass
1081                .base
1082                .raw_encoder
1083                .dispatch_indirect(buf_raw, offset);
1084        }
1085    }
1086
1087    Ok(())
1088}
1089
1090// Recording a compute pass.
1091//
1092// The only error that should be returned from these methods is
1093// `EncoderStateError::Ended`, when the pass has already ended and an immediate
1094// validation error is raised.
1095//
1096// All other errors should be stored in the pass for later reporting when
1097// `CommandEncoder.finish()` is called.
1098//
1099// The `pass_try!` macro should be used to handle errors appropriately. Note
1100// that the `pass_try!` and `pass_base!` macros may return early from the
1101// function that invokes them, like the `?` operator.
1102impl Global {
1103    pub fn compute_pass_set_bind_group(
1104        &self,
1105        pass: &mut ComputePass,
1106        index: u32,
1107        bind_group_id: Option<id::BindGroupId>,
1108        offsets: &[DynamicOffset],
1109    ) -> Result<(), PassStateError> {
1110        let scope = PassErrorScope::SetBindGroup;
1111
1112        // This statement will return an error if the pass is ended. It's
1113        // important the error check comes before the early-out for
1114        // `set_and_check_redundant`.
1115        let base = pass_base!(pass, scope);
1116
1117        if pass.current_bind_groups.set_and_check_redundant(
1118            bind_group_id,
1119            index,
1120            &mut base.dynamic_offsets,
1121            offsets,
1122        ) {
1123            return Ok(());
1124        }
1125
1126        let mut bind_group = None;
1127        if bind_group_id.is_some() {
1128            let bind_group_id = bind_group_id.unwrap();
1129
1130            let hub = &self.hub;
1131            bind_group = Some(pass_try!(
1132                base,
1133                scope,
1134                hub.bind_groups.get(bind_group_id).get(),
1135            ));
1136        }
1137
1138        base.commands.push(ArcComputeCommand::SetBindGroup {
1139            index,
1140            num_dynamic_offsets: offsets.len(),
1141            bind_group,
1142        });
1143
1144        Ok(())
1145    }
1146
1147    pub fn compute_pass_set_pipeline(
1148        &self,
1149        pass: &mut ComputePass,
1150        pipeline_id: id::ComputePipelineId,
1151    ) -> Result<(), PassStateError> {
1152        let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1153
1154        let scope = PassErrorScope::SetPipelineCompute;
1155
1156        // This statement will return an error if the pass is ended.
1157        // Its important the error check comes before the early-out for `redundant`.
1158        let base = pass_base!(pass, scope);
1159
1160        if redundant {
1161            return Ok(());
1162        }
1163
1164        let hub = &self.hub;
1165        let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1166
1167        base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1168
1169        Ok(())
1170    }
1171
1172    pub fn compute_pass_set_push_constants(
1173        &self,
1174        pass: &mut ComputePass,
1175        offset: u32,
1176        data: &[u8],
1177    ) -> Result<(), PassStateError> {
1178        let scope = PassErrorScope::SetPushConstant;
1179        let base = pass_base!(pass, scope);
1180
1181        if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1182            pass_try!(
1183                base,
1184                scope,
1185                Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1186            );
1187        }
1188
1189        if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1190            pass_try!(
1191                base,
1192                scope,
1193                Err(ComputePassErrorInner::PushConstantSizeAlignment),
1194            )
1195        }
1196        let value_offset = pass_try!(
1197            base,
1198            scope,
1199            base.push_constant_data
1200                .len()
1201                .try_into()
1202                .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1203        );
1204
1205        base.push_constant_data.extend(
1206            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1207                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1208        );
1209
1210        base.commands.push(ArcComputeCommand::SetPushConstant {
1211            offset,
1212            size_bytes: data.len() as u32,
1213            values_offset: value_offset,
1214        });
1215
1216        Ok(())
1217    }
1218
1219    pub fn compute_pass_dispatch_workgroups(
1220        &self,
1221        pass: &mut ComputePass,
1222        groups_x: u32,
1223        groups_y: u32,
1224        groups_z: u32,
1225    ) -> Result<(), PassStateError> {
1226        let scope = PassErrorScope::Dispatch { indirect: false };
1227
1228        pass_base!(pass, scope)
1229            .commands
1230            .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1231
1232        Ok(())
1233    }
1234
1235    pub fn compute_pass_dispatch_workgroups_indirect(
1236        &self,
1237        pass: &mut ComputePass,
1238        buffer_id: id::BufferId,
1239        offset: BufferAddress,
1240    ) -> Result<(), PassStateError> {
1241        let hub = &self.hub;
1242        let scope = PassErrorScope::Dispatch { indirect: true };
1243        let base = pass_base!(pass, scope);
1244
1245        let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1246
1247        base.commands
1248            .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1249
1250        Ok(())
1251    }
1252
1253    pub fn compute_pass_push_debug_group(
1254        &self,
1255        pass: &mut ComputePass,
1256        label: &str,
1257        color: u32,
1258    ) -> Result<(), PassStateError> {
1259        let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1260
1261        let bytes = label.as_bytes();
1262        base.string_data.extend_from_slice(bytes);
1263
1264        base.commands.push(ArcComputeCommand::PushDebugGroup {
1265            color,
1266            len: bytes.len(),
1267        });
1268
1269        Ok(())
1270    }
1271
1272    pub fn compute_pass_pop_debug_group(
1273        &self,
1274        pass: &mut ComputePass,
1275    ) -> Result<(), PassStateError> {
1276        let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1277
1278        base.commands.push(ArcComputeCommand::PopDebugGroup);
1279
1280        Ok(())
1281    }
1282
1283    pub fn compute_pass_insert_debug_marker(
1284        &self,
1285        pass: &mut ComputePass,
1286        label: &str,
1287        color: u32,
1288    ) -> Result<(), PassStateError> {
1289        let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1290
1291        let bytes = label.as_bytes();
1292        base.string_data.extend_from_slice(bytes);
1293
1294        base.commands.push(ArcComputeCommand::InsertDebugMarker {
1295            color,
1296            len: bytes.len(),
1297        });
1298
1299        Ok(())
1300    }
1301
1302    pub fn compute_pass_write_timestamp(
1303        &self,
1304        pass: &mut ComputePass,
1305        query_set_id: id::QuerySetId,
1306        query_index: u32,
1307    ) -> Result<(), PassStateError> {
1308        let scope = PassErrorScope::WriteTimestamp;
1309        let base = pass_base!(pass, scope);
1310
1311        let hub = &self.hub;
1312        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1313
1314        base.commands.push(ArcComputeCommand::WriteTimestamp {
1315            query_set,
1316            query_index,
1317        });
1318
1319        Ok(())
1320    }
1321
1322    pub fn compute_pass_begin_pipeline_statistics_query(
1323        &self,
1324        pass: &mut ComputePass,
1325        query_set_id: id::QuerySetId,
1326        query_index: u32,
1327    ) -> Result<(), PassStateError> {
1328        let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1329        let base = pass_base!(pass, scope);
1330
1331        let hub = &self.hub;
1332        let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1333
1334        base.commands
1335            .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1336                query_set,
1337                query_index,
1338            });
1339
1340        Ok(())
1341    }
1342
1343    pub fn compute_pass_end_pipeline_statistics_query(
1344        &self,
1345        pass: &mut ComputePass,
1346    ) -> Result<(), PassStateError> {
1347        pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1348            .commands
1349            .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1350
1351        Ok(())
1352    }
1353}