1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13 command::{
14 bind::{Binder, BinderError},
15 compute_command::ArcComputeCommand,
16 encoder::EncodingState,
17 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18 pass::{self, flush_bindings_helper},
19 pass_base, pass_try,
20 query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21 ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22 CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23 PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24 TimestampWritesError,
25 },
26 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27 global::Global,
28 hal_label, id,
29 init_tracker::MemoryInitKind,
30 pipeline::ComputePipeline,
31 resource::{
32 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33 MissingBufferUsageError, ParentDevice, RawResourceAccess, Trackable,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker},
36 Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41pub struct ComputePass {
49 base: ComputeBasePass,
51
52 parent: Option<Arc<CommandEncoder>>,
58
59 timestamp_writes: Option<ArcPassTimestampWrites>,
60
61 current_bind_groups: BindGroupStateChange,
63 current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69 let ArcComputePassDescriptor {
70 label,
71 timestamp_writes,
72 } = desc;
73
74 Self {
75 base: BasePass::new(&label),
76 parent: Some(parent),
77 timestamp_writes,
78
79 current_bind_groups: BindGroupStateChange::new(),
80 current_pipeline: StateChange::new(),
81 }
82 }
83
84 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85 Self {
86 base: BasePass::new_invalid(label, err),
87 parent: Some(parent),
88 timestamp_writes: None,
89 current_bind_groups: BindGroupStateChange::new(),
90 current_pipeline: StateChange::new(),
91 }
92 }
93
94 #[inline]
95 pub fn label(&self) -> Option<&str> {
96 self.base.label.as_deref()
97 }
98}
99
100impl fmt::Debug for ComputePass {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self.parent {
103 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104 None => write!(f, "ComputePass {{ parent: None }}"),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111 pub label: Label<'a>,
112 pub timestamp_writes: Option<PTW>,
114}
115
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122 #[error("Compute pipeline must be set")]
123 MissingPipeline(pass::MissingPipeline),
124 #[error(transparent)]
125 IncompatibleBindGroup(#[from] Box<BinderError>),
126 #[error(
127 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128 )]
129 InvalidGroupSize { current: [u32; 3], limit: u32 },
130 #[error(transparent)]
131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135 fn webgpu_error_type(&self) -> ErrorType {
136 ErrorType::Validation
137 }
138}
139
140#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143 #[error(transparent)]
144 Device(#[from] DeviceError),
145 #[error(transparent)]
146 EncoderState(#[from] EncoderStateError),
147 #[error("Parent encoder is invalid")]
148 InvalidParentEncoder,
149 #[error(transparent)]
150 DebugGroupError(#[from] DebugGroupError),
151 #[error(transparent)]
152 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153 #[error(transparent)]
154 DestroyedResource(#[from] DestroyedResourceError),
155 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156 UnalignedIndirectBufferOffset(BufferAddress),
157 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158 IndirectBufferOverrun {
159 offset: u64,
160 end_offset: u64,
161 buffer_size: u64,
162 },
163 #[error(transparent)]
164 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165 #[error(transparent)]
166 MissingBufferUsage(#[from] MissingBufferUsageError),
167 #[error(transparent)]
168 Dispatch(#[from] DispatchError),
169 #[error(transparent)]
170 Bind(#[from] BindError),
171 #[error(transparent)]
172 ImmediateData(#[from] ImmediateUploadError),
173 #[error("Immediate data offset must be aligned to 4 bytes")]
174 ImmediateOffsetAlignment,
175 #[error("Immediate data size must be aligned to 4 bytes")]
176 ImmediateDataizeAlignment,
177 #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
178 ImmediateOutOfMemory,
179 #[error(transparent)]
180 QueryUse(#[from] QueryUseError),
181 #[error(transparent)]
182 MissingFeatures(#[from] MissingFeatures),
183 #[error(transparent)]
184 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185 #[error("The compute pass has already been ended and no further commands can be recorded")]
186 PassEnded,
187 #[error(transparent)]
188 InvalidResource(#[from] InvalidResourceError),
189 #[error(transparent)]
190 TimestampWrites(#[from] TimestampWritesError),
191 #[error(transparent)]
193 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201 pub scope: PassErrorScope,
202 #[source]
203 pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207 fn from(value: pass::MissingPipeline) -> Self {
208 Self::Dispatch(DispatchError::MissingPipeline(value))
209 }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214 E: Into<ComputePassErrorInner>,
215{
216 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217 ComputePassError {
218 scope,
219 inner: self.into(),
220 }
221 }
222}
223
224impl WebGpuError for ComputePassError {
225 fn webgpu_error_type(&self) -> ErrorType {
226 let Self { scope: _, inner } = self;
227 let e: &dyn WebGpuError = match inner {
228 ComputePassErrorInner::Device(e) => e,
229 ComputePassErrorInner::EncoderState(e) => e,
230 ComputePassErrorInner::DebugGroupError(e) => e,
231 ComputePassErrorInner::DestroyedResource(e) => e,
232 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
233 ComputePassErrorInner::MissingBufferUsage(e) => e,
234 ComputePassErrorInner::Dispatch(e) => e,
235 ComputePassErrorInner::Bind(e) => e,
236 ComputePassErrorInner::ImmediateData(e) => e,
237 ComputePassErrorInner::QueryUse(e) => e,
238 ComputePassErrorInner::MissingFeatures(e) => e,
239 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
240 ComputePassErrorInner::InvalidResource(e) => e,
241 ComputePassErrorInner::TimestampWrites(e) => e,
242 ComputePassErrorInner::InvalidValuesOffset(e) => e,
243
244 ComputePassErrorInner::InvalidParentEncoder
245 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247 | ComputePassErrorInner::IndirectBufferOverrun { .. }
248 | ComputePassErrorInner::ImmediateOffsetAlignment
249 | ComputePassErrorInner::ImmediateDataizeAlignment
250 | ComputePassErrorInner::ImmediateOutOfMemory
251 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
252 };
253 e.webgpu_error_type()
254 }
255}
256
257struct State<'scope, 'snatch_guard, 'cmd_enc> {
258 pipeline: Option<Arc<ComputePipeline>>,
259
260 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
261
262 active_query: Option<(Arc<resource::QuerySet>, u32)>,
263
264 immediates: Vec<u32>,
265
266 intermediate_trackers: Tracker,
267}
268
269impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
270 fn is_ready(&self) -> Result<(), DispatchError> {
271 if let Some(pipeline) = self.pipeline.as_ref() {
272 self.pass.binder.check_compatibility(pipeline.as_ref())?;
273 self.pass.binder.check_late_buffer_bindings()?;
274 Ok(())
275 } else {
276 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
277 }
278 }
279
280 fn flush_bindings(
319 &mut self,
320 indirect_buffer: Option<&Arc<Buffer>>,
321 track_indirect_buffer: bool,
322 ) -> Result<(), ComputePassErrorInner> {
323 for bind_group in self.pass.binder.list_active() {
324 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
325 }
326
327 if let Some(buffer) = indirect_buffer {
331 self.pass
332 .scope
333 .buffers
334 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
335 }
336
337 for bind_group in self.pass.binder.list_active() {
343 self.intermediate_trackers
344 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
345 }
346
347 if track_indirect_buffer {
348 self.intermediate_trackers
349 .buffers
350 .set_and_remove_from_usage_scope_sparse(
351 &mut self.pass.scope.buffers,
352 indirect_buffer.map(|buf| buf.tracker_index()),
353 );
354 } else if let Some(buffer) = indirect_buffer {
355 self.pass
356 .scope
357 .buffers
358 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
359 }
360
361 flush_bindings_helper(&mut self.pass)?;
362
363 CommandEncoder::drain_barriers(
364 self.pass.base.raw_encoder,
365 &mut self.intermediate_trackers,
366 self.pass.base.snatch_guard,
367 );
368 Ok(())
369 }
370}
371
372impl Global {
375 pub fn command_encoder_begin_compute_pass(
386 &self,
387 encoder_id: id::CommandEncoderId,
388 desc: &ComputePassDescriptor<'_>,
389 ) -> (ComputePass, Option<CommandEncoderError>) {
390 use EncoderStateError as SErr;
391
392 let scope = PassErrorScope::Pass;
393 let hub = &self.hub;
394
395 let label = desc.label.as_deref().map(Cow::Borrowed);
396
397 let cmd_enc = hub.command_encoders.get(encoder_id);
398 let mut cmd_buf_data = cmd_enc.data.lock();
399
400 match cmd_buf_data.lock_encoder() {
401 Ok(()) => {
402 drop(cmd_buf_data);
403 if let Err(err) = cmd_enc.device.check_is_valid() {
404 return (
405 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
406 None,
407 );
408 }
409
410 match desc
411 .timestamp_writes
412 .as_ref()
413 .map(|tw| {
414 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
415 &cmd_enc.device,
416 &hub.query_sets.read(),
417 tw,
418 )
419 })
420 .transpose()
421 {
422 Ok(timestamp_writes) => {
423 let arc_desc = ArcComputePassDescriptor {
424 label,
425 timestamp_writes,
426 };
427 (ComputePass::new(cmd_enc, arc_desc), None)
428 }
429 Err(err) => (
430 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
431 None,
432 ),
433 }
434 }
435 Err(err @ SErr::Locked) => {
436 cmd_buf_data.invalidate(err.clone());
440 drop(cmd_buf_data);
441 (
442 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
443 None,
444 )
445 }
446 Err(err @ (SErr::Ended | SErr::Submitted)) => {
447 drop(cmd_buf_data);
450 (
451 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
452 Some(err.into()),
453 )
454 }
455 Err(err @ SErr::Invalid) => {
456 drop(cmd_buf_data);
462 (
463 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
464 None,
465 )
466 }
467 Err(SErr::Unlocked) => {
468 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
469 }
470 }
471 }
472
473 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
474 profiling::scope!(
475 "CommandEncoder::run_compute_pass {}",
476 pass.base.label.as_deref().unwrap_or("")
477 );
478
479 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
480 let mut cmd_buf_data = cmd_enc.data.lock();
481
482 cmd_buf_data.unlock_encoder()?;
483
484 let base = pass.base.take();
485
486 if let Err(ComputePassError {
487 inner:
488 ComputePassErrorInner::EncoderState(
489 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
490 ),
491 scope: _,
492 }) = base
493 {
494 return Err(err.clone());
501 }
502
503 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
504 Ok(ArcCommand::RunComputePass {
505 pass: base?,
506 timestamp_writes: pass.timestamp_writes.take(),
507 })
508 })
509 }
510}
511
512pub(super) fn encode_compute_pass(
513 parent_state: &mut EncodingState<InnerCommandEncoder>,
514 mut base: BasePass<ArcComputeCommand, Infallible>,
515 mut timestamp_writes: Option<ArcPassTimestampWrites>,
516) -> Result<(), ComputePassError> {
517 let pass_scope = PassErrorScope::Pass;
518
519 let device = parent_state.device;
520
521 parent_state
525 .raw_encoder
526 .close_if_open()
527 .map_pass_err(pass_scope)?;
528 let raw_encoder = parent_state
529 .raw_encoder
530 .open_pass(base.label.as_deref())
531 .map_pass_err(pass_scope)?;
532
533 let mut debug_scope_depth = 0;
534
535 let mut state = State {
536 pipeline: None,
537
538 pass: pass::PassState {
539 base: EncodingState {
540 device,
541 raw_encoder,
542 tracker: parent_state.tracker,
543 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
544 texture_memory_actions: parent_state.texture_memory_actions,
545 as_actions: parent_state.as_actions,
546 temp_resources: parent_state.temp_resources,
547 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
548 snatch_guard: parent_state.snatch_guard,
549 debug_scope_depth: &mut debug_scope_depth,
550 },
551 binder: Binder::new(),
552 temp_offsets: Vec::new(),
553 dynamic_offset_count: 0,
554 pending_discard_init_fixups: SurfacesInDiscardState::new(),
555 scope: device.new_usage_scope(),
556 string_offset: 0,
557 },
558 active_query: None,
559
560 immediates: Vec::new(),
561
562 intermediate_trackers: Tracker::new(),
563 };
564
565 let indices = &device.tracker_indices;
566 state
567 .pass
568 .base
569 .tracker
570 .buffers
571 .set_size(indices.buffers.size());
572 state
573 .pass
574 .base
575 .tracker
576 .textures
577 .set_size(indices.textures.size());
578
579 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
580 if let Some(tw) = timestamp_writes.take() {
581 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
582
583 let query_set = state
584 .pass
585 .base
586 .tracker
587 .query_sets
588 .insert_single(tw.query_set);
589
590 let range = if let (Some(index_a), Some(index_b)) =
593 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
594 {
595 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
596 } else {
597 tw.beginning_of_pass_write_index
598 .or(tw.end_of_pass_write_index)
599 .map(|i| i..i + 1)
600 };
601 if let Some(range) = range {
604 unsafe {
605 state
606 .pass
607 .base
608 .raw_encoder
609 .reset_queries(query_set.raw(), range);
610 }
611 }
612
613 Some(hal::PassTimestampWrites {
614 query_set: query_set.raw(),
615 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
616 end_of_pass_write_index: tw.end_of_pass_write_index,
617 })
618 } else {
619 None
620 };
621
622 let hal_desc = hal::ComputePassDescriptor {
623 label: hal_label(base.label.as_deref(), device.instance_flags),
624 timestamp_writes,
625 };
626
627 unsafe {
628 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
629 }
630
631 for command in base.commands.drain(..) {
632 match command {
633 ArcComputeCommand::SetBindGroup {
634 index,
635 num_dynamic_offsets,
636 bind_group,
637 } => {
638 let scope = PassErrorScope::SetBindGroup;
639 pass::set_bind_group::<ComputePassErrorInner>(
640 &mut state.pass,
641 device,
642 &base.dynamic_offsets,
643 index,
644 num_dynamic_offsets,
645 bind_group,
646 false,
647 )
648 .map_pass_err(scope)?;
649 }
650 ArcComputeCommand::SetPipeline(pipeline) => {
651 let scope = PassErrorScope::SetPipelineCompute;
652 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
653 }
654 ArcComputeCommand::SetImmediate {
655 offset,
656 size_bytes,
657 values_offset,
658 } => {
659 let scope = PassErrorScope::SetImmediate;
660 pass::set_immediates::<ComputePassErrorInner, _>(
661 &mut state.pass,
662 &base.immediates_data,
663 offset,
664 size_bytes,
665 Some(values_offset),
666 |data_slice| {
667 let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
668 let size_in_elements =
669 (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
670 state.immediates[offset_in_elements..][..size_in_elements]
671 .copy_from_slice(data_slice);
672 },
673 )
674 .map_pass_err(scope)?;
675 }
676 ArcComputeCommand::Dispatch(groups) => {
677 let scope = PassErrorScope::Dispatch { indirect: false };
678 dispatch(&mut state, groups).map_pass_err(scope)?;
679 }
680 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
681 let scope = PassErrorScope::Dispatch { indirect: true };
682 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
683 }
684 ArcComputeCommand::PushDebugGroup { color: _, len } => {
685 pass::push_debug_group(&mut state.pass, &base.string_data, len);
686 }
687 ArcComputeCommand::PopDebugGroup => {
688 let scope = PassErrorScope::PopDebugGroup;
689 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
690 .map_pass_err(scope)?;
691 }
692 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
693 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
694 }
695 ArcComputeCommand::WriteTimestamp {
696 query_set,
697 query_index,
698 } => {
699 let scope = PassErrorScope::WriteTimestamp;
700 pass::write_timestamp::<ComputePassErrorInner>(
701 &mut state.pass,
702 device,
703 None, query_set,
705 query_index,
706 )
707 .map_pass_err(scope)?;
708 }
709 ArcComputeCommand::BeginPipelineStatisticsQuery {
710 query_set,
711 query_index,
712 } => {
713 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
714 validate_and_begin_pipeline_statistics_query(
715 query_set,
716 state.pass.base.raw_encoder,
717 &mut state.pass.base.tracker.query_sets,
718 device,
719 query_index,
720 None,
721 &mut state.active_query,
722 )
723 .map_pass_err(scope)?;
724 }
725 ArcComputeCommand::EndPipelineStatisticsQuery => {
726 let scope = PassErrorScope::EndPipelineStatisticsQuery;
727 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
728 .map_pass_err(scope)?;
729 }
730 }
731 }
732
733 if *state.pass.base.debug_scope_depth > 0 {
734 Err(
735 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
736 .map_pass_err(pass_scope),
737 )?;
738 }
739
740 unsafe {
741 state.pass.base.raw_encoder.end_compute_pass();
742 }
743
744 let State {
745 pass: pass::PassState {
746 pending_discard_init_fixups,
747 ..
748 },
749 intermediate_trackers,
750 ..
751 } = state;
752
753 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
755
756 let transit = parent_state
760 .raw_encoder
761 .open_pass(hal_label(
762 Some("(wgpu internal) Pre Pass"),
763 device.instance_flags,
764 ))
765 .map_pass_err(pass_scope)?;
766 fixup_discarded_surfaces(
767 pending_discard_init_fixups.into_iter(),
768 transit,
769 &mut parent_state.tracker.textures,
770 device,
771 parent_state.snatch_guard,
772 );
773 CommandEncoder::insert_barriers_from_tracker(
774 transit,
775 parent_state.tracker,
776 &intermediate_trackers,
777 parent_state.snatch_guard,
778 );
779 parent_state
781 .raw_encoder
782 .close_and_swap()
783 .map_pass_err(pass_scope)?;
784
785 Ok(())
786}
787
788fn set_pipeline(
789 state: &mut State,
790 device: &Arc<Device>,
791 pipeline: Arc<ComputePipeline>,
792) -> Result<(), ComputePassErrorInner> {
793 pipeline.same_device(device)?;
794
795 state.pipeline = Some(pipeline.clone());
796
797 let pipeline = state
798 .pass
799 .base
800 .tracker
801 .compute_pipelines
802 .insert_single(pipeline)
803 .clone();
804
805 unsafe {
806 state
807 .pass
808 .base
809 .raw_encoder
810 .set_compute_pipeline(pipeline.raw());
811 }
812
813 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
815 &mut state.pass,
816 &pipeline.layout,
817 &pipeline.late_sized_buffer_groups,
818 || {
819 state.immediates.clear();
822 if pipeline.layout.immediate_size != 0 {
824 let len = pipeline.layout.immediate_size as usize
826 / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
827 state.immediates.extend(core::iter::repeat_n(0, len));
828 }
829 },
830 )
831}
832
833fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
834 api_log!("ComputePass::dispatch {groups:?}");
835
836 state.is_ready()?;
837
838 state.flush_bindings(None, false)?;
839
840 let groups_size_limit = state
841 .pass
842 .base
843 .device
844 .limits
845 .max_compute_workgroups_per_dimension;
846
847 if groups[0] > groups_size_limit
848 || groups[1] > groups_size_limit
849 || groups[2] > groups_size_limit
850 {
851 return Err(ComputePassErrorInner::Dispatch(
852 DispatchError::InvalidGroupSize {
853 current: groups,
854 limit: groups_size_limit,
855 },
856 ));
857 }
858
859 unsafe {
860 state.pass.base.raw_encoder.dispatch(groups);
861 }
862 Ok(())
863}
864
865fn dispatch_indirect(
866 state: &mut State,
867 device: &Arc<Device>,
868 buffer: Arc<Buffer>,
869 offset: u64,
870) -> Result<(), ComputePassErrorInner> {
871 api_log!("ComputePass::dispatch_indirect");
872
873 buffer.same_device(device)?;
874
875 state.is_ready()?;
876
877 state
878 .pass
879 .base
880 .device
881 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
882
883 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
884
885 if offset % 4 != 0 {
886 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
887 }
888
889 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
890 if end_offset > buffer.size {
891 return Err(ComputePassErrorInner::IndirectBufferOverrun {
892 offset,
893 end_offset,
894 buffer_size: buffer.size,
895 });
896 }
897
898 buffer.check_destroyed(state.pass.base.snatch_guard)?;
899
900 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
902 buffer.initialization_status.read().create_action(
903 &buffer,
904 offset..(offset + stride),
905 MemoryInitKind::NeedsInitializedMemory,
906 ),
907 );
908
909 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
910 let params = indirect_validation.dispatch.params(
911 &state.pass.base.device.limits,
912 offset,
913 buffer.size,
914 );
915
916 unsafe {
917 state
918 .pass
919 .base
920 .raw_encoder
921 .set_compute_pipeline(params.pipeline);
922 }
923
924 unsafe {
925 state.pass.base.raw_encoder.set_immediates(
926 params.pipeline_layout,
927 0,
928 &[params.offset_remainder as u32 / 4],
929 );
930 }
931
932 unsafe {
933 state.pass.base.raw_encoder.set_bind_group(
934 params.pipeline_layout,
935 0,
936 Some(params.dst_bind_group),
937 &[],
938 );
939 }
940 unsafe {
941 state.pass.base.raw_encoder.set_bind_group(
942 params.pipeline_layout,
943 1,
944 Some(
945 buffer
946 .indirect_validation_bind_groups
947 .get(state.pass.base.snatch_guard)
948 .unwrap()
949 .dispatch
950 .as_ref(),
951 ),
952 &[params.aligned_offset as u32],
953 );
954 }
955
956 let src_transition = state
957 .intermediate_trackers
958 .buffers
959 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
960 let src_barrier = src_transition
961 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
962 unsafe {
963 state
964 .pass
965 .base
966 .raw_encoder
967 .transition_buffers(src_barrier.as_slice());
968 }
969
970 unsafe {
971 state
972 .pass
973 .base
974 .raw_encoder
975 .transition_buffers(&[hal::BufferBarrier {
976 buffer: params.dst_buffer,
977 usage: hal::StateTransition {
978 from: wgt::BufferUses::INDIRECT,
979 to: wgt::BufferUses::STORAGE_READ_WRITE,
980 },
981 }]);
982 }
983
984 unsafe {
985 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
986 }
987
988 {
990 let pipeline = state.pipeline.as_ref().unwrap();
991
992 unsafe {
993 state
994 .pass
995 .base
996 .raw_encoder
997 .set_compute_pipeline(pipeline.raw());
998 }
999
1000 if !state.immediates.is_empty() {
1001 unsafe {
1002 state.pass.base.raw_encoder.set_immediates(
1003 pipeline.layout.raw(),
1004 0,
1005 &state.immediates,
1006 );
1007 }
1008 }
1009
1010 for (i, e) in state.pass.binder.list_valid() {
1011 let group = e.group.as_ref().unwrap();
1012 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1013 unsafe {
1014 state.pass.base.raw_encoder.set_bind_group(
1015 pipeline.layout.raw(),
1016 i as u32,
1017 Some(raw_bg),
1018 &e.dynamic_offsets,
1019 );
1020 }
1021 }
1022 }
1023
1024 unsafe {
1025 state
1026 .pass
1027 .base
1028 .raw_encoder
1029 .transition_buffers(&[hal::BufferBarrier {
1030 buffer: params.dst_buffer,
1031 usage: hal::StateTransition {
1032 from: wgt::BufferUses::STORAGE_READ_WRITE,
1033 to: wgt::BufferUses::INDIRECT,
1034 },
1035 }]);
1036 }
1037
1038 state.flush_bindings(Some(&buffer), false)?;
1039 unsafe {
1040 state
1041 .pass
1042 .base
1043 .raw_encoder
1044 .dispatch_indirect(params.dst_buffer, 0);
1045 }
1046 } else {
1047 state.flush_bindings(Some(&buffer), true)?;
1048
1049 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1050 unsafe {
1051 state
1052 .pass
1053 .base
1054 .raw_encoder
1055 .dispatch_indirect(buf_raw, offset);
1056 }
1057 }
1058
1059 Ok(())
1060}
1061
1062impl Global {
1075 pub fn compute_pass_set_bind_group(
1076 &self,
1077 pass: &mut ComputePass,
1078 index: u32,
1079 bind_group_id: Option<id::BindGroupId>,
1080 offsets: &[DynamicOffset],
1081 ) -> Result<(), PassStateError> {
1082 let scope = PassErrorScope::SetBindGroup;
1083
1084 let base = pass_base!(pass, scope);
1088
1089 if pass.current_bind_groups.set_and_check_redundant(
1090 bind_group_id,
1091 index,
1092 &mut base.dynamic_offsets,
1093 offsets,
1094 ) {
1095 return Ok(());
1096 }
1097
1098 let mut bind_group = None;
1099 if let Some(bind_group_id) = bind_group_id {
1100 let hub = &self.hub;
1101 bind_group = Some(pass_try!(
1102 base,
1103 scope,
1104 hub.bind_groups.get(bind_group_id).get(),
1105 ));
1106 }
1107
1108 base.commands.push(ArcComputeCommand::SetBindGroup {
1109 index,
1110 num_dynamic_offsets: offsets.len(),
1111 bind_group,
1112 });
1113
1114 Ok(())
1115 }
1116
1117 pub fn compute_pass_set_pipeline(
1118 &self,
1119 pass: &mut ComputePass,
1120 pipeline_id: id::ComputePipelineId,
1121 ) -> Result<(), PassStateError> {
1122 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1123
1124 let scope = PassErrorScope::SetPipelineCompute;
1125
1126 let base = pass_base!(pass, scope);
1129
1130 if redundant {
1131 return Ok(());
1132 }
1133
1134 let hub = &self.hub;
1135 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1136
1137 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1138
1139 Ok(())
1140 }
1141
1142 pub fn compute_pass_set_immediates(
1143 &self,
1144 pass: &mut ComputePass,
1145 offset: u32,
1146 data: &[u8],
1147 ) -> Result<(), PassStateError> {
1148 let scope = PassErrorScope::SetImmediate;
1149 let base = pass_base!(pass, scope);
1150
1151 if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1152 pass_try!(
1153 base,
1154 scope,
1155 Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1156 );
1157 }
1158
1159 if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1160 pass_try!(
1161 base,
1162 scope,
1163 Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1164 )
1165 }
1166 let value_offset = pass_try!(
1167 base,
1168 scope,
1169 base.immediates_data
1170 .len()
1171 .try_into()
1172 .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1173 );
1174
1175 base.immediates_data.extend(
1176 data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1177 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1178 );
1179
1180 base.commands.push(ArcComputeCommand::SetImmediate {
1181 offset,
1182 size_bytes: data.len() as u32,
1183 values_offset: value_offset,
1184 });
1185
1186 Ok(())
1187 }
1188
1189 pub fn compute_pass_dispatch_workgroups(
1190 &self,
1191 pass: &mut ComputePass,
1192 groups_x: u32,
1193 groups_y: u32,
1194 groups_z: u32,
1195 ) -> Result<(), PassStateError> {
1196 let scope = PassErrorScope::Dispatch { indirect: false };
1197
1198 pass_base!(pass, scope)
1199 .commands
1200 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1201
1202 Ok(())
1203 }
1204
1205 pub fn compute_pass_dispatch_workgroups_indirect(
1206 &self,
1207 pass: &mut ComputePass,
1208 buffer_id: id::BufferId,
1209 offset: BufferAddress,
1210 ) -> Result<(), PassStateError> {
1211 let hub = &self.hub;
1212 let scope = PassErrorScope::Dispatch { indirect: true };
1213 let base = pass_base!(pass, scope);
1214
1215 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1216
1217 base.commands
1218 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1219
1220 Ok(())
1221 }
1222
1223 pub fn compute_pass_push_debug_group(
1224 &self,
1225 pass: &mut ComputePass,
1226 label: &str,
1227 color: u32,
1228 ) -> Result<(), PassStateError> {
1229 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1230
1231 let bytes = label.as_bytes();
1232 base.string_data.extend_from_slice(bytes);
1233
1234 base.commands.push(ArcComputeCommand::PushDebugGroup {
1235 color,
1236 len: bytes.len(),
1237 });
1238
1239 Ok(())
1240 }
1241
1242 pub fn compute_pass_pop_debug_group(
1243 &self,
1244 pass: &mut ComputePass,
1245 ) -> Result<(), PassStateError> {
1246 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1247
1248 base.commands.push(ArcComputeCommand::PopDebugGroup);
1249
1250 Ok(())
1251 }
1252
1253 pub fn compute_pass_insert_debug_marker(
1254 &self,
1255 pass: &mut ComputePass,
1256 label: &str,
1257 color: u32,
1258 ) -> Result<(), PassStateError> {
1259 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1260
1261 let bytes = label.as_bytes();
1262 base.string_data.extend_from_slice(bytes);
1263
1264 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1265 color,
1266 len: bytes.len(),
1267 });
1268
1269 Ok(())
1270 }
1271
1272 pub fn compute_pass_write_timestamp(
1273 &self,
1274 pass: &mut ComputePass,
1275 query_set_id: id::QuerySetId,
1276 query_index: u32,
1277 ) -> Result<(), PassStateError> {
1278 let scope = PassErrorScope::WriteTimestamp;
1279 let base = pass_base!(pass, scope);
1280
1281 let hub = &self.hub;
1282 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1283
1284 base.commands.push(ArcComputeCommand::WriteTimestamp {
1285 query_set,
1286 query_index,
1287 });
1288
1289 Ok(())
1290 }
1291
1292 pub fn compute_pass_begin_pipeline_statistics_query(
1293 &self,
1294 pass: &mut ComputePass,
1295 query_set_id: id::QuerySetId,
1296 query_index: u32,
1297 ) -> Result<(), PassStateError> {
1298 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1299 let base = pass_base!(pass, scope);
1300
1301 let hub = &self.hub;
1302 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1303
1304 base.commands
1305 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1306 query_set,
1307 query_index,
1308 });
1309
1310 Ok(())
1311 }
1312
1313 pub fn compute_pass_end_pipeline_statistics_query(
1314 &self,
1315 pass: &mut ComputePass,
1316 ) -> Result<(), PassStateError> {
1317 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1318 .commands
1319 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1320
1321 Ok(())
1322 }
1323}