1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use thiserror::Error;
6
7use crate::{
8 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
9 pipeline::LateSizedBufferGroup,
10 resource::{Labeled, ParentDevice, ResourceErrorIdent},
11};
12
13mod compat {
14 use alloc::{
15 string::{String, ToString as _},
16 sync::{Arc, Weak},
17 vec::Vec,
18 };
19 use core::{num::NonZeroU32, ops::Range};
20
21 use arrayvec::ArrayVec;
22 use thiserror::Error;
23 use wgt::{BindingType, ShaderStages};
24
25 use crate::{
26 binding_model::BindGroupLayout,
27 error::MultiError,
28 resource::{Labeled, ParentDevice, ResourceErrorIdent},
29 };
30
31 pub(crate) enum Error {
32 Incompatible {
33 expected_bgl: ResourceErrorIdent,
34 assigned_bgl: ResourceErrorIdent,
35 inner: MultiError,
36 },
37 Missing,
38 }
39
40 #[derive(Debug, Clone)]
41 struct Entry {
42 assigned: Option<Arc<BindGroupLayout>>,
43 expected: Option<Arc<BindGroupLayout>>,
44 }
45
46 impl Entry {
47 fn empty() -> Self {
48 Self {
49 assigned: None,
50 expected: None,
51 }
52 }
53 fn is_active(&self) -> bool {
54 self.assigned.is_some() && self.expected.is_some()
55 }
56
57 fn is_valid(&self) -> bool {
58 if let Some(expected_bgl) = self.expected.as_ref() {
59 if let Some(assigned_bgl) = self.assigned.as_ref() {
60 expected_bgl.is_equal(assigned_bgl)
61 } else {
62 false
63 }
64 } else {
65 true
66 }
67 }
68
69 fn is_incompatible(&self) -> bool {
70 self.expected.is_none() || !self.is_valid()
71 }
72
73 fn check(&self) -> Result<(), Error> {
74 if let Some(expected_bgl) = self.expected.as_ref() {
75 if let Some(assigned_bgl) = self.assigned.as_ref() {
76 if expected_bgl.is_equal(assigned_bgl) {
77 Ok(())
78 } else {
79 #[derive(Clone, Debug, Error)]
80 #[error(
81 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
82 )]
83 struct IncompatibleExclusivePipelines {
84 expected: String,
85 assigned: String,
86 }
87
88 use crate::binding_model::ExclusivePipeline;
89 match (
90 expected_bgl.exclusive_pipeline.get().unwrap(),
91 assigned_bgl.exclusive_pipeline.get().unwrap(),
92 ) {
93 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
94 (
95 ExclusivePipeline::Render(e_pipeline),
96 ExclusivePipeline::Render(a_pipeline),
97 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
98 (
99 ExclusivePipeline::Compute(e_pipeline),
100 ExclusivePipeline::Compute(a_pipeline),
101 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
102 (expected, assigned) => {
103 return Err(Error::Incompatible {
104 expected_bgl: expected_bgl.error_ident(),
105 assigned_bgl: assigned_bgl.error_ident(),
106 inner: MultiError::new(core::iter::once(
107 IncompatibleExclusivePipelines {
108 expected: expected.to_string(),
109 assigned: assigned.to_string(),
110 },
111 ))
112 .unwrap(),
113 });
114 }
115 }
116
117 #[derive(Clone, Debug, Error)]
118 enum EntryError {
119 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
120 Visibility {
121 binding: u32,
122 expected: ShaderStages,
123 assigned: ShaderStages,
124 },
125 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
126 Type {
127 binding: u32,
128 expected: BindingType,
129 assigned: BindingType,
130 },
131 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
132 Count {
133 binding: u32,
134 expected: Option<NonZeroU32>,
135 assigned: Option<NonZeroU32>,
136 },
137 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
138 ExtraExpected { binding: u32 },
139 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
140 ExtraAssigned { binding: u32 },
141 }
142
143 let mut errors = Vec::new();
144
145 for (&binding, expected_entry) in expected_bgl.entries.iter() {
146 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
147 if assigned_entry.visibility != expected_entry.visibility {
148 errors.push(EntryError::Visibility {
149 binding,
150 expected: expected_entry.visibility,
151 assigned: assigned_entry.visibility,
152 });
153 }
154 if assigned_entry.ty != expected_entry.ty {
155 errors.push(EntryError::Type {
156 binding,
157 expected: expected_entry.ty,
158 assigned: assigned_entry.ty,
159 });
160 }
161 if assigned_entry.count != expected_entry.count {
162 errors.push(EntryError::Count {
163 binding,
164 expected: expected_entry.count,
165 assigned: assigned_entry.count,
166 });
167 }
168 } else {
169 errors.push(EntryError::ExtraExpected { binding });
170 }
171 }
172
173 for (&binding, _) in assigned_bgl.entries.iter() {
174 if !expected_bgl.entries.contains_key(binding) {
175 errors.push(EntryError::ExtraAssigned { binding });
176 }
177 }
178
179 Err(Error::Incompatible {
180 expected_bgl: expected_bgl.error_ident(),
181 assigned_bgl: assigned_bgl.error_ident(),
182 inner: MultiError::new(errors.drain(..)).unwrap(),
183 })
184 }
185 } else {
186 Err(Error::Missing)
187 }
188 } else {
189 Ok(())
190 }
191 }
192 }
193
194 #[derive(Debug, Default)]
195 pub(super) struct BoundBindGroupLayouts {
196 entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
197 rebind_start: usize,
198 }
199
200 impl BoundBindGroupLayouts {
201 pub fn new() -> Self {
202 Self {
203 entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
204 rebind_start: 0,
205 }
206 }
207
208 pub fn num_valid_entries(&self) -> usize {
209 self.entries
211 .iter()
212 .position(|e| e.is_incompatible())
213 .unwrap_or(self.entries.len())
214 }
215
216 pub fn take_rebind_range(&mut self) -> Range<usize> {
218 let end = self.num_valid_entries();
219 let start = self.rebind_start;
220 self.rebind_start = end;
221 start..end.max(start)
222 }
223
224 pub fn update_start_index(&mut self, start_index: usize) {
225 self.rebind_start = self.rebind_start.min(start_index);
226 }
227
228 pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
229 let start_index = self
230 .entries
231 .iter()
232 .zip(expectations)
233 .position(|(e, expect)| {
234 e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
235 })
236 .unwrap_or(expectations.len());
237 for (e, expect) in self.entries[start_index..]
238 .iter_mut()
239 .zip(expectations[start_index..].iter())
240 {
241 e.expected = Some(expect.clone());
242 }
243 for e in self.entries[expectations.len()..].iter_mut() {
244 e.expected = None;
245 }
246 self.update_start_index(start_index);
247 }
248
249 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
250 self.entries[index].assigned = Some(value);
251 self.update_start_index(index);
252 }
253
254 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
255 self.entries
256 .iter()
257 .enumerate()
258 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
259 }
260
261 #[allow(clippy::result_large_err)]
262 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
263 for (index, entry) in self.entries.iter().enumerate() {
264 entry.check().map_err(|e| (index, e))?;
265 }
266 Ok(())
267 }
268 }
269}
270
271#[derive(Clone, Debug, Error)]
272pub enum BinderError {
273 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
274 MissingBindGroup {
275 index: usize,
276 pipeline: ResourceErrorIdent,
277 },
278 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
279 IncompatibleBindGroup {
280 expected_bgl: ResourceErrorIdent,
281 assigned_bgl: ResourceErrorIdent,
282 assigned_bg: ResourceErrorIdent,
283 index: usize,
284 pipeline: ResourceErrorIdent,
285 #[source]
286 inner: crate::error::MultiError,
287 },
288}
289
290#[derive(Debug)]
291struct LateBufferBinding {
292 binding_index: u32,
293 shader_expect_size: wgt::BufferAddress,
294 bound_size: wgt::BufferAddress,
295}
296
297#[derive(Debug, Default)]
298pub(super) struct EntryPayload {
299 pub(super) group: Option<Arc<BindGroup>>,
300 pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
301 late_buffer_bindings: Vec<LateBufferBinding>,
302 pub(super) late_bindings_effective_count: usize,
305}
306
307impl EntryPayload {
308 fn reset(&mut self) {
309 self.group = None;
310 self.dynamic_offsets.clear();
311 self.late_buffer_bindings.clear();
312 self.late_bindings_effective_count = 0;
313 }
314}
315
316#[derive(Debug, Default)]
317pub(super) struct Binder {
318 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
319 manager: compat::BoundBindGroupLayouts,
320 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
321}
322
323impl Binder {
324 pub(super) fn new() -> Self {
325 Self {
326 pipeline_layout: None,
327 manager: compat::BoundBindGroupLayouts::new(),
328 payloads: Default::default(),
329 }
330 }
331 pub(super) fn reset(&mut self) {
332 self.pipeline_layout = None;
333 self.manager = compat::BoundBindGroupLayouts::new();
334 for payload in self.payloads.iter_mut() {
335 payload.reset();
336 }
337 }
338
339 pub(super) fn change_pipeline_layout<'a>(
342 &'a mut self,
343 new: &Arc<PipelineLayout>,
344 late_sized_buffer_groups: &[LateSizedBufferGroup],
345 ) -> bool {
346 if let Some(old) = self.pipeline_layout.as_ref() {
347 if old.is_equal(new) {
348 return false;
349 }
350 }
351
352 let old = self.pipeline_layout.replace(new.clone());
353
354 self.manager.update_expectations(&new.bind_group_layouts);
355
356 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
359 payload.late_bindings_effective_count = late_group.shader_sizes.len();
360 for (late_binding, &shader_expect_size) in payload
363 .late_buffer_bindings
364 .iter_mut()
365 .zip(late_group.shader_sizes.iter())
366 {
367 late_binding.shader_expect_size = shader_expect_size;
368 }
369 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
371 for &shader_expect_size in
372 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
373 {
374 payload.late_buffer_bindings.push(LateBufferBinding {
375 binding_index: 0,
376 shader_expect_size,
377 bound_size: 0,
378 });
379 }
380 }
381 }
382
383 if let Some(old) = old {
384 if old.immediate_size != new.immediate_size {
386 self.manager.update_start_index(0);
387 }
388 }
389
390 true
391 }
392
393 pub(super) fn assign_group<'a>(
394 &'a mut self,
395 index: usize,
396 bind_group: &Arc<BindGroup>,
397 offsets: &[wgt::DynamicOffset],
398 ) {
399 let payload = &mut self.payloads[index];
400 payload.group = Some(bind_group.clone());
401 payload.dynamic_offsets.clear();
402 payload.dynamic_offsets.extend_from_slice(offsets);
403
404 for (late_binding, late_info) in payload
410 .late_buffer_bindings
411 .iter_mut()
412 .zip(bind_group.late_buffer_binding_infos.iter())
413 {
414 late_binding.binding_index = late_info.binding_index;
415 late_binding.bound_size = late_info.size.get();
416 }
417
418 if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
420 for late_info in
421 bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
422 {
423 payload.late_buffer_bindings.push(LateBufferBinding {
424 binding_index: late_info.binding_index,
425 shader_expect_size: 0,
426 bound_size: late_info.size.get(),
427 });
428 }
429 }
430
431 self.manager.assign(index, bind_group.layout.clone());
432 }
433
434 pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
436 self.manager.take_rebind_range()
437 }
438
439 pub(super) fn entries(
440 &self,
441 range: Range<usize>,
442 ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + Clone + '_ {
443 let payloads = &self.payloads[range.clone()];
444 zip(range, payloads)
445 }
446
447 pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
448 let payloads = &self.payloads;
449 self.manager
450 .list_active()
451 .map(move |index| payloads[index].group.as_ref().unwrap())
452 }
453
454 pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
455 self.payloads
456 .iter()
457 .take(self.manager.num_valid_entries())
458 .enumerate()
459 }
460
461 pub(super) fn check_compatibility<T: Labeled>(
462 &self,
463 pipeline: &T,
464 ) -> Result<(), Box<BinderError>> {
465 self.manager.get_invalid().map_err(|(index, error)| {
466 Box::new(match error {
467 compat::Error::Incompatible {
468 expected_bgl,
469 assigned_bgl,
470 inner,
471 } => BinderError::IncompatibleBindGroup {
472 expected_bgl,
473 assigned_bgl,
474 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
475 index,
476 pipeline: pipeline.error_ident(),
477 inner,
478 },
479 compat::Error::Missing => BinderError::MissingBindGroup {
480 index,
481 pipeline: pipeline.error_ident(),
482 },
483 })
484 })
485 }
486
487 pub(super) fn check_late_buffer_bindings(
489 &self,
490 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
491 for group_index in self.manager.list_active() {
492 let payload = &self.payloads[group_index];
493 for late_binding in
494 &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
495 {
496 if late_binding.bound_size < late_binding.shader_expect_size {
497 return Err(LateMinBufferBindingSizeMismatch {
498 group_index: group_index as u32,
499 binding_index: late_binding.binding_index,
500 shader_size: late_binding.shader_expect_size,
501 bound_size: late_binding.bound_size,
502 });
503 }
504 }
505 }
506 Ok(())
507 }
508}