1use thiserror::Error;
2use wgt::{
3 error::{ErrorType, WebGpuError},
4 BufferAddress, DynamicOffset,
5};
6
7use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
8use core::{fmt, str};
9
10use crate::command::{
11 encoder::EncodingState, pass, CommandBufferMutable, CommandEncoder, DebugGroupError,
12 EncoderStateError, PassStateError, TimestampWritesError,
13};
14use crate::resource::DestroyedResourceError;
15use crate::{binding_model::BindError, resource::RawResourceAccess};
16use crate::{
17 binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
18 command::{
19 bind::{Binder, BinderError},
20 compute_command::ArcComputeCommand,
21 end_pipeline_statistics_query,
22 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
23 pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
24 BasePass, BindGroupStateChange, CommandEncoderError, MapPassErr, PassErrorScope,
25 PassTimestampWrites, QueryUseError, StateChange,
26 },
27 device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
28 global::Global,
29 hal_label, id,
30 init_tracker::MemoryInitKind,
31 pipeline::ComputePipeline,
32 resource::{
33 self, Buffer, InvalidResourceError, Labeled, MissingBufferUsageError, ParentDevice,
34 },
35 track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex},
36 Label,
37};
38
39pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
40
41pub struct ComputePass {
49 base: ComputeBasePass,
51
52 parent: Option<Arc<CommandEncoder>>,
58
59 timestamp_writes: Option<ArcPassTimestampWrites>,
60
61 current_bind_groups: BindGroupStateChange,
63 current_pipeline: StateChange<id::ComputePipelineId>,
64}
65
66impl ComputePass {
67 fn new(parent: Arc<CommandEncoder>, desc: ArcComputePassDescriptor) -> Self {
69 let ArcComputePassDescriptor {
70 label,
71 timestamp_writes,
72 } = desc;
73
74 Self {
75 base: BasePass::new(&label),
76 parent: Some(parent),
77 timestamp_writes,
78
79 current_bind_groups: BindGroupStateChange::new(),
80 current_pipeline: StateChange::new(),
81 }
82 }
83
84 fn new_invalid(parent: Arc<CommandEncoder>, label: &Label, err: ComputePassError) -> Self {
85 Self {
86 base: BasePass::new_invalid(label, err),
87 parent: Some(parent),
88 timestamp_writes: None,
89 current_bind_groups: BindGroupStateChange::new(),
90 current_pipeline: StateChange::new(),
91 }
92 }
93
94 #[inline]
95 pub fn label(&self) -> Option<&str> {
96 self.base.label.as_deref()
97 }
98}
99
100impl fmt::Debug for ComputePass {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self.parent {
103 Some(ref cmd_enc) => write!(f, "ComputePass {{ parent: {} }}", cmd_enc.error_ident()),
104 None => write!(f, "ComputePass {{ parent: None }}"),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Default)]
110pub struct ComputePassDescriptor<'a, PTW = PassTimestampWrites> {
111 pub label: Label<'a>,
112 pub timestamp_writes: Option<PTW>,
114}
115
116type ArcComputePassDescriptor<'a> = ComputePassDescriptor<'a, ArcPassTimestampWrites>;
118
119#[derive(Clone, Debug, Error)]
120#[non_exhaustive]
121pub enum DispatchError {
122 #[error("Compute pipeline must be set")]
123 MissingPipeline(pass::MissingPipeline),
124 #[error(transparent)]
125 IncompatibleBindGroup(#[from] Box<BinderError>),
126 #[error(
127 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
128 )]
129 InvalidGroupSize { current: [u32; 3], limit: u32 },
130 #[error(transparent)]
131 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
132}
133
134impl WebGpuError for DispatchError {
135 fn webgpu_error_type(&self) -> ErrorType {
136 ErrorType::Validation
137 }
138}
139
140#[derive(Clone, Debug, Error)]
142pub enum ComputePassErrorInner {
143 #[error(transparent)]
144 Device(#[from] DeviceError),
145 #[error(transparent)]
146 EncoderState(#[from] EncoderStateError),
147 #[error("Parent encoder is invalid")]
148 InvalidParentEncoder,
149 #[error(transparent)]
150 DebugGroupError(#[from] DebugGroupError),
151 #[error(transparent)]
152 BindGroupIndexOutOfRange(#[from] pass::BindGroupIndexOutOfRange),
153 #[error(transparent)]
154 DestroyedResource(#[from] DestroyedResourceError),
155 #[error("Indirect buffer offset {0:?} is not a multiple of 4")]
156 UnalignedIndirectBufferOffset(BufferAddress),
157 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
158 IndirectBufferOverrun {
159 offset: u64,
160 end_offset: u64,
161 buffer_size: u64,
162 },
163 #[error(transparent)]
164 ResourceUsageCompatibility(#[from] ResourceUsageCompatibilityError),
165 #[error(transparent)]
166 MissingBufferUsage(#[from] MissingBufferUsageError),
167 #[error(transparent)]
168 Dispatch(#[from] DispatchError),
169 #[error(transparent)]
170 Bind(#[from] BindError),
171 #[error(transparent)]
172 PushConstants(#[from] PushConstantUploadError),
173 #[error("Push constant offset must be aligned to 4 bytes")]
174 PushConstantOffsetAlignment,
175 #[error("Push constant size must be aligned to 4 bytes")]
176 PushConstantSizeAlignment,
177 #[error("Ran out of push constant space. Don't set 4gb of push constants per ComputePass.")]
178 PushConstantOutOfMemory,
179 #[error(transparent)]
180 QueryUse(#[from] QueryUseError),
181 #[error(transparent)]
182 MissingFeatures(#[from] MissingFeatures),
183 #[error(transparent)]
184 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
185 #[error("The compute pass has already been ended and no further commands can be recorded")]
186 PassEnded,
187 #[error(transparent)]
188 InvalidResource(#[from] InvalidResourceError),
189 #[error(transparent)]
190 TimestampWrites(#[from] TimestampWritesError),
191 #[error(transparent)]
193 InvalidValuesOffset(#[from] pass::InvalidValuesOffset),
194}
195
196#[derive(Clone, Debug, Error)]
199#[error("{scope}")]
200pub struct ComputePassError {
201 pub scope: PassErrorScope,
202 #[source]
203 pub(super) inner: ComputePassErrorInner,
204}
205
206impl From<pass::MissingPipeline> for ComputePassErrorInner {
207 fn from(value: pass::MissingPipeline) -> Self {
208 Self::Dispatch(DispatchError::MissingPipeline(value))
209 }
210}
211
212impl<E> MapPassErr<ComputePassError> for E
213where
214 E: Into<ComputePassErrorInner>,
215{
216 fn map_pass_err(self, scope: PassErrorScope) -> ComputePassError {
217 ComputePassError {
218 scope,
219 inner: self.into(),
220 }
221 }
222}
223
224impl WebGpuError for ComputePassError {
225 fn webgpu_error_type(&self) -> ErrorType {
226 let Self { scope: _, inner } = self;
227 let e: &dyn WebGpuError = match inner {
228 ComputePassErrorInner::Device(e) => e,
229 ComputePassErrorInner::EncoderState(e) => e,
230 ComputePassErrorInner::DebugGroupError(e) => e,
231 ComputePassErrorInner::DestroyedResource(e) => e,
232 ComputePassErrorInner::ResourceUsageCompatibility(e) => e,
233 ComputePassErrorInner::MissingBufferUsage(e) => e,
234 ComputePassErrorInner::Dispatch(e) => e,
235 ComputePassErrorInner::Bind(e) => e,
236 ComputePassErrorInner::PushConstants(e) => e,
237 ComputePassErrorInner::QueryUse(e) => e,
238 ComputePassErrorInner::MissingFeatures(e) => e,
239 ComputePassErrorInner::MissingDownlevelFlags(e) => e,
240 ComputePassErrorInner::InvalidResource(e) => e,
241 ComputePassErrorInner::TimestampWrites(e) => e,
242 ComputePassErrorInner::InvalidValuesOffset(e) => e,
243
244 ComputePassErrorInner::InvalidParentEncoder
245 | ComputePassErrorInner::BindGroupIndexOutOfRange { .. }
246 | ComputePassErrorInner::UnalignedIndirectBufferOffset(_)
247 | ComputePassErrorInner::IndirectBufferOverrun { .. }
248 | ComputePassErrorInner::PushConstantOffsetAlignment
249 | ComputePassErrorInner::PushConstantSizeAlignment
250 | ComputePassErrorInner::PushConstantOutOfMemory
251 | ComputePassErrorInner::PassEnded => return ErrorType::Validation,
252 };
253 e.webgpu_error_type()
254 }
255}
256
257struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
258 pipeline: Option<Arc<ComputePipeline>>,
259
260 pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
261
262 active_query: Option<(Arc<resource::QuerySet>, u32)>,
263
264 push_constants: Vec<u32>,
265
266 intermediate_trackers: Tracker,
267}
268
269impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
270 State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
271{
272 fn is_ready(&self) -> Result<(), DispatchError> {
273 if let Some(pipeline) = self.pipeline.as_ref() {
274 self.pass.binder.check_compatibility(pipeline.as_ref())?;
275 self.pass.binder.check_late_buffer_bindings()?;
276 Ok(())
277 } else {
278 Err(DispatchError::MissingPipeline(pass::MissingPipeline))
279 }
280 }
281
282 fn flush_states(
285 &mut self,
286 indirect_buffer: Option<TrackerIndex>,
287 ) -> Result<(), ResourceUsageCompatibilityError> {
288 for bind_group in self.pass.binder.list_active() {
289 unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
290 }
293
294 for bind_group in self.pass.binder.list_active() {
295 unsafe {
296 self.intermediate_trackers
297 .set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used)
298 }
299 }
300
301 unsafe {
303 self.intermediate_trackers
304 .buffers
305 .set_and_remove_from_usage_scope_sparse(
306 &mut self.pass.scope.buffers,
307 indirect_buffer,
308 );
309 }
310
311 CommandEncoder::drain_barriers(
312 self.pass.base.raw_encoder,
313 &mut self.intermediate_trackers,
314 self.pass.base.snatch_guard,
315 );
316 Ok(())
317 }
318}
319
320impl Global {
323 pub fn command_encoder_begin_compute_pass(
334 &self,
335 encoder_id: id::CommandEncoderId,
336 desc: &ComputePassDescriptor<'_>,
337 ) -> (ComputePass, Option<CommandEncoderError>) {
338 use EncoderStateError as SErr;
339
340 let scope = PassErrorScope::Pass;
341 let hub = &self.hub;
342
343 let label = desc.label.as_deref().map(Cow::Borrowed);
344
345 let cmd_enc = hub.command_encoders.get(encoder_id);
346 let mut cmd_buf_data = cmd_enc.data.lock();
347
348 match cmd_buf_data.lock_encoder() {
349 Ok(()) => {
350 drop(cmd_buf_data);
351 if let Err(err) = cmd_enc.device.check_is_valid() {
352 return (
353 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
354 None,
355 );
356 }
357
358 match desc
359 .timestamp_writes
360 .as_ref()
361 .map(|tw| {
362 Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
363 &cmd_enc.device,
364 &hub.query_sets.read(),
365 tw,
366 )
367 })
368 .transpose()
369 {
370 Ok(timestamp_writes) => {
371 let arc_desc = ArcComputePassDescriptor {
372 label,
373 timestamp_writes,
374 };
375 (ComputePass::new(cmd_enc, arc_desc), None)
376 }
377 Err(err) => (
378 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
379 None,
380 ),
381 }
382 }
383 Err(err @ SErr::Locked) => {
384 cmd_buf_data.invalidate(err.clone());
388 drop(cmd_buf_data);
389 (
390 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
391 None,
392 )
393 }
394 Err(err @ (SErr::Ended | SErr::Submitted)) => {
395 drop(cmd_buf_data);
398 (
399 ComputePass::new_invalid(cmd_enc, &label, err.clone().map_pass_err(scope)),
400 Some(err.into()),
401 )
402 }
403 Err(err @ SErr::Invalid) => {
404 drop(cmd_buf_data);
410 (
411 ComputePass::new_invalid(cmd_enc, &label, err.map_pass_err(scope)),
412 None,
413 )
414 }
415 Err(SErr::Unlocked) => {
416 unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
417 }
418 }
419 }
420
421 #[doc(hidden)]
427 #[cfg(any(feature = "serde", feature = "replay"))]
428 pub fn compute_pass_end_with_unresolved_commands(
429 &self,
430 encoder_id: id::CommandEncoderId,
431 base: BasePass<super::ComputeCommand, core::convert::Infallible>,
432 timestamp_writes: Option<&PassTimestampWrites>,
433 ) {
434 #[cfg(feature = "trace")]
435 {
436 let cmd_enc = self.hub.command_encoders.get(encoder_id);
437 let mut cmd_buf_data = cmd_enc.data.lock();
438 let cmd_buf_data = cmd_buf_data.get_inner();
439
440 if let Some(ref mut list) = cmd_buf_data.trace_commands {
441 list.push(crate::command::Command::RunComputePass {
442 base: BasePass {
443 label: base.label.clone(),
444 error: None,
445 commands: base.commands.clone(),
446 dynamic_offsets: base.dynamic_offsets.clone(),
447 string_data: base.string_data.clone(),
448 push_constant_data: base.push_constant_data.clone(),
449 },
450 timestamp_writes: timestamp_writes.cloned(),
451 });
452 }
453 }
454
455 let BasePass {
456 label,
457 error: _,
458 commands,
459 dynamic_offsets,
460 string_data,
461 push_constant_data,
462 } = base;
463
464 let (mut compute_pass, encoder_error) = self.command_encoder_begin_compute_pass(
465 encoder_id,
466 &ComputePassDescriptor {
467 label: label.as_deref().map(Cow::Borrowed),
468 timestamp_writes: timestamp_writes.cloned(),
469 },
470 );
471 if let Some(err) = encoder_error {
472 panic!("{:?}", err);
473 };
474
475 compute_pass.base = BasePass {
476 label,
477 error: None,
478 commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
479 .unwrap(),
480 dynamic_offsets,
481 string_data,
482 push_constant_data,
483 };
484
485 self.compute_pass_end(&mut compute_pass).unwrap();
486 }
487
488 pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
489 profiling::scope!(
490 "CommandEncoder::run_compute_pass {}",
491 pass.base.label.as_deref().unwrap_or("")
492 );
493
494 let cmd_enc = pass.parent.take().ok_or(EncoderStateError::Ended)?;
495 let mut cmd_buf_data = cmd_enc.data.lock();
496
497 if let Some(err) = pass.base.error.take() {
498 if matches!(
499 err,
500 ComputePassError {
501 inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
502 scope: _,
503 }
504 ) {
505 return Err(EncoderStateError::Ended);
510 } else {
511 cmd_buf_data.invalidate(err);
515 return Ok(());
516 }
517 }
518
519 cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
520 encode_compute_pass(cmd_buf_data, &cmd_enc, pass)
521 })
522 }
523}
524
525fn encode_compute_pass(
526 cmd_buf_data: &mut CommandBufferMutable,
527 cmd_enc: &Arc<CommandEncoder>,
528 pass: &mut ComputePass,
529) -> Result<(), ComputePassError> {
530 let pass_scope = PassErrorScope::Pass;
531
532 let device = &cmd_enc.device;
533 device.check_is_valid().map_pass_err(pass_scope)?;
534
535 let base = &mut pass.base;
536
537 let encoder = &mut cmd_buf_data.encoder;
538
539 encoder.close_if_open().map_pass_err(pass_scope)?;
543 let raw_encoder = encoder
544 .open_pass(base.label.as_deref())
545 .map_pass_err(pass_scope)?;
546
547 let snatch_guard = device.snatchable_lock.read();
548 let mut debug_scope_depth = 0;
549
550 let mut state = State {
551 pipeline: None,
552
553 pass: pass::PassState {
554 base: EncodingState {
555 device,
556 raw_encoder,
557 tracker: &mut cmd_buf_data.trackers,
558 buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
559 texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
560 as_actions: &mut cmd_buf_data.as_actions,
561 indirect_draw_validation_resources: &mut cmd_buf_data
562 .indirect_draw_validation_resources,
563 snatch_guard: &snatch_guard,
564 debug_scope_depth: &mut debug_scope_depth,
565 },
566 binder: Binder::new(),
567 temp_offsets: Vec::new(),
568 dynamic_offset_count: 0,
569
570 pending_discard_init_fixups: SurfacesInDiscardState::new(),
571
572 scope: device.new_usage_scope(),
573
574 string_offset: 0,
575 },
576 active_query: None,
577
578 push_constants: Vec::new(),
579
580 intermediate_trackers: Tracker::new(),
581 };
582
583 let indices = &state.pass.base.device.tracker_indices;
584 state
585 .pass
586 .base
587 .tracker
588 .buffers
589 .set_size(indices.buffers.size());
590 state
591 .pass
592 .base
593 .tracker
594 .textures
595 .set_size(indices.textures.size());
596
597 let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
598 if let Some(tw) = pass.timestamp_writes.take() {
599 tw.query_set
600 .same_device_as(cmd_enc.as_ref())
601 .map_pass_err(pass_scope)?;
602
603 let query_set = state
604 .pass
605 .base
606 .tracker
607 .query_sets
608 .insert_single(tw.query_set);
609
610 let range = if let (Some(index_a), Some(index_b)) =
613 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
614 {
615 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
616 } else {
617 tw.beginning_of_pass_write_index
618 .or(tw.end_of_pass_write_index)
619 .map(|i| i..i + 1)
620 };
621 if let Some(range) = range {
624 unsafe {
625 state
626 .pass
627 .base
628 .raw_encoder
629 .reset_queries(query_set.raw(), range);
630 }
631 }
632
633 Some(hal::PassTimestampWrites {
634 query_set: query_set.raw(),
635 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
636 end_of_pass_write_index: tw.end_of_pass_write_index,
637 })
638 } else {
639 None
640 };
641
642 let hal_desc = hal::ComputePassDescriptor {
643 label: hal_label(base.label.as_deref(), device.instance_flags),
644 timestamp_writes,
645 };
646
647 unsafe {
648 state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
649 }
650
651 for command in base.commands.drain(..) {
652 match command {
653 ArcComputeCommand::SetBindGroup {
654 index,
655 num_dynamic_offsets,
656 bind_group,
657 } => {
658 let scope = PassErrorScope::SetBindGroup;
659 pass::set_bind_group::<ComputePassErrorInner>(
660 &mut state.pass,
661 cmd_enc.as_ref(),
662 &base.dynamic_offsets,
663 index,
664 num_dynamic_offsets,
665 bind_group,
666 false,
667 )
668 .map_pass_err(scope)?;
669 }
670 ArcComputeCommand::SetPipeline(pipeline) => {
671 let scope = PassErrorScope::SetPipelineCompute;
672 set_pipeline(&mut state, cmd_enc.as_ref(), pipeline).map_pass_err(scope)?;
673 }
674 ArcComputeCommand::SetPushConstant {
675 offset,
676 size_bytes,
677 values_offset,
678 } => {
679 let scope = PassErrorScope::SetPushConstant;
680 pass::set_push_constant::<ComputePassErrorInner, _>(
681 &mut state.pass,
682 &base.push_constant_data,
683 wgt::ShaderStages::COMPUTE,
684 offset,
685 size_bytes,
686 Some(values_offset),
687 |data_slice| {
688 let offset_in_elements = (offset / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
689 let size_in_elements = (size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
690 state.push_constants[offset_in_elements..][..size_in_elements]
691 .copy_from_slice(data_slice);
692 },
693 )
694 .map_pass_err(scope)?;
695 }
696 ArcComputeCommand::Dispatch(groups) => {
697 let scope = PassErrorScope::Dispatch { indirect: false };
698 dispatch(&mut state, groups).map_pass_err(scope)?;
699 }
700 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
701 let scope = PassErrorScope::Dispatch { indirect: true };
702 dispatch_indirect(&mut state, cmd_enc.as_ref(), buffer, offset)
703 .map_pass_err(scope)?;
704 }
705 ArcComputeCommand::PushDebugGroup { color: _, len } => {
706 pass::push_debug_group(&mut state.pass, &base.string_data, len);
707 }
708 ArcComputeCommand::PopDebugGroup => {
709 let scope = PassErrorScope::PopDebugGroup;
710 pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
711 .map_pass_err(scope)?;
712 }
713 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
714 pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
715 }
716 ArcComputeCommand::WriteTimestamp {
717 query_set,
718 query_index,
719 } => {
720 let scope = PassErrorScope::WriteTimestamp;
721 pass::write_timestamp::<ComputePassErrorInner>(
722 &mut state.pass,
723 cmd_enc.as_ref(),
724 None,
725 query_set,
726 query_index,
727 )
728 .map_pass_err(scope)?;
729 }
730 ArcComputeCommand::BeginPipelineStatisticsQuery {
731 query_set,
732 query_index,
733 } => {
734 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
735 validate_and_begin_pipeline_statistics_query(
736 query_set,
737 state.pass.base.raw_encoder,
738 &mut state.pass.base.tracker.query_sets,
739 cmd_enc.as_ref(),
740 query_index,
741 None,
742 &mut state.active_query,
743 )
744 .map_pass_err(scope)?;
745 }
746 ArcComputeCommand::EndPipelineStatisticsQuery => {
747 let scope = PassErrorScope::EndPipelineStatisticsQuery;
748 end_pipeline_statistics_query(state.pass.base.raw_encoder, &mut state.active_query)
749 .map_pass_err(scope)?;
750 }
751 }
752 }
753
754 if *state.pass.base.debug_scope_depth > 0 {
755 Err(
756 ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
757 .map_pass_err(pass_scope),
758 )?;
759 }
760
761 unsafe {
762 state.pass.base.raw_encoder.end_compute_pass();
763 }
764
765 let State {
766 pass:
767 pass::PassState {
768 base: EncodingState { tracker, .. },
769 pending_discard_init_fixups,
770 ..
771 },
772 intermediate_trackers,
773 ..
774 } = state;
775
776 encoder.close().map_pass_err(pass_scope)?;
778
779 let transit = encoder
783 .open_pass(hal_label(
784 Some("(wgpu internal) Pre Pass"),
785 device.instance_flags,
786 ))
787 .map_pass_err(pass_scope)?;
788 fixup_discarded_surfaces(
789 pending_discard_init_fixups.into_iter(),
790 transit,
791 &mut tracker.textures,
792 device,
793 &snatch_guard,
794 );
795 CommandEncoder::insert_barriers_from_tracker(
796 transit,
797 tracker,
798 &intermediate_trackers,
799 &snatch_guard,
800 );
801 encoder.close_and_swap().map_pass_err(pass_scope)?;
803
804 Ok(())
805}
806
807fn set_pipeline(
808 state: &mut State,
809 cmd_enc: &CommandEncoder,
810 pipeline: Arc<ComputePipeline>,
811) -> Result<(), ComputePassErrorInner> {
812 pipeline.same_device_as(cmd_enc)?;
813
814 state.pipeline = Some(pipeline.clone());
815
816 let pipeline = state
817 .pass
818 .base
819 .tracker
820 .compute_pipelines
821 .insert_single(pipeline)
822 .clone();
823
824 unsafe {
825 state
826 .pass
827 .base
828 .raw_encoder
829 .set_compute_pipeline(pipeline.raw());
830 }
831
832 pass::rebind_resources::<ComputePassErrorInner, _>(
834 &mut state.pass,
835 &pipeline.layout,
836 &pipeline.late_sized_buffer_groups,
837 || {
838 state.push_constants.clear();
841 if let Some(push_constant_range) =
843 pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
844 pcr.stages
845 .contains(wgt::ShaderStages::COMPUTE)
846 .then_some(pcr.range.clone())
847 })
848 {
849 let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
851 state.push_constants.extend(core::iter::repeat_n(0, len));
852 }
853 },
854 )
855}
856
857fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorInner> {
858 state.is_ready()?;
859
860 state.flush_states(None)?;
861
862 let groups_size_limit = state
863 .pass
864 .base
865 .device
866 .limits
867 .max_compute_workgroups_per_dimension;
868
869 if groups[0] > groups_size_limit
870 || groups[1] > groups_size_limit
871 || groups[2] > groups_size_limit
872 {
873 return Err(ComputePassErrorInner::Dispatch(
874 DispatchError::InvalidGroupSize {
875 current: groups,
876 limit: groups_size_limit,
877 },
878 ));
879 }
880
881 unsafe {
882 state.pass.base.raw_encoder.dispatch(groups);
883 }
884 Ok(())
885}
886
887fn dispatch_indirect(
888 state: &mut State,
889 cmd_enc: &CommandEncoder,
890 buffer: Arc<Buffer>,
891 offset: u64,
892) -> Result<(), ComputePassErrorInner> {
893 buffer.same_device_as(cmd_enc)?;
894
895 state.is_ready()?;
896
897 state
898 .pass
899 .base
900 .device
901 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
902
903 buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
904 buffer.check_destroyed(state.pass.base.snatch_guard)?;
905
906 if offset % 4 != 0 {
907 return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
908 }
909
910 let end_offset = offset + size_of::<wgt::DispatchIndirectArgs>() as u64;
911 if end_offset > buffer.size {
912 return Err(ComputePassErrorInner::IndirectBufferOverrun {
913 offset,
914 end_offset,
915 buffer_size: buffer.size,
916 });
917 }
918
919 let stride = 3 * 4; state.pass.base.buffer_memory_init_actions.extend(
921 buffer.initialization_status.read().create_action(
922 &buffer,
923 offset..(offset + stride),
924 MemoryInitKind::NeedsInitializedMemory,
925 ),
926 );
927
928 if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
929 let params = indirect_validation.dispatch.params(
930 &state.pass.base.device.limits,
931 offset,
932 buffer.size,
933 );
934
935 unsafe {
936 state
937 .pass
938 .base
939 .raw_encoder
940 .set_compute_pipeline(params.pipeline);
941 }
942
943 unsafe {
944 state.pass.base.raw_encoder.set_push_constants(
945 params.pipeline_layout,
946 wgt::ShaderStages::COMPUTE,
947 0,
948 &[params.offset_remainder as u32 / 4],
949 );
950 }
951
952 unsafe {
953 state.pass.base.raw_encoder.set_bind_group(
954 params.pipeline_layout,
955 0,
956 Some(params.dst_bind_group),
957 &[],
958 );
959 }
960 unsafe {
961 state.pass.base.raw_encoder.set_bind_group(
962 params.pipeline_layout,
963 1,
964 Some(
965 buffer
966 .indirect_validation_bind_groups
967 .get(state.pass.base.snatch_guard)
968 .unwrap()
969 .dispatch
970 .as_ref(),
971 ),
972 &[params.aligned_offset as u32],
973 );
974 }
975
976 let src_transition = state
977 .intermediate_trackers
978 .buffers
979 .set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
980 let src_barrier = src_transition
981 .map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
982 unsafe {
983 state
984 .pass
985 .base
986 .raw_encoder
987 .transition_buffers(src_barrier.as_slice());
988 }
989
990 unsafe {
991 state
992 .pass
993 .base
994 .raw_encoder
995 .transition_buffers(&[hal::BufferBarrier {
996 buffer: params.dst_buffer,
997 usage: hal::StateTransition {
998 from: wgt::BufferUses::INDIRECT,
999 to: wgt::BufferUses::STORAGE_READ_WRITE,
1000 },
1001 }]);
1002 }
1003
1004 unsafe {
1005 state.pass.base.raw_encoder.dispatch([1, 1, 1]);
1006 }
1007
1008 {
1010 let pipeline = state.pipeline.as_ref().unwrap();
1011
1012 unsafe {
1013 state
1014 .pass
1015 .base
1016 .raw_encoder
1017 .set_compute_pipeline(pipeline.raw());
1018 }
1019
1020 if !state.push_constants.is_empty() {
1021 unsafe {
1022 state.pass.base.raw_encoder.set_push_constants(
1023 pipeline.layout.raw(),
1024 wgt::ShaderStages::COMPUTE,
1025 0,
1026 &state.push_constants,
1027 );
1028 }
1029 }
1030
1031 for (i, e) in state.pass.binder.list_valid() {
1032 let group = e.group.as_ref().unwrap();
1033 let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
1034 unsafe {
1035 state.pass.base.raw_encoder.set_bind_group(
1036 pipeline.layout.raw(),
1037 i as u32,
1038 Some(raw_bg),
1039 &e.dynamic_offsets,
1040 );
1041 }
1042 }
1043 }
1044
1045 unsafe {
1046 state
1047 .pass
1048 .base
1049 .raw_encoder
1050 .transition_buffers(&[hal::BufferBarrier {
1051 buffer: params.dst_buffer,
1052 usage: hal::StateTransition {
1053 from: wgt::BufferUses::STORAGE_READ_WRITE,
1054 to: wgt::BufferUses::INDIRECT,
1055 },
1056 }]);
1057 }
1058
1059 state.flush_states(None)?;
1060 unsafe {
1061 state
1062 .pass
1063 .base
1064 .raw_encoder
1065 .dispatch_indirect(params.dst_buffer, 0);
1066 }
1067 } else {
1068 state
1069 .pass
1070 .scope
1071 .buffers
1072 .merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
1073
1074 use crate::resource::Trackable;
1075 state.flush_states(Some(buffer.tracker_index()))?;
1076
1077 let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
1078 unsafe {
1079 state
1080 .pass
1081 .base
1082 .raw_encoder
1083 .dispatch_indirect(buf_raw, offset);
1084 }
1085 }
1086
1087 Ok(())
1088}
1089
1090impl Global {
1103 pub fn compute_pass_set_bind_group(
1104 &self,
1105 pass: &mut ComputePass,
1106 index: u32,
1107 bind_group_id: Option<id::BindGroupId>,
1108 offsets: &[DynamicOffset],
1109 ) -> Result<(), PassStateError> {
1110 let scope = PassErrorScope::SetBindGroup;
1111
1112 let base = pass_base!(pass, scope);
1116
1117 if pass.current_bind_groups.set_and_check_redundant(
1118 bind_group_id,
1119 index,
1120 &mut base.dynamic_offsets,
1121 offsets,
1122 ) {
1123 return Ok(());
1124 }
1125
1126 let mut bind_group = None;
1127 if bind_group_id.is_some() {
1128 let bind_group_id = bind_group_id.unwrap();
1129
1130 let hub = &self.hub;
1131 bind_group = Some(pass_try!(
1132 base,
1133 scope,
1134 hub.bind_groups.get(bind_group_id).get(),
1135 ));
1136 }
1137
1138 base.commands.push(ArcComputeCommand::SetBindGroup {
1139 index,
1140 num_dynamic_offsets: offsets.len(),
1141 bind_group,
1142 });
1143
1144 Ok(())
1145 }
1146
1147 pub fn compute_pass_set_pipeline(
1148 &self,
1149 pass: &mut ComputePass,
1150 pipeline_id: id::ComputePipelineId,
1151 ) -> Result<(), PassStateError> {
1152 let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
1153
1154 let scope = PassErrorScope::SetPipelineCompute;
1155
1156 let base = pass_base!(pass, scope);
1159
1160 if redundant {
1161 return Ok(());
1162 }
1163
1164 let hub = &self.hub;
1165 let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
1166
1167 base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
1168
1169 Ok(())
1170 }
1171
1172 pub fn compute_pass_set_push_constants(
1173 &self,
1174 pass: &mut ComputePass,
1175 offset: u32,
1176 data: &[u8],
1177 ) -> Result<(), PassStateError> {
1178 let scope = PassErrorScope::SetPushConstant;
1179 let base = pass_base!(pass, scope);
1180
1181 if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1182 pass_try!(
1183 base,
1184 scope,
1185 Err(ComputePassErrorInner::PushConstantOffsetAlignment),
1186 );
1187 }
1188
1189 if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
1190 pass_try!(
1191 base,
1192 scope,
1193 Err(ComputePassErrorInner::PushConstantSizeAlignment),
1194 )
1195 }
1196 let value_offset = pass_try!(
1197 base,
1198 scope,
1199 base.push_constant_data
1200 .len()
1201 .try_into()
1202 .map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
1203 );
1204
1205 base.push_constant_data.extend(
1206 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
1207 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
1208 );
1209
1210 base.commands.push(ArcComputeCommand::SetPushConstant {
1211 offset,
1212 size_bytes: data.len() as u32,
1213 values_offset: value_offset,
1214 });
1215
1216 Ok(())
1217 }
1218
1219 pub fn compute_pass_dispatch_workgroups(
1220 &self,
1221 pass: &mut ComputePass,
1222 groups_x: u32,
1223 groups_y: u32,
1224 groups_z: u32,
1225 ) -> Result<(), PassStateError> {
1226 let scope = PassErrorScope::Dispatch { indirect: false };
1227
1228 pass_base!(pass, scope)
1229 .commands
1230 .push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1231
1232 Ok(())
1233 }
1234
1235 pub fn compute_pass_dispatch_workgroups_indirect(
1236 &self,
1237 pass: &mut ComputePass,
1238 buffer_id: id::BufferId,
1239 offset: BufferAddress,
1240 ) -> Result<(), PassStateError> {
1241 let hub = &self.hub;
1242 let scope = PassErrorScope::Dispatch { indirect: true };
1243 let base = pass_base!(pass, scope);
1244
1245 let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
1246
1247 base.commands
1248 .push(ArcComputeCommand::DispatchIndirect { buffer, offset });
1249
1250 Ok(())
1251 }
1252
1253 pub fn compute_pass_push_debug_group(
1254 &self,
1255 pass: &mut ComputePass,
1256 label: &str,
1257 color: u32,
1258 ) -> Result<(), PassStateError> {
1259 let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
1260
1261 let bytes = label.as_bytes();
1262 base.string_data.extend_from_slice(bytes);
1263
1264 base.commands.push(ArcComputeCommand::PushDebugGroup {
1265 color,
1266 len: bytes.len(),
1267 });
1268
1269 Ok(())
1270 }
1271
1272 pub fn compute_pass_pop_debug_group(
1273 &self,
1274 pass: &mut ComputePass,
1275 ) -> Result<(), PassStateError> {
1276 let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
1277
1278 base.commands.push(ArcComputeCommand::PopDebugGroup);
1279
1280 Ok(())
1281 }
1282
1283 pub fn compute_pass_insert_debug_marker(
1284 &self,
1285 pass: &mut ComputePass,
1286 label: &str,
1287 color: u32,
1288 ) -> Result<(), PassStateError> {
1289 let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
1290
1291 let bytes = label.as_bytes();
1292 base.string_data.extend_from_slice(bytes);
1293
1294 base.commands.push(ArcComputeCommand::InsertDebugMarker {
1295 color,
1296 len: bytes.len(),
1297 });
1298
1299 Ok(())
1300 }
1301
1302 pub fn compute_pass_write_timestamp(
1303 &self,
1304 pass: &mut ComputePass,
1305 query_set_id: id::QuerySetId,
1306 query_index: u32,
1307 ) -> Result<(), PassStateError> {
1308 let scope = PassErrorScope::WriteTimestamp;
1309 let base = pass_base!(pass, scope);
1310
1311 let hub = &self.hub;
1312 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1313
1314 base.commands.push(ArcComputeCommand::WriteTimestamp {
1315 query_set,
1316 query_index,
1317 });
1318
1319 Ok(())
1320 }
1321
1322 pub fn compute_pass_begin_pipeline_statistics_query(
1323 &self,
1324 pass: &mut ComputePass,
1325 query_set_id: id::QuerySetId,
1326 query_index: u32,
1327 ) -> Result<(), PassStateError> {
1328 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
1329 let base = pass_base!(pass, scope);
1330
1331 let hub = &self.hub;
1332 let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
1333
1334 base.commands
1335 .push(ArcComputeCommand::BeginPipelineStatisticsQuery {
1336 query_set,
1337 query_index,
1338 });
1339
1340 Ok(())
1341 }
1342
1343 pub fn compute_pass_end_pipeline_statistics_query(
1344 &self,
1345 pass: &mut ComputePass,
1346 ) -> Result<(), PassStateError> {
1347 pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
1348 .commands
1349 .push(ArcComputeCommand::EndPipelineStatisticsQuery);
1350
1351 Ok(())
1352 }
1353}