1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{convert::Infallible, fmt, str};
9
10use crate::{
11 api_log, binding_model::BindError, command::pass::flush_bindings_helper,
12 resource::RawResourceAccess,
13};
14use crate::{
15 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
16 command::{
17 bind::{Binder, BinderError},
18 compute_command::ArcComputeCommand,
19 end_pipeline_statistics_query,
20 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
21 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
22 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
23 PassTimestampWrites, QueryUseError, StateChange,
24 },
25 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
26 global::Global,
27 hal_label, id,
28 init_tracker::MemoryInitKind,
29 pipeline::ComputePipeline,
30 resource::{
31 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
32 },
33 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
34 Label,
35};
36use crate::{command::InnerCommandEncoder, resource::DestroyedResourceError};
37use crate::{
38 command::{
39 encoder::EncodingState, pass, ArcCommand, CommandEncoder, DebugGroupError,
40 EncoderStateError, PassStateError, TimestampWritesError,
41 },
42 device::Device,
43};
44
45pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
46
47pub struct ComputePass {
55 base: ComputeBasePass,
57
58 parent: Option<Arc<CommandEncoder>>,
64
65 timestamp_writes: Option<ArcPassTimestampWrites>,
66
67 current_bind_groups: BindGroupStateChange,
69 current_pipeline: StateChange<id::ComputePipelineId>,
70}
71
72impl ComputePass {
73 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
75 let ArcComputePassDescriptor {
76 label,
77 timestamp_writes,
78 } = desc;
79
80 Self {
81 base: BasePass::new(&label),
82 parent: Some(parent),
83 timestamp_writes,
84
85 current_bind_groups: BindGroupStateChange::new(),
86 current_pipeline: StateChange::new(),
87 }
88 }
89
90 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
91 Self {
92 base: BasePass::new_invalid(label, err),
93 parent: Some(parent),
94 timestamp_writes: None,
95 current_bind_groups: BindGroupStateChange::new(),
96 current_pipeline: StateChange::new(),
97 }
98 }
99
100 #[inline]
101 pub fn label(&self) -> Option<&str> {
102 self.base.label.as_deref()
103 }
104}
105
106impl fmt::Debug for ComputePass {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 match self.parent {
109 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
110 None => write!(f, "ComputePass {{ parent: None }}"),
111 }
112 }
113}
114
115#[derive(Clone, Debug, Default)]
116pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
117 pub label: Label<'a>,
118 pub timestamp_writes: Option<PTW>,
120}
121
122type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
124
125#[derive(Clone, Debug, Error)]
126#[non_exhaustive]
127pub enum DispatchError {
128 #[error("Compute pipeline must be set")]
129 MissingPipeline(pass::MissingPipeline),
130 #[error(transparent)]
131 IncompatibleBindGroup(#[from] Box<BinderError>),
132 #[error(
133 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
134 )]
135 InvalidGroupSize { current: [u32; 3], limit: u32 },
136 #[error(transparent)]
137 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
138}
139
140impl WebGpuError for DispatchError {
141 fn webgpu_error_type(&self) -> ErrorType {
142 ErrorType::Validation
143 }
144}
145
146#[derive(Clone, Debug, Error)]
148pub enum ComputePassErrorInner {
149 #[error(transparent)]
150 Device(#[from] DeviceError),
151 #[error(transparent)]
152 EncoderState(#[from] EncoderStateError),
153 #[error("Parent encoder is invalid")]
154 InvalidParentEncoder,
155 #[error(transparent)]
156 DebugGroupError(#[from] DebugGroupError),
157 #[error(transparent)]
158 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
159 #[error(transparent)]
160 DestroyedResource(#[from] DestroyedResourceError),
161 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
162 UnalignedIndirectBufferOffset(BufferAddress),
163 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
164 IndirectBufferOverrun {
165 offset: u64,
166 end_offset: u64,
167 buffer_size: u64,
168 },
169 #[error(transparent)]
170 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
171 #[error(transparent)]
172 MissingBufferUsage(#[from] MissingBufferUsageError),
173 #[error(transparent)]
174 Dispatch(#[from] DispatchError),
175 #[error(transparent)]
176 Bind(#[from] BindError),
177 #[error(transparent)]
178 PushConstants(#[from] PushConstantUploadError),
179 #[error("Push constant offset must be aligned to 4 bytes")]
180 PushConstantOffsetAlignment,
181 #[error("Push constant size must be aligned to 4 bytes")]
182 PushConstantSizeAlignment,
183 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
184 PushConstantOutOfMemory,
185 #[error(transparent)]
186 QueryUse(#[from] QueryUseError),
187 #[error(transparent)]
188 MissingFeatures(#[from] MissingFeatures),
189 #[error(transparent)]
190 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
191 #[error("The compute pass has already been ended and no further commands can be recorded")]
192 PassEnded,
193 #[error(transparent)]
194 InvalidResource(#[from] InvalidResourceError),
195 #[error(transparent)]
196 TimestampWrites(#[from] TimestampWritesError),
197 #[error(transparent)]
199 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
200}
201
202#[derive(Clone, Debug, Error)]
205#[error("{scope}")]
206pub struct ComputePassError {
207 pub scope: PassErrorScope,
208 #[source]
209 pub(super) inner: ComputePassErrorInner,
210}
211
212impl From<pass::MissingPipeline> for ComputePassErrorInner {
213 fn from(value: pass::MissingPipeline) -> Self {
214 Self::Dispatch(DispatchError::MissingPipeline(value))
215 }
216}
217
218impl<E> MapPassErr<ComputePassError> for E
219where
220 E: Into<ComputePassErrorInner>,
221{
222 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
223 ComputePassError {
224 scope,
225 inner: self.into(),
226 }
227 }
228}
229
230impl WebGpuError for ComputePassError {
231 fn webgpu_error_type(&self) -> ErrorType {
232 let Self { scope: _, inner } = self;
233 let e: &dyn WebGpuError = match inner {
234 ComputePassErrorInner::Device(e) => e,
235 ComputePassErrorInner::EncoderState(e) => e,
236 ComputePassErrorInner::DebugGroupError(e) => e,
237 ComputePassErrorInner::DestroyedResource(e) => e,
238 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
239 ComputePassErrorInner::MissingBufferUsage(e) => e,
240 ComputePassErrorInner::Dispatch(e) => e,
241 ComputePassErrorInner::Bind(e) => e,
242 ComputePassErrorInner::PushConstants(e) => e,
243 ComputePassErrorInner::QueryUse(e) => e,
244 ComputePassErrorInner::MissingFeatures(e) => e,
245 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
246 ComputePassErrorInner::InvalidResource(e) => e,
247 ComputePassErrorInner::TimestampWrites(e) => e,
248 ComputePassErrorInner::InvalidValuesOffset(e) => e,
249
250 ComputePassErrorInner::InvalidParentEncoder
251 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
252 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
253 | ComputePassErrorInner::IndirectBufferOverrun { .. }
254 | ComputePassErrorInner::PushConstantOffsetAlignment
255 | ComputePassErrorInner::PushConstantSizeAlignment
256 | ComputePassErrorInner::PushConstantOutOfMemory
257 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
258 };
259 e.webgpu_error_type()
260 }
261}
262
263struct State<'scope, 'snatch_guard, 'cmd_enc> {
264 pipeline: Option<Arc<ComputePipeline>>,
265
266 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc>,
267
268 active_query: Option<(Arc<resource::QuerySet>, u32)>,
269
270 push_constants: Vec<u32>,
271
272 intermediate_trackers: Tracker,
273}
274
275impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
276 fn is_ready(&self) -> Result<(), DispatchError> {
277 if let Some(pipeline) = self.pipeline.as_ref() {
278 self.pass.binder.check_compatibility(pipeline.as_ref())?;
279 self.pass.binder.check_late_buffer_bindings()?;
280 Ok(())
281 } else {
282 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
283 }
284 }
285
286 fn flush_bindings(
318 &mut self,
319 indirect_buffer: Option<&Arc<Buffer>>,
320 indirect_buffer_index_if_not_validating: Option<TrackerIndex>,
321 ) -> Result<(), ComputePassErrorInner> {
322 let mut scope = self.pass.base.device.new_usage_scope();
323
324 for bind_group in self.pass.binder.list_active() {
325 unsafe { scope.merge_bind_group(&bind_group.used)? };
326 }
327
328 if let Some(buffer) = indirect_buffer {
334 scope
335 .buffers
336 .merge_single(buffer, wgt::BufferUses::INDIRECT)?;
337 }
338
339 self.intermediate_trackers
341 .buffers
342 .set_multiple(&mut scope.buffers, indirect_buffer_index_if_not_validating);
343
344 flush_bindings_helper(&mut self.pass, |bind_group| {
345 self.intermediate_trackers
346 .set_from_bind_group(&mut scope, &bind_group.used)
347 })?;
348
349 CommandEncoder::drain_barriers(
350 self.pass.base.raw_encoder,
351 &mut self.intermediate_trackers,
352 self.pass.base.snatch_guard,
353 );
354 Ok(())
355 }
356}
357
358impl Global {
361 pub fn command_encoder_begin_compute_pass(
372 &self,
373 encoder_id: id::CommandEncoderId,
374 desc: &ComputePassDescriptor<'_>,
375 ) -> (ComputePass, Option<CommandEncoderError>) {
376 use EncoderStateError as SErr;
377
378 let scope = PassErrorScope::Pass;
379 let hub = &self.hub;
380
381 let label = desc.label.as_deref().map(Cow::Borrowed);
382
383 let cmd_enc = hub.command_encoders.get(encoder_id);
384 let mut cmd_buf_data = cmd_enc.data.lock();
385
386 match cmd_buf_data.lock_encoder() {
387 Ok(()) => {
388 drop(cmd_buf_data);
389 if let Err(err) = cmd_enc.device.check_is_valid() {
390 return (
391 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
392 None,
393 );
394 }
395
396 match desc
397 .timestamp_writes
398 .as_ref()
399 .map(|tw| {
400 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
401 &cmd_enc.device,
402 &hub.query_sets.read(),
403 tw,
404 )
405 })
406 .transpose()
407 {
408 Ok(timestamp_writes) => {
409 let arc_desc = ArcComputePassDescriptor {
410 label,
411 timestamp_writes,
412 };
413 (ComputePass::new(cmd_enc, arc_desc), None)
414 }
415 Err(err) => (
416 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
417 None,
418 ),
419 }
420 }
421 Err(err @ SErr::Locked) => {
422 cmd_buf_data.invalidate(err.clone());
426 drop(cmd_buf_data);
427 (
428 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
429 None,
430 )
431 }
432 Err(err @ (SErr::Ended | SErr::Submitted)) => {
433 drop(cmd_buf_data);
436 (
437 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
438 Some(err.into()),
439 )
440 }
441 Err(err @ SErr::Invalid) => {
442 drop(cmd_buf_data);
448 (
449 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
450 None,
451 )
452 }
453 Err(SErr::Unlocked) => {
454 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
455 }
456 }
457 }
458
459 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
460 profiling::scope!(
461 "CommandEncoder::run_compute_pass {}",
462 pass.base.label.as_deref().unwrap_or("")
463 );
464
465 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
466 let mut cmd_buf_data = cmd_enc.data.lock();
467
468 cmd_buf_data.unlock_encoder()?;
469
470 let base = pass.base.take();
471
472 if matches!(
473 base,
474 Err(ComputePassError {
475 inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
476 scope: _,
477 })
478 ) {
479 return Err(EncoderStateError::Ended);
488 }
489
490 cmd_buf_data.push_with(|| -> Result<_, ComputePassError> {
491 Ok(ArcCommand::RunComputePass {
492 pass: base?,
493 timestamp_writes: pass.timestamp_writes.take(),
494 })
495 })
496 }
497}
498
499pub(super) fn encode_compute_pass(
500 parent_state: &mut EncodingState<InnerCommandEncoder>,
501 mut base: BasePass<ArcComputeCommand, Infallible>,
502 mut timestamp_writes: Option<ArcPassTimestampWrites>,
503) -> Result<(), ComputePassError> {
504 let pass_scope = PassErrorScope::Pass;
505
506 let device = parent_state.device;
507
508 parent_state
512 .raw_encoder
513 .close_if_open()
514 .map_pass_err(pass_scope)?;
515 let raw_encoder = parent_state
516 .raw_encoder
517 .open_pass(base.label.as_deref())
518 .map_pass_err(pass_scope)?;
519
520 let mut debug_scope_depth = 0;
521
522 let mut state = State {
523 pipeline: None,
524
525 pass: pass::PassState {
526 base: EncodingState {
527 device,
528 raw_encoder,
529 tracker: parent_state.tracker,
530 buffer_memory_init_actions: parent_state.buffer_memory_init_actions,
531 texture_memory_actions: parent_state.texture_memory_actions,
532 as_actions: parent_state.as_actions,
533 temp_resources: parent_state.temp_resources,
534 indirect_draw_validation_resources: parent_state.indirect_draw_validation_resources,
535 snatch_guard: parent_state.snatch_guard,
536 debug_scope_depth: &mut debug_scope_depth,
537 },
538 binder: Binder::new(),
539 temp_offsets: Vec::new(),
540 dynamic_offset_count: 0,
541 pending_discard_init_fixups: SurfacesInDiscardState::new(),
542 scope: device.new_usage_scope(),
543 string_offset: 0,
544 },
545 active_query: None,
546
547 push_constants: Vec::new(),
548
549 intermediate_trackers: Tracker::new(),
550 };
551
552 let indices = &device.tracker_indices;
553 state
554 .pass
555 .base
556 .tracker
557 .buffers
558 .set_size(indices.buffers.size());
559 state
560 .pass
561 .base
562 .tracker
563 .textures
564 .set_size(indices.textures.size());
565
566 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
567 if let Some(tw) = timestamp_writes.take() {
568 tw.query_set.same_device(device).map_pass_err(pass_scope)?;
569
570 let query_set = state
571 .pass
572 .base
573 .tracker
574 .query_sets
575 .insert_single(tw.query_set);
576
577 let range = if let (Some(index_a), Some(index_b)) =
580 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
581 {
582 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
583 } else {
584 tw.beginning_of_pass_write_index
585 .or(tw.end_of_pass_write_index)
586 .map(|i| i..i + 1)
587 };
588 if let Some(range) = range {
591 unsafe {
592 state
593 .pass
594 .base
595 .raw_encoder
596 .reset_queries(query_set.raw(), range);
597 }
598 }
599
600 Some(hal::PassTimestampWrites {
601 query_set: query_set.raw(),
602 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
603 end_of_pass_write_index: tw.end_of_pass_write_index,
604 })
605 } else {
606 None
607 };
608
609 let hal_desc = hal::ComputePassDescriptor {
610 label: hal_label(base.label.as_deref(), device.instance_flags),
611 timestamp_writes,
612 };
613
614 unsafe {
615 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
616 }
617
618 for command in base.commands.drain(..) {
619 match command {
620 ArcComputeCommand::SetBindGroup {
621 index,
622 num_dynamic_offsets,
623 bind_group,
624 } => {
625 let scope = PassErrorScope::SetBindGroup;
626 pass::set_bind_group::<ComputePassErrorInner>(
627 &mut state.pass,
628 device,
629 &base.dynamic_offsets,
630 index,
631 num_dynamic_offsets,
632 bind_group,
633 false,
634 )
635 .map_pass_err(scope)?;
636 }
637 ArcComputeCommand::SetPipeline(pipeline) => {
638 let scope = PassErrorScope::SetPipelineCompute;
639 set_pipeline(&mut state, device, pipeline).map_pass_err(scope)?;
640 }
641 ArcComputeCommand::SetPushConstant {
642 offset,
643 size_bytes,
644 values_offset,
645 } => {
646 let scope = PassErrorScope::SetPushConstant;
647 pass::set_push_constant::<ComputePassErrorInner, _>(
648 &mut state.pass,
649 &base.push_constant_data,
650 wgt::ShaderStages::COMPUTE,
651 offset,
652 size_bytes,
653 Some(values_offset),
654 |data_slice| {
655 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
656 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
657 state.push_constants[offset_in_elements..][..size_in_elements]
658 .copy_from_slice(data_slice);
659 },
660 )
661 .map_pass_err(scope)?;
662 }
663 ArcComputeCommand::Dispatch(groups) => {
664 let scope = PassErrorScope::Dispatch { indirect: false };
665 dispatch(&mut state, groups).map_pass_err(scope)?;
666 }
667 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
668 let scope = PassErrorScope::Dispatch { indirect: true };
669 dispatch_indirect(&mut state, device, buffer, offset).map_pass_err(scope)?;
670 }
671 ArcComputeCommand::PushDebugGroup { color: _, len } => {
672 pass::push_debug_group(&mut state.pass, &base.string_data, len);
673 }
674 ArcComputeCommand::PopDebugGroup => {
675 let scope = PassErrorScope::PopDebugGroup;
676 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
677 .map_pass_err(scope)?;
678 }
679 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
680 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
681 }
682 ArcComputeCommand::WriteTimestamp {
683 query_set,
684 query_index,
685 } => {
686 let scope = PassErrorScope::WriteTimestamp;
687 pass::write_timestamp::<ComputePassErrorInner>(
688 &mut state.pass,
689 device,
690 None, query_set,
692 query_index,
693 )
694 .map_pass_err(scope)?;
695 }
696 ArcComputeCommand::BeginPipelineStatisticsQuery {
697 query_set,
698 query_index,
699 } => {
700 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
701 validate_and_begin_pipeline_statistics_query(
702 query_set,
703 state.pass.base.raw_encoder,
704 &mut state.pass.base.tracker.query_sets,
705 device,
706 query_index,
707 None,
708 &mut state.active_query,
709 )
710 .map_pass_err(scope)?;
711 }
712 ArcComputeCommand::EndPipelineStatisticsQuery => {
713 let scope = PassErrorScope::EndPipelineStatisticsQuery;
714 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
715 .map_pass_err(scope)?;
716 }
717 }
718 }
719
720 if *state.pass.base.debug_scope_depth > 0 {
721 Err(
722 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
723 .map_pass_err(pass_scope),
724 )?;
725 }
726
727 unsafe {
728 state.pass.base.raw_encoder.end_compute_pass();
729 }
730
731 let State {
732 pass: pass::PassState {
733 pending_discard_init_fixups,
734 ..
735 },
736 intermediate_trackers,
737 ..
738 } = state;
739
740 parent_state.raw_encoder.close().map_pass_err(pass_scope)?;
742
743 let transit = parent_state
747 .raw_encoder
748 .open_pass(hal_label(
749 Some("(wgpu internal) Pre Pass"),
750 device.instance_flags,
751 ))
752 .map_pass_err(pass_scope)?;
753 fixup_discarded_surfaces(
754 pending_discard_init_fixups.into_iter(),
755 transit,
756 &mut parent_state.tracker.textures,
757 device,
758 parent_state.snatch_guard,
759 );
760 CommandEncoder::insert_barriers_from_tracker(
761 transit,
762 parent_state.tracker,
763 &intermediate_trackers,
764 parent_state.snatch_guard,
765 );
766 parent_state
768 .raw_encoder
769 .close_and_swap()
770 .map_pass_err(pass_scope)?;
771
772 Ok(())
773}
774
775fn set_pipeline(
776 state: &mut State,
777 device: &Arc<Device>,
778 pipeline: Arc<ComputePipeline>,
779) -> Result<(), ComputePassErrorInner> {
780 pipeline.same_device(device)?;
781
782 state.pipeline = Some(pipeline.clone());
783
784 let pipeline = state
785 .pass
786 .base
787 .tracker
788 .compute_pipelines
789 .insert_single(pipeline)
790 .clone();
791
792 unsafe {
793 state
794 .pass
795 .base
796 .raw_encoder
797 .set_compute_pipeline(pipeline.raw());
798 }
799
800 pass::change_pipeline_layout::<ComputePassErrorInner, _>(
802 &mut state.pass,
803 &pipeline.layout,
804 &pipeline.late_sized_buffer_groups,
805 || {
806 state.push_constants.clear();
809 if let Some(push_constant_range) =
811 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
812 pcr.stages
813 .contains(wgt::ShaderStages::COMPUTE)
814 .then_some(pcr.range.clone())
815 })
816 {
817 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
819 state.push_constants.extend(core::iter::repeat_n(0, len));
820 }
821 },
822 )
823}
824
825fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
826 api_log!("ComputePass::dispatch {groups:?}");
827
828 state.is_ready()?;
829
830 state.flush_bindings(None, None)?;
831
832 let groups_size_limit = state
833 .pass
834 .base
835 .device
836 .limits
837 .max_compute_workgroups_per_dimension;
838
839 if groups[0] > groups_size_limit
840 || groups[1] > groups_size_limit
841 || groups[2] > groups_size_limit
842 {
843 return Err(ComputePassErrorInner::Dispatch(
844 DispatchError::InvalidGroupSize {
845 current: groups,
846 limit: groups_size_limit,
847 },
848 ));
849 }
850
851 unsafe {
852 state.pass.base.raw_encoder.dispatch(groups);
853 }
854 Ok(())
855}
856
857fn dispatch_indirect(
858 state: &mut State,
859 device: &Arc<Device>,
860 buffer: Arc<Buffer>,
861 offset: u64,
862) -> Result<(), ComputePassErrorInner> {
863 api_log!("ComputePass::dispatch_indirect");
864
865 buffer.same_device(device)?;
866
867 state.is_ready()?;
868
869 state
870 .pass
871 .base
872 .device
873 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
874
875 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
876 buffer.check_destroyed(state.pass.base.snatch_guard)?;
877
878 if offset % 4 != 0 {
879 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
880 }
881
882 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
883 if end_offset > buffer.size {
884 return Err(ComputePassErrorInner::IndirectBufferOverrun {
885 offset,
886 end_offset,
887 buffer_size: buffer.size,
888 });
889 }
890
891 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
893 buffer.initialization_status.read().create_action(
894 &buffer,
895 offset..(offset + stride),
896 MemoryInitKind::NeedsInitializedMemory,
897 ),
898 );
899
900 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
901 let params = indirect_validation.dispatch.params(
902 &state.pass.base.device.limits,
903 offset,
904 buffer.size,
905 );
906
907 unsafe {
908 state
909 .pass
910 .base
911 .raw_encoder
912 .set_compute_pipeline(params.pipeline);
913 }
914
915 unsafe {
916 state.pass.base.raw_encoder.set_push_constants(
917 params.pipeline_layout,
918 wgt::ShaderStages::COMPUTE,
919 0,
920 &[params.offset_remainder as u32 / 4],
921 );
922 }
923
924 unsafe {
925 state.pass.base.raw_encoder.set_bind_group(
926 params.pipeline_layout,
927 0,
928 Some(params.dst_bind_group),
929 &[],
930 );
931 }
932 unsafe {
933 state.pass.base.raw_encoder.set_bind_group(
934 params.pipeline_layout,
935 1,
936 Some(
937 buffer
938 .indirect_validation_bind_groups
939 .get(state.pass.base.snatch_guard)
940 .unwrap()
941 .dispatch
942 .as_ref(),
943 ),
944 &[params.aligned_offset as u32],
945 );
946 }
947
948 let src_transition = state
949 .intermediate_trackers
950 .buffers
951 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
952 let src_barrier = src_transition
953 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
954 unsafe {
955 state
956 .pass
957 .base
958 .raw_encoder
959 .transition_buffers(src_barrier.as_slice());
960 }
961
962 unsafe {
963 state
964 .pass
965 .base
966 .raw_encoder
967 .transition_buffers(&[hal::BufferBarrier {
968 buffer: params.dst_buffer,
969 usage: hal::StateTransition {
970 from: wgt::BufferUses::INDIRECT,
971 to: wgt::BufferUses::STORAGE_READ_WRITE,
972 },
973 }]);
974 }
975
976 unsafe {
977 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
978 }
979
980 {
982 let pipeline = state.pipeline.as_ref().unwrap();
983
984 unsafe {
985 state
986 .pass
987 .base
988 .raw_encoder
989 .set_compute_pipeline(pipeline.raw());
990 }
991
992 if !state.push_constants.is_empty() {
993 unsafe {
994 state.pass.base.raw_encoder.set_push_constants(
995 pipeline.layout.raw(),
996 wgt::ShaderStages::COMPUTE,
997 0,
998 &state.push_constants,
999 );
1000 }
1001 }
1002
1003 for (i, e) in state.pass.binder.list_valid() {
1004 let group = e.group.as_ref().unwrap();
1005 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1006 unsafe {
1007 state.pass.base.raw_encoder.set_bind_group(
1008 pipeline.layout.raw(),
1009 i as u32,
1010 Some(raw_bg),
1011 &e.dynamic_offsets,
1012 );
1013 }
1014 }
1015 }
1016
1017 unsafe {
1018 state
1019 .pass
1020 .base
1021 .raw_encoder
1022 .transition_buffers(&[hal::BufferBarrier {
1023 buffer: params.dst_buffer,
1024 usage: hal::StateTransition {
1025 from: wgt::BufferUses::STORAGE_READ_WRITE,
1026 to: wgt::BufferUses::INDIRECT,
1027 },
1028 }]);
1029 }
1030
1031 state.flush_bindings(Some(&buffer), None)?;
1032 unsafe {
1033 state
1034 .pass
1035 .base
1036 .raw_encoder
1037 .dispatch_indirect(params.dst_buffer, 0);
1038 }
1039 } else {
1040 use crate::resource::Trackable;
1041 state.flush_bindings(Some(&buffer), Some(buffer.tracker_index()))?;
1042
1043 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1044 unsafe {
1045 state
1046 .pass
1047 .base
1048 .raw_encoder
1049 .dispatch_indirect(buf_raw, offset);
1050 }
1051 }
1052
1053 Ok(())
1054}
1055
1056impl Global {
1069 pub fn compute_pass_set_bind_group(
1070 &self,
1071 pass: &mut ComputePass,
1072 index: u32,
1073 bind_group_id: Option<id::BindGroupId>,
1074 offsets: &[DynamicOffset],
1075 ) -> Result<(), PassStateError> {
1076 let scope = PassErrorScope::SetBindGroup;
1077
1078 let base = pass_base!(pass, scope);
1082
1083 if pass.current_bind_groups.set_and_check_redundant(
1084 bind_group_id,
1085 index,
1086 &mut base.dynamic_offsets,
1087 offsets,
1088 ) {
1089 return Ok(());
1090 }
1091
1092 let mut bind_group = None;
1093 if bind_group_id.is_some() {
1094 let bind_group_id = bind_group_id.unwrap();
1095
1096 let hub = &self.hub;
1097 bind_group = Some(pass_try!(
1098 base,
1099 scope,
1100 hub.bind_groups.get(bind_group_id).get(),
1101 ));
1102 }
1103
1104 base.commands.push(ArcComputeCommand::SetBindGroup {
1105 index,
1106 num_dynamic_offsets: offsets.len(),
1107 bind_group,
1108 });
1109
1110 Ok(())
1111 }
1112
1113 pub fn compute_pass_set_pipeline(
1114 &self,
1115 pass: &mut ComputePass,
1116 pipeline_id: id::ComputePipelineId,
1117 ) -> Result<(), PassStateError> {
1118 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1119
1120 let scope = PassErrorScope::SetPipelineCompute;
1121
1122 let base = pass_base!(pass, scope);
1125
1126 if redundant {
1127 return Ok(());
1128 }
1129
1130 let hub = &self.hub;
1131 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1132
1133 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1134
1135 Ok(())
1136 }
1137
1138 pub fn compute_pass_set_push_constants(
1139 &self,
1140 pass: &mut ComputePass,
1141 offset: u32,
1142 data: &[u8],
1143 ) -> Result<(), PassStateError> {
1144 let scope = PassErrorScope::SetPushConstant;
1145 let base = pass_base!(pass, scope);
1146
1147 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1148 pass_try!(
1149 base,
1150 scope,
1151 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1152 );
1153 }
1154
1155 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1156 pass_try!(
1157 base,
1158 scope,
1159 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1160 )
1161 }
1162 let value_offset = pass_try!(
1163 base,
1164 scope,
1165 base.push_constant_data
1166 .len()
1167 .try_into()
1168 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1169 );
1170
1171 base.push_constant_data.extend(
1172 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1173 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1174 );
1175
1176 base.commands.push(ArcComputeCommand::SetPushConstant {
1177 offset,
1178 size_bytes: data.len() as u32,
1179 values_offset: value_offset,
1180 });
1181
1182 Ok(())
1183 }
1184
1185 pub fn compute_pass_dispatch_workgroups(
1186 &self,
1187 pass: &mut ComputePass,
1188 groups_x: u32,
1189 groups_y: u32,
1190 groups_z: u32,
1191 ) -> Result<(), PassStateError> {
1192 let scope = PassErrorScope::Dispatch { indirect: false };
1193
1194 pass_base!(pass, scope)
1195 .commands
1196 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1197
1198 Ok(())
1199 }
1200
1201 pub fn compute_pass_dispatch_workgroups_indirect(
1202 &self,
1203 pass: &mut ComputePass,
1204 buffer_id: id::BufferId,
1205 offset: BufferAddress,
1206 ) -> Result<(), PassStateError> {
1207 let hub = &self.hub;
1208 let scope = PassErrorScope::Dispatch { indirect: true };
1209 let base = pass_base!(pass, scope);
1210
1211 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1212
1213 base.commands
1214 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1215
1216 Ok(())
1217 }
1218
1219 pub fn compute_pass_push_debug_group(
1220 &self,
1221 pass: &mut ComputePass,
1222 label: &str,
1223 color: u32,
1224 ) -> Result<(), PassStateError> {
1225 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1226
1227 let bytes = label.as_bytes();
1228 base.string_data.extend_from_slice(bytes);
1229
1230 base.commands.push(ArcComputeCommand::PushDebugGroup {
1231 color,
1232 len: bytes.len(),
1233 });
1234
1235 Ok(())
1236 }
1237
1238 pub fn compute_pass_pop_debug_group(
1239 &self,
1240 pass: &mut ComputePass,
1241 ) -> Result<(), PassStateError> {
1242 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1243
1244 base.commands.push(ArcComputeCommand::PopDebugGroup);
1245
1246 Ok(())
1247 }
1248
1249 pub fn compute_pass_insert_debug_marker(
1250 &self,
1251 pass: &mut ComputePass,
1252 label: &str,
1253 color: u32,
1254 ) -> Result<(), PassStateError> {
1255 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1256
1257 let bytes = label.as_bytes();
1258 base.string_data.extend_from_slice(bytes);
1259
1260 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1261 color,
1262 len: bytes.len(),
1263 });
1264
1265 Ok(())
1266 }
1267
1268 pub fn compute_pass_write_timestamp(
1269 &self,
1270 pass: &mut ComputePass,
1271 query_set_id: id::QuerySetId,
1272 query_index: u32,
1273 ) -> Result<(), PassStateError> {
1274 let scope = PassErrorScope::WriteTimestamp;
1275 let base = pass_base!(pass, scope);
1276
1277 let hub = &self.hub;
1278 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1279
1280 base.commands.push(ArcComputeCommand::WriteTimestamp {
1281 query_set,
1282 query_index,
1283 });
1284
1285 Ok(())
1286 }
1287
1288 pub fn compute_pass_begin_pipeline_statistics_query(
1289 &self,
1290 pass: &mut ComputePass,
1291 query_set_id: id::QuerySetId,
1292 query_index: u32,
1293 ) -> Result<(), PassStateError> {
1294 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1295 let base = pass_base!(pass, scope);
1296
1297 let hub = &self.hub;
1298 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1299
1300 base.commands
1301 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1302 query_set,
1303 query_index,
1304 });
1305
1306 Ok(())
1307 }
1308
1309 pub fn compute_pass_end_pipeline_statistics_query(
1310 &self,
1311 pass: &mut ComputePass,
1312 ) -> Result<(), PassStateError> {
1313 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1314 .commands
1315 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1316
1317 Ok(())
1318 }
1319}