1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13 command::{
14 bind::{Binder, BinderError},
15 compute_command::ArcComputeCommand,
16 encoder::EncodingState,
17 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18 pass::{self, flush_bindings_helper},
19 pass_base, pass_try,
20 query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21 ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22 CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23 PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24 TimestampWritesError, TransitionResourcesError,
25 },
26 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27 global::Global,
28 hal_label, id,
29 init_tracker::MemoryInitKind,
30 pipeline::ComputePipeline,
31 resource::{
32 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33 MissingBufferUsageError, ParentDevice, RawResourceAccess, TextureView, Trackable,
34 },
35 track::{ResourceUsageCompatibilityError, TextureViewBindGroupState, Tracker},
36 Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41pub struct ComputePass {
49 base: ComputeBasePass,
51
52 parent: Option<Arc<CommandEncoder>>,
58
59 timestamp_writes: Option<ArcPassTimestampWrites>,
60
61 current_bind_groups: BindGroupStateChange,
63 current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69 let ArcComputePassDescriptor {
70 label,
71 timestamp_writes,
72 } = desc;
73
74 Self {
75 base: BasePass::new(&label),
76 parent: Some(parent),
77 timestamp_writes,
78
79 current_bind_groups: BindGroupStateChange::new(),
80 current_pipeline: StateChange::new(),
81 }
82 }
83
84 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85 Self {
86 base: BasePass::new_invalid(label, err),
87 parent: Some(parent),
88 timestamp_writes: None,
89 current_bind_groups: BindGroupStateChange::new(),
90 current_pipeline: StateChange::new(),
91 }
92 }
93
94 #[inline]
95 pub fn label(&self) -> Option<&str> {
96 self.base.label.as_deref()
97 }
98}
99
100impl fmt::Debug for ComputePass {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self.parent {
103 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104 None => write!(f, "ComputePass {{ parent: None }}"),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111 pub label: Label<'a>,
112 pub timestamp_writes: Option<PTW>,
114}
115
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122 #[error("Compute pipeline must be set")]
123 MissingPipeline(pass::MissingPipeline),
124 #[error(transparent)]
125 IncompatibleBindGroup(#[from] Box<BinderError>),
126 #[error(
127 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128 )]
129 InvalidGroupSize { current: [u32; 3], limit: u32 },
130 #[error(transparent)]
131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132 #[error("Not all immediate data required by the pipeline has been set via set_immediates (missing byte ranges: {missing})")]
133 MissingImmediateData {
134 missing: naga::valid::ImmediateSlots,
135 },
136}
137
138impl WebGpuError for DispatchError {
139 fn webgpu_error_type(&self) -> ErrorType {
140 ErrorType::Validation
141 }
142}
143
144#[derive(Clone, Debug, Error)]
146pub enum ComputePassErrorInner {
147 #[error(transparent)]
148 Device(#[from] DeviceError),
149 #[error(transparent)]
150 EncoderState(#[from] EncoderStateError),
151 #[error("Parent encoder is invalid")]
152 InvalidParentEncoder,
153 #[error(transparent)]
154 DebugGroupError(#[from] DebugGroupError),
155 #[error(transparent)]
156 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
157 #[error(transparent)]
158 DestroyedResource(#[from] DestroyedResourceError),
159 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
160 UnalignedIndirectBufferOffset(BufferAddress),
161 #[error("Indirect buffer of {args_size} bytes starting at offset {offset} would overrun buffer of size {buffer_size}")]
162 IndirectBufferOverrun {
163 args_size: u64,
164 offset: u64,
165 buffer_size: u64,
166 },
167 #[error(transparent)]
168 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
169 #[error(transparent)]
170 MissingBufferUsage(#[from] MissingBufferUsageError),
171 #[error(transparent)]
172 Dispatch(#[from] DispatchError),
173 #[error(transparent)]
174 Bind(#[from] BindError),
175 #[error(transparent)]
176 ImmediateData(#[from] ImmediateUploadError),
177 #[error("Immediate data offset must be aligned to 4 bytes")]
178 ImmediateOffsetAlignment,
179 #[error("Immediate data size must be aligned to 4 bytes")]
180 ImmediateDataizeAlignment,
181 #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
182 ImmediateOutOfMemory,
183 #[error(transparent)]
184 QueryUse(#[from] QueryUseError),
185 #[error(transparent)]
186 TransitionResources(#[from] TransitionResourcesError),
187 #[error(transparent)]
188 MissingFeatures(#[from] MissingFeatures),
189 #[error(transparent)]
190 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
191 #[error("The compute pass has already been ended and no further commands can be recorded")]
192 PassEnded,
193 #[error(transparent)]
194 InvalidResource(#[from] InvalidResourceError),
195 #[error(transparent)]
196 TimestampWrites(#[from] TimestampWritesError),
197 #[error(transparent)]
199 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
200}
201
202#[derive(Clone, Debug, Error)]
205#[error("{scope}")]
206pub struct ComputePassError {
207 pub scope: PassErrorScope,
208 #[source]
209 pub(super) inner: ComputePassErrorInner,
210}
211
212impl From<pass::MissingPipeline> for ComputePassErrorInner {
213 fn from(value: pass::MissingPipeline) -> Self {
214 Self::Dispatch(DispatchError::MissingPipeline(value))
215 }
216}
217
218impl<E> MapPassErr<ComputePassError> for E
219where
220 E: Into<ComputePassErrorInner>,
221{
222 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
223 ComputePassError {
224 scope,
225 inner: self.into(),
226 }
227 }
228}
229
230impl WebGpuError for ComputePassError {
231 fn webgpu_error_type(&self) -> ErrorType {
232 let Self { scope: _, inner } = self;
233 match inner {
234 ComputePassErrorInner::Device(e) => e.webgpu_error_type(),
235 ComputePassErrorInner::EncoderState(e) => e.webgpu_error_type(),
236 ComputePassErrorInner::DebugGroupError(e) => e.webgpu_error_type(),
237 ComputePassErrorInner::DestroyedResource(e) => e.webgpu_error_type(),
238 ComputePassErrorInner::ResourceUsageCompatibility(e) => e.webgpu_error_type(),
239 ComputePassErrorInner::MissingBufferUsage(e) => e.webgpu_error_type(),
240 ComputePassErrorInner::Dispatch(e) => e.webgpu_error_type(),
241 ComputePassErrorInner::Bind(e) => e.webgpu_error_type(),
242 ComputePassErrorInner::ImmediateData(e) => e.webgpu_error_type(),
243 ComputePassErrorInner::QueryUse(e) => e.webgpu_error_type(),
244 ComputePassErrorInner::TransitionResources(e) => e.webgpu_error_type(),
245 ComputePassErrorInner::MissingFeatures(e) => e.webgpu_error_type(),
246 ComputePassErrorInner::MissingDownlevelFlags(e) => e.webgpu_error_type(),
247 ComputePassErrorInner::InvalidResource(e) => e.webgpu_error_type(),
248 ComputePassErrorInner::TimestampWrites(e) => e.webgpu_error_type(),
249 ComputePassErrorInner::InvalidValuesOffset(e) => e.webgpu_error_type(),
250
251 ComputePassErrorInner::InvalidParentEncoder
252 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
253 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
254 | ComputePassErrorInner::IndirectBufferOverrun { .. }
255 | ComputePassErrorInner::ImmediateOffsetAlignment
256 | ComputePassErrorInner::ImmediateDataizeAlignment
257 | ComputePassErrorInner::ImmediateOutOfMemory
258 | ComputePassErrorInner::PassEnded => ErrorType::Validation,
259 }
260 }
261}
262
263struct State<'scope, 'snatch_guard, 'cmd_enc> {
264 pipeline: Option<Arc<ComputePipeline>>,
265
266 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
267
268 active_query: Option<(Arc<resource::QuerySet>, u32)>,
269
270 immediates: Vec<u32>,
271
272 immediate_slots_set: naga::valid::ImmediateSlots,
275
276 intermediate_trackers: Tracker,
277}
278
279impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
280 fn is_ready(&self) -> Result<(), DispatchError> {
281 if let Some(pipeline) = self.pipeline.as_ref() {
282 self.pass.binder.check_compatibility(pipeline.as_ref())?;
283 self.pass.binder.check_late_buffer_bindings()?;
284 if !self
285 .immediate_slots_set
286 .contains(pipeline.immediate_slots_required)
287 {
288 return Err(DispatchError::MissingImmediateData {
289 missing: pipeline
290 .immediate_slots_required
291 .difference(self.immediate_slots_set),
292 });
293 }
294 Ok(())
295 } else {
296 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
297 }
298 }
299
300 fn flush_bindings(
339 &mut self,
340 indirect_buffer: Option<&Arc<Buffer>>,
341 track_indirect_buffer: bool,
342 ) -> Result<(), ComputePassErrorInner> {
343 for bind_group in self.pass.binder.list_active() {
344 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
345 }
346
347 if let Some(buffer) = indirect_buffer {
351 self.pass
352 .scope
353 .buffers
354 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
355 }
356
357 for bind_group in self.pass.binder.list_active() {
363 self.intermediate_trackers
364 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
365 }
366
367 if track_indirect_buffer {
368 self.intermediate_trackers
369 .buffers
370 .set_and_remove_from_usage_scope_sparse(
371 &mut self.pass.scope.buffers,
372 indirect_buffer.map(|buf| buf.tracker_index()),
373 );
374 } else if let Some(buffer) = indirect_buffer {
375 self.pass
376 .scope
377 .buffers
378 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
379 }
380
381 flush_bindings_helper(&mut self.pass)?;
382
383 CommandEncoder::drain_barriers(
384 self.pass.base.raw_encoder,
385 &mut self.intermediate_trackers,
386 self.pass.base.snatch_guard,
387 );
388 Ok(())
389 }
390}
391
392fn transition_resources(
395 state: &mut State,
396 buffer_transitions: Vec<wgt::BufferTransition<Arc<Buffer>>>,
397 texture_transitions: Vec<wgt::TextureTransition<Arc<TextureView>>>,
398) -> Result<(), TransitionResourcesError> {
399 let indices = &state.pass.base.device.tracker_indices;
400 state.pass.scope.buffers.set_size(indices.buffers.size());
401 state.pass.scope.textures.set_size(indices.textures.size());
402
403 let mut buffer_ids = Vec::with_capacity(buffer_transitions.len());
404 let mut textures = TextureViewBindGroupState::new();
405
406 for buffer_transition in buffer_transitions {
408 buffer_transition
409 .buffer
410 .same_device(state.pass.base.device)?;
411
412 state
413 .pass
414 .scope
415 .buffers
416 .merge_single(&buffer_transition.buffer, buffer_transition.state)?;
417 buffer_ids.push(buffer_transition.buffer.tracker_index());
418 }
419
420 state
421 .intermediate_trackers
422 .buffers
423 .set_and_remove_from_usage_scope_sparse(&mut state.pass.scope.buffers, buffer_ids);
424
425 for texture_transition in texture_transitions {
427 texture_transition
428 .texture
429 .same_device(state.pass.base.device)?;
430
431 unsafe {
432 state.pass.scope.textures.merge_single(
433 &texture_transition.texture.parent,
434 texture_transition.selector,
435 texture_transition.state,
436 )
437 }?;
438
439 textures.insert_single(texture_transition.texture, texture_transition.state);
440 }
441
442 state
443 .intermediate_trackers
444 .textures
445 .set_and_remove_from_usage_scope_sparse(&mut state.pass.scope.textures, &textures);
446
447 CommandEncoder::drain_barriers(
449 state.pass.base.raw_encoder,
450 &mut state.intermediate_trackers,
451 state.pass.base.snatch_guard,
452 );
453 Ok(())
454}
455
456impl Global {
459 pub fn command_encoder_begin_compute_pass(
470 &self,
471 encoder_id: id::CommandEncoderId,
472 desc: &ComputePassDescriptor<'_>,
473 ) -> (ComputePass, Option<CommandEncoderError>) {
474 use EncoderStateError as SErr;
475
476 let scope = PassErrorScope::Pass;
477 let hub = &self.hub;
478
479 let label = desc.label.as_deref().map(Cow::Borrowed);
480
481 let cmd_enc = hub.command_encoders.get(encoder_id);
482 let mut cmd_buf_data = cmd_enc.data.lock();
483
484 match cmd_buf_data.lock_encoder() {
485 Ok(()) => {
486 drop(cmd_buf_data);
487 if let Err(err) = cmd_enc.device.check_is_valid() {
488 return (
489 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
490 None,
491 );
492 }
493
494 match desc
495 .timestamp_writes
496 .as_ref()
497 .map(|tw| {
498 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
499 &cmd_enc.device,
500 &hub.query_sets.read(),
501 tw,
502 )
503 })
504 .transpose()
505 {
506 Ok(timestamp_writes) => {
507 let arc_desc = ArcComputePassDescriptor {
508 label,
509 timestamp_writes,
510 };
511 (ComputePass::new(cmd_enc, arc_desc), None)
512 }
513 Err(err) => (
514 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
515 None,
516 ),
517 }
518 }
519 Err(err @ SErr::Locked) => {
520 cmd_buf_data.invalidate(err.clone());
524 drop(cmd_buf_data);
525 (
526 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
527 None,
528 )
529 }
530 Err(err @ (SErr::Ended | SErr::Submitted)) => {
531 drop(cmd_buf_data);
534 (
535 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
536 Some(err.into()),
537 )
538 }
539 Err(err @ SErr::Invalid) => {
540 drop(cmd_buf_data);
546 (
547 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
548 None,
549 )
550 }
551 Err(SErr::Unlocked) => {
552 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
553 }
554 }
555 }
556
557 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
558 profiling::scope!(
559 "CommandEncoder::run_compute_pass {}",
560 pass.base.label.as_deref().unwrap_or("")
561 );
562
563 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
564 let mut cmd_buf_data = cmd_enc.data.lock();
565
566 cmd_buf_data.unlock_encoder()?;
567
568 let base = pass.base.take();
569
570 if let Err(ComputePassError {
571 inner:
572 ComputePassErrorInner::EncoderState(
573 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
574 ),
575 scope: _,
576 }) = base
577 {
578 return Err(err.clone());
585 }
586
587 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
588 Ok(ArcCommand::RunComputePass {
589 pass: base?,
590 timestamp_writes: pass.timestamp_writes.take(),
591 })
592 })
593 }
594}
595
596pub(super) fn encode_compute_pass(
597 parent_state: &mut EncodingState<InnerCommandEncoder>,
598 mut base: BasePass<ArcComputeCommand, Infallible>,
599 mut timestamp_writes: Option<ArcPassTimestampWrites>,
600) -> Result<(), ComputePassError> {
601 let pass_scope = PassErrorScope::Pass;
602
603 let device = parent_state.device;
604
605 parent_state
609 .raw_encoder
610 .close_if_open()
611 .map_pass_err(pass_scope)?;
612 let raw_encoder = parent_state
613 .raw_encoder
614 .open_pass(base.label.as_deref())
615 .map_pass_err(pass_scope)?;
616
617 let mut debug_scope_depth = 0;
618
619 let mut state = State {
620 pipeline: None,
621
622 pass: pass::PassState {
623 base: EncodingState {
624 device,
625 raw_encoder,
626 tracker: parent_state.tracker,
627 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
628 texture_memory_actions: parent_state.texture_memory_actions,
629 as_actions: parent_state.as_actions,
630 temp_resources: parent_state.temp_resources,
631 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
632 snatch_guard: parent_state.snatch_guard,
633 debug_scope_depth: &mut debug_scope_depth,
634 },
635 binder: Binder::new(),
636 temp_offsets: Vec::new(),
637 dynamic_offset_count: 0,
638 pending_discard_init_fixups: SurfacesInDiscardState::new(),
639 scope: device.new_usage_scope(),
640 string_offset: 0,
641 },
642 active_query: None,
643
644 immediates: Vec::new(),
645
646 immediate_slots_set: Default::default(),
647
648 intermediate_trackers: Tracker::new(
649 device.ordered_buffer_usages,
650 device.ordered_texture_usages,
651 ),
652 };
653
654 let indices = &device.tracker_indices;
655 state
656 .pass
657 .base
658 .tracker
659 .buffers
660 .set_size(indices.buffers.size());
661 state
662 .pass
663 .base
664 .tracker
665 .textures
666 .set_size(indices.textures.size());
667
668 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
669 if let Some(tw) = timestamp_writes.take() {
670 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
671
672 let query_set = state
673 .pass
674 .base
675 .tracker
676 .query_sets
677 .insert_single(tw.query_set);
678
679 let range = if let (Some(index_a), Some(index_b)) =
682 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
683 {
684 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
685 } else {
686 tw.beginning_of_pass_write_index
687 .or(tw.end_of_pass_write_index)
688 .map(|i| i..i + 1)
689 };
690 if let Some(range) = range {
693 unsafe {
694 state
695 .pass
696 .base
697 .raw_encoder
698 .reset_queries(query_set.raw(), range);
699 }
700 }
701
702 Some(hal::PassTimestampWrites {
703 query_set: query_set.raw(),
704 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
705 end_of_pass_write_index: tw.end_of_pass_write_index,
706 })
707 } else {
708 None
709 };
710
711 let hal_desc = hal::ComputePassDescriptor {
712 label: hal_label(base.label.as_deref(), device.instance_flags),
713 timestamp_writes,
714 };
715
716 unsafe {
717 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
718 }
719
720 for command in base.commands.drain(..) {
721 match command {
722 ArcComputeCommand::SetBindGroup {
723 index,
724 num_dynamic_offsets,
725 bind_group,
726 } => {
727 let scope = PassErrorScope::SetBindGroup;
728 pass::set_bind_group::<ComputePassErrorInner>(
729 &mut state.pass,
730 device,
731 &base.dynamic_offsets,
732 index,
733 num_dynamic_offsets,
734 bind_group,
735 false,
736 )
737 .map_pass_err(scope)?;
738 }
739 ArcComputeCommand::SetPipeline(pipeline) => {
740 let scope = PassErrorScope::SetPipelineCompute;
741 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
742 }
743 ArcComputeCommand::SetImmediate {
744 offset,
745 size_bytes,
746 values_offset,
747 } => {
748 let scope = PassErrorScope::SetImmediate;
749 pass::set_immediates::<ComputePassErrorInner, _>(
750 &mut state.pass,
751 &base.immediates_data,
752 offset,
753 size_bytes,
754 Some(values_offset),
755 |data_slice| {
756 let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
757 let size_in_elements =
758 (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
759 state.immediates[offset_in_elements..][..size_in_elements]
760 .copy_from_slice(data_slice);
761 },
762 )
763 .map_pass_err(scope)?;
764 state.immediate_slots_set |=
765 naga::valid::ImmediateSlots::from_range(offset, size_bytes);
766 }
767 ArcComputeCommand::DispatchWorkgroups(groups) => {
768 let scope = PassErrorScope::Dispatch { indirect: false };
769 dispatch_workgroups(&mut state, groups).map_pass_err(scope)?;
770 }
771 ArcComputeCommand::DispatchWorkgroupsIndirect { buffer, offset } => {
772 let scope = PassErrorScope::Dispatch { indirect: true };
773 dispatch_workgroups_indirect(&mut state, device, buffer, offset)
774 .map_pass_err(scope)?;
775 }
776 ArcComputeCommand::PushDebugGroup { color: _, len } => {
777 pass::push_debug_group(&mut state.pass, &base.string_data, len);
778 }
779 ArcComputeCommand::PopDebugGroup => {
780 let scope = PassErrorScope::PopDebugGroup;
781 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
782 .map_pass_err(scope)?;
783 }
784 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
785 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
786 }
787 ArcComputeCommand::WriteTimestamp {
788 query_set,
789 query_index,
790 } => {
791 let scope = PassErrorScope::WriteTimestamp;
792 pass::write_timestamp::<ComputePassErrorInner>(
793 &mut state.pass,
794 device,
795 None, query_set,
797 query_index,
798 )
799 .map_pass_err(scope)?;
800 }
801 ArcComputeCommand::BeginPipelineStatisticsQuery {
802 query_set,
803 query_index,
804 } => {
805 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
806 validate_and_begin_pipeline_statistics_query(
807 query_set,
808 state.pass.base.raw_encoder,
809 &mut state.pass.base.tracker.query_sets,
810 device,
811 query_index,
812 None,
813 &mut state.active_query,
814 )
815 .map_pass_err(scope)?;
816 }
817 ArcComputeCommand::EndPipelineStatisticsQuery => {
818 let scope = PassErrorScope::EndPipelineStatisticsQuery;
819 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
820 .map_pass_err(scope)?;
821 }
822 ArcComputeCommand::TransitionResources {
823 buffer_transitions,
824 texture_transitions,
825 } => {
826 let scope = PassErrorScope::TransitionResources;
827 transition_resources(&mut state, buffer_transitions, texture_transitions)
828 .map_pass_err(scope)?;
829 }
830 }
831 }
832
833 if *state.pass.base.debug_scope_depth > 0 {
834 Err(
835 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
836 .map_pass_err(pass_scope),
837 )?;
838 }
839
840 unsafe {
841 state.pass.base.raw_encoder.end_compute_pass();
842 }
843
844 let State {
845 pass: pass::PassState {
846 pending_discard_init_fixups,
847 ..
848 },
849 intermediate_trackers,
850 ..
851 } = state;
852
853 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
855
856 let transit = parent_state
860 .raw_encoder
861 .open_pass(hal_label(
862 Some("(wgpu internal) Pre Pass"),
863 device.instance_flags,
864 ))
865 .map_pass_err(pass_scope)?;
866 fixup_discarded_surfaces(
867 pending_discard_init_fixups.into_iter(),
868 transit,
869 &mut parent_state.tracker.textures,
870 device,
871 parent_state.snatch_guard,
872 );
873 CommandEncoder::insert_barriers_from_tracker(
874 transit,
875 parent_state.tracker,
876 &intermediate_trackers,
877 parent_state.snatch_guard,
878 );
879 parent_state
881 .raw_encoder
882 .close_and_swap()
883 .map_pass_err(pass_scope)?;
884
885 Ok(())
886}
887
888fn set_pipeline(
889 state: &mut State,
890 device: &Arc<Device>,
891 pipeline: Arc<ComputePipeline>,
892) -> Result<(), ComputePassErrorInner> {
893 pipeline.same_device(device)?;
894
895 state.pipeline = Some(pipeline.clone());
896
897 let pipeline = state
898 .pass
899 .base
900 .tracker
901 .compute_pipelines
902 .insert_single(pipeline)
903 .clone();
904
905 unsafe {
906 state
907 .pass
908 .base
909 .raw_encoder
910 .set_compute_pipeline(pipeline.raw());
911 }
912
913 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
915 &mut state.pass,
916 &pipeline.layout,
917 &pipeline.late_sized_buffer_groups,
918 || {
919 state.immediates.clear();
922 if pipeline.layout.immediate_size != 0 {
924 let len = pipeline.layout.immediate_size as usize
926 / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
927 state.immediates.extend(core::iter::repeat_n(0, len));
928 }
929 },
930 )
931}
932
933fn dispatch_workgroups(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
934 api_log!("ComputePass::dispatch {groups:?}");
935
936 state.is_ready()?;
937
938 state.flush_bindings(None, false)?;
939
940 let groups_size_limit = state
941 .pass
942 .base
943 .device
944 .limits
945 .max_compute_workgroups_per_dimension;
946
947 if groups.iter().copied().any(|g| g > groups_size_limit) {
948 return Err(ComputePassErrorInner::Dispatch(
949 DispatchError::InvalidGroupSize {
950 current: groups,
951 limit: groups_size_limit,
952 },
953 ));
954 }
955
956 unsafe {
957 state.pass.base.raw_encoder.dispatch_workgroups(groups);
958 }
959 Ok(())
960}
961
962fn dispatch_workgroups_indirect(
963 state: &mut State,
964 device: &Arc<Device>,
965 buffer: Arc<Buffer>,
966 offset: u64,
967) -> Result<(), ComputePassErrorInner> {
968 api_log!("ComputePass::dispatch_indirect");
969
970 buffer.same_device(device)?;
971
972 state.is_ready()?;
973
974 state
975 .pass
976 .base
977 .device
978 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
979
980 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
981
982 if !offset.is_multiple_of(4) {
983 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
984 }
985
986 let args_size = size_of::<wgt::DispatchIndirectArgs>() as u64;
987 if buffer.size < args_size || buffer.size - args_size < offset {
988 return Err(ComputePassErrorInner::IndirectBufferOverrun {
989 args_size,
990 offset,
991 buffer_size: buffer.size,
992 });
993 }
994
995 buffer.check_destroyed(state.pass.base.snatch_guard)?;
996
997 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
999 buffer.initialization_status.read().create_action(
1000 &buffer,
1001 offset..(offset + stride),
1002 MemoryInitKind::NeedsInitializedMemory,
1003 ),
1004 );
1005
1006 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
1007 let params = indirect_validation.dispatch.params(
1008 &state.pass.base.device.limits,
1009 offset,
1010 buffer.size,
1011 );
1012
1013 unsafe {
1014 state
1015 .pass
1016 .base
1017 .raw_encoder
1018 .set_compute_pipeline(params.pipeline);
1019 }
1020
1021 unsafe {
1022 state.pass.base.raw_encoder.set_immediates(
1023 params.pipeline_layout,
1024 0,
1025 &[params.offset_remainder as u32 / 4],
1026 );
1027 }
1028
1029 unsafe {
1030 state.pass.base.raw_encoder.set_bind_group(
1031 params.pipeline_layout,
1032 0,
1033 params.dst_bind_group,
1034 &[],
1035 );
1036 }
1037 unsafe {
1038 state.pass.base.raw_encoder.set_bind_group(
1039 params.pipeline_layout,
1040 1,
1041 buffer
1042 .indirect_validation_bind_groups
1043 .get(state.pass.base.snatch_guard)
1044 .unwrap()
1045 .dispatch
1046 .as_ref(),
1047 &[params.aligned_offset as u32],
1048 );
1049 }
1050
1051 let src_transition = state
1052 .intermediate_trackers
1053 .buffers
1054 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
1055 let src_barrier = src_transition
1056 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
1057 unsafe {
1058 state
1059 .pass
1060 .base
1061 .raw_encoder
1062 .transition_buffers(src_barrier.as_slice());
1063 }
1064
1065 unsafe {
1066 state
1067 .pass
1068 .base
1069 .raw_encoder
1070 .transition_buffers(&[hal::BufferBarrier {
1071 buffer: params.dst_buffer,
1072 usage: hal::StateTransition {
1073 from: wgt::BufferUses::INDIRECT,
1074 to: wgt::BufferUses::STORAGE_READ_WRITE,
1075 },
1076 }]);
1077 }
1078
1079 unsafe {
1080 state.pass.base.raw_encoder.dispatch_workgroups([1, 1, 1]);
1081 }
1082
1083 {
1085 let pipeline = state.pipeline.as_ref().unwrap();
1086
1087 unsafe {
1088 state
1089 .pass
1090 .base
1091 .raw_encoder
1092 .set_compute_pipeline(pipeline.raw());
1093 }
1094
1095 if !state.immediates.is_empty() {
1096 unsafe {
1097 state.pass.base.raw_encoder.set_immediates(
1098 pipeline.layout.raw(),
1099 0,
1100 &state.immediates,
1101 );
1102 }
1103 }
1104
1105 for (i, group, dynamic_offsets) in state.pass.binder.list_valid() {
1106 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1107 unsafe {
1108 state.pass.base.raw_encoder.set_bind_group(
1109 pipeline.layout.raw(),
1110 i as u32,
1111 raw_bg,
1112 dynamic_offsets,
1113 );
1114 }
1115 }
1116 }
1117
1118 unsafe {
1119 state
1120 .pass
1121 .base
1122 .raw_encoder
1123 .transition_buffers(&[hal::BufferBarrier {
1124 buffer: params.dst_buffer,
1125 usage: hal::StateTransition {
1126 from: wgt::BufferUses::STORAGE_READ_WRITE,
1127 to: wgt::BufferUses::INDIRECT,
1128 },
1129 }]);
1130 }
1131
1132 state.flush_bindings(Some(&buffer), false)?;
1133 unsafe {
1134 state
1135 .pass
1136 .base
1137 .raw_encoder
1138 .dispatch_workgroups_indirect(params.dst_buffer, 0);
1139 }
1140 } else {
1141 state.flush_bindings(Some(&buffer), true)?;
1142
1143 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1144 unsafe {
1145 state
1146 .pass
1147 .base
1148 .raw_encoder
1149 .dispatch_workgroups_indirect(buf_raw, offset);
1150 }
1151 }
1152
1153 Ok(())
1154}
1155
1156impl Global {
1169 pub fn compute_pass_set_bind_group(
1170 &self,
1171 pass: &mut ComputePass,
1172 index: u32,
1173 bind_group_id: Option<id::BindGroupId>,
1174 offsets: &[DynamicOffset],
1175 ) -> Result<(), PassStateError> {
1176 let scope = PassErrorScope::SetBindGroup;
1177
1178 let base = pass_base!(pass, scope);
1182
1183 if pass.current_bind_groups.set_and_check_redundant(
1184 bind_group_id,
1185 index,
1186 &mut base.dynamic_offsets,
1187 offsets,
1188 ) {
1189 return Ok(());
1190 }
1191
1192 let mut bind_group = None;
1193 if let Some(bind_group_id) = bind_group_id {
1194 let hub = &self.hub;
1195 bind_group = Some(pass_try!(
1196 base,
1197 scope,
1198 hub.bind_groups.get(bind_group_id).get(),
1199 ));
1200 }
1201
1202 base.commands.push(ArcComputeCommand::SetBindGroup {
1203 index,
1204 num_dynamic_offsets: offsets.len(),
1205 bind_group,
1206 });
1207
1208 Ok(())
1209 }
1210
1211 pub fn compute_pass_set_pipeline(
1212 &self,
1213 pass: &mut ComputePass,
1214 pipeline_id: id::ComputePipelineId,
1215 ) -> Result<(), PassStateError> {
1216 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1217
1218 let scope = PassErrorScope::SetPipelineCompute;
1219
1220 let base = pass_base!(pass, scope);
1223
1224 if redundant {
1225 return Ok(());
1226 }
1227
1228 let hub = &self.hub;
1229 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1230
1231 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1232
1233 Ok(())
1234 }
1235
1236 pub fn compute_pass_set_immediates(
1237 &self,
1238 pass: &mut ComputePass,
1239 offset: u32,
1240 data: &[u8],
1241 ) -> Result<(), PassStateError> {
1242 let scope = PassErrorScope::SetImmediate;
1243 let base = pass_base!(pass, scope);
1244
1245 if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1246 pass_try!(
1247 base,
1248 scope,
1249 Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1250 );
1251 }
1252
1253 if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1254 pass_try!(
1255 base,
1256 scope,
1257 Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1258 )
1259 }
1260 let value_offset = pass_try!(
1261 base,
1262 scope,
1263 base.immediates_data
1264 .len()
1265 .try_into()
1266 .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1267 );
1268
1269 base.immediates_data.extend(
1270 data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1271 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1272 );
1273
1274 base.commands.push(ArcComputeCommand::SetImmediate {
1275 offset,
1276 size_bytes: data.len() as u32,
1277 values_offset: value_offset,
1278 });
1279
1280 Ok(())
1281 }
1282
1283 pub fn compute_pass_dispatch_workgroups(
1284 &self,
1285 pass: &mut ComputePass,
1286 groups_x: u32,
1287 groups_y: u32,
1288 groups_z: u32,
1289 ) -> Result<(), PassStateError> {
1290 let scope = PassErrorScope::Dispatch { indirect: false };
1291
1292 pass_base!(pass, scope)
1293 .commands
1294 .push(ArcComputeCommand::DispatchWorkgroups([
1295 groups_x, groups_y, groups_z,
1296 ]));
1297
1298 Ok(())
1299 }
1300
1301 pub fn compute_pass_dispatch_workgroups_indirect(
1302 &self,
1303 pass: &mut ComputePass,
1304 buffer_id: id::BufferId,
1305 offset: BufferAddress,
1306 ) -> Result<(), PassStateError> {
1307 let hub = &self.hub;
1308 let scope = PassErrorScope::Dispatch { indirect: true };
1309 let base = pass_base!(pass, scope);
1310
1311 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1312
1313 base.commands
1314 .push(ArcComputeCommand::DispatchWorkgroupsIndirect { buffer, offset });
1315
1316 Ok(())
1317 }
1318
1319 pub fn compute_pass_push_debug_group(
1320 &self,
1321 pass: &mut ComputePass,
1322 label: &str,
1323 color: u32,
1324 ) -> Result<(), PassStateError> {
1325 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1326
1327 let bytes = label.as_bytes();
1328 base.string_data.extend_from_slice(bytes);
1329
1330 base.commands.push(ArcComputeCommand::PushDebugGroup {
1331 color,
1332 len: bytes.len(),
1333 });
1334
1335 Ok(())
1336 }
1337
1338 pub fn compute_pass_pop_debug_group(
1339 &self,
1340 pass: &mut ComputePass,
1341 ) -> Result<(), PassStateError> {
1342 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1343
1344 base.commands.push(ArcComputeCommand::PopDebugGroup);
1345
1346 Ok(())
1347 }
1348
1349 pub fn compute_pass_insert_debug_marker(
1350 &self,
1351 pass: &mut ComputePass,
1352 label: &str,
1353 color: u32,
1354 ) -> Result<(), PassStateError> {
1355 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1356
1357 let bytes = label.as_bytes();
1358 base.string_data.extend_from_slice(bytes);
1359
1360 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1361 color,
1362 len: bytes.len(),
1363 });
1364
1365 Ok(())
1366 }
1367
1368 pub fn compute_pass_write_timestamp(
1369 &self,
1370 pass: &mut ComputePass,
1371 query_set_id: id::QuerySetId,
1372 query_index: u32,
1373 ) -> Result<(), PassStateError> {
1374 let scope = PassErrorScope::WriteTimestamp;
1375 let base = pass_base!(pass, scope);
1376
1377 let hub = &self.hub;
1378 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1379
1380 base.commands.push(ArcComputeCommand::WriteTimestamp {
1381 query_set,
1382 query_index,
1383 });
1384
1385 Ok(())
1386 }
1387
1388 pub fn compute_pass_begin_pipeline_statistics_query(
1389 &self,
1390 pass: &mut ComputePass,
1391 query_set_id: id::QuerySetId,
1392 query_index: u32,
1393 ) -> Result<(), PassStateError> {
1394 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1395 let base = pass_base!(pass, scope);
1396
1397 let hub = &self.hub;
1398 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1399
1400 base.commands
1401 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1402 query_set,
1403 query_index,
1404 });
1405
1406 Ok(())
1407 }
1408
1409 pub fn compute_pass_end_pipeline_statistics_query(
1410 &self,
1411 pass: &mut ComputePass,
1412 ) -> Result<(), PassStateError> {
1413 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1414 .commands
1415 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1416
1417 Ok(())
1418 }
1419
1420 pub fn compute_pass_transition_resources(
1421 &self,
1422 pass: &mut ComputePass,
1423 buffer_transitions: impl Iterator<Item = wgt::BufferTransition<id::BufferId>>,
1424 texture_transitions: impl Iterator<Item = wgt::TextureTransition<id::TextureViewId>>,
1425 ) -> Result<(), PassStateError> {
1426 let scope = PassErrorScope::TransitionResources;
1427 let base = pass_base!(pass, scope);
1428
1429 let hub = &self.hub;
1430 let buffer_transitions = pass_try!(
1431 base,
1432 scope,
1433 buffer_transitions
1434 .map(|buffer_transition| -> Result<_, InvalidResourceError> {
1435 Ok(wgt::BufferTransition {
1436 buffer: hub.buffers.get(buffer_transition.buffer).get()?,
1437 state: buffer_transition.state,
1438 })
1439 })
1440 .collect::<Result<Vec<_>, _>>()
1441 );
1442
1443 let texture_transitions = pass_try!(
1444 base,
1445 scope,
1446 texture_transitions
1447 .map(|texture_transition| -> Result<_, InvalidResourceError> {
1448 Ok(wgt::TextureTransition {
1449 texture: hub.texture_views.get(texture_transition.texture).get()?,
1450 selector: texture_transition.selector,
1451 state: texture_transition.state,
1452 })
1453 })
1454 .collect::<Result<Vec<_>, _>>()
1455 );
1456
1457 base.commands.push(ArcComputeCommand::TransitionResources {
1458 buffer_transitions,
1459 texture_transitions,
1460 });
1461
1462 Ok(())
1463 }
1464}