1use alloc::{
8 sync::{Arc, Weak},
9 vec::Vec,
10};
11
12use hal::BufferBarrier;
13use wgt::{strict_assert, strict_assert_eq, BufferUses};
14
15use super::{PendingTransition, TrackerIndex};
16use crate::{
17 resource::{Buffer, Trackable},
18 snatch::SnatchGuard,
19 track::{
20 invalid_resource_state, skip_barrier, ResourceMetadata, ResourceMetadataProvider,
21 ResourceUsageCompatibilityError, ResourceUses,
22 },
23};
24
25impl ResourceUses for BufferUses {
26 const EXCLUSIVE: Self = Self::EXCLUSIVE;
27
28 type Selector = ();
29
30 fn bits(self) -> u16 {
31 Self::bits(&self)
32 }
33
34 fn all_ordered(self) -> bool {
35 Self::ORDERED.contains(self)
36 }
37
38 fn any_exclusive(self) -> bool {
39 self.intersects(Self::EXCLUSIVE)
40 }
41}
42
43#[derive(Debug)]
45pub(crate) struct BufferBindGroupState {
46 buffers: Vec<(Arc<Buffer>, BufferUses)>,
47}
48impl BufferBindGroupState {
49 pub fn new() -> Self {
50 Self {
51 buffers: Vec::new(),
52 }
53 }
54
55 pub(crate) fn optimize(&mut self) {
60 self.buffers
61 .sort_unstable_by_key(|(b, _)| b.tracker_index());
62 }
63
64 pub fn used_tracker_indices(&self) -> impl Iterator<Item = TrackerIndex> + '_ {
66 self.buffers
67 .iter()
68 .map(|(b, _)| b.tracker_index())
69 .collect::<Vec<_>>()
70 .into_iter()
71 }
72
73 pub fn insert_single(&mut self, buffer: Arc<Buffer>, state: BufferUses) {
75 self.buffers.push((buffer, state));
76 }
77}
78
79#[derive(Debug)]
81pub(crate) struct BufferUsageScope {
82 state: Vec<BufferUses>,
83 metadata: ResourceMetadata<Arc<Buffer>>,
84}
85
86impl Default for BufferUsageScope {
87 fn default() -> Self {
88 Self {
89 state: Vec::new(),
90 metadata: ResourceMetadata::new(),
91 }
92 }
93}
94
95impl BufferUsageScope {
96 fn tracker_assert_in_bounds(&self, index: usize) {
97 strict_assert!(index < self.state.len());
98 self.metadata.tracker_assert_in_bounds(index);
99 }
100 pub fn clear(&mut self) {
101 self.state.clear();
102 self.metadata.clear();
103 }
104
105 pub fn set_size(&mut self, size: usize) {
110 self.state.resize(size, BufferUses::empty());
111 self.metadata.set_size(size);
112 }
113
114 fn allow_index(&mut self, index: usize) {
116 if index >= self.state.len() {
117 self.set_size(index + 1);
118 }
119 }
120
121 pub unsafe fn merge_bind_group(
134 &mut self,
135 bind_group: &BufferBindGroupState,
136 ) -> Result<(), ResourceUsageCompatibilityError> {
137 for &(ref resource, state) in bind_group.buffers.iter() {
138 let index = resource.tracker_index().as_usize();
139
140 unsafe {
141 self.insert_or_merge(
142 index as _,
143 index,
144 BufferStateProvider::Direct { state },
145 ResourceMetadataProvider::Direct { resource },
146 )?
147 };
148 }
149
150 Ok(())
151 }
152
153 pub fn merge_usage_scope(
161 &mut self,
162 scope: &Self,
163 ) -> Result<(), ResourceUsageCompatibilityError> {
164 let incoming_size = scope.state.len();
165 if incoming_size > self.state.len() {
166 self.set_size(incoming_size);
167 }
168
169 for index in scope.metadata.owned_indices() {
170 self.tracker_assert_in_bounds(index);
171 scope.tracker_assert_in_bounds(index);
172
173 unsafe {
174 self.insert_or_merge(
175 index as u32,
176 index,
177 BufferStateProvider::Indirect {
178 state: &scope.state,
179 },
180 ResourceMetadataProvider::Indirect {
181 metadata: &scope.metadata,
182 },
183 )?;
184 };
185 }
186
187 Ok(())
188 }
189
190 pub fn merge_single(
198 &mut self,
199 buffer: &Arc<Buffer>,
200 new_state: BufferUses,
201 ) -> Result<(), ResourceUsageCompatibilityError> {
202 let index = buffer.tracker_index().as_usize();
203
204 self.allow_index(index);
205
206 self.tracker_assert_in_bounds(index);
207
208 unsafe {
209 self.insert_or_merge(
210 index as _,
211 index,
212 BufferStateProvider::Direct { state: new_state },
213 ResourceMetadataProvider::Direct { resource: buffer },
214 )?;
215 }
216
217 Ok(())
218 }
219
220 #[inline(always)]
230 unsafe fn insert_or_merge(
231 &mut self,
232 index32: u32,
233 index: usize,
234 state_provider: BufferStateProvider<'_>,
235 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
236 ) -> Result<(), ResourceUsageCompatibilityError> {
237 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
238
239 if !currently_owned {
240 unsafe {
241 insert(
242 None,
243 &mut self.state,
244 &mut self.metadata,
245 index,
246 state_provider,
247 None,
248 metadata_provider,
249 )
250 };
251 return Ok(());
252 }
253
254 unsafe {
255 merge(
256 &mut self.state,
257 index32,
258 index,
259 state_provider,
260 metadata_provider,
261 )
262 }
263 }
264}
265
266pub(crate) struct BufferTracker {
268 start: Vec<BufferUses>,
269 end: Vec<BufferUses>,
270
271 metadata: ResourceMetadata<Arc<Buffer>>,
272
273 temp: Vec<PendingTransition<BufferUses>>,
274}
275
276impl BufferTracker {
277 pub fn new() -> Self {
278 Self {
279 start: Vec::new(),
280 end: Vec::new(),
281
282 metadata: ResourceMetadata::new(),
283
284 temp: Vec::new(),
285 }
286 }
287
288 fn tracker_assert_in_bounds(&self, index: usize) {
289 strict_assert!(index < self.start.len());
290 strict_assert!(index < self.end.len());
291 self.metadata.tracker_assert_in_bounds(index);
292 }
293
294 pub fn set_size(&mut self, size: usize) {
299 self.start.resize(size, BufferUses::empty());
300 self.end.resize(size, BufferUses::empty());
301
302 self.metadata.set_size(size);
303 }
304
305 fn allow_index(&mut self, index: usize) {
307 if index >= self.start.len() {
308 self.set_size(index + 1);
309 }
310 }
311
312 pub fn contains(&self, buffer: &Buffer) -> bool {
314 self.metadata.contains(buffer.tracker_index().as_usize())
315 }
316
317 pub fn used_resources(&self) -> impl Iterator<Item = &Arc<Buffer>> + '_ {
319 self.metadata.owned_resources()
320 }
321
322 pub fn drain_transitions<'a, 'b: 'a>(
324 &'b mut self,
325 snatch_guard: &'a SnatchGuard<'a>,
326 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
327 let buffer_barriers = self.temp.drain(..).map(|pending| {
328 let buf = unsafe { self.metadata.get_resource_unchecked(pending.id as _) };
329 pending.into_hal(buf, snatch_guard)
330 });
331 buffer_barriers
332 }
333
334 pub fn set_single(
342 &mut self,
343 buffer: &Arc<Buffer>,
344 state: BufferUses,
345 ) -> Option<PendingTransition<BufferUses>> {
346 let index: usize = buffer.tracker_index().as_usize();
347
348 self.allow_index(index);
349
350 self.tracker_assert_in_bounds(index);
351
352 unsafe {
353 self.insert_or_barrier_update(
354 index,
355 BufferStateProvider::Direct { state },
356 None,
357 ResourceMetadataProvider::Direct { resource: buffer },
358 )
359 };
360
361 strict_assert!(self.temp.len() <= 1);
362
363 self.temp.pop()
364 }
365
366 pub fn set_from_tracker(&mut self, tracker: &Self) {
375 let incoming_size = tracker.start.len();
376 if incoming_size > self.start.len() {
377 self.set_size(incoming_size);
378 }
379
380 for index in tracker.metadata.owned_indices() {
381 self.tracker_assert_in_bounds(index);
382 tracker.tracker_assert_in_bounds(index);
383 unsafe {
384 self.insert_or_barrier_update(
385 index,
386 BufferStateProvider::Indirect {
387 state: &tracker.start,
388 },
389 Some(BufferStateProvider::Indirect {
390 state: &tracker.end,
391 }),
392 ResourceMetadataProvider::Indirect {
393 metadata: &tracker.metadata,
394 },
395 )
396 }
397 }
398 }
399
400 pub fn set_from_usage_scope(&mut self, scope: &BufferUsageScope) {
409 let incoming_size = scope.state.len();
410 if incoming_size > self.start.len() {
411 self.set_size(incoming_size);
412 }
413
414 for index in scope.metadata.owned_indices() {
415 self.tracker_assert_in_bounds(index);
416 scope.tracker_assert_in_bounds(index);
417 unsafe {
418 self.insert_or_barrier_update(
419 index,
420 BufferStateProvider::Indirect {
421 state: &scope.state,
422 },
423 None,
424 ResourceMetadataProvider::Indirect {
425 metadata: &scope.metadata,
426 },
427 )
428 }
429 }
430 }
431
432 pub unsafe fn set_and_remove_from_usage_scope_sparse(
451 &mut self,
452 scope: &mut BufferUsageScope,
453 index_source: impl IntoIterator<Item = TrackerIndex>,
454 ) {
455 let incoming_size = scope.state.len();
456 if incoming_size > self.start.len() {
457 self.set_size(incoming_size);
458 }
459
460 for index in index_source {
461 let index = index.as_usize();
462
463 scope.tracker_assert_in_bounds(index);
464
465 if unsafe { !scope.metadata.contains_unchecked(index) } {
466 continue;
467 }
468 unsafe {
469 self.insert_or_barrier_update(
470 index,
471 BufferStateProvider::Indirect {
472 state: &scope.state,
473 },
474 None,
475 ResourceMetadataProvider::Indirect {
476 metadata: &scope.metadata,
477 },
478 )
479 };
480
481 unsafe { scope.metadata.remove(index) };
482 }
483 }
484
485 #[inline(always)]
504 unsafe fn insert_or_barrier_update(
505 &mut self,
506 index: usize,
507 start_state_provider: BufferStateProvider<'_>,
508 end_state_provider: Option<BufferStateProvider<'_>>,
509 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
510 ) {
511 let currently_owned = unsafe { self.metadata.contains_unchecked(index) };
512
513 if !currently_owned {
514 unsafe {
515 insert(
516 Some(&mut self.start),
517 &mut self.end,
518 &mut self.metadata,
519 index,
520 start_state_provider,
521 end_state_provider,
522 metadata_provider,
523 )
524 };
525 return;
526 }
527
528 let update_state_provider =
529 end_state_provider.unwrap_or_else(|| start_state_provider.clone());
530 unsafe { barrier(&mut self.end, index, start_state_provider, &mut self.temp) };
531
532 unsafe { update(&mut self.end, index, update_state_provider) };
533 }
534}
535
536pub(crate) struct DeviceBufferTracker {
538 current_states: Vec<BufferUses>,
539 metadata: ResourceMetadata<Weak<Buffer>>,
540 temp: Vec<PendingTransition<BufferUses>>,
541}
542
543impl DeviceBufferTracker {
544 pub fn new() -> Self {
545 Self {
546 current_states: Vec::new(),
547 metadata: ResourceMetadata::new(),
548 temp: Vec::new(),
549 }
550 }
551
552 fn tracker_assert_in_bounds(&self, index: usize) {
553 strict_assert!(index < self.current_states.len());
554 self.metadata.tracker_assert_in_bounds(index);
555 }
556
557 fn allow_index(&mut self, index: usize) {
559 if index >= self.current_states.len() {
560 self.current_states.resize(index + 1, BufferUses::empty());
561 self.metadata.set_size(index + 1);
562 }
563 }
564
565 pub fn used_resources(&self) -> impl Iterator<Item = &Weak<Buffer>> + '_ {
567 self.metadata.owned_resources()
568 }
569
570 pub fn insert_single(&mut self, buffer: &Arc<Buffer>, state: BufferUses) {
574 let index = buffer.tracker_index().as_usize();
575
576 self.allow_index(index);
577
578 self.tracker_assert_in_bounds(index);
579
580 unsafe {
581 insert(
582 None,
583 &mut self.current_states,
584 &mut self.metadata,
585 index,
586 BufferStateProvider::Direct { state },
587 None,
588 ResourceMetadataProvider::Direct {
589 resource: &Arc::downgrade(buffer),
590 },
591 )
592 }
593 }
594
595 pub fn set_single(
600 &mut self,
601 buffer: &Arc<Buffer>,
602 state: BufferUses,
603 ) -> Option<PendingTransition<BufferUses>> {
604 let index: usize = buffer.tracker_index().as_usize();
605
606 self.tracker_assert_in_bounds(index);
607
608 let start_state_provider = BufferStateProvider::Direct { state };
609
610 unsafe {
611 barrier(
612 &mut self.current_states,
613 index,
614 start_state_provider.clone(),
615 &mut self.temp,
616 )
617 };
618 unsafe { update(&mut self.current_states, index, start_state_provider) };
619
620 strict_assert!(self.temp.len() <= 1);
621
622 self.temp.pop()
623 }
624
625 pub fn set_from_tracker_and_drain_transitions<'a, 'b: 'a>(
630 &'a mut self,
631 tracker: &'a BufferTracker,
632 snatch_guard: &'b SnatchGuard<'b>,
633 ) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
634 for index in tracker.metadata.owned_indices() {
635 self.tracker_assert_in_bounds(index);
636
637 let start_state_provider = BufferStateProvider::Indirect {
638 state: &tracker.start,
639 };
640 let end_state_provider = BufferStateProvider::Indirect {
641 state: &tracker.end,
642 };
643 unsafe {
644 barrier(
645 &mut self.current_states,
646 index,
647 start_state_provider,
648 &mut self.temp,
649 )
650 };
651 unsafe { update(&mut self.current_states, index, end_state_provider) };
652 }
653
654 self.temp.drain(..).map(|pending| {
655 let buf = unsafe { tracker.metadata.get_resource_unchecked(pending.id as _) };
656 pending.into_hal(buf, snatch_guard)
657 })
658 }
659}
660
661#[derive(Debug, Clone)]
663enum BufferStateProvider<'a> {
664 Direct { state: BufferUses },
666 Indirect { state: &'a [BufferUses] },
668}
669impl BufferStateProvider<'_> {
670 #[inline(always)]
676 unsafe fn get_state(&self, index: usize) -> BufferUses {
677 match *self {
678 BufferStateProvider::Direct { state } => state,
679 BufferStateProvider::Indirect { state } => {
680 strict_assert!(index < state.len());
681 *unsafe { state.get_unchecked(index) }
682 }
683 }
684 }
685}
686
687#[inline(always)]
688unsafe fn insert<T: Clone>(
689 start_states: Option<&mut [BufferUses]>,
690 current_states: &mut [BufferUses],
691 resource_metadata: &mut ResourceMetadata<T>,
692 index: usize,
693 start_state_provider: BufferStateProvider<'_>,
694 end_state_provider: Option<BufferStateProvider<'_>>,
695 metadata_provider: ResourceMetadataProvider<'_, T>,
696) {
697 let new_start_state = unsafe { start_state_provider.get_state(index) };
698 let new_end_state =
699 end_state_provider.map_or(new_start_state, |p| unsafe { p.get_state(index) });
700
701 strict_assert_eq!(invalid_resource_state(new_start_state), false);
704 strict_assert_eq!(invalid_resource_state(new_end_state), false);
705
706 unsafe {
707 if let Some(&mut ref mut start_state) = start_states {
708 *start_state.get_unchecked_mut(index) = new_start_state;
709 }
710 *current_states.get_unchecked_mut(index) = new_end_state;
711
712 let resource = metadata_provider.get(index);
713 resource_metadata.insert(index, resource.clone());
714 }
715}
716
717#[inline(always)]
718unsafe fn merge(
719 current_states: &mut [BufferUses],
720 _index32: u32,
721 index: usize,
722 state_provider: BufferStateProvider<'_>,
723 metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
724) -> Result<(), ResourceUsageCompatibilityError> {
725 let current_state = unsafe { current_states.get_unchecked_mut(index) };
726 let new_state = unsafe { state_provider.get_state(index) };
727
728 let merged_state = *current_state | new_state;
729
730 if invalid_resource_state(merged_state) {
731 return Err(ResourceUsageCompatibilityError::from_buffer(
732 unsafe { metadata_provider.get(index) },
733 *current_state,
734 new_state,
735 ));
736 }
737
738 *current_state = merged_state;
739
740 Ok(())
741}
742
743#[inline(always)]
744unsafe fn barrier(
745 current_states: &mut [BufferUses],
746 index: usize,
747 state_provider: BufferStateProvider<'_>,
748 barriers: &mut Vec<PendingTransition<BufferUses>>,
749) {
750 let current_state = unsafe { *current_states.get_unchecked(index) };
751 let new_state = unsafe { state_provider.get_state(index) };
752
753 if skip_barrier(current_state, new_state) {
754 return;
755 }
756
757 barriers.push(PendingTransition {
758 id: index as _,
759 selector: (),
760 usage: hal::StateTransition {
761 from: current_state,
762 to: new_state,
763 },
764 });
765}
766
767#[inline(always)]
768unsafe fn update(
769 current_states: &mut [BufferUses],
770 index: usize,
771 state_provider: BufferStateProvider<'_>,
772) {
773 let current_state = unsafe { current_states.get_unchecked_mut(index) };
774 let new_state = unsafe { state_provider.get_state(index) };
775
776 *current_state = new_state;
777}