1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17 ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18 TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21 arena::{Handle, HandleSet},
22 back::{
23 self, get_entry_points,
24 msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25 Baked,
26 },
27 common,
28 proc::{
29 self, concrete_int_scalars,
30 index::{self, BoundsCheck},
31 ExternalTextureNameKey, NameKey, TypeResolution,
32 },
33 valid, FastHashMap, FastHashSet,
34};
35
36const ATOMIC_REFERENCE: &str = "&";
40
41pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
42pub(crate) const MODF_FUNCTION: &str = "naga_modf";
43pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
44pub(crate) const ABS_FUNCTION: &str = "naga_abs";
45pub(crate) const DIV_FUNCTION: &str = "naga_div";
46pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
47pub(crate) const MOD_FUNCTION: &str = "naga_mod";
48pub(crate) const NEG_FUNCTION: &str = "naga_neg";
49pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
50pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
51pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
52pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
53pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
54pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
55pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
56 "nagaTextureSampleBaseClampToEdge";
57pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
65pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
70pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
71pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
72
73fn put_numeric_type(
82 out: &mut impl Write,
83 scalar: crate::Scalar,
84 sizes: &[crate::VectorSize],
85) -> Result<(), FmtError> {
86 match (scalar, sizes) {
87 (scalar, &[]) => {
88 write!(out, "{}", scalar.to_msl_name())
89 }
90 (scalar, &[rows]) => {
91 write!(
92 out,
93 "{}::{}{}",
94 NAMESPACE,
95 scalar.to_msl_name(),
96 common::vector_size_str(rows)
97 )
98 }
99 (scalar, &[rows, columns]) => {
100 write!(
101 out,
102 "{}::{}{}x{}",
103 NAMESPACE,
104 scalar.to_msl_name(),
105 common::vector_size_str(columns),
106 common::vector_size_str(rows)
107 )
108 }
109 (_, _) => Ok(()), }
111}
112
113const fn scalar_is_int(scalar: crate::Scalar) -> bool {
114 use crate::ScalarKind::*;
115 match scalar.kind {
116 Sint | Uint | AbstractInt | Bool => true,
117 Float | AbstractFloat => false,
118 }
119}
120
121const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
123
124const REINTERPRET_PREFIX: &str = "reinterpreted_";
126
127struct ClampedLod(Handle<crate::Expression>);
133
134impl Display for ClampedLod {
135 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
136 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
137 }
138}
139
140struct ArraySizeMember(Handle<crate::GlobalVariable>);
155
156impl Display for ArraySizeMember {
157 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
158 self.0.write_prefixed(f, "size")
159 }
160}
161
162#[derive(Clone, Copy)]
167struct Reinterpreted<'a> {
168 target_type: &'a str,
169 orig: Handle<crate::Expression>,
170}
171
172impl<'a> Reinterpreted<'a> {
173 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
174 Self { target_type, orig }
175 }
176}
177
178impl Display for Reinterpreted<'_> {
179 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
180 f.write_str(REINTERPRET_PREFIX)?;
181 f.write_str(self.target_type)?;
182 self.orig.write_prefixed(f, "_e")
183 }
184}
185
186pub(super) struct TypeContext<'a> {
187 pub handle: Handle<crate::Type>,
188 pub gctx: proc::GlobalCtx<'a>,
189 pub names: &'a FastHashMap<NameKey, String>,
190 pub access: crate::StorageAccess,
191 pub first_time: bool,
192}
193
194impl TypeContext<'_> {
195 fn scalar(&self) -> Option<crate::Scalar> {
196 let ty = &self.gctx.types[self.handle];
197 ty.inner.scalar()
198 }
199
200 fn vector_size(&self) -> Option<crate::VectorSize> {
201 let ty = &self.gctx.types[self.handle];
202 match ty.inner {
203 crate::TypeInner::Vector { size, .. } => Some(size),
204 _ => None,
205 }
206 }
207
208 fn unwrap_array(self) -> Self {
209 match self.gctx.types[self.handle].inner {
210 crate::TypeInner::Array { base, .. } => Self {
211 handle: base,
212 ..self
213 },
214 _ => self,
215 }
216 }
217}
218
219impl Display for TypeContext<'_> {
220 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221 let ty = &self.gctx.types[self.handle];
222 if ty.needs_alias() && !self.first_time {
223 let name = &self.names[&NameKey::Type(self.handle)];
224 return write!(out, "{name}");
225 }
226
227 match ty.inner {
228 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229 crate::TypeInner::Atomic(scalar) => {
230 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231 }
232 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233 crate::TypeInner::Matrix {
234 columns,
235 rows,
236 scalar,
237 } => put_numeric_type(out, scalar, &[rows, columns]),
238 crate::TypeInner::CooperativeMatrix {
240 columns,
241 rows,
242 scalar,
243 role: _,
244 } => {
245 write!(
246 out,
247 "{NAMESPACE}::simdgroup_{}{}x{}",
248 scalar.to_msl_name(),
249 columns as u32,
250 rows as u32,
251 )
252 }
253 crate::TypeInner::Pointer { base, space } => {
254 let sub = Self {
255 handle: base,
256 first_time: false,
257 ..*self
258 };
259 let space_name = match space.to_msl_name() {
260 Some(name) => name,
261 None => return Ok(()),
262 };
263 write!(out, "{space_name} {sub}&")
264 }
265 crate::TypeInner::ValuePointer {
266 size,
267 scalar,
268 space,
269 } => {
270 match space.to_msl_name() {
271 Some(name) => write!(out, "{name} ")?,
272 None => return Ok(()),
273 };
274 match size {
275 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
276 None => put_numeric_type(out, scalar, &[])?,
277 };
278
279 write!(out, "&")
280 }
281 crate::TypeInner::Array { base, .. } => {
282 let sub = Self {
283 handle: base,
284 first_time: false,
285 ..*self
286 };
287 write!(out, "{sub}")
290 }
291 crate::TypeInner::Struct { .. } => unreachable!(),
292 crate::TypeInner::Image {
293 dim,
294 arrayed,
295 class,
296 } => {
297 let dim_str = match dim {
298 crate::ImageDimension::D1 => "1d",
299 crate::ImageDimension::D2 => "2d",
300 crate::ImageDimension::D3 => "3d",
301 crate::ImageDimension::Cube => "cube",
302 };
303 let (texture_str, msaa_str, scalar, access) = match class {
304 crate::ImageClass::Sampled { kind, multi } => {
305 let (msaa_str, access) = if multi {
306 ("_ms", "read")
307 } else {
308 ("", "sample")
309 };
310 let scalar = crate::Scalar { kind, width: 4 };
311 ("texture", msaa_str, scalar, access)
312 }
313 crate::ImageClass::Depth { multi } => {
314 let (msaa_str, access) = if multi {
315 ("_ms", "read")
316 } else {
317 ("", "sample")
318 };
319 let scalar = crate::Scalar {
320 kind: crate::ScalarKind::Float,
321 width: 4,
322 };
323 ("depth", msaa_str, scalar, access)
324 }
325 crate::ImageClass::Storage { format, .. } => {
326 let access = if self
327 .access
328 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
329 {
330 "read_write"
331 } else if self.access.contains(crate::StorageAccess::STORE) {
332 "write"
333 } else if self.access.contains(crate::StorageAccess::LOAD) {
334 "read"
335 } else {
336 log::warn!(
337 "Storage access for {:?} (name '{}'): {:?}",
338 self.handle,
339 ty.name.as_deref().unwrap_or_default(),
340 self.access
341 );
342 unreachable!("module is not valid");
343 };
344 ("texture", "", format.into(), access)
345 }
346 crate::ImageClass::External => {
347 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
348 }
349 };
350 let base_name = scalar.to_msl_name();
351 let array_str = if arrayed { "_array" } else { "" };
352 write!(
353 out,
354 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
355 )
356 }
357 crate::TypeInner::Sampler { comparison: _ } => {
358 write!(out, "{NAMESPACE}::sampler")
359 }
360 crate::TypeInner::AccelerationStructure { vertex_return } => {
361 if vertex_return {
362 unimplemented!("metal does not support vertex ray hit return")
363 }
364 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
365 }
366 crate::TypeInner::RayQuery { vertex_return } => {
367 if vertex_return {
368 unimplemented!("metal does not support vertex ray hit return")
369 }
370 write!(out, "{}", super::ray::metal_intersector_ty())
371 }
372 crate::TypeInner::BindingArray { base, .. } => {
373 let base_inner = &self.gctx.types[base].inner;
374 let base_tyname = Self {
375 handle: base,
376 first_time: false,
377 ..*self
378 };
379 match *base_inner {
380 crate::TypeInner::Struct { .. } => {
381 write!(
384 out,
385 "device {ARGUMENT_BUFFER_WRAPPER_STRUCT}<device {base_tyname}*>*"
386 )
387 }
388 _ => {
389 write!(
390 out,
391 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
392 )
393 }
394 }
395 }
396 }
397 }
398}
399
400pub(super) struct TypedGlobalVariable<'a> {
401 pub module: &'a crate::Module,
402 pub names: &'a FastHashMap<NameKey, String>,
403 pub handle: Handle<crate::GlobalVariable>,
404 pub usage: valid::GlobalUse,
405 pub reference: bool,
406}
407
408struct TypedGlobalVariableParts {
409 ty_name: String,
410 var_name: String,
411}
412
413impl TypedGlobalVariable<'_> {
414 fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
415 let var = &self.module.global_variables[self.handle];
416 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
417
418 let storage_access = match var.space {
419 crate::AddressSpace::Storage { access } => access,
420 _ => match self.module.types[var.ty].inner {
421 crate::TypeInner::Image {
422 class: crate::ImageClass::Storage { access, .. },
423 ..
424 } => access,
425 crate::TypeInner::BindingArray { base, .. } => {
426 match self.module.types[base].inner {
427 crate::TypeInner::Image {
428 class: crate::ImageClass::Storage { access, .. },
429 ..
430 } => access,
431 _ => crate::StorageAccess::default(),
432 }
433 }
434 _ => crate::StorageAccess::default(),
435 },
436 };
437 let ty_name = TypeContext {
438 handle: var.ty,
439 gctx: self.module.to_ctx(),
440 names: self.names,
441 access: storage_access,
442 first_time: false,
443 };
444
445 let (coherent, space, access, reference) = if matches!(
446 self.module.types[var.ty].inner,
447 crate::TypeInner::BindingArray { .. }
448 ) {
449 ("", "", "", "")
450 } else {
451 let access = if var.space.needs_access_qualifier()
452 && !self.usage.intersects(valid::GlobalUse::WRITE)
453 {
454 "const"
455 } else {
456 ""
457 };
458 match (var.space.to_msl_name(), var.space) {
459 (Some(space), crate::AddressSpace::WorkGroup) => {
460 ("", space, access, if self.reference { "&" } else { "" })
461 }
462 (Some(space), _) if self.reference => {
463 let coherent = if var
464 .memory_decorations
465 .contains(crate::MemoryDecorations::COHERENT)
466 {
467 "coherent "
468 } else {
469 ""
470 };
471 (coherent, space, access, "&")
472 }
473 _ => ("", "", "", ""),
474 }
475 };
476
477 let ty = format!(
478 "{coherent}{space}{}{ty_name}{}{access}{reference}",
479 if space.is_empty() { "" } else { " " },
480 if access.is_empty() { "" } else { " " },
481 );
482
483 Ok(TypedGlobalVariableParts {
484 ty_name: ty,
485 var_name: name.clone(),
486 })
487 }
488 pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
489 let parts = self.to_parts()?;
490
491 Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
492 }
493}
494
495#[derive(Eq, PartialEq, Hash)]
496pub(super) enum WrappedFunction {
497 UnaryOp {
498 op: crate::UnaryOperator,
499 ty: (Option<crate::VectorSize>, crate::Scalar),
500 },
501 BinaryOp {
502 op: crate::BinaryOperator,
503 left_ty: (Option<crate::VectorSize>, crate::Scalar),
504 right_ty: (Option<crate::VectorSize>, crate::Scalar),
505 },
506 Math {
507 fun: crate::MathFunction,
508 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
509 },
510 Cast {
511 src_scalar: crate::Scalar,
512 vector_size: Option<crate::VectorSize>,
513 dst_scalar: crate::Scalar,
514 },
515 ImageLoad {
516 class: crate::ImageClass,
517 },
518 ImageSample {
519 class: crate::ImageClass,
520 clamp_to_edge: bool,
521 },
522 ImageQuerySize {
523 class: crate::ImageClass,
524 },
525 CooperativeLoad {
526 space_name: &'static str,
527 columns: crate::CooperativeSize,
528 rows: crate::CooperativeSize,
529 scalar: crate::Scalar,
530 },
531 CooperativeMultiplyAdd {
532 space_name: &'static str,
533 columns: crate::CooperativeSize,
534 rows: crate::CooperativeSize,
535 intermediate: crate::CooperativeSize,
536 ab_scalar: crate::Scalar,
537 c_scalar: crate::Scalar,
538 },
539 RayQueryGetIntersection {
540 committed: bool,
541 },
542}
543
544#[expect(missing_debug_implementations, reason = "would be way too verbose?")]
545pub struct Writer<W> {
546 pub(super) out: W,
547 pub(super) names: FastHashMap<NameKey, String>,
548 pub(super) named_expressions: crate::NamedExpressions,
549 need_bake_expressions: back::NeedBakeExpressions,
551 pub(super) namer: proc::Namer,
552 pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
553 emit_int_div_checks: bool,
554 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
557 needs_object_memory_barriers: bool,
558}
559
560impl crate::Scalar {
561 pub(super) fn to_msl_name(self) -> &'static str {
562 use crate::ScalarKind as Sk;
563 match self {
564 Self {
565 kind: Sk::Float,
566 width: 4,
567 } => "float",
568 Self {
569 kind: Sk::Float,
570 width: 2,
571 } => "half",
572 Self {
573 kind: Sk::Sint,
574 width: 2,
575 } => "short",
576 Self {
577 kind: Sk::Uint,
578 width: 2,
579 } => "ushort",
580 Self {
581 kind: Sk::Sint,
582 width: 4,
583 } => "int",
584 Self {
585 kind: Sk::Uint,
586 width: 4,
587 } => "uint",
588 Self {
589 kind: Sk::Sint,
590 width: 8,
591 } => "long",
592 Self {
593 kind: Sk::Uint,
594 width: 8,
595 } => "ulong",
596 Self {
597 kind: Sk::Bool,
598 width: _,
599 } => "bool",
600 Self {
601 kind: Sk::AbstractInt | Sk::AbstractFloat,
602 width: _,
603 } => unreachable!("Found Abstract scalar kind"),
604 _ => unreachable!("Unsupported scalar kind: {:?}", self),
605 }
606 }
607}
608
609const fn separate(need_separator: bool) -> &'static str {
610 if need_separator {
611 ","
612 } else {
613 ""
614 }
615}
616
617fn should_pack_struct_member(
618 members: &[crate::StructMember],
619 span: u32,
620 index: usize,
621 module: &crate::Module,
622) -> Option<crate::Scalar> {
623 let member = &members[index];
624
625 let ty_inner = &module.types[member.ty].inner;
626 let last_offset = member.offset + ty_inner.size(module.to_ctx());
627 let next_offset = match members.get(index + 1) {
628 Some(next) => next.offset,
629 None => span,
630 };
631 let is_tight = next_offset == last_offset;
632
633 match *ty_inner {
634 crate::TypeInner::Vector {
635 size: crate::VectorSize::Tri,
636 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
637 } if is_tight => Some(scalar),
638 _ => None,
639 }
640}
641
642impl crate::AddressSpace {
643 const fn needs_pass_through(&self) -> bool {
647 match *self {
648 Self::Uniform
649 | Self::Storage { .. }
650 | Self::Private
651 | Self::WorkGroup
652 | Self::Immediate
653 | Self::Handle
654 | Self::TaskPayload => true,
655 Self::Function => false,
656 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
657 }
658 }
659
660 const fn needs_access_qualifier(&self) -> bool {
662 match *self {
663 Self::Storage { .. } => true,
668 Self::TaskPayload => true,
669 Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
670 Self::Private | Self::WorkGroup => false,
672 Self::Uniform | Self::Immediate => false,
674 Self::Handle | Self::Function => false,
676 }
677 }
678
679 const fn to_msl_name(self) -> Option<&'static str> {
680 match self {
681 Self::Handle => None,
682 Self::Uniform | Self::Immediate => Some("constant"),
683 Self::Storage { .. } => Some("device"),
684 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
688 Self::WorkGroup => Some("threadgroup"),
689 Self::TaskPayload => Some("object_data"),
690 Self::IncomingRayPayload => Some("ray_data"),
691 }
692 }
693}
694
695impl crate::Type {
696 const fn needs_alias(&self) -> bool {
698 use crate::TypeInner as Ti;
699
700 match self.inner {
701 Ti::Scalar(_)
703 | Ti::Vector { .. }
704 | Ti::Matrix { .. }
705 | Ti::CooperativeMatrix { .. }
706 | Ti::Atomic(_)
707 | Ti::Pointer { .. }
708 | Ti::ValuePointer { .. } => self.name.is_some(),
709 Ti::Struct { .. } | Ti::Array { .. } => true,
711 Ti::Image { .. }
713 | Ti::Sampler { .. }
714 | Ti::AccelerationStructure { .. }
715 | Ti::RayQuery { .. }
716 | Ti::BindingArray { .. } => false,
717 }
718 }
719}
720
721#[derive(Clone, Copy)]
722pub(super) enum FunctionOrigin {
723 Handle(Handle<crate::Function>),
724 EntryPoint(proc::EntryPointIndex),
725}
726
727pub(super) trait NameKeyExt {
728 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
729 match origin {
730 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
731 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
732 }
733 }
734
735 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
740 match origin {
741 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
742 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
743 }
744 }
745}
746
747impl NameKeyExt for NameKey {}
748
749#[derive(Clone, Copy)]
759enum LevelOfDetail {
760 Direct(Handle<crate::Expression>),
761 Restricted(Handle<crate::Expression>),
762}
763
764struct TexelAddress {
774 coordinate: Handle<crate::Expression>,
775 array_index: Option<Handle<crate::Expression>>,
776 sample: Option<Handle<crate::Expression>>,
777 level: Option<LevelOfDetail>,
778}
779
780pub(super) struct ExpressionContext<'a> {
781 pub(super) function: &'a crate::Function,
782 pub(super) origin: FunctionOrigin,
783 pub(super) info: &'a valid::FunctionInfo,
784 pub(super) module: &'a crate::Module,
785 pub(super) mod_info: &'a valid::ModuleInfo,
786 pub(super) pipeline_options: &'a PipelineOptions,
787 pub(super) lang_version: (u8, u8),
788 pub(super) policies: index::BoundsCheckPolicies,
789
790 pub(super) guarded_indices: HandleSet<crate::Expression>,
794 pub(super) force_loop_bounding: bool,
796 emit_int_div_checks: bool,
798 pub(super) ray_query_initialization_tracking: bool,
799}
800
801impl<'a> ExpressionContext<'a> {
802 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
803 self.info[handle].ty.inner_with(&self.module.types)
804 }
805
806 fn binding_array_index_from_chain(
809 &self,
810 mut expr: Handle<crate::Expression>,
811 global: Handle<crate::GlobalVariable>,
812 ) -> Option<index::GuardedIndex> {
813 let expressions = &self.function.expressions;
814 loop {
815 match expressions[expr] {
816 crate::Expression::Load { pointer } => expr = pointer,
817 crate::Expression::Access { base, index } => {
818 if matches!(
819 expressions[base],
820 crate::Expression::GlobalVariable(g) if g == global
821 ) {
822 return Some(index::GuardedIndex::Expression(index));
823 }
824 expr = base;
825 }
826 crate::Expression::AccessIndex { base, index } => {
827 if matches!(
828 expressions[base],
829 crate::Expression::GlobalVariable(g) if g == global
830 ) {
831 return Some(index::GuardedIndex::Known(index));
832 }
833 expr = base;
834 }
835 crate::Expression::GlobalVariable(_) => return None,
836 _ => return None,
837 }
838 }
839 }
840
841 fn is_global_access_chain(&self, expr: Handle<crate::Expression>) -> bool {
843 let expressions = &self.function.expressions;
844 match expressions[expr] {
845 crate::Expression::Access { base, .. } => match expressions[base] {
846 crate::Expression::GlobalVariable(_) => true,
847 crate::Expression::Access { .. } => self.is_global_access_chain(base),
848 _ => false,
849 },
850 crate::Expression::AccessIndex { base, .. } => {
851 matches!(expressions[base], crate::Expression::GlobalVariable(_))
852 }
853 _ => false,
854 }
855 }
856
857 fn struct_member_needs_arrow(
858 &self,
859 base: Handle<crate::Expression>,
860 originating_global_ty: impl FnOnce(&crate::TypeInner) -> bool,
861 ) -> bool {
862 let originating_matches = match self.function.originating_global(base) {
863 Some(gv) => {
864 originating_global_ty(&self.module.types[self.module.global_variables[gv].ty].inner)
865 }
866 None => false,
867 };
868 originating_matches && self.is_global_access_chain(base)
869 }
870
871 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
878 let image_ty = self.resolve_type(image);
879 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
880 class.is_mipmapped() && dim != crate::ImageDimension::D1
881 } else {
882 false
883 }
884 }
885
886 fn choose_bounds_check_policy(
887 &self,
888 pointer: Handle<crate::Expression>,
889 ) -> index::BoundsCheckPolicy {
890 self.policies
891 .choose_policy(pointer, &self.module.types, self.info)
892 }
893
894 fn access_needs_check(
896 &self,
897 base: Handle<crate::Expression>,
898 index: index::GuardedIndex,
899 ) -> Option<index::IndexableLength> {
900 index::access_needs_check(
901 base,
902 index,
903 self.module,
904 &self.function.expressions,
905 self.info,
906 )
907 }
908
909 fn bounds_check_iter(
911 &self,
912 chain: Handle<crate::Expression>,
913 ) -> impl Iterator<Item = BoundsCheck> + '_ {
914 index::bounds_check_iter(chain, self.module, self.function, self.info)
915 }
916
917 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
919 index::oob_local_types(self.module, self.function, self.info, self.policies)
920 }
921
922 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
923 match self.function.expressions[expr_handle] {
924 crate::Expression::AccessIndex { base, index } => {
925 let ty = match *self.resolve_type(base) {
926 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
927 ref ty => ty,
928 };
929 match *ty {
930 crate::TypeInner::Struct {
931 ref members, span, ..
932 } => should_pack_struct_member(members, span, index as usize, self.module),
933 _ => None,
934 }
935 }
936 _ => None,
937 }
938 }
939}
940
941pub(super) struct StatementContext<'a> {
942 pub(super) expression: ExpressionContext<'a>,
943 pub(super) result_struct: Option<&'a str>,
944}
945
946impl<W: Write> Writer<W> {
947 pub fn new(out: W) -> Self {
949 Writer {
950 out,
951 names: FastHashMap::default(),
952 named_expressions: Default::default(),
953 need_bake_expressions: Default::default(),
954 namer: proc::Namer::default(),
955 wrapped_functions: FastHashSet::default(),
956 emit_int_div_checks: true,
957 struct_member_pads: FastHashSet::default(),
958 needs_object_memory_barriers: false,
959 }
960 }
961
962 pub fn finish(self) -> W {
965 self.out
966 }
967
968 fn gen_force_bounded_loop_statements(
1075 &mut self,
1076 level: back::Level,
1077 context: &StatementContext,
1078 ) -> Option<(String, String)> {
1079 if !context.expression.force_loop_bounding {
1080 return None;
1081 }
1082
1083 let loop_bound_name = self.namer.call("loop_bound");
1084 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1087 let level = level.next();
1088 let break_and_inc = format!(
1089 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1090{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1091 );
1092
1093 Some((decl, break_and_inc))
1094 }
1095
1096 fn put_call_parameters(
1097 &mut self,
1098 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1099 context: &ExpressionContext,
1100 ) -> BackendResult {
1101 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1102 writer.put_expression(expr, context, true)
1103 })
1104 }
1105
1106 fn put_call_parameters_impl<C, E>(
1107 &mut self,
1108 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1109 ctx: &C,
1110 put_expression: E,
1111 ) -> BackendResult
1112 where
1113 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1114 {
1115 write!(self.out, "(")?;
1116 for (i, handle) in parameters.enumerate() {
1117 if i != 0 {
1118 write!(self.out, ", ")?;
1119 }
1120 put_expression(self, ctx, handle)?;
1121 }
1122 write!(self.out, ")")?;
1123 Ok(())
1124 }
1125
1126 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1132 let oob_local_types = context.oob_local_types();
1133 for &ty in oob_local_types.iter() {
1134 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1135 self.names.insert(name_key, self.namer.call("oob"));
1136 }
1137
1138 for (name_key, ty, init) in context
1139 .function
1140 .local_variables
1141 .iter()
1142 .map(|(local_handle, local)| {
1143 let name_key = NameKey::local(context.origin, local_handle);
1144 (name_key, local.ty, local.init)
1145 })
1146 .chain(oob_local_types.iter().map(|&ty| {
1147 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1148 (name_key, ty, None)
1149 }))
1150 {
1151 let ty_name = TypeContext {
1152 handle: ty,
1153 gctx: context.module.to_ctx(),
1154 names: &self.names,
1155 access: crate::StorageAccess::empty(),
1156 first_time: false,
1157 };
1158 write!(
1159 self.out,
1160 "{}{} {}",
1161 back::INDENT,
1162 ty_name,
1163 self.names[&name_key]
1164 )?;
1165 match init {
1166 Some(value) => {
1167 write!(self.out, " = ")?;
1168 self.put_expression(value, context, true)?;
1169 }
1170 None => {
1171 write!(self.out, " = {{}}")?;
1172 }
1173 };
1174 writeln!(self.out, ";")?;
1175
1176 if context.ray_query_initialization_tracking {
1178 if let crate::TypeInner::RayQuery { .. } = context.module.types[ty].inner {
1179 writeln!(
1180 self.out,
1181 "{}uint {}{} = 0u;",
1182 back::INDENT,
1183 super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
1184 self.names[&name_key]
1185 )?;
1186
1187 writeln!(
1188 self.out,
1189 "{}float {}{} = 0.0;",
1190 back::INDENT,
1191 super::ray::RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX,
1192 self.names[&name_key]
1193 )?;
1194 }
1195 }
1196 }
1197 Ok(())
1198 }
1199
1200 fn put_level_of_detail(
1201 &mut self,
1202 level: LevelOfDetail,
1203 context: &ExpressionContext,
1204 ) -> BackendResult {
1205 match level {
1206 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1207 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1208 }
1209 Ok(())
1210 }
1211
1212 fn put_image_query(
1213 &mut self,
1214 image: Handle<crate::Expression>,
1215 query: &str,
1216 level: Option<LevelOfDetail>,
1217 context: &ExpressionContext,
1218 ) -> BackendResult {
1219 self.put_expression(image, context, false)?;
1220 write!(self.out, ".get_{query}(")?;
1221 if let Some(level) = level {
1222 self.put_level_of_detail(level, context)?;
1223 }
1224 write!(self.out, ")")?;
1225 Ok(())
1226 }
1227
1228 fn put_image_size_query(
1229 &mut self,
1230 image: Handle<crate::Expression>,
1231 level: Option<LevelOfDetail>,
1232 kind: crate::ScalarKind,
1233 context: &ExpressionContext,
1234 ) -> BackendResult {
1235 if let crate::TypeInner::Image {
1236 class: crate::ImageClass::External,
1237 ..
1238 } = *context.resolve_type(image)
1239 {
1240 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1241 self.put_expression(image, context, true)?;
1242 write!(self.out, ")")?;
1243 return Ok(());
1244 }
1245
1246 let dim = match *context.resolve_type(image) {
1249 crate::TypeInner::Image { dim, .. } => dim,
1250 ref other => unreachable!("Unexpected type {:?}", other),
1251 };
1252 let scalar = crate::Scalar { kind, width: 4 };
1253 let coordinate_type = scalar.to_msl_name();
1254 match dim {
1255 crate::ImageDimension::D1 => {
1256 if kind == crate::ScalarKind::Uint {
1260 self.put_image_query(image, "width", None, context)?;
1262 } else {
1263 write!(self.out, "int(")?;
1265 self.put_image_query(image, "width", None, context)?;
1266 write!(self.out, ")")?;
1267 }
1268 }
1269 crate::ImageDimension::D2 => {
1270 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1271 self.put_image_query(image, "width", level, context)?;
1272 write!(self.out, ", ")?;
1273 self.put_image_query(image, "height", level, context)?;
1274 write!(self.out, ")")?;
1275 }
1276 crate::ImageDimension::D3 => {
1277 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1278 self.put_image_query(image, "width", level, context)?;
1279 write!(self.out, ", ")?;
1280 self.put_image_query(image, "height", level, context)?;
1281 write!(self.out, ", ")?;
1282 self.put_image_query(image, "depth", level, context)?;
1283 write!(self.out, ")")?;
1284 }
1285 crate::ImageDimension::Cube => {
1286 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1287 self.put_image_query(image, "width", level, context)?;
1288 write!(self.out, ")")?;
1289 }
1290 }
1291 Ok(())
1292 }
1293
1294 fn put_cast_to_uint_scalar_or_vector(
1295 &mut self,
1296 expr: Handle<crate::Expression>,
1297 context: &ExpressionContext,
1298 ) -> BackendResult {
1299 match *context.resolve_type(expr) {
1301 crate::TypeInner::Scalar(_) => {
1302 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1303 }
1304 crate::TypeInner::Vector { size, .. } => {
1305 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1306 }
1307 _ => {
1308 return Err(Error::GenericValidation(
1309 "Invalid type for image coordinate".into(),
1310 ))
1311 }
1312 };
1313
1314 write!(self.out, "(")?;
1315 self.put_expression(expr, context, true)?;
1316 write!(self.out, ")")?;
1317 Ok(())
1318 }
1319
1320 fn put_image_sample_level(
1321 &mut self,
1322 image: Handle<crate::Expression>,
1323 level: crate::SampleLevel,
1324 context: &ExpressionContext,
1325 ) -> BackendResult {
1326 let has_levels = context.image_needs_lod(image);
1327 match level {
1328 crate::SampleLevel::Auto => {}
1329 crate::SampleLevel::Zero => {
1330 }
1332 _ if !has_levels => {
1333 log::warn!("1D image can't be sampled with level {level:?}");
1334 }
1335 crate::SampleLevel::Exact(h) => {
1336 write!(self.out, ", {NAMESPACE}::level(")?;
1337 self.put_expression(h, context, true)?;
1338 write!(self.out, ")")?;
1339 }
1340 crate::SampleLevel::Bias(h) => {
1341 write!(self.out, ", {NAMESPACE}::bias(")?;
1342 self.put_expression(h, context, true)?;
1343 write!(self.out, ")")?;
1344 }
1345 crate::SampleLevel::Gradient { x, y } => {
1346 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1347 self.put_expression(x, context, true)?;
1348 write!(self.out, ", ")?;
1349 self.put_expression(y, context, true)?;
1350 write!(self.out, ")")?;
1351 }
1352 }
1353 Ok(())
1354 }
1355
1356 fn put_image_coordinate_limits(
1357 &mut self,
1358 image: Handle<crate::Expression>,
1359 level: Option<LevelOfDetail>,
1360 context: &ExpressionContext,
1361 ) -> BackendResult {
1362 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1363 write!(self.out, " - 1")?;
1364 Ok(())
1365 }
1366
1367 fn put_restricted_scalar_image_index(
1385 &mut self,
1386 image: Handle<crate::Expression>,
1387 index: Handle<crate::Expression>,
1388 limit_method: &str,
1389 context: &ExpressionContext,
1390 ) -> BackendResult {
1391 write!(self.out, "{NAMESPACE}::min(uint(")?;
1392 self.put_expression(index, context, true)?;
1393 write!(self.out, "), ")?;
1394 self.put_expression(image, context, false)?;
1395 write!(self.out, ".{limit_method}() - 1)")?;
1396 Ok(())
1397 }
1398
1399 fn put_restricted_texel_address(
1400 &mut self,
1401 image: Handle<crate::Expression>,
1402 address: &TexelAddress,
1403 context: &ExpressionContext,
1404 ) -> BackendResult {
1405 write!(self.out, "{NAMESPACE}::min(")?;
1407 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408 write!(self.out, ", ")?;
1409 self.put_image_coordinate_limits(image, address.level, context)?;
1410 write!(self.out, ")")?;
1411
1412 if let Some(array_index) = address.array_index {
1414 write!(self.out, ", ")?;
1415 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1416 }
1417
1418 if let Some(sample) = address.sample {
1420 write!(self.out, ", ")?;
1421 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1422 }
1423
1424 if let Some(level) = address.level {
1427 write!(self.out, ", ")?;
1428 self.put_level_of_detail(level, context)?;
1429 }
1430
1431 Ok(())
1432 }
1433
1434 fn put_image_access_bounds_check(
1436 &mut self,
1437 image: Handle<crate::Expression>,
1438 address: &TexelAddress,
1439 context: &ExpressionContext,
1440 ) -> BackendResult {
1441 let mut conjunction = "";
1442
1443 let level = if let Some(level) = address.level {
1446 write!(self.out, "uint(")?;
1447 self.put_level_of_detail(level, context)?;
1448 write!(self.out, ") < ")?;
1449 self.put_expression(image, context, true)?;
1450 write!(self.out, ".get_num_mip_levels()")?;
1451 conjunction = " && ";
1452 Some(level)
1453 } else {
1454 None
1455 };
1456
1457 if let Some(sample) = address.sample {
1459 write!(self.out, "uint(")?;
1460 self.put_expression(sample, context, true)?;
1461 write!(self.out, ") < ")?;
1462 self.put_expression(image, context, true)?;
1463 write!(self.out, ".get_num_samples()")?;
1464 conjunction = " && ";
1465 }
1466
1467 if let Some(array_index) = address.array_index {
1469 write!(self.out, "{conjunction}uint(")?;
1470 self.put_expression(array_index, context, true)?;
1471 write!(self.out, ") < ")?;
1472 self.put_expression(image, context, true)?;
1473 write!(self.out, ".get_array_size()")?;
1474 conjunction = " && ";
1475 }
1476
1477 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1479 crate::TypeInner::Vector { .. } => true,
1480 _ => false,
1481 };
1482 write!(self.out, "{conjunction}")?;
1483 if coord_is_vector {
1484 write!(self.out, "{NAMESPACE}::all(")?;
1485 }
1486 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1487 write!(self.out, " < ")?;
1488 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1489 if coord_is_vector {
1490 write!(self.out, ")")?;
1491 }
1492
1493 Ok(())
1494 }
1495
1496 fn put_image_load(
1497 &mut self,
1498 load: Handle<crate::Expression>,
1499 image: Handle<crate::Expression>,
1500 mut address: TexelAddress,
1501 context: &ExpressionContext,
1502 ) -> BackendResult {
1503 if let crate::TypeInner::Image {
1504 class: crate::ImageClass::External,
1505 ..
1506 } = *context.resolve_type(image)
1507 {
1508 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1509 self.put_expression(image, context, true)?;
1510 write!(self.out, ", ")?;
1511 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1512 write!(self.out, ")")?;
1513 return Ok(());
1514 }
1515
1516 match context.policies.image_load {
1517 proc::BoundsCheckPolicy::Restrict => {
1518 if address.level.is_some() {
1521 address.level = if context.image_needs_lod(image) {
1522 Some(LevelOfDetail::Restricted(load))
1523 } else {
1524 None
1525 }
1526 }
1527
1528 self.put_expression(image, context, false)?;
1529 write!(self.out, ".read(")?;
1530 self.put_restricted_texel_address(image, &address, context)?;
1531 write!(self.out, ")")?;
1532 }
1533 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1534 write!(self.out, "(")?;
1535 self.put_image_access_bounds_check(image, &address, context)?;
1536 write!(self.out, " ? ")?;
1537 self.put_unchecked_image_load(image, &address, context)?;
1538 write!(self.out, ": DefaultConstructible())")?;
1539 }
1540 proc::BoundsCheckPolicy::Unchecked => {
1541 self.put_unchecked_image_load(image, &address, context)?;
1542 }
1543 }
1544
1545 Ok(())
1546 }
1547
1548 fn put_unchecked_image_load(
1549 &mut self,
1550 image: Handle<crate::Expression>,
1551 address: &TexelAddress,
1552 context: &ExpressionContext,
1553 ) -> BackendResult {
1554 self.put_expression(image, context, false)?;
1555 write!(self.out, ".read(")?;
1556 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1558 if let Some(expr) = address.array_index {
1559 write!(self.out, ", ")?;
1560 self.put_expression(expr, context, true)?;
1561 }
1562 if let Some(sample) = address.sample {
1563 write!(self.out, ", ")?;
1564 self.put_expression(sample, context, true)?;
1565 }
1566 if let Some(level) = address.level {
1567 if context.image_needs_lod(image) {
1568 write!(self.out, ", ")?;
1569 self.put_level_of_detail(level, context)?;
1570 }
1571 }
1572 write!(self.out, ")")?;
1573
1574 Ok(())
1575 }
1576
1577 fn put_image_atomic(
1578 &mut self,
1579 level: back::Level,
1580 image: Handle<crate::Expression>,
1581 address: &TexelAddress,
1582 fun: crate::AtomicFunction,
1583 value: Handle<crate::Expression>,
1584 context: &StatementContext,
1585 ) -> BackendResult {
1586 write!(self.out, "{level}")?;
1587 self.put_expression(image, &context.expression, false)?;
1588 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1589 fun.to_msl_64_bit()?
1590 } else {
1591 fun.to_msl()
1592 };
1593 write!(self.out, ".atomic_{op}(")?;
1594 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1596 write!(self.out, ", ")?;
1597 self.put_expression(value, &context.expression, true)?;
1598 writeln!(self.out, ");")?;
1599
1600 let value_ty = context.expression.resolve_type(value);
1606 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1607 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1608 (_, Some(8)) => "ulong4(0uL)",
1609 _ => "uint4(0u)",
1610 };
1611 let coord_ty = context.expression.resolve_type(address.coordinate);
1612 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1613 ""
1614 } else {
1615 ".x"
1616 };
1617 write!(self.out, "{level}if (")?;
1618 self.put_expression(address.coordinate, &context.expression, true)?;
1619 write!(self.out, "{x} == -99999) {{ ")?;
1620 self.put_expression(image, &context.expression, false)?;
1621 write!(self.out, ".write({zero_value}, ")?;
1622 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1623 if let Some(array_index) = address.array_index {
1624 write!(self.out, ", ")?;
1625 self.put_expression(array_index, &context.expression, true)?;
1626 }
1627 writeln!(self.out, "); }}")?;
1628
1629 Ok(())
1630 }
1631
1632 fn put_image_store(
1633 &mut self,
1634 level: back::Level,
1635 image: Handle<crate::Expression>,
1636 address: &TexelAddress,
1637 value: Handle<crate::Expression>,
1638 context: &StatementContext,
1639 ) -> BackendResult {
1640 write!(self.out, "{level}")?;
1641 self.put_expression(image, &context.expression, false)?;
1642 write!(self.out, ".write(")?;
1643 self.put_expression(value, &context.expression, true)?;
1644 write!(self.out, ", ")?;
1645 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1647 if let Some(expr) = address.array_index {
1648 write!(self.out, ", ")?;
1649 self.put_expression(expr, &context.expression, true)?;
1650 }
1651 writeln!(self.out, ");")?;
1652
1653 Ok(())
1654 }
1655
1656 fn binding_array_layout_count(
1671 module: &crate::Module,
1672 pipeline_options: &PipelineOptions,
1673 global: Handle<crate::GlobalVariable>,
1674 ) -> u32 {
1675 let var = &module.global_variables[global];
1676 let crate::TypeInner::BindingArray { size, .. } = module.types[var.ty].inner else {
1677 unreachable!("binding_array_layout_count called on non-binding-array global");
1678 };
1679 let from_shader = match size {
1680 crate::ArraySize::Constant(n) => n.get(),
1681 crate::ArraySize::Pending(_) | crate::ArraySize::Dynamic => 0,
1682 };
1683 let from_layout = var
1684 .binding
1685 .and_then(|br| pipeline_options.binding_array_length_map.get(&br))
1686 .copied()
1687 .unwrap_or(0);
1688 from_shader.max(from_layout).max(1)
1689 }
1690
1691 fn put_binding_array_size_member_index(
1692 &mut self,
1693 index: index::GuardedIndex,
1694 context: &ExpressionContext,
1695 ) -> BackendResult {
1696 match index {
1697 index::GuardedIndex::Expression(expr) => {
1698 write!(self.out, "unsigned(")?;
1699 self.put_expression(expr, context, true)?;
1700 write!(self.out, ")")?;
1701 }
1702 index::GuardedIndex::Known(value) => write!(self.out, "{value}u")?,
1703 }
1704 Ok(())
1705 }
1706
1707 fn put_dynamic_array_max_index(
1708 &mut self,
1709 handle: Handle<crate::GlobalVariable>,
1710 chain_expr: Handle<crate::Expression>,
1711 context: &ExpressionContext,
1712 ) -> BackendResult {
1713 let global = &context.module.global_variables[handle];
1714 let (offset, array_ty) = match context.module.types[global.ty].inner {
1715 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1716 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1717 None => return Err(Error::GenericValidation("Struct has no members".into())),
1718 },
1719 crate::TypeInner::BindingArray { base, .. } => match context.module.types[base].inner {
1720 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1721 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1722 None => return Err(Error::GenericValidation("Struct has no members".into())),
1723 },
1724 _ => {
1725 return Err(Error::GenericValidation(
1726 "binding_array element must be a struct with a runtime-sized array field"
1727 .into(),
1728 ))
1729 }
1730 },
1731 crate::TypeInner::Array {
1732 size: crate::ArraySize::Dynamic,
1733 ..
1734 } => (0, global.ty),
1735 ref ty => {
1736 return Err(Error::GenericValidation(format!(
1737 "Expected type with dynamic array, got {ty:?}"
1738 )))
1739 }
1740 };
1741
1742 let (size, stride) = match context.module.types[array_ty].inner {
1743 crate::TypeInner::Array { base, stride, .. } => (
1744 context.module.types[base]
1745 .inner
1746 .size(context.module.to_ctx()),
1747 stride,
1748 ),
1749 ref ty => {
1750 return Err(Error::GenericValidation(format!(
1751 "Expected array type, got {ty:?}"
1752 )))
1753 }
1754 };
1755
1756 write!(
1769 self.out,
1770 "(_buffer_sizes.{member}",
1771 member = ArraySizeMember(handle),
1772 )?;
1773 if let crate::TypeInner::BindingArray { .. } = context.module.types[global.ty].inner {
1774 let Some(array_index) = context.binding_array_index_from_chain(chain_expr, handle)
1775 else {
1776 return Err(Error::GenericValidation(
1777 "Could not find binding_array index for buffer size".into(),
1778 ));
1779 };
1780 write!(self.out, "[")?;
1781 match array_index {
1782 index::GuardedIndex::Expression(expr) => {
1783 write!(self.out, "unsigned(")?;
1784 self.put_expression(expr, context, true)?;
1785 write!(self.out, ")")?;
1786 }
1787 index::GuardedIndex::Known(i) => {
1788 write!(self.out, "{i}u")?;
1789 }
1790 }
1791 write!(self.out, "]")?;
1792 }
1793 write!(
1794 self.out,
1795 " - {offset} - {size}) / {stride}",
1796 offset = offset,
1797 size = size,
1798 stride = stride,
1799 )?;
1800 Ok(())
1801 }
1802
1803 fn put_dot_product<T: Copy>(
1808 &mut self,
1809 arg: T,
1810 arg1: T,
1811 size: usize,
1812 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1813 ) -> BackendResult {
1814 write!(self.out, "(")?;
1817
1818 for index in 0..size {
1820 write!(self.out, " + ")?;
1823 extractor(self, arg, index)?;
1824 write!(self.out, " * ")?;
1825 extractor(self, arg1, index)?;
1826 }
1827
1828 write!(self.out, ")")?;
1829 Ok(())
1830 }
1831
1832 fn put_pack4x8(
1834 &mut self,
1835 arg: Handle<crate::Expression>,
1836 context: &ExpressionContext<'_>,
1837 was_signed: bool,
1838 clamp_bounds: Option<(&str, &str)>,
1839 ) -> Result<(), Error> {
1840 let write_arg = |this: &mut Self| -> BackendResult {
1841 if let Some((min, max)) = clamp_bounds {
1842 write!(this.out, "{NAMESPACE}::clamp(")?;
1844 this.put_expression(arg, context, true)?;
1845 write!(this.out, ", {min}, {max})")?;
1846 } else {
1847 this.put_expression(arg, context, true)?;
1848 }
1849 Ok(())
1850 };
1851
1852 if context.lang_version >= (2, 1) {
1853 let packed_type = if was_signed {
1854 "packed_char4"
1855 } else {
1856 "packed_uchar4"
1857 };
1858 write!(self.out, "as_type<uint>({packed_type}(")?;
1860 write_arg(self)?;
1861 write!(self.out, "))")?;
1862 } else {
1863 if was_signed {
1865 write!(self.out, "uint(")?;
1866 }
1867 write!(self.out, "(")?;
1868 write_arg(self)?;
1869 write!(self.out, "[0] & 0xFF) | ((")?;
1870 write_arg(self)?;
1871 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1872 write_arg(self)?;
1873 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1874 write_arg(self)?;
1875 write!(self.out, "[3] & 0xFF) << 24)")?;
1876 if was_signed {
1877 write!(self.out, ")")?;
1878 }
1879 }
1880
1881 Ok(())
1882 }
1883
1884 fn put_isign(
1887 &mut self,
1888 arg: Handle<crate::Expression>,
1889 context: &ExpressionContext,
1890 ) -> BackendResult {
1891 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1892 let scalar = context
1893 .resolve_type(arg)
1894 .scalar()
1895 .expect("put_isign should only be called for args which have an integer scalar type")
1896 .to_msl_name();
1897 match context.resolve_type(arg) {
1898 &crate::TypeInner::Vector { size, .. } => {
1899 let size = common::vector_size_str(size);
1900 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1901 }
1902 _ => {
1903 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1904 }
1905 }
1906 write!(self.out, ", (")?;
1907 self.put_expression(arg, context, true)?;
1908 write!(self.out, " > 0)), {scalar}(0), (")?;
1909 self.put_expression(arg, context, true)?;
1910 write!(self.out, " == 0))")?;
1911 Ok(())
1912 }
1913
1914 pub(super) fn put_const_expression(
1915 &mut self,
1916 expr_handle: Handle<crate::Expression>,
1917 module: &crate::Module,
1918 mod_info: &valid::ModuleInfo,
1919 arena: &crate::Arena<crate::Expression>,
1920 ) -> BackendResult {
1921 self.put_possibly_const_expression(
1922 expr_handle,
1923 arena,
1924 module,
1925 mod_info,
1926 &(module, mod_info),
1927 |&(_, mod_info), expr| &mod_info[expr],
1928 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1929 )
1930 }
1931
1932 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1933 match literal {
1934 crate::Literal::F64(_) => {
1935 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1936 }
1937 crate::Literal::F16(value) => {
1938 if value.is_infinite() {
1939 let sign = if value.is_sign_negative() { "-" } else { "" };
1940 write!(self.out, "{sign}INFINITY")?;
1941 } else if value.is_nan() {
1942 write!(self.out, "NAN")?;
1943 } else {
1944 let suffix = if value.fract() == f16::from_f32(0.0) {
1945 ".0h"
1946 } else {
1947 "h"
1948 };
1949 write!(self.out, "{value}{suffix}")?;
1950 }
1951 }
1952 crate::Literal::F32(value) => {
1953 if value.is_infinite() {
1954 let sign = if value.is_sign_negative() { "-" } else { "" };
1955 write!(self.out, "{sign}INFINITY")?;
1956 } else if value.is_nan() {
1957 write!(self.out, "NAN")?;
1958 } else {
1959 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1960 write!(self.out, "{value}{suffix}")?;
1961 }
1962 }
1963 crate::Literal::U16(value) => {
1964 write!(self.out, "static_cast<ushort>({value})")?;
1965 }
1966 crate::Literal::I16(value) => {
1967 write!(self.out, "static_cast<short>({value})")?;
1968 }
1969 crate::Literal::U32(value) => {
1970 write!(self.out, "{value}u")?;
1971 }
1972 crate::Literal::I32(value) => {
1973 if value == i32::MIN {
1978 write!(self.out, "({} - 1)", value + 1)?;
1979 } else {
1980 write!(self.out, "{value}")?;
1981 }
1982 }
1983 crate::Literal::U64(value) => {
1984 write!(self.out, "{value}uL")?;
1985 }
1986 crate::Literal::I64(value) => {
1987 if value == i64::MIN {
1994 write!(self.out, "({}L - 1L)", value + 1)?;
1995 } else {
1996 write!(self.out, "{value}L")?;
1997 }
1998 }
1999 crate::Literal::Bool(value) => {
2000 write!(self.out, "{value}")?;
2001 }
2002 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2003 return Err(Error::GenericValidation(
2004 "Unsupported abstract literal".into(),
2005 ));
2006 }
2007 }
2008 Ok(())
2009 }
2010
2011 #[allow(clippy::too_many_arguments)]
2012 fn put_possibly_const_expression<C, I, E>(
2013 &mut self,
2014 expr_handle: Handle<crate::Expression>,
2015 expressions: &crate::Arena<crate::Expression>,
2016 module: &crate::Module,
2017 mod_info: &valid::ModuleInfo,
2018 ctx: &C,
2019 get_expr_ty: I,
2020 put_expression: E,
2021 ) -> BackendResult
2022 where
2023 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
2024 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
2025 {
2026 match expressions[expr_handle] {
2027 crate::Expression::Literal(literal) => {
2028 self.put_literal(literal)?;
2029 }
2030 crate::Expression::Constant(handle) => {
2031 let constant = &module.constants[handle];
2032 if constant.name.is_some() {
2033 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2034 } else {
2035 self.put_const_expression(
2036 constant.init,
2037 module,
2038 mod_info,
2039 &module.global_expressions,
2040 )?;
2041 }
2042 }
2043 crate::Expression::ZeroValue(ty) => {
2044 let ty_name = TypeContext {
2045 handle: ty,
2046 gctx: module.to_ctx(),
2047 names: &self.names,
2048 access: crate::StorageAccess::empty(),
2049 first_time: false,
2050 };
2051 write!(self.out, "{ty_name} {{}}")?;
2052 }
2053 crate::Expression::Compose { ty, ref components } => {
2054 let ty_name = TypeContext {
2055 handle: ty,
2056 gctx: module.to_ctx(),
2057 names: &self.names,
2058 access: crate::StorageAccess::empty(),
2059 first_time: false,
2060 };
2061 write!(self.out, "{ty_name}")?;
2062 match module.types[ty].inner {
2063 crate::TypeInner::Scalar(_)
2064 | crate::TypeInner::Vector { .. }
2065 | crate::TypeInner::Matrix { .. } => {
2066 self.put_call_parameters_impl(
2067 components.iter().copied(),
2068 ctx,
2069 put_expression,
2070 )?;
2071 }
2072 crate::TypeInner::Array { .. } => {
2073 write!(self.out, " {{{{")?;
2076 for (index, &component) in components.iter().enumerate() {
2077 if index != 0 {
2078 write!(self.out, ", ")?;
2079 }
2080 put_expression(self, ctx, component)?;
2081 }
2082 write!(self.out, "}}}}")?;
2083 }
2084 crate::TypeInner::Struct { .. } => {
2085 write!(self.out, " {{")?;
2086 for (index, &component) in components.iter().enumerate() {
2087 if index != 0 {
2088 write!(self.out, ", ")?;
2089 }
2090 if self.struct_member_pads.contains(&(ty, index as u32)) {
2092 write!(self.out, "{{}}, ")?;
2093 }
2094 put_expression(self, ctx, component)?;
2095 }
2096 write!(self.out, "}}")?;
2097 }
2098 _ => return Err(Error::UnsupportedCompose(ty)),
2099 }
2100 }
2101 crate::Expression::Splat { size, value } => {
2102 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
2103 crate::TypeInner::Scalar(scalar) => scalar,
2104 ref ty => {
2105 return Err(Error::GenericValidation(format!(
2106 "Expected splat value type must be a scalar, got {ty:?}",
2107 )))
2108 }
2109 };
2110 put_numeric_type(&mut self.out, scalar, &[size])?;
2111 write!(self.out, "(")?;
2112 put_expression(self, ctx, value)?;
2113 write!(self.out, ")")?;
2114 }
2115 _ => {
2116 return Err(Error::Override);
2117 }
2118 }
2119
2120 Ok(())
2121 }
2122
2123 pub(super) fn put_expression(
2135 &mut self,
2136 expr_handle: Handle<crate::Expression>,
2137 context: &ExpressionContext,
2138 is_scoped: bool,
2139 ) -> BackendResult {
2140 if let Some(name) = self.named_expressions.get(&expr_handle) {
2141 write!(self.out, "{name}")?;
2142 return Ok(());
2143 }
2144
2145 let expression = &context.function.expressions[expr_handle];
2146 match *expression {
2147 crate::Expression::Literal(_)
2148 | crate::Expression::Constant(_)
2149 | crate::Expression::ZeroValue(_)
2150 | crate::Expression::Compose { .. }
2151 | crate::Expression::Splat { .. } => {
2152 self.put_possibly_const_expression(
2153 expr_handle,
2154 &context.function.expressions,
2155 context.module,
2156 context.mod_info,
2157 context,
2158 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2159 |writer, context, expr| writer.put_expression(expr, context, true),
2160 )?;
2161 }
2162 crate::Expression::Override(_) => return Err(Error::Override),
2163 crate::Expression::Access { base, .. }
2164 | crate::Expression::AccessIndex { base, .. } => {
2165 let policy = context.choose_bounds_check_policy(base);
2171 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2172 && self.put_bounds_checks(
2173 expr_handle,
2174 context,
2175 back::Level(0),
2176 if is_scoped { "" } else { "(" },
2177 )?
2178 {
2179 write!(self.out, " ? ")?;
2180 self.put_access_chain(expr_handle, policy, context)?;
2181 write!(self.out, " : ")?;
2182
2183 if context.resolve_type(base).pointer_space().is_some() {
2184 let result_ty = context.info[expr_handle]
2188 .ty
2189 .inner_with(&context.module.types)
2190 .pointer_base_type();
2191 let result_ty_handle = match result_ty {
2192 Some(TypeResolution::Handle(handle)) => handle,
2193 Some(TypeResolution::Value(_)) => {
2194 unreachable!(
2202 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2203 );
2204 }
2205 None => {
2206 unreachable!(
2207 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2208 )
2209 }
2210 };
2211 let name_key =
2212 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2213 self.out.write_str(&self.names[&name_key])?;
2214 } else {
2215 write!(self.out, "DefaultConstructible()")?;
2216 }
2217
2218 if !is_scoped {
2219 write!(self.out, ")")?;
2220 }
2221 } else {
2222 self.put_access_chain(expr_handle, policy, context)?;
2223 }
2224 }
2225 crate::Expression::Swizzle {
2226 size,
2227 vector,
2228 pattern,
2229 } => {
2230 self.put_wrapped_expression_for_packed_vec3_access(
2231 vector,
2232 context,
2233 false,
2234 &Self::put_expression,
2235 )?;
2236 write!(self.out, ".")?;
2237 for &sc in pattern[..size as usize].iter() {
2238 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2239 }
2240 }
2241 crate::Expression::FunctionArgument(index) => {
2242 let name_key = match context.origin {
2243 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2244 FunctionOrigin::EntryPoint(ep_index) => {
2245 NameKey::EntryPointArgument(ep_index, index)
2246 }
2247 };
2248 let name = &self.names[&name_key];
2249 write!(self.out, "{name}")?;
2250 }
2251 crate::Expression::GlobalVariable(handle) => {
2252 let name = &self.names[&NameKey::GlobalVariable(handle)];
2253 write!(self.out, "{name}")?;
2254 }
2255 crate::Expression::LocalVariable(handle) => {
2256 let name_key = NameKey::local(context.origin, handle);
2257 let name = &self.names[&name_key];
2258 write!(self.out, "{name}")?;
2259 }
2260 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2261 crate::Expression::ImageSample {
2262 coordinate,
2263 image,
2264 sampler,
2265 clamp_to_edge: true,
2266 gather: None,
2267 array_index: None,
2268 offset: None,
2269 level: crate::SampleLevel::Zero,
2270 depth_ref: None,
2271 } => {
2272 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2273 self.put_expression(image, context, true)?;
2274 write!(self.out, ", ")?;
2275 self.put_expression(sampler, context, true)?;
2276 write!(self.out, ", ")?;
2277 self.put_expression(coordinate, context, true)?;
2278 write!(self.out, ")")?;
2279 }
2280 crate::Expression::ImageSample {
2281 image,
2282 sampler,
2283 gather,
2284 coordinate,
2285 array_index,
2286 offset,
2287 level,
2288 depth_ref,
2289 clamp_to_edge,
2290 } => {
2291 if clamp_to_edge {
2292 return Err(Error::GenericValidation(
2293 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2294 ));
2295 }
2296
2297 let main_op = match gather {
2298 Some(_) => "gather",
2299 None => "sample",
2300 };
2301 let comparison_op = match depth_ref {
2302 Some(_) => "_compare",
2303 None => "",
2304 };
2305 self.put_expression(image, context, false)?;
2306 write!(self.out, ".{main_op}{comparison_op}(")?;
2307 self.put_expression(sampler, context, true)?;
2308 write!(self.out, ", ")?;
2309 self.put_expression(coordinate, context, true)?;
2310 if let Some(expr) = array_index {
2311 write!(self.out, ", ")?;
2312 self.put_expression(expr, context, true)?;
2313 }
2314 if let Some(dref) = depth_ref {
2315 write!(self.out, ", ")?;
2316 self.put_expression(dref, context, true)?;
2317 }
2318
2319 self.put_image_sample_level(image, level, context)?;
2320
2321 if let Some(offset) = offset {
2322 write!(self.out, ", ")?;
2323 self.put_expression(offset, context, true)?;
2324 }
2325
2326 match gather {
2327 None | Some(crate::SwizzleComponent::X) => {}
2328 Some(component) => {
2329 let is_cube_map = match *context.resolve_type(image) {
2330 crate::TypeInner::Image {
2331 dim: crate::ImageDimension::Cube,
2332 ..
2333 } => true,
2334 _ => false,
2335 };
2336 if offset.is_none() && !is_cube_map {
2339 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2340 }
2341 let letter = back::COMPONENTS[component as usize];
2342 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2343 }
2344 }
2345 write!(self.out, ")")?;
2346 }
2347 crate::Expression::ImageLoad {
2348 image,
2349 coordinate,
2350 array_index,
2351 sample,
2352 level,
2353 } => {
2354 let address = TexelAddress {
2355 coordinate,
2356 array_index,
2357 sample,
2358 level: level.map(LevelOfDetail::Direct),
2359 };
2360 self.put_image_load(expr_handle, image, address, context)?;
2361 }
2362 crate::Expression::ImageQuery { image, query } => match query {
2365 crate::ImageQuery::Size { level } => {
2366 self.put_image_size_query(
2367 image,
2368 level.map(LevelOfDetail::Direct),
2369 crate::ScalarKind::Uint,
2370 context,
2371 )?;
2372 }
2373 crate::ImageQuery::NumLevels => {
2374 self.put_expression(image, context, false)?;
2375 write!(self.out, ".get_num_mip_levels()")?;
2376 }
2377 crate::ImageQuery::NumLayers => {
2378 self.put_expression(image, context, false)?;
2379 write!(self.out, ".get_array_size()")?;
2380 }
2381 crate::ImageQuery::NumSamples => {
2382 self.put_expression(image, context, false)?;
2383 write!(self.out, ".get_num_samples()")?;
2384 }
2385 },
2386 crate::Expression::Unary { op, expr } => {
2387 let op_str = match op {
2388 crate::UnaryOperator::Negate => {
2389 match context.resolve_type(expr).scalar_kind() {
2390 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2391 _ => "-",
2392 }
2393 }
2394 crate::UnaryOperator::LogicalNot => "!",
2395 crate::UnaryOperator::BitwiseNot => "~",
2396 };
2397 write!(self.out, "{op_str}(")?;
2398 self.put_expression(expr, context, false)?;
2399 write!(self.out, ")")?;
2400 }
2401 crate::Expression::Binary { op, left, right } => {
2402 let kind = context
2403 .resolve_type(left)
2404 .scalar_kind()
2405 .ok_or(Error::UnsupportedBinaryOp(op))?;
2406
2407 if op == crate::BinaryOperator::Divide
2408 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2409 && context.emit_int_div_checks
2410 {
2411 write!(self.out, "{DIV_FUNCTION}(")?;
2412 self.put_expression(left, context, true)?;
2413 write!(self.out, ", ")?;
2414 self.put_expression(right, context, true)?;
2415 write!(self.out, ")")?;
2416 } else if op == crate::BinaryOperator::Modulo
2417 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2418 && context.emit_int_div_checks
2419 {
2420 write!(self.out, "{MOD_FUNCTION}(")?;
2421 self.put_expression(left, context, true)?;
2422 write!(self.out, ", ")?;
2423 self.put_expression(right, context, true)?;
2424 write!(self.out, ")")?;
2425 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2426 write!(self.out, "{NAMESPACE}::fmod(")?;
2431 self.put_expression(left, context, true)?;
2432 write!(self.out, ", ")?;
2433 self.put_expression(right, context, true)?;
2434 write!(self.out, ")")?;
2435 } else if (op == crate::BinaryOperator::Add
2436 || op == crate::BinaryOperator::Subtract
2437 || op == crate::BinaryOperator::Multiply)
2438 && kind == crate::ScalarKind::Sint
2439 {
2440 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2441 crate::TypeInner::Scalar(scalar) => {
2442 Ok(crate::TypeInner::Scalar(crate::Scalar {
2443 kind: crate::ScalarKind::Uint,
2444 ..scalar
2445 }))
2446 }
2447 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2448 size,
2449 scalar: crate::Scalar {
2450 kind: crate::ScalarKind::Uint,
2451 ..scalar
2452 },
2453 }),
2454 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2455 };
2456
2457 self.put_bitcasted_expression(
2462 context.resolve_type(expr_handle),
2463 expr_handle,
2464 context,
2465 &|writer, context, is_scoped| {
2466 writer.put_binop(
2467 op,
2468 left,
2469 right,
2470 context,
2471 is_scoped,
2472 &|writer, expr, context, _is_scoped| {
2473 writer.put_bitcasted_expression(
2474 &to_unsigned(context.resolve_type(expr))?,
2475 expr,
2476 context,
2477 &|writer, context, is_scoped| {
2478 writer.put_expression(expr, context, is_scoped)
2479 },
2480 )
2481 },
2482 )
2483 },
2484 )?;
2485 } else {
2486 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2487 }
2488 }
2489 crate::Expression::Select {
2490 condition,
2491 accept,
2492 reject,
2493 } => match *context.resolve_type(condition) {
2494 crate::TypeInner::Scalar(crate::Scalar {
2495 kind: crate::ScalarKind::Bool,
2496 ..
2497 }) => {
2498 if !is_scoped {
2499 write!(self.out, "(")?;
2500 }
2501 self.put_expression(condition, context, false)?;
2502 write!(self.out, " ? ")?;
2503 self.put_expression(accept, context, false)?;
2504 write!(self.out, " : ")?;
2505 self.put_expression(reject, context, false)?;
2506 if !is_scoped {
2507 write!(self.out, ")")?;
2508 }
2509 }
2510 crate::TypeInner::Vector {
2511 scalar:
2512 crate::Scalar {
2513 kind: crate::ScalarKind::Bool,
2514 ..
2515 },
2516 ..
2517 } => {
2518 write!(self.out, "{NAMESPACE}::select(")?;
2519 self.put_expression(reject, context, true)?;
2520 write!(self.out, ", ")?;
2521 self.put_expression(accept, context, true)?;
2522 write!(self.out, ", ")?;
2523 self.put_expression(condition, context, true)?;
2524 write!(self.out, ")")?;
2525 }
2526 ref ty => {
2527 return Err(Error::GenericValidation(format!(
2528 "Expected select condition to be a non-bool type, got {ty:?}",
2529 )))
2530 }
2531 },
2532 crate::Expression::Derivative { axis, expr, .. } => {
2533 use crate::DerivativeAxis as Axis;
2534 let op = match axis {
2535 Axis::X => "dfdx",
2536 Axis::Y => "dfdy",
2537 Axis::Width => "fwidth",
2538 };
2539 write!(self.out, "{NAMESPACE}::{op}")?;
2540 self.put_call_parameters(iter::once(expr), context)?;
2541 }
2542 crate::Expression::Relational { fun, argument } => {
2543 let op = match fun {
2544 crate::RelationalFunction::Any => "any",
2545 crate::RelationalFunction::All => "all",
2546 crate::RelationalFunction::IsNan => "isnan",
2547 crate::RelationalFunction::IsInf => "isinf",
2548 };
2549 write!(self.out, "{NAMESPACE}::{op}")?;
2550 self.put_call_parameters(iter::once(argument), context)?;
2551 }
2552 crate::Expression::Math {
2553 fun,
2554 arg,
2555 arg1,
2556 arg2,
2557 arg3,
2558 } => {
2559 use crate::MathFunction as Mf;
2560
2561 let arg_type = context.resolve_type(arg);
2562 let scalar_argument = match arg_type {
2563 &crate::TypeInner::Scalar(_) => true,
2564 _ => false,
2565 };
2566
2567 let fun_name = match fun {
2568 Mf::Abs => "abs",
2570 Mf::Min => "min",
2571 Mf::Max => "max",
2572 Mf::Clamp => "clamp",
2573 Mf::Saturate => "saturate",
2574 Mf::Cos => "cos",
2576 Mf::Cosh => "cosh",
2577 Mf::Sin => "sin",
2578 Mf::Sinh => "sinh",
2579 Mf::Tan => "tan",
2580 Mf::Tanh => "tanh",
2581 Mf::Acos => "acos",
2582 Mf::Asin => "asin",
2583 Mf::Atan => "atan",
2584 Mf::Atan2 => "atan2",
2585 Mf::Asinh => "asinh",
2586 Mf::Acosh => "acosh",
2587 Mf::Atanh => "atanh",
2588 Mf::Radians => "",
2589 Mf::Degrees => "",
2590 Mf::Ceil => "ceil",
2592 Mf::Floor => "floor",
2593 Mf::Round => "rint",
2594 Mf::Fract => "fract",
2595 Mf::Trunc => "trunc",
2596 Mf::Modf => MODF_FUNCTION,
2597 Mf::Frexp => FREXP_FUNCTION,
2598 Mf::Ldexp => "ldexp",
2599 Mf::Exp => "exp",
2601 Mf::Exp2 => "exp2",
2602 Mf::Log => "log",
2603 Mf::Log2 => "log2",
2604 Mf::Pow => "pow",
2605 Mf::Dot => match *context.resolve_type(arg) {
2607 crate::TypeInner::Vector {
2608 scalar:
2609 crate::Scalar {
2610 kind: crate::ScalarKind::Float,
2612 ..
2613 },
2614 ..
2615 } => "dot",
2616 crate::TypeInner::Vector {
2617 size,
2618 scalar:
2619 scalar @ crate::Scalar {
2620 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2621 ..
2622 },
2623 } => {
2624 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2626 write!(self.out, "{fun_name}(")?;
2627 self.put_expression(arg, context, true)?;
2628 write!(self.out, ", ")?;
2629 self.put_expression(arg1.unwrap(), context, true)?;
2630 write!(self.out, ")")?;
2631 return Ok(());
2632 }
2633 _ => unreachable!(
2634 "Correct TypeInner for dot product should be already validated"
2635 ),
2636 },
2637 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2638 if context.lang_version >= (2, 1) {
2639 let packed_type = match fun {
2643 Mf::Dot4I8Packed => "packed_char4",
2644 Mf::Dot4U8Packed => "packed_uchar4",
2645 _ => unreachable!(),
2646 };
2647
2648 return self.put_dot_product(
2649 Reinterpreted::new(packed_type, arg),
2650 Reinterpreted::new(packed_type, arg1.unwrap()),
2651 4,
2652 |writer, arg, index| {
2653 write!(writer.out, "{arg}[{index}]")?;
2656 Ok(())
2657 },
2658 );
2659 } else {
2660 let conversion = match fun {
2664 Mf::Dot4I8Packed => "int",
2665 Mf::Dot4U8Packed => "",
2666 _ => unreachable!(),
2667 };
2668
2669 return self.put_dot_product(
2670 arg,
2671 arg1.unwrap(),
2672 4,
2673 |writer, arg, index| {
2674 write!(writer.out, "({conversion}(")?;
2675 writer.put_expression(arg, context, true)?;
2676 if index == 3 {
2677 write!(writer.out, ") >> 24)")?;
2678 } else {
2679 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2680 }
2681 Ok(())
2682 },
2683 );
2684 }
2685 }
2686 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2687 Mf::Cross => "cross",
2688 Mf::Distance => "distance",
2689 Mf::Length if scalar_argument => "abs",
2690 Mf::Length => "length",
2691 Mf::Normalize => "normalize",
2692 Mf::FaceForward => "faceforward",
2693 Mf::Reflect => "reflect",
2694 Mf::Refract => "refract",
2695 Mf::Sign => match arg_type.scalar_kind() {
2697 Some(crate::ScalarKind::Sint) => {
2698 return self.put_isign(arg, context);
2699 }
2700 _ => "sign",
2701 },
2702 Mf::Fma => "fma",
2703 Mf::Mix => "mix",
2704 Mf::Step => "step",
2705 Mf::SmoothStep => "smoothstep",
2706 Mf::Sqrt => "sqrt",
2707 Mf::InverseSqrt => "rsqrt",
2708 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2709 Mf::Transpose => "transpose",
2710 Mf::Determinant => "determinant",
2711 Mf::QuantizeToF16 => "",
2712 Mf::CountTrailingZeros => "ctz",
2714 Mf::CountLeadingZeros => "clz",
2715 Mf::CountOneBits => "popcount",
2716 Mf::ReverseBits => "reverse_bits",
2717 Mf::ExtractBits => "",
2718 Mf::InsertBits => "",
2719 Mf::FirstTrailingBit => "",
2720 Mf::FirstLeadingBit => "",
2721 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2723 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2724 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2725 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2726 Mf::Pack2x16float => "",
2727 Mf::Pack4xI8 => "",
2728 Mf::Pack4xU8 => "",
2729 Mf::Pack4xI8Clamp => "",
2730 Mf::Pack4xU8Clamp => "",
2731 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2733 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2734 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2735 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2736 Mf::Unpack2x16float => "",
2737 Mf::Unpack4xI8 => "",
2738 Mf::Unpack4xU8 => "",
2739 };
2740
2741 match fun {
2742 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2743 if context.lang_version < (1, 2) {
2752 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2753 }
2754 }
2755 _ => {}
2756 }
2757
2758 match fun {
2759 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2760 write!(self.out, "{ABS_FUNCTION}(")?;
2761 self.put_expression(arg, context, true)?;
2762 write!(self.out, ")")?;
2763 }
2764 Mf::Distance if scalar_argument => {
2765 write!(self.out, "{NAMESPACE}::abs(")?;
2766 self.put_expression(arg, context, false)?;
2767 write!(self.out, " - ")?;
2768 self.put_expression(arg1.unwrap(), context, false)?;
2769 write!(self.out, ")")?;
2770 }
2771 Mf::FirstTrailingBit => {
2772 let scalar = context.resolve_type(arg).scalar().unwrap();
2773 let constant = scalar.width * 8 + 1;
2774
2775 write!(self.out, "((({NAMESPACE}::ctz(")?;
2776 self.put_expression(arg, context, true)?;
2777 write!(self.out, ") + 1) % {constant}) - 1)")?;
2778 }
2779 Mf::FirstLeadingBit => {
2780 let inner = context.resolve_type(arg);
2781 let scalar = inner.scalar().unwrap();
2782 let constant = scalar.width * 8 - 1;
2783
2784 write!(
2785 self.out,
2786 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2787 )?;
2788
2789 if scalar.kind == crate::ScalarKind::Sint {
2790 write!(self.out, "{NAMESPACE}::select(")?;
2791 self.put_expression(arg, context, true)?;
2792 write!(self.out, ", ~")?;
2793 self.put_expression(arg, context, true)?;
2794 write!(self.out, ", ")?;
2795 self.put_expression(arg, context, true)?;
2796 write!(self.out, " < 0)")?;
2797 } else {
2798 self.put_expression(arg, context, true)?;
2799 }
2800
2801 write!(self.out, "), ")?;
2802
2803 match *inner {
2805 crate::TypeInner::Vector { size, scalar } => {
2806 let size = common::vector_size_str(size);
2807 let name = scalar.to_msl_name();
2808 write!(self.out, "{name}{size}")?;
2809 }
2810 crate::TypeInner::Scalar(scalar) => {
2811 let name = scalar.to_msl_name();
2812 write!(self.out, "{name}")?;
2813 }
2814 _ => (),
2815 }
2816
2817 write!(self.out, "(-1), ")?;
2818 self.put_expression(arg, context, true)?;
2819 write!(self.out, " == 0")?;
2820 if scalar.kind == crate::ScalarKind::Sint {
2821 write!(self.out, " || ")?;
2822 self.put_expression(arg, context, true)?;
2823 write!(self.out, " == -1")?;
2824 }
2825 write!(self.out, ")")?;
2826 }
2827 Mf::Unpack2x16float => {
2828 write!(self.out, "float2(as_type<half2>(")?;
2829 self.put_expression(arg, context, false)?;
2830 write!(self.out, "))")?;
2831 }
2832 Mf::Pack2x16float => {
2833 write!(self.out, "as_type<uint>(half2(")?;
2834 self.put_expression(arg, context, false)?;
2835 write!(self.out, "))")?;
2836 }
2837 Mf::ExtractBits => {
2838 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2855
2856 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2857 self.put_expression(arg, context, true)?;
2858 write!(self.out, ", {NAMESPACE}::min(")?;
2859 self.put_expression(arg1.unwrap(), context, true)?;
2860 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2861 self.put_expression(arg2.unwrap(), context, true)?;
2862 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2863 self.put_expression(arg1.unwrap(), context, true)?;
2864 write!(self.out, ", {scalar_bits}u)))")?;
2865 }
2866 Mf::InsertBits => {
2867 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2872
2873 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2874 self.put_expression(arg, context, true)?;
2875 write!(self.out, ", ")?;
2876 self.put_expression(arg1.unwrap(), context, true)?;
2877 write!(self.out, ", {NAMESPACE}::min(")?;
2878 self.put_expression(arg2.unwrap(), context, true)?;
2879 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2880 self.put_expression(arg3.unwrap(), context, true)?;
2881 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2882 self.put_expression(arg2.unwrap(), context, true)?;
2883 write!(self.out, ", {scalar_bits}u)))")?;
2884 }
2885 Mf::Radians => {
2886 write!(self.out, "((")?;
2887 self.put_expression(arg, context, false)?;
2888 write!(self.out, ") * 0.017453292519943295474)")?;
2889 }
2890 Mf::Degrees => {
2891 write!(self.out, "((")?;
2892 self.put_expression(arg, context, false)?;
2893 write!(self.out, ") * 57.295779513082322865)")?;
2894 }
2895 Mf::Modf | Mf::Frexp => {
2896 write!(self.out, "{fun_name}")?;
2897 self.put_call_parameters(iter::once(arg), context)?;
2898 }
2899 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2900 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2901 Mf::Pack4xI8Clamp => {
2902 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2903 }
2904 Mf::Pack4xU8Clamp => {
2905 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2906 }
2907 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2908 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2909 "u"
2910 } else {
2911 ""
2912 };
2913
2914 if context.lang_version >= (2, 1) {
2915 write!(
2917 self.out,
2918 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2919 )?;
2920 self.put_expression(arg, context, true)?;
2921 write!(self.out, "))")?;
2922 } else {
2923 write!(self.out, "({sign_prefix}int4(")?;
2925 self.put_expression(arg, context, true)?;
2926 write!(self.out, ", ")?;
2927 self.put_expression(arg, context, true)?;
2928 write!(self.out, " >> 8, ")?;
2929 self.put_expression(arg, context, true)?;
2930 write!(self.out, " >> 16, ")?;
2931 self.put_expression(arg, context, true)?;
2932 write!(self.out, " >> 24) << 24 >> 24)")?;
2933 }
2934 }
2935 Mf::QuantizeToF16 => {
2936 match *context.resolve_type(arg) {
2937 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2938 crate::TypeInner::Vector { size, .. } => write!(
2939 self.out,
2940 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2941 size = common::vector_size_str(size),
2942 )?,
2943 _ => unreachable!(
2944 "Correct TypeInner for QuantizeToF16 should be already validated"
2945 ),
2946 };
2947
2948 self.put_expression(arg, context, true)?;
2949 write!(self.out, "))")?;
2950 }
2951 _ => {
2952 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2953 self.put_call_parameters(
2954 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2955 context,
2956 )?;
2957 }
2958 }
2959 }
2960 crate::Expression::As {
2961 expr,
2962 kind,
2963 convert,
2964 } => match *context.resolve_type(expr) {
2965 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2966 if src.kind == crate::ScalarKind::Float
2967 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2968 && convert.is_some()
2969 {
2970 let fun_name = match (kind, convert) {
2974 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2975 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2976 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2977 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2978 _ => unreachable!(),
2979 };
2980 write!(self.out, "{fun_name}(")?;
2981 self.put_expression(expr, context, true)?;
2982 write!(self.out, ")")?;
2983 } else {
2984 let target_scalar = crate::Scalar {
2985 kind,
2986 width: convert.unwrap_or(src.width),
2987 };
2988 let op = match convert {
2989 Some(_) => "static_cast",
2990 None => "as_type",
2991 };
2992 write!(self.out, "{op}<")?;
2993 match *context.resolve_type(expr) {
2994 crate::TypeInner::Vector { size, .. } => {
2995 put_numeric_type(&mut self.out, target_scalar, &[size])?
2996 }
2997 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2998 };
2999 write!(self.out, ">(")?;
3000 self.put_expression(expr, context, true)?;
3001 write!(self.out, ")")?;
3002 }
3003 }
3004 crate::TypeInner::Matrix {
3005 columns,
3006 rows,
3007 scalar,
3008 } => {
3009 let target_scalar = crate::Scalar {
3010 kind,
3011 width: convert.unwrap_or(scalar.width),
3012 };
3013 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
3014 write!(self.out, "(")?;
3015 self.put_expression(expr, context, true)?;
3016 write!(self.out, ")")?;
3017 }
3018 ref ty => {
3019 return Err(Error::GenericValidation(format!(
3020 "Unsupported type for As: {ty:?}"
3021 )))
3022 }
3023 },
3024 crate::Expression::CallResult(_)
3026 | crate::Expression::AtomicResult { .. }
3027 | crate::Expression::WorkGroupUniformLoadResult { .. }
3028 | crate::Expression::SubgroupBallotResult
3029 | crate::Expression::SubgroupOperationResult { .. }
3030 | crate::Expression::RayQueryProceedResult => {
3031 unreachable!()
3032 }
3033 crate::Expression::ArrayLength(expr) => {
3034 let global = context.function.originating_global(expr).ok_or_else(|| {
3035 Error::GenericValidation(format!(
3036 "Could not find global variable for ArrayLength operand {:?}",
3037 context.function.expressions[expr]
3038 ))
3039 })?;
3040
3041 if !is_scoped {
3042 write!(self.out, "(")?;
3043 }
3044 write!(self.out, "1 + ")?;
3045 self.put_dynamic_array_max_index(global, expr, context)?;
3046 if !is_scoped {
3047 write!(self.out, ")")?;
3048 }
3049 }
3050 crate::Expression::RayQueryVertexPositions { .. } => {
3051 unimplemented!()
3052 }
3053 crate::Expression::RayQueryGetIntersection { query, committed } => {
3054 if context.lang_version < (2, 4) {
3055 return Err(Error::UnsupportedRayTracing);
3056 }
3057
3058 let crate::Expression::LocalVariable(query_var) =
3060 context.function.expressions[query]
3061 else {
3062 unreachable!()
3063 };
3064
3065 let tracker_expr_name = format!(
3066 "{}{}",
3067 super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
3068 self.names[&NameKey::local(context.origin, query_var)]
3069 );
3070
3071 write!(
3072 self.out,
3073 "{}_{committed}(",
3074 super::ray::INTERSECTION_FUNCTION_NAME
3075 )?;
3076 self.put_expression(query, context, true)?;
3077 if context.ray_query_initialization_tracking {
3078 write!(self.out, ", {tracker_expr_name}")?;
3079 }
3080 write!(self.out, ")")?;
3081 }
3082 crate::Expression::CooperativeLoad { ref data, .. } => {
3083 if context.lang_version < (2, 3) {
3084 return Err(Error::UnsupportedCooperativeMatrix);
3085 }
3086 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
3087 write!(self.out, "&")?;
3088 self.put_access_chain(data.pointer, context.policies.index, context)?;
3089 write!(self.out, ", ")?;
3090 self.put_expression(data.stride, context, true)?;
3091 write!(self.out, ", {})", !data.row_major)?;
3098 }
3099 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
3100 if context.lang_version < (2, 3) {
3101 return Err(Error::UnsupportedCooperativeMatrix);
3102 }
3103 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
3104 self.put_expression(a, context, true)?;
3105 write!(self.out, ", ")?;
3106 self.put_expression(b, context, true)?;
3107 write!(self.out, ", ")?;
3108 self.put_expression(c, context, true)?;
3109 write!(self.out, ")")?;
3110 }
3111 }
3112 Ok(())
3113 }
3114
3115 fn put_binop<F>(
3118 &mut self,
3119 op: crate::BinaryOperator,
3120 left: Handle<crate::Expression>,
3121 right: Handle<crate::Expression>,
3122 context: &ExpressionContext,
3123 is_scoped: bool,
3124 put_expression: &F,
3125 ) -> BackendResult
3126 where
3127 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3128 {
3129 let op_str = back::binary_operation_str(op);
3130
3131 if !is_scoped {
3132 write!(self.out, "(")?;
3133 }
3134
3135 if op == crate::BinaryOperator::Multiply
3138 && matches!(
3139 context.resolve_type(right),
3140 &crate::TypeInner::Matrix { .. }
3141 )
3142 {
3143 self.put_wrapped_expression_for_packed_vec3_access(
3144 left,
3145 context,
3146 false,
3147 put_expression,
3148 )?;
3149 } else {
3150 put_expression(self, left, context, false)?;
3151 }
3152
3153 write!(self.out, " {op_str} ")?;
3154
3155 if op == crate::BinaryOperator::Multiply
3157 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3158 {
3159 self.put_wrapped_expression_for_packed_vec3_access(
3160 right,
3161 context,
3162 false,
3163 put_expression,
3164 )?;
3165 } else {
3166 put_expression(self, right, context, false)?;
3167 }
3168
3169 if !is_scoped {
3170 write!(self.out, ")")?;
3171 }
3172
3173 Ok(())
3174 }
3175
3176 fn put_wrapped_expression_for_packed_vec3_access<F>(
3178 &mut self,
3179 expr_handle: Handle<crate::Expression>,
3180 context: &ExpressionContext,
3181 is_scoped: bool,
3182 put_expression: &F,
3183 ) -> BackendResult
3184 where
3185 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3186 {
3187 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3188 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3189 put_expression(self, expr_handle, context, is_scoped)?;
3190 write!(self.out, ")")?;
3191 } else {
3192 put_expression(self, expr_handle, context, is_scoped)?;
3193 }
3194 Ok(())
3195 }
3196
3197 fn put_bitcasted_expression<F>(
3200 &mut self,
3201 cast_to: &crate::TypeInner,
3202 inner_expr: Handle<crate::Expression>,
3203 context: &ExpressionContext,
3204 put_expression: &F,
3205 ) -> BackendResult
3206 where
3207 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3208 {
3209 let needs_truncation = match *cast_to {
3214 crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3215 crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3216 _ => false,
3217 };
3218
3219 write!(self.out, "as_type<")?;
3220 match *cast_to {
3221 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3222 crate::TypeInner::Vector { size, scalar } => {
3223 put_numeric_type(&mut self.out, scalar, &[size])?
3224 }
3225 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3226 };
3227 write!(self.out, ">(")?;
3228
3229 if needs_truncation {
3230 write!(self.out, "static_cast<")?;
3231 let unsigned_scalar = match *cast_to {
3233 crate::TypeInner::Scalar(scalar) => crate::Scalar {
3234 kind: crate::ScalarKind::Uint,
3235 ..scalar
3236 },
3237 crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3238 kind: crate::ScalarKind::Uint,
3239 ..scalar
3240 },
3241 _ => unreachable!(),
3242 };
3243 match *cast_to {
3244 crate::TypeInner::Scalar(_) => {
3245 put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3246 }
3247 crate::TypeInner::Vector { size, .. } => {
3248 put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3249 }
3250 _ => unreachable!(),
3251 };
3252 write!(self.out, ">(")?;
3253 }
3254
3255 if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3257 put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3258 write!(self.out, "(")?;
3259 put_expression(self, context, true)?;
3260 write!(self.out, ")")?;
3261 } else {
3262 put_expression(self, context, true)?;
3263 }
3264
3265 if needs_truncation {
3266 write!(self.out, ")")?;
3267 }
3268
3269 write!(self.out, ")")?;
3270 Ok(())
3271 }
3272
3273 fn put_index(
3275 &mut self,
3276 index: index::GuardedIndex,
3277 context: &ExpressionContext,
3278 is_scoped: bool,
3279 ) -> BackendResult {
3280 match index {
3281 index::GuardedIndex::Expression(expr) => {
3282 self.put_expression(expr, context, is_scoped)?
3283 }
3284 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3285 }
3286 Ok(())
3287 }
3288
3289 fn put_bounds_checks(
3319 &mut self,
3320 chain: Handle<crate::Expression>,
3321 context: &ExpressionContext,
3322 level: back::Level,
3323 prefix: &'static str,
3324 ) -> Result<bool, Error> {
3325 let mut check_written = false;
3326
3327 for item in context.bounds_check_iter(chain) {
3329 let BoundsCheck {
3330 base,
3331 index,
3332 length,
3333 } = item;
3334
3335 if check_written {
3336 write!(self.out, " && ")?;
3337 } else {
3338 write!(self.out, "{level}{prefix}")?;
3339 check_written = true;
3340 }
3341
3342 write!(self.out, "uint(")?;
3346 self.put_index(index, context, true)?;
3347 self.out.write_str(") < ")?;
3348 match length {
3349 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3350 index::IndexableLength::Dynamic => {
3351 let global = context.function.originating_global(base).ok_or_else(|| {
3352 Error::GenericValidation("Could not find originating global".into())
3353 })?;
3354 if matches!(
3355 context.module.types[context.module.global_variables[global].ty].inner,
3356 crate::TypeInner::BindingArray { .. }
3357 ) {
3358 write!(
3359 self.out,
3360 "{} && _buffer_sizes.{}[",
3361 Self::binding_array_layout_count(
3362 context.module,
3363 context.pipeline_options,
3364 global,
3365 ),
3366 ArraySizeMember(global),
3367 )?;
3368 self.put_binding_array_size_member_index(index, context)?;
3369 write!(self.out, "] != 0u")?;
3370 } else {
3371 write!(self.out, "1 + ")?;
3372 self.put_dynamic_array_max_index(global, base, context)?
3373 }
3374 }
3375 }
3376 }
3377
3378 Ok(check_written)
3379 }
3380
3381 fn put_access_chain(
3401 &mut self,
3402 chain: Handle<crate::Expression>,
3403 policy: index::BoundsCheckPolicy,
3404 context: &ExpressionContext,
3405 ) -> BackendResult {
3406 match context.function.expressions[chain] {
3407 crate::Expression::Access { base, index } => {
3408 let mut base_ty = context.resolve_type(base);
3409
3410 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3412 base_ty = &context.module.types[base].inner;
3413 }
3414
3415 self.put_subscripted_access_chain(
3416 base,
3417 base_ty,
3418 index::GuardedIndex::Expression(index),
3419 policy,
3420 context,
3421 )?;
3422 }
3423 crate::Expression::AccessIndex { base, index } => {
3424 let base_resolution = &context.info[base].ty;
3425 let mut base_ty = base_resolution.inner_with(&context.module.types);
3426 let mut base_ty_handle = base_resolution.handle();
3427
3428 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3430 base_ty = &context.module.types[base].inner;
3431 base_ty_handle = Some(base);
3432 }
3433
3434 match *base_ty {
3438 crate::TypeInner::Struct { .. } => {
3439 let base_ty = base_ty_handle.unwrap();
3440 self.put_access_chain(base, policy, context)?;
3441 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3442 write!(
3443 self.out,
3444 "{}{name}",
3445 if context.struct_member_needs_arrow(base, |ty| {
3446 matches!(ty, crate::TypeInner::BindingArray { .. })
3447 }) {
3448 "->"
3449 } else {
3450 "."
3451 },
3452 )?;
3453 }
3454 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3455 self.put_access_chain(base, policy, context)?;
3456 if context.get_packed_vec_kind(base).is_some() {
3459 write!(self.out, "[{index}]")?;
3460 } else {
3461 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3462 }
3463 }
3464 _ => {
3465 self.put_subscripted_access_chain(
3466 base,
3467 base_ty,
3468 index::GuardedIndex::Known(index),
3469 policy,
3470 context,
3471 )?;
3472 }
3473 }
3474 }
3475 _ => self.put_expression(chain, context, false)?,
3476 }
3477
3478 Ok(())
3479 }
3480
3481 fn put_subscripted_access_chain(
3498 &mut self,
3499 base: Handle<crate::Expression>,
3500 base_ty: &crate::TypeInner,
3501 index: index::GuardedIndex,
3502 policy: index::BoundsCheckPolicy,
3503 context: &ExpressionContext,
3504 ) -> BackendResult {
3505 let accessing_wrapped_array = match *base_ty {
3506 crate::TypeInner::Array {
3507 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3508 ..
3509 } => true,
3510 _ => false,
3511 };
3512 let accessing_wrapped_binding_array =
3513 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3514
3515 self.put_access_chain(base, policy, context)?;
3516 if accessing_wrapped_array {
3517 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3518 }
3519 write!(self.out, "[")?;
3520
3521 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3523 context.access_needs_check(base, index)
3524 } else {
3525 None
3526 };
3527 if let Some(limit) = restriction_needed {
3528 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3529 self.put_index(index, context, true)?;
3530 write!(self.out, "), ")?;
3531 match limit {
3532 index::IndexableLength::Known(limit) => {
3533 write!(self.out, "{}u", limit - 1)?;
3534 }
3535 index::IndexableLength::Dynamic => {
3536 let global = context.function.originating_global(base).ok_or_else(|| {
3537 Error::GenericValidation("Could not find originating global".into())
3538 })?;
3539 self.put_dynamic_array_max_index(global, base, context)?;
3540 }
3541 }
3542 write!(self.out, ")")?;
3543 } else {
3544 self.put_index(index, context, true)?;
3545 }
3546
3547 write!(self.out, "]")?;
3548
3549 if accessing_wrapped_binding_array {
3550 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3551 }
3552
3553 Ok(())
3554 }
3555
3556 fn put_load(
3557 &mut self,
3558 pointer: Handle<crate::Expression>,
3559 context: &ExpressionContext,
3560 is_scoped: bool,
3561 ) -> BackendResult {
3562 let policy = context.choose_bounds_check_policy(pointer);
3565 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3566 && self.put_bounds_checks(
3567 pointer,
3568 context,
3569 back::Level(0),
3570 if is_scoped { "" } else { "(" },
3571 )?
3572 {
3573 write!(self.out, " ? ")?;
3574 self.put_unchecked_load(pointer, policy, context)?;
3575 write!(self.out, " : DefaultConstructible()")?;
3576
3577 if !is_scoped {
3578 write!(self.out, ")")?;
3579 }
3580 } else {
3581 self.put_unchecked_load(pointer, policy, context)?;
3582 }
3583
3584 Ok(())
3585 }
3586
3587 fn put_unchecked_load(
3588 &mut self,
3589 pointer: Handle<crate::Expression>,
3590 policy: index::BoundsCheckPolicy,
3591 context: &ExpressionContext,
3592 ) -> BackendResult {
3593 let is_atomic_pointer = context
3594 .resolve_type(pointer)
3595 .is_atomic_pointer(&context.module.types);
3596
3597 if is_atomic_pointer {
3598 write!(
3599 self.out,
3600 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3601 )?;
3602 self.put_access_chain(pointer, policy, context)?;
3603 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3604 } else {
3605 self.put_access_chain(pointer, policy, context)?;
3609 }
3610
3611 Ok(())
3612 }
3613
3614 fn put_return_value(
3615 &mut self,
3616 level: back::Level,
3617 expr_handle: Handle<crate::Expression>,
3618 result_struct: Option<&str>,
3619 context: &ExpressionContext,
3620 ) -> BackendResult {
3621 match result_struct {
3622 Some(struct_name) => {
3623 let mut has_point_size = false;
3624 let result_ty = context.function.result.as_ref().unwrap().ty;
3625 match context.module.types[result_ty].inner {
3626 crate::TypeInner::Struct { ref members, .. } => {
3627 let tmp = self.namer.call("_tmp");
3628 write!(self.out, "{level}const auto {tmp} = ")?;
3629 self.put_expression(expr_handle, context, true)?;
3630 writeln!(self.out, ";")?;
3631 write!(self.out, "{level}return {struct_name} {{")?;
3632
3633 let mut is_first = true;
3634
3635 for (index, member) in members.iter().enumerate() {
3636 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3637 member.binding
3638 {
3639 has_point_size = true;
3640 if !context.pipeline_options.allow_and_force_point_size {
3641 continue;
3642 }
3643 }
3644
3645 let comma = if is_first { "" } else { "," };
3646 is_first = false;
3647 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3648 if let crate::TypeInner::Array {
3652 size: crate::ArraySize::Constant(size),
3653 ..
3654 } = context.module.types[member.ty].inner
3655 {
3656 write!(self.out, "{comma} {{")?;
3657 for j in 0..size.get() {
3658 if j != 0 {
3659 write!(self.out, ",")?;
3660 }
3661 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3662 }
3663 write!(self.out, "}}")?;
3664 } else {
3665 write!(self.out, "{comma} {tmp}.{name}")?;
3666 }
3667 }
3668 }
3669 _ => {
3670 write!(self.out, "{level}return {struct_name} {{ ")?;
3671 self.put_expression(expr_handle, context, true)?;
3672 }
3673 }
3674
3675 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3676 let stage = context.module.entry_points[ep_index as usize].stage;
3677 if context.pipeline_options.allow_and_force_point_size
3678 && stage == crate::ShaderStage::Vertex
3679 && !has_point_size
3680 {
3681 write!(self.out, ", 1.0")?;
3683 }
3684 }
3685 write!(self.out, " }}")?;
3686 }
3687 None => {
3688 write!(self.out, "{level}return ")?;
3689 self.put_expression(expr_handle, context, true)?;
3690 }
3691 }
3692 writeln!(self.out, ";")?;
3693 Ok(())
3694 }
3695
3696 fn update_expressions_to_bake(
3701 &mut self,
3702 func: &crate::Function,
3703 info: &valid::FunctionInfo,
3704 context: &ExpressionContext,
3705 ) {
3706 use crate::Expression;
3707 self.need_bake_expressions.clear();
3708
3709 for (expr_handle, expr) in func.expressions.iter() {
3710 let expr_info = &info[expr_handle];
3713 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3714 if min_ref_count <= expr_info.ref_count {
3715 self.need_bake_expressions.insert(expr_handle);
3716 } else {
3717 match expr_info.ty {
3718 TypeResolution::Handle(h)
3720 if Some(h) == context.module.special_types.ray_desc =>
3721 {
3722 self.need_bake_expressions.insert(expr_handle);
3723 }
3724 _ => {}
3725 }
3726 }
3727
3728 if let Expression::Math {
3729 fun,
3730 arg,
3731 arg1,
3732 arg2,
3733 ..
3734 } = *expr
3735 {
3736 match fun {
3737 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3747 self.need_bake_expressions.insert(arg);
3748 self.need_bake_expressions.insert(arg1.unwrap());
3749 }
3750 crate::MathFunction::FirstLeadingBit => {
3751 self.need_bake_expressions.insert(arg);
3752 }
3753 crate::MathFunction::Pack4xI8
3754 | crate::MathFunction::Pack4xU8
3755 | crate::MathFunction::Pack4xI8Clamp
3756 | crate::MathFunction::Pack4xU8Clamp
3757 | crate::MathFunction::Unpack4xI8
3758 | crate::MathFunction::Unpack4xU8 => {
3759 if context.lang_version < (2, 1) {
3762 self.need_bake_expressions.insert(arg);
3763 }
3764 }
3765 crate::MathFunction::ExtractBits => {
3766 self.need_bake_expressions.insert(arg1.unwrap());
3768 }
3769 crate::MathFunction::InsertBits => {
3770 self.need_bake_expressions.insert(arg2.unwrap());
3772 }
3773 crate::MathFunction::Sign => {
3774 let inner = context.resolve_type(expr_handle);
3779 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3780 self.need_bake_expressions.insert(arg);
3781 }
3782 }
3783 _ => {}
3784 }
3785 }
3786 }
3787 }
3788
3789 pub(super) fn start_baking_expression(
3790 &mut self,
3791 handle: Handle<crate::Expression>,
3792 context: &ExpressionContext,
3793 name: &str,
3794 ) -> BackendResult {
3795 match context.info[handle].ty {
3796 TypeResolution::Handle(ty_handle) => {
3797 let ty_name = TypeContext {
3798 handle: ty_handle,
3799 gctx: context.module.to_ctx(),
3800 names: &self.names,
3801 access: crate::StorageAccess::empty(),
3802 first_time: false,
3803 };
3804 write!(self.out, "{ty_name}")?;
3805 }
3806 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3807 put_numeric_type(&mut self.out, scalar, &[])?;
3808 }
3809 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3810 put_numeric_type(&mut self.out, scalar, &[size])?;
3811 }
3812 TypeResolution::Value(crate::TypeInner::Matrix {
3813 columns,
3814 rows,
3815 scalar,
3816 }) => {
3817 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3818 }
3819 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3820 columns,
3821 rows,
3822 scalar,
3823 role: _,
3824 }) => {
3825 write!(
3826 self.out,
3827 "{}::simdgroup_{}{}x{}",
3828 NAMESPACE,
3829 scalar.to_msl_name(),
3830 columns as u32,
3831 rows as u32,
3832 )?;
3833 }
3834 TypeResolution::Value(ref other) => {
3835 log::warn!("Type {other:?} isn't a known local");
3836 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3837 }
3838 }
3839
3840 write!(self.out, " {name} = ")?;
3842
3843 Ok(())
3844 }
3845
3846 fn put_cache_restricted_level(
3859 &mut self,
3860 load: Handle<crate::Expression>,
3861 image: Handle<crate::Expression>,
3862 mip_level: Option<Handle<crate::Expression>>,
3863 indent: back::Level,
3864 context: &StatementContext,
3865 ) -> BackendResult {
3866 let level_of_detail = match mip_level {
3869 Some(level) => level,
3870 None => return Ok(()),
3871 };
3872
3873 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3874 || !context.expression.image_needs_lod(image)
3875 {
3876 return Ok(());
3877 }
3878
3879 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3880 self.put_restricted_scalar_image_index(
3881 image,
3882 level_of_detail,
3883 "get_num_mip_levels",
3884 &context.expression,
3885 )?;
3886 writeln!(self.out, ";")?;
3887
3888 Ok(())
3889 }
3890
3891 fn put_casting_to_packed_chars(
3897 &mut self,
3898 fun: crate::MathFunction,
3899 arg0: Handle<crate::Expression>,
3900 arg1: Handle<crate::Expression>,
3901 indent: back::Level,
3902 context: &StatementContext<'_>,
3903 ) -> Result<(), Error> {
3904 let packed_type = match fun {
3905 crate::MathFunction::Dot4I8Packed => "packed_char4",
3906 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3907 _ => unreachable!(),
3908 };
3909
3910 for arg in [arg0, arg1] {
3911 write!(
3912 self.out,
3913 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3914 Reinterpreted::new(packed_type, arg)
3915 )?;
3916 self.put_expression(arg, &context.expression, true)?;
3917 writeln!(self.out, ");")?;
3918 }
3919
3920 Ok(())
3921 }
3922
3923 fn put_block(
3924 &mut self,
3925 level: back::Level,
3926 statements: &[crate::Statement],
3927 context: &StatementContext,
3928 ) -> BackendResult {
3929 for statement in statements {
3930 log::trace!("statement[{}] {:?}", level.0, statement);
3931 match *statement {
3932 crate::Statement::Emit(ref range) => {
3933 for handle in range.clone() {
3934 use crate::MathFunction as Mf;
3935
3936 match context.expression.function.expressions[handle] {
3937 crate::Expression::ImageLoad {
3940 image,
3941 level: mip_level,
3942 ..
3943 } => {
3944 self.put_cache_restricted_level(
3945 handle, image, mip_level, level, context,
3946 )?;
3947 }
3948
3949 crate::Expression::Math {
3958 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3959 arg,
3960 arg1,
3961 ..
3962 } if context.expression.lang_version >= (2, 1) => {
3963 self.put_casting_to_packed_chars(
3964 fun,
3965 arg,
3966 arg1.unwrap(),
3967 level,
3968 context,
3969 )?;
3970 }
3971
3972 _ => (),
3973 }
3974
3975 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3976 let expr_name = if ptr_class.is_some() {
3977 None } else if let Some(name) =
3979 context.expression.function.named_expressions.get(&handle)
3980 {
3981 Some(self.namer.call(name))
3991 } else {
3992 let bake = if context.expression.guarded_indices.contains(handle) {
3996 true
3997 } else {
3998 self.need_bake_expressions.contains(&handle)
3999 };
4000
4001 if bake {
4002 Some(Baked(handle).to_string())
4003 } else {
4004 None
4005 }
4006 };
4007
4008 if let Some(name) = expr_name {
4009 write!(self.out, "{level}")?;
4010 self.start_baking_expression(handle, &context.expression, &name)?;
4011 self.put_expression(handle, &context.expression, true)?;
4012 self.named_expressions.insert(handle, name);
4013 writeln!(self.out, ";")?;
4014 }
4015 }
4016 }
4017 crate::Statement::Block(ref block) => {
4018 if !block.is_empty() {
4019 writeln!(self.out, "{level}{{")?;
4020 self.put_block(level.next(), block, context)?;
4021 writeln!(self.out, "{level}}}")?;
4022 }
4023 }
4024 crate::Statement::If {
4025 condition,
4026 ref accept,
4027 ref reject,
4028 } => {
4029 write!(self.out, "{level}if (")?;
4030 self.put_expression(condition, &context.expression, true)?;
4031 writeln!(self.out, ") {{")?;
4032 self.put_block(level.next(), accept, context)?;
4033 if !reject.is_empty() {
4034 writeln!(self.out, "{level}}} else {{")?;
4035 self.put_block(level.next(), reject, context)?;
4036 }
4037 writeln!(self.out, "{level}}}")?;
4038 }
4039 crate::Statement::Switch {
4040 selector,
4041 ref cases,
4042 } => {
4043 write!(self.out, "{level}switch(")?;
4044 self.put_expression(selector, &context.expression, true)?;
4045 writeln!(self.out, ") {{")?;
4046 let lcase = level.next();
4047 for case in cases.iter() {
4048 match case.value {
4049 crate::SwitchValue::I32(value) => {
4050 write!(self.out, "{lcase}case {value}:")?;
4051 }
4052 crate::SwitchValue::U32(value) => {
4053 write!(self.out, "{lcase}case {value}u:")?;
4054 }
4055 crate::SwitchValue::Default => {
4056 write!(self.out, "{lcase}default:")?;
4057 }
4058 }
4059
4060 let write_block_braces = !(case.fall_through && case.body.is_empty());
4061 if write_block_braces {
4062 writeln!(self.out, " {{")?;
4063 } else {
4064 writeln!(self.out)?;
4065 }
4066
4067 self.put_block(lcase.next(), &case.body, context)?;
4068 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
4069 {
4070 writeln!(self.out, "{}break;", lcase.next())?;
4071 }
4072
4073 if write_block_braces {
4074 writeln!(self.out, "{lcase}}}")?;
4075 }
4076 }
4077 writeln!(self.out, "{level}}}")?;
4078 }
4079 crate::Statement::Loop {
4080 ref body,
4081 ref continuing,
4082 break_if,
4083 } => {
4084 let force_loop_bound_statements =
4085 self.gen_force_bounded_loop_statements(level, context);
4086 let gate_name = (!continuing.is_empty() || break_if.is_some())
4087 .then(|| self.namer.call("loop_init"));
4088
4089 if let Some((ref decl, _)) = force_loop_bound_statements {
4090 writeln!(self.out, "{decl}")?;
4091 }
4092 if let Some(ref gate_name) = gate_name {
4093 writeln!(self.out, "{level}bool {gate_name} = true;")?;
4094 }
4095
4096 writeln!(self.out, "{level}while(true) {{",)?;
4097 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
4098 writeln!(self.out, "{break_and_inc}")?;
4099 }
4100 if let Some(ref gate_name) = gate_name {
4101 let lif = level.next();
4102 let lcontinuing = lif.next();
4103 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
4104 self.put_block(lcontinuing, continuing, context)?;
4105 if let Some(condition) = break_if {
4106 write!(self.out, "{lcontinuing}if (")?;
4107 self.put_expression(condition, &context.expression, true)?;
4108 writeln!(self.out, ") {{")?;
4109 writeln!(self.out, "{}break;", lcontinuing.next())?;
4110 writeln!(self.out, "{lcontinuing}}}")?;
4111 }
4112 writeln!(self.out, "{lif}}}")?;
4113 writeln!(self.out, "{lif}{gate_name} = false;")?;
4114 }
4115 self.put_block(level.next(), body, context)?;
4116
4117 writeln!(self.out, "{level}}}")?;
4118 }
4119 crate::Statement::Break => {
4120 writeln!(self.out, "{level}break;")?;
4121 }
4122 crate::Statement::Continue => {
4123 writeln!(self.out, "{level}continue;")?;
4124 }
4125 crate::Statement::Return {
4126 value: Some(expr_handle),
4127 } => {
4128 self.put_return_value(
4129 level,
4130 expr_handle,
4131 context.result_struct,
4132 &context.expression,
4133 )?;
4134 }
4135 crate::Statement::Return { value: None } => {
4136 writeln!(self.out, "{level}return;")?;
4137 }
4138 crate::Statement::Kill => {
4139 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
4140 }
4141 crate::Statement::ControlBarrier(flags)
4142 | crate::Statement::MemoryBarrier(flags) => {
4143 self.write_barrier(flags, level)?;
4144 }
4145 crate::Statement::Store { pointer, value } => {
4146 self.put_store(pointer, value, level, context)?
4147 }
4148 crate::Statement::ImageStore {
4149 image,
4150 coordinate,
4151 array_index,
4152 value,
4153 } => {
4154 let address = TexelAddress {
4155 coordinate,
4156 array_index,
4157 sample: None,
4158 level: None,
4159 };
4160 self.put_image_store(level, image, &address, value, context)?
4161 }
4162 crate::Statement::Call {
4163 function,
4164 ref arguments,
4165 result,
4166 } => {
4167 write!(self.out, "{level}")?;
4168 if let Some(expr) = result {
4169 let name = Baked(expr).to_string();
4170 self.start_baking_expression(expr, &context.expression, &name)?;
4171 self.named_expressions.insert(expr, name);
4172 }
4173 let fun_name = &self.names[&NameKey::Function(function)];
4174 write!(self.out, "{fun_name}(")?;
4175 for (i, &handle) in arguments.iter().enumerate() {
4177 if i != 0 {
4178 write!(self.out, ", ")?;
4179 }
4180 self.put_expression(handle, &context.expression, true)?;
4181 }
4182 let mut separate = !arguments.is_empty();
4184 let fun_info = &context.expression.mod_info[function];
4185 let mut needs_buffer_sizes = false;
4186 for (handle, var) in context.expression.module.global_variables.iter() {
4187 if fun_info[handle].is_empty() {
4188 continue;
4189 }
4190 if var.space.needs_pass_through() {
4191 let name = &self.names[&NameKey::GlobalVariable(handle)];
4192 if separate {
4193 write!(self.out, ", ")?;
4194 } else {
4195 separate = true;
4196 }
4197 write!(self.out, "{name}")?;
4198 }
4199 needs_buffer_sizes |= context.expression.module.types[var.ty]
4200 .inner
4201 .needs_host_buffer_byte_size(&context.expression.module.types);
4202 }
4203 if needs_buffer_sizes {
4204 if separate {
4205 write!(self.out, ", ")?;
4206 }
4207 write!(self.out, "_buffer_sizes")?;
4208 }
4209
4210 writeln!(self.out, ");")?;
4212 }
4213 crate::Statement::Atomic {
4214 pointer,
4215 ref fun,
4216 value,
4217 result,
4218 } => {
4219 let context = &context.expression;
4220
4221 write!(self.out, "{level}")?;
4226 let fun_key = if let Some(result) = result {
4227 let res_name = Baked(result).to_string();
4228 self.start_baking_expression(result, context, &res_name)?;
4229 self.named_expressions.insert(result, res_name);
4230 fun.to_msl()
4231 } else if context.resolve_type(value).scalar_width() == Some(8) {
4232 fun.to_msl_64_bit()?
4233 } else {
4234 fun.to_msl()
4235 };
4236
4237 let policy = context.choose_bounds_check_policy(pointer);
4241 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4242 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4243
4244 if checked {
4246 write!(self.out, " ? ")?;
4247 }
4248
4249 match *fun {
4251 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4252 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4253 self.put_access_chain(pointer, policy, context)?;
4254 write!(self.out, ", ")?;
4255 self.put_expression(cmp, context, true)?;
4256 write!(self.out, ", ")?;
4257 self.put_expression(value, context, true)?;
4258 write!(self.out, ")")?;
4259 }
4260 _ => {
4261 write!(
4262 self.out,
4263 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4264 )?;
4265 self.put_access_chain(pointer, policy, context)?;
4266 write!(self.out, ", ")?;
4267 self.put_expression(value, context, true)?;
4268 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4269 }
4270 }
4271
4272 if checked {
4274 write!(self.out, " : DefaultConstructible()")?;
4275 }
4276
4277 writeln!(self.out, ";")?;
4279 }
4280 crate::Statement::ImageAtomic {
4281 image,
4282 coordinate,
4283 array_index,
4284 fun,
4285 value,
4286 } => {
4287 let address = TexelAddress {
4288 coordinate,
4289 array_index,
4290 sample: None,
4291 level: None,
4292 };
4293 self.put_image_atomic(level, image, &address, fun, value, context)?
4294 }
4295 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4296 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4297
4298 write!(self.out, "{level}")?;
4299 let name = self.namer.call("");
4300 self.start_baking_expression(result, &context.expression, &name)?;
4301 self.put_load(pointer, &context.expression, true)?;
4302 self.named_expressions.insert(result, name);
4303
4304 writeln!(self.out, ";")?;
4305 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4306 }
4307 crate::Statement::RayQuery { query, ref fun } => {
4308 self.write_ray_query_stmt(level, context, query, fun)?;
4309 }
4310 crate::Statement::SubgroupBallot { result, predicate } => {
4311 write!(self.out, "{level}")?;
4312 let name = self.namer.call("");
4313 self.start_baking_expression(result, &context.expression, &name)?;
4314 self.named_expressions.insert(result, name);
4315 write!(
4316 self.out,
4317 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4318 )?;
4319 if let Some(predicate) = predicate {
4320 self.put_expression(predicate, &context.expression, true)?;
4321 } else {
4322 write!(self.out, "true")?;
4323 }
4324 writeln!(self.out, "), 0, 0, 0);")?;
4325 }
4326 crate::Statement::SubgroupCollectiveOperation {
4327 op,
4328 collective_op,
4329 argument,
4330 result,
4331 } => {
4332 write!(self.out, "{level}")?;
4333 let name = self.namer.call("");
4334 self.start_baking_expression(result, &context.expression, &name)?;
4335 self.named_expressions.insert(result, name);
4336 match (collective_op, op) {
4337 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4338 write!(self.out, "{NAMESPACE}::simd_all(")?
4339 }
4340 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4341 write!(self.out, "{NAMESPACE}::simd_any(")?
4342 }
4343 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4344 write!(self.out, "{NAMESPACE}::simd_sum(")?
4345 }
4346 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4347 write!(self.out, "{NAMESPACE}::simd_product(")?
4348 }
4349 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4350 write!(self.out, "{NAMESPACE}::simd_max(")?
4351 }
4352 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4353 write!(self.out, "{NAMESPACE}::simd_min(")?
4354 }
4355 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4356 write!(self.out, "{NAMESPACE}::simd_and(")?
4357 }
4358 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4359 write!(self.out, "{NAMESPACE}::simd_or(")?
4360 }
4361 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4362 write!(self.out, "{NAMESPACE}::simd_xor(")?
4363 }
4364 (
4365 crate::CollectiveOperation::ExclusiveScan,
4366 crate::SubgroupOperation::Add,
4367 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4368 (
4369 crate::CollectiveOperation::ExclusiveScan,
4370 crate::SubgroupOperation::Mul,
4371 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4372 (
4373 crate::CollectiveOperation::InclusiveScan,
4374 crate::SubgroupOperation::Add,
4375 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4376 (
4377 crate::CollectiveOperation::InclusiveScan,
4378 crate::SubgroupOperation::Mul,
4379 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4380 _ => unimplemented!(),
4381 }
4382 self.put_expression(argument, &context.expression, true)?;
4383 writeln!(self.out, ");")?;
4384 }
4385 crate::Statement::SubgroupGather {
4386 mode,
4387 argument,
4388 result,
4389 } => {
4390 write!(self.out, "{level}")?;
4391 let name = self.namer.call("");
4392 self.start_baking_expression(result, &context.expression, &name)?;
4393 self.named_expressions.insert(result, name);
4394 match mode {
4395 crate::GatherMode::BroadcastFirst => {
4396 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4397 }
4398 crate::GatherMode::Broadcast(_) => {
4399 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4400 }
4401 crate::GatherMode::Shuffle(_) => {
4402 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4403 }
4404 crate::GatherMode::ShuffleDown(_) => {
4405 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4406 }
4407 crate::GatherMode::ShuffleUp(_) => {
4408 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4409 }
4410 crate::GatherMode::ShuffleXor(_) => {
4411 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4412 }
4413 crate::GatherMode::QuadBroadcast(_) => {
4414 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4415 }
4416 crate::GatherMode::QuadSwap(_) => {
4417 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4418 }
4419 }
4420 self.put_expression(argument, &context.expression, true)?;
4421 match mode {
4422 crate::GatherMode::BroadcastFirst => {}
4423 crate::GatherMode::Broadcast(index)
4424 | crate::GatherMode::Shuffle(index)
4425 | crate::GatherMode::ShuffleDown(index)
4426 | crate::GatherMode::ShuffleUp(index)
4427 | crate::GatherMode::ShuffleXor(index)
4428 | crate::GatherMode::QuadBroadcast(index) => {
4429 write!(self.out, ", ")?;
4430 self.put_expression(index, &context.expression, true)?;
4431 }
4432 crate::GatherMode::QuadSwap(direction) => {
4433 write!(self.out, ", ")?;
4434 match direction {
4435 crate::Direction::X => {
4436 write!(self.out, "1u")?;
4437 }
4438 crate::Direction::Y => {
4439 write!(self.out, "2u")?;
4440 }
4441 crate::Direction::Diagonal => {
4442 write!(self.out, "3u")?;
4443 }
4444 }
4445 }
4446 }
4447 writeln!(self.out, ");")?;
4448 }
4449 crate::Statement::CooperativeStore { target, ref data } => {
4450 write!(self.out, "{level}simdgroup_store(")?;
4451 self.put_expression(target, &context.expression, true)?;
4452 write!(self.out, ", &")?;
4453 self.put_access_chain(
4454 data.pointer,
4455 context.expression.policies.index,
4456 &context.expression,
4457 )?;
4458 write!(self.out, ", ")?;
4459 self.put_expression(data.stride, &context.expression, true)?;
4460 if !data.row_major {
4465 let matrix_origin = "0";
4466 let transpose = true;
4467 write!(self.out, ", {matrix_origin}, {transpose}")?;
4468 }
4469 writeln!(self.out, ");")?;
4470 }
4471 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4472 }
4473 }
4474
4475 for statement in statements {
4478 if let crate::Statement::Emit(ref range) = *statement {
4479 for handle in range.clone() {
4480 self.named_expressions.shift_remove(&handle);
4481 }
4482 }
4483 }
4484 Ok(())
4485 }
4486
4487 fn put_store(
4488 &mut self,
4489 pointer: Handle<crate::Expression>,
4490 value: Handle<crate::Expression>,
4491 level: back::Level,
4492 context: &StatementContext,
4493 ) -> BackendResult {
4494 let policy = context.expression.choose_bounds_check_policy(pointer);
4495 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4496 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4497 {
4498 writeln!(self.out, ") {{")?;
4499 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4500 writeln!(self.out, "{level}}}")?;
4501 } else {
4502 self.put_unchecked_store(pointer, value, policy, level, context)?;
4503 }
4504
4505 Ok(())
4506 }
4507
4508 fn put_unchecked_store(
4509 &mut self,
4510 pointer: Handle<crate::Expression>,
4511 value: Handle<crate::Expression>,
4512 policy: index::BoundsCheckPolicy,
4513 level: back::Level,
4514 context: &StatementContext,
4515 ) -> BackendResult {
4516 let is_atomic_pointer = context
4517 .expression
4518 .resolve_type(pointer)
4519 .is_atomic_pointer(&context.expression.module.types);
4520
4521 if is_atomic_pointer {
4522 write!(
4523 self.out,
4524 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4525 )?;
4526 self.put_access_chain(pointer, policy, &context.expression)?;
4527 write!(self.out, ", ")?;
4528 self.put_expression(value, &context.expression, true)?;
4529 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4530 } else {
4531 write!(self.out, "{level}")?;
4532 self.put_access_chain(pointer, policy, &context.expression)?;
4533 write!(self.out, " = ")?;
4534 self.put_expression(value, &context.expression, true)?;
4535 writeln!(self.out, ";")?;
4536 }
4537
4538 Ok(())
4539 }
4540
4541 pub fn write(
4542 &mut self,
4543 module: &crate::Module,
4544 info: &valid::ModuleInfo,
4545 options: &Options,
4546 pipeline_options: &PipelineOptions,
4547 ) -> Result<TranslationInfo, Error> {
4548 self.emit_int_div_checks = options.emit_int_div_checks;
4549 self.names.clear();
4550 self.namer.reset(
4551 module,
4552 &super::keywords::RESERVED_SET,
4553 proc::KeywordSet::empty(),
4554 proc::CaseInsensitiveKeywordSet::empty(),
4555 &[
4556 CLAMPED_LOD_LOAD_PREFIX,
4557 super::ray::INTERSECTION_FUNCTION_NAME,
4558 super::ray::RAY_QUERY_TRACKER_VARIABLE_PREFIX,
4559 super::ray::RAY_QUERY_T_MAX_TRACKER_VARIABLE_PREFIX,
4560 ],
4561 &mut self.names,
4562 );
4563 self.wrapped_functions.clear();
4564 self.struct_member_pads.clear();
4565
4566 writeln!(
4567 self.out,
4568 "// language: metal{}.{}",
4569 options.lang_version.0, options.lang_version.1
4570 )?;
4571 writeln!(self.out, "#include <metal_stdlib>")?;
4572 writeln!(self.out, "#include <simd/simd.h>")?;
4573 writeln!(self.out)?;
4574 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4576
4577 if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4578 return Err(Error::UnsupportedMeshShader);
4579 }
4580 self.needs_object_memory_barriers = module
4581 .entry_points
4582 .iter()
4583 .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4584
4585 if module.special_types.ray_desc.is_some()
4586 || module.special_types.ray_intersection.is_some()
4587 {
4588 if options.lang_version < (2, 4) {
4589 return Err(Error::UnsupportedRayTracing);
4590 }
4591 }
4592
4593 if options
4594 .bounds_check_policies
4595 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4596 {
4597 self.put_default_constructible()?;
4598 }
4599 writeln!(self.out)?;
4600
4601 {
4602 let globals: Vec<Handle<crate::GlobalVariable>> = module
4605 .global_variables
4606 .iter()
4607 .filter(|&(_, var)| {
4608 module.types[var.ty]
4609 .inner
4610 .needs_host_buffer_byte_size(&module.types)
4611 })
4612 .map(|(handle, _)| handle)
4613 .collect();
4614
4615 let mut buffer_indices = vec![];
4616 for vbm in &pipeline_options.vertex_buffer_mappings {
4617 buffer_indices.push(vbm.id);
4618 }
4619
4620 if !globals.is_empty() || !buffer_indices.is_empty() {
4621 writeln!(self.out, "struct _mslBufferSizes {{")?;
4622
4623 for global in globals {
4624 let var = &module.global_variables[global];
4625 let var_ty = var.ty;
4626 match module.types[var_ty].inner {
4627 crate::TypeInner::BindingArray { .. } => {
4628 let n =
4629 Self::binding_array_layout_count(module, pipeline_options, global);
4630 writeln!(
4631 self.out,
4632 "{}uint {}[{n}];",
4633 back::INDENT,
4634 ArraySizeMember(global),
4635 )?;
4636 }
4637 _ => writeln!(
4638 self.out,
4639 "{}uint {};",
4640 back::INDENT,
4641 ArraySizeMember(global)
4642 )?,
4643 }
4644 }
4645
4646 for idx in buffer_indices {
4647 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4648 }
4649
4650 writeln!(self.out, "}};")?;
4651 writeln!(self.out)?;
4652 }
4653 };
4654
4655 self.write_type_defs(module)?;
4656 self.write_global_constants(module, info)?;
4657 self.write_functions(module, info, options, pipeline_options)
4658 }
4659
4660 fn put_default_constructible(&mut self) -> BackendResult {
4673 let tab = back::INDENT;
4674 writeln!(self.out, "struct DefaultConstructible {{")?;
4675 writeln!(self.out, "{tab}template<typename T>")?;
4676 writeln!(self.out, "{tab}operator T() && {{")?;
4677 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4678 writeln!(self.out, "{tab}}}")?;
4679 writeln!(self.out, "}};")?;
4680 Ok(())
4681 }
4682
4683 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4684 let mut generated_argument_buffer_wrapper = false;
4685 let mut generated_external_texture_wrapper = false;
4686 for (handle, ty) in module.types.iter() {
4687 match ty.inner {
4688 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4689 writeln!(self.out, "template <typename T>")?;
4690 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4691 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4692 writeln!(self.out, "}};")?;
4693 generated_argument_buffer_wrapper = true;
4694 }
4695 crate::TypeInner::Image {
4696 class: crate::ImageClass::External,
4697 ..
4698 } if !generated_external_texture_wrapper => {
4699 let params_ty_name = &self.names
4700 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4701 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4702 writeln!(
4703 self.out,
4704 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4705 back::INDENT
4706 )?;
4707 writeln!(
4708 self.out,
4709 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4710 back::INDENT
4711 )?;
4712 writeln!(
4713 self.out,
4714 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4715 back::INDENT
4716 )?;
4717 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4718 writeln!(self.out, "}};")?;
4719 generated_external_texture_wrapper = true;
4720 }
4721 _ => {}
4722 }
4723
4724 if !ty.needs_alias() {
4725 continue;
4726 }
4727 let name = &self.names[&NameKey::Type(handle)];
4728 match ty.inner {
4729 crate::TypeInner::Array {
4743 base,
4744 size,
4745 stride: _,
4746 } => {
4747 let base_name = TypeContext {
4748 handle: base,
4749 gctx: module.to_ctx(),
4750 names: &self.names,
4751 access: crate::StorageAccess::empty(),
4752 first_time: false,
4753 };
4754
4755 match size.resolve(module.to_ctx())? {
4756 proc::IndexableLength::Known(size) => {
4757 writeln!(self.out, "struct {name} {{")?;
4758 writeln!(
4759 self.out,
4760 "{}{} {}[{}];",
4761 back::INDENT,
4762 base_name,
4763 WRAPPED_ARRAY_FIELD,
4764 size
4765 )?;
4766 writeln!(self.out, "}};")?;
4767 }
4768 proc::IndexableLength::Dynamic => {
4769 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4770 }
4771 }
4772 }
4773 crate::TypeInner::Struct {
4774 ref members, span, ..
4775 } => {
4776 writeln!(self.out, "struct {name} {{")?;
4777 let mut last_offset = 0;
4778 for (index, member) in members.iter().enumerate() {
4779 if member.offset > last_offset {
4780 self.struct_member_pads.insert((handle, index as u32));
4781 let pad = member.offset - last_offset;
4782 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4783 }
4784 let ty_inner = &module.types[member.ty].inner;
4785 last_offset = member.offset + ty_inner.size(module.to_ctx());
4786
4787 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4788
4789 match should_pack_struct_member(members, span, index, module) {
4791 Some(scalar) => {
4792 writeln!(
4793 self.out,
4794 "{}{}::packed_{}3 {};",
4795 back::INDENT,
4796 NAMESPACE,
4797 scalar.to_msl_name(),
4798 member_name
4799 )?;
4800 }
4801 None => {
4802 let base_name = TypeContext {
4803 handle: member.ty,
4804 gctx: module.to_ctx(),
4805 names: &self.names,
4806 access: crate::StorageAccess::empty(),
4807 first_time: false,
4808 };
4809 writeln!(
4810 self.out,
4811 "{}{} {};",
4812 back::INDENT,
4813 base_name,
4814 member_name
4815 )?;
4816
4817 if let crate::TypeInner::Vector {
4819 size: crate::VectorSize::Tri,
4820 scalar,
4821 } = *ty_inner
4822 {
4823 last_offset += scalar.width as u32;
4824 }
4825 }
4826 }
4827 }
4828 if last_offset < span {
4829 let pad = span - last_offset;
4830 writeln!(
4831 self.out,
4832 "{}char _pad{}[{}];",
4833 back::INDENT,
4834 members.len(),
4835 pad
4836 )?;
4837 }
4838 writeln!(self.out, "}};")?;
4839 }
4840 _ => {
4841 let ty_name = TypeContext {
4842 handle,
4843 gctx: module.to_ctx(),
4844 names: &self.names,
4845 access: crate::StorageAccess::empty(),
4846 first_time: true,
4847 };
4848 writeln!(self.out, "typedef {ty_name} {name};")?;
4849 }
4850 }
4851 }
4852
4853 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4855 match type_key {
4856 &crate::PredeclaredType::ModfResult { size, scalar }
4857 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4858 let arg_type_name_owner;
4859 let arg_type_name = if let Some(size) = size {
4860 arg_type_name_owner = format!(
4861 "{NAMESPACE}::{}{}",
4862 if scalar.width == 8 { "double" } else { "float" },
4863 size as u8
4864 );
4865 &arg_type_name_owner
4866 } else if scalar.width == 8 {
4867 "double"
4868 } else {
4869 "float"
4870 };
4871
4872 let other_type_name_owner;
4873 let (defined_func_name, called_func_name, other_type_name) =
4874 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4875 (MODF_FUNCTION, "modf", arg_type_name)
4876 } else {
4877 let other_type_name = if let Some(size) = size {
4878 other_type_name_owner = format!("int{}", size as u8);
4879 &other_type_name_owner
4880 } else {
4881 "int"
4882 };
4883 (FREXP_FUNCTION, "frexp", other_type_name)
4884 };
4885
4886 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4887
4888 writeln!(self.out)?;
4889 writeln!(
4890 self.out,
4891 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4892 {other_type_name} other;
4893 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4894 return {struct_name}{{ fract, other }};
4895}}"
4896 )?;
4897 }
4898 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4899 let arg_type_name = scalar.to_msl_name();
4900 let called_func_name = "atomic_compare_exchange_weak_explicit";
4901 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4902 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4903
4904 writeln!(self.out)?;
4905
4906 for address_space_name in ["device", "threadgroup"] {
4907 writeln!(
4908 self.out,
4909 "\
4910template <typename A>
4911{struct_name} {defined_func_name}(
4912 {address_space_name} A *atomic_ptr,
4913 {arg_type_name} cmp,
4914 {arg_type_name} v
4915) {{
4916 bool swapped = {NAMESPACE}::{called_func_name}(
4917 atomic_ptr, &cmp, v,
4918 metal::memory_order_relaxed, metal::memory_order_relaxed
4919 );
4920 return {struct_name}{{cmp, swapped}};
4921}}"
4922 )?;
4923 }
4924 }
4925 }
4926 }
4927
4928 Ok(())
4929 }
4930
4931 fn write_global_constants(
4933 &mut self,
4934 module: &crate::Module,
4935 mod_info: &valid::ModuleInfo,
4936 ) -> BackendResult {
4937 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4938
4939 for (handle, constant) in constants {
4940 let ty_name = TypeContext {
4941 handle: constant.ty,
4942 gctx: module.to_ctx(),
4943 names: &self.names,
4944 access: crate::StorageAccess::empty(),
4945 first_time: false,
4946 };
4947 let name = &self.names[&NameKey::Constant(handle)];
4948 write!(self.out, "constant {ty_name} {name} = ")?;
4949 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4950 writeln!(self.out, ";")?;
4951 }
4952
4953 Ok(())
4954 }
4955
4956 fn put_inline_sampler_properties(
4957 &mut self,
4958 level: back::Level,
4959 sampler: &sm::InlineSampler,
4960 ) -> BackendResult {
4961 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4962 writeln!(
4963 self.out,
4964 "{}{}::{}_address::{},",
4965 level,
4966 NAMESPACE,
4967 letter,
4968 address.as_str(),
4969 )?;
4970 }
4971 writeln!(
4972 self.out,
4973 "{}{}::mag_filter::{},",
4974 level,
4975 NAMESPACE,
4976 sampler.mag_filter.as_str(),
4977 )?;
4978 writeln!(
4979 self.out,
4980 "{}{}::min_filter::{},",
4981 level,
4982 NAMESPACE,
4983 sampler.min_filter.as_str(),
4984 )?;
4985 if let Some(filter) = sampler.mip_filter {
4986 writeln!(
4987 self.out,
4988 "{}{}::mip_filter::{},",
4989 level,
4990 NAMESPACE,
4991 filter.as_str(),
4992 )?;
4993 }
4994 if sampler.border_color != sm::BorderColor::TransparentBlack {
4996 writeln!(
4997 self.out,
4998 "{}{}::border_color::{},",
4999 level,
5000 NAMESPACE,
5001 sampler.border_color.as_str(),
5002 )?;
5003 }
5004 if false {
5008 if let Some(ref lod) = sampler.lod_clamp {
5009 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
5010 }
5011 if let Some(aniso) = sampler.max_anisotropy {
5012 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
5013 }
5014 }
5015 if sampler.compare_func != sm::CompareFunc::Never {
5016 writeln!(
5017 self.out,
5018 "{}{}::compare_func::{},",
5019 level,
5020 NAMESPACE,
5021 sampler.compare_func.as_str(),
5022 )?;
5023 }
5024 writeln!(
5025 self.out,
5026 "{}{}::coord::{}",
5027 level,
5028 NAMESPACE,
5029 sampler.coord.as_str()
5030 )?;
5031 Ok(())
5032 }
5033
5034 fn write_unpacking_function(
5035 &mut self,
5036 format: nt::VertexFormat,
5037 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
5038 use crate::{Scalar, VectorSize};
5039 use nt::VertexFormat::*;
5040 match format {
5041 Uint8 => {
5042 let name = self.namer.call("unpackUint8");
5043 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
5044 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
5045 writeln!(self.out, "}}")?;
5046 Ok((name, 1, None, Scalar::U32))
5047 }
5048 Uint8x2 => {
5049 let name = self.namer.call("unpackUint8x2");
5050 writeln!(
5051 self.out,
5052 "metal::uint2 {name}(metal::uchar b0, \
5053 metal::uchar b1) {{"
5054 )?;
5055 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
5056 writeln!(self.out, "}}")?;
5057 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
5058 }
5059 Uint8x4 => {
5060 let name = self.namer.call("unpackUint8x4");
5061 writeln!(
5062 self.out,
5063 "metal::uint4 {name}(metal::uchar b0, \
5064 metal::uchar b1, \
5065 metal::uchar b2, \
5066 metal::uchar b3) {{"
5067 )?;
5068 writeln!(
5069 self.out,
5070 "{}return metal::uint4(b0, b1, b2, b3);",
5071 back::INDENT
5072 )?;
5073 writeln!(self.out, "}}")?;
5074 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
5075 }
5076 Sint8 => {
5077 let name = self.namer.call("unpackSint8");
5078 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
5079 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
5080 writeln!(self.out, "}}")?;
5081 Ok((name, 1, None, Scalar::I32))
5082 }
5083 Sint8x2 => {
5084 let name = self.namer.call("unpackSint8x2");
5085 writeln!(
5086 self.out,
5087 "metal::int2 {name}(metal::uchar b0, \
5088 metal::uchar b1) {{"
5089 )?;
5090 writeln!(
5091 self.out,
5092 "{}return metal::int2(as_type<char>(b0), \
5093 as_type<char>(b1));",
5094 back::INDENT
5095 )?;
5096 writeln!(self.out, "}}")?;
5097 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
5098 }
5099 Sint8x4 => {
5100 let name = self.namer.call("unpackSint8x4");
5101 writeln!(
5102 self.out,
5103 "metal::int4 {name}(metal::uchar b0, \
5104 metal::uchar b1, \
5105 metal::uchar b2, \
5106 metal::uchar b3) {{"
5107 )?;
5108 writeln!(
5109 self.out,
5110 "{}return metal::int4(as_type<char>(b0), \
5111 as_type<char>(b1), \
5112 as_type<char>(b2), \
5113 as_type<char>(b3));",
5114 back::INDENT
5115 )?;
5116 writeln!(self.out, "}}")?;
5117 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
5118 }
5119 Unorm8 => {
5120 let name = self.namer.call("unpackUnorm8");
5121 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5122 writeln!(
5123 self.out,
5124 "{}return float(float(b0) / 255.0f);",
5125 back::INDENT
5126 )?;
5127 writeln!(self.out, "}}")?;
5128 Ok((name, 1, None, Scalar::F32))
5129 }
5130 Unorm8x2 => {
5131 let name = self.namer.call("unpackUnorm8x2");
5132 writeln!(
5133 self.out,
5134 "metal::float2 {name}(metal::uchar b0, \
5135 metal::uchar b1) {{"
5136 )?;
5137 writeln!(
5138 self.out,
5139 "{}return metal::float2(float(b0) / 255.0f, \
5140 float(b1) / 255.0f);",
5141 back::INDENT
5142 )?;
5143 writeln!(self.out, "}}")?;
5144 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5145 }
5146 Unorm8x4 => {
5147 let name = self.namer.call("unpackUnorm8x4");
5148 writeln!(
5149 self.out,
5150 "metal::float4 {name}(metal::uchar b0, \
5151 metal::uchar b1, \
5152 metal::uchar b2, \
5153 metal::uchar b3) {{"
5154 )?;
5155 writeln!(
5156 self.out,
5157 "{}return metal::float4(float(b0) / 255.0f, \
5158 float(b1) / 255.0f, \
5159 float(b2) / 255.0f, \
5160 float(b3) / 255.0f);",
5161 back::INDENT
5162 )?;
5163 writeln!(self.out, "}}")?;
5164 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5165 }
5166 Snorm8 => {
5167 let name = self.namer.call("unpackSnorm8");
5168 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5169 writeln!(
5170 self.out,
5171 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5172 back::INDENT
5173 )?;
5174 writeln!(self.out, "}}")?;
5175 Ok((name, 1, None, Scalar::F32))
5176 }
5177 Snorm8x2 => {
5178 let name = self.namer.call("unpackSnorm8x2");
5179 writeln!(
5180 self.out,
5181 "metal::float2 {name}(metal::uchar b0, \
5182 metal::uchar b1) {{"
5183 )?;
5184 writeln!(
5185 self.out,
5186 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5187 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5188 back::INDENT
5189 )?;
5190 writeln!(self.out, "}}")?;
5191 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5192 }
5193 Snorm8x4 => {
5194 let name = self.namer.call("unpackSnorm8x4");
5195 writeln!(
5196 self.out,
5197 "metal::float4 {name}(metal::uchar b0, \
5198 metal::uchar b1, \
5199 metal::uchar b2, \
5200 metal::uchar b3) {{"
5201 )?;
5202 writeln!(
5203 self.out,
5204 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5205 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5206 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5207 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5208 back::INDENT
5209 )?;
5210 writeln!(self.out, "}}")?;
5211 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5212 }
5213 Uint16 => {
5214 let name = self.namer.call("unpackUint16");
5215 writeln!(
5216 self.out,
5217 "metal::uint {name}(metal::uint b0, \
5218 metal::uint b1) {{"
5219 )?;
5220 writeln!(
5221 self.out,
5222 "{}return metal::uint(b1 << 8 | b0);",
5223 back::INDENT
5224 )?;
5225 writeln!(self.out, "}}")?;
5226 Ok((name, 2, None, Scalar::U32))
5227 }
5228 Uint16x2 => {
5229 let name = self.namer.call("unpackUint16x2");
5230 writeln!(
5231 self.out,
5232 "metal::uint2 {name}(metal::uint b0, \
5233 metal::uint b1, \
5234 metal::uint b2, \
5235 metal::uint b3) {{"
5236 )?;
5237 writeln!(
5238 self.out,
5239 "{}return metal::uint2(b1 << 8 | b0, \
5240 b3 << 8 | b2);",
5241 back::INDENT
5242 )?;
5243 writeln!(self.out, "}}")?;
5244 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5245 }
5246 Uint16x4 => {
5247 let name = self.namer.call("unpackUint16x4");
5248 writeln!(
5249 self.out,
5250 "metal::uint4 {name}(metal::uint b0, \
5251 metal::uint b1, \
5252 metal::uint b2, \
5253 metal::uint b3, \
5254 metal::uint b4, \
5255 metal::uint b5, \
5256 metal::uint b6, \
5257 metal::uint b7) {{"
5258 )?;
5259 writeln!(
5260 self.out,
5261 "{}return metal::uint4(b1 << 8 | b0, \
5262 b3 << 8 | b2, \
5263 b5 << 8 | b4, \
5264 b7 << 8 | b6);",
5265 back::INDENT
5266 )?;
5267 writeln!(self.out, "}}")?;
5268 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5269 }
5270 Sint16 => {
5271 let name = self.namer.call("unpackSint16");
5272 writeln!(
5273 self.out,
5274 "int {name}(metal::ushort b0, \
5275 metal::ushort b1) {{"
5276 )?;
5277 writeln!(
5278 self.out,
5279 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5280 back::INDENT
5281 )?;
5282 writeln!(self.out, "}}")?;
5283 Ok((name, 2, None, Scalar::I32))
5284 }
5285 Sint16x2 => {
5286 let name = self.namer.call("unpackSint16x2");
5287 writeln!(
5288 self.out,
5289 "metal::int2 {name}(metal::ushort b0, \
5290 metal::ushort b1, \
5291 metal::ushort b2, \
5292 metal::ushort b3) {{"
5293 )?;
5294 writeln!(
5295 self.out,
5296 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5297 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5298 back::INDENT
5299 )?;
5300 writeln!(self.out, "}}")?;
5301 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5302 }
5303 Sint16x4 => {
5304 let name = self.namer.call("unpackSint16x4");
5305 writeln!(
5306 self.out,
5307 "metal::int4 {name}(metal::ushort b0, \
5308 metal::ushort b1, \
5309 metal::ushort b2, \
5310 metal::ushort b3, \
5311 metal::ushort b4, \
5312 metal::ushort b5, \
5313 metal::ushort b6, \
5314 metal::ushort b7) {{"
5315 )?;
5316 writeln!(
5317 self.out,
5318 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5319 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5320 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5321 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5322 back::INDENT
5323 )?;
5324 writeln!(self.out, "}}")?;
5325 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5326 }
5327 Unorm16 => {
5328 let name = self.namer.call("unpackUnorm16");
5329 writeln!(
5330 self.out,
5331 "float {name}(metal::ushort b0, \
5332 metal::ushort b1) {{"
5333 )?;
5334 writeln!(
5335 self.out,
5336 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5337 back::INDENT
5338 )?;
5339 writeln!(self.out, "}}")?;
5340 Ok((name, 2, None, Scalar::F32))
5341 }
5342 Unorm16x2 => {
5343 let name = self.namer.call("unpackUnorm16x2");
5344 writeln!(
5345 self.out,
5346 "metal::float2 {name}(metal::ushort b0, \
5347 metal::ushort b1, \
5348 metal::ushort b2, \
5349 metal::ushort b3) {{"
5350 )?;
5351 writeln!(
5352 self.out,
5353 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5354 float(b3 << 8 | b2) / 65535.0f);",
5355 back::INDENT
5356 )?;
5357 writeln!(self.out, "}}")?;
5358 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5359 }
5360 Unorm16x4 => {
5361 let name = self.namer.call("unpackUnorm16x4");
5362 writeln!(
5363 self.out,
5364 "metal::float4 {name}(metal::ushort b0, \
5365 metal::ushort b1, \
5366 metal::ushort b2, \
5367 metal::ushort b3, \
5368 metal::ushort b4, \
5369 metal::ushort b5, \
5370 metal::ushort b6, \
5371 metal::ushort b7) {{"
5372 )?;
5373 writeln!(
5374 self.out,
5375 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5376 float(b3 << 8 | b2) / 65535.0f, \
5377 float(b5 << 8 | b4) / 65535.0f, \
5378 float(b7 << 8 | b6) / 65535.0f);",
5379 back::INDENT
5380 )?;
5381 writeln!(self.out, "}}")?;
5382 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5383 }
5384 Snorm16 => {
5385 let name = self.namer.call("unpackSnorm16");
5386 writeln!(
5387 self.out,
5388 "float {name}(metal::ushort b0, \
5389 metal::ushort b1) {{"
5390 )?;
5391 writeln!(
5392 self.out,
5393 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5394 back::INDENT
5395 )?;
5396 writeln!(self.out, "}}")?;
5397 Ok((name, 2, None, Scalar::F32))
5398 }
5399 Snorm16x2 => {
5400 let name = self.namer.call("unpackSnorm16x2");
5401 writeln!(
5402 self.out,
5403 "metal::float2 {name}(uint b0, \
5404 uint b1, \
5405 uint b2, \
5406 uint b3) {{"
5407 )?;
5408 writeln!(
5409 self.out,
5410 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5411 back::INDENT
5412 )?;
5413 writeln!(self.out, "}}")?;
5414 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5415 }
5416 Snorm16x4 => {
5417 let name = self.namer.call("unpackSnorm16x4");
5418 writeln!(
5419 self.out,
5420 "metal::float4 {name}(uint b0, \
5421 uint b1, \
5422 uint b2, \
5423 uint b3, \
5424 uint b4, \
5425 uint b5, \
5426 uint b6, \
5427 uint b7) {{"
5428 )?;
5429 writeln!(
5430 self.out,
5431 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5432 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5433 back::INDENT
5434 )?;
5435 writeln!(self.out, "}}")?;
5436 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5437 }
5438 Float16 => {
5439 let name = self.namer.call("unpackFloat16");
5440 writeln!(
5441 self.out,
5442 "float {name}(metal::ushort b0, \
5443 metal::ushort b1) {{"
5444 )?;
5445 writeln!(
5446 self.out,
5447 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5448 back::INDENT
5449 )?;
5450 writeln!(self.out, "}}")?;
5451 Ok((name, 2, None, Scalar::F32))
5452 }
5453 Float16x2 => {
5454 let name = self.namer.call("unpackFloat16x2");
5455 writeln!(
5456 self.out,
5457 "metal::float2 {name}(metal::ushort b0, \
5458 metal::ushort b1, \
5459 metal::ushort b2, \
5460 metal::ushort b3) {{"
5461 )?;
5462 writeln!(
5463 self.out,
5464 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5465 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5466 back::INDENT
5467 )?;
5468 writeln!(self.out, "}}")?;
5469 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5470 }
5471 Float16x4 => {
5472 let name = self.namer.call("unpackFloat16x4");
5473 writeln!(
5474 self.out,
5475 "metal::float4 {name}(metal::ushort b0, \
5476 metal::ushort b1, \
5477 metal::ushort b2, \
5478 metal::ushort b3, \
5479 metal::ushort b4, \
5480 metal::ushort b5, \
5481 metal::ushort b6, \
5482 metal::ushort b7) {{"
5483 )?;
5484 writeln!(
5485 self.out,
5486 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5487 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5488 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5489 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5490 back::INDENT
5491 )?;
5492 writeln!(self.out, "}}")?;
5493 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5494 }
5495 Float32 => {
5496 let name = self.namer.call("unpackFloat32");
5497 writeln!(
5498 self.out,
5499 "float {name}(uint b0, \
5500 uint b1, \
5501 uint b2, \
5502 uint b3) {{"
5503 )?;
5504 writeln!(
5505 self.out,
5506 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5507 back::INDENT
5508 )?;
5509 writeln!(self.out, "}}")?;
5510 Ok((name, 4, None, Scalar::F32))
5511 }
5512 Float32x2 => {
5513 let name = self.namer.call("unpackFloat32x2");
5514 writeln!(
5515 self.out,
5516 "metal::float2 {name}(uint b0, \
5517 uint b1, \
5518 uint b2, \
5519 uint b3, \
5520 uint b4, \
5521 uint b5, \
5522 uint b6, \
5523 uint b7) {{"
5524 )?;
5525 writeln!(
5526 self.out,
5527 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5528 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5529 back::INDENT
5530 )?;
5531 writeln!(self.out, "}}")?;
5532 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5533 }
5534 Float32x3 => {
5535 let name = self.namer.call("unpackFloat32x3");
5536 writeln!(
5537 self.out,
5538 "metal::float3 {name}(uint b0, \
5539 uint b1, \
5540 uint b2, \
5541 uint b3, \
5542 uint b4, \
5543 uint b5, \
5544 uint b6, \
5545 uint b7, \
5546 uint b8, \
5547 uint b9, \
5548 uint b10, \
5549 uint b11) {{"
5550 )?;
5551 writeln!(
5552 self.out,
5553 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5554 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5555 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5556 back::INDENT
5557 )?;
5558 writeln!(self.out, "}}")?;
5559 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5560 }
5561 Float32x4 => {
5562 let name = self.namer.call("unpackFloat32x4");
5563 writeln!(
5564 self.out,
5565 "metal::float4 {name}(uint b0, \
5566 uint b1, \
5567 uint b2, \
5568 uint b3, \
5569 uint b4, \
5570 uint b5, \
5571 uint b6, \
5572 uint b7, \
5573 uint b8, \
5574 uint b9, \
5575 uint b10, \
5576 uint b11, \
5577 uint b12, \
5578 uint b13, \
5579 uint b14, \
5580 uint b15) {{"
5581 )?;
5582 writeln!(
5583 self.out,
5584 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5585 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5586 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5587 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5588 back::INDENT
5589 )?;
5590 writeln!(self.out, "}}")?;
5591 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5592 }
5593 Uint32 => {
5594 let name = self.namer.call("unpackUint32");
5595 writeln!(
5596 self.out,
5597 "uint {name}(uint b0, \
5598 uint b1, \
5599 uint b2, \
5600 uint b3) {{"
5601 )?;
5602 writeln!(
5603 self.out,
5604 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5605 back::INDENT
5606 )?;
5607 writeln!(self.out, "}}")?;
5608 Ok((name, 4, None, Scalar::U32))
5609 }
5610 Uint32x2 => {
5611 let name = self.namer.call("unpackUint32x2");
5612 writeln!(
5613 self.out,
5614 "uint2 {name}(uint b0, \
5615 uint b1, \
5616 uint b2, \
5617 uint b3, \
5618 uint b4, \
5619 uint b5, \
5620 uint b6, \
5621 uint b7) {{"
5622 )?;
5623 writeln!(
5624 self.out,
5625 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5626 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5627 back::INDENT
5628 )?;
5629 writeln!(self.out, "}}")?;
5630 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5631 }
5632 Uint32x3 => {
5633 let name = self.namer.call("unpackUint32x3");
5634 writeln!(
5635 self.out,
5636 "uint3 {name}(uint b0, \
5637 uint b1, \
5638 uint b2, \
5639 uint b3, \
5640 uint b4, \
5641 uint b5, \
5642 uint b6, \
5643 uint b7, \
5644 uint b8, \
5645 uint b9, \
5646 uint b10, \
5647 uint b11) {{"
5648 )?;
5649 writeln!(
5650 self.out,
5651 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5652 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5653 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5654 back::INDENT
5655 )?;
5656 writeln!(self.out, "}}")?;
5657 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5658 }
5659 Uint32x4 => {
5660 let name = self.namer.call("unpackUint32x4");
5661 writeln!(
5662 self.out,
5663 "{NAMESPACE}::uint4 {name}(uint b0, \
5664 uint b1, \
5665 uint b2, \
5666 uint b3, \
5667 uint b4, \
5668 uint b5, \
5669 uint b6, \
5670 uint b7, \
5671 uint b8, \
5672 uint b9, \
5673 uint b10, \
5674 uint b11, \
5675 uint b12, \
5676 uint b13, \
5677 uint b14, \
5678 uint b15) {{"
5679 )?;
5680 writeln!(
5681 self.out,
5682 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5683 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5684 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5685 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5686 back::INDENT
5687 )?;
5688 writeln!(self.out, "}}")?;
5689 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5690 }
5691 Sint32 => {
5692 let name = self.namer.call("unpackSint32");
5693 writeln!(
5694 self.out,
5695 "int {name}(uint b0, \
5696 uint b1, \
5697 uint b2, \
5698 uint b3) {{"
5699 )?;
5700 writeln!(
5701 self.out,
5702 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5703 back::INDENT
5704 )?;
5705 writeln!(self.out, "}}")?;
5706 Ok((name, 4, None, Scalar::I32))
5707 }
5708 Sint32x2 => {
5709 let name = self.namer.call("unpackSint32x2");
5710 writeln!(
5711 self.out,
5712 "metal::int2 {name}(uint b0, \
5713 uint b1, \
5714 uint b2, \
5715 uint b3, \
5716 uint b4, \
5717 uint b5, \
5718 uint b6, \
5719 uint b7) {{"
5720 )?;
5721 writeln!(
5722 self.out,
5723 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5724 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5725 back::INDENT
5726 )?;
5727 writeln!(self.out, "}}")?;
5728 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5729 }
5730 Sint32x3 => {
5731 let name = self.namer.call("unpackSint32x3");
5732 writeln!(
5733 self.out,
5734 "metal::int3 {name}(uint b0, \
5735 uint b1, \
5736 uint b2, \
5737 uint b3, \
5738 uint b4, \
5739 uint b5, \
5740 uint b6, \
5741 uint b7, \
5742 uint b8, \
5743 uint b9, \
5744 uint b10, \
5745 uint b11) {{"
5746 )?;
5747 writeln!(
5748 self.out,
5749 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5750 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5751 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5752 back::INDENT
5753 )?;
5754 writeln!(self.out, "}}")?;
5755 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5756 }
5757 Sint32x4 => {
5758 let name = self.namer.call("unpackSint32x4");
5759 writeln!(
5760 self.out,
5761 "metal::int4 {name}(uint b0, \
5762 uint b1, \
5763 uint b2, \
5764 uint b3, \
5765 uint b4, \
5766 uint b5, \
5767 uint b6, \
5768 uint b7, \
5769 uint b8, \
5770 uint b9, \
5771 uint b10, \
5772 uint b11, \
5773 uint b12, \
5774 uint b13, \
5775 uint b14, \
5776 uint b15) {{"
5777 )?;
5778 writeln!(
5779 self.out,
5780 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5781 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5782 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5783 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5784 back::INDENT
5785 )?;
5786 writeln!(self.out, "}}")?;
5787 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5788 }
5789 Unorm10_10_10_2 => {
5790 let name = self.namer.call("unpackUnorm10_10_10_2");
5791 writeln!(
5792 self.out,
5793 "metal::float4 {name}(uint b0, \
5794 uint b1, \
5795 uint b2, \
5796 uint b3) {{"
5797 )?;
5798 writeln!(
5799 self.out,
5800 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5812 back::INDENT
5813 )?;
5814 writeln!(self.out, "}}")?;
5815 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5816 }
5817 Unorm8x4Bgra => {
5818 let name = self.namer.call("unpackUnorm8x4Bgra");
5819 writeln!(
5820 self.out,
5821 "metal::float4 {name}(metal::uchar b0, \
5822 metal::uchar b1, \
5823 metal::uchar b2, \
5824 metal::uchar b3) {{"
5825 )?;
5826 writeln!(
5827 self.out,
5828 "{}return metal::float4(float(b2) / 255.0f, \
5829 float(b1) / 255.0f, \
5830 float(b0) / 255.0f, \
5831 float(b3) / 255.0f);",
5832 back::INDENT
5833 )?;
5834 writeln!(self.out, "}}")?;
5835 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5836 }
5837 Float64 | Float64x2 | Float64x3 | Float64x4 => unreachable!(),
5838 }
5839 }
5840
5841 fn write_wrapped_unary_op(
5842 &mut self,
5843 module: &crate::Module,
5844 func_ctx: &back::FunctionCtx,
5845 op: crate::UnaryOperator,
5846 operand: Handle<crate::Expression>,
5847 ) -> BackendResult {
5848 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5849 match op {
5850 crate::UnaryOperator::Negate
5857 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5858 {
5859 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5860 return Ok(());
5861 };
5862 let wrapped = WrappedFunction::UnaryOp {
5863 op,
5864 ty: (vector_size, scalar),
5865 };
5866 if !self.wrapped_functions.insert(wrapped) {
5867 return Ok(());
5868 }
5869
5870 let unsigned_scalar = crate::Scalar {
5871 kind: crate::ScalarKind::Uint,
5872 ..scalar
5873 };
5874 let mut type_name = String::new();
5875 let mut unsigned_type_name = String::new();
5876 match vector_size {
5877 None => {
5878 put_numeric_type(&mut type_name, scalar, &[])?;
5879 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5880 }
5881 Some(size) => {
5882 put_numeric_type(&mut type_name, scalar, &[size])?;
5883 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5884 }
5885 };
5886
5887 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5888 let level = back::Level(1);
5889 if scalar.width < 4 {
5893 writeln!(
5894 self.out,
5895 "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5896 )?;
5897 } else {
5898 writeln!(
5899 self.out,
5900 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5901 )?;
5902 }
5903 writeln!(self.out, "}}")?;
5904 writeln!(self.out)?;
5905 }
5906 _ => {}
5907 }
5908 Ok(())
5909 }
5910
5911 fn write_wrapped_binary_op(
5912 &mut self,
5913 module: &crate::Module,
5914 func_ctx: &back::FunctionCtx,
5915 expr: Handle<crate::Expression>,
5916 op: crate::BinaryOperator,
5917 left: Handle<crate::Expression>,
5918 right: Handle<crate::Expression>,
5919 ) -> BackendResult {
5920 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5921 let left_ty = func_ctx.resolve_type(left, &module.types);
5922 let right_ty = func_ctx.resolve_type(right, &module.types);
5923 match (op, expr_ty.scalar_kind()) {
5924 (
5931 crate::BinaryOperator::Divide,
5932 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5933 ) if self.emit_int_div_checks => {
5934 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5935 return Ok(());
5936 };
5937 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5938 return Ok(());
5939 };
5940 let wrapped = WrappedFunction::BinaryOp {
5941 op,
5942 left_ty: left_wrapped_ty,
5943 right_ty: right_wrapped_ty,
5944 };
5945 if !self.wrapped_functions.insert(wrapped) {
5946 return Ok(());
5947 }
5948
5949 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5950 return Ok(());
5951 };
5952 let mut type_name = String::new();
5953 match vector_size {
5954 None => put_numeric_type(&mut type_name, scalar, &[])?,
5955 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5956 };
5957 writeln!(
5958 self.out,
5959 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5960 )?;
5961 let level = back::Level(1);
5962 let (lp, rp) = if scalar.width < 4 {
5966 (format!("{type_name}("), ")".to_string())
5967 } else {
5968 (String::new(), String::new())
5969 };
5970 match scalar.kind {
5971 crate::ScalarKind::Sint => {
5972 let min_val = match scalar.width {
5973 2 => crate::Literal::I16(i16::MIN),
5974 4 => crate::Literal::I32(i32::MIN),
5975 8 => crate::Literal::I64(i64::MIN),
5976 _ => {
5977 return Err(Error::GenericValidation(format!(
5978 "Unexpected width for scalar {scalar:?}"
5979 )));
5980 }
5981 };
5982 write!(
5983 self.out,
5984 "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5985 )?;
5986 self.put_literal(min_val)?;
5987 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5988 }
5989 crate::ScalarKind::Uint => {
5990 let suffix = if scalar.width < 4 { "" } else { "u" };
5991 writeln!(
5992 self.out,
5993 "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5994 )?
5995 }
5996 _ => unreachable!(),
5997 }
5998 writeln!(self.out, "}}")?;
5999 writeln!(self.out)?;
6000 }
6001 (
6014 crate::BinaryOperator::Modulo,
6015 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
6016 ) if self.emit_int_div_checks => {
6017 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
6018 return Ok(());
6019 };
6020 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
6021 else {
6022 return Ok(());
6023 };
6024 let wrapped = WrappedFunction::BinaryOp {
6025 op,
6026 left_ty: left_wrapped_ty,
6027 right_ty: (right_vector_size, right_scalar),
6028 };
6029 if !self.wrapped_functions.insert(wrapped) {
6030 return Ok(());
6031 }
6032
6033 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
6034 return Ok(());
6035 };
6036 let mut type_name = String::new();
6037 match vector_size {
6038 None => put_numeric_type(&mut type_name, scalar, &[])?,
6039 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
6040 };
6041 let mut rhs_type_name = String::new();
6042 match right_vector_size {
6043 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
6044 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
6045 };
6046
6047 writeln!(
6048 self.out,
6049 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
6050 )?;
6051 let level = back::Level(1);
6052 let (lp, rp) = if scalar.width < 4 {
6053 (format!("{type_name}("), ")".to_string())
6054 } else {
6055 (String::new(), String::new())
6056 };
6057 match scalar.kind {
6058 crate::ScalarKind::Sint => {
6059 let min_val = match scalar.width {
6060 2 => crate::Literal::I16(i16::MIN),
6061 4 => crate::Literal::I32(i32::MIN),
6062 8 => crate::Literal::I64(i64::MIN),
6063 _ => {
6064 return Err(Error::GenericValidation(format!(
6065 "Unexpected width for scalar {scalar:?}"
6066 )));
6067 }
6068 };
6069 write!(
6070 self.out,
6071 "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
6072 )?;
6073 self.put_literal(min_val)?;
6074 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
6075 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
6076 }
6077 crate::ScalarKind::Uint => {
6078 let suffix = if scalar.width < 4 { "" } else { "u" };
6079 writeln!(
6080 self.out,
6081 "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
6082 )?
6083 }
6084 _ => unreachable!(),
6085 }
6086 writeln!(self.out, "}}")?;
6087 writeln!(self.out)?;
6088 }
6089 _ => {}
6090 }
6091 Ok(())
6092 }
6093
6094 fn get_dot_wrapper_function_helper_name(
6100 &self,
6101 scalar: crate::Scalar,
6102 size: crate::VectorSize,
6103 ) -> String {
6104 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
6106
6107 let type_name = scalar.to_msl_name();
6108 let size_suffix = common::vector_size_str(size);
6109 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
6110 }
6111
6112 #[allow(clippy::too_many_arguments)]
6113 fn write_wrapped_math_function(
6114 &mut self,
6115 module: &crate::Module,
6116 func_ctx: &back::FunctionCtx,
6117 fun: crate::MathFunction,
6118 arg: Handle<crate::Expression>,
6119 _arg1: Option<Handle<crate::Expression>>,
6120 _arg2: Option<Handle<crate::Expression>>,
6121 _arg3: Option<Handle<crate::Expression>>,
6122 ) -> BackendResult {
6123 let arg_ty = func_ctx.resolve_type(arg, &module.types);
6124 match fun {
6125 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
6133 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
6134 return Ok(());
6135 };
6136 let wrapped = WrappedFunction::Math {
6137 fun,
6138 arg_ty: (vector_size, scalar),
6139 };
6140 if !self.wrapped_functions.insert(wrapped) {
6141 return Ok(());
6142 }
6143
6144 let unsigned_scalar = crate::Scalar {
6145 kind: crate::ScalarKind::Uint,
6146 ..scalar
6147 };
6148 let mut type_name = String::new();
6149 let mut unsigned_type_name = String::new();
6150 match vector_size {
6151 None => {
6152 put_numeric_type(&mut type_name, scalar, &[])?;
6153 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
6154 }
6155 Some(size) => {
6156 put_numeric_type(&mut type_name, scalar, &[size])?;
6157 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
6158 }
6159 };
6160
6161 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
6162 let level = back::Level(1);
6163 let zero = if scalar.width < 4 {
6164 format!("{type_name}(0)")
6165 } else {
6166 "0".to_string()
6167 };
6168 let neg_expr = if scalar.width < 4 {
6169 format!(
6170 "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
6171 )
6172 } else {
6173 format!("-as_type<{unsigned_type_name}>(val)")
6174 };
6175 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
6176 writeln!(self.out, "}}")?;
6177 writeln!(self.out)?;
6178 }
6179
6180 crate::MathFunction::Dot => match *arg_ty {
6181 crate::TypeInner::Vector { size, scalar }
6182 if matches!(
6183 scalar.kind,
6184 crate::ScalarKind::Sint | crate::ScalarKind::Uint
6185 ) =>
6186 {
6187 let wrapped = WrappedFunction::Math {
6189 fun,
6190 arg_ty: (Some(size), scalar),
6191 };
6192 if !self.wrapped_functions.insert(wrapped) {
6193 return Ok(());
6194 }
6195
6196 let mut vec_ty = String::new();
6197 put_numeric_type(&mut vec_ty, scalar, &[size])?;
6198 let mut ret_ty = String::new();
6199 put_numeric_type(&mut ret_ty, scalar, &[])?;
6200
6201 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6202
6203 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6205 let level = back::Level(1);
6206 write!(self.out, "{level}return ")?;
6207 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6208 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6209 Ok(())
6210 })?;
6211 writeln!(self.out, ";")?;
6212 writeln!(self.out, "}}")?;
6213 writeln!(self.out)?;
6214 }
6215 _ => {}
6216 },
6217
6218 _ => {}
6219 }
6220 Ok(())
6221 }
6222
6223 fn write_wrapped_cast(
6224 &mut self,
6225 module: &crate::Module,
6226 func_ctx: &back::FunctionCtx,
6227 expr: Handle<crate::Expression>,
6228 kind: crate::ScalarKind,
6229 convert: Option<crate::Bytes>,
6230 ) -> BackendResult {
6231 let src_ty = func_ctx.resolve_type(expr, &module.types);
6242 let Some(width) = convert else {
6243 return Ok(());
6244 };
6245 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6246 return Ok(());
6247 };
6248 let dst_scalar = crate::Scalar { kind, width };
6249 if src_scalar.kind != crate::ScalarKind::Float
6250 || (dst_scalar.kind != crate::ScalarKind::Sint
6251 && dst_scalar.kind != crate::ScalarKind::Uint)
6252 {
6253 return Ok(());
6254 }
6255 let wrapped = WrappedFunction::Cast {
6256 src_scalar,
6257 vector_size,
6258 dst_scalar,
6259 };
6260 if !self.wrapped_functions.insert(wrapped) {
6261 return Ok(());
6262 }
6263 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6264
6265 let mut src_type_name = String::new();
6266 match vector_size {
6267 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6268 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6269 };
6270 let mut dst_type_name = String::new();
6271 match vector_size {
6272 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6273 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6274 };
6275 let fun_name = match dst_scalar {
6276 crate::Scalar::I32 => F2I32_FUNCTION,
6277 crate::Scalar::U32 => F2U32_FUNCTION,
6278 crate::Scalar::I64 => F2I64_FUNCTION,
6279 crate::Scalar::U64 => F2U64_FUNCTION,
6280 _ => unreachable!(),
6281 };
6282
6283 writeln!(
6284 self.out,
6285 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6286 )?;
6287 let level = back::Level(1);
6288 write!(
6289 self.out,
6290 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6291 )?;
6292 self.put_literal(min)?;
6293 write!(self.out, ", ")?;
6294 self.put_literal(max)?;
6295 writeln!(self.out, "));")?;
6296 writeln!(self.out, "}}")?;
6297 writeln!(self.out)?;
6298 Ok(())
6299 }
6300
6301 fn write_convert_yuv_to_rgb_and_return(
6309 &mut self,
6310 level: back::Level,
6311 y: &str,
6312 uv: &str,
6313 params: &str,
6314 ) -> BackendResult {
6315 let l1 = level;
6316 let l2 = l1.next();
6317
6318 writeln!(
6320 self.out,
6321 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6322 )?;
6323
6324 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6327 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6328 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6329 writeln!(
6330 self.out,
6331 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6332 )?;
6333
6334 writeln!(
6337 self.out,
6338 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6339 )?;
6340
6341 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6344 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6345 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6346 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6347
6348 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6349 Ok(())
6350 }
6351
6352 #[allow(clippy::too_many_arguments)]
6353 fn write_wrapped_image_load(
6354 &mut self,
6355 module: &crate::Module,
6356 func_ctx: &back::FunctionCtx,
6357 image: Handle<crate::Expression>,
6358 _coordinate: Handle<crate::Expression>,
6359 _array_index: Option<Handle<crate::Expression>>,
6360 _sample: Option<Handle<crate::Expression>>,
6361 _level: Option<Handle<crate::Expression>>,
6362 ) -> BackendResult {
6363 let class = match *func_ctx.resolve_type(image, &module.types) {
6365 crate::TypeInner::Image { class, .. } => class,
6366 _ => unreachable!(),
6367 };
6368 if class != crate::ImageClass::External {
6369 return Ok(());
6370 }
6371 let wrapped = WrappedFunction::ImageLoad { class };
6372 if !self.wrapped_functions.insert(wrapped) {
6373 return Ok(());
6374 }
6375
6376 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6377 let l1 = back::Level(1);
6378 let l2 = l1.next();
6379 let l3 = l2.next();
6380 writeln!(
6381 self.out,
6382 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6383 )?;
6384 writeln!(
6388 self.out,
6389 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6390 )?;
6391 writeln!(
6392 self.out,
6393 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6394 )?;
6395
6396 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6398 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6399 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6401 writeln!(self.out, "{l1}}} else {{")?;
6402
6403 writeln!(
6405 self.out,
6406 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6407 )?;
6408 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6409
6410 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6412
6413 writeln!(self.out, "{l2}float2 uv;")?;
6414 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6415 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6417 writeln!(self.out, "{l2}}} else {{")?;
6418 writeln!(
6420 self.out,
6421 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6422 )?;
6423 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6424 writeln!(
6425 self.out,
6426 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6427 )?;
6428 writeln!(self.out, "{l2}}}")?;
6429
6430 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6431
6432 writeln!(self.out, "{l1}}}")?;
6433 writeln!(self.out, "}}")?;
6434 writeln!(self.out)?;
6435 Ok(())
6436 }
6437
6438 #[allow(clippy::too_many_arguments)]
6439 fn write_wrapped_image_sample(
6440 &mut self,
6441 module: &crate::Module,
6442 func_ctx: &back::FunctionCtx,
6443 image: Handle<crate::Expression>,
6444 _sampler: Handle<crate::Expression>,
6445 _gather: Option<crate::SwizzleComponent>,
6446 _coordinate: Handle<crate::Expression>,
6447 _array_index: Option<Handle<crate::Expression>>,
6448 _offset: Option<Handle<crate::Expression>>,
6449 _level: crate::SampleLevel,
6450 _depth_ref: Option<Handle<crate::Expression>>,
6451 clamp_to_edge: bool,
6452 ) -> BackendResult {
6453 if !clamp_to_edge {
6456 return Ok(());
6457 }
6458 let class = match *func_ctx.resolve_type(image, &module.types) {
6459 crate::TypeInner::Image { class, .. } => class,
6460 _ => unreachable!(),
6461 };
6462 let wrapped = WrappedFunction::ImageSample {
6463 class,
6464 clamp_to_edge: true,
6465 };
6466 if !self.wrapped_functions.insert(wrapped) {
6467 return Ok(());
6468 }
6469 match class {
6470 crate::ImageClass::External => {
6471 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6472 let l1 = back::Level(1);
6473 let l2 = l1.next();
6474 let l3 = l2.next();
6475 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6476 writeln!(
6477 self.out,
6478 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6479 )?;
6480
6481 writeln!(
6489 self.out,
6490 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6491 )?;
6492 writeln!(
6493 self.out,
6494 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6495 )?;
6496 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6497 writeln!(
6498 self.out,
6499 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6500 )?;
6501 writeln!(
6502 self.out,
6503 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6504 )?;
6505 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6506 writeln!(
6508 self.out,
6509 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6510 )?;
6511 writeln!(self.out, "{l1}}} else {{")?;
6512 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6513 writeln!(
6514 self.out,
6515 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6516 )?;
6517 writeln!(
6518 self.out,
6519 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6520 )?;
6521
6522 writeln!(
6524 self.out,
6525 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6526 )?;
6527 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6528 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6529 writeln!(
6531 self.out,
6532 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6533 )?;
6534 writeln!(self.out, "{l2}}} else {{")?;
6535 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6537 writeln!(
6538 self.out,
6539 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6540 )?;
6541 writeln!(
6542 self.out,
6543 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6544 )?;
6545 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6546 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6547 writeln!(self.out, "{l2}}}")?;
6548
6549 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6550
6551 writeln!(self.out, "{l1}}}")?;
6552 writeln!(self.out, "}}")?;
6553 writeln!(self.out)?;
6554 }
6555 _ => {
6556 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6557 let l1 = back::Level(1);
6558 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6559 writeln!(
6560 self.out,
6561 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6562 )?;
6563 writeln!(self.out, "}}")?;
6564 writeln!(self.out)?;
6565 }
6566 }
6567 Ok(())
6568 }
6569
6570 fn write_wrapped_image_query(
6571 &mut self,
6572 module: &crate::Module,
6573 func_ctx: &back::FunctionCtx,
6574 image: Handle<crate::Expression>,
6575 query: crate::ImageQuery,
6576 ) -> BackendResult {
6577 if !matches!(query, crate::ImageQuery::Size { .. }) {
6579 return Ok(());
6580 }
6581 let class = match *func_ctx.resolve_type(image, &module.types) {
6582 crate::TypeInner::Image { class, .. } => class,
6583 _ => unreachable!(),
6584 };
6585 if class != crate::ImageClass::External {
6586 return Ok(());
6587 }
6588 let wrapped = WrappedFunction::ImageQuerySize { class };
6589 if !self.wrapped_functions.insert(wrapped) {
6590 return Ok(());
6591 }
6592 writeln!(
6593 self.out,
6594 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6595 )?;
6596 let l1 = back::Level(1);
6597 let l2 = l1.next();
6598 writeln!(
6599 self.out,
6600 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6601 )?;
6602 writeln!(self.out, "{l2}return tex.params.size;")?;
6603 writeln!(self.out, "{l1}}} else {{")?;
6604 writeln!(
6606 self.out,
6607 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6608 )?;
6609 writeln!(self.out, "{l1}}}")?;
6610 writeln!(self.out, "}}")?;
6611 writeln!(self.out)?;
6612 Ok(())
6613 }
6614
6615 fn write_wrapped_cooperative_load(
6616 &mut self,
6617 module: &crate::Module,
6618 func_ctx: &back::FunctionCtx,
6619 columns: crate::CooperativeSize,
6620 rows: crate::CooperativeSize,
6621 pointer: Handle<crate::Expression>,
6622 ) -> BackendResult {
6623 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6624 let space = ptr_ty.pointer_space().unwrap();
6625 let space_name = space.to_msl_name().unwrap_or_default();
6626 let scalar = ptr_ty
6627 .pointer_base_type()
6628 .unwrap()
6629 .inner_with(&module.types)
6630 .scalar()
6631 .unwrap();
6632 let wrapped = WrappedFunction::CooperativeLoad {
6633 space_name,
6634 columns,
6635 rows,
6636 scalar,
6637 };
6638 if !self.wrapped_functions.insert(wrapped) {
6639 return Ok(());
6640 }
6641 let scalar_name = scalar.to_msl_name();
6642 writeln!(
6643 self.out,
6644 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6645 columns as u32, rows as u32,
6646 )?;
6647 let l1 = back::Level(1);
6648 writeln!(
6649 self.out,
6650 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6651 columns as u32, rows as u32
6652 )?;
6653 let matrix_origin = "0";
6654 writeln!(
6655 self.out,
6656 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6657 )?;
6658 writeln!(self.out, "{l1}return m;")?;
6659 writeln!(self.out, "}}")?;
6660 writeln!(self.out)?;
6661 Ok(())
6662 }
6663
6664 fn write_wrapped_cooperative_multiply_add(
6665 &mut self,
6666 module: &crate::Module,
6667 func_ctx: &back::FunctionCtx,
6668 space: crate::AddressSpace,
6669 a: Handle<crate::Expression>,
6670 b: Handle<crate::Expression>,
6671 c: Handle<crate::Expression>,
6672 ) -> BackendResult {
6673 let space_name = space.to_msl_name().unwrap_or_default();
6674 let (a_c, a_r, ab_scalar) = match *func_ctx.resolve_type(a, &module.types) {
6675 crate::TypeInner::CooperativeMatrix {
6676 columns,
6677 rows,
6678 scalar,
6679 ..
6680 } => (columns, rows, scalar),
6681 _ => unreachable!(),
6682 };
6683 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6684 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6685 _ => unreachable!(),
6686 };
6687 let c_scalar = match *func_ctx.resolve_type(c, &module.types) {
6688 crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar,
6689 _ => unreachable!(),
6690 };
6691 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6692 space_name,
6693 columns: b_c,
6694 rows: a_r,
6695 intermediate: a_c,
6696 ab_scalar,
6697 c_scalar,
6698 };
6699 if !self.wrapped_functions.insert(wrapped) {
6700 return Ok(());
6701 }
6702 let ab_scalar_name = ab_scalar.to_msl_name();
6703 let c_scalar_name = c_scalar.to_msl_name();
6704 writeln!(
6705 self.out,
6706 "{NAMESPACE}::simdgroup_{c_scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{ab_scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{ab_scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{c_scalar_name}{}x{}& c) {{",
6707 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6708 )?;
6709 let l1 = back::Level(1);
6710 writeln!(
6711 self.out,
6712 "{l1}{NAMESPACE}::simdgroup_{c_scalar_name}{}x{} d;",
6713 b_c as u32, a_r as u32
6714 )?;
6715 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6716 writeln!(self.out, "{l1}return d;")?;
6717 writeln!(self.out, "}}")?;
6718 writeln!(self.out)?;
6719 Ok(())
6720 }
6721
6722 pub(super) fn write_wrapped_functions(
6723 &mut self,
6724 module: &crate::Module,
6725 func_ctx: &back::FunctionCtx,
6726 options: &Options,
6727 ) -> BackendResult {
6728 for (expr_handle, expr) in func_ctx.expressions.iter() {
6729 match *expr {
6730 crate::Expression::Unary { op, expr: operand } => {
6731 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6732 }
6733 crate::Expression::Binary { op, left, right } => {
6734 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6735 }
6736 crate::Expression::Math {
6737 fun,
6738 arg,
6739 arg1,
6740 arg2,
6741 arg3,
6742 } => {
6743 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6744 }
6745 crate::Expression::As {
6746 expr,
6747 kind,
6748 convert,
6749 } => {
6750 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6751 }
6752 crate::Expression::ImageLoad {
6753 image,
6754 coordinate,
6755 array_index,
6756 sample,
6757 level,
6758 } => {
6759 self.write_wrapped_image_load(
6760 module,
6761 func_ctx,
6762 image,
6763 coordinate,
6764 array_index,
6765 sample,
6766 level,
6767 )?;
6768 }
6769 crate::Expression::ImageSample {
6770 image,
6771 sampler,
6772 gather,
6773 coordinate,
6774 array_index,
6775 offset,
6776 level,
6777 depth_ref,
6778 clamp_to_edge,
6779 } => {
6780 self.write_wrapped_image_sample(
6781 module,
6782 func_ctx,
6783 image,
6784 sampler,
6785 gather,
6786 coordinate,
6787 array_index,
6788 offset,
6789 level,
6790 depth_ref,
6791 clamp_to_edge,
6792 )?;
6793 }
6794 crate::Expression::ImageQuery { image, query } => {
6795 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6796 }
6797 crate::Expression::CooperativeLoad {
6798 columns,
6799 rows,
6800 role: _,
6801 ref data,
6802 } => {
6803 self.write_wrapped_cooperative_load(
6804 module,
6805 func_ctx,
6806 columns,
6807 rows,
6808 data.pointer,
6809 )?;
6810 }
6811 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
6812 let space = crate::AddressSpace::Private;
6813 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b, c)?;
6814 }
6815 crate::Expression::RayQueryGetIntersection { committed, .. } => {
6816 self.write_rq_get_intersection_function(module, committed, options)?;
6817 }
6818 _ => {}
6819 }
6820 }
6821
6822 Ok(())
6823 }
6824
6825 fn write_functions(
6827 &mut self,
6828 module: &crate::Module,
6829 mod_info: &valid::ModuleInfo,
6830 options: &Options,
6831 pipeline_options: &PipelineOptions,
6832 ) -> Result<TranslationInfo, Error> {
6833 use nt::VertexFormat;
6834
6835 struct AttributeMappingResolved {
6838 ty_name: String,
6839 dimension: Option<crate::VectorSize>,
6840 scalar: crate::Scalar,
6841 name: String,
6842 }
6843 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6844
6845 struct VertexBufferMappingResolved<'a> {
6846 id: u32,
6847 stride: u32,
6848 step_mode: back::msl::VertexBufferStepMode,
6849 ty_name: String,
6850 param_name: String,
6851 elem_name: String,
6852 attributes: &'a Vec<back::msl::AttributeMapping>,
6853 }
6854 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6855
6856 struct UnpackingFunction {
6858 name: String,
6859 byte_count: u32,
6860 dimension: Option<crate::VectorSize>,
6861 scalar: crate::Scalar,
6862 }
6863 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6864
6865 let mut needs_vertex_id = false;
6871 let v_id = self.namer.call("v_id");
6872
6873 let mut needs_instance_id = false;
6874 let i_id = self.namer.call("i_id");
6875 if pipeline_options.vertex_pulling_transform {
6876 for vbm in &pipeline_options.vertex_buffer_mappings {
6877 let buffer_id = vbm.id;
6878 let buffer_stride = vbm.stride;
6879
6880 assert!(
6881 buffer_stride > 0,
6882 "Vertex pulling requires a non-zero buffer stride."
6883 );
6884
6885 match vbm.step_mode {
6886 back::msl::VertexBufferStepMode::Constant => {}
6887 back::msl::VertexBufferStepMode::ByVertex => {
6888 needs_vertex_id = true;
6889 }
6890 back::msl::VertexBufferStepMode::ByInstance => {
6891 needs_instance_id = true;
6892 }
6893 }
6894
6895 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6896 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6897 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6898
6899 vbm_resolved.push(VertexBufferMappingResolved {
6900 id: buffer_id,
6901 stride: buffer_stride,
6902 step_mode: vbm.step_mode,
6903 ty_name: buffer_ty,
6904 param_name: buffer_param,
6905 elem_name: buffer_elem,
6906 attributes: &vbm.attributes,
6907 });
6908
6909 for attribute in &vbm.attributes {
6911 if unpacking_functions.contains_key(&attribute.format) {
6912 continue;
6913 }
6914 let (name, byte_count, dimension, scalar) =
6915 match self.write_unpacking_function(attribute.format) {
6916 Ok((name, byte_count, dimension, scalar)) => {
6917 (name, byte_count, dimension, scalar)
6918 }
6919 _ => {
6920 continue;
6921 }
6922 };
6923 unpacking_functions.insert(
6924 attribute.format,
6925 UnpackingFunction {
6926 name,
6927 byte_count,
6928 dimension,
6929 scalar,
6930 },
6931 );
6932 }
6933 }
6934 }
6935
6936 let mut pass_through_globals = Vec::new();
6937 for (fun_handle, fun) in module.functions.iter() {
6938 log::trace!(
6939 "function {:?}, handle {:?}",
6940 fun.name.as_deref().unwrap_or("(anonymous)"),
6941 fun_handle
6942 );
6943
6944 let ctx = back::FunctionCtx {
6945 ty: back::FunctionType::Function(fun_handle),
6946 info: &mod_info[fun_handle],
6947 expressions: &fun.expressions,
6948 named_expressions: &fun.named_expressions,
6949 };
6950
6951 writeln!(self.out)?;
6952 self.write_wrapped_functions(module, &ctx, options)?;
6953
6954 let fun_info = &mod_info[fun_handle];
6955 pass_through_globals.clear();
6956 let mut needs_buffer_sizes = false;
6957 for (handle, var) in module.global_variables.iter() {
6958 if !fun_info[handle].is_empty() {
6959 if var.space.needs_pass_through() {
6960 pass_through_globals.push(handle);
6961 }
6962 needs_buffer_sizes |= module.types[var.ty]
6963 .inner
6964 .needs_host_buffer_byte_size(&module.types);
6965 }
6966 }
6967
6968 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6969 match fun.result {
6970 Some(ref result) => {
6971 let ty_name = TypeContext {
6972 handle: result.ty,
6973 gctx: module.to_ctx(),
6974 names: &self.names,
6975 access: crate::StorageAccess::empty(),
6976 first_time: false,
6977 };
6978 write!(self.out, "{ty_name}")?;
6979 }
6980 None => {
6981 write!(self.out, "void")?;
6982 }
6983 }
6984 writeln!(self.out, " {fun_name}(")?;
6985
6986 for (index, arg) in fun.arguments.iter().enumerate() {
6987 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6988 let param_type_name = TypeContext {
6989 handle: arg.ty,
6990 gctx: module.to_ctx(),
6991 names: &self.names,
6992 access: crate::StorageAccess::empty(),
6993 first_time: false,
6994 };
6995 let separator = separate(
6996 !pass_through_globals.is_empty()
6997 || index + 1 != fun.arguments.len()
6998 || needs_buffer_sizes,
6999 );
7000 writeln!(
7001 self.out,
7002 "{}{} {}{}",
7003 back::INDENT,
7004 param_type_name,
7005 name,
7006 separator
7007 )?;
7008 }
7009 for (index, &handle) in pass_through_globals.iter().enumerate() {
7010 let tyvar = TypedGlobalVariable {
7011 module,
7012 names: &self.names,
7013 handle,
7014 usage: fun_info[handle],
7015 reference: true,
7016 };
7017 let separator =
7018 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
7019 write!(self.out, "{}", back::INDENT)?;
7020 tyvar.try_fmt(&mut self.out)?;
7021 writeln!(self.out, "{separator}")?;
7022 }
7023
7024 if needs_buffer_sizes {
7025 writeln!(
7026 self.out,
7027 "{}constant _mslBufferSizes& _buffer_sizes",
7028 back::INDENT
7029 )?;
7030 }
7031
7032 writeln!(self.out, ") {{")?;
7033
7034 let guarded_indices =
7035 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7036
7037 let context = StatementContext {
7038 expression: ExpressionContext {
7039 function: fun,
7040 origin: FunctionOrigin::Handle(fun_handle),
7041 info: fun_info,
7042 lang_version: options.lang_version,
7043 policies: options.bounds_check_policies,
7044 guarded_indices,
7045 module,
7046 mod_info,
7047 pipeline_options,
7048 force_loop_bounding: options.force_loop_bounding,
7049 emit_int_div_checks: options.emit_int_div_checks,
7050 ray_query_initialization_tracking: options.ray_query_initialization_tracking,
7051 },
7052 result_struct: None,
7053 };
7054
7055 self.put_locals(&context.expression)?;
7056 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7057 self.put_block(back::Level(1), &fun.body, &context)?;
7058 writeln!(self.out, "}}")?;
7059 self.named_expressions.clear();
7060 }
7061
7062 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
7063 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
7064
7065 let mut info = TranslationInfo {
7066 entry_point_names: Vec::with_capacity(ep_range.len()),
7067 };
7068
7069 for ep_index in ep_range {
7070 let ep = &module.entry_points[ep_index];
7071 let fun = &ep.function;
7072 let fun_info = mod_info.get_entry_point(ep_index);
7073 let mut ep_error = None;
7074
7075 let mut v_existing_id = None;
7079 let mut i_existing_id = None;
7080
7081 log::trace!(
7082 "entry point {:?}, index {:?}",
7083 fun.name.as_deref().unwrap_or("(anonymous)"),
7084 ep_index
7085 );
7086
7087 let ctx = back::FunctionCtx {
7088 ty: back::FunctionType::EntryPoint(ep_index as u16),
7089 info: fun_info,
7090 expressions: &fun.expressions,
7091 named_expressions: &fun.named_expressions,
7092 };
7093
7094 self.write_wrapped_functions(module, &ctx, options)?;
7095
7096 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
7097 crate::ShaderStage::Vertex => (
7098 Some("vertex"),
7099 LocationMode::VertexInput,
7100 LocationMode::VertexOutput,
7101 true,
7102 ),
7103 crate::ShaderStage::Fragment => (
7104 Some("fragment"),
7105 LocationMode::FragmentInput,
7106 LocationMode::FragmentOutput,
7107 false,
7108 ),
7109 crate::ShaderStage::Compute => (
7110 Some("kernel"),
7111 LocationMode::Uniform,
7112 LocationMode::Uniform,
7113 false,
7114 ),
7115 crate::ShaderStage::Task => {
7116 (None, LocationMode::Uniform, LocationMode::Uniform, false)
7117 }
7118 crate::ShaderStage::Mesh => {
7119 (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
7120 }
7121 crate::ShaderStage::RayGeneration
7122 | crate::ShaderStage::AnyHit
7123 | crate::ShaderStage::ClosestHit
7124 | crate::ShaderStage::Miss => unimplemented!(),
7125 };
7126
7127 let do_vertex_pulling = can_vertex_pull
7129 && pipeline_options.vertex_pulling_transform
7130 && !pipeline_options.vertex_buffer_mappings.is_empty();
7131
7132 let needs_buffer_sizes = do_vertex_pulling
7134 || module
7135 .global_variables
7136 .iter()
7137 .filter(|&(handle, _)| !fun_info[handle].is_empty())
7138 .any(|(_, var)| {
7139 module.types[var.ty]
7140 .inner
7141 .needs_host_buffer_byte_size(&module.types)
7142 });
7143
7144 if !options.fake_missing_bindings {
7147 for (var_handle, var) in module.global_variables.iter() {
7148 if fun_info[var_handle].is_empty() {
7149 continue;
7150 }
7151 match var.space {
7152 crate::AddressSpace::Uniform
7153 | crate::AddressSpace::Storage { .. }
7154 | crate::AddressSpace::Handle => {
7155 let br = match var.binding {
7156 Some(ref br) => br,
7157 None => {
7158 let var_name = var.name.clone().unwrap_or_default();
7159 ep_error =
7160 Some(super::EntryPointError::MissingBinding(var_name));
7161 break;
7162 }
7163 };
7164 let target = options.get_resource_binding_target(ep, br);
7165 let good = match target {
7166 Some(target) => {
7167 match module.types[var.ty].inner {
7171 crate::TypeInner::Image {
7172 class: crate::ImageClass::External,
7173 ..
7174 } => target.external_texture.is_some(),
7175 crate::TypeInner::Image { .. } => target.texture.is_some(),
7176 crate::TypeInner::Sampler { .. } => {
7177 target.sampler.is_some()
7178 }
7179 _ => target.buffer.is_some(),
7180 }
7181 }
7182 None => false,
7183 };
7184 if !good {
7185 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
7186 break;
7187 }
7188 }
7189 crate::AddressSpace::Immediate => {
7190 if let Err(e) = options.resolve_immediates(ep) {
7191 ep_error = Some(e);
7192 break;
7193 }
7194 }
7195 crate::AddressSpace::Function
7196 | crate::AddressSpace::Private
7197 | crate::AddressSpace::WorkGroup
7198 | crate::AddressSpace::TaskPayload => {}
7199 crate::AddressSpace::RayPayload
7200 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
7201 }
7202 }
7203 if needs_buffer_sizes {
7204 if let Err(err) = options.resolve_sizes_buffer(ep) {
7205 ep_error = Some(err);
7206 }
7207 }
7208 }
7209
7210 if let Some(err) = ep_error {
7211 info.entry_point_names.push(Err(err));
7212 continue;
7213 }
7214 let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7215 info.entry_point_names.push(Ok(fun_name.clone()));
7216
7217 writeln!(self.out)?;
7218
7219 let mut flattened_member_names = FastHashMap::default();
7225 let mut varyings_namer = proc::Namer::default();
7227
7228 let mut empty_names = FastHashMap::default(); varyings_namer.reset(
7230 module,
7231 &super::keywords::RESERVED_SET,
7232 proc::KeywordSet::empty(),
7233 proc::CaseInsensitiveKeywordSet::empty(),
7234 &[CLAMPED_LOD_LOAD_PREFIX],
7235 &mut empty_names,
7236 );
7237
7238 let mut flattened_arguments = Vec::new();
7243 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7244 match module.types[arg.ty].inner {
7245 crate::TypeInner::Struct { ref members, .. } => {
7246 for (member_index, member) in members.iter().enumerate() {
7247 let member_index = member_index as u32;
7248 flattened_arguments.push((
7249 NameKey::StructMember(arg.ty, member_index),
7250 member.ty,
7251 member.binding.as_ref(),
7252 ));
7253 let name_key = NameKey::StructMember(arg.ty, member_index);
7254 let name = match member.binding {
7255 Some(crate::Binding::Location { .. }) => {
7256 if do_vertex_pulling {
7257 self.namer.call(&self.names[&name_key])
7258 } else {
7259 varyings_namer.call(&self.names[&name_key])
7260 }
7261 }
7262 _ => self.namer.call(&self.names[&name_key]),
7263 };
7264 flattened_member_names.insert(name_key, name);
7265 }
7266 }
7267 _ => flattened_arguments.push((
7268 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7269 arg.ty,
7270 arg.binding.as_ref(),
7271 )),
7272 }
7273 }
7274
7275 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7280 let varyings_member_name = self.namer.call("varyings");
7281 let mut has_varyings = false;
7282
7283 if !flattened_arguments.is_empty() {
7284 if !do_vertex_pulling {
7285 writeln!(self.out, "struct {stage_in_name} {{")?;
7286 }
7287 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7288 let Some(binding) = binding else {
7289 continue;
7290 };
7291 let name = match *name_key {
7292 NameKey::StructMember(..) => &flattened_member_names[name_key],
7293 _ => &self.names[name_key],
7294 };
7295 let ty_name = TypeContext {
7296 handle: ty,
7297 gctx: module.to_ctx(),
7298 names: &self.names,
7299 access: crate::StorageAccess::empty(),
7300 first_time: false,
7301 };
7302 let resolved = options.resolve_local_binding(binding, in_mode)?;
7303 let location = match *binding {
7304 crate::Binding::Location { location, .. } => Some(location),
7305 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7306 crate::Binding::BuiltIn(_) => continue,
7307 };
7308 if do_vertex_pulling {
7309 let Some(location) = location else {
7310 continue;
7311 };
7312 am_resolved.insert(
7314 location,
7315 AttributeMappingResolved {
7316 ty_name: ty_name.to_string(),
7317 dimension: ty_name.vector_size(),
7318 scalar: ty_name.scalar().unwrap(),
7319 name: name.to_string(),
7320 },
7321 );
7322 } else {
7323 has_varyings = true;
7324 if let super::ResolvedBinding::User {
7325 prefix,
7326 index,
7327 interpolation: Some(super::ResolvedInterpolation::PerVertex),
7328 } = resolved
7329 {
7330 if options.lang_version < (4, 0) {
7331 return Err(Error::PerVertexNotSupported);
7332 }
7333 write!(
7334 self.out,
7335 "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7336 back::INDENT,
7337 ty_name.unwrap_array()
7338 )?;
7339 } else {
7340 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7341 resolved.try_fmt(&mut self.out)?;
7342 }
7343 writeln!(self.out, ";")?;
7344 }
7345 }
7346 if !do_vertex_pulling {
7347 writeln!(self.out, "}};")?;
7348 }
7349 }
7350
7351 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7354 let result_member_name = self.namer.call("member");
7355 let result_type_name = match fun.result {
7356 Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7357 let mut result_members = Vec::new();
7358 if let crate::TypeInner::Struct { ref members, .. } =
7359 module.types[result.ty].inner
7360 {
7361 for (member_index, member) in members.iter().enumerate() {
7362 result_members.push((
7363 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7364 member.ty,
7365 member.binding.as_ref(),
7366 ));
7367 }
7368 } else {
7369 result_members.push((
7370 &result_member_name,
7371 result.ty,
7372 result.binding.as_ref(),
7373 ));
7374 }
7375
7376 writeln!(self.out, "struct {stage_out_name} {{")?;
7377 let mut has_point_size = false;
7378 for (name, ty, binding) in result_members {
7379 let ty_name = TypeContext {
7380 handle: ty,
7381 gctx: module.to_ctx(),
7382 names: &self.names,
7383 access: crate::StorageAccess::empty(),
7384 first_time: true,
7385 };
7386 let binding = binding.ok_or_else(|| {
7387 Error::GenericValidation("Expected binding, got None".into())
7388 })?;
7389
7390 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7391 has_point_size = true;
7392 if !pipeline_options.allow_and_force_point_size {
7393 continue;
7394 }
7395 }
7396
7397 let array_len = match module.types[ty].inner {
7398 crate::TypeInner::Array {
7399 size: crate::ArraySize::Constant(size),
7400 ..
7401 } => Some(size),
7402 _ => None,
7403 };
7404 let resolved = options.resolve_local_binding(binding, out_mode)?;
7405 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7406 resolved.try_fmt(&mut self.out)?;
7407 if let Some(array_len) = array_len {
7408 write!(self.out, " [{array_len}]")?;
7409 }
7410 writeln!(self.out, ";")?;
7411 }
7412
7413 if pipeline_options.allow_and_force_point_size
7414 && ep.stage == crate::ShaderStage::Vertex
7415 && !has_point_size
7416 {
7417 writeln!(
7419 self.out,
7420 "{}float _point_size [[point_size]];",
7421 back::INDENT
7422 )?;
7423 }
7424 writeln!(self.out, "}};")?;
7425 &stage_out_name
7426 }
7427 Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7428 assert_eq!(
7429 module.types[result.ty].inner,
7430 crate::TypeInner::Vector {
7431 size: crate::VectorSize::Tri,
7432 scalar: crate::Scalar::U32
7433 }
7434 );
7435
7436 "metal::uint3"
7437 }
7438 _ => "void",
7439 };
7440
7441 let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7442 Some(self.write_mesh_output_types(
7443 mesh_info,
7444 &fun_name,
7445 module,
7446 pipeline_options.allow_and_force_point_size,
7447 options,
7448 )?)
7449 } else {
7450 None
7451 };
7452
7453 if do_vertex_pulling {
7456 for vbm in &vbm_resolved {
7457 let buffer_stride = vbm.stride;
7458 let buffer_ty = &vbm.ty_name;
7459
7460 writeln!(
7464 self.out,
7465 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7466 )?;
7467 }
7468 }
7469
7470 let is_wrapped = matches!(
7471 ep.stage,
7472 crate::ShaderStage::Task | crate::ShaderStage::Mesh
7473 );
7474 let fun_name = fun_name.clone();
7475 let nested_fun_name = if is_wrapped {
7476 self.namer.call(&format!("_{fun_name}"))
7477 } else {
7478 fun_name.clone()
7479 };
7480
7481 if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7483 let total_threads =
7484 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7485 write!(
7486 self.out,
7487 "[[max_total_threads_per_threadgroup({total_threads})]] "
7488 )?;
7489 }
7490
7491 if let Some(em_str) = em_str {
7493 write!(self.out, "{em_str} ")?;
7494 }
7495 writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7496
7497 let mut args = Vec::new();
7498
7499 if has_varyings {
7502 args.push(EntryPointArgument {
7503 ty_name: stage_in_name,
7504 name: varyings_member_name.clone(),
7505 binding: " [[stage_in]]".to_string(),
7506 init: None,
7507 });
7508 }
7509
7510 let mut local_invocation_index = None;
7511
7512 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7515 let binding = match binding {
7516 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7517 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7518 _ => continue,
7519 };
7520 let name = match *name_key {
7521 NameKey::StructMember(..) => &flattened_member_names[name_key],
7522 _ => &self.names[name_key],
7523 };
7524
7525 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7526 local_invocation_index = Some(name_key);
7527 }
7528
7529 let ty_name = TypeContext {
7530 handle: ty,
7531 gctx: module.to_ctx(),
7532 names: &self.names,
7533 access: crate::StorageAccess::empty(),
7534 first_time: false,
7535 };
7536
7537 match *binding {
7538 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7539 v_existing_id = Some(name.clone());
7540 }
7541 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7542 i_existing_id = Some(name.clone());
7543 }
7544 _ => {}
7545 };
7546
7547 let resolved = options.resolve_local_binding(binding, in_mode)?;
7548 let mut binding = String::new();
7549 resolved.try_fmt(&mut binding)?;
7550
7551 args.push(EntryPointArgument {
7552 ty_name: format!("{ty_name}"),
7553 name: name.clone(),
7554 binding,
7555 init: None,
7556 });
7557 }
7558
7559 let need_workgroup_variables_initialization =
7560 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7561
7562 if local_invocation_index.is_none()
7563 && (need_workgroup_variables_initialization
7564 || ep.stage == crate::ShaderStage::Task
7565 || ep.stage == crate::ShaderStage::Mesh)
7566 {
7567 args.push(EntryPointArgument {
7568 ty_name: "uint".to_string(),
7569 name: "__local_invocation_index".to_string(),
7570 binding: " [[thread_index_in_threadgroup]]".to_string(),
7571 init: None,
7572 });
7573 }
7574
7575 for (handle, var) in module.global_variables.iter() {
7580 let usage = fun_info[handle];
7581 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7582 continue;
7583 }
7584
7585 if options.lang_version < (1, 2) {
7586 match var.space {
7587 crate::AddressSpace::Storage { access }
7597 if access.contains(crate::StorageAccess::STORE)
7598 && ep.stage == crate::ShaderStage::Fragment =>
7599 {
7600 return Err(Error::UnsupportedWritableStorageBuffer)
7601 }
7602 crate::AddressSpace::Handle => {
7603 match module.types[var.ty].inner {
7604 crate::TypeInner::Image {
7605 class: crate::ImageClass::Storage { access, .. },
7606 ..
7607 } => {
7608 if access.contains(crate::StorageAccess::STORE)
7618 && (ep.stage == crate::ShaderStage::Vertex
7619 || ep.stage == crate::ShaderStage::Fragment)
7620 {
7621 return Err(Error::UnsupportedWritableStorageTexture(
7622 ep.stage,
7623 ));
7624 }
7625
7626 if access.contains(
7627 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7628 ) {
7629 return Err(Error::UnsupportedRWStorageTexture);
7630 }
7631 }
7632 _ => {}
7633 }
7634 }
7635 _ => {}
7636 }
7637 }
7638
7639 match var.space {
7641 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7642 crate::TypeInner::BindingArray { base, .. } => {
7643 match module.types[base].inner {
7644 crate::TypeInner::Sampler { .. } => {
7645 if options.lang_version < (2, 0) {
7646 return Err(Error::UnsupportedArrayOf(
7647 "samplers".to_string(),
7648 ));
7649 }
7650 }
7651 crate::TypeInner::Image { class, .. } => match class {
7652 crate::ImageClass::Sampled { .. }
7653 | crate::ImageClass::Depth { .. }
7654 | crate::ImageClass::Storage {
7655 access: crate::StorageAccess::LOAD,
7656 ..
7657 } => {
7658 if options.lang_version < (2, 0) {
7663 return Err(Error::UnsupportedArrayOf(
7664 "textures".to_string(),
7665 ));
7666 }
7667 }
7668 crate::ImageClass::Storage {
7669 access: crate::StorageAccess::STORE,
7670 ..
7671 } => {
7672 if options.lang_version < (2, 0) {
7677 return Err(Error::UnsupportedArrayOf(
7678 "write-only textures".to_string(),
7679 ));
7680 }
7681 }
7682 crate::ImageClass::Storage { .. } => {
7683 if options.lang_version < (3, 0) {
7684 return Err(Error::UnsupportedArrayOf(
7685 "read-write textures".to_string(),
7686 ));
7687 }
7688 }
7689 crate::ImageClass::External => {
7690 return Err(Error::UnsupportedArrayOf(
7691 "external textures".to_string(),
7692 ));
7693 }
7694 },
7695 _ => {
7696 return Err(Error::UnsupportedArrayOfType(base));
7697 }
7698 }
7699 }
7700 _ => {}
7701 },
7702 _ => {}
7703 }
7704
7705 let resolved = match var.space {
7707 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7708 crate::AddressSpace::WorkGroup => None,
7709 crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7710 _ => options
7711 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7712 .ok(),
7713 };
7714 if let Some(ref resolved) = resolved {
7715 if resolved.as_inline_sampler(options).is_some() {
7717 continue;
7718 }
7719 }
7720
7721 match module.types[var.ty].inner {
7722 crate::TypeInner::Image {
7723 class: crate::ImageClass::External,
7724 ..
7725 } => {
7726 let target = match resolved {
7730 Some(back::msl::ResolvedBinding::Resource(target)) => {
7731 target.external_texture
7732 }
7733 _ => None,
7734 };
7735
7736 for i in 0..3 {
7737 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7738 handle,
7739 ExternalTextureNameKey::Plane(i),
7740 )];
7741 let ty_name = format!(
7742 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7743 );
7744 let name = plane_name.clone();
7745 let binding = if let Some(ref target) = target {
7746 format!(" [[texture({})]]", target.planes[i])
7747 } else {
7748 String::new()
7749 };
7750 args.push(EntryPointArgument {
7751 ty_name,
7752 name,
7753 binding,
7754 init: None,
7755 });
7756 }
7757 let params_ty_name = &self.names
7758 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7759 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7760 handle,
7761 ExternalTextureNameKey::Params,
7762 )];
7763 let binding = if let Some(ref target) = target {
7764 format!(" [[buffer({})]]", target.params)
7765 } else {
7766 String::new()
7767 };
7768
7769 args.push(EntryPointArgument {
7770 ty_name: format!("constant {params_ty_name}&"),
7771 name: params_name.clone(),
7772 binding,
7773 init: None,
7774 });
7775 }
7776 _ => {
7777 if var.space == crate::AddressSpace::WorkGroup
7778 && ep.stage == crate::ShaderStage::Mesh
7779 {
7780 continue;
7781 }
7782 let tyvar = TypedGlobalVariable {
7783 module,
7784 names: &self.names,
7785 handle,
7786 usage,
7787 reference: true,
7788 };
7789 let parts = tyvar.to_parts()?;
7790 let mut binding = String::new();
7791 if let Some(resolved) = resolved {
7792 resolved.try_fmt(&mut binding)?;
7793 }
7794 args.push(EntryPointArgument {
7795 ty_name: parts.ty_name,
7796 name: parts.var_name,
7797 binding,
7798 init: var.init,
7799 });
7800 }
7801 }
7802 }
7803
7804 if do_vertex_pulling {
7805 if needs_vertex_id && v_existing_id.is_none() {
7806 args.push(EntryPointArgument {
7808 ty_name: "uint".to_string(),
7809 name: v_id.clone(),
7810 binding: " [[vertex_id]]".to_string(),
7811 init: None,
7812 });
7813 }
7814
7815 if needs_instance_id && i_existing_id.is_none() {
7816 args.push(EntryPointArgument {
7817 ty_name: "uint".to_string(),
7818 name: i_id.clone(),
7819 binding: " [[instance_id]]".to_string(),
7820 init: None,
7821 });
7822 }
7823
7824 for vbm in &vbm_resolved {
7827 let id = &vbm.id;
7828 let ty_name = &vbm.ty_name;
7829 let param_name = &vbm.param_name;
7830 args.push(EntryPointArgument {
7831 ty_name: format!("const device {ty_name}*"),
7832 name: param_name.clone(),
7833 binding: format!(" [[buffer({id})]]"),
7834 init: None,
7835 });
7836 }
7837 }
7838
7839 if needs_buffer_sizes {
7842 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7844 let mut binding = String::new();
7845 resolved.try_fmt(&mut binding)?;
7846 args.push(EntryPointArgument {
7847 ty_name: "constant _mslBufferSizes&".to_string(),
7848 name: "_buffer_sizes".to_string(),
7849 binding,
7850 init: None,
7851 });
7852 }
7853
7854 let mut is_first_arg = true;
7855 for arg in &args {
7856 if is_first_arg {
7857 write!(self.out, " ")?;
7858 } else {
7859 write!(self.out, ", ")?;
7860 }
7861 is_first_arg = false;
7862 write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7863 if !is_wrapped {
7864 write!(self.out, "{}", arg.binding)?;
7865 if let Some(init) = arg.init {
7866 write!(self.out, " = ")?;
7867 self.put_const_expression(
7868 init,
7869 module,
7870 mod_info,
7871 &module.global_expressions,
7872 )?;
7873 }
7874 }
7875 writeln!(self.out)?;
7876 }
7877 if ep.stage == crate::ShaderStage::Mesh {
7878 for (handle, var) in module.global_variables.iter() {
7879 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7880 continue;
7881 }
7882 if is_first_arg {
7883 write!(self.out, " ")?;
7884 } else {
7885 write!(self.out, ", ")?;
7886 }
7887 let ty_context = TypeContext {
7888 handle: module.global_variables[handle].ty,
7889 gctx: module.to_ctx(),
7890 names: &self.names,
7891 access: crate::StorageAccess::empty(),
7892 first_time: false,
7893 };
7894 writeln!(
7895 self.out,
7896 "threadgroup {ty_context}& {}",
7897 self.names[&NameKey::GlobalVariable(handle)]
7898 )?;
7899 }
7900 }
7901
7902 writeln!(self.out, ") {{")?;
7904
7905 if do_vertex_pulling {
7907 for vbm in &vbm_resolved {
7910 for attribute in vbm.attributes {
7911 let location = attribute.shader_location;
7912 let am_option = am_resolved.get(&location);
7913 if am_option.is_none() {
7914 continue;
7917 }
7918 let am = am_option.unwrap();
7919 let attribute_ty_name = &am.ty_name;
7920 let attribute_name = &am.name;
7921
7922 writeln!(
7923 self.out,
7924 "{}{attribute_ty_name} {attribute_name} = {{}};",
7925 back::Level(1)
7926 )?;
7927 }
7928
7929 write!(self.out, "{}if (", back::Level(1))?;
7932
7933 let idx = &vbm.id;
7934 let stride = &vbm.stride;
7935 let index_name = match vbm.step_mode {
7936 back::msl::VertexBufferStepMode::Constant => "0",
7937 back::msl::VertexBufferStepMode::ByVertex => {
7938 if let Some(ref name) = v_existing_id {
7939 name
7940 } else {
7941 &v_id
7942 }
7943 }
7944 back::msl::VertexBufferStepMode::ByInstance => {
7945 if let Some(ref name) = i_existing_id {
7946 name
7947 } else {
7948 &i_id
7949 }
7950 }
7951 };
7952 write!(
7953 self.out,
7954 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7955 )?;
7956
7957 writeln!(self.out, ") {{")?;
7958
7959 let ty_name = &vbm.ty_name;
7961 let elem_name = &vbm.elem_name;
7962 let param_name = &vbm.param_name;
7963
7964 writeln!(
7965 self.out,
7966 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7967 back::Level(2),
7968 )?;
7969
7970 for attribute in vbm.attributes {
7973 let location = attribute.shader_location;
7974 let Some(am) = am_resolved.get(&location) else {
7975 continue;
7979 };
7980 let attribute_name = &am.name;
7981 let attribute_ty_name = &am.ty_name;
7982
7983 let offset = attribute.offset;
7984 let func = unpacking_functions
7985 .get(&attribute.format)
7986 .expect("Should have generated this unpacking function earlier.");
7987 let func_name = &func.name;
7988
7989 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7995
7996 let needs_conversion = am.scalar != func.scalar;
7999
8000 if needs_padding_or_truncation != Ordering::Equal {
8001 writeln!(
8004 self.out,
8005 "{}// {attribute_ty_name} <- {:?}",
8006 back::Level(2),
8007 attribute.format
8008 )?;
8009 }
8010
8011 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
8012
8013 if needs_padding_or_truncation == Ordering::Greater {
8014 write!(self.out, "{attribute_ty_name}(")?;
8016 }
8017
8018 if needs_conversion {
8020 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
8021 write!(self.out, "(")?;
8022 }
8023 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
8024 for i in (offset + 1)..(offset + func.byte_count) {
8025 write!(self.out, ", {elem_name}.data[{i}]")?;
8026 }
8027 write!(self.out, ")")?;
8028 if needs_conversion {
8029 write!(self.out, ")")?;
8030 }
8031
8032 match needs_padding_or_truncation {
8033 Ordering::Greater => {
8034 let ty_is_int = scalar_is_int(am.scalar);
8036 let zero_value = if ty_is_int { "0" } else { "0.0" };
8037 let one_value = if ty_is_int { "1" } else { "1.0" };
8038 for i in func.dimension.map_or(1, u8::from)
8039 ..am.dimension.map_or(1, u8::from)
8040 {
8041 write!(
8042 self.out,
8043 ", {}",
8044 if i == 3 { one_value } else { zero_value }
8045 )?;
8046 }
8047 }
8048 Ordering::Less => {
8049 write!(
8051 self.out,
8052 ".{}",
8053 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
8054 )?;
8055 }
8056 Ordering::Equal => {}
8057 }
8058
8059 if needs_padding_or_truncation == Ordering::Greater {
8060 write!(self.out, ")")?;
8061 }
8062
8063 writeln!(self.out, ";")?;
8064 }
8065
8066 writeln!(self.out, "{}}}", back::Level(1))?;
8068 }
8069 }
8070
8071 for (handle, var) in module.global_variables.iter() {
8074 let usage = fun_info[handle];
8075 if usage.is_empty() {
8076 continue;
8077 }
8078 if var.space == crate::AddressSpace::Private {
8079 let tyvar = TypedGlobalVariable {
8080 module,
8081 names: &self.names,
8082 handle,
8083 usage,
8084
8085 reference: false,
8086 };
8087 write!(self.out, "{}", back::INDENT)?;
8088 tyvar.try_fmt(&mut self.out)?;
8089 match var.init {
8090 Some(value) => {
8091 write!(self.out, " = ")?;
8092 self.put_const_expression(
8093 value,
8094 module,
8095 mod_info,
8096 &module.global_expressions,
8097 )?;
8098 writeln!(self.out, ";")?;
8099 }
8100 None => {
8101 writeln!(self.out, " = {{}};")?;
8102 }
8103 };
8104 } else if let Some(ref binding) = var.binding {
8105 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
8106 if let Some(sampler) = resolved.as_inline_sampler(options) {
8107 let name = &self.names[&NameKey::GlobalVariable(handle)];
8109 writeln!(
8110 self.out,
8111 "{}constexpr {}::sampler {}(",
8112 back::INDENT,
8113 NAMESPACE,
8114 name
8115 )?;
8116 self.put_inline_sampler_properties(back::Level(2), sampler)?;
8117 writeln!(self.out, "{});", back::INDENT)?;
8118 } else if let crate::TypeInner::Image {
8119 class: crate::ImageClass::External,
8120 ..
8121 } = module.types[var.ty].inner
8122 {
8123 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
8126 let l1 = back::Level(1);
8127 let l2 = l1.next();
8128 writeln!(
8129 self.out,
8130 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
8131 )?;
8132 for i in 0..3 {
8133 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8134 handle,
8135 ExternalTextureNameKey::Plane(i),
8136 )];
8137 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
8138 }
8139 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8140 handle,
8141 ExternalTextureNameKey::Params,
8142 )];
8143 writeln!(self.out, "{l2}.params = {params_name},")?;
8144 writeln!(self.out, "{l1}}};")?;
8145 }
8146 }
8147 }
8148
8149 if need_workgroup_variables_initialization {
8150 self.write_workgroup_variables_initialization(
8151 module,
8152 mod_info,
8153 fun_info,
8154 local_invocation_index,
8155 ep.stage,
8156 )?;
8157 }
8158
8159 for (arg_index, arg) in fun.arguments.iter().enumerate() {
8170 let arg_name =
8171 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
8172 match module.types[arg.ty].inner {
8173 crate::TypeInner::Struct { ref members, .. } => {
8174 let struct_name = &self.names[&NameKey::Type(arg.ty)];
8175 write!(
8176 self.out,
8177 "{}const {} {} = {{ ",
8178 back::INDENT,
8179 struct_name,
8180 arg_name
8181 )?;
8182 for (member_index, member) in members.iter().enumerate() {
8183 let key = NameKey::StructMember(arg.ty, member_index as u32);
8184 let name = &flattened_member_names[&key];
8185 if member_index != 0 {
8186 write!(self.out, ", ")?;
8187 }
8188 if self
8190 .struct_member_pads
8191 .contains(&(arg.ty, member_index as u32))
8192 {
8193 write!(self.out, "{{}}, ")?;
8194 }
8195 match member.binding {
8196 Some(crate::Binding::Location {
8197 interpolation: Some(crate::Interpolation::PerVertex),
8198 ..
8199 }) => {
8200 writeln!(
8201 self.out,
8202 "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
8203 back::INDENT,
8204 varyings_member_name,
8205 arg_name,
8206 )?;
8207 continue;
8208 }
8209 Some(crate::Binding::Location { .. }) => {
8210 if has_varyings {
8211 write!(self.out, "{varyings_member_name}.")?;
8212 }
8213 }
8214 _ => (),
8215 }
8216 write!(self.out, "{name}")?;
8217 }
8218 writeln!(self.out, " }};")?;
8219 }
8220 _ => match arg.binding {
8221 Some(crate::Binding::Location {
8222 interpolation: Some(crate::Interpolation::PerVertex),
8223 ..
8224 }) => {
8225 let ty_name = TypeContext {
8226 handle: arg.ty,
8227 gctx: module.to_ctx(),
8228 names: &self.names,
8229 access: crate::StorageAccess::empty(),
8230 first_time: false,
8231 };
8232 writeln!(
8233 self.out,
8234 "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8235 back::INDENT,
8236 varyings_member_name,
8237 arg_name,
8238 )?;
8239 }
8240 Some(crate::Binding::Location { .. })
8241 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8242 if has_varyings {
8243 writeln!(
8244 self.out,
8245 "{}const auto {} = {}.{};",
8246 back::INDENT,
8247 arg_name,
8248 varyings_member_name,
8249 arg_name
8250 )?;
8251 }
8252 }
8253 _ => {}
8254 },
8255 }
8256 }
8257
8258 let guarded_indices =
8259 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8260
8261 let context = StatementContext {
8262 expression: ExpressionContext {
8263 function: fun,
8264 origin: FunctionOrigin::EntryPoint(ep_index as _),
8265 info: fun_info,
8266 lang_version: options.lang_version,
8267 policies: options.bounds_check_policies,
8268 guarded_indices,
8269 module,
8270 mod_info,
8271 pipeline_options,
8272 force_loop_bounding: options.force_loop_bounding,
8273 emit_int_div_checks: options.emit_int_div_checks,
8274 ray_query_initialization_tracking: options.ray_query_initialization_tracking,
8275 },
8276 result_struct: if ep.stage == crate::ShaderStage::Task {
8277 None
8278 } else {
8279 Some(&stage_out_name)
8280 },
8281 };
8282
8283 self.put_locals(&context.expression)?;
8286 self.update_expressions_to_bake(fun, fun_info, &context.expression);
8287 self.put_block(back::Level(1), &fun.body, &context)?;
8288 writeln!(self.out, "}}")?;
8289 if ep_index + 1 != module.entry_points.len() {
8290 writeln!(self.out)?;
8291 }
8292 self.named_expressions.clear();
8293
8294 if is_wrapped {
8295 self.write_wrapper_function(NestedFunctionInfo {
8296 options,
8297 ep,
8298 module,
8299 mod_info,
8300 fun_info,
8301 args,
8302 local_invocation_index,
8303 nested_name: &nested_fun_name,
8304 outer_name: &fun_name,
8305 out_mesh_info,
8306 })?;
8307 }
8308 }
8309
8310 Ok(info)
8311 }
8312
8313 pub(super) fn write_barrier(
8314 &mut self,
8315 flags: crate::Barrier,
8316 level: back::Level,
8317 ) -> BackendResult {
8318 if flags.is_empty() {
8321 writeln!(
8322 self.out,
8323 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8324 )?;
8325 }
8326 if flags.contains(crate::Barrier::STORAGE) {
8327 writeln!(
8328 self.out,
8329 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8330 )?;
8331 }
8332 if flags.contains(crate::Barrier::WORK_GROUP) {
8333 writeln!(
8334 self.out,
8335 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8336 )?;
8337 if self.needs_object_memory_barriers {
8338 writeln!(
8339 self.out,
8340 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8341 )?;
8342 }
8343 }
8344 if flags.contains(crate::Barrier::SUB_GROUP) {
8345 writeln!(
8346 self.out,
8347 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8348 )?;
8349 }
8350 if flags.contains(crate::Barrier::TEXTURE) {
8351 writeln!(
8352 self.out,
8353 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8354 )?;
8355 }
8356 Ok(())
8357 }
8358}
8359
8360mod workgroup_mem_init {
8363 use crate::EntryPoint;
8364
8365 use super::*;
8366
8367 enum Access {
8368 GlobalVariable(Handle<crate::GlobalVariable>),
8369 StructMember(Handle<crate::Type>, u32),
8370 Array(usize),
8371 }
8372
8373 impl Access {
8374 fn write<W: Write>(
8375 &self,
8376 writer: &mut W,
8377 names: &FastHashMap<NameKey, String>,
8378 ) -> Result<(), core::fmt::Error> {
8379 match *self {
8380 Access::GlobalVariable(handle) => {
8381 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8382 }
8383 Access::StructMember(handle, index) => {
8384 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8385 }
8386 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8387 }
8388 }
8389 }
8390
8391 struct AccessStack {
8392 stack: Vec<Access>,
8393 array_depth: usize,
8394 }
8395
8396 impl AccessStack {
8397 const fn new() -> Self {
8398 Self {
8399 stack: Vec::new(),
8400 array_depth: 0,
8401 }
8402 }
8403
8404 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8405 let array_depth = self.array_depth;
8406 self.stack.push(Access::Array(array_depth));
8407 self.array_depth += 1;
8408 let res = cb(self, array_depth);
8409 self.stack.pop();
8410 self.array_depth -= 1;
8411 res
8412 }
8413
8414 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8415 self.stack.push(new);
8416 let res = cb(self);
8417 self.stack.pop();
8418 res
8419 }
8420
8421 fn write<W: Write>(
8422 &self,
8423 writer: &mut W,
8424 names: &FastHashMap<NameKey, String>,
8425 ) -> Result<(), core::fmt::Error> {
8426 for next in self.stack.iter() {
8427 next.write(writer, names)?;
8428 }
8429 Ok(())
8430 }
8431 }
8432
8433 impl<W: Write> Writer<W> {
8434 pub(super) fn need_workgroup_variables_initialization(
8435 &mut self,
8436 options: &Options,
8437 ep: &EntryPoint,
8438 module: &crate::Module,
8439 fun_info: &valid::FunctionInfo,
8440 ) -> bool {
8441 let is_task = ep.stage == crate::ShaderStage::Task;
8442 options.zero_initialize_workgroup_memory
8443 && ep.stage.compute_like()
8444 && module.global_variables.iter().any(|(handle, var)| {
8445 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8446 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8447 !fun_info[handle].is_empty() && is_right_address_space
8448 })
8449 }
8450
8451 pub fn write_workgroup_variables_initialization(
8452 &mut self,
8453 module: &crate::Module,
8454 module_info: &valid::ModuleInfo,
8455 fun_info: &valid::FunctionInfo,
8456 local_invocation_index: Option<&NameKey>,
8457 stage: crate::ShaderStage,
8458 ) -> BackendResult {
8459 let level = back::Level(1);
8460
8461 writeln!(
8462 self.out,
8463 "{}if ({} == 0u) {{",
8464 level,
8465 local_invocation_index
8466 .map(|name_key| self.names[name_key].as_str())
8467 .unwrap_or("__local_invocation_index"),
8468 )?;
8469
8470 let mut access_stack = AccessStack::new();
8471
8472 let is_task = stage == crate::ShaderStage::Task;
8473 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8474 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8475 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8476 !fun_info[handle].is_empty() && is_right_address_space
8477 });
8478
8479 for (handle, var) in vars {
8480 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8481 self.write_workgroup_variable_initialization(
8482 module,
8483 module_info,
8484 var.ty,
8485 access_stack,
8486 level.next(),
8487 )
8488 })?;
8489 }
8490
8491 writeln!(self.out, "{level}}}")?;
8492 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8493 }
8494
8495 fn write_workgroup_variable_initialization(
8496 &mut self,
8497 module: &crate::Module,
8498 module_info: &valid::ModuleInfo,
8499 ty: Handle<crate::Type>,
8500 access_stack: &mut AccessStack,
8501 level: back::Level,
8502 ) -> BackendResult {
8503 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8504 write!(self.out, "{level}")?;
8505 access_stack.write(&mut self.out, &self.names)?;
8506 writeln!(self.out, " = {{}};")?;
8507 } else {
8508 match module.types[ty].inner {
8509 crate::TypeInner::Atomic { .. } => {
8510 write!(
8511 self.out,
8512 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8513 )?;
8514 access_stack.write(&mut self.out, &self.names)?;
8515 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8516 }
8517 crate::TypeInner::Array { base, size, .. } => {
8518 let count = match size.resolve(module.to_ctx())? {
8519 proc::IndexableLength::Known(count) => count,
8520 proc::IndexableLength::Dynamic => unreachable!(),
8521 };
8522
8523 access_stack.enter_array(|access_stack, array_depth| {
8524 writeln!(
8525 self.out,
8526 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8527 )?;
8528 self.write_workgroup_variable_initialization(
8529 module,
8530 module_info,
8531 base,
8532 access_stack,
8533 level.next(),
8534 )?;
8535 writeln!(self.out, "{level}}}")?;
8536 BackendResult::Ok(())
8537 })?;
8538 }
8539 crate::TypeInner::Struct { ref members, .. } => {
8540 for (index, member) in members.iter().enumerate() {
8541 access_stack.enter(
8542 Access::StructMember(ty, index as u32),
8543 |access_stack| {
8544 self.write_workgroup_variable_initialization(
8545 module,
8546 module_info,
8547 member.ty,
8548 access_stack,
8549 level,
8550 )
8551 },
8552 )?;
8553 }
8554 }
8555 _ => unreachable!(),
8556 }
8557 }
8558
8559 Ok(())
8560 }
8561 }
8562}
8563
8564impl crate::AtomicFunction {
8565 const fn to_msl(self) -> &'static str {
8566 match self {
8567 Self::Add => "fetch_add",
8568 Self::Subtract => "fetch_sub",
8569 Self::And => "fetch_and",
8570 Self::InclusiveOr => "fetch_or",
8571 Self::ExclusiveOr => "fetch_xor",
8572 Self::Min => "fetch_min",
8573 Self::Max => "fetch_max",
8574 Self::Exchange { compare: None } => "exchange",
8575 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8576 }
8577 }
8578
8579 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8580 Ok(match self {
8581 Self::Min => "min",
8582 Self::Max => "max",
8583 _ => Err(Error::FeatureNotImplemented(
8584 "64-bit atomic operation other than min/max".to_string(),
8585 ))?,
8586 })
8587 }
8588}