1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use thiserror::Error;
4
5use crate::{
6 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
7 pipeline::LateSizedBufferGroup,
8 resource::{Labeled, ParentDevice, ResourceErrorIdent},
9};
10
11mod compat {
12 use alloc::{
13 string::{String, ToString as _},
14 sync::{Arc, Weak},
15 vec::Vec,
16 };
17 use core::num::NonZeroU32;
18
19 use thiserror::Error;
20 use wgt::{BindingType, ShaderStages};
21
22 use crate::{
23 binding_model::BindGroupLayout,
24 error::MultiError,
25 resource::{Labeled, ParentDevice, ResourceErrorIdent},
26 };
27
28 pub(crate) enum Error {
29 Incompatible {
30 expected_bgl: ResourceErrorIdent,
31 assigned_bgl: ResourceErrorIdent,
32 inner: MultiError,
33 },
34 Missing,
35 }
36
37 #[derive(Debug, Clone)]
38 struct Entry {
39 assigned: Option<Arc<BindGroupLayout>>,
40 expected: Option<Arc<BindGroupLayout>>,
41 }
42
43 impl Entry {
44 const fn empty() -> Self {
45 Self {
46 assigned: None,
47 expected: None,
48 }
49 }
50
51 fn is_active(&self) -> bool {
52 self.assigned.is_some() && self.expected.is_some()
53 }
54
55 fn is_valid(&self) -> bool {
56 if let Some(expected_bgl) = self.expected.as_ref() {
57 if let Some(assigned_bgl) = self.assigned.as_ref() {
58 expected_bgl.is_equal(assigned_bgl)
59 } else {
60 false
61 }
62 } else {
63 false
64 }
65 }
66
67 fn check(&self) -> Result<(), Error> {
68 if let Some(expected_bgl) = self.expected.as_ref() {
69 if let Some(assigned_bgl) = self.assigned.as_ref() {
70 if expected_bgl.is_equal(assigned_bgl) {
71 Ok(())
72 } else {
73 #[derive(Clone, Debug, Error)]
74 #[error(
75 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
76 )]
77 struct IncompatibleExclusivePipelines {
78 expected: String,
79 assigned: String,
80 }
81
82 use crate::binding_model::ExclusivePipeline;
83 match (
84 expected_bgl.exclusive_pipeline.get().unwrap(),
85 assigned_bgl.exclusive_pipeline.get().unwrap(),
86 ) {
87 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
88 (
89 ExclusivePipeline::Render(e_pipeline),
90 ExclusivePipeline::Render(a_pipeline),
91 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
92 (
93 ExclusivePipeline::Compute(e_pipeline),
94 ExclusivePipeline::Compute(a_pipeline),
95 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96 (expected, assigned) => {
97 return Err(Error::Incompatible {
98 expected_bgl: expected_bgl.error_ident(),
99 assigned_bgl: assigned_bgl.error_ident(),
100 inner: MultiError::new(core::iter::once(
101 IncompatibleExclusivePipelines {
102 expected: expected.to_string(),
103 assigned: assigned.to_string(),
104 },
105 ))
106 .unwrap(),
107 });
108 }
109 }
110
111 #[derive(Clone, Debug, Error)]
112 enum EntryError {
113 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
114 Visibility {
115 binding: u32,
116 expected: ShaderStages,
117 assigned: ShaderStages,
118 },
119 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
120 Type {
121 binding: u32,
122 expected: BindingType,
123 assigned: BindingType,
124 },
125 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
126 Count {
127 binding: u32,
128 expected: Option<NonZeroU32>,
129 assigned: Option<NonZeroU32>,
130 },
131 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
132 ExtraExpected { binding: u32 },
133 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
134 ExtraAssigned { binding: u32 },
135 }
136
137 let mut errors = Vec::new();
138
139 for (&binding, expected_entry) in expected_bgl.entries.iter() {
140 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
141 if assigned_entry.visibility != expected_entry.visibility {
142 errors.push(EntryError::Visibility {
143 binding,
144 expected: expected_entry.visibility,
145 assigned: assigned_entry.visibility,
146 });
147 }
148 if assigned_entry.ty != expected_entry.ty {
149 errors.push(EntryError::Type {
150 binding,
151 expected: expected_entry.ty,
152 assigned: assigned_entry.ty,
153 });
154 }
155 if assigned_entry.count != expected_entry.count {
156 errors.push(EntryError::Count {
157 binding,
158 expected: expected_entry.count,
159 assigned: assigned_entry.count,
160 });
161 }
162 } else {
163 errors.push(EntryError::ExtraExpected { binding });
164 }
165 }
166
167 for (&binding, _) in assigned_bgl.entries.iter() {
168 if !expected_bgl.entries.contains_key(binding) {
169 errors.push(EntryError::ExtraAssigned { binding });
170 }
171 }
172
173 Err(Error::Incompatible {
174 expected_bgl: expected_bgl.error_ident(),
175 assigned_bgl: assigned_bgl.error_ident(),
176 inner: MultiError::new(errors.drain(..)).unwrap(),
177 })
178 }
179 } else {
180 Err(Error::Missing)
181 }
182 } else {
183 Ok(())
184 }
185 }
186 }
187
188 #[derive(Debug)]
189 pub(super) struct BoundBindGroupLayouts {
190 entries: [Entry; hal::MAX_BIND_GROUPS],
191 rebind_start: usize,
192 }
193
194 impl BoundBindGroupLayouts {
195 pub fn new() -> Self {
196 Self {
197 entries: [const { Entry::empty() }; hal::MAX_BIND_GROUPS],
198 rebind_start: 0,
199 }
200 }
201
202 pub fn take_rebind_start_index(&mut self) -> usize {
204 let start = self.rebind_start;
205 self.rebind_start = self.entries.len();
206 start
207 }
208
209 pub fn update_rebind_start_index(&mut self, start_index: usize) {
210 self.rebind_start = self.rebind_start.min(start_index);
211 }
212
213 pub fn update_expectations(&mut self, expectations: &[Option<Arc<BindGroupLayout>>]) {
214 let mut rebind_start_index = None;
215
216 for (i, (e, new_expected_bgl)) in self
217 .entries
218 .iter_mut()
219 .zip(expectations.iter().chain(core::iter::repeat(&None)))
220 .enumerate()
221 {
222 let (must_set, must_rebind) = match (&mut e.expected, new_expected_bgl) {
223 (None, None) => (false, false),
224 (None, Some(_)) => (true, true),
225 (Some(_), None) => (true, false),
226 (Some(old_expected_bgl), Some(new_expected_bgl)) => {
227 let is_different = !old_expected_bgl.is_equal(new_expected_bgl);
228 (is_different, is_different)
229 }
230 };
231 if must_set {
232 e.expected = new_expected_bgl.clone();
233 }
234 if must_rebind && rebind_start_index.is_none() {
235 rebind_start_index = Some(i);
236 }
237 }
238
239 if let Some(rebind_start_index) = rebind_start_index {
240 self.update_rebind_start_index(rebind_start_index);
241 }
242 }
243
244 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
245 self.entries[index].assigned = Some(value);
246 self.update_rebind_start_index(index);
247 }
248
249 pub fn clear(&mut self, index: usize) {
250 self.entries[index].assigned = None;
251 }
252
253 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
254 self.entries
255 .iter()
256 .enumerate()
257 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
258 }
259
260 pub fn list_valid(&self) -> impl Iterator<Item = usize> + '_ {
261 self.entries
262 .iter()
263 .enumerate()
264 .filter_map(|(i, e)| if e.is_valid() { Some(i) } else { None })
265 }
266
267 #[allow(clippy::result_large_err)]
268 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
269 for (index, entry) in self.entries.iter().enumerate() {
270 entry.check().map_err(|e| (index, e))?;
271 }
272 Ok(())
273 }
274 }
275}
276
277#[derive(Clone, Debug, Error)]
278pub enum BinderError {
279 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
280 MissingBindGroup {
281 index: usize,
282 pipeline: ResourceErrorIdent,
283 },
284 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
285 IncompatibleBindGroup {
286 expected_bgl: ResourceErrorIdent,
287 assigned_bgl: ResourceErrorIdent,
288 assigned_bg: ResourceErrorIdent,
289 index: usize,
290 pipeline: ResourceErrorIdent,
291 #[source]
292 inner: crate::error::MultiError,
293 },
294}
295
296#[derive(Debug)]
297struct LateBufferBinding {
298 binding_index: u32,
299 shader_expect_size: wgt::BufferAddress,
300 bound_size: wgt::BufferAddress,
301}
302
303#[derive(Debug, Default)]
304struct EntryPayload {
305 group: Option<Arc<BindGroup>>,
306 dynamic_offsets: Vec<wgt::DynamicOffset>,
307 late_buffer_bindings: Vec<LateBufferBinding>,
308 late_bindings_effective_count: usize,
311}
312
313impl EntryPayload {
314 fn reset(&mut self) {
315 self.group = None;
316 self.dynamic_offsets.clear();
317 self.late_buffer_bindings.clear();
318 self.late_bindings_effective_count = 0;
319 }
320}
321
322#[derive(Debug)]
323pub(super) struct Binder {
324 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
325 manager: compat::BoundBindGroupLayouts,
326 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
327}
328
329impl Binder {
330 pub(super) fn new() -> Self {
331 Self {
332 pipeline_layout: None,
333 manager: compat::BoundBindGroupLayouts::new(),
334 payloads: Default::default(),
335 }
336 }
337 pub(super) fn reset(&mut self) {
338 self.pipeline_layout = None;
339 self.manager = compat::BoundBindGroupLayouts::new();
340 for payload in self.payloads.iter_mut() {
341 payload.reset();
342 }
343 }
344
345 pub(super) fn change_pipeline_layout<'a>(
348 &'a mut self,
349 new: &Arc<PipelineLayout>,
350 late_sized_buffer_groups: &[LateSizedBufferGroup],
351 ) -> bool {
352 if let Some(old) = self.pipeline_layout.as_ref() {
353 if old.is_equal(new) {
354 return false;
355 }
356 }
357
358 let old = self.pipeline_layout.replace(new.clone());
359
360 self.manager.update_expectations(&new.bind_group_layouts);
361
362 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
365 payload.late_bindings_effective_count = late_group.shader_sizes.len();
366 for (late_binding, &shader_expect_size) in payload
369 .late_buffer_bindings
370 .iter_mut()
371 .zip(late_group.shader_sizes.iter())
372 {
373 late_binding.shader_expect_size = shader_expect_size;
374 }
375 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
377 for &shader_expect_size in
378 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
379 {
380 payload.late_buffer_bindings.push(LateBufferBinding {
381 binding_index: 0,
382 shader_expect_size,
383 bound_size: 0,
384 });
385 }
386 }
387 }
388
389 if let Some(old) = old {
390 if old.immediate_size != new.immediate_size {
392 self.manager.update_rebind_start_index(0);
393 }
394 }
395
396 true
397 }
398
399 pub(super) fn assign_group<'a>(
400 &'a mut self,
401 index: usize,
402 bind_group: &Arc<BindGroup>,
403 offsets: &[wgt::DynamicOffset],
404 ) {
405 let payload = &mut self.payloads[index];
406 payload.group = Some(bind_group.clone());
407 payload.dynamic_offsets.clear();
408 payload.dynamic_offsets.extend_from_slice(offsets);
409
410 for (late_binding, late_info) in payload
416 .late_buffer_bindings
417 .iter_mut()
418 .zip(bind_group.late_buffer_binding_infos.iter())
419 {
420 late_binding.binding_index = late_info.binding_index;
421 late_binding.bound_size = late_info.size.get();
422 }
423
424 if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
426 for late_info in
427 bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
428 {
429 payload.late_buffer_bindings.push(LateBufferBinding {
430 binding_index: late_info.binding_index,
431 shader_expect_size: 0,
432 bound_size: late_info.size.get(),
433 });
434 }
435 }
436
437 self.manager.assign(index, bind_group.layout.clone());
438 }
439
440 pub(super) fn clear_group(&mut self, index: usize) {
441 self.payloads[index].reset();
442 self.manager.clear(index);
443 }
444
445 pub(super) fn take_rebind_start_index(&mut self) -> usize {
447 self.manager.take_rebind_start_index()
448 }
449
450 pub(super) fn list_valid_with_start(
451 &self,
452 start: usize,
453 ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
454 let payloads = &self.payloads;
455 self.manager
456 .list_valid()
457 .filter(move |i| *i >= start)
458 .map(move |index| {
459 (
460 index,
461 payloads[index].group.as_ref().unwrap(),
462 payloads[index].dynamic_offsets.as_slice(),
463 )
464 })
465 }
466
467 pub(super) fn list_active(&self) -> impl Iterator<Item = &Arc<BindGroup>> + '_ {
468 let payloads = &self.payloads;
469 self.manager
470 .list_active()
471 .map(move |index| payloads[index].group.as_ref().unwrap())
472 }
473
474 pub(super) fn list_valid(
475 &self,
476 ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
477 self.list_valid_with_start(0)
478 }
479
480 pub(super) fn check_compatibility<T: Labeled>(
481 &self,
482 pipeline: &T,
483 ) -> Result<(), Box<BinderError>> {
484 self.manager.get_invalid().map_err(|(index, error)| {
485 Box::new(match error {
486 compat::Error::Incompatible {
487 expected_bgl,
488 assigned_bgl,
489 inner,
490 } => BinderError::IncompatibleBindGroup {
491 expected_bgl,
492 assigned_bgl,
493 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
494 index,
495 pipeline: pipeline.error_ident(),
496 inner,
497 },
498 compat::Error::Missing => BinderError::MissingBindGroup {
499 index,
500 pipeline: pipeline.error_ident(),
501 },
502 })
503 })
504 }
505
506 pub(super) fn check_late_buffer_bindings(
508 &self,
509 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
510 for group_index in self.manager.list_active() {
511 let payload = &self.payloads[group_index];
512 for late_binding in
513 &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
514 {
515 if late_binding.bound_size < late_binding.shader_expect_size {
516 return Err(LateMinBufferBindingSizeMismatch {
517 group_index: group_index as u32,
518 binding_index: late_binding.binding_index,
519 shader_size: late_binding.shader_expect_size,
520 bound_size: late_binding.bound_size,
521 });
522 }
523 }
524 }
525 Ok(())
526 }
527}