1mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod keyword_set;
9mod layouter;
10mod namer;
11mod overloads;
12mod terminator;
13mod type_methods;
14mod typifier;
15
16pub use constant_evaluator::{
17 ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
18};
19pub use emitter::Emitter;
20pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
21pub use keyword_set::{CaseInsensitiveKeywordSet, KeywordSet};
22pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
23pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
24pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
25pub use terminator::ensure_block_returns;
26use thiserror::Error;
27pub use type_methods::{
28 concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
29};
30pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
31
32use crate::non_max_u32::NonMaxU32;
33
34impl From<super::StorageFormat> for super::Scalar {
35 fn from(format: super::StorageFormat) -> Self {
36 use super::{ScalarKind as Sk, StorageFormat as Sf};
37 let kind = match format {
38 Sf::R8Unorm => Sk::Float,
39 Sf::R8Snorm => Sk::Float,
40 Sf::R8Uint => Sk::Uint,
41 Sf::R8Sint => Sk::Sint,
42 Sf::R16Uint => Sk::Uint,
43 Sf::R16Sint => Sk::Sint,
44 Sf::R16Float => Sk::Float,
45 Sf::Rg8Unorm => Sk::Float,
46 Sf::Rg8Snorm => Sk::Float,
47 Sf::Rg8Uint => Sk::Uint,
48 Sf::Rg8Sint => Sk::Sint,
49 Sf::R32Uint => Sk::Uint,
50 Sf::R32Sint => Sk::Sint,
51 Sf::R32Float => Sk::Float,
52 Sf::Rg16Uint => Sk::Uint,
53 Sf::Rg16Sint => Sk::Sint,
54 Sf::Rg16Float => Sk::Float,
55 Sf::Rgba8Unorm => Sk::Float,
56 Sf::Rgba8Snorm => Sk::Float,
57 Sf::Rgba8Uint => Sk::Uint,
58 Sf::Rgba8Sint => Sk::Sint,
59 Sf::Bgra8Unorm => Sk::Float,
60 Sf::Rgb10a2Uint => Sk::Uint,
61 Sf::Rgb10a2Unorm => Sk::Float,
62 Sf::Rg11b10Ufloat => Sk::Float,
63 Sf::R64Uint => Sk::Uint,
64 Sf::Rg32Uint => Sk::Uint,
65 Sf::Rg32Sint => Sk::Sint,
66 Sf::Rg32Float => Sk::Float,
67 Sf::Rgba16Uint => Sk::Uint,
68 Sf::Rgba16Sint => Sk::Sint,
69 Sf::Rgba16Float => Sk::Float,
70 Sf::Rgba32Uint => Sk::Uint,
71 Sf::Rgba32Sint => Sk::Sint,
72 Sf::Rgba32Float => Sk::Float,
73 Sf::R16Unorm => Sk::Float,
74 Sf::R16Snorm => Sk::Float,
75 Sf::Rg16Unorm => Sk::Float,
76 Sf::Rg16Snorm => Sk::Float,
77 Sf::Rgba16Unorm => Sk::Float,
78 Sf::Rgba16Snorm => Sk::Float,
79 };
80 let width = match format {
81 Sf::R64Uint => 8,
82 _ => 4,
83 };
84 super::Scalar { kind, width }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum HashableLiteral {
90 F64(u64),
91 F32(u32),
92 F16(u16),
93 U16(u16),
94 I16(i16),
95 U32(u32),
96 I32(i32),
97 U64(u64),
98 I64(i64),
99 Bool(bool),
100 AbstractInt(i64),
101 AbstractFloat(u64),
102}
103
104impl From<crate::Literal> for HashableLiteral {
105 fn from(l: crate::Literal) -> Self {
106 match l {
107 crate::Literal::F64(v) => Self::F64(v.to_bits()),
108 crate::Literal::F32(v) => Self::F32(v.to_bits()),
109 crate::Literal::F16(v) => Self::F16(v.to_bits()),
110 crate::Literal::U16(v) => Self::U16(v),
111 crate::Literal::I16(v) => Self::I16(v),
112 crate::Literal::U32(v) => Self::U32(v),
113 crate::Literal::I32(v) => Self::I32(v),
114 crate::Literal::U64(v) => Self::U64(v),
115 crate::Literal::I64(v) => Self::I64(v),
116 crate::Literal::Bool(v) => Self::Bool(v),
117 crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
118 crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
119 }
120 }
121}
122
123impl crate::Literal {
124 pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
125 match (value, scalar.kind, scalar.width) {
126 (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
127 (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
128 (value, crate::ScalarKind::Float, 2) => {
129 Some(Self::F16(half::f16::from_f32_const(value as _)))
130 }
131 (value, crate::ScalarKind::Uint, 2) => Some(Self::U16(value as _)),
132 (value, crate::ScalarKind::Sint, 2) => Some(Self::I16(value as _)),
133 (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
134 (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
135 (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
136 (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
137 (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
138 (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
139 (value, crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(value as _)),
140 (value, crate::ScalarKind::AbstractFloat, 8) => Some(Self::AbstractFloat(value as _)),
141 _ => None,
142 }
143 }
144
145 pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
146 Self::new(0, scalar)
147 }
148
149 pub const fn one(scalar: crate::Scalar) -> Option<Self> {
150 Self::new(1, scalar)
151 }
152
153 pub const fn minus_one(scalar: crate::Scalar) -> Option<Self> {
154 match (scalar.kind, scalar.width) {
155 (crate::ScalarKind::Float, 8) => Some(Self::F64(-1.0)),
156 (crate::ScalarKind::Float, 4) => Some(Self::F32(-1.0)),
157 (crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))),
158 (crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)),
159 (crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)),
160 (crate::ScalarKind::Sint, 2) => Some(Self::I16(-1)),
161 (crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)),
162 _ => None,
163 }
164 }
165
166 pub const fn width(&self) -> crate::Bytes {
167 match *self {
168 Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
169 Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
170 Self::F16(_) | Self::U16(_) | Self::I16(_) => 2,
171 Self::Bool(_) => crate::BOOL_WIDTH,
172 Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
173 }
174 }
175 pub const fn scalar(&self) -> crate::Scalar {
176 match *self {
177 Self::F64(_) => crate::Scalar::F64,
178 Self::F32(_) => crate::Scalar::F32,
179 Self::F16(_) => crate::Scalar::F16,
180 Self::U16(_) => crate::Scalar::U16,
181 Self::I16(_) => crate::Scalar::I16,
182 Self::U32(_) => crate::Scalar::U32,
183 Self::I32(_) => crate::Scalar::I32,
184 Self::U64(_) => crate::Scalar::U64,
185 Self::I64(_) => crate::Scalar::I64,
186 Self::Bool(_) => crate::Scalar::BOOL,
187 Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
188 Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
189 }
190 }
191 pub const fn scalar_kind(&self) -> crate::ScalarKind {
192 self.scalar().kind
193 }
194 pub const fn ty_inner(&self) -> crate::TypeInner {
195 crate::TypeInner::Scalar(self.scalar())
196 }
197}
198
199impl TryFrom<crate::Literal> for u32 {
200 type Error = ConstValueError;
201
202 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
203 match value {
204 crate::Literal::U16(value) => Ok(value as u32),
205 crate::Literal::I16(value) => value.try_into().map_err(|_| ConstValueError::Negative),
206 crate::Literal::U32(value) => Ok(value),
207 crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
208 _ => Err(ConstValueError::InvalidType),
209 }
210 }
211}
212
213impl TryFrom<crate::Literal> for bool {
214 type Error = ConstValueError;
215
216 fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
217 match value {
218 crate::Literal::Bool(value) => Ok(value),
219 _ => Err(ConstValueError::InvalidType),
220 }
221 }
222}
223
224impl super::AddressSpace {
225 pub fn access(self) -> crate::StorageAccess {
226 use crate::StorageAccess as Sa;
227 match self {
228 crate::AddressSpace::Function
229 | crate::AddressSpace::Private
230 | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
231 crate::AddressSpace::Uniform => Sa::LOAD,
232 crate::AddressSpace::Storage { access } => access,
233 crate::AddressSpace::Handle => Sa::LOAD,
234 crate::AddressSpace::Immediate => Sa::LOAD,
235 crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE,
238 crate::AddressSpace::RayPayload | crate::AddressSpace::IncomingRayPayload => {
239 Sa::LOAD | Sa::STORE
240 }
241 }
242 }
243}
244
245impl super::MathFunction {
246 pub const fn argument_count(&self) -> usize {
247 match *self {
248 Self::Abs => 1,
250 Self::Min => 2,
251 Self::Max => 2,
252 Self::Clamp => 3,
253 Self::Saturate => 1,
254 Self::Cos => 1,
256 Self::Cosh => 1,
257 Self::Sin => 1,
258 Self::Sinh => 1,
259 Self::Tan => 1,
260 Self::Tanh => 1,
261 Self::Acos => 1,
262 Self::Asin => 1,
263 Self::Atan => 1,
264 Self::Atan2 => 2,
265 Self::Asinh => 1,
266 Self::Acosh => 1,
267 Self::Atanh => 1,
268 Self::Radians => 1,
269 Self::Degrees => 1,
270 Self::Ceil => 1,
272 Self::Floor => 1,
273 Self::Round => 1,
274 Self::Fract => 1,
275 Self::Trunc => 1,
276 Self::Modf => 1,
277 Self::Frexp => 1,
278 Self::Ldexp => 2,
279 Self::Exp => 1,
281 Self::Exp2 => 1,
282 Self::Log => 1,
283 Self::Log2 => 1,
284 Self::Pow => 2,
285 Self::Dot => 2,
287 Self::Dot4I8Packed => 2,
288 Self::Dot4U8Packed => 2,
289 Self::Outer => 2,
290 Self::Cross => 2,
291 Self::Distance => 2,
292 Self::Length => 1,
293 Self::Normalize => 1,
294 Self::FaceForward => 3,
295 Self::Reflect => 2,
296 Self::Refract => 3,
297 Self::Sign => 1,
299 Self::Fma => 3,
300 Self::Mix => 3,
301 Self::Step => 2,
302 Self::SmoothStep => 3,
303 Self::Sqrt => 1,
304 Self::InverseSqrt => 1,
305 Self::Inverse => 1,
306 Self::Transpose => 1,
307 Self::Determinant => 1,
308 Self::QuantizeToF16 => 1,
309 Self::CountTrailingZeros => 1,
311 Self::CountLeadingZeros => 1,
312 Self::CountOneBits => 1,
313 Self::ReverseBits => 1,
314 Self::ExtractBits => 3,
315 Self::InsertBits => 4,
316 Self::FirstTrailingBit => 1,
317 Self::FirstLeadingBit => 1,
318 Self::Pack4x8snorm => 1,
320 Self::Pack4x8unorm => 1,
321 Self::Pack2x16snorm => 1,
322 Self::Pack2x16unorm => 1,
323 Self::Pack2x16float => 1,
324 Self::Pack4xI8 => 1,
325 Self::Pack4xU8 => 1,
326 Self::Pack4xI8Clamp => 1,
327 Self::Pack4xU8Clamp => 1,
328 Self::Unpack4x8snorm => 1,
330 Self::Unpack4x8unorm => 1,
331 Self::Unpack2x16snorm => 1,
332 Self::Unpack2x16unorm => 1,
333 Self::Unpack2x16float => 1,
334 Self::Unpack4xI8 => 1,
335 Self::Unpack4xU8 => 1,
336 }
337 }
338}
339
340impl crate::Expression {
341 pub const fn needs_pre_emit(&self) -> bool {
343 match *self {
344 Self::Literal(_)
345 | Self::Constant(_)
346 | Self::Override(_)
347 | Self::ZeroValue(_)
348 | Self::FunctionArgument(_)
349 | Self::GlobalVariable(_)
350 | Self::LocalVariable(_) => true,
351 _ => false,
352 }
353 }
354
355 pub const fn is_dynamic_index(&self) -> bool {
369 match *self {
370 Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
371 _ => true,
372 }
373 }
374}
375
376impl crate::Function {
377 pub fn originating_global(
386 &self,
387 mut pointer: crate::Handle<crate::Expression>,
388 ) -> Option<crate::Handle<crate::GlobalVariable>> {
389 loop {
390 pointer = match self.expressions[pointer] {
391 crate::Expression::Access { base, .. } => base,
392 crate::Expression::AccessIndex { base, .. } => base,
393 crate::Expression::GlobalVariable(handle) => return Some(handle),
394 _ => return None,
396 }
397 }
398 }
399}
400
401impl crate::SampleLevel {
402 pub const fn implicit_derivatives(&self) -> bool {
403 match *self {
404 Self::Auto | Self::Bias(_) => true,
405 Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
406 }
407 }
408}
409
410impl crate::Binding {
411 pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
412 match *self {
413 crate::Binding::BuiltIn(built_in) => Some(built_in),
414 Self::Location { .. } => None,
415 }
416 }
417}
418
419impl super::SwizzleComponent {
420 pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
421
422 pub const fn index(&self) -> u32 {
423 match *self {
424 Self::X => 0,
425 Self::Y => 1,
426 Self::Z => 2,
427 Self::W => 3,
428 }
429 }
430 pub const fn from_index(idx: u32) -> Self {
431 match idx {
432 0 => Self::X,
433 1 => Self::Y,
434 2 => Self::Z,
435 _ => Self::W,
436 }
437 }
438}
439
440impl super::ImageClass {
441 pub const fn is_multisampled(self) -> bool {
442 match self {
443 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
444 crate::ImageClass::Storage { .. } => false,
445 crate::ImageClass::External => false,
446 }
447 }
448
449 pub const fn is_mipmapped(self) -> bool {
450 match self {
451 crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
452 crate::ImageClass::Storage { .. } => false,
453 crate::ImageClass::External => false,
454 }
455 }
456
457 pub const fn is_depth(self) -> bool {
458 matches!(self, crate::ImageClass::Depth { .. })
459 }
460}
461
462impl crate::Module {
463 pub const fn to_ctx(&self) -> GlobalCtx<'_> {
464 GlobalCtx {
465 types: &self.types,
466 constants: &self.constants,
467 overrides: &self.overrides,
468 global_expressions: &self.global_expressions,
469 }
470 }
471
472 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
473 compare_types(lhs, rhs, &self.types)
474 }
475}
476
477#[derive(Debug)]
478pub enum ConstValueError {
479 NonConst,
480 Negative,
481 InvalidType,
482}
483
484impl From<core::convert::Infallible> for ConstValueError {
485 fn from(_: core::convert::Infallible) -> Self {
486 unreachable!()
487 }
488}
489
490#[derive(Clone, Copy, Debug)]
491pub struct GlobalCtx<'a> {
492 pub types: &'a crate::UniqueArena<crate::Type>,
493 pub constants: &'a crate::Arena<crate::Constant>,
494 pub overrides: &'a crate::Arena<crate::Override>,
495 pub global_expressions: &'a crate::Arena<crate::Expression>,
496}
497
498impl GlobalCtx<'_> {
499 #[cfg_attr(
506 not(any(
507 feature = "glsl-in",
508 feature = "spv-in",
509 feature = "wgsl-in",
510 glsl_out,
511 hlsl_out,
512 msl_out,
513 wgsl_out
514 )),
515 allow(dead_code)
516 )]
517 pub(super) fn get_const_val<T, E>(
518 &self,
519 handle: crate::Handle<crate::Expression>,
520 ) -> Result<T, ConstValueError>
521 where
522 T: TryFrom<crate::Literal, Error = E>,
523 E: Into<ConstValueError>,
524 {
525 self.get_const_val_from(handle, self.global_expressions)
526 }
527
528 pub(super) fn get_const_val_from<T, E>(
529 &self,
530 handle: crate::Handle<crate::Expression>,
531 arena: &crate::Arena<crate::Expression>,
532 ) -> Result<T, ConstValueError>
533 where
534 T: TryFrom<crate::Literal, Error = E>,
535 E: Into<ConstValueError>,
536 {
537 fn get(
538 gctx: GlobalCtx,
539 handle: crate::Handle<crate::Expression>,
540 arena: &crate::Arena<crate::Expression>,
541 ) -> Option<crate::Literal> {
542 match arena[handle] {
543 crate::Expression::Literal(literal) => Some(literal),
544 crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
545 crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
546 _ => None,
547 },
548 _ => None,
549 }
550 }
551 let value = match arena[handle] {
552 crate::Expression::Constant(c) => {
553 get(*self, self.constants[c].init, self.global_expressions)
554 }
555 _ => get(*self, handle, arena),
556 };
557 match value {
558 Some(v) => v.try_into().map_err(Into::into),
559 None => Err(ConstValueError::NonConst),
560 }
561 }
562
563 pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
564 compare_types(lhs, rhs, self.types)
565 }
566}
567
568#[derive(Error, Debug, Clone, Copy, PartialEq)]
569pub enum ResolveArraySizeError {
570 #[error("array element count must be positive (> 0)")]
571 ExpectedPositiveArrayLength,
572 #[error("internal: array size override has not been resolved")]
573 NonConstArrayLength,
574}
575
576impl crate::ArraySize {
577 pub fn resolve(&self, gctx: GlobalCtx) -> Result<IndexableLength, ResolveArraySizeError> {
587 match *self {
588 crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())),
589 crate::ArraySize::Pending(handle) => {
590 let Some(expr) = gctx.overrides[handle].init else {
591 return Err(ResolveArraySizeError::NonConstArrayLength);
592 };
593 let length = gctx.get_const_val(expr).map_err(|err| match err {
594 ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength,
595 ConstValueError::Negative | ConstValueError::InvalidType => {
596 ResolveArraySizeError::ExpectedPositiveArrayLength
597 }
598 })?;
599
600 if length == 0 {
601 return Err(ResolveArraySizeError::ExpectedPositiveArrayLength);
602 }
603
604 Ok(IndexableLength::Known(length))
605 }
606 crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic),
607 }
608 }
609}
610
611pub fn flatten_compose<'arenas>(
624 ty: crate::Handle<crate::Type>,
625 components: &'arenas [crate::Handle<crate::Expression>],
626 expressions: &'arenas crate::Arena<crate::Expression>,
627 types: &'arenas crate::UniqueArena<crate::Type>,
628) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
629 let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
635 (size as usize, true)
636 } else {
637 (components.len(), false)
638 };
639
640 fn flatten_compose<'c>(
642 component: &'c crate::Handle<crate::Expression>,
643 is_vector: bool,
644 expressions: &'c crate::Arena<crate::Expression>,
645 ) -> &'c [crate::Handle<crate::Expression>] {
646 if is_vector {
647 if let crate::Expression::Compose {
648 ty: _,
649 components: ref subcomponents,
650 } = expressions[*component]
651 {
652 return subcomponents;
653 }
654 }
655 core::slice::from_ref(component)
656 }
657
658 fn flatten_splat<'c>(
660 component: &'c crate::Handle<crate::Expression>,
661 is_vector: bool,
662 expressions: &'c crate::Arena<crate::Expression>,
663 ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
664 let mut expr = *component;
665 let mut count = 1;
666 if is_vector {
667 if let crate::Expression::Splat { size, value } = expressions[expr] {
668 expr = value;
669 count = size as usize;
670 }
671 }
672 core::iter::repeat_n(expr, count)
673 }
674
675 components
682 .iter()
683 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
684 .flat_map(move |component| flatten_compose(component, is_vector, expressions))
685 .flat_map(move |component| flatten_splat(component, is_vector, expressions))
686 .take(size)
687}
688
689#[test]
690fn test_matrix_size() {
691 let module = crate::Module::default();
692 assert_eq!(
693 crate::TypeInner::Matrix {
694 columns: crate::VectorSize::Tri,
695 rows: crate::VectorSize::Tri,
696 scalar: crate::Scalar::F32,
697 }
698 .size(module.to_ctx()),
699 48,
700 );
701}
702
703impl crate::Module {
704 #[allow(clippy::type_complexity)]
714 pub fn analyze_mesh_shader_info(
715 &self,
716 gv: crate::Handle<crate::GlobalVariable>,
717 ) -> (
718 crate::MeshStageInfo,
719 [Option<crate::Handle<crate::Override>>; 2],
720 Option<crate::WithSpan<crate::valid::EntryPointError>>,
721 ) {
722 use crate::span::AddSpan;
723 use crate::valid::EntryPointError;
724 #[derive(Default)]
725 struct OutError {
726 pub inner: Option<EntryPointError>,
727 }
728 impl OutError {
729 pub fn set(&mut self, err: EntryPointError) {
730 if self.inner.is_none() {
731 self.inner = Some(err);
732 }
733 }
734 }
735
736 let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap());
738 let mut output = crate::MeshStageInfo {
739 topology: crate::MeshOutputTopology::Triangles,
740 max_vertices: 0,
741 max_vertices_override: None,
742 max_primitives: 0,
743 max_primitives_override: None,
744 vertex_output_type: null_type,
745 primitive_output_type: null_type,
746 output_variable: gv,
747 };
748 let mut error = OutError::default();
750 let r#type = &self.types[self.global_variables[gv].ty].inner;
751
752 let mut topology = output.topology;
753 let mut vertex_info = (0, None, null_type);
755 let mut primitive_info = (0, None, null_type);
756
757 match r#type {
758 &crate::TypeInner::Struct { ref members, .. } => {
759 let mut builtins = crate::FastHashSet::default();
760 for member in members {
761 match member.binding {
762 Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => {
763 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
765 error.set(EntryPointError::BadMeshOutputVariableField);
766 }
767 if builtins.contains(&crate::BuiltIn::VertexCount) {
769 error.set(EntryPointError::BadMeshOutputVariableType);
770 }
771 builtins.insert(crate::BuiltIn::VertexCount);
772 }
773 Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => {
774 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) {
776 error.set(EntryPointError::BadMeshOutputVariableField);
777 }
778 if builtins.contains(&crate::BuiltIn::PrimitiveCount) {
780 error.set(EntryPointError::BadMeshOutputVariableType);
781 }
782 builtins.insert(crate::BuiltIn::PrimitiveCount);
783 }
784 Some(crate::Binding::BuiltIn(
785 crate::BuiltIn::Vertices | crate::BuiltIn::Primitives,
786 )) => {
787 let ty = &self.types[member.ty].inner;
788 let (a, b, c) = match ty {
790 &crate::TypeInner::Array { base, size, .. } => {
791 let ty = base;
792 let (max, max_override) = match size {
793 crate::ArraySize::Constant(a) => (a.get(), None),
794 crate::ArraySize::Pending(o) => (0, Some(o)),
795 crate::ArraySize::Dynamic => {
796 error.set(EntryPointError::BadMeshOutputVariableField);
797 (0, None)
798 }
799 };
800 (max, max_override, ty)
801 }
802 _ => {
803 error.set(EntryPointError::BadMeshOutputVariableField);
804 (0, None, null_type)
805 }
806 };
807 if matches!(
808 member.binding,
809 Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives))
810 ) {
811 primitive_info = (a, b, c);
813 match self.types[c].inner {
814 crate::TypeInner::Struct { ref members, .. } => {
815 for member in members {
816 match member.binding {
817 Some(crate::Binding::BuiltIn(
818 crate::BuiltIn::PointIndex,
819 )) => {
820 topology = crate::MeshOutputTopology::Points;
821 }
822 Some(crate::Binding::BuiltIn(
823 crate::BuiltIn::LineIndices,
824 )) => {
825 topology = crate::MeshOutputTopology::Lines;
826 }
827 Some(crate::Binding::BuiltIn(
828 crate::BuiltIn::TriangleIndices,
829 )) => {
830 topology = crate::MeshOutputTopology::Triangles;
831 }
832 _ => (),
833 }
834 }
835 }
836 _ => (),
837 }
838 if builtins.contains(&crate::BuiltIn::Primitives) {
840 error.set(EntryPointError::BadMeshOutputVariableType);
841 }
842 builtins.insert(crate::BuiltIn::Primitives);
843 } else {
844 vertex_info = (a, b, c);
845 if builtins.contains(&crate::BuiltIn::Vertices) {
847 error.set(EntryPointError::BadMeshOutputVariableType);
848 }
849 builtins.insert(crate::BuiltIn::Vertices);
850 }
851 }
852 _ => error.set(EntryPointError::BadMeshOutputVariableType),
853 }
854 }
855 output = crate::MeshStageInfo {
856 topology,
857 max_vertices: vertex_info.0,
858 max_vertices_override: None,
859 vertex_output_type: vertex_info.2,
860 max_primitives: primitive_info.0,
861 max_primitives_override: None,
862 primitive_output_type: primitive_info.2,
863 ..output
864 }
865 }
866 _ => error.set(EntryPointError::BadMeshOutputVariableType),
867 }
868 (
869 output,
870 [vertex_info.1, primitive_info.1],
871 error
872 .inner
873 .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)),
874 )
875 }
876
877 pub fn uses_mesh_shaders(&self) -> bool {
878 let binding_uses_mesh = |b: &crate::Binding| {
879 matches!(
880 b,
881 crate::Binding::BuiltIn(
882 crate::BuiltIn::MeshTaskSize
883 | crate::BuiltIn::CullPrimitive
884 | crate::BuiltIn::PointIndex
885 | crate::BuiltIn::LineIndices
886 | crate::BuiltIn::TriangleIndices
887 | crate::BuiltIn::VertexCount
888 | crate::BuiltIn::Vertices
889 | crate::BuiltIn::PrimitiveCount
890 | crate::BuiltIn::Primitives,
891 ) | crate::Binding::Location {
892 per_primitive: true,
893 ..
894 }
895 )
896 };
897 for (_, ty) in self.types.iter() {
898 match ty.inner {
899 crate::TypeInner::Struct { ref members, .. } => {
900 for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
901 if binding_uses_mesh(binding) {
902 return true;
903 }
904 }
905 }
906 _ => (),
907 }
908 }
909 for ep in &self.entry_points {
910 if matches!(
911 ep.stage,
912 crate::ShaderStage::Mesh | crate::ShaderStage::Task
913 ) {
914 return true;
915 }
916 for binding in ep
917 .function
918 .arguments
919 .iter()
920 .filter_map(|arg| arg.binding.as_ref())
921 .chain(
922 ep.function
923 .result
924 .iter()
925 .filter_map(|res| res.binding.as_ref()),
926 )
927 {
928 if binding_uses_mesh(binding) {
929 return true;
930 }
931 }
932 }
933 if self
934 .global_variables
935 .iter()
936 .any(|gv| gv.1.space == crate::AddressSpace::TaskPayload)
937 {
938 return true;
939 }
940 false
941 }
942
943 pub fn uses_ray_tracing(&self, ep_index: Option<usize>) -> RayTracingUses {
944 let mut uses = RayTracingUses::default();
945 let mut uses_ray_tracing = self.special_types.ray_desc.is_some();
947
948 uses.queries |= self.special_types.ray_intersection.is_some();
949
950 for (_, &crate::Type { ref inner, .. }) in self.types.iter() {
951 match *inner {
953 crate::TypeInner::AccelerationStructure { .. } => {
954 uses_ray_tracing = true;
955 }
956 crate::TypeInner::RayQuery { .. } => uses.queries = true,
957 _ => {}
958 }
959 }
960
961 for (index, ep) in self.entry_points.iter().enumerate() {
962 if ep_index.is_some() && ep_index != Some(index) {
963 continue;
964 }
965
966 if matches!(
971 ep.stage,
972 crate::ShaderStage::RayGeneration
973 | crate::ShaderStage::AnyHit
974 | crate::ShaderStage::ClosestHit
975 | crate::ShaderStage::Miss
976 ) {
977 uses.pipelines = true;
978 } else {
979 uses.queries |= uses_ray_tracing;
980 }
981 }
982
983 uses
984 }
985}
986
987#[derive(Copy, Clone, Debug, Default)]
988pub struct RayTracingUses {
989 pub pipelines: bool,
990 pub queries: bool,
991}
992
993impl crate::MeshOutputTopology {
994 pub const fn to_builtin(self) -> crate::BuiltIn {
995 match self {
996 Self::Points => crate::BuiltIn::PointIndex,
997 Self::Lines => crate::BuiltIn::LineIndices,
998 Self::Triangles => crate::BuiltIn::TriangleIndices,
999 }
1000 }
1001}
1002
1003impl crate::AddressSpace {
1004 pub const fn is_workgroup_like(self) -> bool {
1005 matches!(self, Self::WorkGroup | Self::TaskPayload)
1006 }
1007}
1008
1009impl TryFrom<crate::ScalarKind> for nt::glsl::GlslScalarKind {
1010 type Error = ();
1011
1012 fn try_from(value: crate::ScalarKind) -> Result<Self, Self::Error> {
1013 Ok(match value {
1014 crate::ScalarKind::Sint => nt::glsl::GlslScalarKind::Sint,
1015 crate::ScalarKind::Uint => nt::glsl::GlslScalarKind::Uint,
1016 crate::ScalarKind::Float => nt::glsl::GlslScalarKind::Float,
1017 _ => return Err(()),
1018 })
1019 }
1020}
1021
1022impl From<crate::VectorSize> for nt::glsl::GlslVectorSize {
1023 fn from(val: crate::VectorSize) -> Self {
1024 match val {
1025 crate::VectorSize::Bi => nt::glsl::GlslVectorSize::Bi,
1026 crate::VectorSize::Tri => nt::glsl::GlslVectorSize::Tri,
1027 crate::VectorSize::Quad => nt::glsl::GlslVectorSize::Quad,
1028 }
1029 }
1030}
1031
1032impl TryFrom<crate::Scalar> for nt::glsl::GlslScalar {
1033 type Error = ();
1034
1035 fn try_from(value: crate::Scalar) -> Result<Self, Self::Error> {
1036 Ok(nt::glsl::GlslScalar {
1037 kind: value.kind.try_into()?,
1038 width: value.width,
1039 })
1040 }
1041}
1042
1043impl TryFrom<&crate::TypeInner> for nt::glsl::GlslUniformType {
1044 type Error = ();
1045 fn try_from(value: &crate::TypeInner) -> Result<Self, Self::Error> {
1046 match *value {
1047 crate::TypeInner::Scalar(scalar) => {
1048 Ok(nt::glsl::GlslUniformType::Scalar(scalar.try_into()?))
1049 }
1050 crate::TypeInner::Vector { size, scalar } => Ok(nt::glsl::GlslUniformType::Vector {
1051 size: size.into(),
1052 scalar: scalar.try_into()?,
1053 }),
1054 crate::TypeInner::Matrix {
1055 columns,
1056 rows,
1057 scalar,
1058 } => Ok(nt::glsl::GlslUniformType::Matrix {
1059 columns: columns.into(),
1060 rows: rows.into(),
1061 scalar: scalar.try_into()?,
1062 }),
1063 _ => Err(()),
1064 }
1065 }
1066}