wgpu_core/command/
bind.rs

1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use arrayvec::ArrayVec;
6use thiserror::Error;
7
8use crate::{
9    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
10    device::SHADER_STAGE_COUNT,
11    pipeline::LateSizedBufferGroup,
12    resource::{Labeled, ResourceErrorIdent},
13};
14
15mod compat {
16    use alloc::{
17        string::{String, ToString as _},
18        sync::{Arc, Weak},
19        vec::Vec,
20    };
21    use core::{num::NonZeroU32, ops::Range};
22
23    use arrayvec::ArrayVec;
24    use thiserror::Error;
25    use wgt::{BindingType, ShaderStages};
26
27    use crate::{
28        binding_model::BindGroupLayout,
29        error::MultiError,
30        resource::{Labeled, ParentDevice, ResourceErrorIdent},
31    };
32
33    pub(crate) enum Error {
34        Incompatible {
35            expected_bgl: ResourceErrorIdent,
36            assigned_bgl: ResourceErrorIdent,
37            inner: MultiError,
38        },
39        Missing,
40    }
41
42    #[derive(Debug, Clone)]
43    struct Entry {
44        assigned: Option<Arc<BindGroupLayout>>,
45        expected: Option<Arc<BindGroupLayout>>,
46    }
47
48    impl Entry {
49        fn empty() -> Self {
50            Self {
51                assigned: None,
52                expected: None,
53            }
54        }
55        fn is_active(&self) -> bool {
56            self.assigned.is_some() && self.expected.is_some()
57        }
58
59        fn is_valid(&self) -> bool {
60            if let Some(expected_bgl) = self.expected.as_ref() {
61                if let Some(assigned_bgl) = self.assigned.as_ref() {
62                    expected_bgl.is_equal(assigned_bgl)
63                } else {
64                    false
65                }
66            } else {
67                true
68            }
69        }
70
71        fn is_incompatible(&self) -> bool {
72            self.expected.is_none() || !self.is_valid()
73        }
74
75        fn check(&self) -> Result<(), Error> {
76            if let Some(expected_bgl) = self.expected.as_ref() {
77                if let Some(assigned_bgl) = self.assigned.as_ref() {
78                    if expected_bgl.is_equal(assigned_bgl) {
79                        Ok(())
80                    } else {
81                        #[derive(Clone, Debug, Error)]
82                        #[error(
83                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
84                        )]
85                        struct IncompatibleExclusivePipelines {
86                            expected: String,
87                            assigned: String,
88                        }
89
90                        use crate::binding_model::ExclusivePipeline;
91                        match (
92                            expected_bgl.exclusive_pipeline.get().unwrap(),
93                            assigned_bgl.exclusive_pipeline.get().unwrap(),
94                        ) {
95                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
96                            (
97                                ExclusivePipeline::Render(e_pipeline),
98                                ExclusivePipeline::Render(a_pipeline),
99                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100                            (
101                                ExclusivePipeline::Compute(e_pipeline),
102                                ExclusivePipeline::Compute(a_pipeline),
103                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
104                            (expected, assigned) => {
105                                return Err(Error::Incompatible {
106                                    expected_bgl: expected_bgl.error_ident(),
107                                    assigned_bgl: assigned_bgl.error_ident(),
108                                    inner: MultiError::new(core::iter::once(
109                                        IncompatibleExclusivePipelines {
110                                            expected: expected.to_string(),
111                                            assigned: assigned.to_string(),
112                                        },
113                                    ))
114                                    .unwrap(),
115                                });
116                            }
117                        }
118
119                        #[derive(Clone, Debug, Error)]
120                        enum EntryError {
121                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
122                            Visibility {
123                                binding: u32,
124                                expected: ShaderStages,
125                                assigned: ShaderStages,
126                            },
127                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
128                            Type {
129                                binding: u32,
130                                expected: BindingType,
131                                assigned: BindingType,
132                            },
133                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
134                            Count {
135                                binding: u32,
136                                expected: Option<NonZeroU32>,
137                                assigned: Option<NonZeroU32>,
138                            },
139                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
140                            ExtraExpected { binding: u32 },
141                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
142                            ExtraAssigned { binding: u32 },
143                        }
144
145                        let mut errors = Vec::new();
146
147                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
148                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
149                                if assigned_entry.visibility != expected_entry.visibility {
150                                    errors.push(EntryError::Visibility {
151                                        binding,
152                                        expected: expected_entry.visibility,
153                                        assigned: assigned_entry.visibility,
154                                    });
155                                }
156                                if assigned_entry.ty != expected_entry.ty {
157                                    errors.push(EntryError::Type {
158                                        binding,
159                                        expected: expected_entry.ty,
160                                        assigned: assigned_entry.ty,
161                                    });
162                                }
163                                if assigned_entry.count != expected_entry.count {
164                                    errors.push(EntryError::Count {
165                                        binding,
166                                        expected: expected_entry.count,
167                                        assigned: assigned_entry.count,
168                                    });
169                                }
170                            } else {
171                                errors.push(EntryError::ExtraExpected { binding });
172                            }
173                        }
174
175                        for (&binding, _) in assigned_bgl.entries.iter() {
176                            if !expected_bgl.entries.contains_key(binding) {
177                                errors.push(EntryError::ExtraAssigned { binding });
178                            }
179                        }
180
181                        Err(Error::Incompatible {
182                            expected_bgl: expected_bgl.error_ident(),
183                            assigned_bgl: assigned_bgl.error_ident(),
184                            inner: MultiError::new(errors.drain(..)).unwrap(),
185                        })
186                    }
187                } else {
188                    Err(Error::Missing)
189                }
190            } else {
191                Ok(())
192            }
193        }
194    }
195
196    #[derive(Debug, Default)]
197    pub(super) struct BoundBindGroupLayouts {
198        entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
199        rebind_start: usize,
200    }
201
202    impl BoundBindGroupLayouts {
203        pub fn new() -> Self {
204            Self {
205                entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
206                rebind_start: 0,
207            }
208        }
209
210        pub fn num_valid_entries(&self) -> usize {
211            // find first incompatible entry
212            self.entries
213                .iter()
214                .position(|e| e.is_incompatible())
215                .unwrap_or(self.entries.len())
216        }
217
218        /// Get the range of entries that needs to be rebound, and clears it.
219        pub fn take_rebind_range(&mut self) -> Range<usize> {
220            let end = self.num_valid_entries();
221            let start = self.rebind_start;
222            self.rebind_start = end;
223            start..end.max(start)
224        }
225
226        pub fn update_start_index(&mut self, start_index: usize) {
227            self.rebind_start = self.rebind_start.min(start_index);
228        }
229
230        pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
231            let start_index = self
232                .entries
233                .iter()
234                .zip(expectations)
235                .position(|(e, expect)| {
236                    e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
237                })
238                .unwrap_or(expectations.len());
239            for (e, expect) in self.entries[start_index..]
240                .iter_mut()
241                .zip(expectations[start_index..].iter())
242            {
243                e.expected = Some(expect.clone());
244            }
245            for e in self.entries[expectations.len()..].iter_mut() {
246                e.expected = None;
247            }
248            self.update_start_index(start_index);
249        }
250
251        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
252            self.entries[index].assigned = Some(value);
253            self.update_start_index(index);
254        }
255
256        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
257            self.entries
258                .iter()
259                .enumerate()
260                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
261        }
262
263        #[allow(clippy::result_large_err)]
264        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
265            for (index, entry) in self.entries.iter().enumerate() {
266                entry.check().map_err(|e| (index, e))?;
267            }
268            Ok(())
269        }
270    }
271}
272
273#[derive(Clone, Debug, Error)]
274pub enum BinderError {
275    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
276    MissingBindGroup {
277        index: usize,
278        pipeline: ResourceErrorIdent,
279    },
280    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
281    IncompatibleBindGroup {
282        expected_bgl: ResourceErrorIdent,
283        assigned_bgl: ResourceErrorIdent,
284        assigned_bg: ResourceErrorIdent,
285        index: usize,
286        pipeline: ResourceErrorIdent,
287        #[source]
288        inner: crate::error::MultiError,
289    },
290}
291
292#[derive(Debug)]
293struct LateBufferBinding {
294    shader_expect_size: wgt::BufferAddress,
295    bound_size: wgt::BufferAddress,
296}
297
298#[derive(Debug, Default)]
299pub(super) struct EntryPayload {
300    pub(super) group: Option<Arc<BindGroup>>,
301    pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
302    late_buffer_bindings: Vec<LateBufferBinding>,
303    /// Since `LateBufferBinding` may contain information about the bindings
304    /// not used by the pipeline, we need to know when to stop validating.
305    pub(super) late_bindings_effective_count: usize,
306}
307
308impl EntryPayload {
309    fn reset(&mut self) {
310        self.group = None;
311        self.dynamic_offsets.clear();
312        self.late_buffer_bindings.clear();
313        self.late_bindings_effective_count = 0;
314    }
315}
316
317#[derive(Debug, Default)]
318pub(super) struct Binder {
319    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
320    manager: compat::BoundBindGroupLayouts,
321    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
322}
323
324impl Binder {
325    pub(super) fn new() -> Self {
326        Self {
327            pipeline_layout: None,
328            manager: compat::BoundBindGroupLayouts::new(),
329            payloads: Default::default(),
330        }
331    }
332    pub(super) fn reset(&mut self) {
333        self.pipeline_layout = None;
334        self.manager = compat::BoundBindGroupLayouts::new();
335        for payload in self.payloads.iter_mut() {
336            payload.reset();
337        }
338    }
339
340    pub(super) fn change_pipeline_layout<'a>(
341        &'a mut self,
342        new: &Arc<PipelineLayout>,
343        late_sized_buffer_groups: &[LateSizedBufferGroup],
344    ) {
345        let old_id_opt = self.pipeline_layout.replace(new.clone());
346
347        self.manager.update_expectations(&new.bind_group_layouts);
348
349        // Update the buffer binding sizes that are required by shaders.
350        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
351            payload.late_bindings_effective_count = late_group.shader_sizes.len();
352            for (late_binding, &shader_expect_size) in payload
353                .late_buffer_bindings
354                .iter_mut()
355                .zip(late_group.shader_sizes.iter())
356            {
357                late_binding.shader_expect_size = shader_expect_size;
358            }
359            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
360                for &shader_expect_size in
361                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
362                {
363                    payload.late_buffer_bindings.push(LateBufferBinding {
364                        shader_expect_size,
365                        bound_size: 0,
366                    });
367                }
368            }
369        }
370
371        if let Some(old) = old_id_opt {
372            // root constants are the base compatibility property
373            if old.push_constant_ranges != new.push_constant_ranges {
374                self.manager.update_start_index(0);
375            }
376        }
377    }
378
379    pub(super) fn assign_group<'a>(
380        &'a mut self,
381        index: usize,
382        bind_group: &Arc<BindGroup>,
383        offsets: &[wgt::DynamicOffset],
384    ) {
385        let payload = &mut self.payloads[index];
386        payload.group = Some(bind_group.clone());
387        payload.dynamic_offsets.clear();
388        payload.dynamic_offsets.extend_from_slice(offsets);
389
390        // Fill out the actual binding sizes for buffers,
391        // whose layout doesn't specify `min_binding_size`.
392        for (late_binding, late_size) in payload
393            .late_buffer_bindings
394            .iter_mut()
395            .zip(bind_group.late_buffer_binding_sizes.iter())
396        {
397            late_binding.bound_size = late_size.get();
398        }
399        if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
400            for late_size in
401                bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
402            {
403                payload.late_buffer_bindings.push(LateBufferBinding {
404                    shader_expect_size: 0,
405                    bound_size: late_size.get(),
406                });
407            }
408        }
409
410        self.manager.assign(index, bind_group.layout.clone());
411    }
412
413    /// Get the range of entries that needs to be rebound, and clears it.
414    pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
415        self.manager.take_rebind_range()
416    }
417
418    pub(super) fn entries(
419        &self,
420        range: Range<usize>,
421    ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + '_ {
422        let payloads = &self.payloads[range.clone()];
423        zip(range, payloads)
424    }
425
426    pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
427        let payloads = &self.payloads;
428        self.manager
429            .list_active()
430            .map(move |index| payloads[index].group.as_ref().unwrap())
431    }
432
433    pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
434        self.payloads
435            .iter()
436            .take(self.manager.num_valid_entries())
437            .enumerate()
438    }
439
440    pub(super) fn check_compatibility<T: Labeled>(
441        &self,
442        pipeline: &T,
443    ) -> Result<(), Box<BinderError>> {
444        self.manager.get_invalid().map_err(|(index, error)| {
445            Box::new(match error {
446                compat::Error::Incompatible {
447                    expected_bgl,
448                    assigned_bgl,
449                    inner,
450                } => BinderError::IncompatibleBindGroup {
451                    expected_bgl,
452                    assigned_bgl,
453                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
454                    index,
455                    pipeline: pipeline.error_ident(),
456                    inner,
457                },
458                compat::Error::Missing => BinderError::MissingBindGroup {
459                    index,
460                    pipeline: pipeline.error_ident(),
461                },
462            })
463        })
464    }
465
466    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
467    pub(super) fn check_late_buffer_bindings(
468        &self,
469    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
470        for group_index in self.manager.list_active() {
471            let payload = &self.payloads[group_index];
472            for (compact_index, late_binding) in payload.late_buffer_bindings
473                [..payload.late_bindings_effective_count]
474                .iter()
475                .enumerate()
476            {
477                if late_binding.bound_size < late_binding.shader_expect_size {
478                    return Err(LateMinBufferBindingSizeMismatch {
479                        group_index: group_index as u32,
480                        compact_index,
481                        shader_size: late_binding.shader_expect_size,
482                        bound_size: late_binding.bound_size,
483                    });
484                }
485            }
486        }
487        Ok(())
488    }
489}
490
491struct PushConstantChange {
492    stages: wgt::ShaderStages,
493    offset: u32,
494    enable: bool,
495}
496
497/// Break up possibly overlapping push constant ranges into a set of
498/// non-overlapping ranges which contain all the stage flags of the
499/// original ranges. This allows us to zero out (or write any value)
500/// to every possible value.
501pub fn compute_nonoverlapping_ranges(
502    ranges: &[wgt::PushConstantRange],
503) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
504    if ranges.is_empty() {
505        return ArrayVec::new();
506    }
507    debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
508
509    let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
510    for range in ranges {
511        breaks.push(PushConstantChange {
512            stages: range.stages,
513            offset: range.range.start,
514            enable: true,
515        });
516        breaks.push(PushConstantChange {
517            stages: range.stages,
518            offset: range.range.end,
519            enable: false,
520        });
521    }
522    breaks.sort_unstable_by_key(|change| change.offset);
523
524    let mut output_ranges = ArrayVec::new();
525    let mut position = 0_u32;
526    let mut stages = wgt::ShaderStages::NONE;
527
528    for bk in breaks {
529        if bk.offset - position > 0 && !stages.is_empty() {
530            output_ranges.push(wgt::PushConstantRange {
531                stages,
532                range: position..bk.offset,
533            })
534        }
535        position = bk.offset;
536        stages.set(bk.stages, bk.enable);
537    }
538
539    output_ranges
540}