1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::{BindError, ImmediateUploadError, LateMinBufferBindingSizeMismatch},
13 command::{
14 bind::{Binder, BinderError},
15 compute_command::ArcComputeCommand,
16 encoder::EncodingState,
17 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
18 pass::{self, flush_bindings_helper},
19 pass_base, pass_try,
20 query::{end_pipeline_statistics_query, validate_and_begin_pipeline_statistics_query},
21 ArcCommand, ArcPassTimestampWrites, BasePass, BindGroupStateChange, CommandEncoder,
22 CommandEncoderError, DebugGroupError, EncoderStateError, InnerCommandEncoder, MapPassErr,
23 PassErrorScope, PassStateError, PassTimestampWrites, QueryUseError, StateChange,
24 TimestampWritesError,
25 },
26 device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
27 global::Global,
28 hal_label, id,
29 init_tracker::MemoryInitKind,
30 pipeline::ComputePipeline,
31 resource::{
32 self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
33 MissingBufferUsageError, ParentDevice, RawResourceAccess, Trackable,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker},
36 Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41pub struct ComputePass {
49 base: ComputeBasePass,
51
52 parent: Option<Arc<CommandEncoder>>,
58
59 timestamp_writes: Option<ArcPassTimestampWrites>,
60
61 current_bind_groups: BindGroupStateChange,
63 current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69 let ArcComputePassDescriptor {
70 label,
71 timestamp_writes,
72 } = desc;
73
74 Self {
75 base: BasePass::new(&label),
76 parent: Some(parent),
77 timestamp_writes,
78
79 current_bind_groups: BindGroupStateChange::new(),
80 current_pipeline: StateChange::new(),
81 }
82 }
83
84 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85 Self {
86 base: BasePass::new_invalid(label, err),
87 parent: Some(parent),
88 timestamp_writes: None,
89 current_bind_groups: BindGroupStateChange::new(),
90 current_pipeline: StateChange::new(),
91 }
92 }
93
94 #[inline]
95 pub fn label(&self) -> Option<&str> {
96 self.base.label.as_deref()
97 }
98}
99
100impl fmt::Debug for ComputePass {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self.parent {
103 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104 None => write!(f, "ComputePass {{ parent: None }}"),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111 pub label: Label<'a>,
112 pub timestamp_writes: Option<PTW>,
114}
115
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122 #[error("Compute pipeline must be set")]
123 MissingPipeline(pass::MissingPipeline),
124 #[error(transparent)]
125 IncompatibleBindGroup(#[from] Box<BinderError>),
126 #[error(
127 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128 )]
129 InvalidGroupSize { current: [u32; 3], limit: u32 },
130 #[error(transparent)]
131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135 fn webgpu_error_type(&self) -> ErrorType {
136 ErrorType::Validation
137 }
138}
139
140#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143 #[error(transparent)]
144 Device(#[from] DeviceError),
145 #[error(transparent)]
146 EncoderState(#[from] EncoderStateError),
147 #[error("Parent encoder is invalid")]
148 InvalidParentEncoder,
149 #[error(transparent)]
150 DebugGroupError(#[from] DebugGroupError),
151 #[error(transparent)]
152 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153 #[error(transparent)]
154 DestroyedResource(#[from] DestroyedResourceError),
155 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156 UnalignedIndirectBufferOffset(BufferAddress),
157 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158 IndirectBufferOverrun {
159 offset: u64,
160 end_offset: u64,
161 buffer_size: u64,
162 },
163 #[error(transparent)]
164 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165 #[error(transparent)]
166 MissingBufferUsage(#[from] MissingBufferUsageError),
167 #[error(transparent)]
168 Dispatch(#[from] DispatchError),
169 #[error(transparent)]
170 Bind(#[from] BindError),
171 #[error(transparent)]
172 ImmediateData(#[from] ImmediateUploadError),
173 #[error("Immediate data offset must be aligned to 4 bytes")]
174 ImmediateOffsetAlignment,
175 #[error("Immediate data size must be aligned to 4 bytes")]
176 ImmediateDataizeAlignment,
177 #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
178 ImmediateOutOfMemory,
179 #[error(transparent)]
180 QueryUse(#[from] QueryUseError),
181 #[error(transparent)]
182 MissingFeatures(#[from] MissingFeatures),
183 #[error(transparent)]
184 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185 #[error("The compute pass has already been ended and no further commands can be recorded")]
186 PassEnded,
187 #[error(transparent)]
188 InvalidResource(#[from] InvalidResourceError),
189 #[error(transparent)]
190 TimestampWrites(#[from] TimestampWritesError),
191 #[error(transparent)]
193 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201 pub scope: PassErrorScope,
202 #[source]
203 pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207 fn from(value: pass::MissingPipeline) -> Self {
208 Self::Dispatch(DispatchError::MissingPipeline(value))
209 }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214 E: Into<ComputePassErrorInner>,
215{
216 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217 ComputePassError {
218 scope,
219 inner: self.into(),
220 }
221 }
222}
223
224impl WebGpuError for ComputePassError {
225 fn webgpu_error_type(&self) -> ErrorType {
226 let Self { scope: _, inner } = self;
227 let e: &dyn WebGpuError = match inner {
228 ComputePassErrorInner::Device(e) => e,
229 ComputePassErrorInner::EncoderState(e) => e,
230 ComputePassErrorInner::DebugGroupError(e) => e,
231 ComputePassErrorInner::DestroyedResource(e) => e,
232 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
233 ComputePassErrorInner::MissingBufferUsage(e) => e,
234 ComputePassErrorInner::Dispatch(e) => e,
235 ComputePassErrorInner::Bind(e) => e,
236 ComputePassErrorInner::ImmediateData(e) => e,
237 ComputePassErrorInner::QueryUse(e) => e,
238 ComputePassErrorInner::MissingFeatures(e) => e,
239 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
240 ComputePassErrorInner::InvalidResource(e) => e,
241 ComputePassErrorInner::TimestampWrites(e) => e,
242 ComputePassErrorInner::InvalidValuesOffset(e) => e,
243
244 ComputePassErrorInner::InvalidParentEncoder
245 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247 | ComputePassErrorInner::IndirectBufferOverrun { .. }
248 | ComputePassErrorInner::ImmediateOffsetAlignment
249 | ComputePassErrorInner::ImmediateDataizeAlignment
250 | ComputePassErrorInner::ImmediateOutOfMemory
251 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
252 };
253 e.webgpu_error_type()
254 }
255}
256
257struct State<'scope, 'snatch_guard, 'cmd_enc> {
258 pipeline: Option<Arc<ComputePipeline>>,
259
260 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
261
262 active_query: Option<(Arc<resource::QuerySet>, u32)>,
263
264 immediates: Vec<u32>,
265
266 intermediate_trackers: Tracker,
267}
268
269impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
270 fn is_ready(&self) -> Result<(), DispatchError> {
271 if let Some(pipeline) = self.pipeline.as_ref() {
272 self.pass.binder.check_compatibility(pipeline.as_ref())?;
273 self.pass.binder.check_late_buffer_bindings()?;
274 Ok(())
275 } else {
276 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
277 }
278 }
279
280 fn flush_bindings(
319 &mut self,
320 indirect_buffer: Option<&Arc<Buffer>>,
321 track_indirect_buffer: bool,
322 ) -> Result<(), ComputePassErrorInner> {
323 for bind_group in self.pass.binder.list_active() {
324 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
325 }
326
327 if let Some(buffer) = indirect_buffer {
331 self.pass
332 .scope
333 .buffers
334 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
335 }
336
337 for bind_group in self.pass.binder.list_active() {
343 self.intermediate_trackers
344 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
345 }
346
347 if track_indirect_buffer {
348 self.intermediate_trackers
349 .buffers
350 .set_and_remove_from_usage_scope_sparse(
351 &mut self.pass.scope.buffers,
352 indirect_buffer.map(|buf| buf.tracker_index()),
353 );
354 } else if let Some(buffer) = indirect_buffer {
355 self.pass
356 .scope
357 .buffers
358 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
359 }
360
361 flush_bindings_helper(&mut self.pass)?;
362
363 CommandEncoder::drain_barriers(
364 self.pass.base.raw_encoder,
365 &mut self.intermediate_trackers,
366 self.pass.base.snatch_guard,
367 );
368 Ok(())
369 }
370}
371
372impl Global {
375 pub fn command_encoder_begin_compute_pass(
386 &self,
387 encoder_id: id::CommandEncoderId,
388 desc: &ComputePassDescriptor<'_>,
389 ) -> (ComputePass, Option<CommandEncoderError>) {
390 use EncoderStateError as SErr;
391
392 let scope = PassErrorScope::Pass;
393 let hub = &self.hub;
394
395 let label = desc.label.as_deref().map(Cow::Borrowed);
396
397 let cmd_enc = hub.command_encoders.get(encoder_id);
398 let mut cmd_buf_data = cmd_enc.data.lock();
399
400 match cmd_buf_data.lock_encoder() {
401 Ok(()) => {
402 drop(cmd_buf_data);
403 if let Err(err) = cmd_enc.device.check_is_valid() {
404 return (
405 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
406 None,
407 );
408 }
409
410 match desc
411 .timestamp_writes
412 .as_ref()
413 .map(|tw| {
414 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
415 &cmd_enc.device,
416 &hub.query_sets.read(),
417 tw,
418 )
419 })
420 .transpose()
421 {
422 Ok(timestamp_writes) => {
423 let arc_desc = ArcComputePassDescriptor {
424 label,
425 timestamp_writes,
426 };
427 (ComputePass::new(cmd_enc, arc_desc), None)
428 }
429 Err(err) => (
430 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
431 None,
432 ),
433 }
434 }
435 Err(err @ SErr::Locked) => {
436 cmd_buf_data.invalidate(err.clone());
440 drop(cmd_buf_data);
441 (
442 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
443 None,
444 )
445 }
446 Err(err @ (SErr::Ended | SErr::Submitted)) => {
447 drop(cmd_buf_data);
450 (
451 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
452 Some(err.into()),
453 )
454 }
455 Err(err @ SErr::Invalid) => {
456 drop(cmd_buf_data);
462 (
463 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
464 None,
465 )
466 }
467 Err(SErr::Unlocked) => {
468 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
469 }
470 }
471 }
472
473 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
474 profiling::scope!(
475 "CommandEncoder::run_compute_pass {}",
476 pass.base.label.as_deref().unwrap_or("")
477 );
478
479 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
480 let mut cmd_buf_data = cmd_enc.data.lock();
481
482 cmd_buf_data.unlock_encoder()?;
483
484 let base = pass.base.take();
485
486 if let Err(ComputePassError {
487 inner:
488 ComputePassErrorInner::EncoderState(
489 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
490 ),
491 scope: _,
492 }) = base
493 {
494 return Err(err.clone());
501 }
502
503 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
504 Ok(ArcCommand::RunComputePass {
505 pass: base?,
506 timestamp_writes: pass.timestamp_writes.take(),
507 })
508 })
509 }
510}
511
512pub(super) fn encode_compute_pass(
513 parent_state: &mut EncodingState<InnerCommandEncoder>,
514 mut base: BasePass<ArcComputeCommand, Infallible>,
515 mut timestamp_writes: Option<ArcPassTimestampWrites>,
516) -> Result<(), ComputePassError> {
517 let pass_scope = PassErrorScope::Pass;
518
519 let device = parent_state.device;
520
521 parent_state
525 .raw_encoder
526 .close_if_open()
527 .map_pass_err(pass_scope)?;
528 let raw_encoder = parent_state
529 .raw_encoder
530 .open_pass(base.label.as_deref())
531 .map_pass_err(pass_scope)?;
532
533 let mut debug_scope_depth = 0;
534
535 let mut state = State {
536 pipeline: None,
537
538 pass: pass::PassState {
539 base: EncodingState {
540 device,
541 raw_encoder,
542 tracker: parent_state.tracker,
543 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
544 texture_memory_actions: parent_state.texture_memory_actions,
545 as_actions: parent_state.as_actions,
546 temp_resources: parent_state.temp_resources,
547 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
548 snatch_guard: parent_state.snatch_guard,
549 debug_scope_depth: &mut debug_scope_depth,
550 },
551 binder: Binder::new(),
552 temp_offsets: Vec::new(),
553 dynamic_offset_count: 0,
554 pending_discard_init_fixups: SurfacesInDiscardState::new(),
555 scope: device.new_usage_scope(),
556 string_offset: 0,
557 },
558 active_query: None,
559
560 immediates: Vec::new(),
561
562 intermediate_trackers: Tracker::new(),
563 };
564
565 let indices = &device.tracker_indices;
566 state
567 .pass
568 .base
569 .tracker
570 .buffers
571 .set_size(indices.buffers.size());
572 state
573 .pass
574 .base
575 .tracker
576 .textures
577 .set_size(indices.textures.size());
578
579 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
580 if let Some(tw) = timestamp_writes.take() {
581 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
582
583 let query_set = state
584 .pass
585 .base
586 .tracker
587 .query_sets
588 .insert_single(tw.query_set);
589
590 let range = if let (Some(index_a), Some(index_b)) =
593 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
594 {
595 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
596 } else {
597 tw.beginning_of_pass_write_index
598 .or(tw.end_of_pass_write_index)
599 .map(|i| i..i + 1)
600 };
601 if let Some(range) = range {
604 unsafe {
605 state
606 .pass
607 .base
608 .raw_encoder
609 .reset_queries(query_set.raw(), range);
610 }
611 }
612
613 Some(hal::PassTimestampWrites {
614 query_set: query_set.raw(),
615 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
616 end_of_pass_write_index: tw.end_of_pass_write_index,
617 })
618 } else {
619 None
620 };
621
622 let hal_desc = hal::ComputePassDescriptor {
623 label: hal_label(base.label.as_deref(), device.instance_flags),
624 timestamp_writes,
625 };
626
627 unsafe {
628 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
629 }
630
631 for command in base.commands.drain(..) {
632 match command {
633 ArcComputeCommand::SetBindGroup {
634 index,
635 num_dynamic_offsets,
636 bind_group,
637 } => {
638 let scope = PassErrorScope::SetBindGroup;
639 pass::set_bind_group::<ComputePassErrorInner>(
640 &mut state.pass,
641 device,
642 &base.dynamic_offsets,
643 index,
644 num_dynamic_offsets,
645 bind_group,
646 false,
647 )
648 .map_pass_err(scope)?;
649 }
650 ArcComputeCommand::SetPipeline(pipeline) => {
651 let scope = PassErrorScope::SetPipelineCompute;
652 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
653 }
654 ArcComputeCommand::SetImmediate {
655 offset,
656 size_bytes,
657 values_offset,
658 } => {
659 let scope = PassErrorScope::SetImmediate;
660 pass::set_immediates::<ComputePassErrorInner, _>(
661 &mut state.pass,
662 &base.immediates_data,
663 offset,
664 size_bytes,
665 Some(values_offset),
666 |data_slice| {
667 let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
668 let size_in_elements =
669 (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
670 state.immediates[offset_in_elements..][..size_in_elements]
671 .copy_from_slice(data_slice);
672 },
673 )
674 .map_pass_err(scope)?;
675 }
676 ArcComputeCommand::Dispatch(groups) => {
677 let scope = PassErrorScope::Dispatch { indirect: false };
678 dispatch(&mut state, groups).map_pass_err(scope)?;
679 }
680 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
681 let scope = PassErrorScope::Dispatch { indirect: true };
682 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
683 }
684 ArcComputeCommand::PushDebugGroup { color: _, len } => {
685 pass::push_debug_group(&mut state.pass, &base.string_data, len);
686 }
687 ArcComputeCommand::PopDebugGroup => {
688 let scope = PassErrorScope::PopDebugGroup;
689 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
690 .map_pass_err(scope)?;
691 }
692 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
693 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
694 }
695 ArcComputeCommand::WriteTimestamp {
696 query_set,
697 query_index,
698 } => {
699 let scope = PassErrorScope::WriteTimestamp;
700 pass::write_timestamp::<ComputePassErrorInner>(
701 &mut state.pass,
702 device,
703 None, query_set,
705 query_index,
706 )
707 .map_pass_err(scope)?;
708 }
709 ArcComputeCommand::BeginPipelineStatisticsQuery {
710 query_set,
711 query_index,
712 } => {
713 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
714 validate_and_begin_pipeline_statistics_query(
715 query_set,
716 state.pass.base.raw_encoder,
717 &mut state.pass.base.tracker.query_sets,
718 device,
719 query_index,
720 None,
721 &mut state.active_query,
722 )
723 .map_pass_err(scope)?;
724 }
725 ArcComputeCommand::EndPipelineStatisticsQuery => {
726 let scope = PassErrorScope::EndPipelineStatisticsQuery;
727 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
728 .map_pass_err(scope)?;
729 }
730 }
731 }
732
733 if *state.pass.base.debug_scope_depth > 0 {
734 Err(
735 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
736 .map_pass_err(pass_scope),
737 )?;
738 }
739
740 unsafe {
741 state.pass.base.raw_encoder.end_compute_pass();
742 }
743
744 let State {
745 pass: pass::PassState {
746 pending_discard_init_fixups,
747 ..
748 },
749 intermediate_trackers,
750 ..
751 } = state;
752
753 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
755
756 let transit = parent_state
760 .raw_encoder
761 .open_pass(hal_label(
762 Some("(wgpu internal) Pre Pass"),
763 device.instance_flags,
764 ))
765 .map_pass_err(pass_scope)?;
766 fixup_discarded_surfaces(
767 pending_discard_init_fixups.into_iter(),
768 transit,
769 &mut parent_state.tracker.textures,
770 device,
771 parent_state.snatch_guard,
772 );
773 CommandEncoder::insert_barriers_from_tracker(
774 transit,
775 parent_state.tracker,
776 &intermediate_trackers,
777 parent_state.snatch_guard,
778 );
779 parent_state
781 .raw_encoder
782 .close_and_swap()
783 .map_pass_err(pass_scope)?;
784
785 Ok(())
786}
787
788fn set_pipeline(
789 state: &mut State,
790 device: &Arc<Device>,
791 pipeline: Arc<ComputePipeline>,
792) -> Result<(), ComputePassErrorInner> {
793 pipeline.same_device(device)?;
794
795 state.pipeline = Some(pipeline.clone());
796
797 let pipeline = state
798 .pass
799 .base
800 .tracker
801 .compute_pipelines
802 .insert_single(pipeline)
803 .clone();
804
805 unsafe {
806 state
807 .pass
808 .base
809 .raw_encoder
810 .set_compute_pipeline(pipeline.raw());
811 }
812
813 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
815 &mut state.pass,
816 &pipeline.layout,
817 &pipeline.late_sized_buffer_groups,
818 || {
819 state.immediates.clear();
822 if pipeline.layout.immediate_size != 0 {
824 let len = pipeline.layout.immediate_size as usize
826 / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
827 state.immediates.extend(core::iter::repeat_n(0, len));
828 }
829 },
830 )
831}
832
833fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
834 api_log!("ComputePass::dispatch {groups:?}");
835
836 state.is_ready()?;
837
838 state.flush_bindings(None, false)?;
839
840 let groups_size_limit = state
841 .pass
842 .base
843 .device
844 .limits
845 .max_compute_workgroups_per_dimension;
846
847 if groups[0] > groups_size_limit
848 || groups[1] > groups_size_limit
849 || groups[2] > groups_size_limit
850 {
851 return Err(ComputePassErrorInner::Dispatch(
852 DispatchError::InvalidGroupSize {
853 current: groups,
854 limit: groups_size_limit,
855 },
856 ));
857 }
858
859 unsafe {
860 state.pass.base.raw_encoder.dispatch(groups);
861 }
862 Ok(())
863}
864
865fn dispatch_indirect(
866 state: &mut State,
867 device: &Arc<Device>,
868 buffer: Arc<Buffer>,
869 offset: u64,
870) -> Result<(), ComputePassErrorInner> {
871 api_log!("ComputePass::dispatch_indirect");
872
873 buffer.same_device(device)?;
874
875 state.is_ready()?;
876
877 state
878 .pass
879 .base
880 .device
881 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
882
883 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
884
885 if !offset.is_multiple_of(4) {
886 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
887 }
888
889 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
890 if end_offset > buffer.size {
891 return Err(ComputePassErrorInner::IndirectBufferOverrun {
892 offset,
893 end_offset,
894 buffer_size: buffer.size,
895 });
896 }
897
898 buffer.check_destroyed(state.pass.base.snatch_guard)?;
899
900 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
902 buffer.initialization_status.read().create_action(
903 &buffer,
904 offset..(offset + stride),
905 MemoryInitKind::NeedsInitializedMemory,
906 ),
907 );
908
909 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
910 let params = indirect_validation.dispatch.params(
911 &state.pass.base.device.limits,
912 offset,
913 buffer.size,
914 );
915
916 unsafe {
917 state
918 .pass
919 .base
920 .raw_encoder
921 .set_compute_pipeline(params.pipeline);
922 }
923
924 unsafe {
925 state.pass.base.raw_encoder.set_immediates(
926 params.pipeline_layout,
927 0,
928 &[params.offset_remainder as u32 / 4],
929 );
930 }
931
932 unsafe {
933 state.pass.base.raw_encoder.set_bind_group(
934 params.pipeline_layout,
935 0,
936 params.dst_bind_group,
937 &[],
938 );
939 }
940 unsafe {
941 state.pass.base.raw_encoder.set_bind_group(
942 params.pipeline_layout,
943 1,
944 buffer
945 .indirect_validation_bind_groups
946 .get(state.pass.base.snatch_guard)
947 .unwrap()
948 .dispatch
949 .as_ref(),
950 &[params.aligned_offset as u32],
951 );
952 }
953
954 let src_transition = state
955 .intermediate_trackers
956 .buffers
957 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
958 let src_barrier = src_transition
959 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
960 unsafe {
961 state
962 .pass
963 .base
964 .raw_encoder
965 .transition_buffers(src_barrier.as_slice());
966 }
967
968 unsafe {
969 state
970 .pass
971 .base
972 .raw_encoder
973 .transition_buffers(&[hal::BufferBarrier {
974 buffer: params.dst_buffer,
975 usage: hal::StateTransition {
976 from: wgt::BufferUses::INDIRECT,
977 to: wgt::BufferUses::STORAGE_READ_WRITE,
978 },
979 }]);
980 }
981
982 unsafe {
983 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
984 }
985
986 {
988 let pipeline = state.pipeline.as_ref().unwrap();
989
990 unsafe {
991 state
992 .pass
993 .base
994 .raw_encoder
995 .set_compute_pipeline(pipeline.raw());
996 }
997
998 if !state.immediates.is_empty() {
999 unsafe {
1000 state.pass.base.raw_encoder.set_immediates(
1001 pipeline.layout.raw(),
1002 0,
1003 &state.immediates,
1004 );
1005 }
1006 }
1007
1008 for (i, group, dynamic_offsets) in state.pass.binder.list_valid() {
1009 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1010 unsafe {
1011 state.pass.base.raw_encoder.set_bind_group(
1012 pipeline.layout.raw(),
1013 i as u32,
1014 raw_bg,
1015 dynamic_offsets,
1016 );
1017 }
1018 }
1019 }
1020
1021 unsafe {
1022 state
1023 .pass
1024 .base
1025 .raw_encoder
1026 .transition_buffers(&[hal::BufferBarrier {
1027 buffer: params.dst_buffer,
1028 usage: hal::StateTransition {
1029 from: wgt::BufferUses::STORAGE_READ_WRITE,
1030 to: wgt::BufferUses::INDIRECT,
1031 },
1032 }]);
1033 }
1034
1035 state.flush_bindings(Some(&buffer), false)?;
1036 unsafe {
1037 state
1038 .pass
1039 .base
1040 .raw_encoder
1041 .dispatch_indirect(params.dst_buffer, 0);
1042 }
1043 } else {
1044 state.flush_bindings(Some(&buffer), true)?;
1045
1046 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1047 unsafe {
1048 state
1049 .pass
1050 .base
1051 .raw_encoder
1052 .dispatch_indirect(buf_raw, offset);
1053 }
1054 }
1055
1056 Ok(())
1057}
1058
1059impl Global {
1072 pub fn compute_pass_set_bind_group(
1073 &self,
1074 pass: &mut ComputePass,
1075 index: u32,
1076 bind_group_id: Option<id::BindGroupId>,
1077 offsets: &[DynamicOffset],
1078 ) -> Result<(), PassStateError> {
1079 let scope = PassErrorScope::SetBindGroup;
1080
1081 let base = pass_base!(pass, scope);
1085
1086 if pass.current_bind_groups.set_and_check_redundant(
1087 bind_group_id,
1088 index,
1089 &mut base.dynamic_offsets,
1090 offsets,
1091 ) {
1092 return Ok(());
1093 }
1094
1095 let mut bind_group = None;
1096 if let Some(bind_group_id) = bind_group_id {
1097 let hub = &self.hub;
1098 bind_group = Some(pass_try!(
1099 base,
1100 scope,
1101 hub.bind_groups.get(bind_group_id).get(),
1102 ));
1103 }
1104
1105 base.commands.push(ArcComputeCommand::SetBindGroup {
1106 index,
1107 num_dynamic_offsets: offsets.len(),
1108 bind_group,
1109 });
1110
1111 Ok(())
1112 }
1113
1114 pub fn compute_pass_set_pipeline(
1115 &self,
1116 pass: &mut ComputePass,
1117 pipeline_id: id::ComputePipelineId,
1118 ) -> Result<(), PassStateError> {
1119 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1120
1121 let scope = PassErrorScope::SetPipelineCompute;
1122
1123 let base = pass_base!(pass, scope);
1126
1127 if redundant {
1128 return Ok(());
1129 }
1130
1131 let hub = &self.hub;
1132 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1133
1134 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1135
1136 Ok(())
1137 }
1138
1139 pub fn compute_pass_set_immediates(
1140 &self,
1141 pass: &mut ComputePass,
1142 offset: u32,
1143 data: &[u8],
1144 ) -> Result<(), PassStateError> {
1145 let scope = PassErrorScope::SetImmediate;
1146 let base = pass_base!(pass, scope);
1147
1148 if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1149 pass_try!(
1150 base,
1151 scope,
1152 Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1153 );
1154 }
1155
1156 if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1157 pass_try!(
1158 base,
1159 scope,
1160 Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1161 )
1162 }
1163 let value_offset = pass_try!(
1164 base,
1165 scope,
1166 base.immediates_data
1167 .len()
1168 .try_into()
1169 .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1170 );
1171
1172 base.immediates_data.extend(
1173 data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1174 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1175 );
1176
1177 base.commands.push(ArcComputeCommand::SetImmediate {
1178 offset,
1179 size_bytes: data.len() as u32,
1180 values_offset: value_offset,
1181 });
1182
1183 Ok(())
1184 }
1185
1186 pub fn compute_pass_dispatch_workgroups(
1187 &self,
1188 pass: &mut ComputePass,
1189 groups_x: u32,
1190 groups_y: u32,
1191 groups_z: u32,
1192 ) -> Result<(), PassStateError> {
1193 let scope = PassErrorScope::Dispatch { indirect: false };
1194
1195 pass_base!(pass, scope)
1196 .commands
1197 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1198
1199 Ok(())
1200 }
1201
1202 pub fn compute_pass_dispatch_workgroups_indirect(
1203 &self,
1204 pass: &mut ComputePass,
1205 buffer_id: id::BufferId,
1206 offset: BufferAddress,
1207 ) -> Result<(), PassStateError> {
1208 let hub = &self.hub;
1209 let scope = PassErrorScope::Dispatch { indirect: true };
1210 let base = pass_base!(pass, scope);
1211
1212 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1213
1214 base.commands
1215 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1216
1217 Ok(())
1218 }
1219
1220 pub fn compute_pass_push_debug_group(
1221 &self,
1222 pass: &mut ComputePass,
1223 label: &str,
1224 color: u32,
1225 ) -> Result<(), PassStateError> {
1226 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1227
1228 let bytes = label.as_bytes();
1229 base.string_data.extend_from_slice(bytes);
1230
1231 base.commands.push(ArcComputeCommand::PushDebugGroup {
1232 color,
1233 len: bytes.len(),
1234 });
1235
1236 Ok(())
1237 }
1238
1239 pub fn compute_pass_pop_debug_group(
1240 &self,
1241 pass: &mut ComputePass,
1242 ) -> Result<(), PassStateError> {
1243 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1244
1245 base.commands.push(ArcComputeCommand::PopDebugGroup);
1246
1247 Ok(())
1248 }
1249
1250 pub fn compute_pass_insert_debug_marker(
1251 &self,
1252 pass: &mut ComputePass,
1253 label: &str,
1254 color: u32,
1255 ) -> Result<(), PassStateError> {
1256 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1257
1258 let bytes = label.as_bytes();
1259 base.string_data.extend_from_slice(bytes);
1260
1261 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1262 color,
1263 len: bytes.len(),
1264 });
1265
1266 Ok(())
1267 }
1268
1269 pub fn compute_pass_write_timestamp(
1270 &self,
1271 pass: &mut ComputePass,
1272 query_set_id: id::QuerySetId,
1273 query_index: u32,
1274 ) -> Result<(), PassStateError> {
1275 let scope = PassErrorScope::WriteTimestamp;
1276 let base = pass_base!(pass, scope);
1277
1278 let hub = &self.hub;
1279 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1280
1281 base.commands.push(ArcComputeCommand::WriteTimestamp {
1282 query_set,
1283 query_index,
1284 });
1285
1286 Ok(())
1287 }
1288
1289 pub fn compute_pass_begin_pipeline_statistics_query(
1290 &self,
1291 pass: &mut ComputePass,
1292 query_set_id: id::QuerySetId,
1293 query_index: u32,
1294 ) -> Result<(), PassStateError> {
1295 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1296 let base = pass_base!(pass, scope);
1297
1298 let hub = &self.hub;
1299 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1300
1301 base.commands
1302 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1303 query_set,
1304 query_index,
1305 });
1306
1307 Ok(())
1308 }
1309
1310 pub fn compute_pass_end_pipeline_statistics_query(
1311 &self,
1312 pass: &mut ComputePass,
1313 ) -> Result<(), PassStateError> {
1314 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1315 .commands
1316 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1317
1318 Ok(())
1319 }
1320}