wgpu_core/command/
bind.rs

1use core::{iter::zip, ops::Range};
2
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4
5use arrayvec::ArrayVec;
6use thiserror::Error;
7
8use crate::{
9    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
10    device::SHADER_STAGE_COUNT,
11    pipeline::LateSizedBufferGroup,
12    resource::{Labeled, ResourceErrorIdent},
13};
14
15mod compat {
16    use alloc::{
17        string::{String, ToString as _},
18        sync::{Arc, Weak},
19        vec::Vec,
20    };
21    use core::{num::NonZeroU32, ops::Range};
22
23    use arrayvec::ArrayVec;
24    use thiserror::Error;
25    use wgt::{BindingType, ShaderStages};
26
27    use crate::{
28        binding_model::BindGroupLayout,
29        error::MultiError,
30        resource::{Labeled, ParentDevice, ResourceErrorIdent},
31    };
32
33    pub(crate) enum Error {
34        Incompatible {
35            expected_bgl: ResourceErrorIdent,
36            assigned_bgl: ResourceErrorIdent,
37            inner: MultiError,
38        },
39        Missing,
40    }
41
42    #[derive(Debug, Clone)]
43    struct Entry {
44        assigned: Option<Arc<BindGroupLayout>>,
45        expected: Option<Arc<BindGroupLayout>>,
46    }
47
48    impl Entry {
49        fn empty() -> Self {
50            Self {
51                assigned: None,
52                expected: None,
53            }
54        }
55        fn is_active(&self) -> bool {
56            self.assigned.is_some() && self.expected.is_some()
57        }
58
59        fn is_valid(&self) -> bool {
60            if let Some(expected_bgl) = self.expected.as_ref() {
61                if let Some(assigned_bgl) = self.assigned.as_ref() {
62                    expected_bgl.is_equal(assigned_bgl)
63                } else {
64                    false
65                }
66            } else {
67                true
68            }
69        }
70
71        fn is_incompatible(&self) -> bool {
72            self.expected.is_none() || !self.is_valid()
73        }
74
75        fn check(&self) -> Result<(), Error> {
76            if let Some(expected_bgl) = self.expected.as_ref() {
77                if let Some(assigned_bgl) = self.assigned.as_ref() {
78                    if expected_bgl.is_equal(assigned_bgl) {
79                        Ok(())
80                    } else {
81                        #[derive(Clone, Debug, Error)]
82                        #[error(
83                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
84                        )]
85                        struct IncompatibleExclusivePipelines {
86                            expected: String,
87                            assigned: String,
88                        }
89
90                        use crate::binding_model::ExclusivePipeline;
91                        match (
92                            expected_bgl.exclusive_pipeline.get().unwrap(),
93                            assigned_bgl.exclusive_pipeline.get().unwrap(),
94                        ) {
95                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
96                            (
97                                ExclusivePipeline::Render(e_pipeline),
98                                ExclusivePipeline::Render(a_pipeline),
99                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
100                            (
101                                ExclusivePipeline::Compute(e_pipeline),
102                                ExclusivePipeline::Compute(a_pipeline),
103                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
104                            (expected, assigned) => {
105                                return Err(Error::Incompatible {
106                                    expected_bgl: expected_bgl.error_ident(),
107                                    assigned_bgl: assigned_bgl.error_ident(),
108                                    inner: MultiError::new(core::iter::once(
109                                        IncompatibleExclusivePipelines {
110                                            expected: expected.to_string(),
111                                            assigned: assigned.to_string(),
112                                        },
113                                    ))
114                                    .unwrap(),
115                                });
116                            }
117                        }
118
119                        #[derive(Clone, Debug, Error)]
120                        enum EntryError {
121                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
122                            Visibility {
123                                binding: u32,
124                                expected: ShaderStages,
125                                assigned: ShaderStages,
126                            },
127                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
128                            Type {
129                                binding: u32,
130                                expected: BindingType,
131                                assigned: BindingType,
132                            },
133                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
134                            Count {
135                                binding: u32,
136                                expected: Option<NonZeroU32>,
137                                assigned: Option<NonZeroU32>,
138                            },
139                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
140                            ExtraExpected { binding: u32 },
141                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
142                            ExtraAssigned { binding: u32 },
143                        }
144
145                        let mut errors = Vec::new();
146
147                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
148                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
149                                if assigned_entry.visibility != expected_entry.visibility {
150                                    errors.push(EntryError::Visibility {
151                                        binding,
152                                        expected: expected_entry.visibility,
153                                        assigned: assigned_entry.visibility,
154                                    });
155                                }
156                                if assigned_entry.ty != expected_entry.ty {
157                                    errors.push(EntryError::Type {
158                                        binding,
159                                        expected: expected_entry.ty,
160                                        assigned: assigned_entry.ty,
161                                    });
162                                }
163                                if assigned_entry.count != expected_entry.count {
164                                    errors.push(EntryError::Count {
165                                        binding,
166                                        expected: expected_entry.count,
167                                        assigned: assigned_entry.count,
168                                    });
169                                }
170                            } else {
171                                errors.push(EntryError::ExtraExpected { binding });
172                            }
173                        }
174
175                        for (&binding, _) in assigned_bgl.entries.iter() {
176                            if !expected_bgl.entries.contains_key(binding) {
177                                errors.push(EntryError::ExtraAssigned { binding });
178                            }
179                        }
180
181                        Err(Error::Incompatible {
182                            expected_bgl: expected_bgl.error_ident(),
183                            assigned_bgl: assigned_bgl.error_ident(),
184                            inner: MultiError::new(errors.drain(..)).unwrap(),
185                        })
186                    }
187                } else {
188                    Err(Error::Missing)
189                }
190            } else {
191                Ok(())
192            }
193        }
194    }
195
196    #[derive(Debug, Default)]
197    pub(super) struct BoundBindGroupLayouts {
198        entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
199        rebind_start: usize,
200    }
201
202    impl BoundBindGroupLayouts {
203        pub fn new() -> Self {
204            Self {
205                entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
206                rebind_start: 0,
207            }
208        }
209
210        pub fn num_valid_entries(&self) -> usize {
211            // find first incompatible entry
212            self.entries
213                .iter()
214                .position(|e| e.is_incompatible())
215                .unwrap_or(self.entries.len())
216        }
217
218        /// Get the range of entries that needs to be rebound, and clears it.
219        pub fn take_rebind_range(&mut self) -> Range<usize> {
220            let end = self.num_valid_entries();
221            let start = self.rebind_start;
222            self.rebind_start = end;
223            start..end.max(start)
224        }
225
226        pub fn update_start_index(&mut self, start_index: usize) {
227            self.rebind_start = self.rebind_start.min(start_index);
228        }
229
230        pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
231            let start_index = self
232                .entries
233                .iter()
234                .zip(expectations)
235                .position(|(e, expect)| {
236                    e.expected.is_none() || !e.expected.as_ref().unwrap().is_equal(expect)
237                })
238                .unwrap_or(expectations.len());
239            for (e, expect) in self.entries[start_index..]
240                .iter_mut()
241                .zip(expectations[start_index..].iter())
242            {
243                e.expected = Some(expect.clone());
244            }
245            for e in self.entries[expectations.len()..].iter_mut() {
246                e.expected = None;
247            }
248            self.update_start_index(start_index);
249        }
250
251        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
252            self.entries[index].assigned = Some(value);
253            self.update_start_index(index);
254        }
255
256        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
257            self.entries
258                .iter()
259                .enumerate()
260                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
261        }
262
263        #[allow(clippy::result_large_err)]
264        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
265            for (index, entry) in self.entries.iter().enumerate() {
266                entry.check().map_err(|e| (index, e))?;
267            }
268            Ok(())
269        }
270    }
271}
272
273#[derive(Clone, Debug, Error)]
274pub enum BinderError {
275    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
276    MissingBindGroup {
277        index: usize,
278        pipeline: ResourceErrorIdent,
279    },
280    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
281    IncompatibleBindGroup {
282        expected_bgl: ResourceErrorIdent,
283        assigned_bgl: ResourceErrorIdent,
284        assigned_bg: ResourceErrorIdent,
285        index: usize,
286        pipeline: ResourceErrorIdent,
287        #[source]
288        inner: crate::error::MultiError,
289    },
290}
291
292#[derive(Debug)]
293struct LateBufferBinding {
294    binding_index: u32,
295    shader_expect_size: wgt::BufferAddress,
296    bound_size: wgt::BufferAddress,
297}
298
299#[derive(Debug, Default)]
300pub(super) struct EntryPayload {
301    pub(super) group: Option<Arc<BindGroup>>,
302    pub(super) dynamic_offsets: Vec<wgt::DynamicOffset>,
303    late_buffer_bindings: Vec<LateBufferBinding>,
304    /// Since `LateBufferBinding` may contain information about the bindings
305    /// not used by the pipeline, we need to know when to stop validating.
306    pub(super) late_bindings_effective_count: usize,
307}
308
309impl EntryPayload {
310    fn reset(&mut self) {
311        self.group = None;
312        self.dynamic_offsets.clear();
313        self.late_buffer_bindings.clear();
314        self.late_bindings_effective_count = 0;
315    }
316}
317
318#[derive(Debug, Default)]
319pub(super) struct Binder {
320    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
321    manager: compat::BoundBindGroupLayouts,
322    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
323}
324
325impl Binder {
326    pub(super) fn new() -> Self {
327        Self {
328            pipeline_layout: None,
329            manager: compat::BoundBindGroupLayouts::new(),
330            payloads: Default::default(),
331        }
332    }
333    pub(super) fn reset(&mut self) {
334        self.pipeline_layout = None;
335        self.manager = compat::BoundBindGroupLayouts::new();
336        for payload in self.payloads.iter_mut() {
337            payload.reset();
338        }
339    }
340
341    pub(super) fn change_pipeline_layout<'a>(
342        &'a mut self,
343        new: &Arc<PipelineLayout>,
344        late_sized_buffer_groups: &[LateSizedBufferGroup],
345    ) {
346        let old_id_opt = self.pipeline_layout.replace(new.clone());
347
348        self.manager.update_expectations(&new.bind_group_layouts);
349
350        // Update the buffer binding sizes that are required by shaders.
351
352        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
353            payload.late_bindings_effective_count = late_group.shader_sizes.len();
354            // Update entries that already exist as the bind group was bound before the pipeline
355            // was bound.
356            for (late_binding, &shader_expect_size) in payload
357                .late_buffer_bindings
358                .iter_mut()
359                .zip(late_group.shader_sizes.iter())
360            {
361                late_binding.shader_expect_size = shader_expect_size;
362            }
363            // Add new entries for the bindings that were not known when the bind group was bound.
364            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
365                for &shader_expect_size in
366                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
367                {
368                    payload.late_buffer_bindings.push(LateBufferBinding {
369                        binding_index: 0,
370                        shader_expect_size,
371                        bound_size: 0,
372                    });
373                }
374            }
375        }
376
377        if let Some(old) = old_id_opt {
378            // root constants are the base compatibility property
379            if old.push_constant_ranges != new.push_constant_ranges {
380                self.manager.update_start_index(0);
381            }
382        }
383    }
384
385    pub(super) fn assign_group<'a>(
386        &'a mut self,
387        index: usize,
388        bind_group: &Arc<BindGroup>,
389        offsets: &[wgt::DynamicOffset],
390    ) {
391        let payload = &mut self.payloads[index];
392        payload.group = Some(bind_group.clone());
393        payload.dynamic_offsets.clear();
394        payload.dynamic_offsets.extend_from_slice(offsets);
395
396        // Fill out the actual binding sizes for buffers,
397        // whose layout doesn't specify `min_binding_size`.
398
399        // Update entries that already exist as the pipeline was bound before the group
400        // was bound.
401        for (late_binding, late_info) in payload
402            .late_buffer_bindings
403            .iter_mut()
404            .zip(bind_group.late_buffer_binding_infos.iter())
405        {
406            late_binding.binding_index = late_info.binding_index;
407            late_binding.bound_size = late_info.size.get();
408        }
409
410        // Add new entries for the bindings that were not known when the pipeline was bound.
411        if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
412            for late_info in
413                bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
414            {
415                payload.late_buffer_bindings.push(LateBufferBinding {
416                    binding_index: late_info.binding_index,
417                    shader_expect_size: 0,
418                    bound_size: late_info.size.get(),
419                });
420            }
421        }
422
423        self.manager.assign(index, bind_group.layout.clone());
424    }
425
426    /// Get the range of entries that needs to be rebound, and clears it.
427    pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
428        self.manager.take_rebind_range()
429    }
430
431    pub(super) fn entries(
432        &self,
433        range: Range<usize>,
434    ) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + Clone + '_ {
435        let payloads = &self.payloads[range.clone()];
436        zip(range, payloads)
437    }
438
439    pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {
440        let payloads = &self.payloads;
441        self.manager
442            .list_active()
443            .map(move |index| payloads[index].group.as_ref().unwrap())
444    }
445
446    pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + 'a {
447        self.payloads
448            .iter()
449            .take(self.manager.num_valid_entries())
450            .enumerate()
451    }
452
453    pub(super) fn check_compatibility<T: Labeled>(
454        &self,
455        pipeline: &T,
456    ) -> Result<(), Box<BinderError>> {
457        self.manager.get_invalid().map_err(|(index, error)| {
458            Box::new(match error {
459                compat::Error::Incompatible {
460                    expected_bgl,
461                    assigned_bgl,
462                    inner,
463                } => BinderError::IncompatibleBindGroup {
464                    expected_bgl,
465                    assigned_bgl,
466                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
467                    index,
468                    pipeline: pipeline.error_ident(),
469                    inner,
470                },
471                compat::Error::Missing => BinderError::MissingBindGroup {
472                    index,
473                    pipeline: pipeline.error_ident(),
474                },
475            })
476        })
477    }
478
479    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
480    pub(super) fn check_late_buffer_bindings(
481        &self,
482    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
483        for group_index in self.manager.list_active() {
484            let payload = &self.payloads[group_index];
485            for late_binding in
486                &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
487            {
488                if late_binding.bound_size < late_binding.shader_expect_size {
489                    return Err(LateMinBufferBindingSizeMismatch {
490                        group_index: group_index as u32,
491                        binding_index: late_binding.binding_index,
492                        shader_size: late_binding.shader_expect_size,
493                        bound_size: late_binding.bound_size,
494                    });
495                }
496            }
497        }
498        Ok(())
499    }
500}
501
502struct PushConstantChange {
503    stages: wgt::ShaderStages,
504    offset: u32,
505    enable: bool,
506}
507
508/// Break up possibly overlapping push constant ranges into a set of
509/// non-overlapping ranges which contain all the stage flags of the
510/// original ranges. This allows us to zero out (or write any value)
511/// to every possible value.
512pub fn compute_nonoverlapping_ranges(
513    ranges: &[wgt::PushConstantRange],
514) -> ArrayVec<wgt::PushConstantRange, { SHADER_STAGE_COUNT * 2 }> {
515    if ranges.is_empty() {
516        return ArrayVec::new();
517    }
518    debug_assert!(ranges.len() <= SHADER_STAGE_COUNT);
519
520    let mut breaks: ArrayVec<PushConstantChange, { SHADER_STAGE_COUNT * 2 }> = ArrayVec::new();
521    for range in ranges {
522        breaks.push(PushConstantChange {
523            stages: range.stages,
524            offset: range.range.start,
525            enable: true,
526        });
527        breaks.push(PushConstantChange {
528            stages: range.stages,
529            offset: range.range.end,
530            enable: false,
531        });
532    }
533    breaks.sort_unstable_by_key(|change| change.offset);
534
535    let mut output_ranges = ArrayVec::new();
536    let mut position = 0_u32;
537    let mut stages = wgt::ShaderStages::NONE;
538
539    for bk in breaks {
540        if bk.offset - position > 0 && !stages.is_empty() {
541            output_ranges.push(wgt::PushConstantRange {
542                stages,
543                range: position..bk.offset,
544            })
545        }
546        position = bk.offset;
547        stages.set(bk.stages, bk.enable);
548    }
549
550    output_ranges
551}