1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28 concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35 fn from(format: super::StorageFormat) -> Self {
36 use super::{ScalarKind as Sk, StorageFormat as Sf};
37 let kind = match format {
38 Sf::R8Unorm => Sk::Float,
39 Sf::R8Snorm => Sk::Float,
40 Sf::R8Uint => Sk::Uint,
41 Sf::R8Sint => Sk::Sint,
42 Sf::R16Uint => Sk::Uint,
43 Sf::R16Sint => Sk::Sint,
44 Sf::R16Float => Sk::Float,
45 Sf::Rg8Unorm => Sk::Float,
46 Sf::Rg8Snorm => Sk::Float,
47 Sf::Rg8Uint => Sk::Uint,
48 Sf::Rg8Sint => Sk::Sint,
49 Sf::R32Uint => Sk::Uint,
50 Sf::R32Sint => Sk::Sint,
51 Sf::R32Float => Sk::Float,
52 Sf::Rg16Uint => Sk::Uint,
53 Sf::Rg16Sint => Sk::Sint,
54 Sf::Rg16Float => Sk::Float,
55 Sf::Rgba8Unorm => Sk::Float,
56 Sf::Rgba8Snorm => Sk::Float,
57 Sf::Rgba8Uint => Sk::Uint,
58 Sf::Rgba8Sint => Sk::Sint,
59 Sf::Bgra8Unorm => Sk::Float,
60 Sf::Rgb10a2Uint => Sk::Uint,
61 Sf::Rgb10a2Unorm => Sk::Float,
62 Sf::Rg11b10Ufloat => Sk::Float,
63 Sf::R64Uint => Sk::Uint,
64 Sf::Rg32Uint => Sk::Uint,
65 Sf::Rg32Sint => Sk::Sint,
66 Sf::Rg32Float => Sk::Float,
67 Sf::Rgba16Uint => Sk::Uint,
68 Sf::Rgba16Sint => Sk::Sint,
69 Sf::Rgba16Float => Sk::Float,
70 Sf::Rgba32Uint => Sk::Uint,
71 Sf::Rgba32Sint => Sk::Sint,
72 Sf::Rgba32Float => Sk::Float,
73 Sf::R16Unorm => Sk::Float,
74 Sf::R16Snorm => Sk::Float,
75 Sf::Rg16Unorm => Sk::Float,
76 Sf::Rg16Snorm => Sk::Float,
77 Sf::Rgba16Unorm => Sk::Float,
78 Sf::Rgba16Snorm => Sk::Float,
79 };
80 let width = match format {
81 Sf::R64Uint => 8,
82 _ => 4,
83 };
84 super::Scalar { kind, width }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90 F64(u64),
91 F32(u32),
92 F16(u16),
93 U32(u32),
94 I32(i32),
95 U64(u64),
96 I64(i64),
97 Bool(bool),
98 AbstractInt(i64),
99 AbstractFloat(u64),
100}
101
102impl From<crate::Literal> for HashableLiteral {
103 fn from(l: crate::Literal) -> Self {
104 match l {
105 crate::Literal::F64(v) => Self::F64(v.to_bits()),
106 crate::Literal::F32(v) => Self::F32(v.to_bits()),
107 crate::Literal::F16(v) => Self::F16(v.to_bits()),
108 crate::Literal::U32(v) => Self::U32(v),
109 crate::Literal::I32(v) => Self::I32(v),
110 crate::Literal::U64(v) => Self::U64(v),
111 crate::Literal::I64(v) => Self::I64(v),
112 crate::Literal::Bool(v) => Self::Bool(v),
113 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
114 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
115 }
116 }
117}
118
119impl crate::Literal {
120 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
121 match (value, scalar.kind, scalar.width) {
122 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
123 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
124 (value, crate::ScalarKind::Float, 2) => {
125 Some(Self::F16(half::f16::from_f32_const(value as _)))
126 }
127 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
128 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
129 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
130 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
131 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
132 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
133 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
134 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
135 _ => None,
136 }
137 }
138
139 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
140 Self::new(0, scalar)
141 }
142
143 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
144 Self::new(1, scalar)
145 }
146
147 pub const fn width(&self) -> crate::Bytes {
148 match *self {
149 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
150 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
151 Self::F16(_) => 2,
152 Self::Bool(_) => crate::BOOL_WIDTH,
153 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
154 }
155 }
156 pub const fn scalar(&self) -> crate::Scalar {
157 match *self {
158 Self::F64(_) => crate::Scalar::F64,
159 Self::F32(_) => crate::Scalar::F32,
160 Self::F16(_) => crate::Scalar::F16,
161 Self::U32(_) => crate::Scalar::U32,
162 Self::I32(_) => crate::Scalar::I32,
163 Self::U64(_) => crate::Scalar::U64,
164 Self::I64(_) => crate::Scalar::I64,
165 Self::Bool(_) => crate::Scalar::BOOL,
166 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
167 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
168 }
169 }
170 pub const fn scalar_kind(&self) -> crate::ScalarKind {
171 self.scalar().kind
172 }
173 pub const fn ty_inner(&self) -> crate::TypeInner {
174 crate::TypeInner::Scalar(self.scalar())
175 }
176}
177
178impl TryFrom<crate::Literal> for u32 {
179 type Error = ConstValueError;
180
181 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
182 match value {
183 crate::Literal::U32(value) => Ok(value),
184 crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
185 _ => Err(ConstValueError::InvalidType),
186 }
187 }
188}
189
190impl TryFrom<crate::Literal> for bool {
191 type Error = ConstValueError;
192
193 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
194 match value {
195 crate::Literal::Bool(value) => Ok(value),
196 _ => Err(ConstValueError::InvalidType),
197 }
198 }
199}
200
201impl super::AddressSpace {
202 pub fn access(self) -> crate::StorageAccess {
203 use crate::StorageAccess as Sa;
204 match self {
205 crate::AddressSpace::Function
206 | crate::AddressSpace::Private
207 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
208 crate::AddressSpace::Uniform => Sa::LOAD,
209 crate::AddressSpace::Storage { access } => access,
210 crate::AddressSpace::Handle => Sa::LOAD,
211 crate::AddressSpace::Immediate => Sa::LOAD,
212 crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
215 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
216 Sa::LOAD | Sa::STORE
217 }
218 }
219 }
220}
221
222impl super::MathFunction {
223 pub const fn argument_count(&self) -> usize {
224 match *self {
225 Self::Abs => 1,
227 Self::Min => 2,
228 Self::Max => 2,
229 Self::Clamp => 3,
230 Self::Saturate => 1,
231 Self::Cos => 1,
233 Self::Cosh => 1,
234 Self::Sin => 1,
235 Self::Sinh => 1,
236 Self::Tan => 1,
237 Self::Tanh => 1,
238 Self::Acos => 1,
239 Self::Asin => 1,
240 Self::Atan => 1,
241 Self::Atan2 => 2,
242 Self::Asinh => 1,
243 Self::Acosh => 1,
244 Self::Atanh => 1,
245 Self::Radians => 1,
246 Self::Degrees => 1,
247 Self::Ceil => 1,
249 Self::Floor => 1,
250 Self::Round => 1,
251 Self::Fract => 1,
252 Self::Trunc => 1,
253 Self::Modf => 1,
254 Self::Frexp => 1,
255 Self::Ldexp => 2,
256 Self::Exp => 1,
258 Self::Exp2 => 1,
259 Self::Log => 1,
260 Self::Log2 => 1,
261 Self::Pow => 2,
262 Self::Dot => 2,
264 Self::Dot4I8Packed => 2,
265 Self::Dot4U8Packed => 2,
266 Self::Outer => 2,
267 Self::Cross => 2,
268 Self::Distance => 2,
269 Self::Length => 1,
270 Self::Normalize => 1,
271 Self::FaceForward => 3,
272 Self::Reflect => 2,
273 Self::Refract => 3,
274 Self::Sign => 1,
276 Self::Fma => 3,
277 Self::Mix => 3,
278 Self::Step => 2,
279 Self::SmoothStep => 3,
280 Self::Sqrt => 1,
281 Self::InverseSqrt => 1,
282 Self::Inverse => 1,
283 Self::Transpose => 1,
284 Self::Determinant => 1,
285 Self::QuantizeToF16 => 1,
286 Self::CountTrailingZeros => 1,
288 Self::CountLeadingZeros => 1,
289 Self::CountOneBits => 1,
290 Self::ReverseBits => 1,
291 Self::ExtractBits => 3,
292 Self::InsertBits => 4,
293 Self::FirstTrailingBit => 1,
294 Self::FirstLeadingBit => 1,
295 Self::Pack4x8snorm => 1,
297 Self::Pack4x8unorm => 1,
298 Self::Pack2x16snorm => 1,
299 Self::Pack2x16unorm => 1,
300 Self::Pack2x16float => 1,
301 Self::Pack4xI8 => 1,
302 Self::Pack4xU8 => 1,
303 Self::Pack4xI8Clamp => 1,
304 Self::Pack4xU8Clamp => 1,
305 Self::Unpack4x8snorm => 1,
307 Self::Unpack4x8unorm => 1,
308 Self::Unpack2x16snorm => 1,
309 Self::Unpack2x16unorm => 1,
310 Self::Unpack2x16float => 1,
311 Self::Unpack4xI8 => 1,
312 Self::Unpack4xU8 => 1,
313 }
314 }
315}
316
317impl crate::Expression {
318 pub const fn needs_pre_emit(&self) -> bool {
320 match *self {
321 Self::Literal(_)
322 | Self::Constant(_)
323 | Self::Override(_)
324 | Self::ZeroValue(_)
325 | Self::FunctionArgument(_)
326 | Self::GlobalVariable(_)
327 | Self::LocalVariable(_) => true,
328 _ => false,
329 }
330 }
331
332 pub const fn is_dynamic_index(&self) -> bool {
346 match *self {
347 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
348 _ => true,
349 }
350 }
351}
352
353impl crate::Function {
354 pub fn originating_global(
363 &self,
364 mut pointer: crate::Handle<crate::Expression>,
365 ) -> Option<crate::Handle<crate::GlobalVariable>> {
366 loop {
367 pointer = match self.expressions[pointer] {
368 crate::Expression::Access { base, .. } => base,
369 crate::Expression::AccessIndex { base, .. } => base,
370 crate::Expression::GlobalVariable(handle) => return Some(handle),
371 crate::Expression::LocalVariable(_) => return None,
372 crate::Expression::FunctionArgument(_) => return None,
373 _ => unreachable!(),
375 }
376 }
377 }
378}
379
380impl crate::SampleLevel {
381 pub const fn implicit_derivatives(&self) -> bool {
382 match *self {
383 Self::Auto | Self::Bias(_) => true,
384 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
385 }
386 }
387}
388
389impl crate::Binding {
390 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
391 match *self {
392 crate::Binding::BuiltIn(built_in) => Some(built_in),
393 Self::Location { .. } => None,
394 }
395 }
396}
397
398impl super::SwizzleComponent {
399 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
400
401 pub const fn index(&self) -> u32 {
402 match *self {
403 Self::X => 0,
404 Self::Y => 1,
405 Self::Z => 2,
406 Self::W => 3,
407 }
408 }
409 pub const fn from_index(idx: u32) -> Self {
410 match idx {
411 0 => Self::X,
412 1 => Self::Y,
413 2 => Self::Z,
414 _ => Self::W,
415 }
416 }
417}
418
419impl super::ImageClass {
420 pub const fn is_multisampled(self) -> bool {
421 match self {
422 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
423 crate::ImageClass::Storage { .. } => false,
424 crate::ImageClass::External => false,
425 }
426 }
427
428 pub const fn is_mipmapped(self) -> bool {
429 match self {
430 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
431 crate::ImageClass::Storage { .. } => false,
432 crate::ImageClass::External => false,
433 }
434 }
435
436 pub const fn is_depth(self) -> bool {
437 matches!(self, crate::ImageClass::Depth { .. })
438 }
439}
440
441impl crate::Module {
442 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
443 GlobalCtx {
444 types: &self.types,
445 constants: &self.constants,
446 overrides: &self.overrides,
447 global_expressions: &self.global_expressions,
448 }
449 }
450
451 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
452 compare_types(lhs, rhs, &self.types)
453 }
454}
455
456#[derive(Debug)]
457pub enum ConstValueError {
458 NonConst,
459 Negative,
460 InvalidType,
461}
462
463impl From<core::convert::Infallible> for ConstValueError {
464 fn from(_: core::convert::Infallible) -> Self {
465 unreachable!()
466 }
467}
468
469#[derive(Clone, Copy)]
470pub struct GlobalCtx<'a> {
471 pub types: &'a crate::UniqueArena<crate::Type>,
472 pub constants: &'a crate::Arena<crate::Constant>,
473 pub overrides: &'a crate::Arena<crate::Override>,
474 pub global_expressions: &'a crate::Arena<crate::Expression>,
475}
476
477impl GlobalCtx<'_> {
478 #[cfg_attr(
485 not(any(
486 feature = "glsl-in",
487 feature = "spv-in",
488 feature = "wgsl-in",
489 glsl_out,
490 hlsl_out,
491 msl_out,
492 wgsl_out
493 )),
494 allow(dead_code)
495 )]
496 pub(super) fn get_const_val<T, E>(
497 &self,
498 handle: crate::Handle<crate::Expression>,
499 ) -> Result<T, ConstValueError>
500 where
501 T: TryFrom<crate::Literal, Error = E>,
502 E: Into<ConstValueError>,
503 {
504 self.get_const_val_from(handle, self.global_expressions)
505 }
506
507 pub(super) fn get_const_val_from<T, E>(
508 &self,
509 handle: crate::Handle<crate::Expression>,
510 arena: &crate::Arena<crate::Expression>,
511 ) -> Result<T, ConstValueError>
512 where
513 T: TryFrom<crate::Literal, Error = E>,
514 E: Into<ConstValueError>,
515 {
516 fn get(
517 gctx: GlobalCtx,
518 handle: crate::Handle<crate::Expression>,
519 arena: &crate::Arena<crate::Expression>,
520 ) -> Option<crate::Literal> {
521 match arena[handle] {
522 crate::Expression::Literal(literal) => Some(literal),
523 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
524 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
525 _ => None,
526 },
527 _ => None,
528 }
529 }
530 let value = match arena[handle] {
531 crate::Expression::Constant(c) => {
532 get(*self, self.constants[c].init, self.global_expressions)
533 }
534 _ => get(*self, handle, arena),
535 };
536 match value {
537 Some(v) => v.try_into().map_err(Into::into),
538 None => Err(ConstValueError::NonConst),
539 }
540 }
541
542 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
543 compare_types(lhs, rhs, self.types)
544 }
545}
546
547#[derive(Error, Debug, Clone, Copy, PartialEq)]
548pub enum ResolveArraySizeError {
549 #[error("array element count must be positive (> 0)")]
550 ExpectedPositiveArrayLength,
551 #[error("internal: array size override has not been resolved")]
552 NonConstArrayLength,
553}
554
555impl crate::ArraySize {
556 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
566 match *self {
567 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
568 crate::ArraySize::Pending(handle) => {
569 let Some(expr) = gctx.overrides[handle].init else {
570 return Err(ResolveArraySizeError::NonConstArrayLength);
571 };
572 let length = gctx.get_const_val(expr).map_err(|err| match err {
573 ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
574 ConstValueError::Negative | ConstValueError::InvalidType => {
575 ResolveArraySizeError::ExpectedPositiveArrayLength
576 }
577 })?;
578
579 if length == 0 {
580 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
581 }
582
583 Ok(IndexableLength::Known(length))
584 }
585 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
586 }
587 }
588}
589
590pub fn flatten_compose<'arenas>(
603 ty: crate::Handle<crate::Type>,
604 components: &'arenas [crate::Handle<crate::Expression>],
605 expressions: &'arenas crate::Arena<crate::Expression>,
606 types: &'arenas crate::UniqueArena<crate::Type>,
607) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
608 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
614 (size as usize, true)
615 } else {
616 (components.len(), false)
617 };
618
619 fn flatten_compose<'c>(
621 component: &'c crate::Handle<crate::Expression>,
622 is_vector: bool,
623 expressions: &'c crate::Arena<crate::Expression>,
624 ) -> &'c [crate::Handle<crate::Expression>] {
625 if is_vector {
626 if let crate::Expression::Compose {
627 ty: _,
628 components: ref subcomponents,
629 } = expressions[*component]
630 {
631 return subcomponents;
632 }
633 }
634 core::slice::from_ref(component)
635 }
636
637 fn flatten_splat<'c>(
639 component: &'c crate::Handle<crate::Expression>,
640 is_vector: bool,
641 expressions: &'c crate::Arena<crate::Expression>,
642 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
643 let mut expr = *component;
644 let mut count = 1;
645 if is_vector {
646 if let crate::Expression::Splat { size, value } = expressions[expr] {
647 expr = value;
648 count = size as usize;
649 }
650 }
651 core::iter::repeat_n(expr, count)
652 }
653
654 components
661 .iter()
662 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
663 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
664 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
665 .take(size)
666}
667
668impl super::ShaderStage {
669 pub const fn compute_like(self) -> bool {
670 match self {
671 Self::Vertex | Self::Fragment => false,
672 Self::Compute | Self::Task | Self::Mesh => true,
673 Self::RayGeneration | Self::AnyHit | Self::ClosestHit | Self::Miss => false,
674 }
675 }
676
677 pub const fn mesh_like(self) -> bool {
679 match self {
680 Self::Task | Self::Mesh => true,
681 _ => false,
682 }
683 }
684}
685
686#[test]
687fn test_matrix_size() {
688 let module = crate::Module::default();
689 assert_eq!(
690 crate::TypeInner::Matrix {
691 columns: crate::VectorSize::Tri,
692 rows: crate::VectorSize::Tri,
693 scalar: crate::Scalar::F32,
694 }
695 .size(module.to_ctx()),
696 48,
697 );
698}
699
700impl crate::Module {
701 #[allow(clippy::type_complexity)]
711 pub fn analyze_mesh_shader_info(
712 &self,
713 gv: crate::Handle<crate::GlobalVariable>,
714 ) -> (
715 crate::MeshStageInfo,
716 [Option<crate::Handle<crate::Override>>; 2],
717 Option<crate::WithSpan<crate::valid::EntryPointError>>,
718 ) {
719 use crate::span::AddSpan;
720 use crate::valid::EntryPointError;
721 #[derive(Default)]
722 struct OutError {
723 pub inner: Option<EntryPointError>,
724 }
725 impl OutError {
726 pub fn set(&mut self, err: EntryPointError) {
727 if self.inner.is_none() {
728 self.inner = Some(err);
729 }
730 }
731 }
732
733 let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
735 let mut output = crate::MeshStageInfo {
736 topology: crate::MeshOutputTopology::Triangles,
737 max_vertices: 0,
738 max_vertices_override: None,
739 max_primitives: 0,
740 max_primitives_override: None,
741 vertex_output_type: null_type,
742 primitive_output_type: null_type,
743 output_variable: gv,
744 };
745 let mut error = OutError::default();
747 let r#type = &self.types[self.global_variables[gv].ty].inner;
748
749 let mut topology = output.topology;
750 let mut vertex_info = (0, None, null_type);
752 let mut primitive_info = (0, None, null_type);
753
754 match r#type {
755 &crate::TypeInner::Struct { ref members, .. } => {
756 let mut builtins = crate::FastHashSet::default();
757 for member in members {
758 match member.binding {
759 Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
760 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
762 error.set(EntryPointError::BadMeshOutputVariableField);
763 }
764 if builtins.contains(&crate::BuiltIn::VertexCount) {
766 error.set(EntryPointError::BadMeshOutputVariableType);
767 }
768 builtins.insert(crate::BuiltIn::VertexCount);
769 }
770 Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
771 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
773 error.set(EntryPointError::BadMeshOutputVariableField);
774 }
775 if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
777 error.set(EntryPointError::BadMeshOutputVariableType);
778 }
779 builtins.insert(crate::BuiltIn::PrimitiveCount);
780 }
781 Some(crate::Binding::BuiltIn(
782 crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
783 )) => {
784 let ty = &self.types[member.ty].inner;
785 let (a, b, c) = match ty {
787 &crate::TypeInner::Array { base, size, .. } => {
788 let ty = base;
789 let (max, max_override) = match size {
790 crate::ArraySize::Constant(a) => (a.get(), None),
791 crate::ArraySize::Pending(o) => (0, Some(o)),
792 crate::ArraySize::Dynamic => {
793 error.set(EntryPointError::BadMeshOutputVariableField);
794 (0, None)
795 }
796 };
797 (max, max_override, ty)
798 }
799 _ => {
800 error.set(EntryPointError::BadMeshOutputVariableField);
801 (0, None, null_type)
802 }
803 };
804 if matches!(
805 member.binding,
806 Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
807 ) {
808 primitive_info = (a, b, c);
810 match self.types[c].inner {
811 crate::TypeInner::Struct { ref members, .. } => {
812 for member in members {
813 match member.binding {
814 Some(crate::Binding::BuiltIn(
815 crate::BuiltIn::PointIndex,
816 )) => {
817 topology = crate::MeshOutputTopology::Points;
818 }
819 Some(crate::Binding::BuiltIn(
820 crate::BuiltIn::LineIndices,
821 )) => {
822 topology = crate::MeshOutputTopology::Lines;
823 }
824 Some(crate::Binding::BuiltIn(
825 crate::BuiltIn::TriangleIndices,
826 )) => {
827 topology = crate::MeshOutputTopology::Triangles;
828 }
829 _ => (),
830 }
831 }
832 }
833 _ => (),
834 }
835 if builtins.contains(&crate::BuiltIn::Primitives) {
837 error.set(EntryPointError::BadMeshOutputVariableType);
838 }
839 builtins.insert(crate::BuiltIn::Primitives);
840 } else {
841 vertex_info = (a, b, c);
842 if builtins.contains(&crate::BuiltIn::Vertices) {
844 error.set(EntryPointError::BadMeshOutputVariableType);
845 }
846 builtins.insert(crate::BuiltIn::Vertices);
847 }
848 }
849 _ => error.set(EntryPointError::BadMeshOutputVariableType),
850 }
851 }
852 output = crate::MeshStageInfo {
853 topology,
854 max_vertices: vertex_info.0,
855 max_vertices_override: None,
856 vertex_output_type: vertex_info.2,
857 max_primitives: primitive_info.0,
858 max_primitives_override: None,
859 primitive_output_type: primitive_info.2,
860 ..output
861 }
862 }
863 _ => error.set(EntryPointError::BadMeshOutputVariableType),
864 }
865 (
866 output,
867 [vertex_info.1, primitive_info.1],
868 error
869 .inner
870 .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
871 )
872 }
873
874 pub fn uses_mesh_shaders(&self) -> bool {
875 let binding_uses_mesh = |b: &crate::Binding| {
876 matches!(
877 b,
878 crate::Binding::BuiltIn(
879 crate::BuiltIn::MeshTaskSize
880 | crate::BuiltIn::CullPrimitive
881 | crate::BuiltIn::PointIndex
882 | crate::BuiltIn::LineIndices
883 | crate::BuiltIn::TriangleIndices
884 | crate::BuiltIn::VertexCount
885 | crate::BuiltIn::Vertices
886 | crate::BuiltIn::PrimitiveCount
887 | crate::BuiltIn::Primitives,
888 ) | crate::Binding::Location {
889 per_primitive: true,
890 ..
891 }
892 )
893 };
894 for (_, ty) in self.types.iter() {
895 match ty.inner {
896 crate::TypeInner::Struct { ref members, .. } => {
897 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
898 if binding_uses_mesh(binding) {
899 return true;
900 }
901 }
902 }
903 _ => (),
904 }
905 }
906 for ep in &self.entry_points {
907 if matches!(
908 ep.stage,
909 crate::ShaderStage::Mesh | crate::ShaderStage::Task
910 ) {
911 return true;
912 }
913 for binding in ep
914 .function
915 .arguments
916 .iter()
917 .filter_map(|arg| arg.binding.as_ref())
918 .chain(
919 ep.function
920 .result
921 .iter()
922 .filter_map(|res| res.binding.as_ref()),
923 )
924 {
925 if binding_uses_mesh(binding) {
926 return true;
927 }
928 }
929 }
930 if self
931 .global_variables
932 .iter()
933 .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
934 {
935 return true;
936 }
937 false
938 }
939}
940
941impl crate::MeshOutputTopology {
942 pub const fn to_builtin(self) -> crate::BuiltIn {
943 match self {
944 Self::Points => crate::BuiltIn::PointIndex,
945 Self::Lines => crate::BuiltIn::LineIndices,
946 Self::Triangles => crate::BuiltIn::TriangleIndices,
947 }
948 }
949}
950
951impl crate::AddressSpace {
952 pub const fn is_workgroup_like(self) -> bool {
953 matches!(self, Self::WorkGroup | Self::TaskPayload)
954 }
955}