1use alloc::{
8 sync::{Arc, Weak},
9 vec::Vec,
10};
11
12use hal::BufferBarrier;
13use wgt::{strict_assert, strict_assert_eq, BufferUses};
14
15use super::{PendingTransition, TrackerIndex};
16use crate::{
17 resource::{Buffer, Trackable},
18 snatch::SnatchGuard,
19 track::{
20 invalid_resource_state, skip_barrier, ResourceMetadata, ResourceMetadataProvider,
21 ResourceUsageCompatibilityError, ResourceUses,
22 },
23};
24
25impl ResourceUses for BufferUses {
26 const EXCLUSIVE: Self = Self::EXCLUSIVE;
27
28 type Selector = ();
29
30 fn bits(self) -> u16 {
31 Self::bits(&self)
32 }
33
34 fn all_ordered(self) -> bool {
35 Self::ORDERED.contains(self)
36 }
37
38 fn any_exclusive(self) -> bool {
39 self.intersects(Self::EXCLUSIVE)
40 }
41}
42
43#[derive(Debug)]
45pub(crate) struct BufferBindGroupState {
46 buffers: Vec<(Arc<Buffer>, BufferUses)>,
47}
48impl BufferBindGroupState {
49 pub fn new() -> Self {
50 Self {
51 buffers: Vec::new(),
52 }
53 }
54
55 pub(crate) fn optimize(&mut self) {
60 self.buffers
61 .sort_unstable_by_key(|(b, _)| b.tracker_index());
62 }
63
64 pub fn used_tracker_indices(&self) -> impl Iterator<Item = TrackerIndex> + '_ {
66 self.buffers
67 .iter()
68 .map(|(b, _)| b.tracker_index())
69 .collect::<Vec<_>>()
70 .into_iter()
71 }
72
73 pub fn insert_single(&mut self, buffer: Arc<Buffer>, state: BufferUses) {
75 self.buffers.push((buffer, state));
76 }
77}
78
79#[derive(Debug)]
81pub(crate) struct BufferUsageScope {
82 state: Vec<BufferUses>,
83 metadata: ResourceMetadata<Arc<Buffer>>,
84}
85
86impl Default for BufferUsageScope {
87 fn default() -> Self {
88 Self {
89 state: Vec::new(),
90 metadata: ResourceMetadata::new(),
91 }
92 }
93}
94
95impl BufferUsageScope {
96 fn tracker_assert_in_bounds(&self, index: usize) {
97 strict_assert!(index < self.state.len());
98 self.metadata.tracker_assert_in_bounds(index);
99 }
100 pub fn clear(&mut self) {
101 self.state.clear();
102 self.metadata.clear();
103 }
104
105 pub fn set_size(&mut self, size: usize) {
110 self.state.resize(size, BufferUses::empty());
111 self.metadata.set_size(size);
112 }
113
114 fn allow_index(&mut self, index: usize) {
116 if index >= self.state.len() {
117 self.set_size(index + 1);
118 }
119 }
120
121 pub unsafe fn merge_bind_group(
134 &mut self,
135 bind_group: &BufferBindGroupState,
136 ) -> Result<(), ResourceUsageCompatibilityError> {
137 for &(ref resource, state) in bind_group.buffers.iter() {
138 let index = resource.tracker_index().as_usize();
139
140 unsafe {
141 self.insert_or_merge(
142 index as _,
143 index,
144 BufferStateProvider::Direct { state },
145 ResourceMetadataProvider::Direct { resource },
146 )?
147 };
148 }
149
150 Ok(())
151 }
152
153 pub fn merge_usage_scope(
161 &mut self,
162 scope: &Self,
163 ) -> Result<(), ResourceUsageCompatibilityError> {
164 let incoming_size = scope.state.len();
165 if incoming_size > self.state.len() {
166 self.set_size(incoming_size);
167 }
168
169 for index in scope.metadata.owned_indices() {
170 self.tracker_assert_in_bounds(index);
171 scope.tracker_assert_in_bounds(index);
172
173 unsafe {
174 self.insert_or_merge(
175 index as u32,
176 index,
177 BufferStateProvider::Indirect {
178 state: &scope.state,
179 },
180 ResourceMetadataProvider::Indirect {
181 metadata: &scope.metadata,
182 },
183 )?;
184 };
185 }
186
187 Ok(())
188 }
189
190 pub fn merge_single(
198 &mut self,
199 buffer: &Arc<Buffer>,
200 new_state: BufferUses,
201 ) -> Result<(), ResourceUsageCompatibilityError> {
202 let index = buffer.tracker_index().as_usize();
203
204 self.allow_index(index);
205
206 self.tracker_assert_in_bounds(index);
207
208 unsafe {
209 self.insert_or_merge(
210 index as _,
211 index,
212 BufferStateProvider::Direct { state: new_state },
213 ResourceMetadataProvider::Direct { resource: buffer },
214 )?;
215 }
216
217 Ok(())
218 }
219
220 #[inline(always)]
230 unsafe fn insert_or_merge(
231 &mut self,
232 index32: u32,
233 index: usize,
234 state_provider: BufferStateProvider<'_>,
235 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
236 ) -> Result<(), ResourceUsageCompatibilityError> {
237 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
238
239 if !currently_owned {
240 unsafe {
241 insert(
242 None,
243 &mut self.state,
244 &mut self.metadata,
245 index,
246 state_provider,
247 None,
248 metadata_provider,
249 )
250 };
251 return Ok(());
252 }
253
254 unsafe {
255 merge(
256 &mut self.state,
257 index32,
258 index,
259 state_provider,
260 metadata_provider,
261 )
262 }
263 }
264
265 pub fn remove_usage(&mut self, buffer: &Buffer, usage: BufferUses) {
271 let index = buffer.tracker_index().as_usize();
272 if self.metadata.contains(index) {
273 unsafe {
276 *self.state.get_unchecked_mut(index) &= !usage;
277 }
278 }
279 }
280}
281
282pub(crate) struct BufferTracker {
284 start: Vec<BufferUses>,
285 end: Vec<BufferUses>,
286
287 metadata: ResourceMetadata<Arc<Buffer>>,
288
289 temp: Vec<PendingTransition<BufferUses>>,
290}
291
292impl BufferTracker {
293 pub fn new() -> Self {
294 Self {
295 start: Vec::new(),
296 end: Vec::new(),
297
298 metadata: ResourceMetadata::new(),
299
300 temp: Vec::new(),
301 }
302 }
303
304 fn tracker_assert_in_bounds(&self, index: usize) {
305 strict_assert!(index < self.start.len());
306 strict_assert!(index < self.end.len());
307 self.metadata.tracker_assert_in_bounds(index);
308 }
309
310 pub fn set_size(&mut self, size: usize) {
315 self.start.resize(size, BufferUses::empty());
316 self.end.resize(size, BufferUses::empty());
317
318 self.metadata.set_size(size);
319 }
320
321 fn allow_index(&mut self, index: usize) {
323 if index >= self.start.len() {
324 self.set_size(index + 1);
325 }
326 }
327
328 pub fn contains(&self, buffer: &Buffer) -> bool {
330 self.metadata.contains(buffer.tracker_index().as_usize())
331 }
332
333 pub fn used_resources(&self) -> impl Iterator<Item = &Arc<Buffer>> + '_ {
335 self.metadata.owned_resources()
336 }
337
338 pub fn drain_transitions<'a, 'b: 'a>(
340 &'b mut self,
341 snatch_guard: &'a SnatchGuard<'a>,
342 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
343 let buffer_barriers = self.temp.drain(..).map(|pending| {
344 let buf = unsafe { self.metadata.get_resource_unchecked(pending.id as _) };
345 pending.into_hal(buf, snatch_guard)
346 });
347 buffer_barriers
348 }
349
350 pub fn set_single(
358 &mut self,
359 buffer: &Arc<Buffer>,
360 state: BufferUses,
361 ) -> Option<PendingTransition<BufferUses>> {
362 let index: usize = buffer.tracker_index().as_usize();
363
364 self.allow_index(index);
365
366 self.tracker_assert_in_bounds(index);
367
368 unsafe {
369 self.insert_or_barrier_update(
370 index,
371 BufferStateProvider::Direct { state },
372 None,
373 ResourceMetadataProvider::Direct { resource: buffer },
374 )
375 };
376
377 strict_assert!(self.temp.len() <= 1);
378
379 self.temp.pop()
380 }
381
382 pub fn set_from_tracker(&mut self, tracker: &Self) {
391 let incoming_size = tracker.start.len();
392 if incoming_size > self.start.len() {
393 self.set_size(incoming_size);
394 }
395
396 for index in tracker.metadata.owned_indices() {
397 self.tracker_assert_in_bounds(index);
398 tracker.tracker_assert_in_bounds(index);
399 unsafe {
400 self.insert_or_barrier_update(
401 index,
402 BufferStateProvider::Indirect {
403 state: &tracker.start,
404 },
405 Some(BufferStateProvider::Indirect {
406 state: &tracker.end,
407 }),
408 ResourceMetadataProvider::Indirect {
409 metadata: &tracker.metadata,
410 },
411 )
412 }
413 }
414 }
415
416 pub fn set_from_usage_scope(&mut self, scope: &BufferUsageScope) {
425 let incoming_size = scope.state.len();
426 if incoming_size > self.start.len() {
427 self.set_size(incoming_size);
428 }
429
430 for index in scope.metadata.owned_indices() {
431 self.tracker_assert_in_bounds(index);
432 scope.tracker_assert_in_bounds(index);
433 unsafe {
434 self.insert_or_barrier_update(
435 index,
436 BufferStateProvider::Indirect {
437 state: &scope.state,
438 },
439 None,
440 ResourceMetadataProvider::Indirect {
441 metadata: &scope.metadata,
442 },
443 )
444 }
445 }
446 }
447
448 pub fn set_and_remove_from_usage_scope_sparse(
467 &mut self,
468 scope: &mut BufferUsageScope,
469 index_source: impl IntoIterator<Item = TrackerIndex>,
470 ) {
471 let incoming_size = scope.state.len();
472 if incoming_size > self.start.len() {
473 self.set_size(incoming_size);
474 }
475
476 for index in index_source {
477 let index = index.as_usize();
478
479 scope.tracker_assert_in_bounds(index);
480
481 if unsafe { !scope.metadata.contains_unchecked(index) } {
482 continue;
483 }
484
485 unsafe {
488 self.insert_or_barrier_update(
489 index,
490 BufferStateProvider::Indirect {
491 state: &scope.state,
492 },
493 None,
494 ResourceMetadataProvider::Indirect {
495 metadata: &scope.metadata,
496 },
497 )
498 };
499
500 unsafe { scope.metadata.remove(index) };
501 }
502 }
503
504 #[inline(always)]
523 unsafe fn insert_or_barrier_update(
524 &mut self,
525 index: usize,
526 start_state_provider: BufferStateProvider<'_>,
527 end_state_provider: Option<BufferStateProvider<'_>>,
528 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
529 ) {
530 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
531
532 if !currently_owned {
533 unsafe {
534 insert(
535 Some(&mut self.start),
536 &mut self.end,
537 &mut self.metadata,
538 index,
539 start_state_provider,
540 end_state_provider,
541 metadata_provider,
542 )
543 };
544 return;
545 }
546
547 let update_state_provider =
548 end_state_provider.unwrap_or_else(|| start_state_provider.clone());
549 unsafe { barrier(&mut self.end, index, start_state_provider, &mut self.temp) };
550
551 unsafe { update(&mut self.end, index, update_state_provider) };
552 }
553}
554
555pub(crate) struct DeviceBufferTracker {
557 current_states: Vec<BufferUses>,
558 metadata: ResourceMetadata<Weak<Buffer>>,
559 temp: Vec<PendingTransition<BufferUses>>,
560}
561
562impl DeviceBufferTracker {
563 pub fn new() -> Self {
564 Self {
565 current_states: Vec::new(),
566 metadata: ResourceMetadata::new(),
567 temp: Vec::new(),
568 }
569 }
570
571 fn tracker_assert_in_bounds(&self, index: usize) {
572 strict_assert!(index < self.current_states.len());
573 self.metadata.tracker_assert_in_bounds(index);
574 }
575
576 fn allow_index(&mut self, index: usize) {
578 if index >= self.current_states.len() {
579 self.current_states.resize(index + 1, BufferUses::empty());
580 self.metadata.set_size(index + 1);
581 }
582 }
583
584 pub fn used_resources(&self) -> impl Iterator<Item = &Weak<Buffer>> + '_ {
586 self.metadata.owned_resources()
587 }
588
589 pub fn insert_single(&mut self, buffer: &Arc<Buffer>, state: BufferUses) {
593 let index = buffer.tracker_index().as_usize();
594
595 self.allow_index(index);
596
597 self.tracker_assert_in_bounds(index);
598
599 unsafe {
600 insert(
601 None,
602 &mut self.current_states,
603 &mut self.metadata,
604 index,
605 BufferStateProvider::Direct { state },
606 None,
607 ResourceMetadataProvider::Direct {
608 resource: &Arc::downgrade(buffer),
609 },
610 )
611 }
612 }
613
614 pub fn set_single(
619 &mut self,
620 buffer: &Arc<Buffer>,
621 state: BufferUses,
622 ) -> Option<PendingTransition<BufferUses>> {
623 let index: usize = buffer.tracker_index().as_usize();
624
625 self.tracker_assert_in_bounds(index);
626
627 let start_state_provider = BufferStateProvider::Direct { state };
628
629 unsafe {
630 barrier(
631 &mut self.current_states,
632 index,
633 start_state_provider.clone(),
634 &mut self.temp,
635 )
636 };
637 unsafe { update(&mut self.current_states, index, start_state_provider) };
638
639 strict_assert!(self.temp.len() <= 1);
640
641 self.temp.pop()
642 }
643
644 pub fn set_from_tracker_and_drain_transitions<'a, 'b: 'a>(
649 &'a mut self,
650 tracker: &'a BufferTracker,
651 snatch_guard: &'b SnatchGuard<'b>,
652 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
653 for index in tracker.metadata.owned_indices() {
654 self.tracker_assert_in_bounds(index);
655
656 let start_state_provider = BufferStateProvider::Indirect {
657 state: &tracker.start,
658 };
659 let end_state_provider = BufferStateProvider::Indirect {
660 state: &tracker.end,
661 };
662 unsafe {
663 barrier(
664 &mut self.current_states,
665 index,
666 start_state_provider,
667 &mut self.temp,
668 )
669 };
670 unsafe { update(&mut self.current_states, index, end_state_provider) };
671 }
672
673 self.temp.drain(..).map(|pending| {
674 let buf = unsafe { tracker.metadata.get_resource_unchecked(pending.id as _) };
675 pending.into_hal(buf, snatch_guard)
676 })
677 }
678}
679
680#[derive(Debug, Clone)]
682enum BufferStateProvider<'a> {
683 Direct { state: BufferUses },
685 Indirect { state: &'a [BufferUses] },
687}
688impl BufferStateProvider<'_> {
689 #[inline(always)]
695 unsafe fn get_state(&self, index: usize) -> BufferUses {
696 match *self {
697 BufferStateProvider::Direct { state } => state,
698 BufferStateProvider::Indirect { state } => {
699 strict_assert!(index < state.len());
700 *unsafe { state.get_unchecked(index) }
701 }
702 }
703 }
704}
705
706#[inline(always)]
707unsafe fn insert<T: Clone>(
708 start_states: Option<&mut [BufferUses]>,
709 current_states: &mut [BufferUses],
710 resource_metadata: &mut ResourceMetadata<T>,
711 index: usize,
712 start_state_provider: BufferStateProvider<'_>,
713 end_state_provider: Option<BufferStateProvider<'_>>,
714 metadata_provider: ResourceMetadataProvider<'_, T>,
715) {
716 let new_start_state = unsafe { start_state_provider.get_state(index) };
717 let new_end_state =
718 end_state_provider.map_or(new_start_state, |p| unsafe { p.get_state(index) });
719
720 strict_assert_eq!(invalid_resource_state(new_start_state), false);
723 strict_assert_eq!(invalid_resource_state(new_end_state), false);
724
725 unsafe {
726 if let Some(&mut ref mut start_state) = start_states {
727 *start_state.get_unchecked_mut(index) = new_start_state;
728 }
729 *current_states.get_unchecked_mut(index) = new_end_state;
730
731 let resource = metadata_provider.get(index);
732 resource_metadata.insert(index, resource.clone());
733 }
734}
735
736#[inline(always)]
737unsafe fn merge(
738 current_states: &mut [BufferUses],
739 _index32: u32,
740 index: usize,
741 state_provider: BufferStateProvider<'_>,
742 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
743) -> Result<(), ResourceUsageCompatibilityError> {
744 let current_state = unsafe { current_states.get_unchecked_mut(index) };
745 let new_state = unsafe { state_provider.get_state(index) };
746
747 let merged_state = *current_state | new_state;
748
749 if invalid_resource_state(merged_state) {
750 return Err(ResourceUsageCompatibilityError::from_buffer(
751 unsafe { metadata_provider.get(index) },
752 *current_state,
753 new_state,
754 ));
755 }
756
757 *current_state = merged_state;
758
759 Ok(())
760}
761
762#[inline(always)]
763unsafe fn barrier(
764 current_states: &mut [BufferUses],
765 index: usize,
766 state_provider: BufferStateProvider<'_>,
767 barriers: &mut Vec<PendingTransition<BufferUses>>,
768) {
769 let current_state = unsafe { *current_states.get_unchecked(index) };
770 let new_state = unsafe { state_provider.get_state(index) };
771
772 if skip_barrier(current_state, new_state) {
773 return;
774 }
775
776 barriers.push(PendingTransition {
777 id: index as _,
778 selector: (),
779 usage: hal::StateTransition {
780 from: current_state,
781 to: new_state,
782 },
783 });
784}
785
786#[inline(always)]
787unsafe fn update(
788 current_states: &mut [BufferUses],
789 index: usize,
790 state_provider: BufferStateProvider<'_>,
791) {
792 let current_state = unsafe { current_states.get_unchecked_mut(index) };
793 let new_state = unsafe { state_provider.get_state(index) };
794
795 *current_state = new_state;
796}