1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use arrayvec::ArrayVec;
6use thiserror::Error;
7
8use crate::{
9 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
10 device::SHADER_STAGE_COUNT,
11 pipeline::LateSizedBufferGroup,
12 resource::{Labeled, ResourceErrorIdent},
13};
14
15mod compat {
16 use alloc::{
17 string::{String, ToString as _},
18 sync::{Arc, Weak},
19 vec::Vec,
20 };
21 use core::{num::NonZeroU32, ops::Range};
22
23 use arrayvec::ArrayVec;
24 use thiserror::Error;
25 use wgt::{BindingType, ShaderStages};
26
27 use crate::{
28 binding_model::BindGroupLayout,
29 error::MultiError,
30 resource::{Labeled, ParentDevice, ResourceErrorIdent},
31 };
32
33 pub(crate) enum Error {
34 Incompatible {
35 expected_bgl: ResourceErrorIdent,
36 assigned_bgl: ResourceErrorIdent,
37 inner: MultiError,
38 },
39 Missing,
40 }
41
42 #[derive(Debug, Clone)]
43 struct Entry {
44 assigned: Option<Arc<BindGroupLayout>>,
45 expected: Option<Arc<BindGroupLayout>>,
46 }
47
48 impl Entry {
49 fn empty() -> Self {
50 Self {
51 assigned: None,
52 expected: None,
53 }
54 }
55 fn is_active(&self) -> bool {
56 self.assigned.is_some() && self.expected.is_some()
57 }
58
59 fn is_valid(&self) -> bool {
60 if let Some(expected_bgl) = self.expected.as_ref() {
61 if let Some(assigned_bgl) = self.assigned.as_ref() {
62 expected_bgl.is_equal(assigned_bgl)
63 } else {
64 false
65 }
66 } else {
67 true
68 }
69 }
70
71 fn is_incompatible(&self) -> bool {
72 self.expected.is_none() || !self.is_valid()
73 }
74
75 fn check(&self) -> Result<(), Error> {
76 if let Some(expected_bgl) = self.expected.as_ref() {
77 if let Some(assigned_bgl) = self.assigned.as_ref() {
78 if expected_bgl.is_equal(assigned_bgl) {
79 Ok(())
80 } else {
81 #[derive(Clone, Debug, Error)]
82 #[error(
83 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
84 )]
85 struct IncompatibleExclusivePipelines {
86 expected: String,
87 assigned: String,
88 }
89
90 use crate::binding_model::ExclusivePipeline;
91 match (
92 expected_bgl.exclusive_pipeline.get().unwrap(),
93 assigned_bgl.exclusive_pipeline.get().unwrap(),
94 ) {
95 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
96 (
97 ExclusivePipeline::Render(e_pipeline),
98 ExclusivePipeline::Render(a_pipeline),
99 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100 (
101 ExclusivePipeline::Compute(e_pipeline),
102 ExclusivePipeline::Compute(a_pipeline),
103 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
104 (expected, assigned) => {
105 return Err(Error::Incompatible {
106 expected_bgl: expected_bgl.error_ident(),
107 assigned_bgl: assigned_bgl.error_ident(),
108 inner: MultiError::new(core::iter::once(
109 IncompatibleExclusivePipelines {
110 expected: expected.to_string(),
111 assigned: assigned.to_string(),
112 },
113 ))
114 .unwrap(),
115 });
116 }
117 }
118
119 #[derive(Clone, Debug, Error)]
120 enum EntryError {
121 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
122 Visibility {
123 binding: u32,
124 expected: ShaderStages,
125 assigned: ShaderStages,
126 },
127 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
128 Type {
129 binding: u32,
130 expected: BindingType,
131 assigned: BindingType,
132 },
133 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
134 Count {
135 binding: u32,
136 expected: Option<NonZeroU32>,
137 assigned: Option<NonZeroU32>,
138 },
139 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
140 ExtraExpected { binding: u32 },
141 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
142 ExtraAssigned { binding: u32 },
143 }
144
145 let mut errors = Vec::new();
146
147 for (&binding, expected_entry) in expected_bgl.entries.iter() {
148 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
149 if assigned_entry.visibility != expected_entry.visibility {
150 errors.push(EntryError::Visibility {
151 binding,
152 expected: expected_entry.visibility,
153 assigned: assigned_entry.visibility,
154 });
155 }
156 if assigned_entry.ty != expected_entry.ty {
157 errors.push(EntryError::Type {
158 binding,
159 expected: expected_entry.ty,
160 assigned: assigned_entry.ty,
161 });
162 }
163 if assigned_entry.count != expected_entry.count {
164 errors.push(EntryError::Count {
165 binding,
166 expected: expected_entry.count,
167 assigned: assigned_entry.count,
168 });
169 }
170 } else {
171 errors.push(EntryError::ExtraExpected { binding });
172 }
173 }
174
175 for (&binding, _) in assigned_bgl.entries.iter() {
176 if !expected_bgl.entries.contains_key(binding) {
177 errors.push(EntryError::ExtraAssigned { binding });
178 }
179 }
180
181 Err(Error::Incompatible {
182 expected_bgl: expected_bgl.error_ident(),
183 assigned_bgl: assigned_bgl.error_ident(),
184 inner: MultiError::new(errors.drain(..)).unwrap(),
185 })
186 }
187 } else {
188 Err(Error::Missing)
189 }
190 } else {
191 Ok(())
192 }
193 }
194 }
195
196 #[derive(Debug, Default)]
197 pub(super) struct BoundBindGroupLayouts {
198 entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
199 rebind_start: usize,
200 }
201
202 impl BoundBindGroupLayouts {
203 pub fn new() -> Self {
204 Self {
205 entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
206 rebind_start: 0,
207 }
208 }
209
210 pub fn num_valid_entries(&self) -> usize {
211 self.entries
213 .iter()
214 .position(|e| e.is_incompatible())
215 .unwrap_or(self.entries.len())
216 }
217
218 pub fn take_rebind_range(&mut self) -> Range<usize> {
220 let end = self.num_valid_entries();
221 let start = self.rebind_start;
222 self.rebind_start = end;
223 start..end.max(start)
224 }
225
226 pub fn update_start_index(&mut self, start_index: usize) {
227 self.rebind_start = self.rebind_start.min(start_index);
228 }
229
230 pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
231 let start_index = self
232 .entries
233 .iter()
234 .zip(expectations)
235 .position(|(e, expect)| {
236 e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
237 })
238 .unwrap_or(expectations.len());
239 for (e, expect) in self.entries[start_index..]
240 .iter_mut()
241 .zip(expectations[start_index..].iter())
242 {
243 e.expected = Some(expect.clone());
244 }
245 for e in self.entries[expectations.len()..].iter_mut() {
246 e.expected = None;
247 }
248 self.update_start_index(start_index);
249 }
250
251 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
252 self.entries[index].assigned = Some(value);
253 self.update_start_index(index);
254 }
255
256 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
257 self.entries
258 .iter()
259 .enumerate()
260 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
261 }
262
263 #[allow(clippy::result_large_err)]
264 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
265 for (index, entry) in self.entries.iter().enumerate() {
266 entry.check().map_err(|e| (index, e))?;
267 }
268 Ok(())
269 }
270 }
271}
272
273#[derive(Clone, Debug, Error)]
274pub enum BinderError {
275 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
276 MissingBindGroup {
277 index: usize,
278 pipeline: ResourceErrorIdent,
279 },
280 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
281 IncompatibleBindGroup {
282 expected_bgl: ResourceErrorIdent,
283 assigned_bgl: ResourceErrorIdent,
284 assigned_bg: ResourceErrorIdent,
285 index: usize,
286 pipeline: ResourceErrorIdent,
287 #[source]
288 inner: crate::error::MultiError,
289 },
290}
291
292#[derive(Debug)]
293struct LateBufferBinding {
294 shader_expect_size: wgt::BufferAddress,
295 bound_size: wgt::BufferAddress,
296}
297
298#[derive(Debug, Default)]
299pub(super) struct EntryPayload {
300 pub(super) group: Option<Arc<BindGroup>>,
301 pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
302 late_buffer_bindings: Vec<LateBufferBinding>,
303 pub(super) late_bindings_effective_count: usize,
306}
307
308impl EntryPayload {
309 fn reset(&mut self) {
310 self.group = None;
311 self.dynamic_offsets.clear();
312 self.late_buffer_bindings.clear();
313 self.late_bindings_effective_count = 0;
314 }
315}
316
317#[derive(Debug, Default)]
318pub(super) struct Binder {
319 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
320 manager: compat::BoundBindGroupLayouts,
321 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
322}
323
324impl Binder {
325 pub(super) fn new() -> Self {
326 Self {
327 pipeline_layout: None,
328 manager: compat::BoundBindGroupLayouts::new(),
329 payloads: Default::default(),
330 }
331 }
332 pub(super) fn reset(&mut self) {
333 self.pipeline_layout = None;
334 self.manager = compat::BoundBindGroupLayouts::new();
335 for payload in self.payloads.iter_mut() {
336 payload.reset();
337 }
338 }
339
340 pub(super) fn change_pipeline_layout<'a>(
341 &'a mut self,
342 new: &Arc<PipelineLayout>,
343 late_sized_buffer_groups: &[LateSizedBufferGroup],
344 ) {
345 let old_id_opt = self.pipeline_layout.replace(new.clone());
346
347 self.manager.update_expectations(&new.bind_group_layouts);
348
349 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
351 payload.late_bindings_effective_count = late_group.shader_sizes.len();
352 for (late_binding, &shader_expect_size) in payload
353 .late_buffer_bindings
354 .iter_mut()
355 .zip(late_group.shader_sizes.iter())
356 {
357 late_binding.shader_expect_size = shader_expect_size;
358 }
359 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
360 for &shader_expect_size in
361 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
362 {
363 payload.late_buffer_bindings.push(LateBufferBinding {
364 shader_expect_size,
365 bound_size: 0,
366 });
367 }
368 }
369 }
370
371 if let Some(old) = old_id_opt {
372 if old.push_constant_ranges != new.push_constant_ranges {
374 self.manager.update_start_index(0);
375 }
376 }
377 }
378
379 pub(super) fn assign_group<'a>(
380 &'a mut self,
381 index: usize,
382 bind_group: &Arc<BindGroup>,
383 offsets: &[wgt::DynamicOffset],
384 ) {
385 let payload = &mut self.payloads[index];
386 payload.group = Some(bind_group.clone());
387 payload.dynamic_offsets.clear();
388 payload.dynamic_offsets.extend_from_slice(offsets);
389
390 for (late_binding, late_size) in payload
393 .late_buffer_bindings
394 .iter_mut()
395 .zip(bind_group.late_buffer_binding_sizes.iter())
396 {
397 late_binding.bound_size = late_size.get();
398 }
399 if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
400 for late_size in
401 bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
402 {
403 payload.late_buffer_bindings.push(LateBufferBinding {
404 shader_expect_size: 0,
405 bound_size: late_size.get(),
406 });
407 }
408 }
409
410 self.manager.assign(index, bind_group.layout.clone());
411 }
412
413 pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
415 self.manager.take_rebind_range()
416 }
417
418 pub(super) fn entries(
419 &self,
420 range: Range<usize>,
421 ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + '_ {
422 let payloads = &self.payloads[range.clone()];
423 zip(range, payloads)
424 }
425
426 pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
427 let payloads = &self.payloads;
428 self.manager
429 .list_active()
430 .map(move |index| payloads[index].group.as_ref().unwrap())
431 }
432
433 pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
434 self.payloads
435 .iter()
436 .take(self.manager.num_valid_entries())
437 .enumerate()
438 }
439
440 pub(super) fn check_compatibility<T: Labeled>(
441 &self,
442 pipeline: &T,
443 ) -> Result<(), Box<BinderError>> {
444 self.manager.get_invalid().map_err(|(index, error)| {
445 Box::new(match error {
446 compat::Error::Incompatible {
447 expected_bgl,
448 assigned_bgl,
449 inner,
450 } => BinderError::IncompatibleBindGroup {
451 expected_bgl,
452 assigned_bgl,
453 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
454 index,
455 pipeline: pipeline.error_ident(),
456 inner,
457 },
458 compat::Error::Missing => BinderError::MissingBindGroup {
459 index,
460 pipeline: pipeline.error_ident(),
461 },
462 })
463 })
464 }
465
466 pub(super) fn check_late_buffer_bindings(
468 &self,
469 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
470 for group_index in self.manager.list_active() {
471 let payload = &self.payloads[group_index];
472 for (compact_index, late_binding) in payload.late_buffer_bindings
473 [..payload.late_bindings_effective_count]
474 .iter()
475 .enumerate()
476 {
477 if late_binding.bound_size < late_binding.shader_expect_size {
478 return Err(LateMinBufferBindingSizeMismatch {
479 group_index: group_index as u32,
480 compact_index,
481 shader_size: late_binding.shader_expect_size,
482 bound_size: late_binding.bound_size,
483 });
484 }
485 }
486 }
487 Ok(())
488 }
489}
490
491struct PushConstantChange {
492 stages: wgt::ShaderStages,
493 offset: u32,
494 enable: bool,
495}
496
497pub fn compute_nonoverlapping_ranges(
502 ranges: &[wgt::PushConstantRange],
503) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
504 if ranges.is_empty() {
505 return ArrayVec::new();
506 }
507 debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
508
509 let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
510 for range in ranges {
511 breaks.push(PushConstantChange {
512 stages: range.stages,
513 offset: range.range.start,
514 enable: true,
515 });
516 breaks.push(PushConstantChange {
517 stages: range.stages,
518 offset: range.range.end,
519 enable: false,
520 });
521 }
522 breaks.sort_unstable_by_key(|change| change.offset);
523
524 let mut output_ranges = ArrayVec::new();
525 let mut position = 0_u32;
526 let mut stages = wgt::ShaderStages::NONE;
527
528 for bk in breaks {
529 if bk.offset - position > 0 && !stages.is_empty() {
530 output_ranges.push(wgt::PushConstantRange {
531 stages,
532 range: position..bk.offset,
533 })
534 }
535 position = bk.offset;
536 stages.set(bk.stages, bk.enable);
537 }
538
539 output_ranges
540}