1use alloc::{
2 boxed::Box,
3 string::{String, ToString as _},
4 sync::Arc,
5 vec::Vec,
6};
7use core::fmt;
8
9use arrayvec::ArrayVec;
10use hashbrown::{hash_map::Entry, HashSet};
11use shader_io_deductions::{display_deductions_as_optional_list, MaxVertexShaderOutputDeduction};
12use thiserror::Error;
13use wgt::{
14 error::{ErrorType, WebGpuError},
15 BindGroupLayoutEntry, BindingType,
16};
17
18use crate::{
19 command::ColorAttachmentError, device::bgl, resource::InvalidResourceError,
20 validation::shader_io_deductions::MaxFragmentShaderInputDeduction, FastHashMap, FastHashSet,
21};
22
23pub mod shader_io_deductions;
24
25#[derive(Debug)]
26enum ResourceType {
27 Buffer {
28 size: wgt::BufferSize,
29 },
30 Texture {
31 dim: naga::ImageDimension,
32 arrayed: bool,
33 class: naga::ImageClass,
34 },
35 Sampler {
36 comparison: bool,
37 },
38 AccelerationStructure {
39 vertex_return: bool,
40 },
41}
42
43#[derive(Clone, Debug)]
44pub enum BindingTypeName {
45 Buffer,
46 Texture,
47 Sampler,
48 AccelerationStructure,
49 ExternalTexture,
50}
51
52impl From<&ResourceType> for BindingTypeName {
53 fn from(ty: &ResourceType) -> BindingTypeName {
54 match ty {
55 ResourceType::Buffer { .. } => BindingTypeName::Buffer,
56 ResourceType::Texture {
57 class: naga::ImageClass::External,
58 ..
59 } => BindingTypeName::ExternalTexture,
60 ResourceType::Texture { .. } => BindingTypeName::Texture,
61 ResourceType::Sampler { .. } => BindingTypeName::Sampler,
62 ResourceType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
63 }
64 }
65}
66
67impl From<&BindingType> for BindingTypeName {
68 fn from(ty: &BindingType) -> BindingTypeName {
69 match ty {
70 BindingType::Buffer { .. } => BindingTypeName::Buffer,
71 BindingType::Texture { .. } => BindingTypeName::Texture,
72 BindingType::StorageTexture { .. } => BindingTypeName::Texture,
73 BindingType::Sampler { .. } => BindingTypeName::Sampler,
74 BindingType::AccelerationStructure { .. } => BindingTypeName::AccelerationStructure,
75 BindingType::ExternalTexture => BindingTypeName::ExternalTexture,
76 }
77 }
78}
79
80#[derive(Debug)]
81struct Resource {
82 #[allow(unused)]
83 name: Option<String>,
84 bind: naga::ResourceBinding,
85 ty: ResourceType,
86 class: naga::AddressSpace,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
90enum NumericDimension {
91 Scalar,
92 Vector(naga::VectorSize),
93 Matrix(naga::VectorSize, naga::VectorSize),
94}
95
96impl fmt::Display for NumericDimension {
97 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98 match *self {
99 Self::Scalar => write!(f, ""),
100 Self::Vector(size) => write!(f, "x{}", size as u8),
101 Self::Matrix(columns, rows) => write!(f, "x{}{}", columns as u8, rows as u8),
102 }
103 }
104}
105
106#[derive(Clone, Copy, Debug, Eq, PartialEq)]
107pub struct NumericType {
108 dim: NumericDimension,
109 scalar: naga::Scalar,
110}
111
112impl fmt::Display for NumericType {
113 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114 write!(
115 f,
116 "{:?}{}{}",
117 self.scalar.kind,
118 self.scalar.width * 8,
119 self.dim
120 )
121 }
122}
123
124#[derive(Clone, Debug, Eq, PartialEq)]
125pub struct InterfaceVar {
126 pub ty: NumericType,
127 interpolation: Option<naga::Interpolation>,
128 sampling: Option<naga::Sampling>,
129 per_primitive: bool,
130}
131
132impl InterfaceVar {
133 pub fn vertex_attribute(format: wgt::VertexFormat) -> Self {
134 InterfaceVar {
135 ty: NumericType::from_vertex_format(format),
136 interpolation: None,
137 sampling: None,
138 per_primitive: false,
139 }
140 }
141}
142
143impl fmt::Display for InterfaceVar {
144 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
145 write!(
146 f,
147 "{} interpolated as {:?} with sampling {:?}",
148 self.ty, self.interpolation, self.sampling
149 )
150 }
151}
152
153#[derive(Debug, Eq, PartialEq)]
154enum Varying {
155 Local { location: u32, iv: InterfaceVar },
156 BuiltIn(BuiltIn),
157}
158
159#[derive(Clone, Debug, Eq, PartialEq)]
160enum BuiltIn {
161 Position { invariant: bool },
162 ViewIndex,
163 BaseInstance,
164 BaseVertex,
165 ClipDistances { array_size: u32 },
166 CullDistance,
167 InstanceIndex,
168 PointSize,
169 VertexIndex,
170 DrawIndex,
171 FragDepth,
172 PointCoord,
173 FrontFacing,
174 PrimitiveIndex,
175 Barycentric { perspective: bool },
176 SampleIndex,
177 SampleMask,
178 GlobalInvocationId,
179 LocalInvocationId,
180 LocalInvocationIndex,
181 WorkGroupId,
182 WorkGroupSize,
183 NumWorkGroups,
184 NumSubgroups,
185 SubgroupId,
186 SubgroupSize,
187 SubgroupInvocationId,
188 MeshTaskSize,
189 CullPrimitive,
190 PointIndex,
191 LineIndices,
192 TriangleIndices,
193 VertexCount,
194 Vertices,
195 PrimitiveCount,
196 Primitives,
197 RayInvocationId,
198 NumRayInvocations,
199 InstanceCustomData,
200 GeometryIndex,
201 WorldRayOrigin,
202 WorldRayDirection,
203 ObjectRayOrigin,
204 ObjectRayDirection,
205 RayTmin,
206 RayTCurrentMax,
207 ObjectToWorld,
208 WorldToObject,
209 HitKind,
210}
211
212impl BuiltIn {
213 pub fn to_naga(&self) -> naga::BuiltIn {
214 match self {
215 &Self::Position { invariant } => naga::BuiltIn::Position { invariant },
216 Self::ViewIndex => naga::BuiltIn::ViewIndex,
217 Self::BaseInstance => naga::BuiltIn::BaseInstance,
218 Self::BaseVertex => naga::BuiltIn::BaseVertex,
219 Self::ClipDistances { .. } => naga::BuiltIn::ClipDistances,
220 Self::CullDistance => naga::BuiltIn::CullDistance,
221 Self::InstanceIndex => naga::BuiltIn::InstanceIndex,
222 Self::PointSize => naga::BuiltIn::PointSize,
223 Self::VertexIndex => naga::BuiltIn::VertexIndex,
224 Self::DrawIndex => naga::BuiltIn::DrawIndex,
225 Self::FragDepth => naga::BuiltIn::FragDepth,
226 Self::PointCoord => naga::BuiltIn::PointCoord,
227 Self::FrontFacing => naga::BuiltIn::FrontFacing,
228 Self::PrimitiveIndex => naga::BuiltIn::PrimitiveIndex,
229 Self::Barycentric { perspective } => naga::BuiltIn::Barycentric {
230 perspective: *perspective,
231 },
232 Self::SampleIndex => naga::BuiltIn::SampleIndex,
233 Self::SampleMask => naga::BuiltIn::SampleMask,
234 Self::GlobalInvocationId => naga::BuiltIn::GlobalInvocationId,
235 Self::LocalInvocationId => naga::BuiltIn::LocalInvocationId,
236 Self::LocalInvocationIndex => naga::BuiltIn::LocalInvocationIndex,
237 Self::WorkGroupId => naga::BuiltIn::WorkGroupId,
238 Self::WorkGroupSize => naga::BuiltIn::WorkGroupSize,
239 Self::NumWorkGroups => naga::BuiltIn::NumWorkGroups,
240 Self::NumSubgroups => naga::BuiltIn::NumSubgroups,
241 Self::SubgroupId => naga::BuiltIn::SubgroupId,
242 Self::SubgroupSize => naga::BuiltIn::SubgroupSize,
243 Self::SubgroupInvocationId => naga::BuiltIn::SubgroupInvocationId,
244 Self::MeshTaskSize => naga::BuiltIn::MeshTaskSize,
245 Self::CullPrimitive => naga::BuiltIn::CullPrimitive,
246 Self::PointIndex => naga::BuiltIn::PointIndex,
247 Self::LineIndices => naga::BuiltIn::LineIndices,
248 Self::TriangleIndices => naga::BuiltIn::TriangleIndices,
249 Self::VertexCount => naga::BuiltIn::VertexCount,
250 Self::Vertices => naga::BuiltIn::Vertices,
251 Self::PrimitiveCount => naga::BuiltIn::PrimitiveCount,
252 Self::Primitives => naga::BuiltIn::Primitives,
253 Self::RayInvocationId => naga::BuiltIn::RayInvocationId,
254 Self::NumRayInvocations => naga::BuiltIn::NumRayInvocations,
255 Self::InstanceCustomData => naga::BuiltIn::InstanceCustomData,
256 Self::GeometryIndex => naga::BuiltIn::GeometryIndex,
257 Self::WorldRayOrigin => naga::BuiltIn::WorldRayOrigin,
258 Self::WorldRayDirection => naga::BuiltIn::WorldRayDirection,
259 Self::ObjectRayOrigin => naga::BuiltIn::ObjectRayOrigin,
260 Self::ObjectRayDirection => naga::BuiltIn::ObjectRayDirection,
261 Self::RayTmin => naga::BuiltIn::RayTmin,
262 Self::RayTCurrentMax => naga::BuiltIn::RayTCurrentMax,
263 Self::ObjectToWorld => naga::BuiltIn::ObjectToWorld,
264 Self::WorldToObject => naga::BuiltIn::WorldToObject,
265 Self::HitKind => naga::BuiltIn::HitKind,
266 }
267 }
268}
269
270#[allow(unused)]
271#[derive(Debug)]
272struct SpecializationConstant {
273 id: u32,
274 ty: NumericType,
275}
276
277#[derive(Debug)]
278struct EntryPointMeshInfo {
279 max_vertices: u32,
280 max_primitives: u32,
281 primitive_topology: wgt::PrimitiveTopology,
282}
283
284#[derive(Debug, Default)]
285struct EntryPoint {
286 inputs: Vec<Varying>,
287 outputs: Vec<Varying>,
288 resources: Vec<naga::Handle<Resource>>,
289 #[allow(unused)]
290 spec_constants: Vec<SpecializationConstant>,
291 sampling_pairs: FastHashSet<(naga::Handle<Resource>, naga::Handle<Resource>)>,
292 workgroup_size: [u32; 3],
293 dual_source_blending: bool,
294 task_payload_size: Option<u32>,
295 mesh_info: Option<EntryPointMeshInfo>,
296 immediate_slots_required: naga::valid::ImmediateSlots,
297}
298
299#[derive(Debug)]
300pub struct Interface {
301 limits: wgt::Limits,
302 resources: naga::Arena<Resource>,
303 entry_points: FastHashMap<(naga::ShaderStage, String), EntryPoint>,
304 pub(crate) immediate_size: u32,
305}
306
307#[derive(Debug)]
308pub struct PassthroughInterface {
309 pub entry_point_names: HashSet<String>,
310}
311
312#[expect(clippy::large_enum_variant)]
316#[derive(Debug)]
317pub enum ShaderMetaData {
318 Interface(Interface),
319 Passthrough(PassthroughInterface),
320}
321impl ShaderMetaData {
322 pub fn interface(&self) -> Option<&Interface> {
323 match self {
324 Self::Interface(i) => Some(i),
325 Self::Passthrough(_) => None,
326 }
327 }
328}
329
330#[derive(Clone, Debug, Error)]
331#[non_exhaustive]
332pub enum BindingError {
333 #[error("Binding is missing from the pipeline layout")]
334 Missing,
335 #[error("Visibility flags don't include the shader stage")]
336 Invisible,
337 #[error(
338 "Type on the shader side ({shader:?}) does not match the pipeline binding ({binding:?})"
339 )]
340 WrongType {
341 binding: BindingTypeName,
342 shader: BindingTypeName,
343 },
344 #[error("Storage class {binding:?} doesn't match the shader {shader:?}")]
345 WrongAddressSpace {
346 binding: naga::AddressSpace,
347 shader: naga::AddressSpace,
348 },
349 #[error("Address space {space:?} is not a valid Buffer address space")]
350 WrongBufferAddressSpace { space: naga::AddressSpace },
351 #[error("Buffer structure size {buffer_size}, added to one element of an unbound array, if it's the last field, ended up greater than the given `min_binding_size`, which is {min_binding_size}")]
352 WrongBufferSize {
353 buffer_size: wgt::BufferSize,
354 min_binding_size: wgt::BufferSize,
355 },
356 #[error("View dimension {dim:?} (is array: {is_array}) doesn't match the binding {binding:?}")]
357 WrongTextureViewDimension {
358 dim: naga::ImageDimension,
359 is_array: bool,
360 binding: BindingType,
361 },
362 #[error("Texture class {binding:?} doesn't match the shader {shader:?}")]
363 WrongTextureClass {
364 binding: naga::ImageClass,
365 shader: naga::ImageClass,
366 },
367 #[error("Comparison flag doesn't match the shader")]
368 WrongSamplerComparison,
369 #[error("Derived bind group layout type is not consistent between stages")]
370 InconsistentlyDerivedType,
371 #[error("Texture format {0:?} is not supported for storage use")]
372 BadStorageFormat(wgt::TextureFormat),
373}
374
375impl WebGpuError for BindingError {
376 fn webgpu_error_type(&self) -> ErrorType {
377 ErrorType::Validation
378 }
379}
380
381#[derive(Clone, Debug, Error)]
382#[non_exhaustive]
383pub enum FilteringError {
384 #[error("Integer textures can't be sampled with a filtering sampler")]
385 Integer,
386 #[error("Non-filterable float textures can't be sampled with a filtering sampler")]
387 Float,
388}
389
390impl WebGpuError for FilteringError {
391 fn webgpu_error_type(&self) -> ErrorType {
392 ErrorType::Validation
393 }
394}
395
396#[derive(Clone, Debug, Error)]
397#[non_exhaustive]
398pub enum InputError {
399 #[error("Input is not provided by the earlier stage in the pipeline")]
400 Missing,
401 #[error("Input type is not compatible with the provided {0}")]
402 WrongType(NumericType),
403 #[error("Input interpolation doesn't match provided {0:?}")]
404 InterpolationMismatch(Option<naga::Interpolation>),
405 #[error("Input sampling doesn't match provided {0:?}")]
406 SamplingMismatch(Option<naga::Sampling>),
407 #[error("Pipeline input has per_primitive={pipeline_input}, but shader expects per_primitive={shader}")]
408 WrongPerPrimitive { pipeline_input: bool, shader: bool },
409}
410
411impl WebGpuError for InputError {
412 fn webgpu_error_type(&self) -> ErrorType {
413 ErrorType::Validation
414 }
415}
416
417#[derive(Clone, Debug, Error)]
419#[non_exhaustive]
420pub enum StageError {
421 #[error(transparent)]
422 InvalidWorkgroupSize(#[from] InvalidWorkgroupSizeError),
423 #[error("Unable to find entry point '{0}'")]
424 MissingEntryPoint(String),
425 #[error("Shader global {0:?} is not available in the pipeline layout")]
426 Binding(naga::ResourceBinding, #[source] BindingError),
427 #[error("Unable to filter the texture ({texture:?}) by the sampler ({sampler:?})")]
428 Filtering {
429 texture: naga::ResourceBinding,
430 sampler: naga::ResourceBinding,
431 #[source]
432 error: FilteringError,
433 },
434 #[error("Location[{location}] {var} is not provided by the previous stage outputs")]
435 Input {
436 location: wgt::ShaderLocation,
437 var: InterfaceVar,
438 #[source]
439 error: InputError,
440 },
441 #[error(
442 "Unable to select an entry point: no entry point was found in the provided shader module"
443 )]
444 NoEntryPointFound,
445 #[error(
446 "Unable to select an entry point: \
447 multiple entry points were found in the provided shader module, \
448 but no entry point was specified"
449 )]
450 MultipleEntryPointsFound,
451 #[error(transparent)]
452 InvalidResource(#[from] InvalidResourceError),
453 #[error(
454 "vertex shader output location Location[{location}] ({var}) exceeds the \
455 `max_inter_stage_shader_variables` limit ({}, 0-based){}",
456 limit - 1,
458 display_deductions_as_optional_list(deductions, |d| d.for_location())
459 )]
460 VertexOutputLocationTooLarge {
461 location: u32,
462 var: InterfaceVar,
463 limit: u32,
464 deductions: Vec<MaxVertexShaderOutputDeduction>,
465 },
466 #[error(
467 "found {num_found} user-defined vertex shader output variables, which exceeds the \
468 `max_inter_stage_shader_variables` limit ({limit}){}",
469 display_deductions_as_optional_list(deductions, |d| d.for_variables())
470 )]
471 TooManyUserDefinedVertexOutputs {
472 num_found: u32,
473 limit: u32,
474 deductions: Vec<MaxVertexShaderOutputDeduction>,
475 },
476 #[error(
477 "fragment shader input location Location[{location}] ({var}) exceeds the \
478 `max_inter_stage_shader_variables` limit ({}, 0-based){}",
479 limit - 1,
481 display_deductions_as_optional_list(deductions, |d| d.for_variables())
485 )]
486 FragmentInputLocationTooLarge {
487 location: u32,
488 var: InterfaceVar,
489 limit: u32,
490 deductions: Vec<MaxFragmentShaderInputDeduction>,
491 },
492 #[error(
493 "found {num_found} user-defined fragment shader input variables, which exceeds the \
494 `max_inter_stage_shader_variables` limit ({limit}){}",
495 display_deductions_as_optional_list(deductions, |d| d.for_variables())
496 )]
497 TooManyUserDefinedFragmentInputs {
498 num_found: u32,
499 limit: u32,
500 deductions: Vec<MaxFragmentShaderInputDeduction>,
501 },
502 #[error(
503 "Location[{location}] {var}'s index exceeds the `max_color_attachments` limit ({limit})"
504 )]
505 ColorAttachmentLocationTooLarge {
506 location: u32,
507 var: InterfaceVar,
508 limit: u32,
509 },
510 #[error("Mesh shaders are limited to {limit} output vertices by `Limits::max_mesh_output_vertices`, but the shader has a maximum number of {value}")]
511 TooManyMeshVertices { limit: u32, value: u32 },
512 #[error("Mesh shaders are limited to {limit} output primitives by `Limits::max_mesh_output_primitives`, but the shader has a maximum number of {value}")]
513 TooManyMeshPrimitives { limit: u32, value: u32 },
514 #[error("Mesh or task shaders are limited to {limit} bytes of task payload by `Limits::max_task_payload_size`, but the shader has a task payload of size {value}")]
515 TaskPayloadTooLarge { limit: u32, value: u32 },
516 #[error("Mesh shader's task payload has size ({shader:?}), which doesn't match the payload declared in the task stage ({input:?})")]
517 TaskPayloadMustMatch {
518 input: Option<u32>,
519 shader: Option<u32>,
520 },
521 #[error("Primitive index can only be used in a fragment shader if the preceding shader was a vertex shader or a mesh shader that writes to primitive index.")]
522 InvalidPrimitiveIndex,
523 #[error("If a mesh shader writes to primitive index, it must be read by the fragment shader.")]
524 MissingPrimitiveIndex,
525 #[error("DrawId cannot be used in a mesh shader in a pipeline with a task shader")]
526 DrawIdError,
527 #[error("Pipeline uses dual-source blending, but the shader does not support it")]
528 InvalidDualSourceBlending,
529 #[error("Fragment shader writes depth, but pipeline does not have a depth attachment")]
530 MissingFragDepthAttachment,
531 #[error("Per vertex fragment inputs can only be used in triangle primitive pipelines")]
532 PerVertexNotTriangles,
533 #[error("Mesh shader pipelines must have primitive topology of TriangleList, LineList or PointList, and this must match with what the mesh shader declares.")]
534 MeshTopologyMismatch,
535}
536
537impl WebGpuError for StageError {
538 fn webgpu_error_type(&self) -> ErrorType {
539 match self {
540 Self::Binding(_, e) => e.webgpu_error_type(),
541 Self::InvalidResource(e) => e.webgpu_error_type(),
542 Self::Filtering {
543 texture: _,
544 sampler: _,
545 error,
546 } => error.webgpu_error_type(),
547 Self::Input {
548 location: _,
549 var: _,
550 error,
551 } => error.webgpu_error_type(),
552 Self::InvalidWorkgroupSize { .. }
553 | Self::MissingEntryPoint(..)
554 | Self::NoEntryPointFound
555 | Self::MultipleEntryPointsFound
556 | Self::VertexOutputLocationTooLarge { .. }
557 | Self::TooManyUserDefinedVertexOutputs { .. }
558 | Self::FragmentInputLocationTooLarge { .. }
559 | Self::TooManyUserDefinedFragmentInputs { .. }
560 | Self::ColorAttachmentLocationTooLarge { .. }
561 | Self::TooManyMeshVertices { .. }
562 | Self::TooManyMeshPrimitives { .. }
563 | Self::TaskPayloadTooLarge { .. }
564 | Self::TaskPayloadMustMatch { .. }
565 | Self::InvalidPrimitiveIndex
566 | Self::MissingPrimitiveIndex
567 | Self::DrawIdError
568 | Self::InvalidDualSourceBlending
569 | Self::MissingFragDepthAttachment
570 | Self::PerVertexNotTriangles
571 | Self::MeshTopologyMismatch => ErrorType::Validation,
572 }
573 }
574}
575
576pub use wgpu_naga_bridge::map_storage_format_from_naga;
577pub use wgpu_naga_bridge::map_storage_format_to_naga;
578
579impl Resource {
580 fn check_binding_use(&self, entry: &BindGroupLayoutEntry) -> Result<(), BindingError> {
581 match self.ty {
582 ResourceType::Buffer { size } => {
583 let min_size = match entry.ty {
584 BindingType::Buffer {
585 ty,
586 has_dynamic_offset: _,
587 min_binding_size,
588 } => {
589 let class = match ty {
590 wgt::BufferBindingType::Uniform => naga::AddressSpace::Uniform,
591 wgt::BufferBindingType::Storage { read_only } => {
592 let mut naga_access = naga::StorageAccess::LOAD;
593 naga_access.set(naga::StorageAccess::STORE, !read_only);
594 naga::AddressSpace::Storage {
595 access: naga_access,
596 }
597 }
598 };
599 if self.class != class {
600 return Err(BindingError::WrongAddressSpace {
601 binding: class,
602 shader: self.class,
603 });
604 }
605 min_binding_size
606 }
607 _ => {
608 return Err(BindingError::WrongType {
609 binding: (&entry.ty).into(),
610 shader: (&self.ty).into(),
611 })
612 }
613 };
614 match min_size {
615 Some(non_zero) if non_zero < size => {
616 return Err(BindingError::WrongBufferSize {
617 buffer_size: size,
618 min_binding_size: non_zero,
619 })
620 }
621 _ => (),
622 }
623 }
624 ResourceType::Sampler { comparison } => match entry.ty {
625 BindingType::Sampler(ty) => {
626 if (ty == wgt::SamplerBindingType::Comparison) != comparison {
627 return Err(BindingError::WrongSamplerComparison);
628 }
629 }
630 _ => {
631 return Err(BindingError::WrongType {
632 binding: (&entry.ty).into(),
633 shader: (&self.ty).into(),
634 })
635 }
636 },
637 ResourceType::Texture {
638 dim,
639 arrayed,
640 class: shader_class,
641 } => {
642 let view_dimension = match entry.ty {
643 BindingType::Texture { view_dimension, .. }
644 | BindingType::StorageTexture { view_dimension, .. } => view_dimension,
645 BindingType::ExternalTexture => wgt::TextureViewDimension::D2,
646 _ => {
647 return Err(BindingError::WrongTextureViewDimension {
648 dim,
649 is_array: false,
650 binding: entry.ty,
651 })
652 }
653 };
654 if arrayed {
655 match (dim, view_dimension) {
656 (naga::ImageDimension::D2, wgt::TextureViewDimension::D2Array) => (),
657 (naga::ImageDimension::Cube, wgt::TextureViewDimension::CubeArray) => (),
658 _ => {
659 return Err(BindingError::WrongTextureViewDimension {
660 dim,
661 is_array: true,
662 binding: entry.ty,
663 })
664 }
665 }
666 } else {
667 match (dim, view_dimension) {
668 (naga::ImageDimension::D1, wgt::TextureViewDimension::D1) => (),
669 (naga::ImageDimension::D2, wgt::TextureViewDimension::D2) => (),
670 (naga::ImageDimension::D3, wgt::TextureViewDimension::D3) => (),
671 (naga::ImageDimension::Cube, wgt::TextureViewDimension::Cube) => (),
672 _ => {
673 return Err(BindingError::WrongTextureViewDimension {
674 dim,
675 is_array: false,
676 binding: entry.ty,
677 })
678 }
679 }
680 }
681 match entry.ty {
682 BindingType::Texture {
683 sample_type,
684 view_dimension: _,
685 multisampled: multi,
686 } => {
687 let binding_class = match sample_type {
688 wgt::TextureSampleType::Float { .. } => naga::ImageClass::Sampled {
689 kind: naga::ScalarKind::Float,
690 multi,
691 },
692 wgt::TextureSampleType::Sint => naga::ImageClass::Sampled {
693 kind: naga::ScalarKind::Sint,
694 multi,
695 },
696 wgt::TextureSampleType::Uint => naga::ImageClass::Sampled {
697 kind: naga::ScalarKind::Uint,
698 multi,
699 },
700 wgt::TextureSampleType::Depth => naga::ImageClass::Depth { multi },
701 };
702 if shader_class == binding_class {
703 Ok(())
704 } else {
705 Err(binding_class)
706 }
707 }
708 BindingType::StorageTexture {
709 access: wgt_binding_access,
710 format: wgt_binding_format,
711 view_dimension: _,
712 } => {
713 const LOAD_STORE: naga::StorageAccess =
714 naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
715 let binding_format = map_storage_format_to_naga(wgt_binding_format)
716 .ok_or(BindingError::BadStorageFormat(wgt_binding_format))?;
717 let binding_access = match wgt_binding_access {
718 wgt::StorageTextureAccess::ReadOnly => naga::StorageAccess::LOAD,
719 wgt::StorageTextureAccess::WriteOnly => naga::StorageAccess::STORE,
720 wgt::StorageTextureAccess::ReadWrite => LOAD_STORE,
721 wgt::StorageTextureAccess::Atomic => {
722 naga::StorageAccess::ATOMIC | LOAD_STORE
723 }
724 };
725 match shader_class {
726 naga::ImageClass::Storage {
729 format: shader_format,
730 access: shader_access,
731 } if shader_format == binding_format
732 && (shader_access == binding_access
733 || shader_access == naga::StorageAccess::STORE
734 && binding_access == LOAD_STORE) =>
735 {
736 Ok(())
737 }
738 _ => Err(naga::ImageClass::Storage {
739 format: binding_format,
740 access: binding_access,
741 }),
742 }
743 }
744 BindingType::ExternalTexture => {
745 let binding_class = naga::ImageClass::External;
746 if shader_class == binding_class {
747 Ok(())
748 } else {
749 Err(binding_class)
750 }
751 }
752 _ => {
753 return Err(BindingError::WrongType {
754 binding: (&entry.ty).into(),
755 shader: (&self.ty).into(),
756 })
757 }
758 }
759 .map_err(|binding_class| BindingError::WrongTextureClass {
760 binding: binding_class,
761 shader: shader_class,
762 })?;
763 }
764 ResourceType::AccelerationStructure { vertex_return } => match entry.ty {
765 BindingType::AccelerationStructure {
766 vertex_return: entry_vertex_return,
767 } if vertex_return == entry_vertex_return => (),
768 _ => {
769 return Err(BindingError::WrongType {
770 binding: (&entry.ty).into(),
771 shader: (&self.ty).into(),
772 })
773 }
774 },
775 };
776
777 Ok(())
778 }
779
780 fn derive_binding_type(
781 &self,
782 is_reffed_by_sampler_in_entrypoint: bool,
783 ) -> Result<BindingType, BindingError> {
784 Ok(match self.ty {
785 ResourceType::Buffer { size } => BindingType::Buffer {
786 ty: match self.class {
787 naga::AddressSpace::Uniform => wgt::BufferBindingType::Uniform,
788 naga::AddressSpace::Storage { access } => wgt::BufferBindingType::Storage {
789 read_only: access == naga::StorageAccess::LOAD,
790 },
791 _ => return Err(BindingError::WrongBufferAddressSpace { space: self.class }),
792 },
793 has_dynamic_offset: false,
794 min_binding_size: Some(size),
795 },
796 ResourceType::Sampler { comparison } => BindingType::Sampler(if comparison {
797 wgt::SamplerBindingType::Comparison
798 } else {
799 wgt::SamplerBindingType::Filtering
800 }),
801 ResourceType::Texture {
802 dim,
803 arrayed,
804 class,
805 } => {
806 let view_dimension = match dim {
807 naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
808 naga::ImageDimension::D2 if arrayed => wgt::TextureViewDimension::D2Array,
809 naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
810 naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
811 naga::ImageDimension::Cube if arrayed => wgt::TextureViewDimension::CubeArray,
812 naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
813 };
814 match class {
815 naga::ImageClass::Sampled { multi, kind } => BindingType::Texture {
816 sample_type: match kind {
817 naga::ScalarKind::Float => wgt::TextureSampleType::Float {
818 filterable: is_reffed_by_sampler_in_entrypoint,
819 },
820 naga::ScalarKind::Sint => wgt::TextureSampleType::Sint,
821 naga::ScalarKind::Uint => wgt::TextureSampleType::Uint,
822 naga::ScalarKind::AbstractInt
823 | naga::ScalarKind::AbstractFloat
824 | naga::ScalarKind::Bool => unreachable!(),
825 },
826 view_dimension,
827 multisampled: multi,
828 },
829 naga::ImageClass::Depth { multi } => BindingType::Texture {
830 sample_type: wgt::TextureSampleType::Depth,
831 view_dimension,
832 multisampled: multi,
833 },
834 naga::ImageClass::Storage { format, access } => BindingType::StorageTexture {
835 access: {
836 const LOAD_STORE: naga::StorageAccess =
837 naga::StorageAccess::LOAD.union(naga::StorageAccess::STORE);
838 match access {
839 naga::StorageAccess::LOAD => wgt::StorageTextureAccess::ReadOnly,
840 naga::StorageAccess::STORE => wgt::StorageTextureAccess::WriteOnly,
841 LOAD_STORE => wgt::StorageTextureAccess::ReadWrite,
842 _ if access.contains(naga::StorageAccess::ATOMIC) => {
843 wgt::StorageTextureAccess::Atomic
844 }
845 _ => unreachable!(),
846 }
847 },
848 view_dimension,
849 format: {
850 let f = map_storage_format_from_naga(format);
851 let original = map_storage_format_to_naga(f)
852 .ok_or(BindingError::BadStorageFormat(f))?;
853 debug_assert_eq!(format, original);
854 f
855 },
856 },
857 naga::ImageClass::External => BindingType::ExternalTexture,
858 }
859 }
860 ResourceType::AccelerationStructure { vertex_return } => {
861 BindingType::AccelerationStructure { vertex_return }
862 }
863 })
864 }
865}
866
867impl NumericType {
868 fn from_vertex_format(format: wgt::VertexFormat) -> Self {
869 use naga::{Scalar, VectorSize as Vs};
870 use wgt::VertexFormat as Vf;
871
872 let (dim, scalar) = match format {
873 Vf::Uint8 | Vf::Uint16 | Vf::Uint32 => (NumericDimension::Scalar, Scalar::U32),
874 Vf::Uint8x2 | Vf::Uint16x2 | Vf::Uint32x2 => {
875 (NumericDimension::Vector(Vs::Bi), Scalar::U32)
876 }
877 Vf::Uint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::U32),
878 Vf::Uint8x4 | Vf::Uint16x4 | Vf::Uint32x4 => {
879 (NumericDimension::Vector(Vs::Quad), Scalar::U32)
880 }
881 Vf::Sint8 | Vf::Sint16 | Vf::Sint32 => (NumericDimension::Scalar, Scalar::I32),
882 Vf::Sint8x2 | Vf::Sint16x2 | Vf::Sint32x2 => {
883 (NumericDimension::Vector(Vs::Bi), Scalar::I32)
884 }
885 Vf::Sint32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::I32),
886 Vf::Sint8x4 | Vf::Sint16x4 | Vf::Sint32x4 => {
887 (NumericDimension::Vector(Vs::Quad), Scalar::I32)
888 }
889 Vf::Unorm8 | Vf::Unorm16 | Vf::Snorm8 | Vf::Snorm16 | Vf::Float16 | Vf::Float32 => {
890 (NumericDimension::Scalar, Scalar::F32)
891 }
892 Vf::Unorm8x2
893 | Vf::Snorm8x2
894 | Vf::Unorm16x2
895 | Vf::Snorm16x2
896 | Vf::Float16x2
897 | Vf::Float32x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
898 Vf::Float32x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
899 Vf::Unorm8x4
900 | Vf::Snorm8x4
901 | Vf::Unorm16x4
902 | Vf::Snorm16x4
903 | Vf::Float16x4
904 | Vf::Float32x4
905 | Vf::Unorm10_10_10_2
906 | Vf::Unorm8x4Bgra => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
907 Vf::Float64 => (NumericDimension::Scalar, Scalar::F64),
908 Vf::Float64x2 => (NumericDimension::Vector(Vs::Bi), Scalar::F64),
909 Vf::Float64x3 => (NumericDimension::Vector(Vs::Tri), Scalar::F64),
910 Vf::Float64x4 => (NumericDimension::Vector(Vs::Quad), Scalar::F64),
911 };
912
913 NumericType {
914 dim,
915 scalar,
918 }
919 }
920
921 fn from_texture_format(format: wgt::TextureFormat) -> Self {
922 use naga::{Scalar, VectorSize as Vs};
923 use wgt::TextureFormat as Tf;
924
925 let (dim, scalar) = match format {
926 Tf::R8Unorm | Tf::R8Snorm | Tf::R16Float | Tf::R32Float => {
927 (NumericDimension::Scalar, Scalar::F32)
928 }
929 Tf::R8Uint | Tf::R16Uint | Tf::R32Uint => (NumericDimension::Scalar, Scalar::U32),
930 Tf::R8Sint | Tf::R16Sint | Tf::R32Sint => (NumericDimension::Scalar, Scalar::I32),
931 Tf::Rg8Unorm | Tf::Rg8Snorm | Tf::Rg16Float | Tf::Rg32Float => {
932 (NumericDimension::Vector(Vs::Bi), Scalar::F32)
933 }
934 Tf::R64Uint => (NumericDimension::Scalar, Scalar::U64),
935 Tf::Rg8Uint | Tf::Rg16Uint | Tf::Rg32Uint => {
936 (NumericDimension::Vector(Vs::Bi), Scalar::U32)
937 }
938 Tf::Rg8Sint | Tf::Rg16Sint | Tf::Rg32Sint => {
939 (NumericDimension::Vector(Vs::Bi), Scalar::I32)
940 }
941 Tf::R16Snorm | Tf::R16Unorm => (NumericDimension::Scalar, Scalar::F32),
942 Tf::Rg16Snorm | Tf::Rg16Unorm => (NumericDimension::Vector(Vs::Bi), Scalar::F32),
943 Tf::Rgba16Snorm | Tf::Rgba16Unorm => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
944 Tf::Rgba8Unorm
945 | Tf::Rgba8UnormSrgb
946 | Tf::Rgba8Snorm
947 | Tf::Bgra8Unorm
948 | Tf::Bgra8UnormSrgb
949 | Tf::Rgb10a2Unorm
950 | Tf::Rgba16Float
951 | Tf::Rgba32Float => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
952 Tf::Rgba8Uint | Tf::Rgba16Uint | Tf::Rgba32Uint | Tf::Rgb10a2Uint => {
953 (NumericDimension::Vector(Vs::Quad), Scalar::U32)
954 }
955 Tf::Rgba8Sint | Tf::Rgba16Sint | Tf::Rgba32Sint => {
956 (NumericDimension::Vector(Vs::Quad), Scalar::I32)
957 }
958 Tf::Rg11b10Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
959 Tf::Stencil8
960 | Tf::Depth16Unorm
961 | Tf::Depth32Float
962 | Tf::Depth32FloatStencil8
963 | Tf::Depth24Plus
964 | Tf::Depth24PlusStencil8 => {
965 panic!("Unexpected depth format")
966 }
967 Tf::NV12 => panic!("Unexpected nv12 format"),
968 Tf::P010 => panic!("Unexpected p010 format"),
969 Tf::Rgb9e5Ufloat => (NumericDimension::Vector(Vs::Tri), Scalar::F32),
970 Tf::Bc1RgbaUnorm
971 | Tf::Bc1RgbaUnormSrgb
972 | Tf::Bc2RgbaUnorm
973 | Tf::Bc2RgbaUnormSrgb
974 | Tf::Bc3RgbaUnorm
975 | Tf::Bc3RgbaUnormSrgb
976 | Tf::Bc7RgbaUnorm
977 | Tf::Bc7RgbaUnormSrgb
978 | Tf::Etc2Rgb8A1Unorm
979 | Tf::Etc2Rgb8A1UnormSrgb
980 | Tf::Etc2Rgba8Unorm
981 | Tf::Etc2Rgba8UnormSrgb => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
982 Tf::Bc4RUnorm | Tf::Bc4RSnorm | Tf::EacR11Unorm | Tf::EacR11Snorm => {
983 (NumericDimension::Scalar, Scalar::F32)
984 }
985 Tf::Bc5RgUnorm | Tf::Bc5RgSnorm | Tf::EacRg11Unorm | Tf::EacRg11Snorm => {
986 (NumericDimension::Vector(Vs::Bi), Scalar::F32)
987 }
988 Tf::Bc6hRgbUfloat | Tf::Bc6hRgbFloat | Tf::Etc2Rgb8Unorm | Tf::Etc2Rgb8UnormSrgb => {
989 (NumericDimension::Vector(Vs::Tri), Scalar::F32)
990 }
991 Tf::Astc {
992 block: _,
993 channel: _,
994 } => (NumericDimension::Vector(Vs::Quad), Scalar::F32),
995 };
996
997 NumericType {
998 dim,
999 scalar,
1002 }
1003 }
1004
1005 fn is_subtype_of(&self, other: &NumericType) -> bool {
1006 if self.scalar.width > other.scalar.width {
1007 return false;
1008 }
1009 if self.scalar.kind != other.scalar.kind {
1010 return false;
1011 }
1012 match (self.dim, other.dim) {
1013 (NumericDimension::Scalar, NumericDimension::Scalar) => true,
1014 (NumericDimension::Scalar, NumericDimension::Vector(_)) => true,
1015 (NumericDimension::Vector(s0), NumericDimension::Vector(s1)) => s0 <= s1,
1016 (NumericDimension::Matrix(c0, r0), NumericDimension::Matrix(c1, r1)) => {
1017 c0 == c1 && r0 == r1
1018 }
1019 _ => false,
1020 }
1021 }
1022}
1023
1024pub fn check_texture_format(
1026 format: wgt::TextureFormat,
1027 output: &NumericType,
1028) -> Result<(), NumericType> {
1029 let nt = NumericType::from_texture_format(format);
1030 if nt.is_subtype_of(output) {
1031 Ok(())
1032 } else {
1033 Err(nt)
1034 }
1035}
1036
1037pub enum BindingLayoutSource {
1038 Derived(Box<ArrayVec<bgl::EntryMap, { hal::MAX_BIND_GROUPS }>>),
1042 Provided(Arc<crate::binding_model::PipelineLayout>),
1046}
1047
1048impl BindingLayoutSource {
1049 pub fn new_derived(limits: &wgt::Limits) -> Self {
1050 let mut array = ArrayVec::new();
1051 for _ in 0..limits.max_bind_groups {
1052 array.push(Default::default());
1053 }
1054 BindingLayoutSource::Derived(Box::new(array))
1055 }
1056}
1057
1058#[derive(Debug, Clone, Default)]
1059pub struct StageIo {
1060 pub varyings: FastHashMap<wgt::ShaderLocation, InterfaceVar>,
1061 pub task_payload_size: Option<u32>,
1063 pub primitive_index: Option<bool>,
1069}
1070
1071impl Interface {
1072 fn populate(
1073 list: &mut Vec<Varying>,
1074 binding: Option<&naga::Binding>,
1075 ty: naga::Handle<naga::Type>,
1076 arena: &naga::UniqueArena<naga::Type>,
1077 ) {
1078 let numeric_ty = match arena[ty].inner {
1079 naga::TypeInner::Scalar(scalar) => NumericType {
1080 dim: NumericDimension::Scalar,
1081 scalar,
1082 },
1083 naga::TypeInner::Vector { size, scalar } => NumericType {
1084 dim: NumericDimension::Vector(size),
1085 scalar,
1086 },
1087 naga::TypeInner::Matrix {
1088 columns,
1089 rows,
1090 scalar,
1091 } => NumericType {
1092 dim: NumericDimension::Matrix(columns, rows),
1093 scalar,
1094 },
1095 naga::TypeInner::Struct { ref members, .. } => {
1096 for member in members {
1097 Self::populate(list, member.binding.as_ref(), member.ty, arena);
1098 }
1099 return;
1100 }
1101 naga::TypeInner::Array { base, size, stride }
1102 if matches!(
1103 binding,
1104 Some(naga::Binding::BuiltIn(naga::BuiltIn::ClipDistances)),
1105 ) =>
1106 {
1107 debug_assert_eq!(
1109 &arena[base].inner,
1110 &naga::TypeInner::Scalar(naga::Scalar::F32)
1111 );
1112 debug_assert_eq!(stride, 4);
1113
1114 let naga::ArraySize::Constant(array_size) = size else {
1115 unreachable!("non-constant array size for `clip_distances`")
1122 };
1123 let array_size = array_size.get();
1124
1125 list.push(Varying::BuiltIn(BuiltIn::ClipDistances { array_size }));
1126 return;
1127 }
1128 ref other => {
1129 log::debug!("Unexpected varying type: {other:?}");
1135 return;
1136 }
1137 };
1138
1139 let varying = match binding {
1140 Some(&naga::Binding::Location {
1141 location,
1142 interpolation,
1143 sampling,
1144 per_primitive,
1145 blend_src: _,
1146 }) => Varying::Local {
1147 location,
1148 iv: InterfaceVar {
1149 ty: numeric_ty,
1150 interpolation,
1151 sampling,
1152 per_primitive,
1153 },
1154 },
1155 Some(&naga::Binding::BuiltIn(built_in)) => Varying::BuiltIn(match built_in {
1156 naga::BuiltIn::Position { invariant } => BuiltIn::Position { invariant },
1157 naga::BuiltIn::ViewIndex => BuiltIn::ViewIndex,
1158 naga::BuiltIn::BaseInstance => BuiltIn::BaseInstance,
1159 naga::BuiltIn::BaseVertex => BuiltIn::BaseVertex,
1160 naga::BuiltIn::ClipDistances => unreachable!(),
1161 naga::BuiltIn::CullDistance => BuiltIn::CullDistance,
1162 naga::BuiltIn::InstanceIndex => BuiltIn::InstanceIndex,
1163 naga::BuiltIn::PointSize => BuiltIn::PointSize,
1164 naga::BuiltIn::VertexIndex => BuiltIn::VertexIndex,
1165 naga::BuiltIn::DrawIndex => BuiltIn::DrawIndex,
1166 naga::BuiltIn::FragDepth => BuiltIn::FragDepth,
1167 naga::BuiltIn::PointCoord => BuiltIn::PointCoord,
1168 naga::BuiltIn::FrontFacing => BuiltIn::FrontFacing,
1169 naga::BuiltIn::PrimitiveIndex => BuiltIn::PrimitiveIndex,
1170 naga::BuiltIn::Barycentric { perspective } => BuiltIn::Barycentric { perspective },
1171 naga::BuiltIn::SampleIndex => BuiltIn::SampleIndex,
1172 naga::BuiltIn::SampleMask => BuiltIn::SampleMask,
1173 naga::BuiltIn::GlobalInvocationId => BuiltIn::GlobalInvocationId,
1174 naga::BuiltIn::LocalInvocationId => BuiltIn::LocalInvocationId,
1175 naga::BuiltIn::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
1176 naga::BuiltIn::WorkGroupId => BuiltIn::WorkGroupId,
1177 naga::BuiltIn::WorkGroupSize => BuiltIn::WorkGroupSize,
1178 naga::BuiltIn::NumWorkGroups => BuiltIn::NumWorkGroups,
1179 naga::BuiltIn::NumSubgroups => BuiltIn::NumSubgroups,
1180 naga::BuiltIn::SubgroupId => BuiltIn::SubgroupId,
1181 naga::BuiltIn::SubgroupSize => BuiltIn::SubgroupSize,
1182 naga::BuiltIn::SubgroupInvocationId => BuiltIn::SubgroupInvocationId,
1183 naga::BuiltIn::MeshTaskSize => BuiltIn::MeshTaskSize,
1184 naga::BuiltIn::CullPrimitive => BuiltIn::CullPrimitive,
1185 naga::BuiltIn::PointIndex => BuiltIn::PointIndex,
1186 naga::BuiltIn::LineIndices => BuiltIn::LineIndices,
1187 naga::BuiltIn::TriangleIndices => BuiltIn::TriangleIndices,
1188 naga::BuiltIn::VertexCount => BuiltIn::VertexCount,
1189 naga::BuiltIn::Vertices => BuiltIn::Vertices,
1190 naga::BuiltIn::PrimitiveCount => BuiltIn::PrimitiveCount,
1191 naga::BuiltIn::Primitives => BuiltIn::Primitives,
1192 naga::BuiltIn::RayInvocationId => BuiltIn::RayInvocationId,
1193 naga::BuiltIn::NumRayInvocations => BuiltIn::NumRayInvocations,
1194 naga::BuiltIn::InstanceCustomData => BuiltIn::InstanceCustomData,
1195 naga::BuiltIn::GeometryIndex => BuiltIn::GeometryIndex,
1196 naga::BuiltIn::WorldRayOrigin => BuiltIn::WorldRayOrigin,
1197 naga::BuiltIn::WorldRayDirection => BuiltIn::WorldRayDirection,
1198 naga::BuiltIn::ObjectRayOrigin => BuiltIn::ObjectRayOrigin,
1199 naga::BuiltIn::ObjectRayDirection => BuiltIn::ObjectRayDirection,
1200 naga::BuiltIn::RayTmin => BuiltIn::RayTmin,
1201 naga::BuiltIn::RayTCurrentMax => BuiltIn::RayTCurrentMax,
1202 naga::BuiltIn::ObjectToWorld => BuiltIn::ObjectToWorld,
1203 naga::BuiltIn::WorldToObject => BuiltIn::WorldToObject,
1204 naga::BuiltIn::HitKind => BuiltIn::HitKind,
1205 }),
1206 None => {
1207 log::error!("Missing binding for a varying");
1208 return;
1209 }
1210 };
1211 list.push(varying);
1212 }
1213
1214 pub fn new(module: &naga::Module, info: &naga::valid::ModuleInfo, limits: wgt::Limits) -> Self {
1215 let mut resources = naga::Arena::new();
1216 let mut resource_mapping = FastHashMap::default();
1217 for (var_handle, var) in module.global_variables.iter() {
1218 let bind = match var.binding {
1219 Some(br) => br,
1220 _ => continue,
1221 };
1222 let naga_ty = &module.types[var.ty].inner;
1223
1224 let inner_ty = match *naga_ty {
1225 naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
1226 ref ty => ty,
1227 };
1228
1229 let ty = match *inner_ty {
1230 naga::TypeInner::Image {
1231 dim,
1232 arrayed,
1233 class,
1234 } => ResourceType::Texture {
1235 dim,
1236 arrayed,
1237 class,
1238 },
1239 naga::TypeInner::Sampler { comparison } => ResourceType::Sampler { comparison },
1240 naga::TypeInner::AccelerationStructure { vertex_return } => {
1241 ResourceType::AccelerationStructure { vertex_return }
1242 }
1243 ref other => ResourceType::Buffer {
1244 size: wgt::BufferSize::new(other.size(module.to_ctx()) as u64).unwrap(),
1245 },
1246 };
1247 let handle = resources.append(
1248 Resource {
1249 name: var.name.clone(),
1250 bind,
1251 ty,
1252 class: var.space,
1253 },
1254 Default::default(),
1255 );
1256 resource_mapping.insert(var_handle, handle);
1257 }
1258
1259 let immediate_size = naga::valid::ImmediateSlots::size_for_module(module);
1260
1261 let mut entry_points = FastHashMap::default();
1262 entry_points.reserve(module.entry_points.len());
1263 for (index, entry_point) in module.entry_points.iter().enumerate() {
1264 let info = info.get_entry_point(index);
1265 let mut ep = EntryPoint::default();
1266 for arg in entry_point.function.arguments.iter() {
1267 Self::populate(&mut ep.inputs, arg.binding.as_ref(), arg.ty, &module.types);
1268 }
1269 if let Some(ref result) = entry_point.function.result {
1270 Self::populate(
1271 &mut ep.outputs,
1272 result.binding.as_ref(),
1273 result.ty,
1274 &module.types,
1275 );
1276 }
1277
1278 for (var_handle, var) in module.global_variables.iter() {
1279 let usage = info[var_handle];
1280 if !usage.is_empty() && var.binding.is_some() {
1281 ep.resources.push(resource_mapping[&var_handle]);
1282 }
1283 }
1284
1285 for key in info.sampling_set.iter() {
1286 ep.sampling_pairs
1287 .insert((resource_mapping[&key.image], resource_mapping[&key.sampler]));
1288 }
1289 ep.dual_source_blending = info.dual_source_blending;
1290 ep.workgroup_size = entry_point.workgroup_size;
1291 ep.immediate_slots_required = info.immediate_slots_used;
1292
1293 if let Some(task_payload) = entry_point.task_payload {
1294 ep.task_payload_size = Some(
1295 module.types[module.global_variables[task_payload].ty]
1296 .inner
1297 .size(module.to_ctx()),
1298 );
1299 }
1300 if let Some(ref mesh_info) = entry_point.mesh_info {
1301 ep.mesh_info = Some(EntryPointMeshInfo {
1302 max_vertices: mesh_info.max_vertices,
1303 max_primitives: mesh_info.max_primitives,
1304 primitive_topology: match mesh_info.topology {
1305 naga::MeshOutputTopology::Triangles => wgt::PrimitiveTopology::TriangleList,
1306 naga::MeshOutputTopology::Lines => wgt::PrimitiveTopology::LineList,
1307 naga::MeshOutputTopology::Points => wgt::PrimitiveTopology::PointList,
1308 },
1309 });
1310 Self::populate(
1311 &mut ep.outputs,
1312 None,
1313 mesh_info.vertex_output_type,
1314 &module.types,
1315 );
1316 Self::populate(
1317 &mut ep.outputs,
1318 None,
1319 mesh_info.primitive_output_type,
1320 &module.types,
1321 );
1322 }
1323
1324 entry_points.insert((entry_point.stage, entry_point.name.clone()), ep);
1325 }
1326
1327 Self {
1328 limits,
1329 resources,
1330 entry_points,
1331 immediate_size,
1332 }
1333 }
1334
1335 pub fn immediate_slots_required(
1336 &self,
1337 stage: naga::ShaderStage,
1338 entry_point_name: &str,
1339 ) -> naga::valid::ImmediateSlots {
1340 self.entry_points
1341 .get(&(stage, entry_point_name.to_string()))
1342 .map_or(Default::default(), |ep| ep.immediate_slots_required)
1343 }
1344
1345 pub fn finalize_entry_point_name(
1346 &self,
1347 stage: naga::ShaderStage,
1348 entry_point_name: Option<&str>,
1349 ) -> Result<String, StageError> {
1350 entry_point_name
1351 .map(|ep| ep.to_string())
1352 .map(Ok)
1353 .unwrap_or_else(|| {
1354 let mut entry_points = self
1355 .entry_points
1356 .keys()
1357 .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
1358 let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1359 if entry_points.next().is_some() {
1360 return Err(StageError::MultipleEntryPointsFound);
1361 }
1362 Ok(first.clone())
1363 })
1364 }
1365
1366 pub fn check_stage(
1369 &self,
1370 layouts: &mut BindingLayoutSource,
1371 shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
1372 entry_point_name: &str,
1373 shader_stage: ShaderStageForValidation,
1374 inputs: StageIo,
1375 primitive_topology: Option<wgt::PrimitiveTopology>,
1376 ) -> Result<StageIo, StageError> {
1377 let pair = (shader_stage.to_naga(), entry_point_name.to_string());
1380 let entry_point = match self.entry_points.get(&pair) {
1381 Some(some) => some,
1382 None => return Err(StageError::MissingEntryPoint(pair.1)),
1383 };
1384 let (_, entry_point_name) = pair;
1385
1386 let stage_bit = shader_stage.to_wgt_bit();
1387
1388 for &handle in entry_point.resources.iter() {
1390 let res = &self.resources[handle];
1391 let result = 'err: {
1392 match layouts {
1393 BindingLayoutSource::Provided(pipeline_layout) => {
1394 if let ResourceType::Buffer { size } = res.ty {
1396 match shader_binding_sizes.entry(res.bind) {
1397 Entry::Occupied(e) => {
1398 *e.into_mut() = size.max(*e.get());
1399 }
1400 Entry::Vacant(e) => {
1401 e.insert(size);
1402 }
1403 }
1404 }
1405
1406 let Some(entry) =
1407 pipeline_layout.get_bgl_entry(res.bind.group, res.bind.binding)
1408 else {
1409 break 'err Err(BindingError::Missing);
1410 };
1411
1412 if !entry.visibility.contains(stage_bit) {
1413 break 'err Err(BindingError::Invisible);
1414 }
1415
1416 res.check_binding_use(entry)
1417 }
1418 BindingLayoutSource::Derived(layouts) => {
1419 let Some(map) = layouts.get_mut(res.bind.group as usize) else {
1420 break 'err Err(BindingError::Missing);
1421 };
1422
1423 let ty = match res.derive_binding_type(
1424 entry_point
1425 .sampling_pairs
1426 .iter()
1427 .any(|&(im, _samp)| im == handle),
1428 ) {
1429 Ok(ty) => ty,
1430 Err(error) => break 'err Err(error),
1431 };
1432
1433 match map.entry(res.bind.binding) {
1434 indexmap::map::Entry::Occupied(e) if e.get().ty != ty => {
1435 break 'err Err(BindingError::InconsistentlyDerivedType)
1436 }
1437 indexmap::map::Entry::Occupied(e) => {
1438 e.into_mut().visibility |= stage_bit;
1439 }
1440 indexmap::map::Entry::Vacant(e) => {
1441 e.insert(BindGroupLayoutEntry {
1442 binding: res.bind.binding,
1443 ty,
1444 visibility: stage_bit,
1445 count: None,
1446 });
1447 }
1448 }
1449 Ok(())
1450 }
1451 }
1452 };
1453 if let Err(error) = result {
1454 return Err(StageError::Binding(res.bind, error));
1455 }
1456 }
1457
1458 if let BindingLayoutSource::Provided(pipeline_layout) = layouts {
1463 for &(texture_handle, sampler_handle) in entry_point.sampling_pairs.iter() {
1464 let texture_bind = &self.resources[texture_handle].bind;
1465 let sampler_bind = &self.resources[sampler_handle].bind;
1466 let texture_layout = pipeline_layout
1467 .get_bgl_entry(texture_bind.group, texture_bind.binding)
1468 .unwrap();
1469 let sampler_layout = pipeline_layout
1470 .get_bgl_entry(sampler_bind.group, sampler_bind.binding)
1471 .unwrap();
1472 assert!(texture_layout.visibility.contains(stage_bit));
1473 assert!(sampler_layout.visibility.contains(stage_bit));
1474
1475 let sampler_filtering = matches!(
1476 sampler_layout.ty,
1477 BindingType::Sampler(wgt::SamplerBindingType::Filtering)
1478 );
1479 let texture_sample_type = match texture_layout.ty {
1480 BindingType::Texture { sample_type, .. } => sample_type,
1481 BindingType::ExternalTexture => {
1482 wgt::TextureSampleType::Float { filterable: true }
1483 }
1484 _ => unreachable!(),
1485 };
1486
1487 let error = match (sampler_filtering, texture_sample_type) {
1488 (true, wgt::TextureSampleType::Float { filterable: false }) => {
1489 Some(FilteringError::Float)
1490 }
1491 (true, wgt::TextureSampleType::Sint) => Some(FilteringError::Integer),
1492 (true, wgt::TextureSampleType::Uint) => Some(FilteringError::Integer),
1493 _ => None,
1494 };
1495
1496 if let Some(error) = error {
1497 return Err(StageError::Filtering {
1498 texture: *texture_bind,
1499 sampler: *sampler_bind,
1500 error,
1501 });
1502 }
1503 }
1504 }
1505
1506 if shader_stage.to_naga().compute_like() {
1508 let workgroup_size_check = match shader_stage.to_naga() {
1509 naga::ShaderStage::Compute => WorkgroupSizeCheck {
1510 dimensions: &entry_point.workgroup_size,
1511 per_dimension_limits: &[
1512 self.limits.max_compute_workgroup_size_x,
1513 self.limits.max_compute_workgroup_size_y,
1514 self.limits.max_compute_workgroup_size_z,
1515 ],
1516 per_dimension_limits_desc: "max_compute_workgroup_size_*",
1517
1518 total_limit: self.limits.max_compute_invocations_per_workgroup,
1519 total_limit_desc: "max_compute_invocations_per_workgroup",
1520 },
1521 naga::ShaderStage::Task => WorkgroupSizeCheck {
1522 dimensions: &entry_point.workgroup_size,
1523 per_dimension_limits: &[self.limits.max_task_invocations_per_dimension; 3],
1524 per_dimension_limits_desc: "max_task_invocations_per_dimension",
1525
1526 total_limit: self.limits.max_task_invocations_per_workgroup,
1527 total_limit_desc: "max_task_invocations_per_workgroup",
1528 },
1529 naga::ShaderStage::Mesh => WorkgroupSizeCheck {
1530 dimensions: &entry_point.workgroup_size,
1531 per_dimension_limits: &[self.limits.max_mesh_invocations_per_dimension; 3],
1532 per_dimension_limits_desc: "max_mesh_invocations_per_dimension",
1533
1534 total_limit: self.limits.max_mesh_invocations_per_workgroup,
1535 total_limit_desc: "max_mesh_invocations_per_workgroup",
1536 },
1537 _ => unreachable!(),
1538 };
1539 let total = workgroup_size_check.check_and_compute_total_invocations()?;
1540 if total == 0 {
1541 return Err(StageError::InvalidWorkgroupSize(
1542 InvalidWorkgroupSizeError::Zero {
1543 dimensions: entry_point.workgroup_size,
1544 },
1545 ));
1546 }
1547 }
1548
1549 let mut this_stage_primitive_index = false;
1550 let mut has_draw_id = false;
1551 let mut has_per_vertex = false;
1552
1553 for input in entry_point.inputs.iter() {
1555 match *input {
1556 Varying::Local { location, ref iv } => {
1557 let result = inputs
1558 .varyings
1559 .get(&location)
1560 .ok_or(InputError::Missing)
1561 .and_then(|provided| {
1562 let (compatible, per_primitive_correct) = match shader_stage.to_naga() {
1563 naga::ShaderStage::Vertex => {
1566 let is_compatible =
1567 iv.ty.scalar.kind == provided.ty.scalar.kind;
1568 (is_compatible, !iv.per_primitive)
1570 }
1571 naga::ShaderStage::Fragment => {
1572 if iv.interpolation != provided.interpolation {
1573 return Err(InputError::InterpolationMismatch(
1574 provided.interpolation,
1575 ));
1576 }
1577 if iv.sampling != provided.sampling {
1578 return Err(InputError::SamplingMismatch(
1579 provided.sampling,
1580 ));
1581 }
1582 (
1583 iv.ty.is_subtype_of(&provided.ty),
1584 iv.per_primitive == provided.per_primitive,
1585 )
1586 }
1587 naga::ShaderStage::Compute
1589 | naga::ShaderStage::Task
1590 | naga::ShaderStage::Mesh => (false, false),
1591 naga::ShaderStage::RayGeneration
1592 | naga::ShaderStage::AnyHit
1593 | naga::ShaderStage::ClosestHit
1594 | naga::ShaderStage::Miss => {
1595 unreachable!()
1596 }
1597 };
1598 if !compatible {
1599 return Err(InputError::WrongType(provided.ty));
1600 } else if !per_primitive_correct {
1601 return Err(InputError::WrongPerPrimitive {
1602 pipeline_input: provided.per_primitive,
1603 shader: iv.per_primitive,
1604 });
1605 }
1606 Ok(())
1607 });
1608
1609 if let Err(error) = result {
1610 return Err(StageError::Input {
1611 location,
1612 var: iv.clone(),
1613 error,
1614 });
1615 }
1616 has_per_vertex |= iv.interpolation == Some(naga::Interpolation::PerVertex);
1617 }
1618 Varying::BuiltIn(BuiltIn::PrimitiveIndex) => {
1619 this_stage_primitive_index = true;
1620 }
1621 Varying::BuiltIn(BuiltIn::DrawIndex) => {
1622 has_draw_id = true;
1623 }
1624 Varying::BuiltIn(_) => {}
1625 }
1626 }
1627
1628 match shader_stage {
1629 ShaderStageForValidation::Vertex {
1630 topology,
1631 compare_function,
1632 } => {
1633 let mut max_vertex_shader_output_variables =
1634 self.limits.max_inter_stage_shader_variables;
1635 let mut max_vertex_shader_output_location = max_vertex_shader_output_variables - 1;
1636
1637 let point_list_deduction = if topology == wgt::PrimitiveTopology::PointList {
1638 Some(MaxVertexShaderOutputDeduction::PointListPrimitiveTopology)
1639 } else {
1640 None
1641 };
1642
1643 let clip_distance_deductions = entry_point.outputs.iter().filter_map(|output| {
1644 if let &Varying::BuiltIn(BuiltIn::ClipDistances { array_size }) = output {
1645 Some(MaxVertexShaderOutputDeduction::ClipDistances { array_size })
1646 } else {
1647 None
1648 }
1649 });
1650 debug_assert!(
1651 clip_distance_deductions.clone().count() <= 1,
1652 "multiple `clip_distances` outputs found"
1653 );
1654
1655 let deductions = point_list_deduction
1656 .into_iter()
1657 .chain(clip_distance_deductions);
1658
1659 for deduction in deductions.clone() {
1660 max_vertex_shader_output_variables = max_vertex_shader_output_variables
1663 .checked_sub(deduction.for_variables())
1664 .unwrap();
1665 max_vertex_shader_output_location = max_vertex_shader_output_location
1666 .checked_sub(deduction.for_location())
1667 .unwrap();
1668 }
1669
1670 let mut num_user_defined_outputs = 0;
1671
1672 for output in entry_point.outputs.iter() {
1673 match *output {
1674 Varying::Local { ref iv, location } => {
1675 if location > max_vertex_shader_output_location {
1676 return Err(StageError::VertexOutputLocationTooLarge {
1677 location,
1678 var: iv.clone(),
1679 limit: self.limits.max_inter_stage_shader_variables,
1680 deductions: deductions.collect(),
1681 });
1682 }
1683 num_user_defined_outputs += 1;
1684 }
1685 Varying::BuiltIn(_) => {}
1686 };
1687
1688 if let Some(
1689 cmp @ wgt::CompareFunction::Equal | cmp @ wgt::CompareFunction::NotEqual,
1690 ) = compare_function
1691 {
1692 if let Varying::BuiltIn(BuiltIn::Position { invariant: false }) = *output {
1693 log::warn!(
1694 concat!(
1695 "Vertex shader with entry point {} outputs a ",
1696 "@builtin(position) without the @invariant attribute and ",
1697 "is used in a pipeline with {cmp:?}. On some machines, ",
1698 "this can cause bad artifacting as {cmp:?} assumes the ",
1699 "values output from the vertex shader exactly match the ",
1700 "value in the depth buffer. The @invariant attribute on the ",
1701 "@builtin(position) vertex output ensures that the exact ",
1702 "same pixel depths are used every render."
1703 ),
1704 entry_point_name,
1705 cmp = cmp
1706 );
1707 }
1708 }
1709 }
1710
1711 if num_user_defined_outputs > max_vertex_shader_output_variables {
1712 return Err(StageError::TooManyUserDefinedVertexOutputs {
1713 num_found: num_user_defined_outputs,
1714 limit: self.limits.max_inter_stage_shader_variables,
1715 deductions: deductions.collect(),
1716 });
1717 }
1718 }
1719 ShaderStageForValidation::Fragment {
1720 dual_source_blending,
1721 has_depth_attachment,
1722 } => {
1723 let mut max_fragment_shader_input_variables =
1724 self.limits.max_inter_stage_shader_variables;
1725
1726 let deductions = entry_point.inputs.iter().filter_map(|output| match output {
1727 Varying::Local { .. } => None,
1728 Varying::BuiltIn(builtin) => {
1729 MaxFragmentShaderInputDeduction::from_inter_stage_builtin(builtin.to_naga())
1730 .or_else(|| {
1731 unreachable!(
1732 concat!(
1733 "unexpected built-in provided; ",
1734 "{:?} is not used for fragment stage input",
1735 ),
1736 builtin
1737 )
1738 })
1739 }
1740 });
1741
1742 for deduction in deductions.clone() {
1743 max_fragment_shader_input_variables = max_fragment_shader_input_variables
1746 .checked_sub(deduction.for_variables())
1747 .unwrap();
1748 }
1749
1750 let mut num_user_defined_inputs = 0;
1751
1752 for output in entry_point.inputs.iter() {
1753 match *output {
1754 Varying::Local { ref iv, location } => {
1755 if location >= self.limits.max_inter_stage_shader_variables {
1756 return Err(StageError::FragmentInputLocationTooLarge {
1757 location,
1758 var: iv.clone(),
1759 limit: self.limits.max_inter_stage_shader_variables,
1760 deductions: deductions.collect(),
1761 });
1762 }
1763 num_user_defined_inputs += 1;
1764 }
1765 Varying::BuiltIn(_) => {}
1766 };
1767 }
1768
1769 if num_user_defined_inputs > max_fragment_shader_input_variables {
1770 return Err(StageError::TooManyUserDefinedFragmentInputs {
1771 num_found: num_user_defined_inputs,
1772 limit: self.limits.max_inter_stage_shader_variables,
1773 deductions: deductions.collect(),
1774 });
1775 }
1776
1777 for output in &entry_point.outputs {
1778 let &Varying::Local { location, ref iv } = output else {
1779 continue;
1780 };
1781 if location >= self.limits.max_color_attachments {
1782 return Err(StageError::ColorAttachmentLocationTooLarge {
1783 location,
1784 var: iv.clone(),
1785 limit: self.limits.max_color_attachments,
1786 });
1787 }
1788 }
1789
1790 if dual_source_blending && !entry_point.dual_source_blending {
1795 return Err(StageError::InvalidDualSourceBlending);
1796 }
1797
1798 if entry_point
1799 .outputs
1800 .contains(&Varying::BuiltIn(BuiltIn::FragDepth))
1801 && !has_depth_attachment
1802 {
1803 return Err(StageError::MissingFragDepthAttachment);
1804 }
1805 }
1806 ShaderStageForValidation::Mesh => {
1807 for output in &entry_point.outputs {
1808 if matches!(output, Varying::BuiltIn(BuiltIn::PrimitiveIndex)) {
1809 this_stage_primitive_index = true;
1810 }
1811 }
1812 }
1813 _ => (),
1814 }
1815
1816 if let Some(ref mesh_info) = entry_point.mesh_info {
1817 if mesh_info.max_vertices > self.limits.max_mesh_output_vertices {
1818 return Err(StageError::TooManyMeshVertices {
1819 limit: self.limits.max_mesh_output_vertices,
1820 value: mesh_info.max_vertices,
1821 });
1822 }
1823 if mesh_info.max_primitives > self.limits.max_mesh_output_primitives {
1824 return Err(StageError::TooManyMeshPrimitives {
1825 limit: self.limits.max_mesh_output_primitives,
1826 value: mesh_info.max_primitives,
1827 });
1828 }
1829 if primitive_topology != Some(mesh_info.primitive_topology) {
1830 return Err(StageError::MeshTopologyMismatch);
1831 }
1832 }
1833 if let Some(task_payload_size) = entry_point.task_payload_size {
1834 if task_payload_size > self.limits.max_task_payload_size {
1835 return Err(StageError::TaskPayloadTooLarge {
1836 limit: self.limits.max_task_payload_size,
1837 value: task_payload_size,
1838 });
1839 }
1840 }
1841 if shader_stage.to_naga() == naga::ShaderStage::Mesh
1842 && entry_point.task_payload_size != inputs.task_payload_size
1843 {
1844 return Err(StageError::TaskPayloadMustMatch {
1845 input: inputs.task_payload_size,
1846 shader: entry_point.task_payload_size,
1847 });
1848 }
1849
1850 if shader_stage.to_naga() == naga::ShaderStage::Fragment
1852 && this_stage_primitive_index
1853 && inputs.primitive_index == Some(false)
1854 {
1855 return Err(StageError::InvalidPrimitiveIndex);
1856 } else if shader_stage.to_naga() == naga::ShaderStage::Fragment
1857 && !this_stage_primitive_index
1858 && inputs.primitive_index == Some(true)
1859 {
1860 return Err(StageError::MissingPrimitiveIndex);
1861 }
1862 if shader_stage.to_naga() == naga::ShaderStage::Mesh
1863 && inputs.task_payload_size.is_some()
1864 && has_draw_id
1865 {
1866 return Err(StageError::DrawIdError);
1867 }
1868
1869 if primitive_topology.is_none_or(|e| !e.is_triangles()) && has_per_vertex {
1870 return Err(StageError::PerVertexNotTriangles);
1871 }
1872
1873 let outputs = entry_point
1874 .outputs
1875 .iter()
1876 .filter_map(|output| match *output {
1877 Varying::Local { location, ref iv } => Some((location, iv.clone())),
1878 Varying::BuiltIn(_) => None,
1879 })
1880 .collect();
1881
1882 Ok(StageIo {
1883 task_payload_size: entry_point.task_payload_size,
1884 varyings: outputs,
1885 primitive_index: if shader_stage.to_naga() == naga::ShaderStage::Mesh {
1886 Some(this_stage_primitive_index)
1887 } else {
1888 None
1889 },
1890 })
1891 }
1892
1893 pub fn fragment_uses_dual_source_blending(
1894 &self,
1895 entry_point_name: &str,
1896 ) -> Result<bool, StageError> {
1897 let pair = (naga::ShaderStage::Fragment, entry_point_name.to_string());
1898 self.entry_points
1899 .get(&pair)
1900 .ok_or(StageError::MissingEntryPoint(pair.1))
1901 .map(|ep| ep.dual_source_blending)
1902 }
1903}
1904
1905pub fn check_color_attachment_count(
1906 num_attachments: usize,
1907 limit: u32,
1908) -> Result<(), ColorAttachmentError> {
1909 let limit = usize::try_from(limit).unwrap();
1910 if num_attachments > limit {
1911 return Err(ColorAttachmentError::TooMany {
1912 given: num_attachments,
1913 limit,
1914 });
1915 }
1916
1917 Ok(())
1918}
1919
1920pub fn validate_color_attachment_bytes_per_sample(
1926 attachment_formats: impl IntoIterator<Item = wgt::TextureFormat>,
1927 limit: u32,
1928) -> Result<(), ColorAttachmentError> {
1929 let mut total_bytes_per_sample: u32 = 0;
1930 for format in attachment_formats {
1931 let byte_cost = format.target_pixel_byte_cost().unwrap();
1932 let alignment = format.target_component_alignment().unwrap();
1933
1934 total_bytes_per_sample = total_bytes_per_sample.next_multiple_of(alignment);
1935 total_bytes_per_sample += byte_cost;
1936 }
1937
1938 if total_bytes_per_sample > limit {
1939 return Err(ColorAttachmentError::TooManyBytesPerSample {
1940 total: total_bytes_per_sample,
1941 limit,
1942 });
1943 }
1944
1945 Ok(())
1946}
1947
1948#[derive(Clone, Debug, Error)]
1949pub enum InvalidWorkgroupSizeError {
1950 #[error(
1951 "Workgroup size {dimensions:?} ({total} total invocations) must be less or equal to \
1952 the per-dimension limit `Limits::{per_dimension_limits_desc}` of {per_dimension_limits:?} \
1953 and the total invocation limit `Limits::{total_limit_desc}` of {total_limit}"
1954 )]
1955 LimitExceeded {
1956 dimensions: [u32; 3],
1957 per_dimension_limits: [u32; 3],
1958 per_dimension_limits_desc: &'static str,
1959 total: u32,
1960 total_limit: u32,
1961 total_limit_desc: &'static str,
1962 },
1963 #[error("Workgroup sizes {dimensions:?} must be positive")]
1964 Zero { dimensions: [u32; 3] },
1965}
1966
1967#[derive(Clone, Debug)]
1970pub(crate) struct WorkgroupSizeCheck<'a> {
1971 pub dimensions: &'a [u32; 3],
1972 pub per_dimension_limits: &'a [u32; 3],
1973 pub per_dimension_limits_desc: &'static str,
1974 pub total_limit: u32,
1975 pub total_limit_desc: &'static str,
1976}
1977
1978impl WorkgroupSizeCheck<'_> {
1979 pub(crate) fn check_and_compute_total_invocations(
1985 self,
1986 ) -> Result<u32, InvalidWorkgroupSizeError> {
1987 let Self {
1988 dimensions,
1989 per_dimension_limits,
1990 per_dimension_limits_desc,
1991 total_limit,
1992 total_limit_desc,
1993 } = self;
1994
1995 let total = dimensions
1996 .iter()
1997 .fold(1u32, |total, &dim| total.saturating_mul(dim));
1998
1999 let invalid_total_invocations = total > total_limit;
2000
2001 let dimension_too_large = dimensions
2002 .iter()
2003 .zip(per_dimension_limits.iter())
2004 .any(|(dim, limit)| dim > limit);
2005
2006 if invalid_total_invocations || dimension_too_large {
2007 Err(InvalidWorkgroupSizeError::LimitExceeded {
2008 dimensions: *dimensions,
2009 per_dimension_limits: *per_dimension_limits,
2010 per_dimension_limits_desc,
2011 total,
2012 total_limit,
2013 total_limit_desc,
2014 })
2015 } else {
2016 Ok(total)
2017 }
2018 }
2019}
2020
2021pub enum ShaderStageForValidation {
2022 Vertex {
2023 topology: wgt::PrimitiveTopology,
2024 compare_function: Option<wgt::CompareFunction>,
2025 },
2026 Mesh,
2027 Fragment {
2028 dual_source_blending: bool,
2029 has_depth_attachment: bool,
2030 },
2031 Compute,
2032 Task,
2033}
2034
2035impl ShaderStageForValidation {
2036 pub fn to_naga(&self) -> naga::ShaderStage {
2037 match self {
2038 Self::Vertex { .. } => naga::ShaderStage::Vertex,
2039 Self::Mesh => naga::ShaderStage::Mesh,
2040 Self::Fragment { .. } => naga::ShaderStage::Fragment,
2041 Self::Compute => naga::ShaderStage::Compute,
2042 Self::Task => naga::ShaderStage::Task,
2043 }
2044 }
2045
2046 pub fn to_wgt_bit(&self) -> wgt::ShaderStages {
2047 match self {
2048 Self::Vertex { .. } => wgt::ShaderStages::VERTEX,
2049 Self::Mesh => wgt::ShaderStages::MESH,
2050 Self::Fragment { .. } => wgt::ShaderStages::FRAGMENT,
2051 Self::Compute => wgt::ShaderStages::COMPUTE,
2052 Self::Task => wgt::ShaderStages::TASK,
2053 }
2054 }
2055}