1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28 concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35 fn from(format: super::StorageFormat) -> Self {
36 use super::{ScalarKind as Sk, StorageFormat as Sf};
37 let kind = match format {
38 Sf::R8Unorm => Sk::Float,
39 Sf::R8Snorm => Sk::Float,
40 Sf::R8Uint => Sk::Uint,
41 Sf::R8Sint => Sk::Sint,
42 Sf::R16Uint => Sk::Uint,
43 Sf::R16Sint => Sk::Sint,
44 Sf::R16Float => Sk::Float,
45 Sf::Rg8Unorm => Sk::Float,
46 Sf::Rg8Snorm => Sk::Float,
47 Sf::Rg8Uint => Sk::Uint,
48 Sf::Rg8Sint => Sk::Sint,
49 Sf::R32Uint => Sk::Uint,
50 Sf::R32Sint => Sk::Sint,
51 Sf::R32Float => Sk::Float,
52 Sf::Rg16Uint => Sk::Uint,
53 Sf::Rg16Sint => Sk::Sint,
54 Sf::Rg16Float => Sk::Float,
55 Sf::Rgba8Unorm => Sk::Float,
56 Sf::Rgba8Snorm => Sk::Float,
57 Sf::Rgba8Uint => Sk::Uint,
58 Sf::Rgba8Sint => Sk::Sint,
59 Sf::Bgra8Unorm => Sk::Float,
60 Sf::Rgb10a2Uint => Sk::Uint,
61 Sf::Rgb10a2Unorm => Sk::Float,
62 Sf::Rg11b10Ufloat => Sk::Float,
63 Sf::R64Uint => Sk::Uint,
64 Sf::Rg32Uint => Sk::Uint,
65 Sf::Rg32Sint => Sk::Sint,
66 Sf::Rg32Float => Sk::Float,
67 Sf::Rgba16Uint => Sk::Uint,
68 Sf::Rgba16Sint => Sk::Sint,
69 Sf::Rgba16Float => Sk::Float,
70 Sf::Rgba32Uint => Sk::Uint,
71 Sf::Rgba32Sint => Sk::Sint,
72 Sf::Rgba32Float => Sk::Float,
73 Sf::R16Unorm => Sk::Float,
74 Sf::R16Snorm => Sk::Float,
75 Sf::Rg16Unorm => Sk::Float,
76 Sf::Rg16Snorm => Sk::Float,
77 Sf::Rgba16Unorm => Sk::Float,
78 Sf::Rgba16Snorm => Sk::Float,
79 };
80 let width = match format {
81 Sf::R64Uint => 8,
82 _ => 4,
83 };
84 super::Scalar { kind, width }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90 F64(u64),
91 F32(u32),
92 F16(u16),
93 U32(u32),
94 I32(i32),
95 U64(u64),
96 I64(i64),
97 Bool(bool),
98 AbstractInt(i64),
99 AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103 fn from(l: crate::Literal) -> Self {
104 match l {
105 crate::Literal::F64(v) => Self::F64(v.to_bits()),
106 crate::Literal::F32(v) => Self::F32(v.to_bits()),
107 crate::Literal::F16(v) => Self::F16(v.to_bits()),
108 crate::Literal::U32(v) => Self::U32(v),
109 crate::Literal::I32(v) => Self::I32(v),
110 crate::Literal::U64(v) => Self::U64(v),
111 crate::Literal::I64(v) => Self::I64(v),
112 crate::Literal::Bool(v) => Self::Bool(v),
113 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115 }
116 }
117}
118
119impl crate::Literal {
120 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121 match (value, scalar.kind, scalar.width) {
122 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124 (value, crate::ScalarKind::Float, 2) => {
125 Some(Self::F16(half::f16::from_f32_const(value as _)))
126 }
127 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
128 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
129 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
130 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
131 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
132 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
133 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
134 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
135 _ => None,
136 }
137 }
138
139 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
140 Self::new(0, scalar)
141 }
142
143 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
144 Self::new(1, scalar)
145 }
146
147 pub const fn minus_one(scalar: crate::Scalar) -> Option<Self> {
148 match (scalar.kind, scalar.width) {
149 (crate::ScalarKind::Float, 8) => Some(Self::F64(-1.0)),
150 (crate::ScalarKind::Float, 4) => Some(Self::F32(-1.0)),
151 (crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))),
152 (crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)),
153 (crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)),
154 (crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)),
155 _ => None,
156 }
157 }
158
159 pub const fn width(&self) -> crate::Bytes {
160 match *self {
161 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
162 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
163 Self::F16(_) => 2,
164 Self::Bool(_) => crate::BOOL_WIDTH,
165 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
166 }
167 }
168 pub const fn scalar(&self) -> crate::Scalar {
169 match *self {
170 Self::F64(_) => crate::Scalar::F64,
171 Self::F32(_) => crate::Scalar::F32,
172 Self::F16(_) => crate::Scalar::F16,
173 Self::U32(_) => crate::Scalar::U32,
174 Self::I32(_) => crate::Scalar::I32,
175 Self::U64(_) => crate::Scalar::U64,
176 Self::I64(_) => crate::Scalar::I64,
177 Self::Bool(_) => crate::Scalar::BOOL,
178 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
179 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
180 }
181 }
182 pub const fn scalar_kind(&self) -> crate::ScalarKind {
183 self.scalar().kind
184 }
185 pub const fn ty_inner(&self) -> crate::TypeInner {
186 crate::TypeInner::Scalar(self.scalar())
187 }
188}
189
190impl TryFrom<crate::Literal> for u32 {
191 type Error = ConstValueError;
192
193 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
194 match value {
195 crate::Literal::U32(value) => Ok(value),
196 crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
197 _ => Err(ConstValueError::InvalidType),
198 }
199 }
200}
201
202impl TryFrom<crate::Literal> for bool {
203 type Error = ConstValueError;
204
205 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
206 match value {
207 crate::Literal::Bool(value) => Ok(value),
208 _ => Err(ConstValueError::InvalidType),
209 }
210 }
211}
212
213impl super::AddressSpace {
214 pub fn access(self) -> crate::StorageAccess {
215 use crate::StorageAccess as Sa;
216 match self {
217 crate::AddressSpace::Function
218 | crate::AddressSpace::Private
219 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
220 crate::AddressSpace::Uniform => Sa::LOAD,
221 crate::AddressSpace::Storage { access } => access,
222 crate::AddressSpace::Handle => Sa::LOAD,
223 crate::AddressSpace::Immediate => Sa::LOAD,
224 crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
227 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
228 Sa::LOAD | Sa::STORE
229 }
230 }
231 }
232}
233
234impl super::MathFunction {
235 pub const fn argument_count(&self) -> usize {
236 match *self {
237 Self::Abs => 1,
239 Self::Min => 2,
240 Self::Max => 2,
241 Self::Clamp => 3,
242 Self::Saturate => 1,
243 Self::Cos => 1,
245 Self::Cosh => 1,
246 Self::Sin => 1,
247 Self::Sinh => 1,
248 Self::Tan => 1,
249 Self::Tanh => 1,
250 Self::Acos => 1,
251 Self::Asin => 1,
252 Self::Atan => 1,
253 Self::Atan2 => 2,
254 Self::Asinh => 1,
255 Self::Acosh => 1,
256 Self::Atanh => 1,
257 Self::Radians => 1,
258 Self::Degrees => 1,
259 Self::Ceil => 1,
261 Self::Floor => 1,
262 Self::Round => 1,
263 Self::Fract => 1,
264 Self::Trunc => 1,
265 Self::Modf => 1,
266 Self::Frexp => 1,
267 Self::Ldexp => 2,
268 Self::Exp => 1,
270 Self::Exp2 => 1,
271 Self::Log => 1,
272 Self::Log2 => 1,
273 Self::Pow => 2,
274 Self::Dot => 2,
276 Self::Dot4I8Packed => 2,
277 Self::Dot4U8Packed => 2,
278 Self::Outer => 2,
279 Self::Cross => 2,
280 Self::Distance => 2,
281 Self::Length => 1,
282 Self::Normalize => 1,
283 Self::FaceForward => 3,
284 Self::Reflect => 2,
285 Self::Refract => 3,
286 Self::Sign => 1,
288 Self::Fma => 3,
289 Self::Mix => 3,
290 Self::Step => 2,
291 Self::SmoothStep => 3,
292 Self::Sqrt => 1,
293 Self::InverseSqrt => 1,
294 Self::Inverse => 1,
295 Self::Transpose => 1,
296 Self::Determinant => 1,
297 Self::QuantizeToF16 => 1,
298 Self::CountTrailingZeros => 1,
300 Self::CountLeadingZeros => 1,
301 Self::CountOneBits => 1,
302 Self::ReverseBits => 1,
303 Self::ExtractBits => 3,
304 Self::InsertBits => 4,
305 Self::FirstTrailingBit => 1,
306 Self::FirstLeadingBit => 1,
307 Self::Pack4x8snorm => 1,
309 Self::Pack4x8unorm => 1,
310 Self::Pack2x16snorm => 1,
311 Self::Pack2x16unorm => 1,
312 Self::Pack2x16float => 1,
313 Self::Pack4xI8 => 1,
314 Self::Pack4xU8 => 1,
315 Self::Pack4xI8Clamp => 1,
316 Self::Pack4xU8Clamp => 1,
317 Self::Unpack4x8snorm => 1,
319 Self::Unpack4x8unorm => 1,
320 Self::Unpack2x16snorm => 1,
321 Self::Unpack2x16unorm => 1,
322 Self::Unpack2x16float => 1,
323 Self::Unpack4xI8 => 1,
324 Self::Unpack4xU8 => 1,
325 }
326 }
327}
328
329impl crate::Expression {
330 pub const fn needs_pre_emit(&self) -> bool {
332 match *self {
333 Self::Literal(_)
334 | Self::Constant(_)
335 | Self::Override(_)
336 | Self::ZeroValue(_)
337 | Self::FunctionArgument(_)
338 | Self::GlobalVariable(_)
339 | Self::LocalVariable(_) => true,
340 _ => false,
341 }
342 }
343
344 pub const fn is_dynamic_index(&self) -> bool {
358 match *self {
359 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
360 _ => true,
361 }
362 }
363}
364
365impl crate::Function {
366 pub fn originating_global(
375 &self,
376 mut pointer: crate::Handle<crate::Expression>,
377 ) -> Option<crate::Handle<crate::GlobalVariable>> {
378 loop {
379 pointer = match self.expressions[pointer] {
380 crate::Expression::Access { base, .. } => base,
381 crate::Expression::AccessIndex { base, .. } => base,
382 crate::Expression::GlobalVariable(handle) => return Some(handle),
383 crate::Expression::LocalVariable(_) => return None,
384 crate::Expression::FunctionArgument(_) => return None,
385 _ => unreachable!(),
387 }
388 }
389 }
390}
391
392impl crate::SampleLevel {
393 pub const fn implicit_derivatives(&self) -> bool {
394 match *self {
395 Self::Auto | Self::Bias(_) => true,
396 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
397 }
398 }
399}
400
401impl crate::Binding {
402 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
403 match *self {
404 crate::Binding::BuiltIn(built_in) => Some(built_in),
405 Self::Location { .. } => None,
406 }
407 }
408}
409
410impl super::SwizzleComponent {
411 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
412
413 pub const fn index(&self) -> u32 {
414 match *self {
415 Self::X => 0,
416 Self::Y => 1,
417 Self::Z => 2,
418 Self::W => 3,
419 }
420 }
421 pub const fn from_index(idx: u32) -> Self {
422 match idx {
423 0 => Self::X,
424 1 => Self::Y,
425 2 => Self::Z,
426 _ => Self::W,
427 }
428 }
429}
430
431impl super::ImageClass {
432 pub const fn is_multisampled(self) -> bool {
433 match self {
434 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
435 crate::ImageClass::Storage { .. } => false,
436 crate::ImageClass::External => false,
437 }
438 }
439
440 pub const fn is_mipmapped(self) -> bool {
441 match self {
442 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
443 crate::ImageClass::Storage { .. } => false,
444 crate::ImageClass::External => false,
445 }
446 }
447
448 pub const fn is_depth(self) -> bool {
449 matches!(self, crate::ImageClass::Depth { .. })
450 }
451}
452
453impl crate::Module {
454 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
455 GlobalCtx {
456 types: &self.types,
457 constants: &self.constants,
458 overrides: &self.overrides,
459 global_expressions: &self.global_expressions,
460 }
461 }
462
463 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
464 compare_types(lhs, rhs, &self.types)
465 }
466}
467
468#[derive(Debug)]
469pub enum ConstValueError {
470 NonConst,
471 Negative,
472 InvalidType,
473}
474
475impl From<core::convert::Infallible> for ConstValueError {
476 fn from(_: core::convert::Infallible) -> Self {
477 unreachable!()
478 }
479}
480
481#[derive(Clone, Copy)]
482pub struct GlobalCtx<'a> {
483 pub types: &'a crate::UniqueArena<crate::Type>,
484 pub constants: &'a crate::Arena<crate::Constant>,
485 pub overrides: &'a crate::Arena<crate::Override>,
486 pub global_expressions: &'a crate::Arena<crate::Expression>,
487}
488
489impl GlobalCtx<'_> {
490 #[cfg_attr(
497 not(any(
498 feature = "glsl-in",
499 feature = "spv-in",
500 feature = "wgsl-in",
501 glsl_out,
502 hlsl_out,
503 msl_out,
504 wgsl_out
505 )),
506 allow(dead_code)
507 )]
508 pub(super) fn get_const_val<T, E>(
509 &self,
510 handle: crate::Handle<crate::Expression>,
511 ) -> Result<T, ConstValueError>
512 where
513 T: TryFrom<crate::Literal, Error = E>,
514 E: Into<ConstValueError>,
515 {
516 self.get_const_val_from(handle, self.global_expressions)
517 }
518
519 pub(super) fn get_const_val_from<T, E>(
520 &self,
521 handle: crate::Handle<crate::Expression>,
522 arena: &crate::Arena<crate::Expression>,
523 ) -> Result<T, ConstValueError>
524 where
525 T: TryFrom<crate::Literal, Error = E>,
526 E: Into<ConstValueError>,
527 {
528 fn get(
529 gctx: GlobalCtx,
530 handle: crate::Handle<crate::Expression>,
531 arena: &crate::Arena<crate::Expression>,
532 ) -> Option<crate::Literal> {
533 match arena[handle] {
534 crate::Expression::Literal(literal) => Some(literal),
535 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
536 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
537 _ => None,
538 },
539 _ => None,
540 }
541 }
542 let value = match arena[handle] {
543 crate::Expression::Constant(c) => {
544 get(*self, self.constants[c].init, self.global_expressions)
545 }
546 _ => get(*self, handle, arena),
547 };
548 match value {
549 Some(v) => v.try_into().map_err(Into::into),
550 None => Err(ConstValueError::NonConst),
551 }
552 }
553
554 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
555 compare_types(lhs, rhs, self.types)
556 }
557}
558
559#[derive(Error, Debug, Clone, Copy, PartialEq)]
560pub enum ResolveArraySizeError {
561 #[error("array element count must be positive (> 0)")]
562 ExpectedPositiveArrayLength,
563 #[error("internal: array size override has not been resolved")]
564 NonConstArrayLength,
565}
566
567impl crate::ArraySize {
568 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
578 match *self {
579 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
580 crate::ArraySize::Pending(handle) => {
581 let Some(expr) = gctx.overrides[handle].init else {
582 return Err(ResolveArraySizeError::NonConstArrayLength);
583 };
584 let length = gctx.get_const_val(expr).map_err(|err| match err {
585 ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
586 ConstValueError::Negative | ConstValueError::InvalidType => {
587 ResolveArraySizeError::ExpectedPositiveArrayLength
588 }
589 })?;
590
591 if length == 0 {
592 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
593 }
594
595 Ok(IndexableLength::Known(length))
596 }
597 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
598 }
599 }
600}
601
602pub fn flatten_compose<'arenas>(
615 ty: crate::Handle<crate::Type>,
616 components: &'arenas [crate::Handle<crate::Expression>],
617 expressions: &'arenas crate::Arena<crate::Expression>,
618 types: &'arenas crate::UniqueArena<crate::Type>,
619) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
620 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
626 (size as usize, true)
627 } else {
628 (components.len(), false)
629 };
630
631 fn flatten_compose<'c>(
633 component: &'c crate::Handle<crate::Expression>,
634 is_vector: bool,
635 expressions: &'c crate::Arena<crate::Expression>,
636 ) -> &'c [crate::Handle<crate::Expression>] {
637 if is_vector {
638 if let crate::Expression::Compose {
639 ty: _,
640 components: ref subcomponents,
641 } = expressions[*component]
642 {
643 return subcomponents;
644 }
645 }
646 core::slice::from_ref(component)
647 }
648
649 fn flatten_splat<'c>(
651 component: &'c crate::Handle<crate::Expression>,
652 is_vector: bool,
653 expressions: &'c crate::Arena<crate::Expression>,
654 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
655 let mut expr = *component;
656 let mut count = 1;
657 if is_vector {
658 if let crate::Expression::Splat { size, value } = expressions[expr] {
659 expr = value;
660 count = size as usize;
661 }
662 }
663 core::iter::repeat_n(expr, count)
664 }
665
666 components
673 .iter()
674 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
675 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
676 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
677 .take(size)
678}
679
680impl super::ShaderStage {
681 pub const fn compute_like(self) -> bool {
682 match self {
683 Self::Vertex | Self::Fragment => false,
684 Self::Compute | Self::Task | Self::Mesh => true,
685 Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => false,
686 }
687 }
688
689 pub const fn mesh_like(self) -> bool {
691 match self {
692 Self::Task | Self::Mesh => true,
693 _ => false,
694 }
695 }
696}
697
698#[test]
699fn test_matrix_size() {
700 let module = crate::Module::default();
701 assert_eq!(
702 crate::TypeInner::Matrix {
703 columns: crate::VectorSize::Tri,
704 rows: crate::VectorSize::Tri,
705 scalar: crate::Scalar::F32,
706 }
707 .size(module.to_ctx()),
708 48,
709 );
710}
711
712impl crate::Module {
713 #[allow(clippy::type_complexity)]
723 pub fn analyze_mesh_shader_info(
724 &self,
725 gv: crate::Handle<crate::GlobalVariable>,
726 ) -> (
727 crate::MeshStageInfo,
728 [Option<crate::Handle<crate::Override>>; 2],
729 Option<crate::WithSpan<crate::valid::EntryPointError>>,
730 ) {
731 use crate::span::AddSpan;
732 use crate::valid::EntryPointError;
733 #[derive(Default)]
734 struct OutError {
735 pub inner: Option<EntryPointError>,
736 }
737 impl OutError {
738 pub fn set(&mut self, err: EntryPointError) {
739 if self.inner.is_none() {
740 self.inner = Some(err);
741 }
742 }
743 }
744
745 let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
747 let mut output = crate::MeshStageInfo {
748 topology: crate::MeshOutputTopology::Triangles,
749 max_vertices: 0,
750 max_vertices_override: None,
751 max_primitives: 0,
752 max_primitives_override: None,
753 vertex_output_type: null_type,
754 primitive_output_type: null_type,
755 output_variable: gv,
756 };
757 let mut error = OutError::default();
759 let r#type = &self.types[self.global_variables[gv].ty].inner;
760
761 let mut topology = output.topology;
762 let mut vertex_info = (0, None, null_type);
764 let mut primitive_info = (0, None, null_type);
765
766 match r#type {
767 &crate::TypeInner::Struct { ref members, .. } => {
768 let mut builtins = crate::FastHashSet::default();
769 for member in members {
770 match member.binding {
771 Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
772 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
774 error.set(EntryPointError::BadMeshOutputVariableField);
775 }
776 if builtins.contains(&crate::BuiltIn::VertexCount) {
778 error.set(EntryPointError::BadMeshOutputVariableType);
779 }
780 builtins.insert(crate::BuiltIn::VertexCount);
781 }
782 Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
783 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
785 error.set(EntryPointError::BadMeshOutputVariableField);
786 }
787 if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
789 error.set(EntryPointError::BadMeshOutputVariableType);
790 }
791 builtins.insert(crate::BuiltIn::PrimitiveCount);
792 }
793 Some(crate::Binding::BuiltIn(
794 crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
795 )) => {
796 let ty = &self.types[member.ty].inner;
797 let (a, b, c) = match ty {
799 &crate::TypeInner::Array { base, size, .. } => {
800 let ty = base;
801 let (max, max_override) = match size {
802 crate::ArraySize::Constant(a) => (a.get(), None),
803 crate::ArraySize::Pending(o) => (0, Some(o)),
804 crate::ArraySize::Dynamic => {
805 error.set(EntryPointError::BadMeshOutputVariableField);
806 (0, None)
807 }
808 };
809 (max, max_override, ty)
810 }
811 _ => {
812 error.set(EntryPointError::BadMeshOutputVariableField);
813 (0, None, null_type)
814 }
815 };
816 if matches!(
817 member.binding,
818 Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
819 ) {
820 primitive_info = (a, b, c);
822 match self.types[c].inner {
823 crate::TypeInner::Struct { ref members, .. } => {
824 for member in members {
825 match member.binding {
826 Some(crate::Binding::BuiltIn(
827 crate::BuiltIn::PointIndex,
828 )) => {
829 topology = crate::MeshOutputTopology::Points;
830 }
831 Some(crate::Binding::BuiltIn(
832 crate::BuiltIn::LineIndices,
833 )) => {
834 topology = crate::MeshOutputTopology::Lines;
835 }
836 Some(crate::Binding::BuiltIn(
837 crate::BuiltIn::TriangleIndices,
838 )) => {
839 topology = crate::MeshOutputTopology::Triangles;
840 }
841 _ => (),
842 }
843 }
844 }
845 _ => (),
846 }
847 if builtins.contains(&crate::BuiltIn::Primitives) {
849 error.set(EntryPointError::BadMeshOutputVariableType);
850 }
851 builtins.insert(crate::BuiltIn::Primitives);
852 } else {
853 vertex_info = (a, b, c);
854 if builtins.contains(&crate::BuiltIn::Vertices) {
856 error.set(EntryPointError::BadMeshOutputVariableType);
857 }
858 builtins.insert(crate::BuiltIn::Vertices);
859 }
860 }
861 _ => error.set(EntryPointError::BadMeshOutputVariableType),
862 }
863 }
864 output = crate::MeshStageInfo {
865 topology,
866 max_vertices: vertex_info.0,
867 max_vertices_override: None,
868 vertex_output_type: vertex_info.2,
869 max_primitives: primitive_info.0,
870 max_primitives_override: None,
871 primitive_output_type: primitive_info.2,
872 ..output
873 }
874 }
875 _ => error.set(EntryPointError::BadMeshOutputVariableType),
876 }
877 (
878 output,
879 [vertex_info.1, primitive_info.1],
880 error
881 .inner
882 .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
883 )
884 }
885
886 pub fn uses_mesh_shaders(&self) -> bool {
887 let binding_uses_mesh = |b: &crate::Binding| {
888 matches!(
889 b,
890 crate::Binding::BuiltIn(
891 crate::BuiltIn::MeshTaskSize
892 | crate::BuiltIn::CullPrimitive
893 | crate::BuiltIn::PointIndex
894 | crate::BuiltIn::LineIndices
895 | crate::BuiltIn::TriangleIndices
896 | crate::BuiltIn::VertexCount
897 | crate::BuiltIn::Vertices
898 | crate::BuiltIn::PrimitiveCount
899 | crate::BuiltIn::Primitives,
900 ) | crate::Binding::Location {
901 per_primitive: true,
902 ..
903 }
904 )
905 };
906 for (_, ty) in self.types.iter() {
907 match ty.inner {
908 crate::TypeInner::Struct { ref members, .. } => {
909 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
910 if binding_uses_mesh(binding) {
911 return true;
912 }
913 }
914 }
915 _ => (),
916 }
917 }
918 for ep in &self.entry_points {
919 if matches!(
920 ep.stage,
921 crate::ShaderStage::Mesh | crate::ShaderStage::Task
922 ) {
923 return true;
924 }
925 for binding in ep
926 .function
927 .arguments
928 .iter()
929 .filter_map(|arg| arg.binding.as_ref())
930 .chain(
931 ep.function
932 .result
933 .iter()
934 .filter_map(|res| res.binding.as_ref()),
935 )
936 {
937 if binding_uses_mesh(binding) {
938 return true;
939 }
940 }
941 }
942 if self
943 .global_variables
944 .iter()
945 .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
946 {
947 return true;
948 }
949 false
950 }
951}
952
953impl crate::MeshOutputTopology {
954 pub const fn to_builtin(self) -> crate::BuiltIn {
955 match self {
956 Self::Points => crate::BuiltIn::PointIndex,
957 Self::Lines => crate::BuiltIn::LineIndices,
958 Self::Triangles => crate::BuiltIn::TriangleIndices,
959 }
960 }
961}
962
963impl crate::AddressSpace {
964 pub const fn is_workgroup_like(self) -> bool {
965 matches!(self, Self::WorkGroup | Self::TaskPayload)
966 }
967}