1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use arrayvec::ArrayVec;
6use thiserror::Error;
7
8use crate::{
9 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
10 device::SHADER_STAGE_COUNT,
11 pipeline::LateSizedBufferGroup,
12 resource::{Labeled, ResourceErrorIdent},
13};
14
15mod compat {
16 use alloc::{
17 string::{String, ToString as _},
18 sync::{Arc, Weak},
19 vec::Vec,
20 };
21 use core::{num::NonZeroU32, ops::Range};
22
23 use arrayvec::ArrayVec;
24 use thiserror::Error;
25 use wgt::{BindingType, ShaderStages};
26
27 use crate::{
28 binding_model::BindGroupLayout,
29 error::MultiError,
30 resource::{Labeled, ParentDevice, ResourceErrorIdent},
31 };
32
33 pub(crate) enum Error {
34 Incompatible {
35 expected_bgl: ResourceErrorIdent,
36 assigned_bgl: ResourceErrorIdent,
37 inner: MultiError,
38 },
39 Missing,
40 }
41
42 #[derive(Debug, Clone)]
43 struct Entry {
44 assigned: Option<Arc<BindGroupLayout>>,
45 expected: Option<Arc<BindGroupLayout>>,
46 }
47
48 impl Entry {
49 fn empty() -> Self {
50 Self {
51 assigned: None,
52 expected: None,
53 }
54 }
55 fn is_active(&self) -> bool {
56 self.assigned.is_some() && self.expected.is_some()
57 }
58
59 fn is_valid(&self) -> bool {
60 if let Some(expected_bgl) = self.expected.as_ref() {
61 if let Some(assigned_bgl) = self.assigned.as_ref() {
62 expected_bgl.is_equal(assigned_bgl)
63 } else {
64 false
65 }
66 } else {
67 true
68 }
69 }
70
71 fn is_incompatible(&self) -> bool {
72 self.expected.is_none() || !self.is_valid()
73 }
74
75 fn check(&self) -> Result<(), Error> {
76 if let Some(expected_bgl) = self.expected.as_ref() {
77 if let Some(assigned_bgl) = self.assigned.as_ref() {
78 if expected_bgl.is_equal(assigned_bgl) {
79 Ok(())
80 } else {
81 #[derive(Clone, Debug, Error)]
82 #[error(
83 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
84 )]
85 struct IncompatibleExclusivePipelines {
86 expected: String,
87 assigned: String,
88 }
89
90 use crate::binding_model::ExclusivePipeline;
91 match (
92 expected_bgl.exclusive_pipeline.get().unwrap(),
93 assigned_bgl.exclusive_pipeline.get().unwrap(),
94 ) {
95 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
96 (
97 ExclusivePipeline::Render(e_pipeline),
98 ExclusivePipeline::Render(a_pipeline),
99 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100 (
101 ExclusivePipeline::Compute(e_pipeline),
102 ExclusivePipeline::Compute(a_pipeline),
103 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
104 (expected, assigned) => {
105 return Err(Error::Incompatible {
106 expected_bgl: expected_bgl.error_ident(),
107 assigned_bgl: assigned_bgl.error_ident(),
108 inner: MultiError::new(core::iter::once(
109 IncompatibleExclusivePipelines {
110 expected: expected.to_string(),
111 assigned: assigned.to_string(),
112 },
113 ))
114 .unwrap(),
115 });
116 }
117 }
118
119 #[derive(Clone, Debug, Error)]
120 enum EntryError {
121 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
122 Visibility {
123 binding: u32,
124 expected: ShaderStages,
125 assigned: ShaderStages,
126 },
127 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
128 Type {
129 binding: u32,
130 expected: BindingType,
131 assigned: BindingType,
132 },
133 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
134 Count {
135 binding: u32,
136 expected: Option<NonZeroU32>,
137 assigned: Option<NonZeroU32>,
138 },
139 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
140 ExtraExpected { binding: u32 },
141 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
142 ExtraAssigned { binding: u32 },
143 }
144
145 let mut errors = Vec::new();
146
147 for (&binding, expected_entry) in expected_bgl.entries.iter() {
148 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
149 if assigned_entry.visibility != expected_entry.visibility {
150 errors.push(EntryError::Visibility {
151 binding,
152 expected: expected_entry.visibility,
153 assigned: assigned_entry.visibility,
154 });
155 }
156 if assigned_entry.ty != expected_entry.ty {
157 errors.push(EntryError::Type {
158 binding,
159 expected: expected_entry.ty,
160 assigned: assigned_entry.ty,
161 });
162 }
163 if assigned_entry.count != expected_entry.count {
164 errors.push(EntryError::Count {
165 binding,
166 expected: expected_entry.count,
167 assigned: assigned_entry.count,
168 });
169 }
170 } else {
171 errors.push(EntryError::ExtraExpected { binding });
172 }
173 }
174
175 for (&binding, _) in assigned_bgl.entries.iter() {
176 if !expected_bgl.entries.contains_key(binding) {
177 errors.push(EntryError::ExtraAssigned { binding });
178 }
179 }
180
181 Err(Error::Incompatible {
182 expected_bgl: expected_bgl.error_ident(),
183 assigned_bgl: assigned_bgl.error_ident(),
184 inner: MultiError::new(errors.drain(..)).unwrap(),
185 })
186 }
187 } else {
188 Err(Error::Missing)
189 }
190 } else {
191 Ok(())
192 }
193 }
194 }
195
196 #[derive(Debug, Default)]
197 pub(super) struct BoundBindGroupLayouts {
198 entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
199 rebind_start: usize,
200 }
201
202 impl BoundBindGroupLayouts {
203 pub fn new() -> Self {
204 Self {
205 entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
206 rebind_start: 0,
207 }
208 }
209
210 pub fn num_valid_entries(&self) -> usize {
211 self.entries
213 .iter()
214 .position(|e| e.is_incompatible())
215 .unwrap_or(self.entries.len())
216 }
217
218 pub fn take_rebind_range(&mut self) -> Range<usize> {
220 let end = self.num_valid_entries();
221 let start = self.rebind_start;
222 self.rebind_start = end;
223 start..end.max(start)
224 }
225
226 pub fn update_start_index(&mut self, start_index: usize) {
227 self.rebind_start = self.rebind_start.min(start_index);
228 }
229
230 pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
231 let start_index = self
232 .entries
233 .iter()
234 .zip(expectations)
235 .position(|(e, expect)| {
236 e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
237 })
238 .unwrap_or(expectations.len());
239 for (e, expect) in self.entries[start_index..]
240 .iter_mut()
241 .zip(expectations[start_index..].iter())
242 {
243 e.expected = Some(expect.clone());
244 }
245 for e in self.entries[expectations.len()..].iter_mut() {
246 e.expected = None;
247 }
248 self.update_start_index(start_index);
249 }
250
251 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
252 self.entries[index].assigned = Some(value);
253 self.update_start_index(index);
254 }
255
256 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
257 self.entries
258 .iter()
259 .enumerate()
260 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
261 }
262
263 #[allow(clippy::result_large_err)]
264 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
265 for (index, entry) in self.entries.iter().enumerate() {
266 entry.check().map_err(|e| (index, e))?;
267 }
268 Ok(())
269 }
270 }
271}
272
273#[derive(Clone, Debug, Error)]
274pub enum BinderError {
275 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
276 MissingBindGroup {
277 index: usize,
278 pipeline: ResourceErrorIdent,
279 },
280 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
281 IncompatibleBindGroup {
282 expected_bgl: ResourceErrorIdent,
283 assigned_bgl: ResourceErrorIdent,
284 assigned_bg: ResourceErrorIdent,
285 index: usize,
286 pipeline: ResourceErrorIdent,
287 #[source]
288 inner: crate::error::MultiError,
289 },
290}
291
292#[derive(Debug)]
293struct LateBufferBinding {
294 binding_index: u32,
295 shader_expect_size: wgt::BufferAddress,
296 bound_size: wgt::BufferAddress,
297}
298
299#[derive(Debug, Default)]
300pub(super) struct EntryPayload {
301 pub(super) group: Option<Arc<BindGroup>>,
302 pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
303 late_buffer_bindings: Vec<LateBufferBinding>,
304 pub(super) late_bindings_effective_count: usize,
307}
308
309impl EntryPayload {
310 fn reset(&mut self) {
311 self.group = None;
312 self.dynamic_offsets.clear();
313 self.late_buffer_bindings.clear();
314 self.late_bindings_effective_count = 0;
315 }
316}
317
318#[derive(Debug, Default)]
319pub(super) struct Binder {
320 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
321 manager: compat::BoundBindGroupLayouts,
322 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
323}
324
325impl Binder {
326 pub(super) fn new() -> Self {
327 Self {
328 pipeline_layout: None,
329 manager: compat::BoundBindGroupLayouts::new(),
330 payloads: Default::default(),
331 }
332 }
333 pub(super) fn reset(&mut self) {
334 self.pipeline_layout = None;
335 self.manager = compat::BoundBindGroupLayouts::new();
336 for payload in self.payloads.iter_mut() {
337 payload.reset();
338 }
339 }
340
341 pub(super) fn change_pipeline_layout<'a>(
342 &'a mut self,
343 new: &Arc<PipelineLayout>,
344 late_sized_buffer_groups: &[LateSizedBufferGroup],
345 ) {
346 let old_id_opt = self.pipeline_layout.replace(new.clone());
347
348 self.manager.update_expectations(&new.bind_group_layouts);
349
350 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
353 payload.late_bindings_effective_count = late_group.shader_sizes.len();
354 for (late_binding, &shader_expect_size) in payload
357 .late_buffer_bindings
358 .iter_mut()
359 .zip(late_group.shader_sizes.iter())
360 {
361 late_binding.shader_expect_size = shader_expect_size;
362 }
363 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
365 for &shader_expect_size in
366 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
367 {
368 payload.late_buffer_bindings.push(LateBufferBinding {
369 binding_index: 0,
370 shader_expect_size,
371 bound_size: 0,
372 });
373 }
374 }
375 }
376
377 if let Some(old) = old_id_opt {
378 if old.push_constant_ranges != new.push_constant_ranges {
380 self.manager.update_start_index(0);
381 }
382 }
383 }
384
385 pub(super) fn assign_group<'a>(
386 &'a mut self,
387 index: usize,
388 bind_group: &Arc<BindGroup>,
389 offsets: &[wgt::DynamicOffset],
390 ) {
391 let payload = &mut self.payloads[index];
392 payload.group = Some(bind_group.clone());
393 payload.dynamic_offsets.clear();
394 payload.dynamic_offsets.extend_from_slice(offsets);
395
396 for (late_binding, late_info) in payload
402 .late_buffer_bindings
403 .iter_mut()
404 .zip(bind_group.late_buffer_binding_infos.iter())
405 {
406 late_binding.binding_index = late_info.binding_index;
407 late_binding.bound_size = late_info.size.get();
408 }
409
410 if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
412 for late_info in
413 bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
414 {
415 payload.late_buffer_bindings.push(LateBufferBinding {
416 binding_index: late_info.binding_index,
417 shader_expect_size: 0,
418 bound_size: late_info.size.get(),
419 });
420 }
421 }
422
423 self.manager.assign(index, bind_group.layout.clone());
424 }
425
426 pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
428 self.manager.take_rebind_range()
429 }
430
431 pub(super) fn entries(
432 &self,
433 range: Range<usize>,
434 ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + Clone + '_ {
435 let payloads = &self.payloads[range.clone()];
436 zip(range, payloads)
437 }
438
439 pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
440 let payloads = &self.payloads;
441 self.manager
442 .list_active()
443 .map(move |index| payloads[index].group.as_ref().unwrap())
444 }
445
446 pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
447 self.payloads
448 .iter()
449 .take(self.manager.num_valid_entries())
450 .enumerate()
451 }
452
453 pub(super) fn check_compatibility<T: Labeled>(
454 &self,
455 pipeline: &T,
456 ) -> Result<(), Box<BinderError>> {
457 self.manager.get_invalid().map_err(|(index, error)| {
458 Box::new(match error {
459 compat::Error::Incompatible {
460 expected_bgl,
461 assigned_bgl,
462 inner,
463 } => BinderError::IncompatibleBindGroup {
464 expected_bgl,
465 assigned_bgl,
466 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
467 index,
468 pipeline: pipeline.error_ident(),
469 inner,
470 },
471 compat::Error::Missing => BinderError::MissingBindGroup {
472 index,
473 pipeline: pipeline.error_ident(),
474 },
475 })
476 })
477 }
478
479 pub(super) fn check_late_buffer_bindings(
481 &self,
482 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
483 for group_index in self.manager.list_active() {
484 let payload = &self.payloads[group_index];
485 for late_binding in
486 &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
487 {
488 if late_binding.bound_size < late_binding.shader_expect_size {
489 return Err(LateMinBufferBindingSizeMismatch {
490 group_index: group_index as u32,
491 binding_index: late_binding.binding_index,
492 shader_size: late_binding.shader_expect_size,
493 bound_size: late_binding.bound_size,
494 });
495 }
496 }
497 }
498 Ok(())
499 }
500}
501
502struct PushConstantChange {
503 stages: wgt::ShaderStages,
504 offset: u32,
505 enable: bool,
506}
507
508pub fn compute_nonoverlapping_ranges(
513 ranges: &[wgt::PushConstantRange],
514) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
515 if ranges.is_empty() {
516 return ArrayVec::new();
517 }
518 debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
519
520 let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
521 for range in ranges {
522 breaks.push(PushConstantChange {
523 stages: range.stages,
524 offset: range.range.start,
525 enable: true,
526 });
527 breaks.push(PushConstantChange {
528 stages: range.stages,
529 offset: range.range.end,
530 enable: false,
531 });
532 }
533 breaks.sort_unstable_by_key(|change| change.offset);
534
535 let mut output_ranges = ArrayVec::new();
536 let mut position = 0_u32;
537 let mut stages = wgt::ShaderStages::NONE;
538
539 for bk in breaks {
540 if bk.offset - position > 0 && !stages.is_empty() {
541 output_ranges.push(wgt::PushConstantRange {
542 stages,
543 range: position..bk.offset,
544 })
545 }
546 position = bk.offset;
547 stages.set(bk.stages, bk.enable);
548 }
549
550 output_ranges
551}