1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use thiserror::Error;
4
5use crate::{
6 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
7 pipeline::LateSizedBufferGroup,
8 resource::{Labeled, ParentDevice, ResourceErrorIdent},
9};
10
11mod compat {
12 use alloc::{
13 string::{String, ToString as _},
14 sync::{Arc, Weak},
15 vec::Vec,
16 };
17 use core::num::NonZeroU32;
18
19 use thiserror::Error;
20 use wgt::{BindingType, ShaderStages};
21
22 use crate::{
23 binding_model::BindGroupLayout,
24 error::MultiError,
25 resource::{Labeled, ParentDevice, ResourceErrorIdent},
26 };
27
28 pub(crate) enum Error {
29 Incompatible {
30 expected_bgl: ResourceErrorIdent,
31 assigned_bgl: ResourceErrorIdent,
32 inner: MultiError,
33 },
34 Missing,
35 }
36
37 #[derive(Debug, Clone)]
38 struct Entry {
39 assigned: Option<Arc<BindGroupLayout>>,
40 expected: Option<Arc<BindGroupLayout>>,
41 }
42
43 impl Entry {
44 const fn empty() -> Self {
45 Self {
46 assigned: None,
47 expected: None,
48 }
49 }
50
51 fn is_assigned(&self) -> bool {
52 self.assigned.is_some()
53 }
54
55 fn is_active(&self) -> bool {
56 self.assigned.is_some() && self.expected.is_some()
57 }
58
59 fn is_valid(&self) -> bool {
60 if let Some(expected_bgl) = self.expected.as_ref() {
61 if let Some(assigned_bgl) = self.assigned.as_ref() {
62 expected_bgl.is_equal(assigned_bgl)
63 } else {
64 false
65 }
66 } else {
67 false
68 }
69 }
70
71 fn check(&self) -> Result<(), Error> {
72 if let Some(expected_bgl) = self.expected.as_ref() {
73 if let Some(assigned_bgl) = self.assigned.as_ref() {
74 if expected_bgl.is_equal(assigned_bgl) {
75 Ok(())
76 } else {
77 #[derive(Clone, Debug, Error)]
78 #[error(
79 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
80 )]
81 struct IncompatibleExclusivePipelines {
82 expected: String,
83 assigned: String,
84 }
85
86 use crate::binding_model::ExclusivePipeline;
87 match (
88 expected_bgl.exclusive_pipeline.get().unwrap(),
89 assigned_bgl.exclusive_pipeline.get().unwrap(),
90 ) {
91 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
92 (
93 ExclusivePipeline::Render(e_pipeline),
94 ExclusivePipeline::Render(a_pipeline),
95 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96 (
97 ExclusivePipeline::Compute(e_pipeline),
98 ExclusivePipeline::Compute(a_pipeline),
99 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100 (expected, assigned) => {
101 return Err(Error::Incompatible {
102 expected_bgl: expected_bgl.error_ident(),
103 assigned_bgl: assigned_bgl.error_ident(),
104 inner: MultiError::new(core::iter::once(
105 IncompatibleExclusivePipelines {
106 expected: expected.to_string(),
107 assigned: assigned.to_string(),
108 },
109 ))
110 .unwrap(),
111 });
112 }
113 }
114
115 #[derive(Clone, Debug, Error)]
116 enum EntryError {
117 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
118 Visibility {
119 binding: u32,
120 expected: ShaderStages,
121 assigned: ShaderStages,
122 },
123 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
124 Type {
125 binding: u32,
126 expected: BindingType,
127 assigned: BindingType,
128 },
129 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
130 Count {
131 binding: u32,
132 expected: Option<NonZeroU32>,
133 assigned: Option<NonZeroU32>,
134 },
135 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
136 ExtraExpected { binding: u32 },
137 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
138 ExtraAssigned { binding: u32 },
139 }
140
141 let mut errors = Vec::new();
142
143 for (&binding, expected_entry) in expected_bgl.entries.iter() {
144 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
145 if assigned_entry.visibility != expected_entry.visibility {
146 errors.push(EntryError::Visibility {
147 binding,
148 expected: expected_entry.visibility,
149 assigned: assigned_entry.visibility,
150 });
151 }
152 if assigned_entry.ty != expected_entry.ty {
153 errors.push(EntryError::Type {
154 binding,
155 expected: expected_entry.ty,
156 assigned: assigned_entry.ty,
157 });
158 }
159 if assigned_entry.count != expected_entry.count {
160 errors.push(EntryError::Count {
161 binding,
162 expected: expected_entry.count,
163 assigned: assigned_entry.count,
164 });
165 }
166 } else {
167 errors.push(EntryError::ExtraExpected { binding });
168 }
169 }
170
171 for (&binding, _) in assigned_bgl.entries.iter() {
172 if !expected_bgl.entries.contains_key(binding) {
173 errors.push(EntryError::ExtraAssigned { binding });
174 }
175 }
176
177 Err(Error::Incompatible {
178 expected_bgl: expected_bgl.error_ident(),
179 assigned_bgl: assigned_bgl.error_ident(),
180 inner: MultiError::new(errors.drain(..)).unwrap(),
181 })
182 }
183 } else {
184 Err(Error::Missing)
185 }
186 } else {
187 Ok(())
188 }
189 }
190 }
191
192 #[derive(Debug)]
193 pub(super) struct BoundBindGroupLayouts {
194 entries: [Entry; hal::MAX_BIND_GROUPS],
195 rebind_start: usize,
196 }
197
198 impl BoundBindGroupLayouts {
199 pub fn new() -> Self {
200 Self {
201 entries: [const { Entry::empty() }; hal::MAX_BIND_GROUPS],
202 rebind_start: 0,
203 }
204 }
205
206 pub fn take_rebind_start_index(&mut self) -> usize {
208 let start = self.rebind_start;
209 self.rebind_start = self.entries.len();
210 start
211 }
212
213 pub fn update_rebind_start_index(&mut self, start_index: usize) {
214 self.rebind_start = self.rebind_start.min(start_index);
215 }
216
217 pub fn update_expectations(&mut self, expectations: &[Option<Arc<BindGroupLayout>>]) {
218 let mut rebind_start_index = None;
219
220 for (i, (e, new_expected_bgl)) in self
221 .entries
222 .iter_mut()
223 .zip(expectations.iter().chain(core::iter::repeat(&None)))
224 .enumerate()
225 {
226 let (must_set, must_rebind) = match (&mut e.expected, new_expected_bgl) {
227 (None, None) => (false, false),
228 (None, Some(_)) => (true, true),
229 (Some(_), None) => (true, false),
230 (Some(old_expected_bgl), Some(new_expected_bgl)) => {
231 let is_different = !old_expected_bgl.is_equal(new_expected_bgl);
232 (is_different, is_different)
233 }
234 };
235 if must_set {
236 e.expected = new_expected_bgl.clone();
237 }
238 if must_rebind && rebind_start_index.is_none() {
239 rebind_start_index = Some(i);
240 }
241 }
242
243 if let Some(rebind_start_index) = rebind_start_index {
244 self.update_rebind_start_index(rebind_start_index);
245 }
246 }
247
248 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
249 self.entries[index].assigned = Some(value);
250 self.update_rebind_start_index(index);
251 }
252
253 pub fn clear(&mut self, index: usize) {
254 self.entries[index].assigned = None;
255 }
256
257 pub fn list_assigned(&self) -> impl Iterator<Item = usize> + '_ {
258 self.entries
259 .iter()
260 .enumerate()
261 .filter_map(|(i, e)| if e.is_assigned() { Some(i) } else { None })
262 }
263
264 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
265 self.entries
266 .iter()
267 .enumerate()
268 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
269 }
270
271 pub fn list_valid(&self) -> impl Iterator<Item = usize> + '_ {
272 self.entries
273 .iter()
274 .enumerate()
275 .filter_map(|(i, e)| if e.is_valid() { Some(i) } else { None })
276 }
277
278 #[allow(clippy::result_large_err)]
279 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
280 for (index, entry) in self.entries.iter().enumerate() {
281 entry.check().map_err(|e| (index, e))?;
282 }
283 Ok(())
284 }
285 }
286}
287
288#[derive(Clone, Debug, Error)]
289pub enum BinderError {
290 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
291 MissingBindGroup {
292 index: usize,
293 pipeline: ResourceErrorIdent,
294 },
295 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
296 IncompatibleBindGroup {
297 expected_bgl: ResourceErrorIdent,
298 assigned_bgl: ResourceErrorIdent,
299 assigned_bg: ResourceErrorIdent,
300 index: usize,
301 pipeline: ResourceErrorIdent,
302 #[source]
303 inner: crate::error::MultiError,
304 },
305}
306
307#[derive(Debug)]
308struct LateBufferBinding {
309 binding_index: u32,
310 shader_expect_size: wgt::BufferAddress,
311 bound_size: wgt::BufferAddress,
312}
313
314#[derive(Debug, Default)]
315struct EntryPayload {
316 group: Option<Arc<BindGroup>>,
317 dynamic_offsets: Vec<wgt::DynamicOffset>,
318 late_buffer_bindings: Vec<LateBufferBinding>,
319 late_bindings_effective_count: usize,
322}
323
324impl EntryPayload {
325 fn reset(&mut self) {
326 self.group = None;
327 self.dynamic_offsets.clear();
328 self.late_buffer_bindings.clear();
329 self.late_bindings_effective_count = 0;
330 }
331}
332
333#[derive(Debug)]
334pub(super) struct Binder {
335 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
336 manager: compat::BoundBindGroupLayouts,
337 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
338}
339
340impl Binder {
341 pub(super) fn new() -> Self {
342 Self {
343 pipeline_layout: None,
344 manager: compat::BoundBindGroupLayouts::new(),
345 payloads: Default::default(),
346 }
347 }
348 pub(super) fn reset(&mut self) {
349 self.pipeline_layout = None;
350 self.manager = compat::BoundBindGroupLayouts::new();
351 for payload in self.payloads.iter_mut() {
352 payload.reset();
353 }
354 }
355
356 pub(super) fn change_pipeline_layout<'a>(
359 &'a mut self,
360 new: &Arc<PipelineLayout>,
361 late_sized_buffer_groups: &[LateSizedBufferGroup],
362 ) -> bool {
363 self.update_late_buffer_bindings(late_sized_buffer_groups);
364
365 if let Some(old) = self.pipeline_layout.as_ref() {
366 if old.is_equal(new) {
367 return false;
368 }
369 }
370
371 let old = self.pipeline_layout.replace(new.clone());
372
373 self.manager.update_expectations(&new.bind_group_layouts);
374
375 if let Some(old) = old {
376 if old.immediate_size != new.immediate_size {
378 self.manager.update_rebind_start_index(0);
379 }
380 }
381
382 true
383 }
384
385 pub(super) fn assign_group<'a>(
386 &'a mut self,
387 index: usize,
388 bind_group: &Arc<BindGroup>,
389 offsets: &[wgt::DynamicOffset],
390 ) {
391 let payload = &mut self.payloads[index];
392 payload.group = Some(bind_group.clone());
393 payload.dynamic_offsets.clear();
394 payload.dynamic_offsets.extend_from_slice(offsets);
395
396 for (late_binding, late_info) in payload
402 .late_buffer_bindings
403 .iter_mut()
404 .zip(bind_group.late_buffer_binding_infos.iter())
405 {
406 late_binding.binding_index = late_info.binding_index;
407 late_binding.bound_size = late_info.size.get();
408 }
409
410 if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
412 for late_info in
413 bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
414 {
415 payload.late_buffer_bindings.push(LateBufferBinding {
416 binding_index: late_info.binding_index,
417 shader_expect_size: 0,
418 bound_size: late_info.size.get(),
419 });
420 }
421 }
422
423 self.manager.assign(index, bind_group.layout.clone());
424 }
425
426 pub(super) fn clear_group(&mut self, index: usize) {
427 self.payloads[index].reset();
428 self.manager.clear(index);
429 }
430
431 pub(super) fn take_rebind_start_index(&mut self) -> usize {
433 self.manager.take_rebind_start_index()
434 }
435
436 pub(super) fn list_valid_with_start(
437 &self,
438 start: usize,
439 ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
440 let payloads = &self.payloads;
441 self.manager
442 .list_valid()
443 .filter(move |i| *i >= start)
444 .map(move |index| {
445 (
446 index,
447 payloads[index].group.as_ref().unwrap(),
448 payloads[index].dynamic_offsets.as_slice(),
449 )
450 })
451 }
452
453 pub(super) fn last_assigned_index(&self) -> Option<usize> {
454 self.manager.list_assigned().last()
455 }
456
457 pub(super) fn list_active(&self) -> impl Iterator<Item = &Arc<BindGroup>> + '_ {
458 let payloads = &self.payloads;
459 self.manager
460 .list_active()
461 .map(move |index| payloads[index].group.as_ref().unwrap())
462 }
463
464 pub(super) fn list_valid(
465 &self,
466 ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
467 self.list_valid_with_start(0)
468 }
469
470 pub(super) fn check_compatibility<T: Labeled>(
471 &self,
472 pipeline: &T,
473 ) -> Result<(), Box<BinderError>> {
474 self.manager.get_invalid().map_err(|(index, error)| {
475 Box::new(match error {
476 compat::Error::Incompatible {
477 expected_bgl,
478 assigned_bgl,
479 inner,
480 } => BinderError::IncompatibleBindGroup {
481 expected_bgl,
482 assigned_bgl,
483 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
484 index,
485 pipeline: pipeline.error_ident(),
486 inner,
487 },
488 compat::Error::Missing => BinderError::MissingBindGroup {
489 index,
490 pipeline: pipeline.error_ident(),
491 },
492 })
493 })
494 }
495
496 pub(super) fn check_late_buffer_bindings(
498 &self,
499 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
500 for group_index in self.manager.list_active() {
501 let payload = &self.payloads[group_index];
502 for late_binding in
503 &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
504 {
505 if late_binding.bound_size < late_binding.shader_expect_size {
506 return Err(LateMinBufferBindingSizeMismatch {
507 group_index: group_index as u32,
508 binding_index: late_binding.binding_index,
509 shader_size: late_binding.shader_expect_size,
510 bound_size: late_binding.bound_size,
511 });
512 }
513 }
514 }
515 Ok(())
516 }
517
518 fn update_late_buffer_bindings(&mut self, late_sized_buffer_groups: &[LateSizedBufferGroup]) {
522 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
523 payload.late_bindings_effective_count = late_group.shader_sizes.len();
524
525 for (late_binding, &shader_expect_size) in payload
528 .late_buffer_bindings
529 .iter_mut()
530 .zip(late_group.shader_sizes.iter())
531 {
532 late_binding.shader_expect_size = shader_expect_size;
533 }
534
535 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
537 for &shader_expect_size in
538 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
539 {
540 payload.late_buffer_bindings.push(LateBufferBinding {
541 binding_index: 0,
542 shader_expect_size,
543 bound_size: 0,
544 });
545 }
546 }
547 }
548 }
549}