wgpu_core/device/
mod.rs

1use alloc::{boxed::Box, string::String, vec::Vec};
2use core::{fmt, num::NonZeroU32};
3
4use crate::{
5    binding_model,
6    ray_tracing::BlasCompactReadyPendingClosure,
7    resource::{
8        Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
9        RawResourceAccess, ResourceErrorIdent,
10    },
11    snatch::SnatchGuard,
12    Label, DOWNLEVEL_ERROR_MESSAGE,
13};
14
15use arrayvec::ArrayVec;
16use smallvec::SmallVec;
17use thiserror::Error;
18use wgt::{
19    error::{ErrorType, WebGpuError},
20    BufferAddress, DeviceLostReason, TextureFormat,
21};
22
23pub(crate) mod bgl;
24pub mod global;
25mod life;
26pub mod queue;
27pub mod ray_tracing;
28pub mod resource;
29#[cfg(any(feature = "trace", feature = "replay"))]
30pub mod trace;
31pub use {life::WaitIdleError, resource::Device};
32
33pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
34// Should be large enough for the largest possible texture row. This
35// value is enough for a 16k texture with float4 format.
36pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
37
38// If a submission is not completed within this time, we go off into UB land.
39// See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this.
40const CLEANUP_WAIT_MS: u32 = 60000;
41
42pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
43
44pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
45
46#[repr(C)]
47#[derive(Clone, Copy, Debug, Eq, PartialEq)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub enum HostMap {
50    Read,
51    Write,
52}
53
54#[derive(Clone, Debug, Hash, PartialEq)]
55#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
56pub(crate) struct AttachmentData<T> {
57    pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
58    pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
59    pub depth_stencil: Option<T>,
60}
61impl<T: PartialEq> Eq for AttachmentData<T> {}
62
63#[derive(Clone, Debug, Hash, PartialEq)]
64#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
65pub(crate) struct RenderPassContext {
66    pub attachments: AttachmentData<TextureFormat>,
67    pub sample_count: u32,
68    pub multiview: Option<NonZeroU32>,
69}
70#[derive(Clone, Debug, Error)]
71#[non_exhaustive]
72pub enum RenderPassCompatibilityError {
73    #[error(
74        "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
75    )]
76    IncompatibleColorAttachment {
77        indices: Vec<usize>,
78        expected: Vec<Option<TextureFormat>>,
79        actual: Vec<Option<TextureFormat>>,
80        res: ResourceErrorIdent,
81    },
82    #[error(
83        "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
84    )]
85    IncompatibleDepthStencilAttachment {
86        expected: Option<TextureFormat>,
87        actual: Option<TextureFormat>,
88        res: ResourceErrorIdent,
89    },
90    #[error(
91        "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
92    )]
93    IncompatibleSampleCount {
94        expected: u32,
95        actual: u32,
96        res: ResourceErrorIdent,
97    },
98    #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
99    IncompatibleMultiview {
100        expected: Option<NonZeroU32>,
101        actual: Option<NonZeroU32>,
102        res: ResourceErrorIdent,
103    },
104}
105
106impl WebGpuError for RenderPassCompatibilityError {
107    fn webgpu_error_type(&self) -> ErrorType {
108        ErrorType::Validation
109    }
110}
111
112impl RenderPassContext {
113    // Assumes the renderpass only contains one subpass
114    pub(crate) fn check_compatible<T: Labeled>(
115        &self,
116        other: &Self,
117        res: &T,
118    ) -> Result<(), RenderPassCompatibilityError> {
119        if self.attachments.colors != other.attachments.colors {
120            let indices = self
121                .attachments
122                .colors
123                .iter()
124                .zip(&other.attachments.colors)
125                .enumerate()
126                .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
127                .collect();
128            return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
129                indices,
130                expected: self.attachments.colors.iter().cloned().collect(),
131                actual: other.attachments.colors.iter().cloned().collect(),
132                res: res.error_ident(),
133            });
134        }
135        if self.attachments.depth_stencil != other.attachments.depth_stencil {
136            return Err(
137                RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
138                    expected: self.attachments.depth_stencil,
139                    actual: other.attachments.depth_stencil,
140                    res: res.error_ident(),
141                },
142            );
143        }
144        if self.sample_count != other.sample_count {
145            return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
146                expected: self.sample_count,
147                actual: other.sample_count,
148                res: res.error_ident(),
149            });
150        }
151        if self.multiview != other.multiview {
152            return Err(RenderPassCompatibilityError::IncompatibleMultiview {
153                expected: self.multiview,
154                actual: other.multiview,
155                res: res.error_ident(),
156            });
157        }
158        Ok(())
159    }
160}
161
162pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
163
164#[derive(Default)]
165pub struct UserClosures {
166    pub mappings: Vec<BufferMapPendingClosure>,
167    pub blas_compact_ready: Vec<BlasCompactReadyPendingClosure>,
168    pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
169    pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
170}
171
172impl UserClosures {
173    fn extend(&mut self, other: Self) {
174        self.mappings.extend(other.mappings);
175        self.blas_compact_ready.extend(other.blas_compact_ready);
176        self.submissions.extend(other.submissions);
177        self.device_lost_invocations
178            .extend(other.device_lost_invocations);
179    }
180
181    fn fire(self) {
182        // Note: this logic is specifically moved out of `handle_mapping()` in order to
183        // have nothing locked by the time we execute users callback code.
184
185        // Mappings _must_ be fired before submissions, as the spec requires all mapping callbacks that are registered before
186        // a on_submitted_work_done callback to be fired before the on_submitted_work_done callback.
187        for (mut operation, status) in self.mappings {
188            if let Some(callback) = operation.callback.take() {
189                callback(status);
190            }
191        }
192        for (mut operation, status) in self.blas_compact_ready {
193            if let Some(callback) = operation.take() {
194                callback(status);
195            }
196        }
197        for closure in self.submissions {
198            closure();
199        }
200        for invocation in self.device_lost_invocations {
201            (invocation.closure)(invocation.reason, invocation.message);
202        }
203    }
204}
205
206#[cfg(send_sync)]
207pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
208#[cfg(not(send_sync))]
209pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
210
211pub struct DeviceLostInvocation {
212    closure: DeviceLostClosure,
213    reason: DeviceLostReason,
214    message: String,
215}
216
217pub(crate) fn map_buffer(
218    buffer: &Buffer,
219    offset: BufferAddress,
220    size: BufferAddress,
221    kind: HostMap,
222    snatch_guard: &SnatchGuard,
223) -> Result<hal::BufferMapping, BufferAccessError> {
224    let raw_device = buffer.device.raw();
225    let raw_buffer = buffer.try_raw(snatch_guard)?;
226    let mapping = unsafe {
227        raw_device
228            .map_buffer(raw_buffer, offset..offset + size)
229            .map_err(|e| buffer.device.handle_hal_error(e))?
230    };
231
232    if !mapping.is_coherent && kind == HostMap::Read {
233        #[allow(clippy::single_range_in_vec_init)]
234        unsafe {
235            raw_device.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
236        }
237    }
238
239    assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
240    assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
241    // Zero out uninitialized parts of the mapping. (Spec dictates all resources
242    // behave as if they were initialized with zero)
243    //
244    // If this is a read mapping, ideally we would use a `clear_buffer` command
245    // before reading the data from GPU (i.e. `invalidate_range`). However, this
246    // would require us to kick off and wait for a command buffer or piggy back
247    // on an existing one (the later is likely the only worthwhile option). As
248    // reading uninitialized memory isn't a particular important path to
249    // support, we instead just initialize the memory here and make sure it is
250    // GPU visible, so this happens at max only once for every buffer region.
251    //
252    // If this is a write mapping zeroing out the memory here is the only
253    // reasonable way as all data is pushed to GPU anyways.
254
255    let mapped = unsafe { core::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
256
257    // We can't call flush_mapped_ranges in this case, so we can't drain the uninitialized ranges either
258    if !mapping.is_coherent
259        && kind == HostMap::Read
260        && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
261    {
262        for uninitialized in buffer
263            .initialization_status
264            .write()
265            .uninitialized(offset..(size + offset))
266        {
267            // The mapping's pointer is already offset, however we track the
268            // uninitialized range relative to the buffer's start.
269            let fill_range =
270                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
271            mapped[fill_range].fill(0);
272        }
273    } else {
274        for uninitialized in buffer
275            .initialization_status
276            .write()
277            .drain(offset..(size + offset))
278        {
279            // The mapping's pointer is already offset, however we track the
280            // uninitialized range relative to the buffer's start.
281            let fill_range =
282                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
283            mapped[fill_range].fill(0);
284
285            // NOTE: This is only possible when MAPPABLE_PRIMARY_BUFFERS is enabled.
286            if !mapping.is_coherent
287                && kind == HostMap::Read
288                && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
289            {
290                unsafe { raw_device.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
291            }
292        }
293    }
294
295    Ok(mapping)
296}
297
298#[derive(Clone, Debug)]
299#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
300pub struct DeviceMismatch {
301    pub(super) res: ResourceErrorIdent,
302    pub(super) res_device: ResourceErrorIdent,
303    pub(super) target: Option<ResourceErrorIdent>,
304    pub(super) target_device: ResourceErrorIdent,
305}
306
307impl fmt::Display for DeviceMismatch {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
309        write!(
310            f,
311            "{} of {} doesn't match {}",
312            self.res_device, self.res, self.target_device
313        )?;
314        if let Some(target) = self.target.as_ref() {
315            write!(f, " of {target}")?;
316        }
317        Ok(())
318    }
319}
320
321impl core::error::Error for DeviceMismatch {}
322
323impl WebGpuError for DeviceMismatch {
324    fn webgpu_error_type(&self) -> ErrorType {
325        ErrorType::Validation
326    }
327}
328
329#[derive(Clone, Debug, Error)]
330#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
331#[non_exhaustive]
332pub enum DeviceError {
333    #[error("Parent device is lost")]
334    Lost,
335    #[error("Not enough memory left.")]
336    OutOfMemory,
337    #[error(transparent)]
338    DeviceMismatch(#[from] Box<DeviceMismatch>),
339}
340
341impl WebGpuError for DeviceError {
342    fn webgpu_error_type(&self) -> ErrorType {
343        match self {
344            Self::DeviceMismatch(e) => e.webgpu_error_type(),
345            Self::Lost => ErrorType::DeviceLost,
346            Self::OutOfMemory => ErrorType::OutOfMemory,
347        }
348    }
349}
350
351impl DeviceError {
352    /// Only use this function in contexts where there is no `Device`.
353    ///
354    /// Use [`Device::handle_hal_error`] otherwise.
355    pub fn from_hal(error: hal::DeviceError) -> Self {
356        match error {
357            hal::DeviceError::Lost => Self::Lost,
358            hal::DeviceError::OutOfMemory => Self::OutOfMemory,
359            hal::DeviceError::Unexpected => Self::Lost,
360        }
361    }
362}
363
364#[derive(Clone, Debug, Error)]
365#[error("Features {0:?} are required but not enabled on the device")]
366pub struct MissingFeatures(pub wgt::Features);
367
368impl WebGpuError for MissingFeatures {
369    fn webgpu_error_type(&self) -> ErrorType {
370        ErrorType::Validation
371    }
372}
373
374#[derive(Clone, Debug, Error)]
375#[error(
376    "Downlevel flags {0:?} are required but not supported on the device.\n{DOWNLEVEL_ERROR_MESSAGE}",
377)]
378pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
379
380impl WebGpuError for MissingDownlevelFlags {
381    fn webgpu_error_type(&self) -> ErrorType {
382        ErrorType::Validation
383    }
384}
385
386/// Create a validator for Naga [`Module`]s.
387///
388/// Create a Naga [`Validator`] that ensures that each [`naga::Module`]
389/// presented to it is valid, and uses no features not included in
390/// `features` and `downlevel`.
391///
392/// The validator can only catch invalid modules and feature misuse
393/// reliably when the `flags` argument includes all the flags in
394/// [`ValidationFlags::default()`].
395///
396/// [`Validator`]: naga::valid::Validator
397/// [`Module`]: naga::Module
398/// [`ValidationFlags::default()`]: naga::valid::ValidationFlags::default
399pub fn create_validator(
400    features: wgt::Features,
401    downlevel: wgt::DownlevelFlags,
402    flags: naga::valid::ValidationFlags,
403) -> naga::valid::Validator {
404    use naga::valid::Capabilities as Caps;
405    let mut caps = Caps::empty();
406    caps.set(
407        Caps::PUSH_CONSTANT,
408        features.contains(wgt::Features::PUSH_CONSTANTS),
409    );
410    caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
411    caps.set(
412        Caps::SHADER_FLOAT16,
413        features.contains(wgt::Features::SHADER_F16),
414    );
415    caps.set(
416        Caps::PRIMITIVE_INDEX,
417        features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
418    );
419    caps.set(
420        Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
421        features
422            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
423    );
424    caps.set(
425        Caps::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
426        features.contains(wgt::Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
427    );
428    caps.set(
429        Caps::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
430        features.contains(wgt::Features::UNIFORM_BUFFER_BINDING_ARRAYS),
431    );
432    // TODO: This needs a proper wgpu feature
433    caps.set(
434        Caps::SAMPLER_NON_UNIFORM_INDEXING,
435        features
436            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
437    );
438    caps.set(
439        Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
440        features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
441    );
442    caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
443    caps.set(
444        Caps::EARLY_DEPTH_TEST,
445        features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
446    );
447    caps.set(
448        Caps::SHADER_INT64,
449        features.contains(wgt::Features::SHADER_INT64),
450    );
451    caps.set(
452        Caps::SHADER_INT64_ATOMIC_MIN_MAX,
453        features.intersects(
454            wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
455        ),
456    );
457    caps.set(
458        Caps::SHADER_INT64_ATOMIC_ALL_OPS,
459        features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
460    );
461    caps.set(
462        Caps::TEXTURE_ATOMIC,
463        features.contains(wgt::Features::TEXTURE_ATOMIC),
464    );
465    caps.set(
466        Caps::TEXTURE_INT64_ATOMIC,
467        features.contains(wgt::Features::TEXTURE_INT64_ATOMIC),
468    );
469    caps.set(
470        Caps::SHADER_FLOAT32_ATOMIC,
471        features.contains(wgt::Features::SHADER_FLOAT32_ATOMIC),
472    );
473    caps.set(
474        Caps::MULTISAMPLED_SHADING,
475        downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
476    );
477    caps.set(
478        Caps::DUAL_SOURCE_BLENDING,
479        features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
480    );
481    caps.set(
482        Caps::CLIP_DISTANCE,
483        features.contains(wgt::Features::CLIP_DISTANCES),
484    );
485    caps.set(
486        Caps::CUBE_ARRAY_TEXTURES,
487        downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
488    );
489    caps.set(
490        Caps::SUBGROUP,
491        features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
492    );
493    caps.set(
494        Caps::SUBGROUP_BARRIER,
495        features.intersects(wgt::Features::SUBGROUP_BARRIER),
496    );
497    caps.set(
498        Caps::RAY_QUERY,
499        features.intersects(wgt::Features::EXPERIMENTAL_RAY_QUERY),
500    );
501    caps.set(
502        Caps::SUBGROUP_VERTEX_STAGE,
503        features.contains(wgt::Features::SUBGROUP_VERTEX),
504    );
505    caps.set(
506        Caps::RAY_HIT_VERTEX_POSITION,
507        features.intersects(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
508    );
509    caps.set(
510        Caps::TEXTURE_EXTERNAL,
511        features.intersects(wgt::Features::EXTERNAL_TEXTURE),
512    );
513
514    naga::valid::Validator::new(flags, caps)
515}