1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use arrayvec::ArrayVec;
4use thiserror::Error;
5
6use crate::{
7 binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
8 device::SHADER_STAGE_COUNT,
9 pipeline::LateSizedBufferGroup,
10 resource::{Labeled, ResourceErrorIdent},
11};
12
13mod compat {
14 use alloc::{
15 string::{String, ToString as _},
16 sync::{Arc, Weak},
17 vec::Vec,
18 };
19 use core::{num::NonZeroU32, ops::Range};
20
21 use arrayvec::ArrayVec;
22 use thiserror::Error;
23 use wgt::{BindingType, ShaderStages};
24
25 use crate::{
26 binding_model::BindGroupLayout,
27 error::MultiError,
28 resource::{Labeled, ParentDevice, ResourceErrorIdent},
29 };
30
31 pub(crate) enum Error {
32 Incompatible {
33 expected_bgl: ResourceErrorIdent,
34 assigned_bgl: ResourceErrorIdent,
35 inner: MultiError,
36 },
37 Missing,
38 }
39
40 #[derive(Debug, Clone)]
41 struct Entry {
42 assigned: Option<Arc<BindGroupLayout>>,
43 expected: Option<Arc<BindGroupLayout>>,
44 }
45
46 impl Entry {
47 fn empty() -> Self {
48 Self {
49 assigned: None,
50 expected: None,
51 }
52 }
53 fn is_active(&self) -> bool {
54 self.assigned.is_some() && self.expected.is_some()
55 }
56
57 fn is_valid(&self) -> bool {
58 if let Some(expected_bgl) = self.expected.as_ref() {
59 if let Some(assigned_bgl) = self.assigned.as_ref() {
60 expected_bgl.is_equal(assigned_bgl)
61 } else {
62 false
63 }
64 } else {
65 true
66 }
67 }
68
69 fn is_incompatible(&self) -> bool {
70 self.expected.is_none() || !self.is_valid()
71 }
72
73 fn check(&self) -> Result<(), Error> {
74 if let Some(expected_bgl) = self.expected.as_ref() {
75 if let Some(assigned_bgl) = self.assigned.as_ref() {
76 if expected_bgl.is_equal(assigned_bgl) {
77 Ok(())
78 } else {
79 #[derive(Clone, Debug, Error)]
80 #[error(
81 "Exclusive pipelines don't match: expected {expected}, got {assigned}"
82 )]
83 struct IncompatibleExclusivePipelines {
84 expected: String,
85 assigned: String,
86 }
87
88 use crate::binding_model::ExclusivePipeline;
89 match (
90 expected_bgl.exclusive_pipeline.get().unwrap(),
91 assigned_bgl.exclusive_pipeline.get().unwrap(),
92 ) {
93 (ExclusivePipeline::None, ExclusivePipeline::None) => {}
94 (
95 ExclusivePipeline::Render(e_pipeline),
96 ExclusivePipeline::Render(a_pipeline),
97 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
98 (
99 ExclusivePipeline::Compute(e_pipeline),
100 ExclusivePipeline::Compute(a_pipeline),
101 ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
102 (expected, assigned) => {
103 return Err(Error::Incompatible {
104 expected_bgl: expected_bgl.error_ident(),
105 assigned_bgl: assigned_bgl.error_ident(),
106 inner: MultiError::new(core::iter::once(
107 IncompatibleExclusivePipelines {
108 expected: expected.to_string(),
109 assigned: assigned.to_string(),
110 },
111 ))
112 .unwrap(),
113 });
114 }
115 }
116
117 #[derive(Clone, Debug, Error)]
118 enum EntryError {
119 #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
120 Visibility {
121 binding: u32,
122 expected: ShaderStages,
123 assigned: ShaderStages,
124 },
125 #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
126 Type {
127 binding: u32,
128 expected: BindingType,
129 assigned: BindingType,
130 },
131 #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
132 Count {
133 binding: u32,
134 expected: Option<NonZeroU32>,
135 assigned: Option<NonZeroU32>,
136 },
137 #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
138 ExtraExpected { binding: u32 },
139 #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
140 ExtraAssigned { binding: u32 },
141 }
142
143 let mut errors = Vec::new();
144
145 for (&binding, expected_entry) in expected_bgl.entries.iter() {
146 if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
147 if assigned_entry.visibility != expected_entry.visibility {
148 errors.push(EntryError::Visibility {
149 binding,
150 expected: expected_entry.visibility,
151 assigned: assigned_entry.visibility,
152 });
153 }
154 if assigned_entry.ty != expected_entry.ty {
155 errors.push(EntryError::Type {
156 binding,
157 expected: expected_entry.ty,
158 assigned: assigned_entry.ty,
159 });
160 }
161 if assigned_entry.count != expected_entry.count {
162 errors.push(EntryError::Count {
163 binding,
164 expected: expected_entry.count,
165 assigned: assigned_entry.count,
166 });
167 }
168 } else {
169 errors.push(EntryError::ExtraExpected { binding });
170 }
171 }
172
173 for (&binding, _) in assigned_bgl.entries.iter() {
174 if !expected_bgl.entries.contains_key(binding) {
175 errors.push(EntryError::ExtraAssigned { binding });
176 }
177 }
178
179 Err(Error::Incompatible {
180 expected_bgl: expected_bgl.error_ident(),
181 assigned_bgl: assigned_bgl.error_ident(),
182 inner: MultiError::new(errors.drain(..)).unwrap(),
183 })
184 }
185 } else {
186 Err(Error::Missing)
187 }
188 } else {
189 Ok(())
190 }
191 }
192 }
193
194 #[derive(Debug, Default)]
195 pub(crate) struct BoundBindGroupLayouts {
196 entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
197 }
198
199 impl BoundBindGroupLayouts {
200 pub fn new() -> Self {
201 Self {
202 entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
203 }
204 }
205
206 pub fn num_valid_entries(&self) -> usize {
207 self.entries
209 .iter()
210 .position(|e| e.is_incompatible())
211 .unwrap_or(self.entries.len())
212 }
213
214 fn make_range(&self, start_index: usize) -> Range<usize> {
215 let end = self.num_valid_entries();
216 start_index..end.max(start_index)
217 }
218
219 pub fn update_expectations(
220 &mut self,
221 expectations: &[Arc<BindGroupLayout>],
222 ) -> Range<usize> {
223 let start_index = self
224 .entries
225 .iter()
226 .zip(expectations)
227 .position(|(e, expect)| {
228 e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
229 })
230 .unwrap_or(expectations.len());
231 for (e, expect) in self.entries[start_index..]
232 .iter_mut()
233 .zip(expectations[start_index..].iter())
234 {
235 e.expected = Some(expect.clone());
236 }
237 for e in self.entries[expectations.len()..].iter_mut() {
238 e.expected = None;
239 }
240 self.make_range(start_index)
241 }
242
243 pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) -> Range<usize> {
244 self.entries[index].assigned = Some(value);
245 self.make_range(index)
246 }
247
248 pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
249 self.entries
250 .iter()
251 .enumerate()
252 .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
253 }
254
255 #[allow(clippy::result_large_err)]
256 pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
257 for (index, entry) in self.entries.iter().enumerate() {
258 entry.check().map_err(|e| (index, e))?;
259 }
260 Ok(())
261 }
262 }
263}
264
265#[derive(Clone, Debug, Error)]
266pub enum BinderError {
267 #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
268 MissingBindGroup {
269 index: usize,
270 pipeline: ResourceErrorIdent,
271 },
272 #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
273 IncompatibleBindGroup {
274 expected_bgl: ResourceErrorIdent,
275 assigned_bgl: ResourceErrorIdent,
276 assigned_bg: ResourceErrorIdent,
277 index: usize,
278 pipeline: ResourceErrorIdent,
279 #[source]
280 inner: crate::error::MultiError,
281 },
282}
283
284#[derive(Debug)]
285struct LateBufferBinding {
286 shader_expect_size: wgt::BufferAddress,
287 bound_size: wgt::BufferAddress,
288}
289
290#[derive(Debug, Default)]
291pub(super) struct EntryPayload {
292 pub(super) group: Option<Arc<BindGroup>>,
293 pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
294 late_buffer_bindings: Vec<LateBufferBinding>,
295 pub(super) late_bindings_effective_count: usize,
298}
299
300impl EntryPayload {
301 fn reset(&mut self) {
302 self.group = None;
303 self.dynamic_offsets.clear();
304 self.late_buffer_bindings.clear();
305 self.late_bindings_effective_count = 0;
306 }
307}
308
309#[derive(Debug, Default)]
310pub(super) struct Binder {
311 pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
312 manager: compat::BoundBindGroupLayouts,
313 payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
314}
315
316impl Binder {
317 pub(super) fn new() -> Self {
318 Self {
319 pipeline_layout: None,
320 manager: compat::BoundBindGroupLayouts::new(),
321 payloads: Default::default(),
322 }
323 }
324 pub(super) fn reset(&mut self) {
325 self.pipeline_layout = None;
326 self.manager = compat::BoundBindGroupLayouts::new();
327 for payload in self.payloads.iter_mut() {
328 payload.reset();
329 }
330 }
331
332 pub(super) fn change_pipeline_layout<'a>(
333 &'a mut self,
334 new: &Arc<PipelineLayout>,
335 late_sized_buffer_groups: &[LateSizedBufferGroup],
336 ) -> (usize, &'a [EntryPayload]) {
337 let old_id_opt = self.pipeline_layout.replace(new.clone());
338
339 let mut bind_range = self.manager.update_expectations(&new.bind_group_layouts);
340
341 for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
343 payload.late_bindings_effective_count = late_group.shader_sizes.len();
344 for (late_binding, &shader_expect_size) in payload
345 .late_buffer_bindings
346 .iter_mut()
347 .zip(late_group.shader_sizes.iter())
348 {
349 late_binding.shader_expect_size = shader_expect_size;
350 }
351 if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
352 for &shader_expect_size in
353 late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
354 {
355 payload.late_buffer_bindings.push(LateBufferBinding {
356 shader_expect_size,
357 bound_size: 0,
358 });
359 }
360 }
361 }
362
363 if let Some(old) = old_id_opt {
364 if old.push_constant_ranges != new.push_constant_ranges {
366 bind_range.start = 0;
367 }
368 }
369
370 (bind_range.start, &self.payloads[bind_range])
371 }
372
373 pub(super) fn assign_group<'a>(
374 &'a mut self,
375 index: usize,
376 bind_group: &Arc<BindGroup>,
377 offsets: &[wgt::DynamicOffset],
378 ) -> &'a [EntryPayload] {
379 let payload = &mut self.payloads[index];
380 payload.group = Some(bind_group.clone());
381 payload.dynamic_offsets.clear();
382 payload.dynamic_offsets.extend_from_slice(offsets);
383
384 for (late_binding, late_size) in payload
387 .late_buffer_bindings
388 .iter_mut()
389 .zip(bind_group.late_buffer_binding_sizes.iter())
390 {
391 late_binding.bound_size = late_size.get();
392 }
393 if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
394 for late_size in
395 bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
396 {
397 payload.late_buffer_bindings.push(LateBufferBinding {
398 shader_expect_size: 0,
399 bound_size: late_size.get(),
400 });
401 }
402 }
403
404 let bind_range = self.manager.assign(index, bind_group.layout.clone());
405 &self.payloads[bind_range]
406 }
407
408 pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
409 let payloads = &self.payloads;
410 self.manager
411 .list_active()
412 .map(move |index| payloads[index].group.as_ref().unwrap())
413 }
414
415 pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
416 self.payloads
417 .iter()
418 .take(self.manager.num_valid_entries())
419 .enumerate()
420 }
421
422 pub(super) fn check_compatibility<T: Labeled>(
423 &self,
424 pipeline: &T,
425 ) -> Result<(), Box<BinderError>> {
426 self.manager.get_invalid().map_err(|(index, error)| {
427 Box::new(match error {
428 compat::Error::Incompatible {
429 expected_bgl,
430 assigned_bgl,
431 inner,
432 } => BinderError::IncompatibleBindGroup {
433 expected_bgl,
434 assigned_bgl,
435 assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
436 index,
437 pipeline: pipeline.error_ident(),
438 inner,
439 },
440 compat::Error::Missing => BinderError::MissingBindGroup {
441 index,
442 pipeline: pipeline.error_ident(),
443 },
444 })
445 })
446 }
447
448 pub(super) fn check_late_buffer_bindings(
450 &self,
451 ) -> Result<(), LateMinBufferBindingSizeMismatch> {
452 for group_index in self.manager.list_active() {
453 let payload = &self.payloads[group_index];
454 for (compact_index, late_binding) in payload.late_buffer_bindings
455 [..payload.late_bindings_effective_count]
456 .iter()
457 .enumerate()
458 {
459 if late_binding.bound_size < late_binding.shader_expect_size {
460 return Err(LateMinBufferBindingSizeMismatch {
461 group_index: group_index as u32,
462 compact_index,
463 shader_size: late_binding.shader_expect_size,
464 bound_size: late_binding.bound_size,
465 });
466 }
467 }
468 }
469 Ok(())
470 }
471}
472
473struct PushConstantChange {
474 stages: wgt::ShaderStages,
475 offset: u32,
476 enable: bool,
477}
478
479pub fn compute_nonoverlapping_ranges(
484 ranges: &[wgt::PushConstantRange],
485) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
486 if ranges.is_empty() {
487 return ArrayVec::new();
488 }
489 debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
490
491 let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
492 for range in ranges {
493 breaks.push(PushConstantChange {
494 stages: range.stages,
495 offset: range.range.start,
496 enable: true,
497 });
498 breaks.push(PushConstantChange {
499 stages: range.stages,
500 offset: range.range.end,
501 enable: false,
502 });
503 }
504 breaks.sort_unstable_by_key(|change| change.offset);
505
506 let mut output_ranges = ArrayVec::new();
507 let mut position = 0_u32;
508 let mut stages = wgt::ShaderStages::NONE;
509
510 for bk in breaks {
511 if bk.offset - position > 0 && !stages.is_empty() {
512 output_ranges.push(wgt::PushConstantRange {
513 stages,
514 range: position..bk.offset,
515 })
516 }
517 position = bk.offset;
518 stages.set(bk.stages, bk.enable);
519 }
520
521 output_ranges
522}