1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::BindError,
13 command::pass::flush_bindings_helper,
14 resource::{RawResourceAccess, Trackable},
15};
16use crate::{
17 binding_model::{ImmediateUploadError, LateMinBufferBindingSizeMismatch},
18 command::{
19 bind::{Binder, BinderError},
20 compute_command::ArcComputeCommand,
21 end_pipeline_statistics_query,
22 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25 PassTimestampWrites, QueryUseError, StateChange,
26 },
27 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28 global::Global,
29 hal_label, id,
30 init_tracker::MemoryInitKind,
31 pipeline::ComputePipeline,
32 resource::{
33 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker},
36 Label,
37};
38use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
39use crate::{
40 command::{
41 encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
42 EncoderStateError, PassStateError, TimestampWritesError,
43 },
44 device::Device,
45};
46
47pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
48
49pub struct ComputePass {
57 base: ComputeBasePass,
59
60 parent: Option<Arc<CommandEncoder>>,
66
67 timestamp_writes: Option<ArcPassTimestampWrites>,
68
69 current_bind_groups: BindGroupStateChange,
71 current_pipeline: StateChange<id::ComputePipelineId>,
72}
73
74impl ComputePass {
75 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
77 let ArcComputePassDescriptor {
78 label,
79 timestamp_writes,
80 } = desc;
81
82 Self {
83 base: BasePass::new(&label),
84 parent: Some(parent),
85 timestamp_writes,
86
87 current_bind_groups: BindGroupStateChange::new(),
88 current_pipeline: StateChange::new(),
89 }
90 }
91
92 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
93 Self {
94 base: BasePass::new_invalid(label, err),
95 parent: Some(parent),
96 timestamp_writes: None,
97 current_bind_groups: BindGroupStateChange::new(),
98 current_pipeline: StateChange::new(),
99 }
100 }
101
102 #[inline]
103 pub fn label(&self) -> Option<&str> {
104 self.base.label.as_deref()
105 }
106}
107
108impl fmt::Debug for ComputePass {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self.parent {
111 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
112 None => write!(f, "ComputePass {{ parent: None }}"),
113 }
114 }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
119 pub label: Label<'a>,
120 pub timestamp_writes: Option<PTW>,
122}
123
124type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
126
127#[derive(Clone, Debug, Error)]
128#[non_exhaustive]
129pub enum DispatchError {
130 #[error("Compute pipeline must be set")]
131 MissingPipeline(pass::MissingPipeline),
132 #[error(transparent)]
133 IncompatibleBindGroup(#[from] Box<BinderError>),
134 #[error(
135 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
136 )]
137 InvalidGroupSize { current: [u32; 3], limit: u32 },
138 #[error(transparent)]
139 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
140}
141
142impl WebGpuError for DispatchError {
143 fn webgpu_error_type(&self) -> ErrorType {
144 ErrorType::Validation
145 }
146}
147
148#[derive(Clone, Debug, Error)]
150pub enum ComputePassErrorInner {
151 #[error(transparent)]
152 Device(#[from] DeviceError),
153 #[error(transparent)]
154 EncoderState(#[from] EncoderStateError),
155 #[error("Parent encoder is invalid")]
156 InvalidParentEncoder,
157 #[error(transparent)]
158 DebugGroupError(#[from] DebugGroupError),
159 #[error(transparent)]
160 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
161 #[error(transparent)]
162 DestroyedResource(#[from] DestroyedResourceError),
163 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
164 UnalignedIndirectBufferOffset(BufferAddress),
165 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
166 IndirectBufferOverrun {
167 offset: u64,
168 end_offset: u64,
169 buffer_size: u64,
170 },
171 #[error(transparent)]
172 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
173 #[error(transparent)]
174 MissingBufferUsage(#[from] MissingBufferUsageError),
175 #[error(transparent)]
176 Dispatch(#[from] DispatchError),
177 #[error(transparent)]
178 Bind(#[from] BindError),
179 #[error(transparent)]
180 ImmediateData(#[from] ImmediateUploadError),
181 #[error("Immediate data offset must be aligned to 4 bytes")]
182 ImmediateOffsetAlignment,
183 #[error("Immediate data size must be aligned to 4 bytes")]
184 ImmediateDataizeAlignment,
185 #[error("Ran out of immediate data space. Don't set 4gb of immediates per ComputePass.")]
186 ImmediateOutOfMemory,
187 #[error(transparent)]
188 QueryUse(#[from] QueryUseError),
189 #[error(transparent)]
190 MissingFeatures(#[from] MissingFeatures),
191 #[error(transparent)]
192 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
193 #[error("The compute pass has already been ended and no further commands can be recorded")]
194 PassEnded,
195 #[error(transparent)]
196 InvalidResource(#[from] InvalidResourceError),
197 #[error(transparent)]
198 TimestampWrites(#[from] TimestampWritesError),
199 #[error(transparent)]
201 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
202}
203
204#[derive(Clone, Debug, Error)]
207#[error("{scope}")]
208pub struct ComputePassError {
209 pub scope: PassErrorScope,
210 #[source]
211 pub(super) inner: ComputePassErrorInner,
212}
213
214impl From<pass::MissingPipeline> for ComputePassErrorInner {
215 fn from(value: pass::MissingPipeline) -> Self {
216 Self::Dispatch(DispatchError::MissingPipeline(value))
217 }
218}
219
220impl<E> MapPassErr<ComputePassError> for E
221where
222 E: Into<ComputePassErrorInner>,
223{
224 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
225 ComputePassError {
226 scope,
227 inner: self.into(),
228 }
229 }
230}
231
232impl WebGpuError for ComputePassError {
233 fn webgpu_error_type(&self) -> ErrorType {
234 let Self { scope: _, inner } = self;
235 let e: &dyn WebGpuError = match inner {
236 ComputePassErrorInner::Device(e) => e,
237 ComputePassErrorInner::EncoderState(e) => e,
238 ComputePassErrorInner::DebugGroupError(e) => e,
239 ComputePassErrorInner::DestroyedResource(e) => e,
240 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
241 ComputePassErrorInner::MissingBufferUsage(e) => e,
242 ComputePassErrorInner::Dispatch(e) => e,
243 ComputePassErrorInner::Bind(e) => e,
244 ComputePassErrorInner::ImmediateData(e) => e,
245 ComputePassErrorInner::QueryUse(e) => e,
246 ComputePassErrorInner::MissingFeatures(e) => e,
247 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
248 ComputePassErrorInner::InvalidResource(e) => e,
249 ComputePassErrorInner::TimestampWrites(e) => e,
250 ComputePassErrorInner::InvalidValuesOffset(e) => e,
251
252 ComputePassErrorInner::InvalidParentEncoder
253 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
254 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
255 | ComputePassErrorInner::IndirectBufferOverrun { .. }
256 | ComputePassErrorInner::ImmediateOffsetAlignment
257 | ComputePassErrorInner::ImmediateDataizeAlignment
258 | ComputePassErrorInner::ImmediateOutOfMemory
259 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
260 };
261 e.webgpu_error_type()
262 }
263}
264
265struct State<'scope, 'snatch_guard, 'cmd_enc> {
266 pipeline: Option<Arc<ComputePipeline>>,
267
268 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
269
270 active_query: Option<(Arc<resource::QuerySet>, u32)>,
271
272 immediates: Vec<u32>,
273
274 intermediate_trackers: Tracker,
275}
276
277impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
278 fn is_ready(&self) -> Result<(), DispatchError> {
279 if let Some(pipeline) = self.pipeline.as_ref() {
280 self.pass.binder.check_compatibility(pipeline.as_ref())?;
281 self.pass.binder.check_late_buffer_bindings()?;
282 Ok(())
283 } else {
284 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
285 }
286 }
287
288 fn flush_bindings(
327 &mut self,
328 indirect_buffer: Option<&Arc<Buffer>>,
329 track_indirect_buffer: bool,
330 ) -> Result<(), ComputePassErrorInner> {
331 for bind_group in self.pass.binder.list_active() {
332 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
333 }
334
335 if let Some(buffer) = indirect_buffer {
339 self.pass
340 .scope
341 .buffers
342 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
343 }
344
345 for bind_group in self.pass.binder.list_active() {
351 self.intermediate_trackers
352 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
353 }
354
355 if track_indirect_buffer {
356 self.intermediate_trackers
357 .buffers
358 .set_and_remove_from_usage_scope_sparse(
359 &mut self.pass.scope.buffers,
360 indirect_buffer.map(|buf| buf.tracker_index()),
361 );
362 } else if let Some(buffer) = indirect_buffer {
363 self.pass
364 .scope
365 .buffers
366 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
367 }
368
369 flush_bindings_helper(&mut self.pass)?;
370
371 CommandEncoder::drain_barriers(
372 self.pass.base.raw_encoder,
373 &mut self.intermediate_trackers,
374 self.pass.base.snatch_guard,
375 );
376 Ok(())
377 }
378}
379
380impl Global {
383 pub fn command_encoder_begin_compute_pass(
394 &self,
395 encoder_id: id::CommandEncoderId,
396 desc: &ComputePassDescriptor<'_>,
397 ) -> (ComputePass, Option<CommandEncoderError>) {
398 use EncoderStateError as SErr;
399
400 let scope = PassErrorScope::Pass;
401 let hub = &self.hub;
402
403 let label = desc.label.as_deref().map(Cow::Borrowed);
404
405 let cmd_enc = hub.command_encoders.get(encoder_id);
406 let mut cmd_buf_data = cmd_enc.data.lock();
407
408 match cmd_buf_data.lock_encoder() {
409 Ok(()) => {
410 drop(cmd_buf_data);
411 if let Err(err) = cmd_enc.device.check_is_valid() {
412 return (
413 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414 None,
415 );
416 }
417
418 match desc
419 .timestamp_writes
420 .as_ref()
421 .map(|tw| {
422 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
423 &cmd_enc.device,
424 &hub.query_sets.read(),
425 tw,
426 )
427 })
428 .transpose()
429 {
430 Ok(timestamp_writes) => {
431 let arc_desc = ArcComputePassDescriptor {
432 label,
433 timestamp_writes,
434 };
435 (ComputePass::new(cmd_enc, arc_desc), None)
436 }
437 Err(err) => (
438 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
439 None,
440 ),
441 }
442 }
443 Err(err @ SErr::Locked) => {
444 cmd_buf_data.invalidate(err.clone());
448 drop(cmd_buf_data);
449 (
450 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
451 None,
452 )
453 }
454 Err(err @ (SErr::Ended | SErr::Submitted)) => {
455 drop(cmd_buf_data);
458 (
459 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
460 Some(err.into()),
461 )
462 }
463 Err(err @ SErr::Invalid) => {
464 drop(cmd_buf_data);
470 (
471 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
472 None,
473 )
474 }
475 Err(SErr::Unlocked) => {
476 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
477 }
478 }
479 }
480
481 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
482 profiling::scope!(
483 "CommandEncoder::run_compute_pass {}",
484 pass.base.label.as_deref().unwrap_or("")
485 );
486
487 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
488 let mut cmd_buf_data = cmd_enc.data.lock();
489
490 cmd_buf_data.unlock_encoder()?;
491
492 let base = pass.base.take();
493
494 if let Err(ComputePassError {
495 inner:
496 ComputePassErrorInner::EncoderState(
497 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
498 ),
499 scope: _,
500 }) = base
501 {
502 return Err(err.clone());
509 }
510
511 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
512 Ok(ArcCommand::RunComputePass {
513 pass: base?,
514 timestamp_writes: pass.timestamp_writes.take(),
515 })
516 })
517 }
518}
519
520pub(super) fn encode_compute_pass(
521 parent_state: &mut EncodingState<InnerCommandEncoder>,
522 mut base: BasePass<ArcComputeCommand, Infallible>,
523 mut timestamp_writes: Option<ArcPassTimestampWrites>,
524) -> Result<(), ComputePassError> {
525 let pass_scope = PassErrorScope::Pass;
526
527 let device = parent_state.device;
528
529 parent_state
533 .raw_encoder
534 .close_if_open()
535 .map_pass_err(pass_scope)?;
536 let raw_encoder = parent_state
537 .raw_encoder
538 .open_pass(base.label.as_deref())
539 .map_pass_err(pass_scope)?;
540
541 let mut debug_scope_depth = 0;
542
543 let mut state = State {
544 pipeline: None,
545
546 pass: pass::PassState {
547 base: EncodingState {
548 device,
549 raw_encoder,
550 tracker: parent_state.tracker,
551 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
552 texture_memory_actions: parent_state.texture_memory_actions,
553 as_actions: parent_state.as_actions,
554 temp_resources: parent_state.temp_resources,
555 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
556 snatch_guard: parent_state.snatch_guard,
557 debug_scope_depth: &mut debug_scope_depth,
558 },
559 binder: Binder::new(),
560 temp_offsets: Vec::new(),
561 dynamic_offset_count: 0,
562 pending_discard_init_fixups: SurfacesInDiscardState::new(),
563 scope: device.new_usage_scope(),
564 string_offset: 0,
565 },
566 active_query: None,
567
568 immediates: Vec::new(),
569
570 intermediate_trackers: Tracker::new(),
571 };
572
573 let indices = &device.tracker_indices;
574 state
575 .pass
576 .base
577 .tracker
578 .buffers
579 .set_size(indices.buffers.size());
580 state
581 .pass
582 .base
583 .tracker
584 .textures
585 .set_size(indices.textures.size());
586
587 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
588 if let Some(tw) = timestamp_writes.take() {
589 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
590
591 let query_set = state
592 .pass
593 .base
594 .tracker
595 .query_sets
596 .insert_single(tw.query_set);
597
598 let range = if let (Some(index_a), Some(index_b)) =
601 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
602 {
603 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
604 } else {
605 tw.beginning_of_pass_write_index
606 .or(tw.end_of_pass_write_index)
607 .map(|i| i..i + 1)
608 };
609 if let Some(range) = range {
612 unsafe {
613 state
614 .pass
615 .base
616 .raw_encoder
617 .reset_queries(query_set.raw(), range);
618 }
619 }
620
621 Some(hal::PassTimestampWrites {
622 query_set: query_set.raw(),
623 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
624 end_of_pass_write_index: tw.end_of_pass_write_index,
625 })
626 } else {
627 None
628 };
629
630 let hal_desc = hal::ComputePassDescriptor {
631 label: hal_label(base.label.as_deref(), device.instance_flags),
632 timestamp_writes,
633 };
634
635 unsafe {
636 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
637 }
638
639 for command in base.commands.drain(..) {
640 match command {
641 ArcComputeCommand::SetBindGroup {
642 index,
643 num_dynamic_offsets,
644 bind_group,
645 } => {
646 let scope = PassErrorScope::SetBindGroup;
647 pass::set_bind_group::<ComputePassErrorInner>(
648 &mut state.pass,
649 device,
650 &base.dynamic_offsets,
651 index,
652 num_dynamic_offsets,
653 bind_group,
654 false,
655 )
656 .map_pass_err(scope)?;
657 }
658 ArcComputeCommand::SetPipeline(pipeline) => {
659 let scope = PassErrorScope::SetPipelineCompute;
660 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
661 }
662 ArcComputeCommand::SetImmediate {
663 offset,
664 size_bytes,
665 values_offset,
666 } => {
667 let scope = PassErrorScope::SetImmediate;
668 pass::set_immediates::<ComputePassErrorInner, _>(
669 &mut state.pass,
670 &base.immediates_data,
671 offset,
672 size_bytes,
673 Some(values_offset),
674 |data_slice| {
675 let offset_in_elements = (offset / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
676 let size_in_elements =
677 (size_bytes / wgt::IMMEDIATE_DATA_ALIGNMENT) as usize;
678 state.immediates[offset_in_elements..][..size_in_elements]
679 .copy_from_slice(data_slice);
680 },
681 )
682 .map_pass_err(scope)?;
683 }
684 ArcComputeCommand::Dispatch(groups) => {
685 let scope = PassErrorScope::Dispatch { indirect: false };
686 dispatch(&mut state, groups).map_pass_err(scope)?;
687 }
688 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
689 let scope = PassErrorScope::Dispatch { indirect: true };
690 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
691 }
692 ArcComputeCommand::PushDebugGroup { color: _, len } => {
693 pass::push_debug_group(&mut state.pass, &base.string_data, len);
694 }
695 ArcComputeCommand::PopDebugGroup => {
696 let scope = PassErrorScope::PopDebugGroup;
697 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
698 .map_pass_err(scope)?;
699 }
700 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
701 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
702 }
703 ArcComputeCommand::WriteTimestamp {
704 query_set,
705 query_index,
706 } => {
707 let scope = PassErrorScope::WriteTimestamp;
708 pass::write_timestamp::<ComputePassErrorInner>(
709 &mut state.pass,
710 device,
711 None, query_set,
713 query_index,
714 )
715 .map_pass_err(scope)?;
716 }
717 ArcComputeCommand::BeginPipelineStatisticsQuery {
718 query_set,
719 query_index,
720 } => {
721 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
722 validate_and_begin_pipeline_statistics_query(
723 query_set,
724 state.pass.base.raw_encoder,
725 &mut state.pass.base.tracker.query_sets,
726 device,
727 query_index,
728 None,
729 &mut state.active_query,
730 )
731 .map_pass_err(scope)?;
732 }
733 ArcComputeCommand::EndPipelineStatisticsQuery => {
734 let scope = PassErrorScope::EndPipelineStatisticsQuery;
735 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
736 .map_pass_err(scope)?;
737 }
738 }
739 }
740
741 if *state.pass.base.debug_scope_depth > 0 {
742 Err(
743 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
744 .map_pass_err(pass_scope),
745 )?;
746 }
747
748 unsafe {
749 state.pass.base.raw_encoder.end_compute_pass();
750 }
751
752 let State {
753 pass: pass::PassState {
754 pending_discard_init_fixups,
755 ..
756 },
757 intermediate_trackers,
758 ..
759 } = state;
760
761 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
763
764 let transit = parent_state
768 .raw_encoder
769 .open_pass(hal_label(
770 Some("(wgpu internal) Pre Pass"),
771 device.instance_flags,
772 ))
773 .map_pass_err(pass_scope)?;
774 fixup_discarded_surfaces(
775 pending_discard_init_fixups.into_iter(),
776 transit,
777 &mut parent_state.tracker.textures,
778 device,
779 parent_state.snatch_guard,
780 );
781 CommandEncoder::insert_barriers_from_tracker(
782 transit,
783 parent_state.tracker,
784 &intermediate_trackers,
785 parent_state.snatch_guard,
786 );
787 parent_state
789 .raw_encoder
790 .close_and_swap()
791 .map_pass_err(pass_scope)?;
792
793 Ok(())
794}
795
796fn set_pipeline(
797 state: &mut State,
798 device: &Arc<Device>,
799 pipeline: Arc<ComputePipeline>,
800) -> Result<(), ComputePassErrorInner> {
801 pipeline.same_device(device)?;
802
803 state.pipeline = Some(pipeline.clone());
804
805 let pipeline = state
806 .pass
807 .base
808 .tracker
809 .compute_pipelines
810 .insert_single(pipeline)
811 .clone();
812
813 unsafe {
814 state
815 .pass
816 .base
817 .raw_encoder
818 .set_compute_pipeline(pipeline.raw());
819 }
820
821 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
823 &mut state.pass,
824 &pipeline.layout,
825 &pipeline.late_sized_buffer_groups,
826 || {
827 state.immediates.clear();
830 if pipeline.layout.immediate_size != 0 {
832 let len = pipeline.layout.immediate_size as usize
834 / wgt::IMMEDIATE_DATA_ALIGNMENT as usize;
835 state.immediates.extend(core::iter::repeat_n(0, len));
836 }
837 },
838 )
839}
840
841fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
842 api_log!("ComputePass::dispatch {groups:?}");
843
844 state.is_ready()?;
845
846 state.flush_bindings(None, false)?;
847
848 let groups_size_limit = state
849 .pass
850 .base
851 .device
852 .limits
853 .max_compute_workgroups_per_dimension;
854
855 if groups[0] > groups_size_limit
856 || groups[1] > groups_size_limit
857 || groups[2] > groups_size_limit
858 {
859 return Err(ComputePassErrorInner::Dispatch(
860 DispatchError::InvalidGroupSize {
861 current: groups,
862 limit: groups_size_limit,
863 },
864 ));
865 }
866
867 unsafe {
868 state.pass.base.raw_encoder.dispatch(groups);
869 }
870 Ok(())
871}
872
873fn dispatch_indirect(
874 state: &mut State,
875 device: &Arc<Device>,
876 buffer: Arc<Buffer>,
877 offset: u64,
878) -> Result<(), ComputePassErrorInner> {
879 api_log!("ComputePass::dispatch_indirect");
880
881 buffer.same_device(device)?;
882
883 state.is_ready()?;
884
885 state
886 .pass
887 .base
888 .device
889 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
890
891 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
892
893 if offset % 4 != 0 {
894 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
895 }
896
897 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
898 if end_offset > buffer.size {
899 return Err(ComputePassErrorInner::IndirectBufferOverrun {
900 offset,
901 end_offset,
902 buffer_size: buffer.size,
903 });
904 }
905
906 buffer.check_destroyed(state.pass.base.snatch_guard)?;
907
908 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
910 buffer.initialization_status.read().create_action(
911 &buffer,
912 offset..(offset + stride),
913 MemoryInitKind::NeedsInitializedMemory,
914 ),
915 );
916
917 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
918 let params = indirect_validation.dispatch.params(
919 &state.pass.base.device.limits,
920 offset,
921 buffer.size,
922 );
923
924 unsafe {
925 state
926 .pass
927 .base
928 .raw_encoder
929 .set_compute_pipeline(params.pipeline);
930 }
931
932 unsafe {
933 state.pass.base.raw_encoder.set_immediates(
934 params.pipeline_layout,
935 0,
936 &[params.offset_remainder as u32 / 4],
937 );
938 }
939
940 unsafe {
941 state.pass.base.raw_encoder.set_bind_group(
942 params.pipeline_layout,
943 0,
944 Some(params.dst_bind_group),
945 &[],
946 );
947 }
948 unsafe {
949 state.pass.base.raw_encoder.set_bind_group(
950 params.pipeline_layout,
951 1,
952 Some(
953 buffer
954 .indirect_validation_bind_groups
955 .get(state.pass.base.snatch_guard)
956 .unwrap()
957 .dispatch
958 .as_ref(),
959 ),
960 &[params.aligned_offset as u32],
961 );
962 }
963
964 let src_transition = state
965 .intermediate_trackers
966 .buffers
967 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
968 let src_barrier = src_transition
969 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
970 unsafe {
971 state
972 .pass
973 .base
974 .raw_encoder
975 .transition_buffers(src_barrier.as_slice());
976 }
977
978 unsafe {
979 state
980 .pass
981 .base
982 .raw_encoder
983 .transition_buffers(&[hal::BufferBarrier {
984 buffer: params.dst_buffer,
985 usage: hal::StateTransition {
986 from: wgt::BufferUses::INDIRECT,
987 to: wgt::BufferUses::STORAGE_READ_WRITE,
988 },
989 }]);
990 }
991
992 unsafe {
993 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
994 }
995
996 {
998 let pipeline = state.pipeline.as_ref().unwrap();
999
1000 unsafe {
1001 state
1002 .pass
1003 .base
1004 .raw_encoder
1005 .set_compute_pipeline(pipeline.raw());
1006 }
1007
1008 if !state.immediates.is_empty() {
1009 unsafe {
1010 state.pass.base.raw_encoder.set_immediates(
1011 pipeline.layout.raw(),
1012 0,
1013 &state.immediates,
1014 );
1015 }
1016 }
1017
1018 for (i, e) in state.pass.binder.list_valid() {
1019 let group = e.group.as_ref().unwrap();
1020 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1021 unsafe {
1022 state.pass.base.raw_encoder.set_bind_group(
1023 pipeline.layout.raw(),
1024 i as u32,
1025 Some(raw_bg),
1026 &e.dynamic_offsets,
1027 );
1028 }
1029 }
1030 }
1031
1032 unsafe {
1033 state
1034 .pass
1035 .base
1036 .raw_encoder
1037 .transition_buffers(&[hal::BufferBarrier {
1038 buffer: params.dst_buffer,
1039 usage: hal::StateTransition {
1040 from: wgt::BufferUses::STORAGE_READ_WRITE,
1041 to: wgt::BufferUses::INDIRECT,
1042 },
1043 }]);
1044 }
1045
1046 state.flush_bindings(Some(&buffer), false)?;
1047 unsafe {
1048 state
1049 .pass
1050 .base
1051 .raw_encoder
1052 .dispatch_indirect(params.dst_buffer, 0);
1053 }
1054 } else {
1055 state.flush_bindings(Some(&buffer), true)?;
1056
1057 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1058 unsafe {
1059 state
1060 .pass
1061 .base
1062 .raw_encoder
1063 .dispatch_indirect(buf_raw, offset);
1064 }
1065 }
1066
1067 Ok(())
1068}
1069
1070impl Global {
1083 pub fn compute_pass_set_bind_group(
1084 &self,
1085 pass: &mut ComputePass,
1086 index: u32,
1087 bind_group_id: Option<id::BindGroupId>,
1088 offsets: &[DynamicOffset],
1089 ) -> Result<(), PassStateError> {
1090 let scope = PassErrorScope::SetBindGroup;
1091
1092 let base = pass_base!(pass, scope);
1096
1097 if pass.current_bind_groups.set_and_check_redundant(
1098 bind_group_id,
1099 index,
1100 &mut base.dynamic_offsets,
1101 offsets,
1102 ) {
1103 return Ok(());
1104 }
1105
1106 let mut bind_group = None;
1107 if let Some(bind_group_id) = bind_group_id {
1108 let hub = &self.hub;
1109 bind_group = Some(pass_try!(
1110 base,
1111 scope,
1112 hub.bind_groups.get(bind_group_id).get(),
1113 ));
1114 }
1115
1116 base.commands.push(ArcComputeCommand::SetBindGroup {
1117 index,
1118 num_dynamic_offsets: offsets.len(),
1119 bind_group,
1120 });
1121
1122 Ok(())
1123 }
1124
1125 pub fn compute_pass_set_pipeline(
1126 &self,
1127 pass: &mut ComputePass,
1128 pipeline_id: id::ComputePipelineId,
1129 ) -> Result<(), PassStateError> {
1130 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1131
1132 let scope = PassErrorScope::SetPipelineCompute;
1133
1134 let base = pass_base!(pass, scope);
1137
1138 if redundant {
1139 return Ok(());
1140 }
1141
1142 let hub = &self.hub;
1143 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1144
1145 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1146
1147 Ok(())
1148 }
1149
1150 pub fn compute_pass_set_immediates(
1151 &self,
1152 pass: &mut ComputePass,
1153 offset: u32,
1154 data: &[u8],
1155 ) -> Result<(), PassStateError> {
1156 let scope = PassErrorScope::SetImmediate;
1157 let base = pass_base!(pass, scope);
1158
1159 if offset & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1160 pass_try!(
1161 base,
1162 scope,
1163 Err(ComputePassErrorInner::ImmediateOffsetAlignment),
1164 );
1165 }
1166
1167 if data.len() as u32 & (wgt::IMMEDIATE_DATA_ALIGNMENT - 1) != 0 {
1168 pass_try!(
1169 base,
1170 scope,
1171 Err(ComputePassErrorInner::ImmediateDataizeAlignment),
1172 )
1173 }
1174 let value_offset = pass_try!(
1175 base,
1176 scope,
1177 base.immediates_data
1178 .len()
1179 .try_into()
1180 .map_err(|_| ComputePassErrorInner::ImmediateOutOfMemory)
1181 );
1182
1183 base.immediates_data.extend(
1184 data.chunks_exact(wgt::IMMEDIATE_DATA_ALIGNMENT as usize)
1185 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1186 );
1187
1188 base.commands.push(ArcComputeCommand::SetImmediate {
1189 offset,
1190 size_bytes: data.len() as u32,
1191 values_offset: value_offset,
1192 });
1193
1194 Ok(())
1195 }
1196
1197 pub fn compute_pass_dispatch_workgroups(
1198 &self,
1199 pass: &mut ComputePass,
1200 groups_x: u32,
1201 groups_y: u32,
1202 groups_z: u32,
1203 ) -> Result<(), PassStateError> {
1204 let scope = PassErrorScope::Dispatch { indirect: false };
1205
1206 pass_base!(pass, scope)
1207 .commands
1208 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1209
1210 Ok(())
1211 }
1212
1213 pub fn compute_pass_dispatch_workgroups_indirect(
1214 &self,
1215 pass: &mut ComputePass,
1216 buffer_id: id::BufferId,
1217 offset: BufferAddress,
1218 ) -> Result<(), PassStateError> {
1219 let hub = &self.hub;
1220 let scope = PassErrorScope::Dispatch { indirect: true };
1221 let base = pass_base!(pass, scope);
1222
1223 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1224
1225 base.commands
1226 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1227
1228 Ok(())
1229 }
1230
1231 pub fn compute_pass_push_debug_group(
1232 &self,
1233 pass: &mut ComputePass,
1234 label: &str,
1235 color: u32,
1236 ) -> Result<(), PassStateError> {
1237 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1238
1239 let bytes = label.as_bytes();
1240 base.string_data.extend_from_slice(bytes);
1241
1242 base.commands.push(ArcComputeCommand::PushDebugGroup {
1243 color,
1244 len: bytes.len(),
1245 });
1246
1247 Ok(())
1248 }
1249
1250 pub fn compute_pass_pop_debug_group(
1251 &self,
1252 pass: &mut ComputePass,
1253 ) -> Result<(), PassStateError> {
1254 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1255
1256 base.commands.push(ArcComputeCommand::PopDebugGroup);
1257
1258 Ok(())
1259 }
1260
1261 pub fn compute_pass_insert_debug_marker(
1262 &self,
1263 pass: &mut ComputePass,
1264 label: &str,
1265 color: u32,
1266 ) -> Result<(), PassStateError> {
1267 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1268
1269 let bytes = label.as_bytes();
1270 base.string_data.extend_from_slice(bytes);
1271
1272 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1273 color,
1274 len: bytes.len(),
1275 });
1276
1277 Ok(())
1278 }
1279
1280 pub fn compute_pass_write_timestamp(
1281 &self,
1282 pass: &mut ComputePass,
1283 query_set_id: id::QuerySetId,
1284 query_index: u32,
1285 ) -> Result<(), PassStateError> {
1286 let scope = PassErrorScope::WriteTimestamp;
1287 let base = pass_base!(pass, scope);
1288
1289 let hub = &self.hub;
1290 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1291
1292 base.commands.push(ArcComputeCommand::WriteTimestamp {
1293 query_set,
1294 query_index,
1295 });
1296
1297 Ok(())
1298 }
1299
1300 pub fn compute_pass_begin_pipeline_statistics_query(
1301 &self,
1302 pass: &mut ComputePass,
1303 query_set_id: id::QuerySetId,
1304 query_index: u32,
1305 ) -> Result<(), PassStateError> {
1306 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1307 let base = pass_base!(pass, scope);
1308
1309 let hub = &self.hub;
1310 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1311
1312 base.commands
1313 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1314 query_set,
1315 query_index,
1316 });
1317
1318 Ok(())
1319 }
1320
1321 pub fn compute_pass_end_pipeline_statistics_query(
1322 &self,
1323 pass: &mut ComputePass,
1324 ) -> Result<(), PassStateError> {
1325 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1326 .commands
1327 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1328
1329 Ok(())
1330 }
1331}