wgpu_core/command/
bind.rs

1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use thiserror::Error;
6
7use crate::{
8    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
9    pipeline::LateSizedBufferGroup,
10    resource::{Labeled, ParentDevice, ResourceErrorIdent},
11};
12
13mod compat {
14    use alloc::{
15        string::{String, ToString as _},
16        sync::{Arc, Weak},
17        vec::Vec,
18    };
19    use core::{num::NonZeroU32, ops::Range};
20
21    use arrayvec::ArrayVec;
22    use thiserror::Error;
23    use wgt::{BindingType, ShaderStages};
24
25    use crate::{
26        binding_model::BindGroupLayout,
27        error::MultiError,
28        resource::{Labeled, ParentDevice, ResourceErrorIdent},
29    };
30
31    pub(crate) enum Error {
32        Incompatible {
33            expected_bgl: ResourceErrorIdent,
34            assigned_bgl: ResourceErrorIdent,
35            inner: MultiError,
36        },
37        Missing,
38    }
39
40    #[derive(Debug, Clone)]
41    struct Entry {
42        assigned: Option<Arc<BindGroupLayout>>,
43        expected: Option<Arc<BindGroupLayout>>,
44    }
45
46    impl Entry {
47        fn empty() -> Self {
48            Self {
49                assigned: None,
50                expected: None,
51            }
52        }
53        fn is_active(&self) -> bool {
54            self.assigned.is_some() && self.expected.is_some()
55        }
56
57        fn is_valid(&self) -> bool {
58            if let Some(expected_bgl) = self.expected.as_ref() {
59                if let Some(assigned_bgl) = self.assigned.as_ref() {
60                    expected_bgl.is_equal(assigned_bgl)
61                } else {
62                    false
63                }
64            } else {
65                true
66            }
67        }
68
69        fn is_incompatible(&self) -> bool {
70            self.expected.is_none() || !self.is_valid()
71        }
72
73        fn check(&self) -> Result<(), Error> {
74            if let Some(expected_bgl) = self.expected.as_ref() {
75                if let Some(assigned_bgl) = self.assigned.as_ref() {
76                    if expected_bgl.is_equal(assigned_bgl) {
77                        Ok(())
78                    } else {
79                        #[derive(Clone, Debug, Error)]
80                        #[error(
81                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
82                        )]
83                        struct IncompatibleExclusivePipelines {
84                            expected: String,
85                            assigned: String,
86                        }
87
88                        use crate::binding_model::ExclusivePipeline;
89                        match (
90                            expected_bgl.exclusive_pipeline.get().unwrap(),
91                            assigned_bgl.exclusive_pipeline.get().unwrap(),
92                        ) {
93                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
94                            (
95                                ExclusivePipeline::Render(e_pipeline),
96                                ExclusivePipeline::Render(a_pipeline),
97                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
98                            (
99                                ExclusivePipeline::Compute(e_pipeline),
100                                ExclusivePipeline::Compute(a_pipeline),
101                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
102                            (expected, assigned) => {
103                                return Err(Error::Incompatible {
104                                    expected_bgl: expected_bgl.error_ident(),
105                                    assigned_bgl: assigned_bgl.error_ident(),
106                                    inner: MultiError::new(core::iter::once(
107                                        IncompatibleExclusivePipelines {
108                                            expected: expected.to_string(),
109                                            assigned: assigned.to_string(),
110                                        },
111                                    ))
112                                    .unwrap(),
113                                });
114                            }
115                        }
116
117                        #[derive(Clone, Debug, Error)]
118                        enum EntryError {
119                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
120                            Visibility {
121                                binding: u32,
122                                expected: ShaderStages,
123                                assigned: ShaderStages,
124                            },
125                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
126                            Type {
127                                binding: u32,
128                                expected: BindingType,
129                                assigned: BindingType,
130                            },
131                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
132                            Count {
133                                binding: u32,
134                                expected: Option<NonZeroU32>,
135                                assigned: Option<NonZeroU32>,
136                            },
137                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
138                            ExtraExpected { binding: u32 },
139                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
140                            ExtraAssigned { binding: u32 },
141                        }
142
143                        let mut errors = Vec::new();
144
145                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
146                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
147                                if assigned_entry.visibility != expected_entry.visibility {
148                                    errors.push(EntryError::Visibility {
149                                        binding,
150                                        expected: expected_entry.visibility,
151                                        assigned: assigned_entry.visibility,
152                                    });
153                                }
154                                if assigned_entry.ty != expected_entry.ty {
155                                    errors.push(EntryError::Type {
156                                        binding,
157                                        expected: expected_entry.ty,
158                                        assigned: assigned_entry.ty,
159                                    });
160                                }
161                                if assigned_entry.count != expected_entry.count {
162                                    errors.push(EntryError::Count {
163                                        binding,
164                                        expected: expected_entry.count,
165                                        assigned: assigned_entry.count,
166                                    });
167                                }
168                            } else {
169                                errors.push(EntryError::ExtraExpected { binding });
170                            }
171                        }
172
173                        for (&binding, _) in assigned_bgl.entries.iter() {
174                            if !expected_bgl.entries.contains_key(binding) {
175                                errors.push(EntryError::ExtraAssigned { binding });
176                            }
177                        }
178
179                        Err(Error::Incompatible {
180                            expected_bgl: expected_bgl.error_ident(),
181                            assigned_bgl: assigned_bgl.error_ident(),
182                            inner: MultiError::new(errors.drain(..)).unwrap(),
183                        })
184                    }
185                } else {
186                    Err(Error::Missing)
187                }
188            } else {
189                Ok(())
190            }
191        }
192    }
193
194    #[derive(Debug, Default)]
195    pub(super) struct BoundBindGroupLayouts {
196        entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
197        rebind_start: usize,
198    }
199
200    impl BoundBindGroupLayouts {
201        pub fn new() -> Self {
202            Self {
203                entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
204                rebind_start: 0,
205            }
206        }
207
208        pub fn num_valid_entries(&self) -> usize {
209            // find first incompatible entry
210            self.entries
211                .iter()
212                .position(|e| e.is_incompatible())
213                .unwrap_or(self.entries.len())
214        }
215
216        /// Get the range of entries that needs to be rebound, and clears it.
217        pub fn take_rebind_range(&mut self) -> Range<usize> {
218            let end = self.num_valid_entries();
219            let start = self.rebind_start;
220            self.rebind_start = end;
221            start..end.max(start)
222        }
223
224        pub fn update_start_index(&mut self, start_index: usize) {
225            self.rebind_start = self.rebind_start.min(start_index);
226        }
227
228        pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
229            let start_index = self
230                .entries
231                .iter()
232                .zip(expectations)
233                .position(|(e, expect)| {
234                    e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
235                })
236                .unwrap_or(expectations.len());
237            for (e, expect) in self.entries[start_index..]
238                .iter_mut()
239                .zip(expectations[start_index..].iter())
240            {
241                e.expected = Some(expect.clone());
242            }
243            for e in self.entries[expectations.len()..].iter_mut() {
244                e.expected = None;
245            }
246            self.update_start_index(start_index);
247        }
248
249        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
250            self.entries[index].assigned = Some(value);
251            self.update_start_index(index);
252        }
253
254        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
255            self.entries
256                .iter()
257                .enumerate()
258                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
259        }
260
261        #[allow(clippy::result_large_err)]
262        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
263            for (index, entry) in self.entries.iter().enumerate() {
264                entry.check().map_err(|e| (index, e))?;
265            }
266            Ok(())
267        }
268    }
269}
270
271#[derive(Clone, Debug, Error)]
272pub enum BinderError {
273    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
274    MissingBindGroup {
275        index: usize,
276        pipeline: ResourceErrorIdent,
277    },
278    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
279    IncompatibleBindGroup {
280        expected_bgl: ResourceErrorIdent,
281        assigned_bgl: ResourceErrorIdent,
282        assigned_bg: ResourceErrorIdent,
283        index: usize,
284        pipeline: ResourceErrorIdent,
285        #[source]
286        inner: crate::error::MultiError,
287    },
288}
289
290#[derive(Debug)]
291struct LateBufferBinding {
292    binding_index: u32,
293    shader_expect_size: wgt::BufferAddress,
294    bound_size: wgt::BufferAddress,
295}
296
297#[derive(Debug, Default)]
298pub(super) struct EntryPayload {
299    pub(super) group: Option<Arc<BindGroup>>,
300    pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
301    late_buffer_bindings: Vec<LateBufferBinding>,
302    /// Since `LateBufferBinding` may contain information about the bindings
303    /// not used by the pipeline, we need to know when to stop validating.
304    pub(super) late_bindings_effective_count: usize,
305}
306
307impl EntryPayload {
308    fn reset(&mut self) {
309        self.group = None;
310        self.dynamic_offsets.clear();
311        self.late_buffer_bindings.clear();
312        self.late_bindings_effective_count = 0;
313    }
314}
315
316#[derive(Debug, Default)]
317pub(super) struct Binder {
318    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
319    manager: compat::BoundBindGroupLayouts,
320    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
321}
322
323impl Binder {
324    pub(super) fn new() -> Self {
325        Self {
326            pipeline_layout: None,
327            manager: compat::BoundBindGroupLayouts::new(),
328            payloads: Default::default(),
329        }
330    }
331    pub(super) fn reset(&mut self) {
332        self.pipeline_layout = None;
333        self.manager = compat::BoundBindGroupLayouts::new();
334        for payload in self.payloads.iter_mut() {
335            payload.reset();
336        }
337    }
338
339    /// Returns `true` if the pipeline layout has been changed, i.e. if the
340    /// new PL was not the same as the old PL.
341    pub(super) fn change_pipeline_layout<'a>(
342        &'a mut self,
343        new: &Arc<PipelineLayout>,
344        late_sized_buffer_groups: &[LateSizedBufferGroup],
345    ) -> bool {
346        if let Some(old) = self.pipeline_layout.as_ref() {
347            if old.is_equal(new) {
348                return false;
349            }
350        }
351
352        let old = self.pipeline_layout.replace(new.clone());
353
354        self.manager.update_expectations(&new.bind_group_layouts);
355
356        // Update the buffer binding sizes that are required by shaders.
357
358        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
359            payload.late_bindings_effective_count = late_group.shader_sizes.len();
360            // Update entries that already exist as the bind group was bound before the pipeline
361            // was bound.
362            for (late_binding, &shader_expect_size) in payload
363                .late_buffer_bindings
364                .iter_mut()
365                .zip(late_group.shader_sizes.iter())
366            {
367                late_binding.shader_expect_size = shader_expect_size;
368            }
369            // Add new entries for the bindings that were not known when the bind group was bound.
370            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
371                for &shader_expect_size in
372                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
373                {
374                    payload.late_buffer_bindings.push(LateBufferBinding {
375                        binding_index: 0,
376                        shader_expect_size,
377                        bound_size: 0,
378                    });
379                }
380            }
381        }
382
383        if let Some(old) = old {
384            // root constants are the base compatibility property
385            if old.immediate_size != new.immediate_size {
386                self.manager.update_start_index(0);
387            }
388        }
389
390        true
391    }
392
393    pub(super) fn assign_group<'a>(
394        &'a mut self,
395        index: usize,
396        bind_group: &Arc<BindGroup>,
397        offsets: &[wgt::DynamicOffset],
398    ) {
399        let payload = &mut self.payloads[index];
400        payload.group = Some(bind_group.clone());
401        payload.dynamic_offsets.clear();
402        payload.dynamic_offsets.extend_from_slice(offsets);
403
404        // Fill out the actual binding sizes for buffers,
405        // whose layout doesn't specify `min_binding_size`.
406
407        // Update entries that already exist as the pipeline was bound before the group
408        // was bound.
409        for (late_binding, late_info) in payload
410            .late_buffer_bindings
411            .iter_mut()
412            .zip(bind_group.late_buffer_binding_infos.iter())
413        {
414            late_binding.binding_index = late_info.binding_index;
415            late_binding.bound_size = late_info.size.get();
416        }
417
418        // Add new entries for the bindings that were not known when the pipeline was bound.
419        if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
420            for late_info in
421                bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
422            {
423                payload.late_buffer_bindings.push(LateBufferBinding {
424                    binding_index: late_info.binding_index,
425                    shader_expect_size: 0,
426                    bound_size: late_info.size.get(),
427                });
428            }
429        }
430
431        self.manager.assign(index, bind_group.layout.clone());
432    }
433
434    /// Get the range of entries that needs to be rebound, and clears it.
435    pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
436        self.manager.take_rebind_range()
437    }
438
439    pub(super) fn entries(
440        &self,
441        range: Range<usize>,
442    ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + Clone + '_ {
443        let payloads = &self.payloads[range.clone()];
444        zip(range, payloads)
445    }
446
447    pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
448        let payloads = &self.payloads;
449        self.manager
450            .list_active()
451            .map(move |index| payloads[index].group.as_ref().unwrap())
452    }
453
454    pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
455        self.payloads
456            .iter()
457            .take(self.manager.num_valid_entries())
458            .enumerate()
459    }
460
461    pub(super) fn check_compatibility<T: Labeled>(
462        &self,
463        pipeline: &T,
464    ) -> Result<(), Box<BinderError>> {
465        self.manager.get_invalid().map_err(|(index, error)| {
466            Box::new(match error {
467                compat::Error::Incompatible {
468                    expected_bgl,
469                    assigned_bgl,
470                    inner,
471                } => BinderError::IncompatibleBindGroup {
472                    expected_bgl,
473                    assigned_bgl,
474                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
475                    index,
476                    pipeline: pipeline.error_ident(),
477                    inner,
478                },
479                compat::Error::Missing => BinderError::MissingBindGroup {
480                    index,
481                    pipeline: pipeline.error_ident(),
482                },
483            })
484        })
485    }
486
487    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
488    pub(super) fn check_late_buffer_bindings(
489        &self,
490    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
491        for group_index in self.manager.list_active() {
492            let payload = &self.payloads[group_index];
493            for late_binding in
494                &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
495            {
496                if late_binding.bound_size < late_binding.shader_expect_size {
497                    return Err(LateMinBufferBindingSizeMismatch {
498                        group_index: group_index as u32,
499                        binding_index: late_binding.binding_index,
500                        shader_size: late_binding.shader_expect_size,
501                        bound_size: late_binding.bound_size,
502                    });
503                }
504            }
505        }
506        Ok(())
507    }
508}