1use alloc::{
8 sync::{Arc, Weak},
9 vec::Vec,
10};
11
12use hal::BufferBarrier;
13use wgt::{strict_assert, strict_assert_eq, BufferUses};
14
15use super::{PendingTransition, TrackerIndex};
16use crate::{
17 resource::{Buffer, Trackable},
18 snatch::SnatchGuard,
19 track::{
20 invalid_resource_state, skip_barrier, ResourceMetadata, ResourceMetadataProvider,
21 ResourceUsageCompatibilityError, ResourceUses,
22 },
23};
24
25impl ResourceUses for BufferUses {
26 const EXCLUSIVE: Self = Self::EXCLUSIVE;
27
28 type Selector = ();
29
30 fn bits(self) -> u16 {
31 Self::bits(&self)
32 }
33
34 fn any_exclusive(self) -> bool {
35 self.intersects(Self::EXCLUSIVE)
36 }
37}
38
39#[derive(Debug)]
41pub(crate) struct BufferBindGroupState {
42 buffers: Vec<(Arc<Buffer>, BufferUses)>,
43}
44impl BufferBindGroupState {
45 pub fn new() -> Self {
46 Self {
47 buffers: Vec::new(),
48 }
49 }
50
51 pub(crate) fn optimize(&mut self) {
56 self.buffers
57 .sort_unstable_by_key(|(b, _)| b.tracker_index());
58 }
59
60 pub fn used_tracker_indices(&self) -> impl Iterator<Item = TrackerIndex> + '_ {
62 self.buffers
63 .iter()
64 .map(|(b, _)| b.tracker_index())
65 .collect::<Vec<_>>()
66 .into_iter()
67 }
68
69 pub fn insert_single(&mut self, buffer: Arc<Buffer>, state: BufferUses) {
71 self.buffers.push((buffer, state));
72 }
73}
74
75#[derive(Debug)]
77pub(crate) struct BufferUsageScope {
78 state: Vec<BufferUses>,
79 metadata: ResourceMetadata<Arc<Buffer>>,
80 ordered_uses_mask: BufferUses,
81}
82
83impl Default for BufferUsageScope {
84 fn default() -> Self {
85 Self {
86 state: Vec::new(),
87 metadata: ResourceMetadata::new(),
88 ordered_uses_mask: BufferUses::empty(),
89 }
90 }
91}
92
93impl BufferUsageScope {
94 fn tracker_assert_in_bounds(&self, index: usize) {
95 strict_assert!(index < self.state.len());
96 self.metadata.tracker_assert_in_bounds(index);
97 }
98 pub fn clear(&mut self) {
99 self.state.clear();
100 self.metadata.clear();
101 }
102
103 pub fn set_size(&mut self, size: usize) {
108 self.state.resize(size, BufferUses::empty());
109 self.metadata.set_size(size);
110 }
111
112 pub fn set_ordered_uses_mask(&mut self, ordered_uses_mask: BufferUses) {
113 self.ordered_uses_mask = ordered_uses_mask;
114 }
115
116 fn allow_index(&mut self, index: usize) {
118 if index >= self.state.len() {
119 self.set_size(index + 1);
120 }
121 }
122
123 pub unsafe fn merge_bind_group(
136 &mut self,
137 bind_group: &BufferBindGroupState,
138 ) -> Result<(), ResourceUsageCompatibilityError> {
139 for &(ref resource, state) in bind_group.buffers.iter() {
140 let index = resource.tracker_index().as_usize();
141
142 unsafe {
143 self.insert_or_merge(
144 index as _,
145 index,
146 BufferStateProvider::Direct { state },
147 ResourceMetadataProvider::Direct { resource },
148 )?
149 };
150 }
151
152 Ok(())
153 }
154
155 pub fn merge_usage_scope(
163 &mut self,
164 scope: &Self,
165 ) -> Result<(), ResourceUsageCompatibilityError> {
166 let incoming_size = scope.state.len();
167 if incoming_size > self.state.len() {
168 self.set_size(incoming_size);
169 }
170
171 for index in scope.metadata.owned_indices() {
172 self.tracker_assert_in_bounds(index);
173 scope.tracker_assert_in_bounds(index);
174
175 unsafe {
176 self.insert_or_merge(
177 index as u32,
178 index,
179 BufferStateProvider::Indirect {
180 state: &scope.state,
181 },
182 ResourceMetadataProvider::Indirect {
183 metadata: &scope.metadata,
184 },
185 )?;
186 };
187 }
188
189 Ok(())
190 }
191
192 pub fn merge_single(
200 &mut self,
201 buffer: &Arc<Buffer>,
202 new_state: BufferUses,
203 ) -> Result<(), ResourceUsageCompatibilityError> {
204 let index = buffer.tracker_index().as_usize();
205
206 self.allow_index(index);
207
208 self.tracker_assert_in_bounds(index);
209
210 unsafe {
211 self.insert_or_merge(
212 index as _,
213 index,
214 BufferStateProvider::Direct { state: new_state },
215 ResourceMetadataProvider::Direct { resource: buffer },
216 )?;
217 }
218
219 Ok(())
220 }
221
222 #[inline(always)]
232 unsafe fn insert_or_merge(
233 &mut self,
234 index32: u32,
235 index: usize,
236 state_provider: BufferStateProvider<'_>,
237 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
238 ) -> Result<(), ResourceUsageCompatibilityError> {
239 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
240
241 if !currently_owned {
242 unsafe {
243 insert(
244 None,
245 &mut self.state,
246 &mut self.metadata,
247 index,
248 state_provider,
249 None,
250 metadata_provider,
251 )
252 };
253 return Ok(());
254 }
255
256 unsafe {
257 merge(
258 &mut self.state,
259 index32,
260 index,
261 state_provider,
262 metadata_provider,
263 )
264 }
265 }
266
267 pub fn remove_usage(&mut self, buffer: &Buffer, usage: BufferUses) {
273 let index = buffer.tracker_index().as_usize();
274 if self.metadata.contains(index) {
275 unsafe {
278 *self.state.get_unchecked_mut(index) &= !usage;
279 }
280 }
281 }
282}
283
284pub(crate) struct BufferTracker {
286 start: Vec<BufferUses>,
287 end: Vec<BufferUses>,
288
289 metadata: ResourceMetadata<Arc<Buffer>>,
290
291 temp: Vec<PendingTransition<BufferUses>>,
292
293 ordered_uses_mask: BufferUses,
294}
295
296impl BufferTracker {
297 pub fn new(ordered_uses_mask: BufferUses) -> Self {
298 Self {
299 start: Vec::new(),
300 end: Vec::new(),
301
302 metadata: ResourceMetadata::new(),
303
304 temp: Vec::new(),
305
306 ordered_uses_mask,
307 }
308 }
309
310 fn tracker_assert_in_bounds(&self, index: usize) {
311 strict_assert!(index < self.start.len());
312 strict_assert!(index < self.end.len());
313 self.metadata.tracker_assert_in_bounds(index);
314 }
315
316 pub fn set_size(&mut self, size: usize) {
321 self.start.resize(size, BufferUses::empty());
322 self.end.resize(size, BufferUses::empty());
323
324 self.metadata.set_size(size);
325 }
326
327 fn allow_index(&mut self, index: usize) {
329 if index >= self.start.len() {
330 self.set_size(index + 1);
331 }
332 }
333
334 pub fn contains(&self, buffer: &Buffer) -> bool {
336 self.metadata.contains(buffer.tracker_index().as_usize())
337 }
338
339 pub fn used_resources(&self) -> impl Iterator<Item = &Arc<Buffer>> + '_ {
341 self.metadata.owned_resources()
342 }
343
344 pub fn drain_transitions<'a, 'b: 'a>(
346 &'b mut self,
347 snatch_guard: &'a SnatchGuard<'a>,
348 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
349 let buffer_barriers = self.temp.drain(..).map(|pending| {
350 let buf = unsafe { self.metadata.get_resource_unchecked(pending.id as _) };
351 pending.into_hal(buf, snatch_guard)
352 });
353 buffer_barriers
354 }
355
356 pub fn set_single(
364 &mut self,
365 buffer: &Arc<Buffer>,
366 state: BufferUses,
367 ) -> Option<PendingTransition<BufferUses>> {
368 let index: usize = buffer.tracker_index().as_usize();
369
370 self.allow_index(index);
371
372 self.tracker_assert_in_bounds(index);
373
374 unsafe {
375 self.insert_or_barrier_update(
376 index,
377 BufferStateProvider::Direct { state },
378 None,
379 ResourceMetadataProvider::Direct { resource: buffer },
380 )
381 };
382
383 strict_assert!(self.temp.len() <= 1);
384
385 self.temp.pop()
386 }
387
388 pub fn set_from_tracker(&mut self, tracker: &Self) {
397 let incoming_size = tracker.start.len();
398 if incoming_size > self.start.len() {
399 self.set_size(incoming_size);
400 }
401
402 for index in tracker.metadata.owned_indices() {
403 self.tracker_assert_in_bounds(index);
404 tracker.tracker_assert_in_bounds(index);
405 unsafe {
406 self.insert_or_barrier_update(
407 index,
408 BufferStateProvider::Indirect {
409 state: &tracker.start,
410 },
411 Some(BufferStateProvider::Indirect {
412 state: &tracker.end,
413 }),
414 ResourceMetadataProvider::Indirect {
415 metadata: &tracker.metadata,
416 },
417 )
418 }
419 }
420 }
421
422 pub fn set_from_usage_scope(&mut self, scope: &BufferUsageScope) {
431 let incoming_size = scope.state.len();
432 if incoming_size > self.start.len() {
433 self.set_size(incoming_size);
434 }
435
436 for index in scope.metadata.owned_indices() {
437 self.tracker_assert_in_bounds(index);
438 scope.tracker_assert_in_bounds(index);
439 unsafe {
440 self.insert_or_barrier_update(
441 index,
442 BufferStateProvider::Indirect {
443 state: &scope.state,
444 },
445 None,
446 ResourceMetadataProvider::Indirect {
447 metadata: &scope.metadata,
448 },
449 )
450 }
451 }
452 }
453
454 pub fn set_and_remove_from_usage_scope_sparse(
473 &mut self,
474 scope: &mut BufferUsageScope,
475 index_source: impl IntoIterator<Item = TrackerIndex>,
476 ) {
477 let incoming_size = scope.state.len();
478 if incoming_size > self.start.len() {
479 self.set_size(incoming_size);
480 }
481
482 for index in index_source {
483 let index = index.as_usize();
484
485 scope.tracker_assert_in_bounds(index);
486
487 if unsafe { !scope.metadata.contains_unchecked(index) } {
488 continue;
489 }
490
491 unsafe {
494 self.insert_or_barrier_update(
495 index,
496 BufferStateProvider::Indirect {
497 state: &scope.state,
498 },
499 None,
500 ResourceMetadataProvider::Indirect {
501 metadata: &scope.metadata,
502 },
503 )
504 };
505
506 unsafe { scope.metadata.remove(index) };
507 }
508 }
509
510 #[inline(always)]
529 unsafe fn insert_or_barrier_update(
530 &mut self,
531 index: usize,
532 start_state_provider: BufferStateProvider<'_>,
533 end_state_provider: Option<BufferStateProvider<'_>>,
534 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
535 ) {
536 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
537
538 if !currently_owned {
539 unsafe {
540 insert(
541 Some(&mut self.start),
542 &mut self.end,
543 &mut self.metadata,
544 index,
545 start_state_provider,
546 end_state_provider,
547 metadata_provider,
548 )
549 };
550 return;
551 }
552
553 let update_state_provider =
554 end_state_provider.unwrap_or_else(|| start_state_provider.clone());
555 unsafe {
556 barrier(
557 &mut self.end,
558 index,
559 start_state_provider,
560 &mut self.temp,
561 self.ordered_uses_mask,
562 )
563 };
564
565 unsafe { update(&mut self.end, index, update_state_provider) };
566 }
567}
568
569pub(crate) struct DeviceBufferTracker {
571 current_states: Vec<BufferUses>,
572 metadata: ResourceMetadata<Weak<Buffer>>,
573 temp: Vec<PendingTransition<BufferUses>>,
574 ordered_uses_mask: BufferUses,
575}
576
577impl DeviceBufferTracker {
578 pub fn new(ordered_uses_mask: BufferUses) -> Self {
579 Self {
580 current_states: Vec::new(),
581 metadata: ResourceMetadata::new(),
582 temp: Vec::new(),
583 ordered_uses_mask,
584 }
585 }
586
587 fn tracker_assert_in_bounds(&self, index: usize) {
588 strict_assert!(index < self.current_states.len());
589 self.metadata.tracker_assert_in_bounds(index);
590 }
591
592 fn allow_index(&mut self, index: usize) {
594 if index >= self.current_states.len() {
595 self.current_states.resize(index + 1, BufferUses::empty());
596 self.metadata.set_size(index + 1);
597 }
598 }
599
600 pub fn used_resources(&self) -> impl Iterator<Item = &Weak<Buffer>> + '_ {
602 self.metadata.owned_resources()
603 }
604
605 pub fn insert_single(&mut self, buffer: &Arc<Buffer>, state: BufferUses) {
609 let index = buffer.tracker_index().as_usize();
610
611 self.allow_index(index);
612
613 self.tracker_assert_in_bounds(index);
614
615 unsafe {
616 insert(
617 None,
618 &mut self.current_states,
619 &mut self.metadata,
620 index,
621 BufferStateProvider::Direct { state },
622 None,
623 ResourceMetadataProvider::Direct {
624 resource: &Arc::downgrade(buffer),
625 },
626 )
627 }
628 }
629
630 pub fn set_single(
635 &mut self,
636 buffer: &Arc<Buffer>,
637 state: BufferUses,
638 ) -> Option<PendingTransition<BufferUses>> {
639 let index: usize = buffer.tracker_index().as_usize();
640
641 self.tracker_assert_in_bounds(index);
642
643 let start_state_provider = BufferStateProvider::Direct { state };
644
645 unsafe {
646 barrier(
647 &mut self.current_states,
648 index,
649 start_state_provider.clone(),
650 &mut self.temp,
651 self.ordered_uses_mask,
652 )
653 };
654 unsafe { update(&mut self.current_states, index, start_state_provider) };
655
656 strict_assert!(self.temp.len() <= 1);
657
658 self.temp.pop()
659 }
660
661 pub fn set_from_tracker_and_drain_transitions<'a, 'b: 'a>(
666 &'a mut self,
667 tracker: &'a BufferTracker,
668 snatch_guard: &'b SnatchGuard<'b>,
669 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
670 for index in tracker.metadata.owned_indices() {
671 self.tracker_assert_in_bounds(index);
672
673 let start_state_provider = BufferStateProvider::Indirect {
674 state: &tracker.start,
675 };
676 let end_state_provider = BufferStateProvider::Indirect {
677 state: &tracker.end,
678 };
679 unsafe {
680 barrier(
681 &mut self.current_states,
682 index,
683 start_state_provider,
684 &mut self.temp,
685 self.ordered_uses_mask,
686 )
687 };
688 unsafe { update(&mut self.current_states, index, end_state_provider) };
689 }
690
691 self.temp.drain(..).map(|pending| {
692 let buf = unsafe { tracker.metadata.get_resource_unchecked(pending.id as _) };
693 pending.into_hal(buf, snatch_guard)
694 })
695 }
696}
697
698#[derive(Debug, Clone)]
700enum BufferStateProvider<'a> {
701 Direct { state: BufferUses },
703 Indirect { state: &'a [BufferUses] },
705}
706impl BufferStateProvider<'_> {
707 #[inline(always)]
713 unsafe fn get_state(&self, index: usize) -> BufferUses {
714 match *self {
715 BufferStateProvider::Direct { state } => state,
716 BufferStateProvider::Indirect { state } => {
717 strict_assert!(index < state.len());
718 *unsafe { state.get_unchecked(index) }
719 }
720 }
721 }
722}
723
724#[inline(always)]
725unsafe fn insert<T: Clone>(
726 start_states: Option<&mut [BufferUses]>,
727 current_states: &mut [BufferUses],
728 resource_metadata: &mut ResourceMetadata<T>,
729 index: usize,
730 start_state_provider: BufferStateProvider<'_>,
731 end_state_provider: Option<BufferStateProvider<'_>>,
732 metadata_provider: ResourceMetadataProvider<'_, T>,
733) {
734 let new_start_state = unsafe { start_state_provider.get_state(index) };
735 let new_end_state =
736 end_state_provider.map_or(new_start_state, |p| unsafe { p.get_state(index) });
737
738 strict_assert_eq!(invalid_resource_state(new_start_state), false);
741 strict_assert_eq!(invalid_resource_state(new_end_state), false);
742
743 unsafe {
744 if let Some(&mut ref mut start_state) = start_states {
745 *start_state.get_unchecked_mut(index) = new_start_state;
746 }
747 *current_states.get_unchecked_mut(index) = new_end_state;
748
749 let resource = metadata_provider.get(index);
750 resource_metadata.insert(index, resource.clone());
751 }
752}
753
754#[inline(always)]
755unsafe fn merge(
756 current_states: &mut [BufferUses],
757 _index32: u32,
758 index: usize,
759 state_provider: BufferStateProvider<'_>,
760 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
761) -> Result<(), ResourceUsageCompatibilityError> {
762 let current_state = unsafe { current_states.get_unchecked_mut(index) };
763 let new_state = unsafe { state_provider.get_state(index) };
764
765 let merged_state = *current_state | new_state;
766
767 if invalid_resource_state(merged_state) {
768 return Err(ResourceUsageCompatibilityError::from_buffer(
769 unsafe { metadata_provider.get(index) },
770 *current_state,
771 new_state,
772 ));
773 }
774
775 *current_state = merged_state;
776
777 Ok(())
778}
779
780#[inline(always)]
781unsafe fn barrier(
782 current_states: &mut [BufferUses],
783 index: usize,
784 state_provider: BufferStateProvider<'_>,
785 barriers: &mut Vec<PendingTransition<BufferUses>>,
786 ordered_uses_mask: BufferUses,
787) {
788 let current_state = unsafe { *current_states.get_unchecked(index) };
789 let new_state = unsafe { state_provider.get_state(index) };
790
791 if skip_barrier(current_state, ordered_uses_mask, new_state) {
792 return;
793 }
794
795 barriers.push(PendingTransition {
796 id: index as _,
797 selector: (),
798 usage: hal::StateTransition {
799 from: current_state,
800 to: new_state,
801 },
802 });
803}
804
805#[inline(always)]
806unsafe fn update(
807 current_states: &mut [BufferUses],
808 index: usize,
809 state_provider: BufferStateProvider<'_>,
810) {
811 let current_state = unsafe { current_states.get_unchecked_mut(index) };
812 let new_state = unsafe { state_provider.get_state(index) };
813
814 *current_state = new_state;
815}