1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log,
12 binding_model::BindError,
13 command::pass::flush_bindings_helper,
14 resource::{RawResourceAccess, Trackable},
15};
16use crate::{
17 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
18 command::{
19 bind::{Binder, BinderError},
20 compute_command::ArcComputeCommand,
21 end_pipeline_statistics_query,
22 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25 PassTimestampWrites, QueryUseError, StateChange,
26 },
27 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28 global::Global,
29 hal_label, id,
30 init_tracker::MemoryInitKind,
31 pipeline::ComputePipeline,
32 resource::{
33 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker},
36 Label,
37};
38use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
39use crate::{
40 command::{
41 encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
42 EncoderStateError, PassStateError, TimestampWritesError,
43 },
44 device::Device,
45};
46
47pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
48
49pub struct ComputePass {
57 base: ComputeBasePass,
59
60 parent: Option<Arc<CommandEncoder>>,
66
67 timestamp_writes: Option<ArcPassTimestampWrites>,
68
69 current_bind_groups: BindGroupStateChange,
71 current_pipeline: StateChange<id::ComputePipelineId>,
72}
73
74impl ComputePass {
75 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
77 let ArcComputePassDescriptor {
78 label,
79 timestamp_writes,
80 } = desc;
81
82 Self {
83 base: BasePass::new(&label),
84 parent: Some(parent),
85 timestamp_writes,
86
87 current_bind_groups: BindGroupStateChange::new(),
88 current_pipeline: StateChange::new(),
89 }
90 }
91
92 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
93 Self {
94 base: BasePass::new_invalid(label, err),
95 parent: Some(parent),
96 timestamp_writes: None,
97 current_bind_groups: BindGroupStateChange::new(),
98 current_pipeline: StateChange::new(),
99 }
100 }
101
102 #[inline]
103 pub fn label(&self) -> Option<&str> {
104 self.base.label.as_deref()
105 }
106}
107
108impl fmt::Debug for ComputePass {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self.parent {
111 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
112 None => write!(f, "ComputePass {{ parent: None }}"),
113 }
114 }
115}
116
117#[derive(Clone, Debug, Default)]
118pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
119 pub label: Label<'a>,
120 pub timestamp_writes: Option<PTW>,
122}
123
124type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
126
127#[derive(Clone, Debug, Error)]
128#[non_exhaustive]
129pub enum DispatchError {
130 #[error("Compute pipeline must be set")]
131 MissingPipeline(pass::MissingPipeline),
132 #[error(transparent)]
133 IncompatibleBindGroup(#[from] Box<BinderError>),
134 #[error(
135 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
136 )]
137 InvalidGroupSize { current: [u32; 3], limit: u32 },
138 #[error(transparent)]
139 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
140}
141
142impl WebGpuError for DispatchError {
143 fn webgpu_error_type(&self) -> ErrorType {
144 ErrorType::Validation
145 }
146}
147
148#[derive(Clone, Debug, Error)]
150pub enum ComputePassErrorInner {
151 #[error(transparent)]
152 Device(#[from] DeviceError),
153 #[error(transparent)]
154 EncoderState(#[from] EncoderStateError),
155 #[error("Parent encoder is invalid")]
156 InvalidParentEncoder,
157 #[error(transparent)]
158 DebugGroupError(#[from] DebugGroupError),
159 #[error(transparent)]
160 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
161 #[error(transparent)]
162 DestroyedResource(#[from] DestroyedResourceError),
163 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
164 UnalignedIndirectBufferOffset(BufferAddress),
165 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
166 IndirectBufferOverrun {
167 offset: u64,
168 end_offset: u64,
169 buffer_size: u64,
170 },
171 #[error(transparent)]
172 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
173 #[error(transparent)]
174 MissingBufferUsage(#[from] MissingBufferUsageError),
175 #[error(transparent)]
176 Dispatch(#[from] DispatchError),
177 #[error(transparent)]
178 Bind(#[from] BindError),
179 #[error(transparent)]
180 PushConstants(#[from] PushConstantUploadError),
181 #[error("Push constant offset must be aligned to 4 bytes")]
182 PushConstantOffsetAlignment,
183 #[error("Push constant size must be aligned to 4 bytes")]
184 PushConstantSizeAlignment,
185 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
186 PushConstantOutOfMemory,
187 #[error(transparent)]
188 QueryUse(#[from] QueryUseError),
189 #[error(transparent)]
190 MissingFeatures(#[from] MissingFeatures),
191 #[error(transparent)]
192 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
193 #[error("The compute pass has already been ended and no further commands can be recorded")]
194 PassEnded,
195 #[error(transparent)]
196 InvalidResource(#[from] InvalidResourceError),
197 #[error(transparent)]
198 TimestampWrites(#[from] TimestampWritesError),
199 #[error(transparent)]
201 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
202}
203
204#[derive(Clone, Debug, Error)]
207#[error("{scope}")]
208pub struct ComputePassError {
209 pub scope: PassErrorScope,
210 #[source]
211 pub(super) inner: ComputePassErrorInner,
212}
213
214impl From<pass::MissingPipeline> for ComputePassErrorInner {
215 fn from(value: pass::MissingPipeline) -> Self {
216 Self::Dispatch(DispatchError::MissingPipeline(value))
217 }
218}
219
220impl<E> MapPassErr<ComputePassError> for E
221where
222 E: Into<ComputePassErrorInner>,
223{
224 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
225 ComputePassError {
226 scope,
227 inner: self.into(),
228 }
229 }
230}
231
232impl WebGpuError for ComputePassError {
233 fn webgpu_error_type(&self) -> ErrorType {
234 let Self { scope: _, inner } = self;
235 let e: &dyn WebGpuError = match inner {
236 ComputePassErrorInner::Device(e) => e,
237 ComputePassErrorInner::EncoderState(e) => e,
238 ComputePassErrorInner::DebugGroupError(e) => e,
239 ComputePassErrorInner::DestroyedResource(e) => e,
240 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
241 ComputePassErrorInner::MissingBufferUsage(e) => e,
242 ComputePassErrorInner::Dispatch(e) => e,
243 ComputePassErrorInner::Bind(e) => e,
244 ComputePassErrorInner::PushConstants(e) => e,
245 ComputePassErrorInner::QueryUse(e) => e,
246 ComputePassErrorInner::MissingFeatures(e) => e,
247 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
248 ComputePassErrorInner::InvalidResource(e) => e,
249 ComputePassErrorInner::TimestampWrites(e) => e,
250 ComputePassErrorInner::InvalidValuesOffset(e) => e,
251
252 ComputePassErrorInner::InvalidParentEncoder
253 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
254 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
255 | ComputePassErrorInner::IndirectBufferOverrun { .. }
256 | ComputePassErrorInner::PushConstantOffsetAlignment
257 | ComputePassErrorInner::PushConstantSizeAlignment
258 | ComputePassErrorInner::PushConstantOutOfMemory
259 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
260 };
261 e.webgpu_error_type()
262 }
263}
264
265struct State<'scope, 'snatch_guard, 'cmd_enc> {
266 pipeline: Option<Arc<ComputePipeline>>,
267
268 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
269
270 active_query: Option<(Arc<resource::QuerySet>, u32)>,
271
272 push_constants: Vec<u32>,
273
274 intermediate_trackers: Tracker,
275}
276
277impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
278 fn is_ready(&self) -> Result<(), DispatchError> {
279 if let Some(pipeline) = self.pipeline.as_ref() {
280 self.pass.binder.check_compatibility(pipeline.as_ref())?;
281 self.pass.binder.check_late_buffer_bindings()?;
282 Ok(())
283 } else {
284 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
285 }
286 }
287
288 fn flush_bindings(
327 &mut self,
328 indirect_buffer: Option<&Arc<Buffer>>,
329 track_indirect_buffer: bool,
330 ) -> Result<(), ComputePassErrorInner> {
331 for bind_group in self.pass.binder.list_active() {
332 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
333 }
334
335 if let Some(buffer) = indirect_buffer {
339 self.pass
340 .scope
341 .buffers
342 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
343 }
344
345 for bind_group in self.pass.binder.list_active() {
351 self.intermediate_trackers
352 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used);
353 }
354
355 if track_indirect_buffer {
356 self.intermediate_trackers
357 .buffers
358 .set_and_remove_from_usage_scope_sparse(
359 &mut self.pass.scope.buffers,
360 indirect_buffer.map(|buf| buf.tracker_index()),
361 );
362 } else if let Some(buffer) = indirect_buffer {
363 self.pass
364 .scope
365 .buffers
366 .remove_usage(buffer, wgt::BufferUses::INDIRECT);
367 }
368
369 flush_bindings_helper(&mut self.pass)?;
370
371 CommandEncoder::drain_barriers(
372 self.pass.base.raw_encoder,
373 &mut self.intermediate_trackers,
374 self.pass.base.snatch_guard,
375 );
376 Ok(())
377 }
378}
379
380impl Global {
383 pub fn command_encoder_begin_compute_pass(
394 &self,
395 encoder_id: id::CommandEncoderId,
396 desc: &ComputePassDescriptor<'_>,
397 ) -> (ComputePass, Option<CommandEncoderError>) {
398 use EncoderStateError as SErr;
399
400 let scope = PassErrorScope::Pass;
401 let hub = &self.hub;
402
403 let label = desc.label.as_deref().map(Cow::Borrowed);
404
405 let cmd_enc = hub.command_encoders.get(encoder_id);
406 let mut cmd_buf_data = cmd_enc.data.lock();
407
408 match cmd_buf_data.lock_encoder() {
409 Ok(()) => {
410 drop(cmd_buf_data);
411 if let Err(err) = cmd_enc.device.check_is_valid() {
412 return (
413 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414 None,
415 );
416 }
417
418 match desc
419 .timestamp_writes
420 .as_ref()
421 .map(|tw| {
422 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
423 &cmd_enc.device,
424 &hub.query_sets.read(),
425 tw,
426 )
427 })
428 .transpose()
429 {
430 Ok(timestamp_writes) => {
431 let arc_desc = ArcComputePassDescriptor {
432 label,
433 timestamp_writes,
434 };
435 (ComputePass::new(cmd_enc, arc_desc), None)
436 }
437 Err(err) => (
438 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
439 None,
440 ),
441 }
442 }
443 Err(err @ SErr::Locked) => {
444 cmd_buf_data.invalidate(err.clone());
448 drop(cmd_buf_data);
449 (
450 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
451 None,
452 )
453 }
454 Err(err @ (SErr::Ended | SErr::Submitted)) => {
455 drop(cmd_buf_data);
458 (
459 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
460 Some(err.into()),
461 )
462 }
463 Err(err @ SErr::Invalid) => {
464 drop(cmd_buf_data);
470 (
471 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
472 None,
473 )
474 }
475 Err(SErr::Unlocked) => {
476 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
477 }
478 }
479 }
480
481 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
482 profiling::scope!(
483 "CommandEncoder::run_compute_pass {}",
484 pass.base.label.as_deref().unwrap_or("")
485 );
486
487 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
488 let mut cmd_buf_data = cmd_enc.data.lock();
489
490 cmd_buf_data.unlock_encoder()?;
491
492 let base = pass.base.take();
493
494 if let Err(ComputePassError {
495 inner:
496 ComputePassErrorInner::EncoderState(
497 err @ (EncoderStateError::Locked | EncoderStateError::Ended),
498 ),
499 scope: _,
500 }) = base
501 {
502 return Err(err.clone());
509 }
510
511 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
512 Ok(ArcCommand::RunComputePass {
513 pass: base?,
514 timestamp_writes: pass.timestamp_writes.take(),
515 })
516 })
517 }
518}
519
520pub(super) fn encode_compute_pass(
521 parent_state: &mut EncodingState<InnerCommandEncoder>,
522 mut base: BasePass<ArcComputeCommand, Infallible>,
523 mut timestamp_writes: Option<ArcPassTimestampWrites>,
524) -> Result<(), ComputePassError> {
525 let pass_scope = PassErrorScope::Pass;
526
527 let device = parent_state.device;
528
529 parent_state
533 .raw_encoder
534 .close_if_open()
535 .map_pass_err(pass_scope)?;
536 let raw_encoder = parent_state
537 .raw_encoder
538 .open_pass(base.label.as_deref())
539 .map_pass_err(pass_scope)?;
540
541 let mut debug_scope_depth = 0;
542
543 let mut state = State {
544 pipeline: None,
545
546 pass: pass::PassState {
547 base: EncodingState {
548 device,
549 raw_encoder,
550 tracker: parent_state.tracker,
551 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
552 texture_memory_actions: parent_state.texture_memory_actions,
553 as_actions: parent_state.as_actions,
554 temp_resources: parent_state.temp_resources,
555 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
556 snatch_guard: parent_state.snatch_guard,
557 debug_scope_depth: &mut debug_scope_depth,
558 },
559 binder: Binder::new(),
560 temp_offsets: Vec::new(),
561 dynamic_offset_count: 0,
562 pending_discard_init_fixups: SurfacesInDiscardState::new(),
563 scope: device.new_usage_scope(),
564 string_offset: 0,
565 },
566 active_query: None,
567
568 push_constants: Vec::new(),
569
570 intermediate_trackers: Tracker::new(),
571 };
572
573 let indices = &device.tracker_indices;
574 state
575 .pass
576 .base
577 .tracker
578 .buffers
579 .set_size(indices.buffers.size());
580 state
581 .pass
582 .base
583 .tracker
584 .textures
585 .set_size(indices.textures.size());
586
587 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
588 if let Some(tw) = timestamp_writes.take() {
589 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
590
591 let query_set = state
592 .pass
593 .base
594 .tracker
595 .query_sets
596 .insert_single(tw.query_set);
597
598 let range = if let (Some(index_a), Some(index_b)) =
601 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
602 {
603 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
604 } else {
605 tw.beginning_of_pass_write_index
606 .or(tw.end_of_pass_write_index)
607 .map(|i| i..i + 1)
608 };
609 if let Some(range) = range {
612 unsafe {
613 state
614 .pass
615 .base
616 .raw_encoder
617 .reset_queries(query_set.raw(), range);
618 }
619 }
620
621 Some(hal::PassTimestampWrites {
622 query_set: query_set.raw(),
623 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
624 end_of_pass_write_index: tw.end_of_pass_write_index,
625 })
626 } else {
627 None
628 };
629
630 let hal_desc = hal::ComputePassDescriptor {
631 label: hal_label(base.label.as_deref(), device.instance_flags),
632 timestamp_writes,
633 };
634
635 unsafe {
636 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
637 }
638
639 for command in base.commands.drain(..) {
640 match command {
641 ArcComputeCommand::SetBindGroup {
642 index,
643 num_dynamic_offsets,
644 bind_group,
645 } => {
646 let scope = PassErrorScope::SetBindGroup;
647 pass::set_bind_group::<ComputePassErrorInner>(
648 &mut state.pass,
649 device,
650 &base.dynamic_offsets,
651 index,
652 num_dynamic_offsets,
653 bind_group,
654 false,
655 )
656 .map_pass_err(scope)?;
657 }
658 ArcComputeCommand::SetPipeline(pipeline) => {
659 let scope = PassErrorScope::SetPipelineCompute;
660 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
661 }
662 ArcComputeCommand::SetPushConstant {
663 offset,
664 size_bytes,
665 values_offset,
666 } => {
667 let scope = PassErrorScope::SetPushConstant;
668 pass::set_push_constant::<ComputePassErrorInner, _>(
669 &mut state.pass,
670 &base.push_constant_data,
671 wgt::ShaderStages::COMPUTE,
672 offset,
673 size_bytes,
674 Some(values_offset),
675 |data_slice| {
676 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
677 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
678 state.push_constants[offset_in_elements..][..size_in_elements]
679 .copy_from_slice(data_slice);
680 },
681 )
682 .map_pass_err(scope)?;
683 }
684 ArcComputeCommand::Dispatch(groups) => {
685 let scope = PassErrorScope::Dispatch { indirect: false };
686 dispatch(&mut state, groups).map_pass_err(scope)?;
687 }
688 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
689 let scope = PassErrorScope::Dispatch { indirect: true };
690 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
691 }
692 ArcComputeCommand::PushDebugGroup { color: _, len } => {
693 pass::push_debug_group(&mut state.pass, &base.string_data, len);
694 }
695 ArcComputeCommand::PopDebugGroup => {
696 let scope = PassErrorScope::PopDebugGroup;
697 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
698 .map_pass_err(scope)?;
699 }
700 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
701 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
702 }
703 ArcComputeCommand::WriteTimestamp {
704 query_set,
705 query_index,
706 } => {
707 let scope = PassErrorScope::WriteTimestamp;
708 pass::write_timestamp::<ComputePassErrorInner>(
709 &mut state.pass,
710 device,
711 None, query_set,
713 query_index,
714 )
715 .map_pass_err(scope)?;
716 }
717 ArcComputeCommand::BeginPipelineStatisticsQuery {
718 query_set,
719 query_index,
720 } => {
721 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
722 validate_and_begin_pipeline_statistics_query(
723 query_set,
724 state.pass.base.raw_encoder,
725 &mut state.pass.base.tracker.query_sets,
726 device,
727 query_index,
728 None,
729 &mut state.active_query,
730 )
731 .map_pass_err(scope)?;
732 }
733 ArcComputeCommand::EndPipelineStatisticsQuery => {
734 let scope = PassErrorScope::EndPipelineStatisticsQuery;
735 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
736 .map_pass_err(scope)?;
737 }
738 }
739 }
740
741 if *state.pass.base.debug_scope_depth > 0 {
742 Err(
743 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
744 .map_pass_err(pass_scope),
745 )?;
746 }
747
748 unsafe {
749 state.pass.base.raw_encoder.end_compute_pass();
750 }
751
752 let State {
753 pass: pass::PassState {
754 pending_discard_init_fixups,
755 ..
756 },
757 intermediate_trackers,
758 ..
759 } = state;
760
761 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
763
764 let transit = parent_state
768 .raw_encoder
769 .open_pass(hal_label(
770 Some("(wgpu internal) Pre Pass"),
771 device.instance_flags,
772 ))
773 .map_pass_err(pass_scope)?;
774 fixup_discarded_surfaces(
775 pending_discard_init_fixups.into_iter(),
776 transit,
777 &mut parent_state.tracker.textures,
778 device,
779 parent_state.snatch_guard,
780 );
781 CommandEncoder::insert_barriers_from_tracker(
782 transit,
783 parent_state.tracker,
784 &intermediate_trackers,
785 parent_state.snatch_guard,
786 );
787 parent_state
789 .raw_encoder
790 .close_and_swap()
791 .map_pass_err(pass_scope)?;
792
793 Ok(())
794}
795
796fn set_pipeline(
797 state: &mut State,
798 device: &Arc<Device>,
799 pipeline: Arc<ComputePipeline>,
800) -> Result<(), ComputePassErrorInner> {
801 pipeline.same_device(device)?;
802
803 state.pipeline = Some(pipeline.clone());
804
805 let pipeline = state
806 .pass
807 .base
808 .tracker
809 .compute_pipelines
810 .insert_single(pipeline)
811 .clone();
812
813 unsafe {
814 state
815 .pass
816 .base
817 .raw_encoder
818 .set_compute_pipeline(pipeline.raw());
819 }
820
821 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
823 &mut state.pass,
824 &pipeline.layout,
825 &pipeline.late_sized_buffer_groups,
826 || {
827 state.push_constants.clear();
830 if let Some(push_constant_range) =
832 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
833 pcr.stages
834 .contains(wgt::ShaderStages::COMPUTE)
835 .then_some(pcr.range.clone())
836 })
837 {
838 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
840 state.push_constants.extend(core::iter::repeat_n(0, len));
841 }
842 },
843 )
844}
845
846fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
847 api_log!("ComputePass::dispatch {groups:?}");
848
849 state.is_ready()?;
850
851 state.flush_bindings(None, false)?;
852
853 let groups_size_limit = state
854 .pass
855 .base
856 .device
857 .limits
858 .max_compute_workgroups_per_dimension;
859
860 if groups[0] > groups_size_limit
861 || groups[1] > groups_size_limit
862 || groups[2] > groups_size_limit
863 {
864 return Err(ComputePassErrorInner::Dispatch(
865 DispatchError::InvalidGroupSize {
866 current: groups,
867 limit: groups_size_limit,
868 },
869 ));
870 }
871
872 unsafe {
873 state.pass.base.raw_encoder.dispatch(groups);
874 }
875 Ok(())
876}
877
878fn dispatch_indirect(
879 state: &mut State,
880 device: &Arc<Device>,
881 buffer: Arc<Buffer>,
882 offset: u64,
883) -> Result<(), ComputePassErrorInner> {
884 api_log!("ComputePass::dispatch_indirect");
885
886 buffer.same_device(device)?;
887
888 state.is_ready()?;
889
890 state
891 .pass
892 .base
893 .device
894 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
895
896 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
897 buffer.check_destroyed(state.pass.base.snatch_guard)?;
898
899 if offset % 4 != 0 {
900 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
901 }
902
903 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
904 if end_offset > buffer.size {
905 return Err(ComputePassErrorInner::IndirectBufferOverrun {
906 offset,
907 end_offset,
908 buffer_size: buffer.size,
909 });
910 }
911
912 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
914 buffer.initialization_status.read().create_action(
915 &buffer,
916 offset..(offset + stride),
917 MemoryInitKind::NeedsInitializedMemory,
918 ),
919 );
920
921 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
922 let params = indirect_validation.dispatch.params(
923 &state.pass.base.device.limits,
924 offset,
925 buffer.size,
926 );
927
928 unsafe {
929 state
930 .pass
931 .base
932 .raw_encoder
933 .set_compute_pipeline(params.pipeline);
934 }
935
936 unsafe {
937 state.pass.base.raw_encoder.set_push_constants(
938 params.pipeline_layout,
939 wgt::ShaderStages::COMPUTE,
940 0,
941 &[params.offset_remainder as u32 / 4],
942 );
943 }
944
945 unsafe {
946 state.pass.base.raw_encoder.set_bind_group(
947 params.pipeline_layout,
948 0,
949 Some(params.dst_bind_group),
950 &[],
951 );
952 }
953 unsafe {
954 state.pass.base.raw_encoder.set_bind_group(
955 params.pipeline_layout,
956 1,
957 Some(
958 buffer
959 .indirect_validation_bind_groups
960 .get(state.pass.base.snatch_guard)
961 .unwrap()
962 .dispatch
963 .as_ref(),
964 ),
965 &[params.aligned_offset as u32],
966 );
967 }
968
969 let src_transition = state
970 .intermediate_trackers
971 .buffers
972 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
973 let src_barrier = src_transition
974 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
975 unsafe {
976 state
977 .pass
978 .base
979 .raw_encoder
980 .transition_buffers(src_barrier.as_slice());
981 }
982
983 unsafe {
984 state
985 .pass
986 .base
987 .raw_encoder
988 .transition_buffers(&[hal::BufferBarrier {
989 buffer: params.dst_buffer,
990 usage: hal::StateTransition {
991 from: wgt::BufferUses::INDIRECT,
992 to: wgt::BufferUses::STORAGE_READ_WRITE,
993 },
994 }]);
995 }
996
997 unsafe {
998 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
999 }
1000
1001 {
1003 let pipeline = state.pipeline.as_ref().unwrap();
1004
1005 unsafe {
1006 state
1007 .pass
1008 .base
1009 .raw_encoder
1010 .set_compute_pipeline(pipeline.raw());
1011 }
1012
1013 if !state.push_constants.is_empty() {
1014 unsafe {
1015 state.pass.base.raw_encoder.set_push_constants(
1016 pipeline.layout.raw(),
1017 wgt::ShaderStages::COMPUTE,
1018 0,
1019 &state.push_constants,
1020 );
1021 }
1022 }
1023
1024 for (i, e) in state.pass.binder.list_valid() {
1025 let group = e.group.as_ref().unwrap();
1026 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1027 unsafe {
1028 state.pass.base.raw_encoder.set_bind_group(
1029 pipeline.layout.raw(),
1030 i as u32,
1031 Some(raw_bg),
1032 &e.dynamic_offsets,
1033 );
1034 }
1035 }
1036 }
1037
1038 unsafe {
1039 state
1040 .pass
1041 .base
1042 .raw_encoder
1043 .transition_buffers(&[hal::BufferBarrier {
1044 buffer: params.dst_buffer,
1045 usage: hal::StateTransition {
1046 from: wgt::BufferUses::STORAGE_READ_WRITE,
1047 to: wgt::BufferUses::INDIRECT,
1048 },
1049 }]);
1050 }
1051
1052 state.flush_bindings(Some(&buffer), false)?;
1053 unsafe {
1054 state
1055 .pass
1056 .base
1057 .raw_encoder
1058 .dispatch_indirect(params.dst_buffer, 0);
1059 }
1060 } else {
1061 state.flush_bindings(Some(&buffer), true)?;
1062
1063 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1064 unsafe {
1065 state
1066 .pass
1067 .base
1068 .raw_encoder
1069 .dispatch_indirect(buf_raw, offset);
1070 }
1071 }
1072
1073 Ok(())
1074}
1075
1076impl Global {
1089 pub fn compute_pass_set_bind_group(
1090 &self,
1091 pass: &mut ComputePass,
1092 index: u32,
1093 bind_group_id: Option<id::BindGroupId>,
1094 offsets: &[DynamicOffset],
1095 ) -> Result<(), PassStateError> {
1096 let scope = PassErrorScope::SetBindGroup;
1097
1098 let base = pass_base!(pass, scope);
1102
1103 if pass.current_bind_groups.set_and_check_redundant(
1104 bind_group_id,
1105 index,
1106 &mut base.dynamic_offsets,
1107 offsets,
1108 ) {
1109 return Ok(());
1110 }
1111
1112 let mut bind_group = None;
1113 if bind_group_id.is_some() {
1114 let bind_group_id = bind_group_id.unwrap();
1115
1116 let hub = &self.hub;
1117 bind_group = Some(pass_try!(
1118 base,
1119 scope,
1120 hub.bind_groups.get(bind_group_id).get(),
1121 ));
1122 }
1123
1124 base.commands.push(ArcComputeCommand::SetBindGroup {
1125 index,
1126 num_dynamic_offsets: offsets.len(),
1127 bind_group,
1128 });
1129
1130 Ok(())
1131 }
1132
1133 pub fn compute_pass_set_pipeline(
1134 &self,
1135 pass: &mut ComputePass,
1136 pipeline_id: id::ComputePipelineId,
1137 ) -> Result<(), PassStateError> {
1138 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1139
1140 let scope = PassErrorScope::SetPipelineCompute;
1141
1142 let base = pass_base!(pass, scope);
1145
1146 if redundant {
1147 return Ok(());
1148 }
1149
1150 let hub = &self.hub;
1151 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1152
1153 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1154
1155 Ok(())
1156 }
1157
1158 pub fn compute_pass_set_push_constants(
1159 &self,
1160 pass: &mut ComputePass,
1161 offset: u32,
1162 data: &[u8],
1163 ) -> Result<(), PassStateError> {
1164 let scope = PassErrorScope::SetPushConstant;
1165 let base = pass_base!(pass, scope);
1166
1167 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1168 pass_try!(
1169 base,
1170 scope,
1171 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1172 );
1173 }
1174
1175 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1176 pass_try!(
1177 base,
1178 scope,
1179 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1180 )
1181 }
1182 let value_offset = pass_try!(
1183 base,
1184 scope,
1185 base.push_constant_data
1186 .len()
1187 .try_into()
1188 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1189 );
1190
1191 base.push_constant_data.extend(
1192 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1193 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1194 );
1195
1196 base.commands.push(ArcComputeCommand::SetPushConstant {
1197 offset,
1198 size_bytes: data.len() as u32,
1199 values_offset: value_offset,
1200 });
1201
1202 Ok(())
1203 }
1204
1205 pub fn compute_pass_dispatch_workgroups(
1206 &self,
1207 pass: &mut ComputePass,
1208 groups_x: u32,
1209 groups_y: u32,
1210 groups_z: u32,
1211 ) -> Result<(), PassStateError> {
1212 let scope = PassErrorScope::Dispatch { indirect: false };
1213
1214 pass_base!(pass, scope)
1215 .commands
1216 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1217
1218 Ok(())
1219 }
1220
1221 pub fn compute_pass_dispatch_workgroups_indirect(
1222 &self,
1223 pass: &mut ComputePass,
1224 buffer_id: id::BufferId,
1225 offset: BufferAddress,
1226 ) -> Result<(), PassStateError> {
1227 let hub = &self.hub;
1228 let scope = PassErrorScope::Dispatch { indirect: true };
1229 let base = pass_base!(pass, scope);
1230
1231 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1232
1233 base.commands
1234 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1235
1236 Ok(())
1237 }
1238
1239 pub fn compute_pass_push_debug_group(
1240 &self,
1241 pass: &mut ComputePass,
1242 label: &str,
1243 color: u32,
1244 ) -> Result<(), PassStateError> {
1245 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1246
1247 let bytes = label.as_bytes();
1248 base.string_data.extend_from_slice(bytes);
1249
1250 base.commands.push(ArcComputeCommand::PushDebugGroup {
1251 color,
1252 len: bytes.len(),
1253 });
1254
1255 Ok(())
1256 }
1257
1258 pub fn compute_pass_pop_debug_group(
1259 &self,
1260 pass: &mut ComputePass,
1261 ) -> Result<(), PassStateError> {
1262 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1263
1264 base.commands.push(ArcComputeCommand::PopDebugGroup);
1265
1266 Ok(())
1267 }
1268
1269 pub fn compute_pass_insert_debug_marker(
1270 &self,
1271 pass: &mut ComputePass,
1272 label: &str,
1273 color: u32,
1274 ) -> Result<(), PassStateError> {
1275 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1276
1277 let bytes = label.as_bytes();
1278 base.string_data.extend_from_slice(bytes);
1279
1280 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1281 color,
1282 len: bytes.len(),
1283 });
1284
1285 Ok(())
1286 }
1287
1288 pub fn compute_pass_write_timestamp(
1289 &self,
1290 pass: &mut ComputePass,
1291 query_set_id: id::QuerySetId,
1292 query_index: u32,
1293 ) -> Result<(), PassStateError> {
1294 let scope = PassErrorScope::WriteTimestamp;
1295 let base = pass_base!(pass, scope);
1296
1297 let hub = &self.hub;
1298 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300 base.commands.push(ArcComputeCommand::WriteTimestamp {
1301 query_set,
1302 query_index,
1303 });
1304
1305 Ok(())
1306 }
1307
1308 pub fn compute_pass_begin_pipeline_statistics_query(
1309 &self,
1310 pass: &mut ComputePass,
1311 query_set_id: id::QuerySetId,
1312 query_index: u32,
1313 ) -> Result<(), PassStateError> {
1314 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1315 let base = pass_base!(pass, scope);
1316
1317 let hub = &self.hub;
1318 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1319
1320 base.commands
1321 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1322 query_set,
1323 query_index,
1324 });
1325
1326 Ok(())
1327 }
1328
1329 pub fn compute_pass_end_pipeline_statistics_query(
1330 &self,
1331 pass: &mut ComputePass,
1332 ) -> Result<(), PassStateError> {
1333 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1334 .commands
1335 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1336
1337 Ok(())
1338 }
1339}