wgpu_core/command/
bind.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use arrayvec::ArrayVec;
4use thiserror::Error;
5
6use crate::{
7    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
8    device::SHADER_STAGE_COUNT,
9    pipeline::LateSizedBufferGroup,
10    resource::{Labeled, ResourceErrorIdent},
11};
12
13mod compat {
14    use alloc::{
15        string::{String, ToString as _},
16        sync::{Arc, Weak},
17        vec::Vec,
18    };
19    use core::{num::NonZeroU32, ops::Range};
20
21    use arrayvec::ArrayVec;
22    use thiserror::Error;
23    use wgt::{BindingType, ShaderStages};
24
25    use crate::{
26        binding_model::BindGroupLayout,
27        error::MultiError,
28        resource::{Labeled, ParentDevice, ResourceErrorIdent},
29    };
30
31    pub(crate) enum Error {
32        Incompatible {
33            expected_bgl: ResourceErrorIdent,
34            assigned_bgl: ResourceErrorIdent,
35            inner: MultiError,
36        },
37        Missing,
38    }
39
40    #[derive(Debug, Clone)]
41    struct Entry {
42        assigned: Option<Arc<BindGroupLayout>>,
43        expected: Option<Arc<BindGroupLayout>>,
44    }
45
46    impl Entry {
47        fn empty() -> Self {
48            Self {
49                assigned: None,
50                expected: None,
51            }
52        }
53        fn is_active(&self) -> bool {
54            self.assigned.is_some() && self.expected.is_some()
55        }
56
57        fn is_valid(&self) -> bool {
58            if let Some(expected_bgl) = self.expected.as_ref() {
59                if let Some(assigned_bgl) = self.assigned.as_ref() {
60                    expected_bgl.is_equal(assigned_bgl)
61                } else {
62                    false
63                }
64            } else {
65                true
66            }
67        }
68
69        fn is_incompatible(&self) -> bool {
70            self.expected.is_none() || !self.is_valid()
71        }
72
73        fn check(&self) -> Result<(), Error> {
74            if let Some(expected_bgl) = self.expected.as_ref() {
75                if let Some(assigned_bgl) = self.assigned.as_ref() {
76                    if expected_bgl.is_equal(assigned_bgl) {
77                        Ok(())
78                    } else {
79                        #[derive(Clone, Debug, Error)]
80                        #[error(
81                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
82                        )]
83                        struct IncompatibleExclusivePipelines {
84                            expected: String,
85                            assigned: String,
86                        }
87
88                        use crate::binding_model::ExclusivePipeline;
89                        match (
90                            expected_bgl.exclusive_pipeline.get().unwrap(),
91                            assigned_bgl.exclusive_pipeline.get().unwrap(),
92                        ) {
93                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
94                            (
95                                ExclusivePipeline::Render(e_pipeline),
96                                ExclusivePipeline::Render(a_pipeline),
97                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
98                            (
99                                ExclusivePipeline::Compute(e_pipeline),
100                                ExclusivePipeline::Compute(a_pipeline),
101                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
102                            (expected, assigned) => {
103                                return Err(Error::Incompatible {
104                                    expected_bgl: expected_bgl.error_ident(),
105                                    assigned_bgl: assigned_bgl.error_ident(),
106                                    inner: MultiError::new(core::iter::once(
107                                        IncompatibleExclusivePipelines {
108                                            expected: expected.to_string(),
109                                            assigned: assigned.to_string(),
110                                        },
111                                    ))
112                                    .unwrap(),
113                                });
114                            }
115                        }
116
117                        #[derive(Clone, Debug, Error)]
118                        enum EntryError {
119                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
120                            Visibility {
121                                binding: u32,
122                                expected: ShaderStages,
123                                assigned: ShaderStages,
124                            },
125                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
126                            Type {
127                                binding: u32,
128                                expected: BindingType,
129                                assigned: BindingType,
130                            },
131                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
132                            Count {
133                                binding: u32,
134                                expected: Option<NonZeroU32>,
135                                assigned: Option<NonZeroU32>,
136                            },
137                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
138                            ExtraExpected { binding: u32 },
139                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
140                            ExtraAssigned { binding: u32 },
141                        }
142
143                        let mut errors = Vec::new();
144
145                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
146                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
147                                if assigned_entry.visibility != expected_entry.visibility {
148                                    errors.push(EntryError::Visibility {
149                                        binding,
150                                        expected: expected_entry.visibility,
151                                        assigned: assigned_entry.visibility,
152                                    });
153                                }
154                                if assigned_entry.ty != expected_entry.ty {
155                                    errors.push(EntryError::Type {
156                                        binding,
157                                        expected: expected_entry.ty,
158                                        assigned: assigned_entry.ty,
159                                    });
160                                }
161                                if assigned_entry.count != expected_entry.count {
162                                    errors.push(EntryError::Count {
163                                        binding,
164                                        expected: expected_entry.count,
165                                        assigned: assigned_entry.count,
166                                    });
167                                }
168                            } else {
169                                errors.push(EntryError::ExtraExpected { binding });
170                            }
171                        }
172
173                        for (&binding, _) in assigned_bgl.entries.iter() {
174                            if !expected_bgl.entries.contains_key(binding) {
175                                errors.push(EntryError::ExtraAssigned { binding });
176                            }
177                        }
178
179                        Err(Error::Incompatible {
180                            expected_bgl: expected_bgl.error_ident(),
181                            assigned_bgl: assigned_bgl.error_ident(),
182                            inner: MultiError::new(errors.drain(..)).unwrap(),
183                        })
184                    }
185                } else {
186                    Err(Error::Missing)
187                }
188            } else {
189                Ok(())
190            }
191        }
192    }
193
194    #[derive(Debug, Default)]
195    pub(crate) struct BoundBindGroupLayouts {
196        entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
197    }
198
199    impl BoundBindGroupLayouts {
200        pub fn new() -> Self {
201            Self {
202                entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
203            }
204        }
205
206        pub fn num_valid_entries(&self) -> usize {
207            // find first incompatible entry
208            self.entries
209                .iter()
210                .position(|e| e.is_incompatible())
211                .unwrap_or(self.entries.len())
212        }
213
214        fn make_range(&self, start_index: usize) -> Range<usize> {
215            let end = self.num_valid_entries();
216            start_index..end.max(start_index)
217        }
218
219        pub fn update_expectations(
220            &mut self,
221            expectations: &[Arc<BindGroupLayout>],
222        ) -> Range<usize> {
223            let start_index = self
224                .entries
225                .iter()
226                .zip(expectations)
227                .position(|(e, expect)| {
228                    e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
229                })
230                .unwrap_or(expectations.len());
231            for (e, expect) in self.entries[start_index..]
232                .iter_mut()
233                .zip(expectations[start_index..].iter())
234            {
235                e.expected = Some(expect.clone());
236            }
237            for e in self.entries[expectations.len()..].iter_mut() {
238                e.expected = None;
239            }
240            self.make_range(start_index)
241        }
242
243        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) -> Range<usize> {
244            self.entries[index].assigned = Some(value);
245            self.make_range(index)
246        }
247
248        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
249            self.entries
250                .iter()
251                .enumerate()
252                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
253        }
254
255        #[allow(clippy::result_large_err)]
256        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
257            for (index, entry) in self.entries.iter().enumerate() {
258                entry.check().map_err(|e| (index, e))?;
259            }
260            Ok(())
261        }
262    }
263}
264
265#[derive(Clone, Debug, Error)]
266pub enum BinderError {
267    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
268    MissingBindGroup {
269        index: usize,
270        pipeline: ResourceErrorIdent,
271    },
272    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
273    IncompatibleBindGroup {
274        expected_bgl: ResourceErrorIdent,
275        assigned_bgl: ResourceErrorIdent,
276        assigned_bg: ResourceErrorIdent,
277        index: usize,
278        pipeline: ResourceErrorIdent,
279        #[source]
280        inner: crate::error::MultiError,
281    },
282}
283
284#[derive(Debug)]
285struct LateBufferBinding {
286    shader_expect_size: wgt::BufferAddress,
287    bound_size: wgt::BufferAddress,
288}
289
290#[derive(Debug, Default)]
291pub(super) struct EntryPayload {
292    pub(super) group: Option<Arc<BindGroup>>,
293    pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
294    late_buffer_bindings: Vec<LateBufferBinding>,
295    /// Since `LateBufferBinding` may contain information about the bindings
296    /// not used by the pipeline, we need to know when to stop validating.
297    pub(super) late_bindings_effective_count: usize,
298}
299
300impl EntryPayload {
301    fn reset(&mut self) {
302        self.group = None;
303        self.dynamic_offsets.clear();
304        self.late_buffer_bindings.clear();
305        self.late_bindings_effective_count = 0;
306    }
307}
308
309#[derive(Debug, Default)]
310pub(super) struct Binder {
311    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
312    manager: compat::BoundBindGroupLayouts,
313    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
314}
315
316impl Binder {
317    pub(super) fn new() -> Self {
318        Self {
319            pipeline_layout: None,
320            manager: compat::BoundBindGroupLayouts::new(),
321            payloads: Default::default(),
322        }
323    }
324    pub(super) fn reset(&mut self) {
325        self.pipeline_layout = None;
326        self.manager = compat::BoundBindGroupLayouts::new();
327        for payload in self.payloads.iter_mut() {
328            payload.reset();
329        }
330    }
331
332    pub(super) fn change_pipeline_layout<'a>(
333        &'a mut self,
334        new: &Arc<PipelineLayout>,
335        late_sized_buffer_groups: &[LateSizedBufferGroup],
336    ) -> (usize, &'a [EntryPayload]) {
337        let old_id_opt = self.pipeline_layout.replace(new.clone());
338
339        let mut bind_range = self.manager.update_expectations(&new.bind_group_layouts);
340
341        // Update the buffer binding sizes that are required by shaders.
342        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
343            payload.late_bindings_effective_count = late_group.shader_sizes.len();
344            for (late_binding, &shader_expect_size) in payload
345                .late_buffer_bindings
346                .iter_mut()
347                .zip(late_group.shader_sizes.iter())
348            {
349                late_binding.shader_expect_size = shader_expect_size;
350            }
351            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
352                for &shader_expect_size in
353                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
354                {
355                    payload.late_buffer_bindings.push(LateBufferBinding {
356                        shader_expect_size,
357                        bound_size: 0,
358                    });
359                }
360            }
361        }
362
363        if let Some(old) = old_id_opt {
364            // root constants are the base compatibility property
365            if old.push_constant_ranges != new.push_constant_ranges {
366                bind_range.start = 0;
367            }
368        }
369
370        (bind_range.start, &self.payloads[bind_range])
371    }
372
373    pub(super) fn assign_group<'a>(
374        &'a mut self,
375        index: usize,
376        bind_group: &Arc<BindGroup>,
377        offsets: &[wgt::DynamicOffset],
378    ) -> &'a [EntryPayload] {
379        let payload = &mut self.payloads[index];
380        payload.group = Some(bind_group.clone());
381        payload.dynamic_offsets.clear();
382        payload.dynamic_offsets.extend_from_slice(offsets);
383
384        // Fill out the actual binding sizes for buffers,
385        // whose layout doesn't specify `min_binding_size`.
386        for (late_binding, late_size) in payload
387            .late_buffer_bindings
388            .iter_mut()
389            .zip(bind_group.late_buffer_binding_sizes.iter())
390        {
391            late_binding.bound_size = late_size.get();
392        }
393        if bind_group.late_buffer_binding_sizes.len() > payload.late_buffer_bindings.len() {
394            for late_size in
395                bind_group.late_buffer_binding_sizes[payload.late_buffer_bindings.len()..].iter()
396            {
397                payload.late_buffer_bindings.push(LateBufferBinding {
398                    shader_expect_size: 0,
399                    bound_size: late_size.get(),
400                });
401            }
402        }
403
404        let bind_range = self.manager.assign(index, bind_group.layout.clone());
405        &self.payloads[bind_range]
406    }
407
408    pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
409        let payloads = &self.payloads;
410        self.manager
411            .list_active()
412            .map(move |index| payloads[index].group.as_ref().unwrap())
413    }
414
415    pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
416        self.payloads
417            .iter()
418            .take(self.manager.num_valid_entries())
419            .enumerate()
420    }
421
422    pub(super) fn check_compatibility<T: Labeled>(
423        &self,
424        pipeline: &T,
425    ) -> Result<(), Box<BinderError>> {
426        self.manager.get_invalid().map_err(|(index, error)| {
427            Box::new(match error {
428                compat::Error::Incompatible {
429                    expected_bgl,
430                    assigned_bgl,
431                    inner,
432                } => BinderError::IncompatibleBindGroup {
433                    expected_bgl,
434                    assigned_bgl,
435                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
436                    index,
437                    pipeline: pipeline.error_ident(),
438                    inner,
439                },
440                compat::Error::Missing => BinderError::MissingBindGroup {
441                    index,
442                    pipeline: pipeline.error_ident(),
443                },
444            })
445        })
446    }
447
448    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
449    pub(super) fn check_late_buffer_bindings(
450        &self,
451    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
452        for group_index in self.manager.list_active() {
453            let payload = &self.payloads[group_index];
454            for (compact_index, late_binding) in payload.late_buffer_bindings
455                [..payload.late_bindings_effective_count]
456                .iter()
457                .enumerate()
458            {
459                if late_binding.bound_size < late_binding.shader_expect_size {
460                    return Err(LateMinBufferBindingSizeMismatch {
461                        group_index: group_index as u32,
462                        compact_index,
463                        shader_size: late_binding.shader_expect_size,
464                        bound_size: late_binding.bound_size,
465                    });
466                }
467            }
468        }
469        Ok(())
470    }
471}
472
473struct PushConstantChange {
474    stages: wgt::ShaderStages,
475    offset: u32,
476    enable: bool,
477}
478
479/// Break up possibly overlapping push constant ranges into a set of
480/// non-overlapping ranges which contain all the stage flags of the
481/// original ranges. This allows us to zero out (or write any value)
482/// to every possible value.
483pub fn compute_nonoverlapping_ranges(
484    ranges: &[wgt::PushConstantRange],
485) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
486    if ranges.is_empty() {
487        return ArrayVec::new();
488    }
489    debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
490
491    let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
492    for range in ranges {
493        breaks.push(PushConstantChange {
494            stages: range.stages,
495            offset: range.range.start,
496            enable: true,
497        });
498        breaks.push(PushConstantChange {
499            stages: range.stages,
500            offset: range.range.end,
501            enable: false,
502        });
503    }
504    breaks.sort_unstable_by_key(|change| change.offset);
505
506    let mut output_ranges = ArrayVec::new();
507    let mut position = 0_u32;
508    let mut stages = wgt::ShaderStages::NONE;
509
510    for bk in breaks {
511        if bk.offset - position > 0 && !stages.is_empty() {
512            output_ranges.push(wgt::PushConstantRange {
513                stages,
514                range: position..bk.offset,
515            })
516        }
517        position = bk.offset;
518        stages.set(bk.stages, bk.enable);
519    }
520
521    output_ranges
522}