wgpu_core/
pipeline.rs

1use alloc::{
2    borrow::{Cow, ToOwned},
3    boxed::Box,
4    string::String,
5    sync::Arc,
6    vec::Vec,
7};
8use core::{marker::PhantomData, mem::ManuallyDrop, num::NonZeroU32};
9
10use arrayvec::ArrayVec;
11use naga::error::ShaderError;
12use thiserror::Error;
13use wgt::error::{ErrorType, WebGpuError};
14
15pub use crate::pipeline_cache::PipelineCacheValidationError;
16use crate::{
17    api_log,
18    binding_model::{
19        BindGroupLayout, CreateBindGroupLayoutError, CreatePipelineLayoutError,
20        GetBindGroupLayoutError, PipelineLayout,
21    },
22    command::ColorAttachmentError,
23    device::{
24        AttachmentData, Device, DeviceError, MissingDownlevelFlags, MissingFeatures,
25        RenderPassContext,
26    },
27    id::{PipelineCacheId, PipelineLayoutId, ShaderModuleId},
28    pipeline_cache,
29    resource::{InvalidResourceError, Labeled, ResourceState, TrackingData},
30    resource_log,
31    validation::{self, ShaderMetaData},
32    Label, LabelHelpers as _,
33};
34
35/// Information about buffer bindings, which
36/// is validated against the shader (and pipeline)
37/// at draw time as opposed to initialization time.
38#[derive(Debug, Default)]
39pub(crate) struct LateSizedBufferGroup {
40    // The order has to match `BindGroup::late_buffer_binding_sizes`.
41    pub(crate) shader_sizes: Vec<wgt::BufferAddress>,
42}
43
44#[allow(clippy::large_enum_variant)]
45pub enum ShaderModuleSource<'a> {
46    #[cfg(feature = "wgsl")]
47    Wgsl(Cow<'a, str>),
48    #[cfg(feature = "glsl")]
49    Glsl(Cow<'a, str>, naga::front::glsl::Options),
50    #[cfg(feature = "spirv")]
51    SpirV(Cow<'a, [u32]>, naga::front::spv::Options),
52    Naga(Cow<'static, naga::Module>),
53    /// Dummy variant because `Naga` doesn't have a lifetime and without enough active features it
54    /// could be the last one active.
55    #[doc(hidden)]
56    Dummy(PhantomData<&'a ()>),
57}
58
59#[derive(Clone, Debug)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct ShaderModuleDescriptor<'a> {
62    pub label: Label<'a>,
63    #[cfg_attr(feature = "serde", serde(default))]
64    pub runtime_checks: wgt::ShaderRuntimeChecks,
65}
66
67pub type ShaderModuleDescriptorPassthrough<'a> =
68    wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>;
69
70#[derive(Debug)]
71pub(crate) struct ShaderModuleState {
72    pub(crate) raw: Box<dyn hal::DynShaderModule>,
73    pub(crate) interface: ShaderMetaData,
74}
75
76#[derive(Debug)]
77pub struct ShaderModule {
78    pub(crate) state: ResourceState<ShaderModuleState>,
79    pub(crate) device: Arc<Device>,
80    /// The `label` from the descriptor used to create the resource.
81    pub(crate) label: String,
82}
83
84impl Drop for ShaderModule {
85    fn drop(&mut self) {
86        resource_log!("Destroy raw {}", self.error_ident());
87        #[cfg(feature = "trace")]
88        if let Some(t) = self.device.trace.lock().as_mut() {
89            use crate::device::trace::{to_trace, Action};
90
91            t.add(Action::DropShaderModule(unsafe { to_trace(self) }));
92        }
93        let ResourceState::Valid(state) =
94            core::mem::replace(&mut self.state, ResourceState::Invalid)
95        else {
96            return;
97        };
98        unsafe {
99            self.device.raw().destroy_shader_module(state.raw);
100        }
101    }
102}
103
104crate::impl_resource_type!(ShaderModule);
105crate::impl_labeled!(ShaderModule);
106crate::impl_parent_device!(ShaderModule);
107crate::impl_storage_item!(ShaderModule);
108
109impl ShaderModule {
110    pub(crate) fn state(&self) -> Result<&ShaderModuleState, InvalidResourceError> {
111        let ResourceState::Valid(state) = &self.state else {
112            return Err(InvalidResourceError(self.error_ident()));
113        };
114        Ok(state)
115    }
116
117    pub(crate) fn invalid(device: Arc<Device>, label: String) -> Arc<Self> {
118        Arc::new(Self {
119            state: ResourceState::Invalid,
120            device,
121            label,
122        })
123    }
124
125    pub(crate) fn finalize_entry_point_name(
126        &self,
127        stage: naga::ShaderStage,
128        entry_point: Option<&str>,
129    ) -> Result<String, validation::StageError> {
130        let state = self.state()?;
131        match state.interface {
132            ShaderMetaData::Interface(ref interface) => {
133                interface.finalize_entry_point_name(stage, entry_point)
134            }
135            ShaderMetaData::Passthrough(ref interface) => {
136                if let Some(ep) = entry_point {
137                    if interface.entry_point_names.contains(ep) {
138                        Ok(ep.to_owned())
139                    } else {
140                        Err(validation::StageError::MissingEntryPoint(ep.to_owned()))
141                    }
142                } else {
143                    if interface.entry_point_names.len() != 1 {
144                        return Err(validation::StageError::MultipleEntryPointsFound);
145                    }
146                    Ok(interface
147                        .entry_point_names
148                        .iter()
149                        .next()
150                        .unwrap()
151                        .to_owned())
152                }
153            }
154        }
155    }
156}
157
158//Note: `Clone` would require `WithSpan: Clone`.
159#[derive(Clone, Debug, Error)]
160#[non_exhaustive]
161pub enum CreateShaderModuleError {
162    #[cfg(feature = "wgsl")]
163    #[error(transparent)]
164    Parsing(#[from] ShaderError<naga::front::wgsl::ParseError>),
165    #[cfg(feature = "glsl")]
166    #[error(transparent)]
167    ParsingGlsl(#[from] ShaderError<naga::front::glsl::ParseErrors>),
168    #[cfg(feature = "spirv")]
169    #[error(transparent)]
170    ParsingSpirV(#[from] ShaderError<naga::front::spv::Error>),
171    #[error("Failed to generate the backend-specific code")]
172    Generation,
173    #[error(transparent)]
174    Device(#[from] DeviceError),
175    #[error(transparent)]
176    Validation(#[from] ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
177    #[error(transparent)]
178    MissingFeatures(#[from] MissingFeatures),
179    #[error(
180        "Shader global {bind:?} uses a group index {group} that exceeds the max_bind_groups limit of {limit}."
181    )]
182    InvalidGroupIndex {
183        bind: naga::ResourceBinding,
184        group: u32,
185        limit: u32,
186    },
187    #[error("Generic shader passthrough does not contain any code compatible with this backend.")]
188    NotCompiledForBackend,
189    #[error(
190        "Generic passthrough shaders which use GLSL or DXIL must contain exactly one entry point."
191    )]
192    IncorrectPassthroughEntryPointCount,
193}
194
195impl WebGpuError for CreateShaderModuleError {
196    fn webgpu_error_type(&self) -> ErrorType {
197        match self {
198            Self::Device(e) => e.webgpu_error_type(),
199            Self::MissingFeatures(e) => e.webgpu_error_type(),
200
201            Self::Generation => ErrorType::Internal,
202
203            Self::Validation(..)
204            | Self::InvalidGroupIndex { .. }
205            | Self::IncorrectPassthroughEntryPointCount
206            | Self::NotCompiledForBackend => ErrorType::Validation,
207            #[cfg(feature = "wgsl")]
208            Self::Parsing(..) => ErrorType::Validation,
209            #[cfg(feature = "glsl")]
210            Self::ParsingGlsl(..) => ErrorType::Validation,
211            #[cfg(feature = "spirv")]
212            Self::ParsingSpirV(..) => ErrorType::Validation,
213        }
214    }
215}
216
217/// Describes a programmable pipeline stage.
218#[derive(Clone, Debug)]
219#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
220pub struct ProgrammableStageDescriptor<'a, SM = ShaderModuleId> {
221    /// The compiled shader module for this stage.
222    pub module: SM,
223    /// The name of the entry point in the compiled shader. The name is selected using the
224    /// following logic:
225    ///
226    /// * If `Some(name)` is specified, there must be a function with this name in the shader.
227    /// * If a single entry point associated with this stage must be in the shader, then proceed as
228    ///   if `Some(…)` was specified with that entry point's name.
229    pub entry_point: Option<Cow<'a, str>>,
230    /// Specifies the values of pipeline-overridable constants in the shader module.
231    ///
232    /// If an `@id` attribute was specified on the declaration,
233    /// the key must be the pipeline constant ID as a decimal ASCII number; if not,
234    /// the key must be the constant's identifier name.
235    ///
236    /// The value may represent any of WGSL's concrete scalar types.
237    pub constants: naga::back::PipelineConstants,
238    /// Whether workgroup scoped memory will be initialized with zero values for this stage.
239    ///
240    /// This is required by the WebGPU spec, but may have overhead which can be avoided
241    /// for cross-platform applications
242    pub zero_initialize_workgroup_memory: bool,
243}
244
245/// cbindgen:ignore
246pub type ResolvedProgrammableStageDescriptor<'a> =
247    ProgrammableStageDescriptor<'a, Arc<ShaderModule>>;
248
249/// Number of implicit bind groups derived at pipeline creation.
250pub type ImplicitBindGroupCount = u8;
251
252#[derive(Clone, Debug, Error)]
253#[non_exhaustive]
254pub enum ImplicitLayoutError {
255    #[error("Unable to reflect the shader {0:?} interface")]
256    ReflectionError(wgt::ShaderStages),
257    #[error(transparent)]
258    BindGroup(#[from] CreateBindGroupLayoutError),
259    #[error(transparent)]
260    Pipeline(#[from] CreatePipelineLayoutError),
261    #[error("Unable to create implicit pipeline layout from passthrough shader stage: {0:?}")]
262    Passthrough(wgt::ShaderStages),
263}
264
265impl WebGpuError for ImplicitLayoutError {
266    fn webgpu_error_type(&self) -> ErrorType {
267        match self {
268            Self::ReflectionError(_) => ErrorType::Validation,
269            Self::BindGroup(e) => e.webgpu_error_type(),
270            Self::Pipeline(e) => e.webgpu_error_type(),
271            Self::Passthrough(_) => ErrorType::Validation,
272        }
273    }
274}
275
276/// Describes a compute pipeline.
277#[derive(Clone, Debug)]
278#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
279pub struct ComputePipelineDescriptor<
280    'a,
281    PLL = PipelineLayoutId,
282    SM = ShaderModuleId,
283    PLC = PipelineCacheId,
284> {
285    pub label: Label<'a>,
286    /// The layout of bind groups for this pipeline.
287    pub layout: Option<PLL>,
288    /// The compiled compute stage and its entry point.
289    pub stage: ProgrammableStageDescriptor<'a, SM>,
290    /// The pipeline cache to use when creating this pipeline.
291    pub cache: Option<PLC>,
292}
293
294/// cbindgen:ignore
295pub type ResolvedComputePipelineDescriptor<'a> =
296    ComputePipelineDescriptor<'a, Arc<PipelineLayout>, Arc<ShaderModule>, Arc<PipelineCache>>;
297
298#[derive(Clone, Debug, Error)]
299#[non_exhaustive]
300pub enum CreateComputePipelineError {
301    #[error(transparent)]
302    Device(#[from] DeviceError),
303    #[error("Unable to derive an implicit layout")]
304    Implicit(#[from] ImplicitLayoutError),
305    #[error("Error matching shader requirements against the pipeline")]
306    Stage(#[from] validation::StageError),
307    #[error("Internal error: {0}")]
308    Internal(String),
309    #[error("Pipeline constant error: {0}")]
310    PipelineConstants(String),
311    #[error(transparent)]
312    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
313    #[error(transparent)]
314    InvalidResource(#[from] InvalidResourceError),
315}
316
317impl WebGpuError for CreateComputePipelineError {
318    fn webgpu_error_type(&self) -> ErrorType {
319        match self {
320            Self::Device(e) => e.webgpu_error_type(),
321            Self::InvalidResource(e) => e.webgpu_error_type(),
322            Self::MissingDownlevelFlags(e) => e.webgpu_error_type(),
323            Self::Implicit(e) => e.webgpu_error_type(),
324            Self::Stage(e) => e.webgpu_error_type(),
325            Self::Internal(_) => ErrorType::Internal,
326            Self::PipelineConstants(_) => ErrorType::Validation,
327        }
328    }
329}
330
331#[derive(Debug)]
332pub struct ComputePipelineState {
333    pub(crate) raw: ManuallyDrop<Box<dyn hal::DynComputePipeline>>,
334    pub(crate) layout: Arc<PipelineLayout>,
335    pub(crate) _shader_module: Arc<ShaderModule>,
336}
337
338#[derive(Debug)]
339pub struct ComputePipeline {
340    pub(crate) state: ResourceState<ComputePipelineState>,
341    pub(crate) device: Arc<Device>,
342    pub(crate) late_sized_buffer_groups: ArrayVec<LateSizedBufferGroup, { hal::MAX_BIND_GROUPS }>,
343    pub(crate) immediate_slots_required: naga::valid::ImmediateSlots,
344    /// The `label` from the descriptor used to create the resource.
345    pub(crate) label: String,
346    pub(crate) tracking_data: TrackingData,
347}
348
349impl Drop for ComputePipeline {
350    fn drop(&mut self) {
351        resource_log!("Destroy raw {}", self.error_ident());
352        #[cfg(feature = "trace")]
353        {
354            use crate::device::trace;
355            if let Some(t) = self.device.trace.lock().as_mut() {
356                t.add(trace::Action::DropComputePipeline(unsafe {
357                    trace::to_trace(self)
358                }));
359            }
360        }
361        let ResourceState::Valid(state) = &mut self.state else {
362            return;
363        };
364        // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point.
365        let raw = unsafe { ManuallyDrop::take(&mut state.raw) };
366        unsafe {
367            self.device.raw().destroy_compute_pipeline(raw);
368        }
369    }
370}
371
372crate::impl_resource_type!(ComputePipeline);
373crate::impl_labeled!(ComputePipeline);
374crate::impl_parent_device!(ComputePipeline);
375crate::impl_storage_item!(ComputePipeline);
376crate::impl_trackable!(ComputePipeline);
377
378impl ComputePipeline {
379    pub(crate) fn raw(&self) -> Result<&dyn hal::DynComputePipeline, InvalidResourceError> {
380        let ResourceState::Valid(state) = &self.state else {
381            return Err(InvalidResourceError(self.error_ident()));
382        };
383        Ok(state.raw.as_ref())
384    }
385
386    pub(crate) fn layout(&self) -> Result<&Arc<PipelineLayout>, InvalidResourceError> {
387        let ResourceState::Valid(state) = &self.state else {
388            return Err(InvalidResourceError(self.error_ident()));
389        };
390        Ok(&state.layout)
391    }
392
393    pub(crate) fn check_valid(&self) -> Result<(), InvalidResourceError> {
394        let ResourceState::Valid(_) = &self.state else {
395            return Err(InvalidResourceError(self.error_ident()));
396        };
397        Ok(())
398    }
399
400    pub(crate) fn invalid(device: Arc<Device>, label: String) -> Arc<Self> {
401        Arc::new(Self {
402            tracking_data: TrackingData::new(device.tracker_indices.compute_pipelines.clone()),
403            state: ResourceState::Invalid,
404            device,
405            late_sized_buffer_groups: ArrayVec::new(),
406            immediate_slots_required: naga::valid::ImmediateSlots::default(),
407            label,
408        })
409    }
410
411    pub fn get_bind_group_layout_inner(
412        self: &Arc<Self>,
413        index: u32,
414    ) -> Result<Arc<BindGroupLayout>, GetBindGroupLayoutError> {
415        self.layout()?.get_bind_group_layout(index, self.into())
416    }
417
418    pub fn get_bind_group_layout(
419        self: &Arc<Self>,
420        index: u32,
421    ) -> (Arc<BindGroupLayout>, Option<GetBindGroupLayoutError>) {
422        let (bgl, error) = match self.get_bind_group_layout_inner(index) {
423            Ok(bgl) => (bgl, None),
424            Err(e) => (
425                BindGroupLayout::invalid(&self.device, String::new()),
426                Some(e),
427            ),
428        };
429        #[cfg(feature = "trace")]
430        if let Some(ref mut trace) = *self.device.trace.lock() {
431            use crate::device::trace;
432            use trace::IntoTrace;
433            trace.add(trace::Action::GetComputePipelineBindGroupLayout {
434                id: bgl.to_trace(),
435                pipeline: self.to_trace(),
436                index,
437            });
438        };
439        (bgl, error)
440    }
441}
442
443#[derive(Clone, Debug, Error)]
444#[non_exhaustive]
445pub enum CreatePipelineCacheError {
446    #[error(transparent)]
447    Device(#[from] DeviceError),
448    #[error("Pipeline cache validation failed")]
449    Validation(#[from] PipelineCacheValidationError),
450    #[error(transparent)]
451    MissingFeatures(#[from] MissingFeatures),
452}
453
454impl WebGpuError for CreatePipelineCacheError {
455    fn webgpu_error_type(&self) -> ErrorType {
456        match self {
457            Self::Device(e) => e.webgpu_error_type(),
458            Self::Validation(e) => e.webgpu_error_type(),
459            Self::MissingFeatures(e) => e.webgpu_error_type(),
460        }
461    }
462}
463
464#[derive(Debug)]
465pub struct PipelineCache {
466    pub(crate) raw: ResourceState<Box<dyn hal::DynPipelineCache>>,
467    pub(crate) device: Arc<Device>,
468    /// The `label` from the descriptor used to create the resource.
469    pub(crate) label: String,
470}
471
472impl Drop for PipelineCache {
473    #[allow(trivial_casts)]
474    fn drop(&mut self) {
475        profiling::scope!("PipelineCache::drop");
476        api_log!("PipelineCache::drop {:?}", self as *const _);
477        #[cfg(feature = "trace")]
478        if let Some(t) = self.device.trace.lock().as_mut() {
479            use crate::device::trace::{to_trace, Action};
480            t.add(Action::DropPipelineCache(unsafe { to_trace(self) }));
481        }
482        resource_log!("Destroy raw {}", self.error_ident());
483        if let ResourceState::Valid(raw) = core::mem::replace(&mut self.raw, ResourceState::Invalid)
484        {
485            unsafe {
486                self.device.raw().destroy_pipeline_cache(raw);
487            }
488        }
489    }
490}
491
492crate::impl_resource_type!(PipelineCache);
493crate::impl_labeled!(PipelineCache);
494crate::impl_parent_device!(PipelineCache);
495crate::impl_storage_item!(PipelineCache);
496
497impl PipelineCache {
498    pub(crate) fn raw(&self) -> Result<&dyn hal::DynPipelineCache, InvalidResourceError> {
499        self.raw
500            .as_ref()
501            .valid()
502            .map(|raw| raw.as_ref())
503            .ok_or_else(|| InvalidResourceError(self.error_ident()))
504    }
505
506    pub(crate) fn check_is_valid(&self) -> Result<(), InvalidResourceError> {
507        self.raw().map(|_| ())
508    }
509
510    pub(crate) fn invalid(device: Arc<Device>, desc: &PipelineCacheDescriptor) -> Arc<Self> {
511        Arc::new(Self {
512            raw: ResourceState::Invalid,
513            device,
514            label: desc.label.to_string(),
515        })
516    }
517
518    pub fn get_data(self: &Arc<Self>) -> Option<Vec<u8>> {
519        api_log!("PipelineCache::get_data");
520
521        let ResourceState::Valid(raw) = &self.raw else {
522            return None;
523        };
524
525        if !self.device.is_valid() {
526            return None;
527        }
528        let mut vec = unsafe { self.device.raw().pipeline_cache_get_data(raw.as_ref()) }?;
529        let validation_key = self.device.raw().pipeline_cache_validation_key()?;
530
531        let mut header_contents = [0; pipeline_cache::HEADER_LENGTH];
532        pipeline_cache::add_cache_header(
533            &mut header_contents,
534            &vec,
535            &self.device.adapter.raw.info,
536            validation_key,
537        );
538
539        let deleted = vec.splice(..0, header_contents).collect::<Vec<_>>();
540        debug_assert!(deleted.is_empty());
541
542        Some(vec)
543    }
544}
545
546/// Describes how the vertex buffer is interpreted.
547#[derive(Clone, Debug)]
548#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
549#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
550pub struct VertexBufferLayout<'a> {
551    /// The stride, in bytes, between elements of this buffer.
552    pub array_stride: wgt::BufferAddress,
553    /// How often this vertex buffer is "stepped" forward.
554    pub step_mode: wgt::VertexStepMode,
555    /// The list of attributes which comprise a single vertex.
556    pub attributes: Cow<'a, [wgt::VertexAttribute]>,
557}
558
559/// A null vertex buffer layout that may be placed in unused slots.
560impl Default for VertexBufferLayout<'_> {
561    fn default() -> Self {
562        Self {
563            array_stride: Default::default(),
564            step_mode: Default::default(),
565            attributes: Cow::Borrowed(&[]),
566        }
567    }
568}
569
570/// Describes the vertex process in a render pipeline.
571#[derive(Clone, Debug)]
572#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
573pub struct VertexState<'a, SM = ShaderModuleId> {
574    /// The compiled vertex stage and its entry point.
575    pub stage: ProgrammableStageDescriptor<'a, SM>,
576    /// The format of any vertex buffers used with this pipeline.
577    pub buffers: Cow<'a, [Option<VertexBufferLayout<'a>>]>,
578}
579
580/// cbindgen:ignore
581pub type ResolvedVertexState<'a> = VertexState<'a, Arc<ShaderModule>>;
582
583/// Describes fragment processing in a render pipeline.
584#[derive(Clone, Debug)]
585#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
586pub struct FragmentState<'a, SM = ShaderModuleId> {
587    /// The compiled fragment stage and its entry point.
588    pub stage: ProgrammableStageDescriptor<'a, SM>,
589    /// The effect of draw calls on the color aspect of the output target.
590    pub targets: Cow<'a, [Option<wgt::ColorTargetState>]>,
591}
592
593/// cbindgen:ignore
594pub type ResolvedFragmentState<'a> = FragmentState<'a, Arc<ShaderModule>>;
595
596/// Describes the task shader in a mesh shader pipeline.
597#[derive(Clone, Debug)]
598#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
599pub struct TaskState<'a, SM = ShaderModuleId> {
600    /// The compiled task stage and its entry point.
601    pub stage: ProgrammableStageDescriptor<'a, SM>,
602}
603
604pub type ResolvedTaskState<'a> = TaskState<'a, Arc<ShaderModule>>;
605
606/// Describes the mesh shader in a mesh shader pipeline.
607#[derive(Clone, Debug)]
608#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
609pub struct MeshState<'a, SM = ShaderModuleId> {
610    /// The compiled mesh stage and its entry point.
611    pub stage: ProgrammableStageDescriptor<'a, SM>,
612}
613
614pub type ResolvedMeshState<'a> = MeshState<'a, Arc<ShaderModule>>;
615
616/// Describes a vertex processor for either a conventional or mesh shading
617/// pipeline architecture.
618///
619/// This is not a public API. It is for use by `player` only. The public APIs
620/// are [`VertexState`], [`TaskState`], and [`MeshState`].
621///
622/// cbindgen:ignore
623#[doc(hidden)]
624#[derive(Clone, Debug)]
625#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
626pub enum RenderPipelineVertexProcessor<'a, SM = ShaderModuleId> {
627    Vertex(VertexState<'a, SM>),
628    Mesh(Option<TaskState<'a, SM>>, MeshState<'a, SM>),
629}
630
631/// Describes a render (graphics) pipeline.
632#[derive(Clone, Debug)]
633#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
634pub struct RenderPipelineDescriptor<
635    'a,
636    PLL = PipelineLayoutId,
637    SM = ShaderModuleId,
638    PLC = PipelineCacheId,
639> {
640    pub label: Label<'a>,
641    /// The layout of bind groups for this pipeline.
642    pub layout: Option<PLL>,
643    /// The vertex processing state for this pipeline.
644    pub vertex: VertexState<'a, SM>,
645    /// The properties of the pipeline at the primitive assembly and rasterization level.
646    #[cfg_attr(feature = "serde", serde(default))]
647    pub primitive: wgt::PrimitiveState,
648    /// The effect of draw calls on the depth and stencil aspects of the output target, if any.
649    #[cfg_attr(feature = "serde", serde(default))]
650    pub depth_stencil: Option<wgt::DepthStencilState>,
651    /// The multi-sampling properties of the pipeline.
652    #[cfg_attr(feature = "serde", serde(default))]
653    pub multisample: wgt::MultisampleState,
654    /// The fragment processing state for this pipeline.
655    pub fragment: Option<FragmentState<'a, SM>>,
656    /// If the pipeline will be used with a multiview render pass, this indicates how many array
657    /// layers the attachments will have.
658    pub multiview_mask: Option<NonZeroU32>,
659    /// The pipeline cache to use when creating this pipeline.
660    pub cache: Option<PLC>,
661}
662/// Describes a mesh shader pipeline.
663#[derive(Clone, Debug)]
664#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
665pub struct MeshPipelineDescriptor<
666    'a,
667    PLL = PipelineLayoutId,
668    SM = ShaderModuleId,
669    PLC = PipelineCacheId,
670> {
671    pub label: Label<'a>,
672    /// The layout of bind groups for this pipeline.
673    pub layout: Option<PLL>,
674    /// The task processing state for this pipeline.
675    pub task: Option<TaskState<'a, SM>>,
676    /// The mesh processing state for this pipeline
677    pub mesh: MeshState<'a, SM>,
678    /// The properties of the pipeline at the primitive assembly and rasterization level.
679    #[cfg_attr(feature = "serde", serde(default))]
680    pub primitive: wgt::PrimitiveState,
681    /// The effect of draw calls on the depth and stencil aspects of the output target, if any.
682    #[cfg_attr(feature = "serde", serde(default))]
683    pub depth_stencil: Option<wgt::DepthStencilState>,
684    /// The multi-sampling properties of the pipeline.
685    #[cfg_attr(feature = "serde", serde(default))]
686    pub multisample: wgt::MultisampleState,
687    /// The fragment processing state for this pipeline.
688    pub fragment: Option<FragmentState<'a, SM>>,
689    /// If the pipeline will be used with a multiview render pass, this indicates how many array
690    /// layers the attachments will have.
691    pub multiview: Option<NonZeroU32>,
692    /// The pipeline cache to use when creating this pipeline.
693    pub cache: Option<PLC>,
694}
695
696/// Describes a render (graphics) pipeline, with either conventional or mesh
697/// shading architecture.
698///
699/// This is not a public API. It is for use by `player` only. The public APIs
700/// are [`RenderPipelineDescriptor`] and [`MeshPipelineDescriptor`].
701///
702/// cbindgen:ignore
703#[doc(hidden)]
704#[derive(Clone, Debug)]
705#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
706pub struct GeneralRenderPipelineDescriptor<
707    'a,
708    PLL = PipelineLayoutId,
709    SM = ShaderModuleId,
710    PLC = PipelineCacheId,
711> {
712    pub label: Label<'a>,
713    /// The layout of bind groups for this pipeline.
714    pub layout: Option<PLL>,
715    /// The vertex processing state for this pipeline.
716    pub vertex: RenderPipelineVertexProcessor<'a, SM>,
717    /// The properties of the pipeline at the primitive assembly and rasterization level.
718    #[cfg_attr(feature = "serde", serde(default))]
719    pub primitive: wgt::PrimitiveState,
720    /// The effect of draw calls on the depth and stencil aspects of the output target, if any.
721    #[cfg_attr(feature = "serde", serde(default))]
722    pub depth_stencil: Option<wgt::DepthStencilState>,
723    /// The multi-sampling properties of the pipeline.
724    #[cfg_attr(feature = "serde", serde(default))]
725    pub multisample: wgt::MultisampleState,
726    /// The fragment processing state for this pipeline.
727    pub fragment: Option<FragmentState<'a, SM>>,
728    /// If the pipeline will be used with a multiview render pass, this indicates how many array
729    /// layers the attachments will have.
730    pub multiview_mask: Option<NonZeroU32>,
731    /// The pipeline cache to use when creating this pipeline.
732    pub cache: Option<PLC>,
733}
734impl<'a, PLL, SM, PLC> From<RenderPipelineDescriptor<'a, PLL, SM, PLC>>
735    for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC>
736{
737    fn from(value: RenderPipelineDescriptor<'a, PLL, SM, PLC>) -> Self {
738        Self {
739            label: value.label,
740            layout: value.layout,
741            vertex: RenderPipelineVertexProcessor::Vertex(value.vertex),
742            primitive: value.primitive,
743            depth_stencil: value.depth_stencil,
744            multisample: value.multisample,
745            fragment: value.fragment,
746            multiview_mask: value.multiview_mask,
747            cache: value.cache,
748        }
749    }
750}
751impl<'a, PLL, SM, PLC> From<MeshPipelineDescriptor<'a, PLL, SM, PLC>>
752    for GeneralRenderPipelineDescriptor<'a, PLL, SM, PLC>
753{
754    fn from(value: MeshPipelineDescriptor<'a, PLL, SM, PLC>) -> Self {
755        Self {
756            label: value.label,
757            layout: value.layout,
758            vertex: RenderPipelineVertexProcessor::Mesh(value.task, value.mesh),
759            primitive: value.primitive,
760            depth_stencil: value.depth_stencil,
761            multisample: value.multisample,
762            fragment: value.fragment,
763            multiview_mask: value.multiview,
764            cache: value.cache,
765        }
766    }
767}
768
769/// Not a public API. For use by `player` only.
770///
771/// cbindgen:ignore
772pub type ResolvedGeneralRenderPipelineDescriptor<'a> =
773    GeneralRenderPipelineDescriptor<'a, Arc<PipelineLayout>, Arc<ShaderModule>, Arc<PipelineCache>>;
774
775#[derive(Clone, Debug)]
776#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
777pub struct PipelineCacheDescriptor<'a> {
778    pub label: Label<'a>,
779    pub data: Option<Cow<'a, [u8]>>,
780    pub fallback: bool,
781}
782
783#[derive(Clone, Debug, Error)]
784#[non_exhaustive]
785pub enum ColorStateError {
786    #[error("Format {0:?} is not renderable")]
787    FormatNotRenderable(wgt::TextureFormat),
788    #[error("Format {0:?} is not blendable")]
789    FormatNotBlendable(wgt::TextureFormat),
790    #[error("Format {0:?} does not have a color aspect")]
791    FormatNotColor(wgt::TextureFormat),
792    #[error("Sample count {0} is not supported by format {1:?} on this device. The WebGPU spec guarantees {2:?} samples are supported by this format. With the TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES feature your device supports {3:?}.")]
793    InvalidSampleCount(u32, wgt::TextureFormat, Vec<u32>, Vec<u32>),
794    #[error("Output format {pipeline} is incompatible with the shader {shader}")]
795    IncompatibleFormat {
796        pipeline: validation::NumericType,
797        shader: validation::NumericType,
798    },
799    #[error("Invalid write mask {0:?}")]
800    InvalidWriteMask(wgt::ColorWrites),
801    #[error("Using the blend factor {factor:?} for render target {target} is not possible. Only the first render target may be used when dual-source blending.")]
802    BlendFactorOnUnsupportedTarget {
803        factor: wgt::BlendFactor,
804        target: u32,
805    },
806    #[error(
807        "Blend factor {factor:?} for render target {target} is not valid. Blend factor must be `one` when using min/max blend operations."
808    )]
809    InvalidMinMaxBlendFactor {
810        factor: wgt::BlendFactor,
811        target: u32,
812    },
813}
814
815#[derive(Clone, Debug, Error)]
816#[non_exhaustive]
817pub enum DepthStencilStateError {
818    #[error("Format {0:?} is not renderable")]
819    FormatNotRenderable(wgt::TextureFormat),
820    #[error("Format {0:?} is not a depth/stencil format")]
821    FormatNotDepthOrStencil(wgt::TextureFormat),
822    #[error("Format {0:?} does not have a depth aspect, but depth test/write is enabled")]
823    FormatNotDepth(wgt::TextureFormat),
824    #[error("Format {0:?} does not have a stencil aspect, but stencil test/write is enabled")]
825    FormatNotStencil(wgt::TextureFormat),
826    #[error("Sample count {0} is not supported by format {1:?} on this device. The WebGPU spec guarantees {2:?} samples are supported by this format. With the TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES feature your device supports {3:?}.")]
827    InvalidSampleCount(u32, wgt::TextureFormat, Vec<u32>, Vec<u32>),
828    #[error("Depth bias is not compatible with non-triangle topology {0:?}")]
829    DepthBiasWithIncompatibleTopology(wgt::PrimitiveTopology),
830    #[error("Depth compare function must be specified for depth format {0:?}")]
831    MissingDepthCompare(wgt::TextureFormat),
832    #[error("Depth write enabled must be specified for depth format {0:?}")]
833    MissingDepthWriteEnabled(wgt::TextureFormat),
834}
835
836#[derive(Clone, Debug, Error)]
837#[non_exhaustive]
838pub enum CreateRenderPipelineError {
839    #[error(transparent)]
840    ColorAttachment(#[from] ColorAttachmentError),
841    #[error(transparent)]
842    Device(#[from] DeviceError),
843    #[error("Unable to derive an implicit layout")]
844    Implicit(#[from] ImplicitLayoutError),
845    #[error("Color state [{0}] is invalid")]
846    ColorState(u8, #[source] ColorStateError),
847    #[error("Depth/stencil state is invalid")]
848    DepthStencilState(#[from] DepthStencilStateError),
849    #[error("Invalid sample count {0}")]
850    InvalidSampleCount(u32),
851    #[error("The number of vertex buffers {given} exceeds the limit {limit}")]
852    TooManyVertexBuffers { given: u32, limit: u32 },
853    #[error("The number of bind groups + vertex buffers {given} exceeds the limit {limit}")]
854    TooManyBindGroupsPlusVertexBuffers { given: u32, limit: u32 },
855    #[error("The number of vertex-stage buffers and acceleration structures {given} exceeds the limit {limit}")]
856    TooManyBuffersAndAccelerationStructuresInVertexStage { given: u32, limit: u32 },
857    #[error("The total number of vertex attributes {given} exceeds the limit {limit}")]
858    TooManyVertexAttributes { given: u32, limit: u32 },
859    #[error("Vertex attribute location {given} must be less than limit {limit}")]
860    VertexAttributeLocationTooLarge { given: u32, limit: u32 },
861    #[error("Vertex buffer {index} stride {given} exceeds the limit {limit}")]
862    VertexStrideTooLarge { index: u32, given: u32, limit: u32 },
863    #[error("Vertex attribute at location {location} stride {given} exceeds the limit {limit}")]
864    VertexAttributeStrideTooLarge {
865        location: wgt::ShaderLocation,
866        given: u32,
867        limit: u32,
868    },
869    #[error("Vertex buffer {index} stride {stride} does not respect `VERTEX_ALIGNMENT`")]
870    UnalignedVertexStride {
871        index: u32,
872        stride: wgt::BufferAddress,
873    },
874    #[error("Vertex attribute at location {location} has invalid offset {offset}")]
875    InvalidVertexAttributeOffset {
876        location: wgt::ShaderLocation,
877        offset: wgt::BufferAddress,
878    },
879    #[error("Two or more vertex attributes were assigned to the same location in the shader: {0}")]
880    ShaderLocationClash(u32),
881    #[error("Strip index format was not set to None but to {strip_index_format:?} while using the non-strip topology {topology:?}")]
882    StripIndexFormatForNonStripTopology {
883        strip_index_format: Option<wgt::IndexFormat>,
884        topology: wgt::PrimitiveTopology,
885    },
886    #[error("Conservative Rasterization is only supported for wgt::PolygonMode::Fill")]
887    ConservativeRasterizationNonFillPolygonMode,
888    #[error(transparent)]
889    MissingFeatures(#[from] MissingFeatures),
890    #[error(transparent)]
891    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
892    #[error("Error matching {stage:?} shader requirements against the pipeline")]
893    Stage {
894        stage: wgt::ShaderStages,
895        #[source]
896        error: validation::StageError,
897    },
898    #[error("Internal error in {stage:?} shader: {error}")]
899    Internal {
900        stage: wgt::ShaderStages,
901        error: String,
902    },
903    #[error("Pipeline constant error in {stage:?} shader: {error}")]
904    PipelineConstants {
905        stage: wgt::ShaderStages,
906        error: String,
907    },
908    #[error("In the provided shader, the type given for group {group} binding {binding} has a size of {size}. As the device does not support `DownlevelFlags::BUFFER_BINDINGS_NOT_16_BYTE_ALIGNED`, the type must have a size that is a multiple of 16 bytes.")]
909    UnalignedShader { group: u32, binding: u32, size: u64 },
910    #[error("Dual-source blending requires exactly one color target, but {count} color targets are present")]
911    DualSourceBlendingWithMultipleColorTargets { count: usize },
912    #[error("{}", concat!(
913        "At least one color attachment or depth-stencil attachment was expected, ",
914        "but no render target for the pipeline was specified."
915    ))]
916    NoTargetSpecified,
917    #[error(transparent)]
918    InvalidResource(#[from] InvalidResourceError),
919}
920
921impl WebGpuError for CreateRenderPipelineError {
922    fn webgpu_error_type(&self) -> ErrorType {
923        match self {
924            Self::Device(e) => e.webgpu_error_type(),
925            Self::InvalidResource(e) => e.webgpu_error_type(),
926            Self::MissingFeatures(e) => e.webgpu_error_type(),
927            Self::MissingDownlevelFlags(e) => e.webgpu_error_type(),
928
929            Self::Internal { .. } => ErrorType::Internal,
930
931            Self::ColorAttachment(_)
932            | Self::Implicit(_)
933            | Self::ColorState(_, _)
934            | Self::DepthStencilState(_)
935            | Self::InvalidSampleCount(_)
936            | Self::TooManyVertexBuffers { .. }
937            | Self::TooManyBindGroupsPlusVertexBuffers { .. }
938            | Self::TooManyBuffersAndAccelerationStructuresInVertexStage { .. }
939            | Self::TooManyVertexAttributes { .. }
940            | Self::VertexAttributeLocationTooLarge { .. }
941            | Self::VertexStrideTooLarge { .. }
942            | Self::UnalignedVertexStride { .. }
943            | Self::InvalidVertexAttributeOffset { .. }
944            | Self::ShaderLocationClash(_)
945            | Self::StripIndexFormatForNonStripTopology { .. }
946            | Self::ConservativeRasterizationNonFillPolygonMode
947            | Self::Stage { .. }
948            | Self::UnalignedShader { .. }
949            | Self::DualSourceBlendingWithMultipleColorTargets { .. }
950            | Self::NoTargetSpecified
951            | Self::PipelineConstants { .. }
952            | Self::VertexAttributeStrideTooLarge { .. } => ErrorType::Validation,
953        }
954    }
955}
956
957bitflags::bitflags! {
958    #[repr(transparent)]
959    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
960    pub struct PipelineFlags: u32 {
961        const BLEND_CONSTANT = 1 << 0;
962        const STENCIL_REFERENCE = 1 << 1;
963        const WRITES_DEPTH = 1 << 2;
964        const WRITES_STENCIL = 1 << 3;
965    }
966}
967
968/// How a render pipeline will retrieve attributes from a particular vertex buffer.
969#[derive(Clone, Copy, Debug)]
970pub struct VertexStep {
971    /// The byte stride in the buffer between one attribute value and the next.
972    pub stride: wgt::BufferAddress,
973
974    /// The byte size required to fit the last vertex in the stream.
975    pub last_stride: wgt::BufferAddress,
976
977    /// Whether the buffer is indexed by vertex number or instance number.
978    pub mode: wgt::VertexStepMode,
979}
980
981impl Default for VertexStep {
982    fn default() -> Self {
983        Self {
984            stride: 0,
985            last_stride: 0,
986            mode: wgt::VertexStepMode::Vertex,
987        }
988    }
989}
990
991#[derive(Debug)]
992pub(crate) struct RenderPipelineState {
993    pub(crate) raw: ManuallyDrop<Box<dyn hal::DynRenderPipeline>>,
994    pub(crate) layout: Arc<PipelineLayout>,
995}
996
997#[derive(Debug)]
998pub struct RenderPipeline {
999    pub(crate) state: ResourceState<RenderPipelineState>,
1000    pub(crate) device: Arc<Device>,
1001    pub(crate) _shader_modules: ArrayVec<Arc<ShaderModule>, { hal::MAX_CONCURRENT_SHADER_STAGES }>,
1002    pub(crate) pass_context: RenderPassContext,
1003    pub(crate) flags: PipelineFlags,
1004    pub(crate) topology: wgt::PrimitiveTopology,
1005    pub(crate) strip_index_format: Option<wgt::IndexFormat>,
1006    pub(crate) vertex_steps: Vec<Option<VertexStep>>,
1007    pub(crate) late_sized_buffer_groups: ArrayVec<LateSizedBufferGroup, { hal::MAX_BIND_GROUPS }>,
1008    pub(crate) immediate_slots_required: naga::valid::ImmediateSlots,
1009    /// The `label` from the descriptor used to create the resource.
1010    pub(crate) label: String,
1011    pub(crate) tracking_data: TrackingData,
1012    /// Whether this is a mesh shader pipeline
1013    pub(crate) is_mesh: bool,
1014    pub(crate) has_task_shader: bool,
1015}
1016
1017impl Drop for RenderPipeline {
1018    fn drop(&mut self) {
1019        resource_log!("Destroy raw {}", self.error_ident());
1020        #[cfg(feature = "trace")]
1021        {
1022            use crate::device::trace;
1023            if let Some(t) = self.device.trace.lock().as_mut() {
1024                t.add(trace::Action::DropRenderPipeline(unsafe {
1025                    trace::to_trace(self)
1026                }));
1027            }
1028        }
1029        let ResourceState::Valid(state) = &mut self.state else {
1030            return;
1031        };
1032        // SAFETY: We are in the Drop impl and we don't use self.raw anymore after this point.
1033        let raw = unsafe { ManuallyDrop::take(&mut state.raw) };
1034        unsafe {
1035            self.device.raw().destroy_render_pipeline(raw);
1036        }
1037    }
1038}
1039
1040crate::impl_resource_type!(RenderPipeline);
1041crate::impl_labeled!(RenderPipeline);
1042crate::impl_parent_device!(RenderPipeline);
1043crate::impl_storage_item!(RenderPipeline);
1044crate::impl_trackable!(RenderPipeline);
1045
1046impl RenderPipeline {
1047    pub(crate) fn raw(&self) -> Result<&dyn hal::DynRenderPipeline, InvalidResourceError> {
1048        let ResourceState::Valid(state) = &self.state else {
1049            return Err(InvalidResourceError(self.error_ident()));
1050        };
1051        Ok(state.raw.as_ref())
1052    }
1053
1054    pub(crate) fn layout(&self) -> Result<&Arc<PipelineLayout>, InvalidResourceError> {
1055        let ResourceState::Valid(state) = &self.state else {
1056            return Err(InvalidResourceError(self.error_ident()));
1057        };
1058        Ok(&state.layout)
1059    }
1060
1061    pub(crate) fn check_valid(&self) -> Result<(), InvalidResourceError> {
1062        let ResourceState::Valid(_) = &self.state else {
1063            return Err(InvalidResourceError(self.error_ident()));
1064        };
1065        Ok(())
1066    }
1067
1068    pub(crate) fn invalid(device: Arc<Device>, label: String) -> Arc<Self> {
1069        Arc::new(Self {
1070            tracking_data: TrackingData::new(device.tracker_indices.render_pipelines.clone()),
1071            state: ResourceState::Invalid,
1072            device,
1073            _shader_modules: ArrayVec::new(),
1074            pass_context: RenderPassContext {
1075                attachments: AttachmentData {
1076                    colors: ArrayVec::new(),
1077                    resolves: ArrayVec::new(),
1078                    depth_stencil: None,
1079                },
1080                sample_count: 0,
1081                multiview_mask: None,
1082            },
1083            flags: PipelineFlags::empty(),
1084            topology: wgt::PrimitiveTopology::TriangleList,
1085            strip_index_format: None,
1086            vertex_steps: Vec::new(),
1087            late_sized_buffer_groups: ArrayVec::new(),
1088            immediate_slots_required: naga::valid::ImmediateSlots::default(),
1089            label,
1090            is_mesh: false,
1091            has_task_shader: false,
1092        })
1093    }
1094
1095    pub fn get_bind_group_layout_inner(
1096        self: &Arc<Self>,
1097        index: u32,
1098    ) -> Result<Arc<BindGroupLayout>, GetBindGroupLayoutError> {
1099        self.layout()?.get_bind_group_layout(index, self.into())
1100    }
1101
1102    pub fn get_bind_group_layout(
1103        self: &Arc<Self>,
1104        index: u32,
1105    ) -> (Arc<BindGroupLayout>, Option<GetBindGroupLayoutError>) {
1106        let (bgl, error) = match self.get_bind_group_layout_inner(index) {
1107            Ok(bgl) => (bgl, None),
1108            Err(e) => (
1109                BindGroupLayout::invalid(&self.device, String::new()),
1110                Some(e),
1111            ),
1112        };
1113        #[cfg(feature = "trace")]
1114        if let Some(ref mut trace) = *self.device.trace.lock() {
1115            use crate::device::trace;
1116            use trace::IntoTrace;
1117            trace.add(trace::Action::GetRenderPipelineBindGroupLayout {
1118                id: bgl.to_trace(),
1119                pipeline: self.to_trace(),
1120                index,
1121            });
1122        };
1123        (bgl, error)
1124    }
1125}