wgpu_core/command/
bind.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use thiserror::Error;
4
5use crate::{
6    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
7    pipeline::LateSizedBufferGroup,
8    resource::{Labeled, ParentDevice, ResourceErrorIdent},
9};
10
11mod compat {
12    use alloc::{
13        string::{String, ToString as _},
14        sync::{Arc, Weak},
15        vec::Vec,
16    };
17    use core::num::NonZeroU32;
18
19    use thiserror::Error;
20    use wgt::{BindingType, ShaderStages};
21
22    use crate::{
23        binding_model::BindGroupLayout,
24        error::MultiError,
25        resource::{Labeled, ParentDevice, ResourceErrorIdent},
26    };
27
28    pub(crate) enum Error {
29        Incompatible {
30            expected_bgl: ResourceErrorIdent,
31            assigned_bgl: ResourceErrorIdent,
32            inner: MultiError,
33        },
34        Missing,
35    }
36
37    #[derive(Debug, Clone)]
38    struct Entry {
39        assigned: Option<Arc<BindGroupLayout>>,
40        expected: Option<Arc<BindGroupLayout>>,
41    }
42
43    impl Entry {
44        const fn empty() -> Self {
45            Self {
46                assigned: None,
47                expected: None,
48            }
49        }
50
51        fn is_assigned(&self) -> bool {
52            self.assigned.is_some()
53        }
54
55        fn is_active(&self) -> bool {
56            self.assigned.is_some() && self.expected.is_some()
57        }
58
59        fn is_valid(&self) -> bool {
60            if let Some(expected_bgl) = self.expected.as_ref() {
61                if let Some(assigned_bgl) = self.assigned.as_ref() {
62                    expected_bgl.is_equal(assigned_bgl)
63                } else {
64                    false
65                }
66            } else {
67                false
68            }
69        }
70
71        fn check(&self) -> Result<(), Error> {
72            if let Some(expected_bgl) = self.expected.as_ref() {
73                if let Some(assigned_bgl) = self.assigned.as_ref() {
74                    if expected_bgl.is_equal(assigned_bgl) {
75                        Ok(())
76                    } else {
77                        #[derive(Clone, Debug, Error)]
78                        #[error(
79                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
80                        )]
81                        struct IncompatibleExclusivePipelines {
82                            expected: String,
83                            assigned: String,
84                        }
85
86                        use crate::binding_model::ExclusivePipeline;
87                        match (
88                            expected_bgl.exclusive_pipeline.get().unwrap(),
89                            assigned_bgl.exclusive_pipeline.get().unwrap(),
90                        ) {
91                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
92                            (
93                                ExclusivePipeline::Render(e_pipeline),
94                                ExclusivePipeline::Render(a_pipeline),
95                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96                            (
97                                ExclusivePipeline::Compute(e_pipeline),
98                                ExclusivePipeline::Compute(a_pipeline),
99                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100                            (expected, assigned) => {
101                                return Err(Error::Incompatible {
102                                    expected_bgl: expected_bgl.error_ident(),
103                                    assigned_bgl: assigned_bgl.error_ident(),
104                                    inner: MultiError::new(core::iter::once(
105                                        IncompatibleExclusivePipelines {
106                                            expected: expected.to_string(),
107                                            assigned: assigned.to_string(),
108                                        },
109                                    ))
110                                    .unwrap(),
111                                });
112                            }
113                        }
114
115                        #[derive(Clone, Debug, Error)]
116                        enum EntryError {
117                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
118                            Visibility {
119                                binding: u32,
120                                expected: ShaderStages,
121                                assigned: ShaderStages,
122                            },
123                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
124                            Type {
125                                binding: u32,
126                                expected: BindingType,
127                                assigned: BindingType,
128                            },
129                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
130                            Count {
131                                binding: u32,
132                                expected: Option<NonZeroU32>,
133                                assigned: Option<NonZeroU32>,
134                            },
135                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
136                            ExtraExpected { binding: u32 },
137                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
138                            ExtraAssigned { binding: u32 },
139                        }
140
141                        let mut errors = Vec::new();
142
143                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
144                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
145                                if assigned_entry.visibility != expected_entry.visibility {
146                                    errors.push(EntryError::Visibility {
147                                        binding,
148                                        expected: expected_entry.visibility,
149                                        assigned: assigned_entry.visibility,
150                                    });
151                                }
152                                if assigned_entry.ty != expected_entry.ty {
153                                    errors.push(EntryError::Type {
154                                        binding,
155                                        expected: expected_entry.ty,
156                                        assigned: assigned_entry.ty,
157                                    });
158                                }
159                                if assigned_entry.count != expected_entry.count {
160                                    errors.push(EntryError::Count {
161                                        binding,
162                                        expected: expected_entry.count,
163                                        assigned: assigned_entry.count,
164                                    });
165                                }
166                            } else {
167                                errors.push(EntryError::ExtraExpected { binding });
168                            }
169                        }
170
171                        for (&binding, _) in assigned_bgl.entries.iter() {
172                            if !expected_bgl.entries.contains_key(binding) {
173                                errors.push(EntryError::ExtraAssigned { binding });
174                            }
175                        }
176
177                        Err(Error::Incompatible {
178                            expected_bgl: expected_bgl.error_ident(),
179                            assigned_bgl: assigned_bgl.error_ident(),
180                            inner: MultiError::new(errors.drain(..)).unwrap(),
181                        })
182                    }
183                } else {
184                    Err(Error::Missing)
185                }
186            } else {
187                Ok(())
188            }
189        }
190    }
191
192    #[derive(Debug)]
193    pub(super) struct BoundBindGroupLayouts {
194        entries: [Entry; hal::MAX_BIND_GROUPS],
195        rebind_start: usize,
196    }
197
198    impl BoundBindGroupLayouts {
199        pub fn new() -> Self {
200            Self {
201                entries: [const { Entry::empty() }; hal::MAX_BIND_GROUPS],
202                rebind_start: 0,
203            }
204        }
205
206        /// Takes the start index of the bind group range to be rebound, and clears it.
207        pub fn take_rebind_start_index(&mut self) -> usize {
208            let start = self.rebind_start;
209            self.rebind_start = self.entries.len();
210            start
211        }
212
213        pub fn update_rebind_start_index(&mut self, start_index: usize) {
214            self.rebind_start = self.rebind_start.min(start_index);
215        }
216
217        pub fn update_expectations(&mut self, expectations: &[Option<Arc<BindGroupLayout>>]) {
218            let mut rebind_start_index = None;
219
220            for (i, (e, new_expected_bgl)) in self
221                .entries
222                .iter_mut()
223                .zip(expectations.iter().chain(core::iter::repeat(&None)))
224                .enumerate()
225            {
226                let (must_set, must_rebind) = match (&mut e.expected, new_expected_bgl) {
227                    (None, None) => (false, false),
228                    (None, Some(_)) => (true, true),
229                    (Some(_), None) => (true, false),
230                    (Some(old_expected_bgl), Some(new_expected_bgl)) => {
231                        let is_different = !old_expected_bgl.is_equal(new_expected_bgl);
232                        (is_different, is_different)
233                    }
234                };
235                if must_set {
236                    e.expected = new_expected_bgl.clone();
237                }
238                if must_rebind && rebind_start_index.is_none() {
239                    rebind_start_index = Some(i);
240                }
241            }
242
243            if let Some(rebind_start_index) = rebind_start_index {
244                self.update_rebind_start_index(rebind_start_index);
245            }
246        }
247
248        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
249            self.entries[index].assigned = Some(value);
250            self.update_rebind_start_index(index);
251        }
252
253        pub fn clear(&mut self, index: usize) {
254            self.entries[index].assigned = None;
255        }
256
257        pub fn list_assigned(&self) -> impl Iterator<Item = usize> + '_ {
258            self.entries
259                .iter()
260                .enumerate()
261                .filter_map(|(i, e)| if e.is_assigned() { Some(i) } else { None })
262        }
263
264        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
265            self.entries
266                .iter()
267                .enumerate()
268                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
269        }
270
271        pub fn list_valid(&self) -> impl Iterator<Item = usize> + '_ {
272            self.entries
273                .iter()
274                .enumerate()
275                .filter_map(|(i, e)| if e.is_valid() { Some(i) } else { None })
276        }
277
278        #[allow(clippy::result_large_err)]
279        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
280            for (index, entry) in self.entries.iter().enumerate() {
281                entry.check().map_err(|e| (index, e))?;
282            }
283            Ok(())
284        }
285    }
286}
287
288#[derive(Clone, Debug, Error)]
289pub enum BinderError {
290    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
291    MissingBindGroup {
292        index: usize,
293        pipeline: ResourceErrorIdent,
294    },
295    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
296    IncompatibleBindGroup {
297        expected_bgl: ResourceErrorIdent,
298        assigned_bgl: ResourceErrorIdent,
299        assigned_bg: ResourceErrorIdent,
300        index: usize,
301        pipeline: ResourceErrorIdent,
302        #[source]
303        inner: crate::error::MultiError,
304    },
305}
306
307#[derive(Debug)]
308struct LateBufferBinding {
309    binding_index: u32,
310    shader_expect_size: wgt::BufferAddress,
311    bound_size: wgt::BufferAddress,
312}
313
314#[derive(Debug, Default)]
315struct EntryPayload {
316    group: Option<Arc<BindGroup>>,
317    dynamic_offsets: Vec<wgt::DynamicOffset>,
318    late_buffer_bindings: Vec<LateBufferBinding>,
319    /// Since `LateBufferBinding` may contain information about the bindings
320    /// not used by the pipeline, we need to know when to stop validating.
321    late_bindings_effective_count: usize,
322}
323
324impl EntryPayload {
325    fn reset(&mut self) {
326        self.group = None;
327        self.dynamic_offsets.clear();
328        self.late_buffer_bindings.clear();
329        self.late_bindings_effective_count = 0;
330    }
331}
332
333#[derive(Debug)]
334pub(super) struct Binder {
335    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
336    manager: compat::BoundBindGroupLayouts,
337    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
338}
339
340impl Binder {
341    pub(super) fn new() -> Self {
342        Self {
343            pipeline_layout: None,
344            manager: compat::BoundBindGroupLayouts::new(),
345            payloads: Default::default(),
346        }
347    }
348    pub(super) fn reset(&mut self) {
349        self.pipeline_layout = None;
350        self.manager = compat::BoundBindGroupLayouts::new();
351        for payload in self.payloads.iter_mut() {
352            payload.reset();
353        }
354    }
355
356    /// Returns `true` if the pipeline layout has been changed, i.e. if the
357    /// new PL was not the same as the old PL.
358    pub(super) fn change_pipeline_layout<'a>(
359        &'a mut self,
360        new: &Arc<PipelineLayout>,
361        late_sized_buffer_groups: &[LateSizedBufferGroup],
362    ) -> bool {
363        self.update_late_buffer_bindings(late_sized_buffer_groups);
364
365        if let Some(old) = self.pipeline_layout.as_ref() {
366            if old.is_equal(new) {
367                return false;
368            }
369        }
370
371        let old = self.pipeline_layout.replace(new.clone());
372
373        self.manager.update_expectations(&new.bind_group_layouts);
374
375        if let Some(old) = old {
376            // root constants are the base compatibility property
377            if old.immediate_size != new.immediate_size {
378                self.manager.update_rebind_start_index(0);
379            }
380        }
381
382        true
383    }
384
385    pub(super) fn assign_group<'a>(
386        &'a mut self,
387        index: usize,
388        bind_group: &Arc<BindGroup>,
389        offsets: &[wgt::DynamicOffset],
390    ) {
391        let payload = &mut self.payloads[index];
392        payload.group = Some(bind_group.clone());
393        payload.dynamic_offsets.clear();
394        payload.dynamic_offsets.extend_from_slice(offsets);
395
396        // Fill out the actual binding sizes for buffers,
397        // whose layout doesn't specify `min_binding_size`.
398
399        // Update entries that already exist as the pipeline was bound before the group
400        // was bound.
401        for (late_binding, late_info) in payload
402            .late_buffer_bindings
403            .iter_mut()
404            .zip(bind_group.late_buffer_binding_infos.iter())
405        {
406            late_binding.binding_index = late_info.binding_index;
407            late_binding.bound_size = late_info.size.get();
408        }
409
410        // Add new entries for the bindings that were not known when the pipeline was bound.
411        if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
412            for late_info in
413                bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
414            {
415                payload.late_buffer_bindings.push(LateBufferBinding {
416                    binding_index: late_info.binding_index,
417                    shader_expect_size: 0,
418                    bound_size: late_info.size.get(),
419                });
420            }
421        }
422
423        self.manager.assign(index, bind_group.layout.clone());
424    }
425
426    pub(super) fn clear_group(&mut self, index: usize) {
427        self.payloads[index].reset();
428        self.manager.clear(index);
429    }
430
431    /// Takes the start index of the bind group range to be rebound, and clears it.
432    pub(super) fn take_rebind_start_index(&mut self) -> usize {
433        self.manager.take_rebind_start_index()
434    }
435
436    pub(super) fn list_valid_with_start(
437        &self,
438        start: usize,
439    ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
440        let payloads = &self.payloads;
441        self.manager
442            .list_valid()
443            .filter(move |i| *i >= start)
444            .map(move |index| {
445                (
446                    index,
447                    payloads[index].group.as_ref().unwrap(),
448                    payloads[index].dynamic_offsets.as_slice(),
449                )
450            })
451    }
452
453    pub(super) fn last_assigned_index(&self) -> Option<usize> {
454        self.manager.list_assigned().last()
455    }
456
457    pub(super) fn list_active(&self) -> impl Iterator<Item = &Arc<BindGroup>> + '_ {
458        let payloads = &self.payloads;
459        self.manager
460            .list_active()
461            .map(move |index| payloads[index].group.as_ref().unwrap())
462    }
463
464    pub(super) fn list_valid(
465        &self,
466    ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
467        self.list_valid_with_start(0)
468    }
469
470    pub(super) fn check_compatibility<T: Labeled>(
471        &self,
472        pipeline: &T,
473    ) -> Result<(), Box<BinderError>> {
474        self.manager.get_invalid().map_err(|(index, error)| {
475            Box::new(match error {
476                compat::Error::Incompatible {
477                    expected_bgl,
478                    assigned_bgl,
479                    inner,
480                } => BinderError::IncompatibleBindGroup {
481                    expected_bgl,
482                    assigned_bgl,
483                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
484                    index,
485                    pipeline: pipeline.error_ident(),
486                    inner,
487                },
488                compat::Error::Missing => BinderError::MissingBindGroup {
489                    index,
490                    pipeline: pipeline.error_ident(),
491                },
492            })
493        })
494    }
495
496    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
497    pub(super) fn check_late_buffer_bindings(
498        &self,
499    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
500        for group_index in self.manager.list_active() {
501            let payload = &self.payloads[group_index];
502            for late_binding in
503                &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
504            {
505                if late_binding.bound_size < late_binding.shader_expect_size {
506                    return Err(LateMinBufferBindingSizeMismatch {
507                        group_index: group_index as u32,
508                        binding_index: late_binding.binding_index,
509                        shader_size: late_binding.shader_expect_size,
510                        bound_size: late_binding.bound_size,
511                    });
512                }
513            }
514        }
515        Ok(())
516    }
517
518    /// This must be called even when a new pipeline has the same layout
519    /// as the previous one, because different pipelines can have different
520    /// shader-expected buffer sizes even with identical layouts.
521    fn update_late_buffer_bindings(&mut self, late_sized_buffer_groups: &[LateSizedBufferGroup]) {
522        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
523            payload.late_bindings_effective_count = late_group.shader_sizes.len();
524
525            // Update entries that already exist as the bind group was bound before the pipeline
526            // was bound.
527            for (late_binding, &shader_expect_size) in payload
528                .late_buffer_bindings
529                .iter_mut()
530                .zip(late_group.shader_sizes.iter())
531            {
532                late_binding.shader_expect_size = shader_expect_size;
533            }
534
535            // Add new entries for the bindings that were not known when the bind group was bound.
536            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
537                for &shader_expect_size in
538                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
539                {
540                    payload.late_buffer_bindings.push(LateBufferBinding {
541                        binding_index: 0,
542                        shader_expect_size,
543                        bound_size: 0,
544                    });
545                }
546            }
547        }
548    }
549}