wgpu_core/command/
bind.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use thiserror::Error;
4
5use crate::{
6    binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
7    pipeline::LateSizedBufferGroup,
8    resource::{Labeled, ParentDevice, ResourceErrorIdent},
9};
10
11mod compat {
12    use alloc::{
13        string::{String, ToString as _},
14        sync::{Arc, Weak},
15        vec::Vec,
16    };
17    use core::num::NonZeroU32;
18
19    use thiserror::Error;
20    use wgt::{BindingType, ShaderStages};
21
22    use crate::{
23        binding_model::BindGroupLayout,
24        error::MultiError,
25        resource::{Labeled, ParentDevice, ResourceErrorIdent},
26    };
27
28    pub(crate) enum Error {
29        Incompatible {
30            expected_bgl: ResourceErrorIdent,
31            assigned_bgl: ResourceErrorIdent,
32            inner: MultiError,
33        },
34        Missing,
35    }
36
37    #[derive(Debug, Clone)]
38    struct Entry {
39        assigned: Option<Arc<BindGroupLayout>>,
40        expected: Option<Arc<BindGroupLayout>>,
41    }
42
43    impl Entry {
44        const fn empty() -> Self {
45            Self {
46                assigned: None,
47                expected: None,
48            }
49        }
50
51        fn is_active(&self) -> bool {
52            self.assigned.is_some() && self.expected.is_some()
53        }
54
55        fn is_valid(&self) -> bool {
56            if let Some(expected_bgl) = self.expected.as_ref() {
57                if let Some(assigned_bgl) = self.assigned.as_ref() {
58                    expected_bgl.is_equal(assigned_bgl)
59                } else {
60                    false
61                }
62            } else {
63                false
64            }
65        }
66
67        fn check(&self) -> Result<(), Error> {
68            if let Some(expected_bgl) = self.expected.as_ref() {
69                if let Some(assigned_bgl) = self.assigned.as_ref() {
70                    if expected_bgl.is_equal(assigned_bgl) {
71                        Ok(())
72                    } else {
73                        #[derive(Clone, Debug, Error)]
74                        #[error(
75                            "Exclusive pipelines don't match: expected {expected}, got {assigned}"
76                        )]
77                        struct IncompatibleExclusivePipelines {
78                            expected: String,
79                            assigned: String,
80                        }
81
82                        use crate::binding_model::ExclusivePipeline;
83                        match (
84                            expected_bgl.exclusive_pipeline.get().unwrap(),
85                            assigned_bgl.exclusive_pipeline.get().unwrap(),
86                        ) {
87                            (ExclusivePipeline::None, ExclusivePipeline::None) => {}
88                            (
89                                ExclusivePipeline::Render(e_pipeline),
90                                ExclusivePipeline::Render(a_pipeline),
91                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
92                            (
93                                ExclusivePipeline::Compute(e_pipeline),
94                                ExclusivePipeline::Compute(a_pipeline),
95                            ) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
96                            (expected, assigned) => {
97                                return Err(Error::Incompatible {
98                                    expected_bgl: expected_bgl.error_ident(),
99                                    assigned_bgl: assigned_bgl.error_ident(),
100                                    inner: MultiError::new(core::iter::once(
101                                        IncompatibleExclusivePipelines {
102                                            expected: expected.to_string(),
103                                            assigned: assigned.to_string(),
104                                        },
105                                    ))
106                                    .unwrap(),
107                                });
108                            }
109                        }
110
111                        #[derive(Clone, Debug, Error)]
112                        enum EntryError {
113                            #[error("Entries with binding {binding} differ in visibility: expected {expected:?}, got {assigned:?}")]
114                            Visibility {
115                                binding: u32,
116                                expected: ShaderStages,
117                                assigned: ShaderStages,
118                            },
119                            #[error("Entries with binding {binding} differ in type: expected {expected:?}, got {assigned:?}")]
120                            Type {
121                                binding: u32,
122                                expected: BindingType,
123                                assigned: BindingType,
124                            },
125                            #[error("Entries with binding {binding} differ in count: expected {expected:?}, got {assigned:?}")]
126                            Count {
127                                binding: u32,
128                                expected: Option<NonZeroU32>,
129                                assigned: Option<NonZeroU32>,
130                            },
131                            #[error("Expected entry with binding {binding} not found in assigned bind group layout")]
132                            ExtraExpected { binding: u32 },
133                            #[error("Assigned entry with binding {binding} not found in expected bind group layout")]
134                            ExtraAssigned { binding: u32 },
135                        }
136
137                        let mut errors = Vec::new();
138
139                        for (&binding, expected_entry) in expected_bgl.entries.iter() {
140                            if let Some(assigned_entry) = assigned_bgl.entries.get(binding) {
141                                if assigned_entry.visibility != expected_entry.visibility {
142                                    errors.push(EntryError::Visibility {
143                                        binding,
144                                        expected: expected_entry.visibility,
145                                        assigned: assigned_entry.visibility,
146                                    });
147                                }
148                                if assigned_entry.ty != expected_entry.ty {
149                                    errors.push(EntryError::Type {
150                                        binding,
151                                        expected: expected_entry.ty,
152                                        assigned: assigned_entry.ty,
153                                    });
154                                }
155                                if assigned_entry.count != expected_entry.count {
156                                    errors.push(EntryError::Count {
157                                        binding,
158                                        expected: expected_entry.count,
159                                        assigned: assigned_entry.count,
160                                    });
161                                }
162                            } else {
163                                errors.push(EntryError::ExtraExpected { binding });
164                            }
165                        }
166
167                        for (&binding, _) in assigned_bgl.entries.iter() {
168                            if !expected_bgl.entries.contains_key(binding) {
169                                errors.push(EntryError::ExtraAssigned { binding });
170                            }
171                        }
172
173                        Err(Error::Incompatible {
174                            expected_bgl: expected_bgl.error_ident(),
175                            assigned_bgl: assigned_bgl.error_ident(),
176                            inner: MultiError::new(errors.drain(..)).unwrap(),
177                        })
178                    }
179                } else {
180                    Err(Error::Missing)
181                }
182            } else {
183                Ok(())
184            }
185        }
186    }
187
188    #[derive(Debug)]
189    pub(super) struct BoundBindGroupLayouts {
190        entries: [Entry; hal::MAX_BIND_GROUPS],
191        rebind_start: usize,
192    }
193
194    impl BoundBindGroupLayouts {
195        pub fn new() -> Self {
196            Self {
197                entries: [const { Entry::empty() }; hal::MAX_BIND_GROUPS],
198                rebind_start: 0,
199            }
200        }
201
202        /// Takes the start index of the bind group range to be rebound, and clears it.
203        pub fn take_rebind_start_index(&mut self) -> usize {
204            let start = self.rebind_start;
205            self.rebind_start = self.entries.len();
206            start
207        }
208
209        pub fn update_rebind_start_index(&mut self, start_index: usize) {
210            self.rebind_start = self.rebind_start.min(start_index);
211        }
212
213        pub fn update_expectations(&mut self, expectations: &[Option<Arc<BindGroupLayout>>]) {
214            let mut rebind_start_index = None;
215
216            for (i, (e, new_expected_bgl)) in self
217                .entries
218                .iter_mut()
219                .zip(expectations.iter().chain(core::iter::repeat(&None)))
220                .enumerate()
221            {
222                let (must_set, must_rebind) = match (&mut e.expected, new_expected_bgl) {
223                    (None, None) => (false, false),
224                    (None, Some(_)) => (true, true),
225                    (Some(_), None) => (true, false),
226                    (Some(old_expected_bgl), Some(new_expected_bgl)) => {
227                        let is_different = !old_expected_bgl.is_equal(new_expected_bgl);
228                        (is_different, is_different)
229                    }
230                };
231                if must_set {
232                    e.expected = new_expected_bgl.clone();
233                }
234                if must_rebind && rebind_start_index.is_none() {
235                    rebind_start_index = Some(i);
236                }
237            }
238
239            if let Some(rebind_start_index) = rebind_start_index {
240                self.update_rebind_start_index(rebind_start_index);
241            }
242        }
243
244        pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
245            self.entries[index].assigned = Some(value);
246            self.update_rebind_start_index(index);
247        }
248
249        pub fn clear(&mut self, index: usize) {
250            self.entries[index].assigned = None;
251        }
252
253        pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
254            self.entries
255                .iter()
256                .enumerate()
257                .filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
258        }
259
260        pub fn list_valid(&self) -> impl Iterator<Item = usize> + '_ {
261            self.entries
262                .iter()
263                .enumerate()
264                .filter_map(|(i, e)| if e.is_valid() { Some(i) } else { None })
265        }
266
267        #[allow(clippy::result_large_err)]
268        pub fn get_invalid(&self) -> Result<(), (usize, Error)> {
269            for (index, entry) in self.entries.iter().enumerate() {
270                entry.check().map_err(|e| (index, e))?;
271            }
272            Ok(())
273        }
274    }
275}
276
277#[derive(Clone, Debug, Error)]
278pub enum BinderError {
279    #[error("The current set {pipeline} expects a BindGroup to be set at index {index}")]
280    MissingBindGroup {
281        index: usize,
282        pipeline: ResourceErrorIdent,
283    },
284    #[error("The {assigned_bgl} of current set {assigned_bg} at index {index} is not compatible with the corresponding {expected_bgl} of {pipeline}")]
285    IncompatibleBindGroup {
286        expected_bgl: ResourceErrorIdent,
287        assigned_bgl: ResourceErrorIdent,
288        assigned_bg: ResourceErrorIdent,
289        index: usize,
290        pipeline: ResourceErrorIdent,
291        #[source]
292        inner: crate::error::MultiError,
293    },
294}
295
296#[derive(Debug)]
297struct LateBufferBinding {
298    binding_index: u32,
299    shader_expect_size: wgt::BufferAddress,
300    bound_size: wgt::BufferAddress,
301}
302
303#[derive(Debug, Default)]
304struct EntryPayload {
305    group: Option<Arc<BindGroup>>,
306    dynamic_offsets: Vec<wgt::DynamicOffset>,
307    late_buffer_bindings: Vec<LateBufferBinding>,
308    /// Since `LateBufferBinding` may contain information about the bindings
309    /// not used by the pipeline, we need to know when to stop validating.
310    late_bindings_effective_count: usize,
311}
312
313impl EntryPayload {
314    fn reset(&mut self) {
315        self.group = None;
316        self.dynamic_offsets.clear();
317        self.late_buffer_bindings.clear();
318        self.late_bindings_effective_count = 0;
319    }
320}
321
322#[derive(Debug)]
323pub(super) struct Binder {
324    pub(super) pipeline_layout: Option<Arc<PipelineLayout>>,
325    manager: compat::BoundBindGroupLayouts,
326    payloads: [EntryPayload; hal::MAX_BIND_GROUPS],
327}
328
329impl Binder {
330    pub(super) fn new() -> Self {
331        Self {
332            pipeline_layout: None,
333            manager: compat::BoundBindGroupLayouts::new(),
334            payloads: Default::default(),
335        }
336    }
337    pub(super) fn reset(&mut self) {
338        self.pipeline_layout = None;
339        self.manager = compat::BoundBindGroupLayouts::new();
340        for payload in self.payloads.iter_mut() {
341            payload.reset();
342        }
343    }
344
345    /// Returns `true` if the pipeline layout has been changed, i.e. if the
346    /// new PL was not the same as the old PL.
347    pub(super) fn change_pipeline_layout<'a>(
348        &'a mut self,
349        new: &Arc<PipelineLayout>,
350        late_sized_buffer_groups: &[LateSizedBufferGroup],
351    ) -> bool {
352        if let Some(old) = self.pipeline_layout.as_ref() {
353            if old.is_equal(new) {
354                return false;
355            }
356        }
357
358        let old = self.pipeline_layout.replace(new.clone());
359
360        self.manager.update_expectations(&new.bind_group_layouts);
361
362        // Update the buffer binding sizes that are required by shaders.
363
364        for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
365            payload.late_bindings_effective_count = late_group.shader_sizes.len();
366            // Update entries that already exist as the bind group was bound before the pipeline
367            // was bound.
368            for (late_binding, &shader_expect_size) in payload
369                .late_buffer_bindings
370                .iter_mut()
371                .zip(late_group.shader_sizes.iter())
372            {
373                late_binding.shader_expect_size = shader_expect_size;
374            }
375            // Add new entries for the bindings that were not known when the bind group was bound.
376            if late_group.shader_sizes.len() > payload.late_buffer_bindings.len() {
377                for &shader_expect_size in
378                    late_group.shader_sizes[payload.late_buffer_bindings.len()..].iter()
379                {
380                    payload.late_buffer_bindings.push(LateBufferBinding {
381                        binding_index: 0,
382                        shader_expect_size,
383                        bound_size: 0,
384                    });
385                }
386            }
387        }
388
389        if let Some(old) = old {
390            // root constants are the base compatibility property
391            if old.immediate_size != new.immediate_size {
392                self.manager.update_rebind_start_index(0);
393            }
394        }
395
396        true
397    }
398
399    pub(super) fn assign_group<'a>(
400        &'a mut self,
401        index: usize,
402        bind_group: &Arc<BindGroup>,
403        offsets: &[wgt::DynamicOffset],
404    ) {
405        let payload = &mut self.payloads[index];
406        payload.group = Some(bind_group.clone());
407        payload.dynamic_offsets.clear();
408        payload.dynamic_offsets.extend_from_slice(offsets);
409
410        // Fill out the actual binding sizes for buffers,
411        // whose layout doesn't specify `min_binding_size`.
412
413        // Update entries that already exist as the pipeline was bound before the group
414        // was bound.
415        for (late_binding, late_info) in payload
416            .late_buffer_bindings
417            .iter_mut()
418            .zip(bind_group.late_buffer_binding_infos.iter())
419        {
420            late_binding.binding_index = late_info.binding_index;
421            late_binding.bound_size = late_info.size.get();
422        }
423
424        // Add new entries for the bindings that were not known when the pipeline was bound.
425        if bind_group.late_buffer_binding_infos.len() > payload.late_buffer_bindings.len() {
426            for late_info in
427                bind_group.late_buffer_binding_infos[payload.late_buffer_bindings.len()..].iter()
428            {
429                payload.late_buffer_bindings.push(LateBufferBinding {
430                    binding_index: late_info.binding_index,
431                    shader_expect_size: 0,
432                    bound_size: late_info.size.get(),
433                });
434            }
435        }
436
437        self.manager.assign(index, bind_group.layout.clone());
438    }
439
440    pub(super) fn clear_group(&mut self, index: usize) {
441        self.payloads[index].reset();
442        self.manager.clear(index);
443    }
444
445    /// Takes the start index of the bind group range to be rebound, and clears it.
446    pub(super) fn take_rebind_start_index(&mut self) -> usize {
447        self.manager.take_rebind_start_index()
448    }
449
450    pub(super) fn list_valid_with_start(
451        &self,
452        start: usize,
453    ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
454        let payloads = &self.payloads;
455        self.manager
456            .list_valid()
457            .filter(move |i| *i >= start)
458            .map(move |index| {
459                (
460                    index,
461                    payloads[index].group.as_ref().unwrap(),
462                    payloads[index].dynamic_offsets.as_slice(),
463                )
464            })
465    }
466
467    pub(super) fn list_active(&self) -> impl Iterator<Item = &Arc<BindGroup>> + '_ {
468        let payloads = &self.payloads;
469        self.manager
470            .list_active()
471            .map(move |index| payloads[index].group.as_ref().unwrap())
472    }
473
474    pub(super) fn list_valid(
475        &self,
476    ) -> impl Iterator<Item = (usize, &Arc<BindGroup>, &[wgt::DynamicOffset])> + '_ {
477        self.list_valid_with_start(0)
478    }
479
480    pub(super) fn check_compatibility<T: Labeled>(
481        &self,
482        pipeline: &T,
483    ) -> Result<(), Box<BinderError>> {
484        self.manager.get_invalid().map_err(|(index, error)| {
485            Box::new(match error {
486                compat::Error::Incompatible {
487                    expected_bgl,
488                    assigned_bgl,
489                    inner,
490                } => BinderError::IncompatibleBindGroup {
491                    expected_bgl,
492                    assigned_bgl,
493                    assigned_bg: self.payloads[index].group.as_ref().unwrap().error_ident(),
494                    index,
495                    pipeline: pipeline.error_ident(),
496                    inner,
497                },
498                compat::Error::Missing => BinderError::MissingBindGroup {
499                    index,
500                    pipeline: pipeline.error_ident(),
501                },
502            })
503        })
504    }
505
506    /// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
507    pub(super) fn check_late_buffer_bindings(
508        &self,
509    ) -> Result<(), LateMinBufferBindingSizeMismatch> {
510        for group_index in self.manager.list_active() {
511            let payload = &self.payloads[group_index];
512            for late_binding in
513                &payload.late_buffer_bindings[..payload.late_bindings_effective_count]
514            {
515                if late_binding.bound_size < late_binding.shader_expect_size {
516                    return Err(LateMinBufferBindingSizeMismatch {
517                        group_index: group_index as u32,
518                        binding_index: late_binding.binding_index,
519                        shader_size: late_binding.shader_expect_size,
520                        bound_size: late_binding.bound_size,
521                    });
522                }
523            }
524        }
525        Ok(())
526    }
527}