1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{
11 pass, CommandEncoder, DebugGroupError, EncoderStateError, PassStateError, TimestampWritesError,
12};
13use crate::resource::DestroyedResourceError;
14use crate::{binding_model::BindError, resource::RawResourceAccess};
15use crate::{
16 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
17 command::{
18 bind::{Binder, BinderError},
19 compute_command::ArcComputeCommand,
20 end_pipeline_statistics_query,
21 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
22 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
23 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
24 PassTimestampWrites, QueryUseError, StateChange,
25 },
26 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
27 global::Global,
28 hal_label, id,
29 init_tracker::MemoryInitKind,
30 pipeline::ComputePipeline,
31 resource::{
32 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
33 },
34 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
35 Label,
36};
37
38pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
39
40pub struct ComputePass {
48 base: ComputeBasePass,
50
51 parent: Option<Arc<CommandEncoder>>,
57
58 timestamp_writes: Option<ArcPassTimestampWrites>,
59
60 current_bind_groups: BindGroupStateChange,
62 current_pipeline: StateChange<id::ComputePipelineId>,
63}
64
65impl ComputePass {
66 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
68 let ArcComputePassDescriptor {
69 label,
70 timestamp_writes,
71 } = desc;
72
73 Self {
74 base: BasePass::new(&label),
75 parent: Some(parent),
76 timestamp_writes,
77
78 current_bind_groups: BindGroupStateChange::new(),
79 current_pipeline: StateChange::new(),
80 }
81 }
82
83 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
84 Self {
85 base: BasePass::new_invalid(label, err),
86 parent: Some(parent),
87 timestamp_writes: None,
88 current_bind_groups: BindGroupStateChange::new(),
89 current_pipeline: StateChange::new(),
90 }
91 }
92
93 #[inline]
94 pub fn label(&self) -> Option<&str> {
95 self.base.label.as_deref()
96 }
97}
98
99impl fmt::Debug for ComputePass {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self.parent {
102 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
103 None => write!(f, "ComputePass {{ parent: None }}"),
104 }
105 }
106}
107
108#[derive(Clone, Debug, Default)]
109pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
110 pub label: Label<'a>,
111 pub timestamp_writes: Option<PTW>,
113}
114
115type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
117
118#[derive(Clone, Debug, Error)]
119#[non_exhaustive]
120pub enum DispatchError {
121 #[error("Compute pipeline must be set")]
122 MissingPipeline(pass::MissingPipeline),
123 #[error(transparent)]
124 IncompatibleBindGroup(#[from] Box<BinderError>),
125 #[error(
126 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
127 )]
128 InvalidGroupSize { current: [u32; 3], limit: u32 },
129 #[error(transparent)]
130 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
131}
132
133impl WebGpuError for DispatchError {
134 fn webgpu_error_type(&self) -> ErrorType {
135 ErrorType::Validation
136 }
137}
138
139#[derive(Clone, Debug, Error)]
141pub enum ComputePassErrorInner {
142 #[error(transparent)]
143 Device(#[from] DeviceError),
144 #[error(transparent)]
145 EncoderState(#[from] EncoderStateError),
146 #[error("Parent encoder is invalid")]
147 InvalidParentEncoder,
148 #[error(transparent)]
149 DebugGroupError(#[from] DebugGroupError),
150 #[error(transparent)]
151 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
152 #[error(transparent)]
153 DestroyedResource(#[from] DestroyedResourceError),
154 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
155 UnalignedIndirectBufferOffset(BufferAddress),
156 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
157 IndirectBufferOverrun {
158 offset: u64,
159 end_offset: u64,
160 buffer_size: u64,
161 },
162 #[error(transparent)]
163 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
164 #[error(transparent)]
165 MissingBufferUsage(#[from] MissingBufferUsageError),
166 #[error(transparent)]
167 Dispatch(#[from] DispatchError),
168 #[error(transparent)]
169 Bind(#[from] BindError),
170 #[error(transparent)]
171 PushConstants(#[from] PushConstantUploadError),
172 #[error("Push constant offset must be aligned to 4 bytes")]
173 PushConstantOffsetAlignment,
174 #[error("Push constant size must be aligned to 4 bytes")]
175 PushConstantSizeAlignment,
176 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
177 PushConstantOutOfMemory,
178 #[error(transparent)]
179 QueryUse(#[from] QueryUseError),
180 #[error(transparent)]
181 MissingFeatures(#[from] MissingFeatures),
182 #[error(transparent)]
183 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
184 #[error("The compute pass has already been ended and no further commands can be recorded")]
185 PassEnded,
186 #[error(transparent)]
187 InvalidResource(#[from] InvalidResourceError),
188 #[error(transparent)]
189 TimestampWrites(#[from] TimestampWritesError),
190 #[error(transparent)]
192 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
193}
194
195#[derive(Clone, Debug, Error)]
198#[error("{scope}")]
199pub struct ComputePassError {
200 pub scope: PassErrorScope,
201 #[source]
202 pub(super) inner: ComputePassErrorInner,
203}
204
205impl From<pass::MissingPipeline> for ComputePassErrorInner {
206 fn from(value: pass::MissingPipeline) -> Self {
207 Self::Dispatch(DispatchError::MissingPipeline(value))
208 }
209}
210
211impl<E> MapPassErr<ComputePassError> for E
212where
213 E: Into<ComputePassErrorInner>,
214{
215 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
216 ComputePassError {
217 scope,
218 inner: self.into(),
219 }
220 }
221}
222
223impl WebGpuError for ComputePassError {
224 fn webgpu_error_type(&self) -> ErrorType {
225 let Self { scope: _, inner } = self;
226 let e: &dyn WebGpuError = match inner {
227 ComputePassErrorInner::Device(e) => e,
228 ComputePassErrorInner::EncoderState(e) => e,
229 ComputePassErrorInner::DebugGroupError(e) => e,
230 ComputePassErrorInner::DestroyedResource(e) => e,
231 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
232 ComputePassErrorInner::MissingBufferUsage(e) => e,
233 ComputePassErrorInner::Dispatch(e) => e,
234 ComputePassErrorInner::Bind(e) => e,
235 ComputePassErrorInner::PushConstants(e) => e,
236 ComputePassErrorInner::QueryUse(e) => e,
237 ComputePassErrorInner::MissingFeatures(e) => e,
238 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
239 ComputePassErrorInner::InvalidResource(e) => e,
240 ComputePassErrorInner::TimestampWrites(e) => e,
241 ComputePassErrorInner::InvalidValuesOffset(e) => e,
242
243 ComputePassErrorInner::InvalidParentEncoder
244 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
245 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
246 | ComputePassErrorInner::IndirectBufferOverrun { .. }
247 | ComputePassErrorInner::PushConstantOffsetAlignment
248 | ComputePassErrorInner::PushConstantSizeAlignment
249 | ComputePassErrorInner::PushConstantOutOfMemory
250 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
251 };
252 e.webgpu_error_type()
253 }
254}
255
256struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
257 pipeline: Option<Arc<ComputePipeline>>,
258
259 general: pass::BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
260
261 active_query: Option<(Arc<resource::QuerySet>, u32)>,
262
263 push_constants: Vec<u32>,
264
265 intermediate_trackers: Tracker,
266}
267
268impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
269 State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
270{
271 fn is_ready(&self) -> Result<(), DispatchError> {
272 if let Some(pipeline) = self.pipeline.as_ref() {
273 self.general.binder.check_compatibility(pipeline.as_ref())?;
274 self.general.binder.check_late_buffer_bindings()?;
275 Ok(())
276 } else {
277 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
278 }
279 }
280
281 fn flush_states(
284 &mut self,
285 indirect_buffer: Option<TrackerIndex>,
286 ) -> Result<(), ResourceUsageCompatibilityError> {
287 for bind_group in self.general.binder.list_active() {
288 unsafe { self.general.scope.merge_bind_group(&bind_group.used)? };
289 }
292
293 for bind_group in self.general.binder.list_active() {
294 unsafe {
295 self.intermediate_trackers
296 .set_and_remove_from_usage_scope_sparse(
297 &mut self.general.scope,
298 &bind_group.used,
299 )
300 }
301 }
302
303 unsafe {
305 self.intermediate_trackers
306 .buffers
307 .set_and_remove_from_usage_scope_sparse(
308 &mut self.general.scope.buffers,
309 indirect_buffer,
310 );
311 }
312
313 CommandEncoder::drain_barriers(
314 self.general.raw_encoder,
315 &mut self.intermediate_trackers,
316 self.general.snatch_guard,
317 );
318 Ok(())
319 }
320}
321
322impl Global {
325 pub fn command_encoder_begin_compute_pass(
336 &self,
337 encoder_id: id::CommandEncoderId,
338 desc: &ComputePassDescriptor<'_>,
339 ) -> (ComputePass, Option<CommandEncoderError>) {
340 use EncoderStateError as SErr;
341
342 let scope = PassErrorScope::Pass;
343 let hub = &self.hub;
344
345 let label = desc.label.as_deref().map(Cow::Borrowed);
346
347 let cmd_enc = hub.command_encoders.get(encoder_id);
348 let mut cmd_buf_data = cmd_enc.data.lock();
349
350 match cmd_buf_data.lock_encoder() {
351 Ok(()) => {
352 drop(cmd_buf_data);
353 if let Err(err) = cmd_enc.device.check_is_valid() {
354 return (
355 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
356 None,
357 );
358 }
359
360 match desc
361 .timestamp_writes
362 .as_ref()
363 .map(|tw| {
364 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
365 &cmd_enc.device,
366 &hub.query_sets.read(),
367 tw,
368 )
369 })
370 .transpose()
371 {
372 Ok(timestamp_writes) => {
373 let arc_desc = ArcComputePassDescriptor {
374 label,
375 timestamp_writes,
376 };
377 (ComputePass::new(cmd_enc, arc_desc), None)
378 }
379 Err(err) => (
380 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
381 None,
382 ),
383 }
384 }
385 Err(err @ SErr::Locked) => {
386 cmd_buf_data.invalidate(err.clone());
390 drop(cmd_buf_data);
391 (
392 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
393 None,
394 )
395 }
396 Err(err @ (SErr::Ended | SErr::Submitted)) => {
397 drop(cmd_buf_data);
400 (
401 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
402 Some(err.into()),
403 )
404 }
405 Err(err @ SErr::Invalid) => {
406 drop(cmd_buf_data);
412 (
413 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
414 None,
415 )
416 }
417 Err(SErr::Unlocked) => {
418 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
419 }
420 }
421 }
422
423 #[doc(hidden)]
429 #[cfg(any(feature = "serde", feature = "replay"))]
430 pub fn compute_pass_end_with_unresolved_commands(
431 &self,
432 encoder_id: id::CommandEncoderId,
433 base: BasePass<super::ComputeCommand, core::convert::Infallible>,
434 timestamp_writes: Option<&PassTimestampWrites>,
435 ) {
436 #[cfg(feature = "trace")]
437 {
438 let cmd_enc = self.hub.command_encoders.get(encoder_id);
439 let mut cmd_buf_data = cmd_enc.data.lock();
440 let cmd_buf_data = cmd_buf_data.get_inner();
441
442 if let Some(ref mut list) = cmd_buf_data.commands {
443 list.push(crate::device::trace::Command::RunComputePass {
444 base: BasePass {
445 label: base.label.clone(),
446 error: None,
447 commands: base.commands.clone(),
448 dynamic_offsets: base.dynamic_offsets.clone(),
449 string_data: base.string_data.clone(),
450 push_constant_data: base.push_constant_data.clone(),
451 },
452 timestamp_writes: timestamp_writes.cloned(),
453 });
454 }
455 }
456
457 let BasePass {
458 label,
459 error: _,
460 commands,
461 dynamic_offsets,
462 string_data,
463 push_constant_data,
464 } = base;
465
466 let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
467 encoder_id,
468 &ComputePassDescriptor {
469 label: label.as_deref().map(Cow::Borrowed),
470 timestamp_writes: timestamp_writes.cloned(),
471 },
472 );
473 if let Some(err) = encoder_error {
474 panic!("{:?}", err);
475 };
476
477 compute_pass.base = BasePass {
478 label,
479 error: None,
480 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
481 .unwrap(),
482 dynamic_offsets,
483 string_data,
484 push_constant_data,
485 };
486
487 self.compute_pass_end(&mut compute_pass).unwrap();
488 }
489
490 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
491 let pass_scope = PassErrorScope::Pass;
492 profiling::scope!(
493 "CommandEncoder::run_compute_pass {}",
494 pass.base.label.as_deref().unwrap_or("")
495 );
496
497 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
498 let mut cmd_buf_data = cmd_enc.data.lock();
499
500 if let Some(err) = pass.base.error.take() {
501 if matches!(
502 err,
503 ComputePassError {
504 inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
505 scope: _,
506 }
507 ) {
508 return Err(EncoderStateError::Ended);
513 } else {
514 cmd_buf_data.invalidate(err);
518 return Ok(());
519 }
520 }
521
522 cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
523 let device = &cmd_enc.device;
524 device.check_is_valid().map_pass_err(pass_scope)?;
525
526 let base = &mut pass.base;
527
528 let encoder = &mut cmd_buf_data.encoder;
529
530 encoder.close_if_open().map_pass_err(pass_scope)?;
534 let raw_encoder = encoder
535 .open_pass(base.label.as_deref())
536 .map_pass_err(pass_scope)?;
537
538 let snatch_guard = device.snatchable_lock.read();
539
540 let mut state = State {
541 pipeline: None,
542
543 general: pass::BaseState {
544 device,
545 raw_encoder,
546 tracker: &mut cmd_buf_data.trackers,
547 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
548 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
549 as_actions: &mut cmd_buf_data.as_actions,
550 binder: Binder::new(),
551 temp_offsets: Vec::new(),
552 dynamic_offset_count: 0,
553
554 pending_discard_init_fixups: SurfacesInDiscardState::new(),
555
556 snatch_guard: &snatch_guard,
557 scope: device.new_usage_scope(),
558
559 debug_scope_depth: 0,
560 string_offset: 0,
561 },
562 active_query: None,
563
564 push_constants: Vec::new(),
565
566 intermediate_trackers: Tracker::new(),
567 };
568
569 let indices = &state.general.device.tracker_indices;
570 state
571 .general
572 .tracker
573 .buffers
574 .set_size(indices.buffers.size());
575 state
576 .general
577 .tracker
578 .textures
579 .set_size(indices.textures.size());
580
581 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
582 if let Some(tw) = pass.timestamp_writes.take() {
583 tw.query_set
584 .same_device_as(cmd_enc.as_ref())
585 .map_pass_err(pass_scope)?;
586
587 let query_set = state.general.tracker.query_sets.insert_single(tw.query_set);
588
589 let range = if let (Some(index_a), Some(index_b)) =
592 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
593 {
594 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
595 } else {
596 tw.beginning_of_pass_write_index
597 .or(tw.end_of_pass_write_index)
598 .map(|i| i..i + 1)
599 };
600 if let Some(range) = range {
603 unsafe {
604 state
605 .general
606 .raw_encoder
607 .reset_queries(query_set.raw(), range);
608 }
609 }
610
611 Some(hal::PassTimestampWrites {
612 query_set: query_set.raw(),
613 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
614 end_of_pass_write_index: tw.end_of_pass_write_index,
615 })
616 } else {
617 None
618 };
619
620 let hal_desc = hal::ComputePassDescriptor {
621 label: hal_label(base.label.as_deref(), device.instance_flags),
622 timestamp_writes,
623 };
624
625 unsafe {
626 state.general.raw_encoder.begin_compute_pass(&hal_desc);
627 }
628
629 for command in base.commands.drain(..) {
630 match command {
631 ArcComputeCommand::SetBindGroup {
632 index,
633 num_dynamic_offsets,
634 bind_group,
635 } => {
636 let scope = PassErrorScope::SetBindGroup;
637 pass::set_bind_group::<ComputePassErrorInner>(
638 &mut state.general,
639 cmd_enc.as_ref(),
640 &base.dynamic_offsets,
641 index,
642 num_dynamic_offsets,
643 bind_group,
644 false,
645 )
646 .map_pass_err(scope)?;
647 }
648 ArcComputeCommand::SetPipeline(pipeline) => {
649 let scope = PassErrorScope::SetPipelineCompute;
650 set_pipeline(&mut state, cmd_enc.as_ref(), pipeline).map_pass_err(scope)?;
651 }
652 ArcComputeCommand::SetPushConstant {
653 offset,
654 size_bytes,
655 values_offset,
656 } => {
657 let scope = PassErrorScope::SetPushConstant;
658 pass::set_push_constant::<ComputePassErrorInner, _>(
659 &mut state.general,
660 &base.push_constant_data,
661 wgt::ShaderStages::COMPUTE,
662 offset,
663 size_bytes,
664 Some(values_offset),
665 |data_slice| {
666 let offset_in_elements =
667 (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
668 let size_in_elements =
669 (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
670 state.push_constants[offset_in_elements..][..size_in_elements]
671 .copy_from_slice(data_slice);
672 },
673 )
674 .map_pass_err(scope)?;
675 }
676 ArcComputeCommand::Dispatch(groups) => {
677 let scope = PassErrorScope::Dispatch { indirect: false };
678 dispatch(&mut state, groups).map_pass_err(scope)?;
679 }
680 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
681 let scope = PassErrorScope::Dispatch { indirect: true };
682 dispatch_indirect(&mut state, cmd_enc.as_ref(), buffer, offset)
683 .map_pass_err(scope)?;
684 }
685 ArcComputeCommand::PushDebugGroup { color: _, len } => {
686 pass::push_debug_group(&mut state.general, &base.string_data, len);
687 }
688 ArcComputeCommand::PopDebugGroup => {
689 let scope = PassErrorScope::PopDebugGroup;
690 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.general)
691 .map_pass_err(scope)?;
692 }
693 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
694 pass::insert_debug_marker(&mut state.general, &base.string_data, len);
695 }
696 ArcComputeCommand::WriteTimestamp {
697 query_set,
698 query_index,
699 } => {
700 let scope = PassErrorScope::WriteTimestamp;
701 pass::write_timestamp::<ComputePassErrorInner>(
702 &mut state.general,
703 cmd_enc.as_ref(),
704 None,
705 query_set,
706 query_index,
707 )
708 .map_pass_err(scope)?;
709 }
710 ArcComputeCommand::BeginPipelineStatisticsQuery {
711 query_set,
712 query_index,
713 } => {
714 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
715 validate_and_begin_pipeline_statistics_query(
716 query_set,
717 state.general.raw_encoder,
718 &mut state.general.tracker.query_sets,
719 cmd_enc.as_ref(),
720 query_index,
721 None,
722 &mut state.active_query,
723 )
724 .map_pass_err(scope)?;
725 }
726 ArcComputeCommand::EndPipelineStatisticsQuery => {
727 let scope = PassErrorScope::EndPipelineStatisticsQuery;
728 end_pipeline_statistics_query(
729 state.general.raw_encoder,
730 &mut state.active_query,
731 )
732 .map_pass_err(scope)?;
733 }
734 }
735 }
736
737 if state.general.debug_scope_depth > 0 {
738 Err(
739 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
740 .map_pass_err(pass_scope),
741 )?;
742 }
743
744 unsafe {
745 state.general.raw_encoder.end_compute_pass();
746 }
747
748 let State {
749 general:
750 pass::BaseState {
751 tracker,
752 pending_discard_init_fixups,
753 ..
754 },
755 intermediate_trackers,
756 ..
757 } = state;
758
759 encoder.close().map_pass_err(pass_scope)?;
761
762 let transit = encoder
766 .open_pass(hal_label(
767 Some("(wgpu internal) Pre Pass"),
768 self.instance.flags,
769 ))
770 .map_pass_err(pass_scope)?;
771 fixup_discarded_surfaces(
772 pending_discard_init_fixups.into_iter(),
773 transit,
774 &mut tracker.textures,
775 device,
776 &snatch_guard,
777 );
778 CommandEncoder::insert_barriers_from_tracker(
779 transit,
780 tracker,
781 &intermediate_trackers,
782 &snatch_guard,
783 );
784 encoder.close_and_swap().map_pass_err(pass_scope)?;
786
787 Ok(())
788 })
789 }
790}
791
792fn set_pipeline(
793 state: &mut State,
794 cmd_enc: &CommandEncoder,
795 pipeline: Arc<ComputePipeline>,
796) -> Result<(), ComputePassErrorInner> {
797 pipeline.same_device_as(cmd_enc)?;
798
799 state.pipeline = Some(pipeline.clone());
800
801 let pipeline = state
802 .general
803 .tracker
804 .compute_pipelines
805 .insert_single(pipeline)
806 .clone();
807
808 unsafe {
809 state
810 .general
811 .raw_encoder
812 .set_compute_pipeline(pipeline.raw());
813 }
814
815 pass::rebind_resources::<ComputePassErrorInner, _>(
817 &mut state.general,
818 &pipeline.layout,
819 &pipeline.late_sized_buffer_groups,
820 || {
821 state.push_constants.clear();
824 if let Some(push_constant_range) =
826 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
827 pcr.stages
828 .contains(wgt::ShaderStages::COMPUTE)
829 .then_some(pcr.range.clone())
830 })
831 {
832 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
834 state.push_constants.extend(core::iter::repeat_n(0, len));
835 }
836 },
837 )
838}
839
840fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
841 state.is_ready()?;
842
843 state.flush_states(None)?;
844
845 let groups_size_limit = state
846 .general
847 .device
848 .limits
849 .max_compute_workgroups_per_dimension;
850
851 if groups[0] > groups_size_limit
852 || groups[1] > groups_size_limit
853 || groups[2] > groups_size_limit
854 {
855 return Err(ComputePassErrorInner::Dispatch(
856 DispatchError::InvalidGroupSize {
857 current: groups,
858 limit: groups_size_limit,
859 },
860 ));
861 }
862
863 unsafe {
864 state.general.raw_encoder.dispatch(groups);
865 }
866 Ok(())
867}
868
869fn dispatch_indirect(
870 state: &mut State,
871 cmd_enc: &CommandEncoder,
872 buffer: Arc<Buffer>,
873 offset: u64,
874) -> Result<(), ComputePassErrorInner> {
875 buffer.same_device_as(cmd_enc)?;
876
877 state.is_ready()?;
878
879 state
880 .general
881 .device
882 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
883
884 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
885 buffer.check_destroyed(state.general.snatch_guard)?;
886
887 if offset % 4 != 0 {
888 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
889 }
890
891 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
892 if end_offset > buffer.size {
893 return Err(ComputePassErrorInner::IndirectBufferOverrun {
894 offset,
895 end_offset,
896 buffer_size: buffer.size,
897 });
898 }
899
900 let stride = 3 * 4; state.general.buffer_memory_init_actions.extend(
902 buffer.initialization_status.read().create_action(
903 &buffer,
904 offset..(offset + stride),
905 MemoryInitKind::NeedsInitializedMemory,
906 ),
907 );
908
909 if let Some(ref indirect_validation) = state.general.device.indirect_validation {
910 let params =
911 indirect_validation
912 .dispatch
913 .params(&state.general.device.limits, offset, buffer.size);
914
915 unsafe {
916 state
917 .general
918 .raw_encoder
919 .set_compute_pipeline(params.pipeline);
920 }
921
922 unsafe {
923 state.general.raw_encoder.set_push_constants(
924 params.pipeline_layout,
925 wgt::ShaderStages::COMPUTE,
926 0,
927 &[params.offset_remainder as u32 / 4],
928 );
929 }
930
931 unsafe {
932 state.general.raw_encoder.set_bind_group(
933 params.pipeline_layout,
934 0,
935 Some(params.dst_bind_group),
936 &[],
937 );
938 }
939 unsafe {
940 state.general.raw_encoder.set_bind_group(
941 params.pipeline_layout,
942 1,
943 Some(
944 buffer
945 .indirect_validation_bind_groups
946 .get(state.general.snatch_guard)
947 .unwrap()
948 .dispatch
949 .as_ref(),
950 ),
951 &[params.aligned_offset as u32],
952 );
953 }
954
955 let src_transition = state
956 .intermediate_trackers
957 .buffers
958 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
959 let src_barrier = src_transition
960 .map(|transition| transition.into_hal(&buffer, state.general.snatch_guard));
961 unsafe {
962 state
963 .general
964 .raw_encoder
965 .transition_buffers(src_barrier.as_slice());
966 }
967
968 unsafe {
969 state
970 .general
971 .raw_encoder
972 .transition_buffers(&[hal::BufferBarrier {
973 buffer: params.dst_buffer,
974 usage: hal::StateTransition {
975 from: wgt::BufferUses::INDIRECT,
976 to: wgt::BufferUses::STORAGE_READ_WRITE,
977 },
978 }]);
979 }
980
981 unsafe {
982 state.general.raw_encoder.dispatch([1, 1, 1]);
983 }
984
985 {
987 let pipeline = state.pipeline.as_ref().unwrap();
988
989 unsafe {
990 state
991 .general
992 .raw_encoder
993 .set_compute_pipeline(pipeline.raw());
994 }
995
996 if !state.push_constants.is_empty() {
997 unsafe {
998 state.general.raw_encoder.set_push_constants(
999 pipeline.layout.raw(),
1000 wgt::ShaderStages::COMPUTE,
1001 0,
1002 &state.push_constants,
1003 );
1004 }
1005 }
1006
1007 for (i, e) in state.general.binder.list_valid() {
1008 let group = e.group.as_ref().unwrap();
1009 let raw_bg = group.try_raw(state.general.snatch_guard)?;
1010 unsafe {
1011 state.general.raw_encoder.set_bind_group(
1012 pipeline.layout.raw(),
1013 i as u32,
1014 Some(raw_bg),
1015 &e.dynamic_offsets,
1016 );
1017 }
1018 }
1019 }
1020
1021 unsafe {
1022 state
1023 .general
1024 .raw_encoder
1025 .transition_buffers(&[hal::BufferBarrier {
1026 buffer: params.dst_buffer,
1027 usage: hal::StateTransition {
1028 from: wgt::BufferUses::STORAGE_READ_WRITE,
1029 to: wgt::BufferUses::INDIRECT,
1030 },
1031 }]);
1032 }
1033
1034 state.flush_states(None)?;
1035 unsafe {
1036 state
1037 .general
1038 .raw_encoder
1039 .dispatch_indirect(params.dst_buffer, 0);
1040 }
1041 } else {
1042 state
1043 .general
1044 .scope
1045 .buffers
1046 .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1047
1048 use crate::resource::Trackable;
1049 state.flush_states(Some(buffer.tracker_index()))?;
1050
1051 let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
1052 unsafe {
1053 state.general.raw_encoder.dispatch_indirect(buf_raw, offset);
1054 }
1055 }
1056
1057 Ok(())
1058}
1059
1060impl Global {
1073 pub fn compute_pass_set_bind_group(
1074 &self,
1075 pass: &mut ComputePass,
1076 index: u32,
1077 bind_group_id: Option<id::BindGroupId>,
1078 offsets: &[DynamicOffset],
1079 ) -> Result<(), PassStateError> {
1080 let scope = PassErrorScope::SetBindGroup;
1081
1082 let base = pass_base!(pass, scope);
1086
1087 if pass.current_bind_groups.set_and_check_redundant(
1088 bind_group_id,
1089 index,
1090 &mut base.dynamic_offsets,
1091 offsets,
1092 ) {
1093 return Ok(());
1094 }
1095
1096 let mut bind_group = None;
1097 if bind_group_id.is_some() {
1098 let bind_group_id = bind_group_id.unwrap();
1099
1100 let hub = &self.hub;
1101 bind_group = Some(pass_try!(
1102 base,
1103 scope,
1104 hub.bind_groups.get(bind_group_id).get(),
1105 ));
1106 }
1107
1108 base.commands.push(ArcComputeCommand::SetBindGroup {
1109 index,
1110 num_dynamic_offsets: offsets.len(),
1111 bind_group,
1112 });
1113
1114 Ok(())
1115 }
1116
1117 pub fn compute_pass_set_pipeline(
1118 &self,
1119 pass: &mut ComputePass,
1120 pipeline_id: id::ComputePipelineId,
1121 ) -> Result<(), PassStateError> {
1122 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1123
1124 let scope = PassErrorScope::SetPipelineCompute;
1125
1126 let base = pass_base!(pass, scope);
1129
1130 if redundant {
1131 return Ok(());
1132 }
1133
1134 let hub = &self.hub;
1135 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1136
1137 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1138
1139 Ok(())
1140 }
1141
1142 pub fn compute_pass_set_push_constants(
1143 &self,
1144 pass: &mut ComputePass,
1145 offset: u32,
1146 data: &[u8],
1147 ) -> Result<(), PassStateError> {
1148 let scope = PassErrorScope::SetPushConstant;
1149 let base = pass_base!(pass, scope);
1150
1151 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1152 pass_try!(
1153 base,
1154 scope,
1155 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1156 );
1157 }
1158
1159 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1160 pass_try!(
1161 base,
1162 scope,
1163 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1164 )
1165 }
1166 let value_offset = pass_try!(
1167 base,
1168 scope,
1169 base.push_constant_data
1170 .len()
1171 .try_into()
1172 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1173 );
1174
1175 base.push_constant_data.extend(
1176 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1177 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1178 );
1179
1180 base.commands.push(ArcComputeCommand::SetPushConstant {
1181 offset,
1182 size_bytes: data.len() as u32,
1183 values_offset: value_offset,
1184 });
1185
1186 Ok(())
1187 }
1188
1189 pub fn compute_pass_dispatch_workgroups(
1190 &self,
1191 pass: &mut ComputePass,
1192 groups_x: u32,
1193 groups_y: u32,
1194 groups_z: u32,
1195 ) -> Result<(), PassStateError> {
1196 let scope = PassErrorScope::Dispatch { indirect: false };
1197
1198 pass_base!(pass, scope)
1199 .commands
1200 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1201
1202 Ok(())
1203 }
1204
1205 pub fn compute_pass_dispatch_workgroups_indirect(
1206 &self,
1207 pass: &mut ComputePass,
1208 buffer_id: id::BufferId,
1209 offset: BufferAddress,
1210 ) -> Result<(), PassStateError> {
1211 let hub = &self.hub;
1212 let scope = PassErrorScope::Dispatch { indirect: true };
1213 let base = pass_base!(pass, scope);
1214
1215 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1216
1217 base.commands
1218 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1219
1220 Ok(())
1221 }
1222
1223 pub fn compute_pass_push_debug_group(
1224 &self,
1225 pass: &mut ComputePass,
1226 label: &str,
1227 color: u32,
1228 ) -> Result<(), PassStateError> {
1229 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1230
1231 let bytes = label.as_bytes();
1232 base.string_data.extend_from_slice(bytes);
1233
1234 base.commands.push(ArcComputeCommand::PushDebugGroup {
1235 color,
1236 len: bytes.len(),
1237 });
1238
1239 Ok(())
1240 }
1241
1242 pub fn compute_pass_pop_debug_group(
1243 &self,
1244 pass: &mut ComputePass,
1245 ) -> Result<(), PassStateError> {
1246 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1247
1248 base.commands.push(ArcComputeCommand::PopDebugGroup);
1249
1250 Ok(())
1251 }
1252
1253 pub fn compute_pass_insert_debug_marker(
1254 &self,
1255 pass: &mut ComputePass,
1256 label: &str,
1257 color: u32,
1258 ) -> Result<(), PassStateError> {
1259 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1260
1261 let bytes = label.as_bytes();
1262 base.string_data.extend_from_slice(bytes);
1263
1264 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1265 color,
1266 len: bytes.len(),
1267 });
1268
1269 Ok(())
1270 }
1271
1272 pub fn compute_pass_write_timestamp(
1273 &self,
1274 pass: &mut ComputePass,
1275 query_set_id: id::QuerySetId,
1276 query_index: u32,
1277 ) -> Result<(), PassStateError> {
1278 let scope = PassErrorScope::WriteTimestamp;
1279 let base = pass_base!(pass, scope);
1280
1281 let hub = &self.hub;
1282 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1283
1284 base.commands.push(ArcComputeCommand::WriteTimestamp {
1285 query_set,
1286 query_index,
1287 });
1288
1289 Ok(())
1290 }
1291
1292 pub fn compute_pass_begin_pipeline_statistics_query(
1293 &self,
1294 pass: &mut ComputePass,
1295 query_set_id: id::QuerySetId,
1296 query_index: u32,
1297 ) -> Result<(), PassStateError> {
1298 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1299 let base = pass_base!(pass, scope);
1300
1301 let hub = &self.hub;
1302 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1303
1304 base.commands
1305 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1306 query_set,
1307 query_index,
1308 });
1309
1310 Ok(())
1311 }
1312
1313 pub fn compute_pass_end_pipeline_statistics_query(
1314 &self,
1315 pass: &mut ComputePass,
1316 ) -> Result<(), PassStateError> {
1317 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1318 .commands
1319 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1320
1321 Ok(())
1322 }
1323}