1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self, concrete_int_scalars,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68 "nagaTextureSampleBaseClampToEdge";
69pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
83pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
84
85fn put_numeric_type(
94 out: &mut impl Write,
95 scalar: crate::Scalar,
96 sizes: &[crate::VectorSize],
97) -> Result<(), FmtError> {
98 match (scalar, sizes) {
99 (scalar, &[]) => {
100 write!(out, "{}", scalar.to_msl_name())
101 }
102 (scalar, &[rows]) => {
103 write!(
104 out,
105 "{}::{}{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(rows)
109 )
110 }
111 (scalar, &[rows, columns]) => {
112 write!(
113 out,
114 "{}::{}{}x{}",
115 NAMESPACE,
116 scalar.to_msl_name(),
117 common::vector_size_str(columns),
118 common::vector_size_str(rows)
119 )
120 }
121 (_, _) => Ok(()), }
123}
124
125const fn scalar_is_int(scalar: crate::Scalar) -> bool {
126 use crate::ScalarKind::*;
127 match scalar.kind {
128 Sint | Uint | AbstractInt | Bool => true,
129 Float | AbstractFloat => false,
130 }
131}
132
133const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
135
136const REINTERPRET_PREFIX: &str = "reinterpreted_";
138
139struct ClampedLod(Handle<crate::Expression>);
145
146impl Display for ClampedLod {
147 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
148 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
149 }
150}
151
152struct ArraySizeMember(Handle<crate::GlobalVariable>);
167
168impl Display for ArraySizeMember {
169 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170 self.0.write_prefixed(f, "size")
171 }
172}
173
174#[derive(Clone, Copy)]
179struct Reinterpreted<'a> {
180 target_type: &'a str,
181 orig: Handle<crate::Expression>,
182}
183
184impl<'a> Reinterpreted<'a> {
185 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
186 Self { target_type, orig }
187 }
188}
189
190impl Display for Reinterpreted<'_> {
191 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
192 f.write_str(REINTERPRET_PREFIX)?;
193 f.write_str(self.target_type)?;
194 self.orig.write_prefixed(f, "_e")
195 }
196}
197
198struct TypeContext<'a> {
199 handle: Handle<crate::Type>,
200 gctx: proc::GlobalCtx<'a>,
201 names: &'a FastHashMap<NameKey, String>,
202 access: crate::StorageAccess,
203 first_time: bool,
204}
205
206impl TypeContext<'_> {
207 fn scalar(&self) -> Option<crate::Scalar> {
208 let ty = &self.gctx.types[self.handle];
209 ty.inner.scalar()
210 }
211
212 fn vector_size(&self) -> Option<crate::VectorSize> {
213 let ty = &self.gctx.types[self.handle];
214 match ty.inner {
215 crate::TypeInner::Vector { size, .. } => Some(size),
216 _ => None,
217 }
218 }
219}
220
221impl Display for TypeContext<'_> {
222 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
223 let ty = &self.gctx.types[self.handle];
224 if ty.needs_alias() && !self.first_time {
225 let name = &self.names[&NameKey::Type(self.handle)];
226 return write!(out, "{name}");
227 }
228
229 match ty.inner {
230 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
231 crate::TypeInner::Atomic(scalar) => {
232 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
233 }
234 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
235 crate::TypeInner::Matrix {
236 columns,
237 rows,
238 scalar,
239 } => put_numeric_type(out, scalar, &[rows, columns]),
240 crate::TypeInner::CooperativeMatrix {
242 columns,
243 rows,
244 scalar,
245 role: _,
246 } => {
247 write!(
248 out,
249 "{NAMESPACE}::simdgroup_{}{}x{}",
250 scalar.to_msl_name(),
251 columns as u32,
252 rows as u32,
253 )
254 }
255 crate::TypeInner::Pointer { base, space } => {
256 let sub = Self {
257 handle: base,
258 first_time: false,
259 ..*self
260 };
261 let space_name = match space.to_msl_name() {
262 Some(name) => name,
263 None => return Ok(()),
264 };
265 write!(out, "{space_name} {sub}&")
266 }
267 crate::TypeInner::ValuePointer {
268 size,
269 scalar,
270 space,
271 } => {
272 match space.to_msl_name() {
273 Some(name) => write!(out, "{name} ")?,
274 None => return Ok(()),
275 };
276 match size {
277 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
278 None => put_numeric_type(out, scalar, &[])?,
279 };
280
281 write!(out, "&")
282 }
283 crate::TypeInner::Array { base, .. } => {
284 let sub = Self {
285 handle: base,
286 first_time: false,
287 ..*self
288 };
289 write!(out, "{sub}")
292 }
293 crate::TypeInner::Struct { .. } => unreachable!(),
294 crate::TypeInner::Image {
295 dim,
296 arrayed,
297 class,
298 } => {
299 let dim_str = match dim {
300 crate::ImageDimension::D1 => "1d",
301 crate::ImageDimension::D2 => "2d",
302 crate::ImageDimension::D3 => "3d",
303 crate::ImageDimension::Cube => "cube",
304 };
305 let (texture_str, msaa_str, scalar, access) = match class {
306 crate::ImageClass::Sampled { kind, multi } => {
307 let (msaa_str, access) = if multi {
308 ("_ms", "read")
309 } else {
310 ("", "sample")
311 };
312 let scalar = crate::Scalar { kind, width: 4 };
313 ("texture", msaa_str, scalar, access)
314 }
315 crate::ImageClass::Depth { multi } => {
316 let (msaa_str, access) = if multi {
317 ("_ms", "read")
318 } else {
319 ("", "sample")
320 };
321 let scalar = crate::Scalar {
322 kind: crate::ScalarKind::Float,
323 width: 4,
324 };
325 ("depth", msaa_str, scalar, access)
326 }
327 crate::ImageClass::Storage { format, .. } => {
328 let access = if self
329 .access
330 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
331 {
332 "read_write"
333 } else if self.access.contains(crate::StorageAccess::STORE) {
334 "write"
335 } else if self.access.contains(crate::StorageAccess::LOAD) {
336 "read"
337 } else {
338 log::warn!(
339 "Storage access for {:?} (name '{}'): {:?}",
340 self.handle,
341 ty.name.as_deref().unwrap_or_default(),
342 self.access
343 );
344 unreachable!("module is not valid");
345 };
346 ("texture", "", format.into(), access)
347 }
348 crate::ImageClass::External => {
349 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
350 }
351 };
352 let base_name = scalar.to_msl_name();
353 let array_str = if arrayed { "_array" } else { "" };
354 write!(
355 out,
356 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
357 )
358 }
359 crate::TypeInner::Sampler { comparison: _ } => {
360 write!(out, "{NAMESPACE}::sampler")
361 }
362 crate::TypeInner::AccelerationStructure { vertex_return } => {
363 if vertex_return {
364 unimplemented!("metal does not support vertex ray hit return")
365 }
366 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
367 }
368 crate::TypeInner::RayQuery { vertex_return } => {
369 if vertex_return {
370 unimplemented!("metal does not support vertex ray hit return")
371 }
372 write!(out, "{RAY_QUERY_TYPE}")
373 }
374 crate::TypeInner::BindingArray { base, .. } => {
375 let base_tyname = Self {
376 handle: base,
377 first_time: false,
378 ..*self
379 };
380
381 write!(
382 out,
383 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
384 )
385 }
386 }
387 }
388}
389
390struct TypedGlobalVariable<'a> {
391 module: &'a crate::Module,
392 names: &'a FastHashMap<NameKey, String>,
393 handle: Handle<crate::GlobalVariable>,
394 usage: valid::GlobalUse,
395 reference: bool,
396}
397
398impl TypedGlobalVariable<'_> {
399 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
400 let var = &self.module.global_variables[self.handle];
401 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
402
403 let storage_access = match var.space {
404 crate::AddressSpace::Storage { access } => access,
405 _ => match self.module.types[var.ty].inner {
406 crate::TypeInner::Image {
407 class: crate::ImageClass::Storage { access, .. },
408 ..
409 } => access,
410 crate::TypeInner::BindingArray { base, .. } => {
411 match self.module.types[base].inner {
412 crate::TypeInner::Image {
413 class: crate::ImageClass::Storage { access, .. },
414 ..
415 } => access,
416 _ => crate::StorageAccess::default(),
417 }
418 }
419 _ => crate::StorageAccess::default(),
420 },
421 };
422 let ty_name = TypeContext {
423 handle: var.ty,
424 gctx: self.module.to_ctx(),
425 names: self.names,
426 access: storage_access,
427 first_time: false,
428 };
429
430 let (coherent, space, access, reference) = match var.space.to_msl_name() {
431 Some(space) if self.reference => {
432 let coherent = if var
433 .memory_decorations
434 .contains(crate::MemoryDecorations::COHERENT)
435 {
436 "coherent "
437 } else {
438 ""
439 };
440 let access = if var.space.needs_access_qualifier()
441 && !self.usage.intersects(valid::GlobalUse::WRITE)
442 {
443 "const"
444 } else {
445 ""
446 };
447 (coherent, space, access, "&")
448 }
449 _ => ("", "", "", ""),
450 };
451
452 Ok(write!(
453 out,
454 "{}{}{}{}{}{}{} {}",
455 coherent,
456 space,
457 if space.is_empty() { "" } else { " " },
458 ty_name,
459 if access.is_empty() { "" } else { " " },
460 access,
461 reference,
462 name,
463 )?)
464 }
465}
466
467#[derive(Eq, PartialEq, Hash)]
468enum WrappedFunction {
469 UnaryOp {
470 op: crate::UnaryOperator,
471 ty: (Option<crate::VectorSize>, crate::Scalar),
472 },
473 BinaryOp {
474 op: crate::BinaryOperator,
475 left_ty: (Option<crate::VectorSize>, crate::Scalar),
476 right_ty: (Option<crate::VectorSize>, crate::Scalar),
477 },
478 Math {
479 fun: crate::MathFunction,
480 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
481 },
482 Cast {
483 src_scalar: crate::Scalar,
484 vector_size: Option<crate::VectorSize>,
485 dst_scalar: crate::Scalar,
486 },
487 ImageLoad {
488 class: crate::ImageClass,
489 },
490 ImageSample {
491 class: crate::ImageClass,
492 clamp_to_edge: bool,
493 },
494 ImageQuerySize {
495 class: crate::ImageClass,
496 },
497 CooperativeLoad {
498 space_name: &'static str,
499 columns: crate::CooperativeSize,
500 rows: crate::CooperativeSize,
501 scalar: crate::Scalar,
502 },
503 CooperativeMultiplyAdd {
504 space_name: &'static str,
505 columns: crate::CooperativeSize,
506 rows: crate::CooperativeSize,
507 intermediate: crate::CooperativeSize,
508 scalar: crate::Scalar,
509 },
510}
511
512pub struct Writer<W> {
513 out: W,
514 names: FastHashMap<NameKey, String>,
515 named_expressions: crate::NamedExpressions,
516 need_bake_expressions: back::NeedBakeExpressions,
518 namer: proc::Namer,
519 wrapped_functions: FastHashSet<WrappedFunction>,
520 #[cfg(test)]
521 put_expression_stack_pointers: FastHashSet<*const ()>,
522 #[cfg(test)]
523 put_block_stack_pointers: FastHashSet<*const ()>,
524 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
527}
528
529impl crate::Scalar {
530 pub(super) fn to_msl_name(self) -> &'static str {
531 use crate::ScalarKind as Sk;
532 match self {
533 Self {
534 kind: Sk::Float,
535 width: 4,
536 } => "float",
537 Self {
538 kind: Sk::Float,
539 width: 2,
540 } => "half",
541 Self {
542 kind: Sk::Sint,
543 width: 4,
544 } => "int",
545 Self {
546 kind: Sk::Uint,
547 width: 4,
548 } => "uint",
549 Self {
550 kind: Sk::Sint,
551 width: 8,
552 } => "long",
553 Self {
554 kind: Sk::Uint,
555 width: 8,
556 } => "ulong",
557 Self {
558 kind: Sk::Bool,
559 width: _,
560 } => "bool",
561 Self {
562 kind: Sk::AbstractInt | Sk::AbstractFloat,
563 width: _,
564 } => unreachable!("Found Abstract scalar kind"),
565 _ => unreachable!("Unsupported scalar kind: {:?}", self),
566 }
567 }
568}
569
570const fn separate(need_separator: bool) -> &'static str {
571 if need_separator {
572 ","
573 } else {
574 ""
575 }
576}
577
578fn should_pack_struct_member(
579 members: &[crate::StructMember],
580 span: u32,
581 index: usize,
582 module: &crate::Module,
583) -> Option<crate::Scalar> {
584 let member = &members[index];
585
586 let ty_inner = &module.types[member.ty].inner;
587 let last_offset = member.offset + ty_inner.size(module.to_ctx());
588 let next_offset = match members.get(index + 1) {
589 Some(next) => next.offset,
590 None => span,
591 };
592 let is_tight = next_offset == last_offset;
593
594 match *ty_inner {
595 crate::TypeInner::Vector {
596 size: crate::VectorSize::Tri,
597 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
598 } if is_tight => Some(scalar),
599 _ => None,
600 }
601}
602
603fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
604 match arena[ty].inner {
605 crate::TypeInner::Struct { ref members, .. } => {
606 if let Some(member) = members.last() {
607 if let crate::TypeInner::Array {
608 size: crate::ArraySize::Dynamic,
609 ..
610 } = arena[member.ty].inner
611 {
612 return true;
613 }
614 }
615 false
616 }
617 crate::TypeInner::Array {
618 size: crate::ArraySize::Dynamic,
619 ..
620 } => true,
621 _ => false,
622 }
623}
624
625impl crate::AddressSpace {
626 const fn needs_pass_through(&self) -> bool {
630 match *self {
631 Self::Uniform
632 | Self::Storage { .. }
633 | Self::Private
634 | Self::WorkGroup
635 | Self::Immediate
636 | Self::Handle
637 | Self::TaskPayload => true,
638 Self::Function => false,
639 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
640 }
641 }
642
643 const fn needs_access_qualifier(&self) -> bool {
645 match *self {
646 Self::Storage { .. } => true,
651 Self::TaskPayload | Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
652 Self::Private | Self::WorkGroup => false,
654 Self::Uniform | Self::Immediate => false,
656 Self::Handle | Self::Function => false,
658 }
659 }
660
661 const fn to_msl_name(self) -> Option<&'static str> {
662 match self {
663 Self::Handle => None,
664 Self::Uniform | Self::Immediate => Some("constant"),
665 Self::Storage { .. } => Some("device"),
666 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
670 Self::WorkGroup => Some("threadgroup"),
671 Self::TaskPayload => Some("object_data"),
672 Self::IncomingRayPayload => Some("ray_data"),
673 }
674 }
675}
676
677impl crate::Type {
678 const fn needs_alias(&self) -> bool {
680 use crate::TypeInner as Ti;
681
682 match self.inner {
683 Ti::Scalar(_)
685 | Ti::Vector { .. }
686 | Ti::Matrix { .. }
687 | Ti::CooperativeMatrix { .. }
688 | Ti::Atomic(_)
689 | Ti::Pointer { .. }
690 | Ti::ValuePointer { .. } => self.name.is_some(),
691 Ti::Struct { .. } | Ti::Array { .. } => true,
693 Ti::Image { .. }
695 | Ti::Sampler { .. }
696 | Ti::AccelerationStructure { .. }
697 | Ti::RayQuery { .. }
698 | Ti::BindingArray { .. } => false,
699 }
700 }
701}
702
703#[derive(Clone, Copy)]
704enum FunctionOrigin {
705 Handle(Handle<crate::Function>),
706 EntryPoint(proc::EntryPointIndex),
707}
708
709trait NameKeyExt {
710 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
711 match origin {
712 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
713 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
714 }
715 }
716
717 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
722 match origin {
723 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
724 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
725 }
726 }
727}
728
729impl NameKeyExt for NameKey {}
730
731#[derive(Clone, Copy)]
741enum LevelOfDetail {
742 Direct(Handle<crate::Expression>),
743 Restricted(Handle<crate::Expression>),
744}
745
746struct TexelAddress {
756 coordinate: Handle<crate::Expression>,
757 array_index: Option<Handle<crate::Expression>>,
758 sample: Option<Handle<crate::Expression>>,
759 level: Option<LevelOfDetail>,
760}
761
762struct ExpressionContext<'a> {
763 function: &'a crate::Function,
764 origin: FunctionOrigin,
765 info: &'a valid::FunctionInfo,
766 module: &'a crate::Module,
767 mod_info: &'a valid::ModuleInfo,
768 pipeline_options: &'a PipelineOptions,
769 lang_version: (u8, u8),
770 policies: index::BoundsCheckPolicies,
771
772 guarded_indices: HandleSet<crate::Expression>,
776 force_loop_bounding: bool,
778}
779
780impl<'a> ExpressionContext<'a> {
781 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
782 self.info[handle].ty.inner_with(&self.module.types)
783 }
784
785 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
792 let image_ty = self.resolve_type(image);
793 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
794 class.is_mipmapped() && dim != crate::ImageDimension::D1
795 } else {
796 false
797 }
798 }
799
800 fn choose_bounds_check_policy(
801 &self,
802 pointer: Handle<crate::Expression>,
803 ) -> index::BoundsCheckPolicy {
804 self.policies
805 .choose_policy(pointer, &self.module.types, self.info)
806 }
807
808 fn access_needs_check(
810 &self,
811 base: Handle<crate::Expression>,
812 index: index::GuardedIndex,
813 ) -> Option<index::IndexableLength> {
814 index::access_needs_check(
815 base,
816 index,
817 self.module,
818 &self.function.expressions,
819 self.info,
820 )
821 }
822
823 fn bounds_check_iter(
825 &self,
826 chain: Handle<crate::Expression>,
827 ) -> impl Iterator<Item = BoundsCheck> + '_ {
828 index::bounds_check_iter(chain, self.module, self.function, self.info)
829 }
830
831 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
833 index::oob_local_types(self.module, self.function, self.info, self.policies)
834 }
835
836 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
837 match self.function.expressions[expr_handle] {
838 crate::Expression::AccessIndex { base, index } => {
839 let ty = match *self.resolve_type(base) {
840 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
841 ref ty => ty,
842 };
843 match *ty {
844 crate::TypeInner::Struct {
845 ref members, span, ..
846 } => should_pack_struct_member(members, span, index as usize, self.module),
847 _ => None,
848 }
849 }
850 _ => None,
851 }
852 }
853}
854
855struct StatementContext<'a> {
856 expression: ExpressionContext<'a>,
857 result_struct: Option<&'a str>,
858}
859
860impl<W: Write> Writer<W> {
861 pub fn new(out: W) -> Self {
863 Writer {
864 out,
865 names: FastHashMap::default(),
866 named_expressions: Default::default(),
867 need_bake_expressions: Default::default(),
868 namer: proc::Namer::default(),
869 wrapped_functions: FastHashSet::default(),
870 #[cfg(test)]
871 put_expression_stack_pointers: Default::default(),
872 #[cfg(test)]
873 put_block_stack_pointers: Default::default(),
874 struct_member_pads: FastHashSet::default(),
875 }
876 }
877
878 pub fn finish(self) -> W {
881 self.out
882 }
883
884 fn gen_force_bounded_loop_statements(
991 &mut self,
992 level: back::Level,
993 context: &StatementContext,
994 ) -> Option<(String, String)> {
995 if !context.expression.force_loop_bounding {
996 return None;
997 }
998
999 let loop_bound_name = self.namer.call("loop_bound");
1000 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1003 let level = level.next();
1004 let break_and_inc = format!(
1005 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1006{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1007 );
1008
1009 Some((decl, break_and_inc))
1010 }
1011
1012 fn put_call_parameters(
1013 &mut self,
1014 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1015 context: &ExpressionContext,
1016 ) -> BackendResult {
1017 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1018 writer.put_expression(expr, context, true)
1019 })
1020 }
1021
1022 fn put_call_parameters_impl<C, E>(
1023 &mut self,
1024 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1025 ctx: &C,
1026 put_expression: E,
1027 ) -> BackendResult
1028 where
1029 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1030 {
1031 write!(self.out, "(")?;
1032 for (i, handle) in parameters.enumerate() {
1033 if i != 0 {
1034 write!(self.out, ", ")?;
1035 }
1036 put_expression(self, ctx, handle)?;
1037 }
1038 write!(self.out, ")")?;
1039 Ok(())
1040 }
1041
1042 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1048 let oob_local_types = context.oob_local_types();
1049 for &ty in oob_local_types.iter() {
1050 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1051 self.names.insert(name_key, self.namer.call("oob"));
1052 }
1053
1054 for (name_key, ty, init) in context
1055 .function
1056 .local_variables
1057 .iter()
1058 .map(|(local_handle, local)| {
1059 let name_key = NameKey::local(context.origin, local_handle);
1060 (name_key, local.ty, local.init)
1061 })
1062 .chain(oob_local_types.iter().map(|&ty| {
1063 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1064 (name_key, ty, None)
1065 }))
1066 {
1067 let ty_name = TypeContext {
1068 handle: ty,
1069 gctx: context.module.to_ctx(),
1070 names: &self.names,
1071 access: crate::StorageAccess::empty(),
1072 first_time: false,
1073 };
1074 write!(
1075 self.out,
1076 "{}{} {}",
1077 back::INDENT,
1078 ty_name,
1079 self.names[&name_key]
1080 )?;
1081 match init {
1082 Some(value) => {
1083 write!(self.out, " = ")?;
1084 self.put_expression(value, context, true)?;
1085 }
1086 None => {
1087 write!(self.out, " = {{}}")?;
1088 }
1089 };
1090 writeln!(self.out, ";")?;
1091 }
1092 Ok(())
1093 }
1094
1095 fn put_level_of_detail(
1096 &mut self,
1097 level: LevelOfDetail,
1098 context: &ExpressionContext,
1099 ) -> BackendResult {
1100 match level {
1101 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1102 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1103 }
1104 Ok(())
1105 }
1106
1107 fn put_image_query(
1108 &mut self,
1109 image: Handle<crate::Expression>,
1110 query: &str,
1111 level: Option<LevelOfDetail>,
1112 context: &ExpressionContext,
1113 ) -> BackendResult {
1114 self.put_expression(image, context, false)?;
1115 write!(self.out, ".get_{query}(")?;
1116 if let Some(level) = level {
1117 self.put_level_of_detail(level, context)?;
1118 }
1119 write!(self.out, ")")?;
1120 Ok(())
1121 }
1122
1123 fn put_image_size_query(
1124 &mut self,
1125 image: Handle<crate::Expression>,
1126 level: Option<LevelOfDetail>,
1127 kind: crate::ScalarKind,
1128 context: &ExpressionContext,
1129 ) -> BackendResult {
1130 if let crate::TypeInner::Image {
1131 class: crate::ImageClass::External,
1132 ..
1133 } = *context.resolve_type(image)
1134 {
1135 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1136 self.put_expression(image, context, true)?;
1137 write!(self.out, ")")?;
1138 return Ok(());
1139 }
1140
1141 let dim = match *context.resolve_type(image) {
1144 crate::TypeInner::Image { dim, .. } => dim,
1145 ref other => unreachable!("Unexpected type {:?}", other),
1146 };
1147 let scalar = crate::Scalar { kind, width: 4 };
1148 let coordinate_type = scalar.to_msl_name();
1149 match dim {
1150 crate::ImageDimension::D1 => {
1151 if kind == crate::ScalarKind::Uint {
1155 self.put_image_query(image, "width", None, context)?;
1157 } else {
1158 write!(self.out, "int(")?;
1160 self.put_image_query(image, "width", None, context)?;
1161 write!(self.out, ")")?;
1162 }
1163 }
1164 crate::ImageDimension::D2 => {
1165 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1166 self.put_image_query(image, "width", level, context)?;
1167 write!(self.out, ", ")?;
1168 self.put_image_query(image, "height", level, context)?;
1169 write!(self.out, ")")?;
1170 }
1171 crate::ImageDimension::D3 => {
1172 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1173 self.put_image_query(image, "width", level, context)?;
1174 write!(self.out, ", ")?;
1175 self.put_image_query(image, "height", level, context)?;
1176 write!(self.out, ", ")?;
1177 self.put_image_query(image, "depth", level, context)?;
1178 write!(self.out, ")")?;
1179 }
1180 crate::ImageDimension::Cube => {
1181 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1182 self.put_image_query(image, "width", level, context)?;
1183 write!(self.out, ")")?;
1184 }
1185 }
1186 Ok(())
1187 }
1188
1189 fn put_cast_to_uint_scalar_or_vector(
1190 &mut self,
1191 expr: Handle<crate::Expression>,
1192 context: &ExpressionContext,
1193 ) -> BackendResult {
1194 match *context.resolve_type(expr) {
1196 crate::TypeInner::Scalar(_) => {
1197 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1198 }
1199 crate::TypeInner::Vector { size, .. } => {
1200 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1201 }
1202 _ => {
1203 return Err(Error::GenericValidation(
1204 "Invalid type for image coordinate".into(),
1205 ))
1206 }
1207 };
1208
1209 write!(self.out, "(")?;
1210 self.put_expression(expr, context, true)?;
1211 write!(self.out, ")")?;
1212 Ok(())
1213 }
1214
1215 fn put_image_sample_level(
1216 &mut self,
1217 image: Handle<crate::Expression>,
1218 level: crate::SampleLevel,
1219 context: &ExpressionContext,
1220 ) -> BackendResult {
1221 let has_levels = context.image_needs_lod(image);
1222 match level {
1223 crate::SampleLevel::Auto => {}
1224 crate::SampleLevel::Zero => {
1225 }
1227 _ if !has_levels => {
1228 log::warn!("1D image can't be sampled with level {level:?}");
1229 }
1230 crate::SampleLevel::Exact(h) => {
1231 write!(self.out, ", {NAMESPACE}::level(")?;
1232 self.put_expression(h, context, true)?;
1233 write!(self.out, ")")?;
1234 }
1235 crate::SampleLevel::Bias(h) => {
1236 write!(self.out, ", {NAMESPACE}::bias(")?;
1237 self.put_expression(h, context, true)?;
1238 write!(self.out, ")")?;
1239 }
1240 crate::SampleLevel::Gradient { x, y } => {
1241 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1242 self.put_expression(x, context, true)?;
1243 write!(self.out, ", ")?;
1244 self.put_expression(y, context, true)?;
1245 write!(self.out, ")")?;
1246 }
1247 }
1248 Ok(())
1249 }
1250
1251 fn put_image_coordinate_limits(
1252 &mut self,
1253 image: Handle<crate::Expression>,
1254 level: Option<LevelOfDetail>,
1255 context: &ExpressionContext,
1256 ) -> BackendResult {
1257 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1258 write!(self.out, " - 1")?;
1259 Ok(())
1260 }
1261
1262 fn put_restricted_scalar_image_index(
1280 &mut self,
1281 image: Handle<crate::Expression>,
1282 index: Handle<crate::Expression>,
1283 limit_method: &str,
1284 context: &ExpressionContext,
1285 ) -> BackendResult {
1286 write!(self.out, "{NAMESPACE}::min(uint(")?;
1287 self.put_expression(index, context, true)?;
1288 write!(self.out, "), ")?;
1289 self.put_expression(image, context, false)?;
1290 write!(self.out, ".{limit_method}() - 1)")?;
1291 Ok(())
1292 }
1293
1294 fn put_restricted_texel_address(
1295 &mut self,
1296 image: Handle<crate::Expression>,
1297 address: &TexelAddress,
1298 context: &ExpressionContext,
1299 ) -> BackendResult {
1300 write!(self.out, "{NAMESPACE}::min(")?;
1302 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1303 write!(self.out, ", ")?;
1304 self.put_image_coordinate_limits(image, address.level, context)?;
1305 write!(self.out, ")")?;
1306
1307 if let Some(array_index) = address.array_index {
1309 write!(self.out, ", ")?;
1310 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1311 }
1312
1313 if let Some(sample) = address.sample {
1315 write!(self.out, ", ")?;
1316 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1317 }
1318
1319 if let Some(level) = address.level {
1322 write!(self.out, ", ")?;
1323 self.put_level_of_detail(level, context)?;
1324 }
1325
1326 Ok(())
1327 }
1328
1329 fn put_image_access_bounds_check(
1331 &mut self,
1332 image: Handle<crate::Expression>,
1333 address: &TexelAddress,
1334 context: &ExpressionContext,
1335 ) -> BackendResult {
1336 let mut conjunction = "";
1337
1338 let level = if let Some(level) = address.level {
1341 write!(self.out, "uint(")?;
1342 self.put_level_of_detail(level, context)?;
1343 write!(self.out, ") < ")?;
1344 self.put_expression(image, context, true)?;
1345 write!(self.out, ".get_num_mip_levels()")?;
1346 conjunction = " && ";
1347 Some(level)
1348 } else {
1349 None
1350 };
1351
1352 if let Some(sample) = address.sample {
1354 write!(self.out, "uint(")?;
1355 self.put_expression(sample, context, true)?;
1356 write!(self.out, ") < ")?;
1357 self.put_expression(image, context, true)?;
1358 write!(self.out, ".get_num_samples()")?;
1359 conjunction = " && ";
1360 }
1361
1362 if let Some(array_index) = address.array_index {
1364 write!(self.out, "{conjunction}uint(")?;
1365 self.put_expression(array_index, context, true)?;
1366 write!(self.out, ") < ")?;
1367 self.put_expression(image, context, true)?;
1368 write!(self.out, ".get_array_size()")?;
1369 conjunction = " && ";
1370 }
1371
1372 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1374 crate::TypeInner::Vector { .. } => true,
1375 _ => false,
1376 };
1377 write!(self.out, "{conjunction}")?;
1378 if coord_is_vector {
1379 write!(self.out, "{NAMESPACE}::all(")?;
1380 }
1381 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1382 write!(self.out, " < ")?;
1383 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1384 if coord_is_vector {
1385 write!(self.out, ")")?;
1386 }
1387
1388 Ok(())
1389 }
1390
1391 fn put_image_load(
1392 &mut self,
1393 load: Handle<crate::Expression>,
1394 image: Handle<crate::Expression>,
1395 mut address: TexelAddress,
1396 context: &ExpressionContext,
1397 ) -> BackendResult {
1398 if let crate::TypeInner::Image {
1399 class: crate::ImageClass::External,
1400 ..
1401 } = *context.resolve_type(image)
1402 {
1403 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1404 self.put_expression(image, context, true)?;
1405 write!(self.out, ", ")?;
1406 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1407 write!(self.out, ")")?;
1408 return Ok(());
1409 }
1410
1411 match context.policies.image_load {
1412 proc::BoundsCheckPolicy::Restrict => {
1413 if address.level.is_some() {
1416 address.level = if context.image_needs_lod(image) {
1417 Some(LevelOfDetail::Restricted(load))
1418 } else {
1419 None
1420 }
1421 }
1422
1423 self.put_expression(image, context, false)?;
1424 write!(self.out, ".read(")?;
1425 self.put_restricted_texel_address(image, &address, context)?;
1426 write!(self.out, ")")?;
1427 }
1428 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1429 write!(self.out, "(")?;
1430 self.put_image_access_bounds_check(image, &address, context)?;
1431 write!(self.out, " ? ")?;
1432 self.put_unchecked_image_load(image, &address, context)?;
1433 write!(self.out, ": DefaultConstructible())")?;
1434 }
1435 proc::BoundsCheckPolicy::Unchecked => {
1436 self.put_unchecked_image_load(image, &address, context)?;
1437 }
1438 }
1439
1440 Ok(())
1441 }
1442
1443 fn put_unchecked_image_load(
1444 &mut self,
1445 image: Handle<crate::Expression>,
1446 address: &TexelAddress,
1447 context: &ExpressionContext,
1448 ) -> BackendResult {
1449 self.put_expression(image, context, false)?;
1450 write!(self.out, ".read(")?;
1451 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1453 if let Some(expr) = address.array_index {
1454 write!(self.out, ", ")?;
1455 self.put_expression(expr, context, true)?;
1456 }
1457 if let Some(sample) = address.sample {
1458 write!(self.out, ", ")?;
1459 self.put_expression(sample, context, true)?;
1460 }
1461 if let Some(level) = address.level {
1462 if context.image_needs_lod(image) {
1463 write!(self.out, ", ")?;
1464 self.put_level_of_detail(level, context)?;
1465 }
1466 }
1467 write!(self.out, ")")?;
1468
1469 Ok(())
1470 }
1471
1472 fn put_image_atomic(
1473 &mut self,
1474 level: back::Level,
1475 image: Handle<crate::Expression>,
1476 address: &TexelAddress,
1477 fun: crate::AtomicFunction,
1478 value: Handle<crate::Expression>,
1479 context: &StatementContext,
1480 ) -> BackendResult {
1481 write!(self.out, "{level}")?;
1482 self.put_expression(image, &context.expression, false)?;
1483 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1484 fun.to_msl_64_bit()?
1485 } else {
1486 fun.to_msl()
1487 };
1488 write!(self.out, ".atomic_{op}(")?;
1489 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1491 write!(self.out, ", ")?;
1492 self.put_expression(value, &context.expression, true)?;
1493 writeln!(self.out, ");")?;
1494
1495 let value_ty = context.expression.resolve_type(value);
1501 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1502 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1503 (_, Some(8)) => "ulong4(0uL)",
1504 _ => "uint4(0u)",
1505 };
1506 let coord_ty = context.expression.resolve_type(address.coordinate);
1507 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1508 ""
1509 } else {
1510 ".x"
1511 };
1512 write!(self.out, "{level}if (")?;
1513 self.put_expression(address.coordinate, &context.expression, true)?;
1514 write!(self.out, "{x} == -99999) {{ ")?;
1515 self.put_expression(image, &context.expression, false)?;
1516 write!(self.out, ".write({zero_value}, ")?;
1517 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1518 if let Some(array_index) = address.array_index {
1519 write!(self.out, ", ")?;
1520 self.put_expression(array_index, &context.expression, true)?;
1521 }
1522 writeln!(self.out, "); }}")?;
1523
1524 Ok(())
1525 }
1526
1527 fn put_image_store(
1528 &mut self,
1529 level: back::Level,
1530 image: Handle<crate::Expression>,
1531 address: &TexelAddress,
1532 value: Handle<crate::Expression>,
1533 context: &StatementContext,
1534 ) -> BackendResult {
1535 write!(self.out, "{level}")?;
1536 self.put_expression(image, &context.expression, false)?;
1537 write!(self.out, ".write(")?;
1538 self.put_expression(value, &context.expression, true)?;
1539 write!(self.out, ", ")?;
1540 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1542 if let Some(expr) = address.array_index {
1543 write!(self.out, ", ")?;
1544 self.put_expression(expr, &context.expression, true)?;
1545 }
1546 writeln!(self.out, ");")?;
1547
1548 Ok(())
1549 }
1550
1551 fn put_dynamic_array_max_index(
1562 &mut self,
1563 handle: Handle<crate::GlobalVariable>,
1564 context: &ExpressionContext,
1565 ) -> BackendResult {
1566 let global = &context.module.global_variables[handle];
1567 let (offset, array_ty) = match context.module.types[global.ty].inner {
1568 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1569 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1570 None => return Err(Error::GenericValidation("Struct has no members".into())),
1571 },
1572 crate::TypeInner::Array {
1573 size: crate::ArraySize::Dynamic,
1574 ..
1575 } => (0, global.ty),
1576 ref ty => {
1577 return Err(Error::GenericValidation(format!(
1578 "Expected type with dynamic array, got {ty:?}"
1579 )))
1580 }
1581 };
1582
1583 let (size, stride) = match context.module.types[array_ty].inner {
1584 crate::TypeInner::Array { base, stride, .. } => (
1585 context.module.types[base]
1586 .inner
1587 .size(context.module.to_ctx()),
1588 stride,
1589 ),
1590 ref ty => {
1591 return Err(Error::GenericValidation(format!(
1592 "Expected array type, got {ty:?}"
1593 )))
1594 }
1595 };
1596
1597 write!(
1610 self.out,
1611 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1612 member = ArraySizeMember(handle),
1613 offset = offset,
1614 size = size,
1615 stride = stride,
1616 )?;
1617 Ok(())
1618 }
1619
1620 fn put_dot_product<T: Copy>(
1625 &mut self,
1626 arg: T,
1627 arg1: T,
1628 size: usize,
1629 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1630 ) -> BackendResult {
1631 write!(self.out, "(")?;
1634
1635 for index in 0..size {
1637 write!(self.out, " + ")?;
1640 extractor(self, arg, index)?;
1641 write!(self.out, " * ")?;
1642 extractor(self, arg1, index)?;
1643 }
1644
1645 write!(self.out, ")")?;
1646 Ok(())
1647 }
1648
1649 fn put_pack4x8(
1651 &mut self,
1652 arg: Handle<crate::Expression>,
1653 context: &ExpressionContext<'_>,
1654 was_signed: bool,
1655 clamp_bounds: Option<(&str, &str)>,
1656 ) -> Result<(), Error> {
1657 let write_arg = |this: &mut Self| -> BackendResult {
1658 if let Some((min, max)) = clamp_bounds {
1659 write!(this.out, "{NAMESPACE}::clamp(")?;
1661 this.put_expression(arg, context, true)?;
1662 write!(this.out, ", {min}, {max})")?;
1663 } else {
1664 this.put_expression(arg, context, true)?;
1665 }
1666 Ok(())
1667 };
1668
1669 if context.lang_version >= (2, 1) {
1670 let packed_type = if was_signed {
1671 "packed_char4"
1672 } else {
1673 "packed_uchar4"
1674 };
1675 write!(self.out, "as_type<uint>({packed_type}(")?;
1677 write_arg(self)?;
1678 write!(self.out, "))")?;
1679 } else {
1680 if was_signed {
1682 write!(self.out, "uint(")?;
1683 }
1684 write!(self.out, "(")?;
1685 write_arg(self)?;
1686 write!(self.out, "[0] & 0xFF) | ((")?;
1687 write_arg(self)?;
1688 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1689 write_arg(self)?;
1690 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1691 write_arg(self)?;
1692 write!(self.out, "[3] & 0xFF) << 24)")?;
1693 if was_signed {
1694 write!(self.out, ")")?;
1695 }
1696 }
1697
1698 Ok(())
1699 }
1700
1701 fn put_isign(
1704 &mut self,
1705 arg: Handle<crate::Expression>,
1706 context: &ExpressionContext,
1707 ) -> BackendResult {
1708 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1709 let scalar = context
1710 .resolve_type(arg)
1711 .scalar()
1712 .expect("put_isign should only be called for args which have an integer scalar type")
1713 .to_msl_name();
1714 match context.resolve_type(arg) {
1715 &crate::TypeInner::Vector { size, .. } => {
1716 let size = common::vector_size_str(size);
1717 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1718 }
1719 _ => {
1720 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1721 }
1722 }
1723 write!(self.out, ", (")?;
1724 self.put_expression(arg, context, true)?;
1725 write!(self.out, " > 0)), {scalar}(0), (")?;
1726 self.put_expression(arg, context, true)?;
1727 write!(self.out, " == 0))")?;
1728 Ok(())
1729 }
1730
1731 fn put_const_expression(
1732 &mut self,
1733 expr_handle: Handle<crate::Expression>,
1734 module: &crate::Module,
1735 mod_info: &valid::ModuleInfo,
1736 arena: &crate::Arena<crate::Expression>,
1737 ) -> BackendResult {
1738 self.put_possibly_const_expression(
1739 expr_handle,
1740 arena,
1741 module,
1742 mod_info,
1743 &(module, mod_info),
1744 |&(_, mod_info), expr| &mod_info[expr],
1745 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1746 )
1747 }
1748
1749 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1750 match literal {
1751 crate::Literal::F64(_) => {
1752 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1753 }
1754 crate::Literal::F16(value) => {
1755 if value.is_infinite() {
1756 let sign = if value.is_sign_negative() { "-" } else { "" };
1757 write!(self.out, "{sign}INFINITY")?;
1758 } else if value.is_nan() {
1759 write!(self.out, "NAN")?;
1760 } else {
1761 let suffix = if value.fract() == f16::from_f32(0.0) {
1762 ".0h"
1763 } else {
1764 "h"
1765 };
1766 write!(self.out, "{value}{suffix}")?;
1767 }
1768 }
1769 crate::Literal::F32(value) => {
1770 if value.is_infinite() {
1771 let sign = if value.is_sign_negative() { "-" } else { "" };
1772 write!(self.out, "{sign}INFINITY")?;
1773 } else if value.is_nan() {
1774 write!(self.out, "NAN")?;
1775 } else {
1776 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1777 write!(self.out, "{value}{suffix}")?;
1778 }
1779 }
1780 crate::Literal::U32(value) => {
1781 write!(self.out, "{value}u")?;
1782 }
1783 crate::Literal::I32(value) => {
1784 if value == i32::MIN {
1789 write!(self.out, "({} - 1)", value + 1)?;
1790 } else {
1791 write!(self.out, "{value}")?;
1792 }
1793 }
1794 crate::Literal::U64(value) => {
1795 write!(self.out, "{value}uL")?;
1796 }
1797 crate::Literal::I64(value) => {
1798 if value == i64::MIN {
1805 write!(self.out, "({}L - 1L)", value + 1)?;
1806 } else {
1807 write!(self.out, "{value}L")?;
1808 }
1809 }
1810 crate::Literal::Bool(value) => {
1811 write!(self.out, "{value}")?;
1812 }
1813 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1814 return Err(Error::GenericValidation(
1815 "Unsupported abstract literal".into(),
1816 ));
1817 }
1818 }
1819 Ok(())
1820 }
1821
1822 #[allow(clippy::too_many_arguments)]
1823 fn put_possibly_const_expression<C, I, E>(
1824 &mut self,
1825 expr_handle: Handle<crate::Expression>,
1826 expressions: &crate::Arena<crate::Expression>,
1827 module: &crate::Module,
1828 mod_info: &valid::ModuleInfo,
1829 ctx: &C,
1830 get_expr_ty: I,
1831 put_expression: E,
1832 ) -> BackendResult
1833 where
1834 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1835 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1836 {
1837 match expressions[expr_handle] {
1838 crate::Expression::Literal(literal) => {
1839 self.put_literal(literal)?;
1840 }
1841 crate::Expression::Constant(handle) => {
1842 let constant = &module.constants[handle];
1843 if constant.name.is_some() {
1844 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1845 } else {
1846 self.put_const_expression(
1847 constant.init,
1848 module,
1849 mod_info,
1850 &module.global_expressions,
1851 )?;
1852 }
1853 }
1854 crate::Expression::ZeroValue(ty) => {
1855 let ty_name = TypeContext {
1856 handle: ty,
1857 gctx: module.to_ctx(),
1858 names: &self.names,
1859 access: crate::StorageAccess::empty(),
1860 first_time: false,
1861 };
1862 write!(self.out, "{ty_name} {{}}")?;
1863 }
1864 crate::Expression::Compose { ty, ref components } => {
1865 let ty_name = TypeContext {
1866 handle: ty,
1867 gctx: module.to_ctx(),
1868 names: &self.names,
1869 access: crate::StorageAccess::empty(),
1870 first_time: false,
1871 };
1872 write!(self.out, "{ty_name}")?;
1873 match module.types[ty].inner {
1874 crate::TypeInner::Scalar(_)
1875 | crate::TypeInner::Vector { .. }
1876 | crate::TypeInner::Matrix { .. } => {
1877 self.put_call_parameters_impl(
1878 components.iter().copied(),
1879 ctx,
1880 put_expression,
1881 )?;
1882 }
1883 crate::TypeInner::Array { .. } => {
1884 write!(self.out, " {{{{")?;
1887 for (index, &component) in components.iter().enumerate() {
1888 if index != 0 {
1889 write!(self.out, ", ")?;
1890 }
1891 put_expression(self, ctx, component)?;
1892 }
1893 write!(self.out, "}}}}")?;
1894 }
1895 crate::TypeInner::Struct { .. } => {
1896 write!(self.out, " {{")?;
1897 for (index, &component) in components.iter().enumerate() {
1898 if index != 0 {
1899 write!(self.out, ", ")?;
1900 }
1901 if self.struct_member_pads.contains(&(ty, index as u32)) {
1903 write!(self.out, "{{}}, ")?;
1904 }
1905 put_expression(self, ctx, component)?;
1906 }
1907 write!(self.out, "}}")?;
1908 }
1909 _ => return Err(Error::UnsupportedCompose(ty)),
1910 }
1911 }
1912 crate::Expression::Splat { size, value } => {
1913 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1914 crate::TypeInner::Scalar(scalar) => scalar,
1915 ref ty => {
1916 return Err(Error::GenericValidation(format!(
1917 "Expected splat value type must be a scalar, got {ty:?}",
1918 )))
1919 }
1920 };
1921 put_numeric_type(&mut self.out, scalar, &[size])?;
1922 write!(self.out, "(")?;
1923 put_expression(self, ctx, value)?;
1924 write!(self.out, ")")?;
1925 }
1926 _ => {
1927 return Err(Error::Override);
1928 }
1929 }
1930
1931 Ok(())
1932 }
1933
1934 fn put_expression(
1946 &mut self,
1947 expr_handle: Handle<crate::Expression>,
1948 context: &ExpressionContext,
1949 is_scoped: bool,
1950 ) -> BackendResult {
1951 #[cfg(test)]
1953 self.put_expression_stack_pointers
1954 .insert(ptr::from_ref(&expr_handle).cast());
1955
1956 if let Some(name) = self.named_expressions.get(&expr_handle) {
1957 write!(self.out, "{name}")?;
1958 return Ok(());
1959 }
1960
1961 let expression = &context.function.expressions[expr_handle];
1962 match *expression {
1963 crate::Expression::Literal(_)
1964 | crate::Expression::Constant(_)
1965 | crate::Expression::ZeroValue(_)
1966 | crate::Expression::Compose { .. }
1967 | crate::Expression::Splat { .. } => {
1968 self.put_possibly_const_expression(
1969 expr_handle,
1970 &context.function.expressions,
1971 context.module,
1972 context.mod_info,
1973 context,
1974 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1975 |writer, context, expr| writer.put_expression(expr, context, true),
1976 )?;
1977 }
1978 crate::Expression::Override(_) => return Err(Error::Override),
1979 crate::Expression::Access { base, .. }
1980 | crate::Expression::AccessIndex { base, .. } => {
1981 let policy = context.choose_bounds_check_policy(base);
1987 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1988 && self.put_bounds_checks(
1989 expr_handle,
1990 context,
1991 back::Level(0),
1992 if is_scoped { "" } else { "(" },
1993 )?
1994 {
1995 write!(self.out, " ? ")?;
1996 self.put_access_chain(expr_handle, policy, context)?;
1997 write!(self.out, " : ")?;
1998
1999 if context.resolve_type(base).pointer_space().is_some() {
2000 let result_ty = context.info[expr_handle]
2004 .ty
2005 .inner_with(&context.module.types)
2006 .pointer_base_type();
2007 let result_ty_handle = match result_ty {
2008 Some(TypeResolution::Handle(handle)) => handle,
2009 Some(TypeResolution::Value(_)) => {
2010 unreachable!(
2018 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2019 );
2020 }
2021 None => {
2022 unreachable!(
2023 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2024 )
2025 }
2026 };
2027 let name_key =
2028 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2029 self.out.write_str(&self.names[&name_key])?;
2030 } else {
2031 write!(self.out, "DefaultConstructible()")?;
2032 }
2033
2034 if !is_scoped {
2035 write!(self.out, ")")?;
2036 }
2037 } else {
2038 self.put_access_chain(expr_handle, policy, context)?;
2039 }
2040 }
2041 crate::Expression::Swizzle {
2042 size,
2043 vector,
2044 pattern,
2045 } => {
2046 self.put_wrapped_expression_for_packed_vec3_access(
2047 vector,
2048 context,
2049 false,
2050 &Self::put_expression,
2051 )?;
2052 write!(self.out, ".")?;
2053 for &sc in pattern[..size as usize].iter() {
2054 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2055 }
2056 }
2057 crate::Expression::FunctionArgument(index) => {
2058 let name_key = match context.origin {
2059 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2060 FunctionOrigin::EntryPoint(ep_index) => {
2061 NameKey::EntryPointArgument(ep_index, index)
2062 }
2063 };
2064 let name = &self.names[&name_key];
2065 write!(self.out, "{name}")?;
2066 }
2067 crate::Expression::GlobalVariable(handle) => {
2068 let name = &self.names[&NameKey::GlobalVariable(handle)];
2069 write!(self.out, "{name}")?;
2070 }
2071 crate::Expression::LocalVariable(handle) => {
2072 let name_key = NameKey::local(context.origin, handle);
2073 let name = &self.names[&name_key];
2074 write!(self.out, "{name}")?;
2075 }
2076 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2077 crate::Expression::ImageSample {
2078 coordinate,
2079 image,
2080 sampler,
2081 clamp_to_edge: true,
2082 gather: None,
2083 array_index: None,
2084 offset: None,
2085 level: crate::SampleLevel::Zero,
2086 depth_ref: None,
2087 } => {
2088 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2089 self.put_expression(image, context, true)?;
2090 write!(self.out, ", ")?;
2091 self.put_expression(sampler, context, true)?;
2092 write!(self.out, ", ")?;
2093 self.put_expression(coordinate, context, true)?;
2094 write!(self.out, ")")?;
2095 }
2096 crate::Expression::ImageSample {
2097 image,
2098 sampler,
2099 gather,
2100 coordinate,
2101 array_index,
2102 offset,
2103 level,
2104 depth_ref,
2105 clamp_to_edge,
2106 } => {
2107 if clamp_to_edge {
2108 return Err(Error::GenericValidation(
2109 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2110 ));
2111 }
2112
2113 let main_op = match gather {
2114 Some(_) => "gather",
2115 None => "sample",
2116 };
2117 let comparison_op = match depth_ref {
2118 Some(_) => "_compare",
2119 None => "",
2120 };
2121 self.put_expression(image, context, false)?;
2122 write!(self.out, ".{main_op}{comparison_op}(")?;
2123 self.put_expression(sampler, context, true)?;
2124 write!(self.out, ", ")?;
2125 self.put_expression(coordinate, context, true)?;
2126 if let Some(expr) = array_index {
2127 write!(self.out, ", ")?;
2128 self.put_expression(expr, context, true)?;
2129 }
2130 if let Some(dref) = depth_ref {
2131 write!(self.out, ", ")?;
2132 self.put_expression(dref, context, true)?;
2133 }
2134
2135 self.put_image_sample_level(image, level, context)?;
2136
2137 if let Some(offset) = offset {
2138 write!(self.out, ", ")?;
2139 self.put_expression(offset, context, true)?;
2140 }
2141
2142 match gather {
2143 None | Some(crate::SwizzleComponent::X) => {}
2144 Some(component) => {
2145 let is_cube_map = match *context.resolve_type(image) {
2146 crate::TypeInner::Image {
2147 dim: crate::ImageDimension::Cube,
2148 ..
2149 } => true,
2150 _ => false,
2151 };
2152 if offset.is_none() && !is_cube_map {
2155 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2156 }
2157 let letter = back::COMPONENTS[component as usize];
2158 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2159 }
2160 }
2161 write!(self.out, ")")?;
2162 }
2163 crate::Expression::ImageLoad {
2164 image,
2165 coordinate,
2166 array_index,
2167 sample,
2168 level,
2169 } => {
2170 let address = TexelAddress {
2171 coordinate,
2172 array_index,
2173 sample,
2174 level: level.map(LevelOfDetail::Direct),
2175 };
2176 self.put_image_load(expr_handle, image, address, context)?;
2177 }
2178 crate::Expression::ImageQuery { image, query } => match query {
2181 crate::ImageQuery::Size { level } => {
2182 self.put_image_size_query(
2183 image,
2184 level.map(LevelOfDetail::Direct),
2185 crate::ScalarKind::Uint,
2186 context,
2187 )?;
2188 }
2189 crate::ImageQuery::NumLevels => {
2190 self.put_expression(image, context, false)?;
2191 write!(self.out, ".get_num_mip_levels()")?;
2192 }
2193 crate::ImageQuery::NumLayers => {
2194 self.put_expression(image, context, false)?;
2195 write!(self.out, ".get_array_size()")?;
2196 }
2197 crate::ImageQuery::NumSamples => {
2198 self.put_expression(image, context, false)?;
2199 write!(self.out, ".get_num_samples()")?;
2200 }
2201 },
2202 crate::Expression::Unary { op, expr } => {
2203 let op_str = match op {
2204 crate::UnaryOperator::Negate => {
2205 match context.resolve_type(expr).scalar_kind() {
2206 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2207 _ => "-",
2208 }
2209 }
2210 crate::UnaryOperator::LogicalNot => "!",
2211 crate::UnaryOperator::BitwiseNot => "~",
2212 };
2213 write!(self.out, "{op_str}(")?;
2214 self.put_expression(expr, context, false)?;
2215 write!(self.out, ")")?;
2216 }
2217 crate::Expression::Binary { op, left, right } => {
2218 let kind = context
2219 .resolve_type(left)
2220 .scalar_kind()
2221 .ok_or(Error::UnsupportedBinaryOp(op))?;
2222
2223 if op == crate::BinaryOperator::Divide
2224 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2225 {
2226 write!(self.out, "{DIV_FUNCTION}(")?;
2227 self.put_expression(left, context, true)?;
2228 write!(self.out, ", ")?;
2229 self.put_expression(right, context, true)?;
2230 write!(self.out, ")")?;
2231 } else if op == crate::BinaryOperator::Modulo
2232 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2233 {
2234 write!(self.out, "{MOD_FUNCTION}(")?;
2235 self.put_expression(left, context, true)?;
2236 write!(self.out, ", ")?;
2237 self.put_expression(right, context, true)?;
2238 write!(self.out, ")")?;
2239 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2240 write!(self.out, "{NAMESPACE}::fmod(")?;
2245 self.put_expression(left, context, true)?;
2246 write!(self.out, ", ")?;
2247 self.put_expression(right, context, true)?;
2248 write!(self.out, ")")?;
2249 } else if (op == crate::BinaryOperator::Add
2250 || op == crate::BinaryOperator::Subtract
2251 || op == crate::BinaryOperator::Multiply)
2252 && kind == crate::ScalarKind::Sint
2253 {
2254 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2255 crate::TypeInner::Scalar(scalar) => {
2256 Ok(crate::TypeInner::Scalar(crate::Scalar {
2257 kind: crate::ScalarKind::Uint,
2258 ..scalar
2259 }))
2260 }
2261 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2262 size,
2263 scalar: crate::Scalar {
2264 kind: crate::ScalarKind::Uint,
2265 ..scalar
2266 },
2267 }),
2268 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2269 };
2270
2271 self.put_bitcasted_expression(
2276 context.resolve_type(expr_handle),
2277 context,
2278 &|writer, context, is_scoped| {
2279 writer.put_binop(
2280 op,
2281 left,
2282 right,
2283 context,
2284 is_scoped,
2285 &|writer, expr, context, _is_scoped| {
2286 writer.put_bitcasted_expression(
2287 &to_unsigned(context.resolve_type(expr))?,
2288 context,
2289 &|writer, context, is_scoped| {
2290 writer.put_expression(expr, context, is_scoped)
2291 },
2292 )
2293 },
2294 )
2295 },
2296 )?;
2297 } else {
2298 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2299 }
2300 }
2301 crate::Expression::Select {
2302 condition,
2303 accept,
2304 reject,
2305 } => match *context.resolve_type(condition) {
2306 crate::TypeInner::Scalar(crate::Scalar {
2307 kind: crate::ScalarKind::Bool,
2308 ..
2309 }) => {
2310 if !is_scoped {
2311 write!(self.out, "(")?;
2312 }
2313 self.put_expression(condition, context, false)?;
2314 write!(self.out, " ? ")?;
2315 self.put_expression(accept, context, false)?;
2316 write!(self.out, " : ")?;
2317 self.put_expression(reject, context, false)?;
2318 if !is_scoped {
2319 write!(self.out, ")")?;
2320 }
2321 }
2322 crate::TypeInner::Vector {
2323 scalar:
2324 crate::Scalar {
2325 kind: crate::ScalarKind::Bool,
2326 ..
2327 },
2328 ..
2329 } => {
2330 write!(self.out, "{NAMESPACE}::select(")?;
2331 self.put_expression(reject, context, true)?;
2332 write!(self.out, ", ")?;
2333 self.put_expression(accept, context, true)?;
2334 write!(self.out, ", ")?;
2335 self.put_expression(condition, context, true)?;
2336 write!(self.out, ")")?;
2337 }
2338 ref ty => {
2339 return Err(Error::GenericValidation(format!(
2340 "Expected select condition to be a non-bool type, got {ty:?}",
2341 )))
2342 }
2343 },
2344 crate::Expression::Derivative { axis, expr, .. } => {
2345 use crate::DerivativeAxis as Axis;
2346 let op = match axis {
2347 Axis::X => "dfdx",
2348 Axis::Y => "dfdy",
2349 Axis::Width => "fwidth",
2350 };
2351 write!(self.out, "{NAMESPACE}::{op}")?;
2352 self.put_call_parameters(iter::once(expr), context)?;
2353 }
2354 crate::Expression::Relational { fun, argument } => {
2355 let op = match fun {
2356 crate::RelationalFunction::Any => "any",
2357 crate::RelationalFunction::All => "all",
2358 crate::RelationalFunction::IsNan => "isnan",
2359 crate::RelationalFunction::IsInf => "isinf",
2360 };
2361 write!(self.out, "{NAMESPACE}::{op}")?;
2362 self.put_call_parameters(iter::once(argument), context)?;
2363 }
2364 crate::Expression::Math {
2365 fun,
2366 arg,
2367 arg1,
2368 arg2,
2369 arg3,
2370 } => {
2371 use crate::MathFunction as Mf;
2372
2373 let arg_type = context.resolve_type(arg);
2374 let scalar_argument = match arg_type {
2375 &crate::TypeInner::Scalar(_) => true,
2376 _ => false,
2377 };
2378
2379 let fun_name = match fun {
2380 Mf::Abs => "abs",
2382 Mf::Min => "min",
2383 Mf::Max => "max",
2384 Mf::Clamp => "clamp",
2385 Mf::Saturate => "saturate",
2386 Mf::Cos => "cos",
2388 Mf::Cosh => "cosh",
2389 Mf::Sin => "sin",
2390 Mf::Sinh => "sinh",
2391 Mf::Tan => "tan",
2392 Mf::Tanh => "tanh",
2393 Mf::Acos => "acos",
2394 Mf::Asin => "asin",
2395 Mf::Atan => "atan",
2396 Mf::Atan2 => "atan2",
2397 Mf::Asinh => "asinh",
2398 Mf::Acosh => "acosh",
2399 Mf::Atanh => "atanh",
2400 Mf::Radians => "",
2401 Mf::Degrees => "",
2402 Mf::Ceil => "ceil",
2404 Mf::Floor => "floor",
2405 Mf::Round => "rint",
2406 Mf::Fract => "fract",
2407 Mf::Trunc => "trunc",
2408 Mf::Modf => MODF_FUNCTION,
2409 Mf::Frexp => FREXP_FUNCTION,
2410 Mf::Ldexp => "ldexp",
2411 Mf::Exp => "exp",
2413 Mf::Exp2 => "exp2",
2414 Mf::Log => "log",
2415 Mf::Log2 => "log2",
2416 Mf::Pow => "pow",
2417 Mf::Dot => match *context.resolve_type(arg) {
2419 crate::TypeInner::Vector {
2420 scalar:
2421 crate::Scalar {
2422 kind: crate::ScalarKind::Float,
2424 ..
2425 },
2426 ..
2427 } => "dot",
2428 crate::TypeInner::Vector {
2429 size,
2430 scalar:
2431 scalar @ crate::Scalar {
2432 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2433 ..
2434 },
2435 } => {
2436 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2438 write!(self.out, "{fun_name}(")?;
2439 self.put_expression(arg, context, true)?;
2440 write!(self.out, ", ")?;
2441 self.put_expression(arg1.unwrap(), context, true)?;
2442 write!(self.out, ")")?;
2443 return Ok(());
2444 }
2445 _ => unreachable!(
2446 "Correct TypeInner for dot product should be already validated"
2447 ),
2448 },
2449 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2450 if context.lang_version >= (2, 1) {
2451 let packed_type = match fun {
2455 Mf::Dot4I8Packed => "packed_char4",
2456 Mf::Dot4U8Packed => "packed_uchar4",
2457 _ => unreachable!(),
2458 };
2459
2460 return self.put_dot_product(
2461 Reinterpreted::new(packed_type, arg),
2462 Reinterpreted::new(packed_type, arg1.unwrap()),
2463 4,
2464 |writer, arg, index| {
2465 write!(writer.out, "{arg}[{index}]")?;
2468 Ok(())
2469 },
2470 );
2471 } else {
2472 let conversion = match fun {
2476 Mf::Dot4I8Packed => "int",
2477 Mf::Dot4U8Packed => "",
2478 _ => unreachable!(),
2479 };
2480
2481 return self.put_dot_product(
2482 arg,
2483 arg1.unwrap(),
2484 4,
2485 |writer, arg, index| {
2486 write!(writer.out, "({conversion}(")?;
2487 writer.put_expression(arg, context, true)?;
2488 if index == 3 {
2489 write!(writer.out, ") >> 24)")?;
2490 } else {
2491 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2492 }
2493 Ok(())
2494 },
2495 );
2496 }
2497 }
2498 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2499 Mf::Cross => "cross",
2500 Mf::Distance => "distance",
2501 Mf::Length if scalar_argument => "abs",
2502 Mf::Length => "length",
2503 Mf::Normalize => "normalize",
2504 Mf::FaceForward => "faceforward",
2505 Mf::Reflect => "reflect",
2506 Mf::Refract => "refract",
2507 Mf::Sign => match arg_type.scalar_kind() {
2509 Some(crate::ScalarKind::Sint) => {
2510 return self.put_isign(arg, context);
2511 }
2512 _ => "sign",
2513 },
2514 Mf::Fma => "fma",
2515 Mf::Mix => "mix",
2516 Mf::Step => "step",
2517 Mf::SmoothStep => "smoothstep",
2518 Mf::Sqrt => "sqrt",
2519 Mf::InverseSqrt => "rsqrt",
2520 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2521 Mf::Transpose => "transpose",
2522 Mf::Determinant => "determinant",
2523 Mf::QuantizeToF16 => "",
2524 Mf::CountTrailingZeros => "ctz",
2526 Mf::CountLeadingZeros => "clz",
2527 Mf::CountOneBits => "popcount",
2528 Mf::ReverseBits => "reverse_bits",
2529 Mf::ExtractBits => "",
2530 Mf::InsertBits => "",
2531 Mf::FirstTrailingBit => "",
2532 Mf::FirstLeadingBit => "",
2533 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2535 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2536 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2537 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2538 Mf::Pack2x16float => "",
2539 Mf::Pack4xI8 => "",
2540 Mf::Pack4xU8 => "",
2541 Mf::Pack4xI8Clamp => "",
2542 Mf::Pack4xU8Clamp => "",
2543 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2545 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2546 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2547 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2548 Mf::Unpack2x16float => "",
2549 Mf::Unpack4xI8 => "",
2550 Mf::Unpack4xU8 => "",
2551 };
2552
2553 match fun {
2554 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2555 if context.lang_version < (1, 2) {
2564 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2565 }
2566 }
2567 _ => {}
2568 }
2569
2570 match fun {
2571 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2572 write!(self.out, "{ABS_FUNCTION}(")?;
2573 self.put_expression(arg, context, true)?;
2574 write!(self.out, ")")?;
2575 }
2576 Mf::Distance if scalar_argument => {
2577 write!(self.out, "{NAMESPACE}::abs(")?;
2578 self.put_expression(arg, context, false)?;
2579 write!(self.out, " - ")?;
2580 self.put_expression(arg1.unwrap(), context, false)?;
2581 write!(self.out, ")")?;
2582 }
2583 Mf::FirstTrailingBit => {
2584 let scalar = context.resolve_type(arg).scalar().unwrap();
2585 let constant = scalar.width * 8 + 1;
2586
2587 write!(self.out, "((({NAMESPACE}::ctz(")?;
2588 self.put_expression(arg, context, true)?;
2589 write!(self.out, ") + 1) % {constant}) - 1)")?;
2590 }
2591 Mf::FirstLeadingBit => {
2592 let inner = context.resolve_type(arg);
2593 let scalar = inner.scalar().unwrap();
2594 let constant = scalar.width * 8 - 1;
2595
2596 write!(
2597 self.out,
2598 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2599 )?;
2600
2601 if scalar.kind == crate::ScalarKind::Sint {
2602 write!(self.out, "{NAMESPACE}::select(")?;
2603 self.put_expression(arg, context, true)?;
2604 write!(self.out, ", ~")?;
2605 self.put_expression(arg, context, true)?;
2606 write!(self.out, ", ")?;
2607 self.put_expression(arg, context, true)?;
2608 write!(self.out, " < 0)")?;
2609 } else {
2610 self.put_expression(arg, context, true)?;
2611 }
2612
2613 write!(self.out, "), ")?;
2614
2615 match *inner {
2617 crate::TypeInner::Vector { size, scalar } => {
2618 let size = common::vector_size_str(size);
2619 let name = scalar.to_msl_name();
2620 write!(self.out, "{name}{size}")?;
2621 }
2622 crate::TypeInner::Scalar(scalar) => {
2623 let name = scalar.to_msl_name();
2624 write!(self.out, "{name}")?;
2625 }
2626 _ => (),
2627 }
2628
2629 write!(self.out, "(-1), ")?;
2630 self.put_expression(arg, context, true)?;
2631 write!(self.out, " == 0 || ")?;
2632 self.put_expression(arg, context, true)?;
2633 write!(self.out, " == -1)")?;
2634 }
2635 Mf::Unpack2x16float => {
2636 write!(self.out, "float2(as_type<half2>(")?;
2637 self.put_expression(arg, context, false)?;
2638 write!(self.out, "))")?;
2639 }
2640 Mf::Pack2x16float => {
2641 write!(self.out, "as_type<uint>(half2(")?;
2642 self.put_expression(arg, context, false)?;
2643 write!(self.out, "))")?;
2644 }
2645 Mf::ExtractBits => {
2646 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2663
2664 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2665 self.put_expression(arg, context, true)?;
2666 write!(self.out, ", {NAMESPACE}::min(")?;
2667 self.put_expression(arg1.unwrap(), context, true)?;
2668 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2669 self.put_expression(arg2.unwrap(), context, true)?;
2670 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2671 self.put_expression(arg1.unwrap(), context, true)?;
2672 write!(self.out, ", {scalar_bits}u)))")?;
2673 }
2674 Mf::InsertBits => {
2675 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2680
2681 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2682 self.put_expression(arg, context, true)?;
2683 write!(self.out, ", ")?;
2684 self.put_expression(arg1.unwrap(), context, true)?;
2685 write!(self.out, ", {NAMESPACE}::min(")?;
2686 self.put_expression(arg2.unwrap(), context, true)?;
2687 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2688 self.put_expression(arg3.unwrap(), context, true)?;
2689 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2690 self.put_expression(arg2.unwrap(), context, true)?;
2691 write!(self.out, ", {scalar_bits}u)))")?;
2692 }
2693 Mf::Radians => {
2694 write!(self.out, "((")?;
2695 self.put_expression(arg, context, false)?;
2696 write!(self.out, ") * 0.017453292519943295474)")?;
2697 }
2698 Mf::Degrees => {
2699 write!(self.out, "((")?;
2700 self.put_expression(arg, context, false)?;
2701 write!(self.out, ") * 57.295779513082322865)")?;
2702 }
2703 Mf::Modf | Mf::Frexp => {
2704 write!(self.out, "{fun_name}")?;
2705 self.put_call_parameters(iter::once(arg), context)?;
2706 }
2707 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2708 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2709 Mf::Pack4xI8Clamp => {
2710 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2711 }
2712 Mf::Pack4xU8Clamp => {
2713 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2714 }
2715 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2716 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2717 "u"
2718 } else {
2719 ""
2720 };
2721
2722 if context.lang_version >= (2, 1) {
2723 write!(
2725 self.out,
2726 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2727 )?;
2728 self.put_expression(arg, context, true)?;
2729 write!(self.out, "))")?;
2730 } else {
2731 write!(self.out, "({sign_prefix}int4(")?;
2733 self.put_expression(arg, context, true)?;
2734 write!(self.out, ", ")?;
2735 self.put_expression(arg, context, true)?;
2736 write!(self.out, " >> 8, ")?;
2737 self.put_expression(arg, context, true)?;
2738 write!(self.out, " >> 16, ")?;
2739 self.put_expression(arg, context, true)?;
2740 write!(self.out, " >> 24) << 24 >> 24)")?;
2741 }
2742 }
2743 Mf::QuantizeToF16 => {
2744 match *context.resolve_type(arg) {
2745 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2746 crate::TypeInner::Vector { size, .. } => write!(
2747 self.out,
2748 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2749 size = common::vector_size_str(size),
2750 )?,
2751 _ => unreachable!(
2752 "Correct TypeInner for QuantizeToF16 should be already validated"
2753 ),
2754 };
2755
2756 self.put_expression(arg, context, true)?;
2757 write!(self.out, "))")?;
2758 }
2759 _ => {
2760 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2761 self.put_call_parameters(
2762 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2763 context,
2764 )?;
2765 }
2766 }
2767 }
2768 crate::Expression::As {
2769 expr,
2770 kind,
2771 convert,
2772 } => match *context.resolve_type(expr) {
2773 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2774 if src.kind == crate::ScalarKind::Float
2775 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2776 && convert.is_some()
2777 {
2778 let fun_name = match (kind, convert) {
2782 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2783 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2784 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2785 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2786 _ => unreachable!(),
2787 };
2788 write!(self.out, "{fun_name}(")?;
2789 self.put_expression(expr, context, true)?;
2790 write!(self.out, ")")?;
2791 } else {
2792 let target_scalar = crate::Scalar {
2793 kind,
2794 width: convert.unwrap_or(src.width),
2795 };
2796 let op = match convert {
2797 Some(_) => "static_cast",
2798 None => "as_type",
2799 };
2800 write!(self.out, "{op}<")?;
2801 match *context.resolve_type(expr) {
2802 crate::TypeInner::Vector { size, .. } => {
2803 put_numeric_type(&mut self.out, target_scalar, &[size])?
2804 }
2805 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2806 };
2807 write!(self.out, ">(")?;
2808 self.put_expression(expr, context, true)?;
2809 write!(self.out, ")")?;
2810 }
2811 }
2812 crate::TypeInner::Matrix {
2813 columns,
2814 rows,
2815 scalar,
2816 } => {
2817 let target_scalar = crate::Scalar {
2818 kind,
2819 width: convert.unwrap_or(scalar.width),
2820 };
2821 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2822 write!(self.out, "(")?;
2823 self.put_expression(expr, context, true)?;
2824 write!(self.out, ")")?;
2825 }
2826 ref ty => {
2827 return Err(Error::GenericValidation(format!(
2828 "Unsupported type for As: {ty:?}"
2829 )))
2830 }
2831 },
2832 crate::Expression::CallResult(_)
2834 | crate::Expression::AtomicResult { .. }
2835 | crate::Expression::WorkGroupUniformLoadResult { .. }
2836 | crate::Expression::SubgroupBallotResult
2837 | crate::Expression::SubgroupOperationResult { .. }
2838 | crate::Expression::RayQueryProceedResult => {
2839 unreachable!()
2840 }
2841 crate::Expression::ArrayLength(expr) => {
2842 let global = match context.function.expressions[expr] {
2844 crate::Expression::AccessIndex { base, .. } => {
2845 match context.function.expressions[base] {
2846 crate::Expression::GlobalVariable(handle) => handle,
2847 ref ex => {
2848 return Err(Error::GenericValidation(format!(
2849 "Expected global variable in AccessIndex, got {ex:?}"
2850 )))
2851 }
2852 }
2853 }
2854 crate::Expression::GlobalVariable(handle) => handle,
2855 ref ex => {
2856 return Err(Error::GenericValidation(format!(
2857 "Unexpected expression in ArrayLength, got {ex:?}"
2858 )))
2859 }
2860 };
2861
2862 if !is_scoped {
2863 write!(self.out, "(")?;
2864 }
2865 write!(self.out, "1 + ")?;
2866 self.put_dynamic_array_max_index(global, context)?;
2867 if !is_scoped {
2868 write!(self.out, ")")?;
2869 }
2870 }
2871 crate::Expression::RayQueryVertexPositions { .. } => {
2872 unimplemented!()
2873 }
2874 crate::Expression::RayQueryGetIntersection {
2875 query,
2876 committed: _,
2877 } => {
2878 if context.lang_version < (2, 4) {
2879 return Err(Error::UnsupportedRayTracing);
2880 }
2881
2882 let ty = context.module.special_types.ray_intersection.unwrap();
2883 let type_name = &self.names[&NameKey::Type(ty)];
2884 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2885 self.put_expression(query, context, true)?;
2886 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2887 let fields = [
2888 "distance",
2889 "user_instance_id", "instance_id",
2891 "", "geometry_id",
2893 "primitive_id",
2894 "triangle_barycentric_coord",
2895 "triangle_front_facing",
2896 "", "object_to_world_transform", "world_to_object_transform", ];
2900 for field in fields {
2901 write!(self.out, ", ")?;
2902 if field.is_empty() {
2903 write!(self.out, "{{}}")?;
2904 } else {
2905 self.put_expression(query, context, true)?;
2906 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2907 }
2908 }
2909 write!(self.out, "}}")?;
2910 }
2911 crate::Expression::CooperativeLoad { ref data, .. } => {
2912 if context.lang_version < (2, 3) {
2913 return Err(Error::UnsupportedCooperativeMatrix);
2914 }
2915 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2916 write!(self.out, "&")?;
2917 self.put_access_chain(data.pointer, context.policies.index, context)?;
2918 write!(self.out, ", ")?;
2919 self.put_expression(data.stride, context, true)?;
2920 write!(self.out, ", {})", data.row_major)?;
2921 }
2922 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2923 if context.lang_version < (2, 3) {
2924 return Err(Error::UnsupportedCooperativeMatrix);
2925 }
2926 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2927 self.put_expression(a, context, true)?;
2928 write!(self.out, ", ")?;
2929 self.put_expression(b, context, true)?;
2930 write!(self.out, ", ")?;
2931 self.put_expression(c, context, true)?;
2932 write!(self.out, ")")?;
2933 }
2934 }
2935 Ok(())
2936 }
2937
2938 fn put_binop<F>(
2941 &mut self,
2942 op: crate::BinaryOperator,
2943 left: Handle<crate::Expression>,
2944 right: Handle<crate::Expression>,
2945 context: &ExpressionContext,
2946 is_scoped: bool,
2947 put_expression: &F,
2948 ) -> BackendResult
2949 where
2950 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2951 {
2952 let op_str = back::binary_operation_str(op);
2953
2954 if !is_scoped {
2955 write!(self.out, "(")?;
2956 }
2957
2958 if op == crate::BinaryOperator::Multiply
2961 && matches!(
2962 context.resolve_type(right),
2963 &crate::TypeInner::Matrix { .. }
2964 )
2965 {
2966 self.put_wrapped_expression_for_packed_vec3_access(
2967 left,
2968 context,
2969 false,
2970 put_expression,
2971 )?;
2972 } else {
2973 put_expression(self, left, context, false)?;
2974 }
2975
2976 write!(self.out, " {op_str} ")?;
2977
2978 if op == crate::BinaryOperator::Multiply
2980 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2981 {
2982 self.put_wrapped_expression_for_packed_vec3_access(
2983 right,
2984 context,
2985 false,
2986 put_expression,
2987 )?;
2988 } else {
2989 put_expression(self, right, context, false)?;
2990 }
2991
2992 if !is_scoped {
2993 write!(self.out, ")")?;
2994 }
2995
2996 Ok(())
2997 }
2998
2999 fn put_wrapped_expression_for_packed_vec3_access<F>(
3001 &mut self,
3002 expr_handle: Handle<crate::Expression>,
3003 context: &ExpressionContext,
3004 is_scoped: bool,
3005 put_expression: &F,
3006 ) -> BackendResult
3007 where
3008 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3009 {
3010 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3011 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3012 put_expression(self, expr_handle, context, is_scoped)?;
3013 write!(self.out, ")")?;
3014 } else {
3015 put_expression(self, expr_handle, context, is_scoped)?;
3016 }
3017 Ok(())
3018 }
3019
3020 fn put_bitcasted_expression<F>(
3023 &mut self,
3024 cast_to: &crate::TypeInner,
3025 context: &ExpressionContext,
3026 put_expression: &F,
3027 ) -> BackendResult
3028 where
3029 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3030 {
3031 write!(self.out, "as_type<")?;
3032 match *cast_to {
3033 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3034 crate::TypeInner::Vector { size, scalar } => {
3035 put_numeric_type(&mut self.out, scalar, &[size])?
3036 }
3037 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3038 };
3039 write!(self.out, ">(")?;
3040 put_expression(self, context, true)?;
3041 write!(self.out, ")")?;
3042
3043 Ok(())
3044 }
3045
3046 fn put_index(
3048 &mut self,
3049 index: index::GuardedIndex,
3050 context: &ExpressionContext,
3051 is_scoped: bool,
3052 ) -> BackendResult {
3053 match index {
3054 index::GuardedIndex::Expression(expr) => {
3055 self.put_expression(expr, context, is_scoped)?
3056 }
3057 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3058 }
3059 Ok(())
3060 }
3061
3062 fn put_bounds_checks(
3092 &mut self,
3093 chain: Handle<crate::Expression>,
3094 context: &ExpressionContext,
3095 level: back::Level,
3096 prefix: &'static str,
3097 ) -> Result<bool, Error> {
3098 let mut check_written = false;
3099
3100 for item in context.bounds_check_iter(chain) {
3102 let BoundsCheck {
3103 base,
3104 index,
3105 length,
3106 } = item;
3107
3108 if check_written {
3109 write!(self.out, " && ")?;
3110 } else {
3111 write!(self.out, "{level}{prefix}")?;
3112 check_written = true;
3113 }
3114
3115 write!(self.out, "uint(")?;
3119 self.put_index(index, context, true)?;
3120 self.out.write_str(") < ")?;
3121 match length {
3122 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3123 index::IndexableLength::Dynamic => {
3124 let global = context.function.originating_global(base).ok_or_else(|| {
3125 Error::GenericValidation("Could not find originating global".into())
3126 })?;
3127 write!(self.out, "1 + ")?;
3128 self.put_dynamic_array_max_index(global, context)?
3129 }
3130 }
3131 }
3132
3133 Ok(check_written)
3134 }
3135
3136 fn put_access_chain(
3156 &mut self,
3157 chain: Handle<crate::Expression>,
3158 policy: index::BoundsCheckPolicy,
3159 context: &ExpressionContext,
3160 ) -> BackendResult {
3161 match context.function.expressions[chain] {
3162 crate::Expression::Access { base, index } => {
3163 let mut base_ty = context.resolve_type(base);
3164
3165 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3167 base_ty = &context.module.types[base].inner;
3168 }
3169
3170 self.put_subscripted_access_chain(
3171 base,
3172 base_ty,
3173 index::GuardedIndex::Expression(index),
3174 policy,
3175 context,
3176 )?;
3177 }
3178 crate::Expression::AccessIndex { base, index } => {
3179 let base_resolution = &context.info[base].ty;
3180 let mut base_ty = base_resolution.inner_with(&context.module.types);
3181 let mut base_ty_handle = base_resolution.handle();
3182
3183 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3185 base_ty = &context.module.types[base].inner;
3186 base_ty_handle = Some(base);
3187 }
3188
3189 match *base_ty {
3193 crate::TypeInner::Struct { .. } => {
3194 let base_ty = base_ty_handle.unwrap();
3195 self.put_access_chain(base, policy, context)?;
3196 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3197 write!(self.out, ".{name}")?;
3198 }
3199 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3200 self.put_access_chain(base, policy, context)?;
3201 if context.get_packed_vec_kind(base).is_some() {
3204 write!(self.out, "[{index}]")?;
3205 } else {
3206 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3207 }
3208 }
3209 _ => {
3210 self.put_subscripted_access_chain(
3211 base,
3212 base_ty,
3213 index::GuardedIndex::Known(index),
3214 policy,
3215 context,
3216 )?;
3217 }
3218 }
3219 }
3220 _ => self.put_expression(chain, context, false)?,
3221 }
3222
3223 Ok(())
3224 }
3225
3226 fn put_subscripted_access_chain(
3243 &mut self,
3244 base: Handle<crate::Expression>,
3245 base_ty: &crate::TypeInner,
3246 index: index::GuardedIndex,
3247 policy: index::BoundsCheckPolicy,
3248 context: &ExpressionContext,
3249 ) -> BackendResult {
3250 let accessing_wrapped_array = match *base_ty {
3251 crate::TypeInner::Array {
3252 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3253 ..
3254 } => true,
3255 _ => false,
3256 };
3257 let accessing_wrapped_binding_array =
3258 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3259
3260 self.put_access_chain(base, policy, context)?;
3261 if accessing_wrapped_array {
3262 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3263 }
3264 write!(self.out, "[")?;
3265
3266 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3268 context.access_needs_check(base, index)
3269 } else {
3270 None
3271 };
3272 if let Some(limit) = restriction_needed {
3273 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3274 self.put_index(index, context, true)?;
3275 write!(self.out, "), ")?;
3276 match limit {
3277 index::IndexableLength::Known(limit) => {
3278 write!(self.out, "{}u", limit - 1)?;
3279 }
3280 index::IndexableLength::Dynamic => {
3281 let global = context.function.originating_global(base).ok_or_else(|| {
3282 Error::GenericValidation("Could not find originating global".into())
3283 })?;
3284 self.put_dynamic_array_max_index(global, context)?;
3285 }
3286 }
3287 write!(self.out, ")")?;
3288 } else {
3289 self.put_index(index, context, true)?;
3290 }
3291
3292 write!(self.out, "]")?;
3293
3294 if accessing_wrapped_binding_array {
3295 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3296 }
3297
3298 Ok(())
3299 }
3300
3301 fn put_load(
3302 &mut self,
3303 pointer: Handle<crate::Expression>,
3304 context: &ExpressionContext,
3305 is_scoped: bool,
3306 ) -> BackendResult {
3307 let policy = context.choose_bounds_check_policy(pointer);
3310 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3311 && self.put_bounds_checks(
3312 pointer,
3313 context,
3314 back::Level(0),
3315 if is_scoped { "" } else { "(" },
3316 )?
3317 {
3318 write!(self.out, " ? ")?;
3319 self.put_unchecked_load(pointer, policy, context)?;
3320 write!(self.out, " : DefaultConstructible()")?;
3321
3322 if !is_scoped {
3323 write!(self.out, ")")?;
3324 }
3325 } else {
3326 self.put_unchecked_load(pointer, policy, context)?;
3327 }
3328
3329 Ok(())
3330 }
3331
3332 fn put_unchecked_load(
3333 &mut self,
3334 pointer: Handle<crate::Expression>,
3335 policy: index::BoundsCheckPolicy,
3336 context: &ExpressionContext,
3337 ) -> BackendResult {
3338 let is_atomic_pointer = context
3339 .resolve_type(pointer)
3340 .is_atomic_pointer(&context.module.types);
3341
3342 if is_atomic_pointer {
3343 write!(
3344 self.out,
3345 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3346 )?;
3347 self.put_access_chain(pointer, policy, context)?;
3348 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3349 } else {
3350 self.put_access_chain(pointer, policy, context)?;
3354 }
3355
3356 Ok(())
3357 }
3358
3359 fn put_return_value(
3360 &mut self,
3361 level: back::Level,
3362 expr_handle: Handle<crate::Expression>,
3363 result_struct: Option<&str>,
3364 context: &ExpressionContext,
3365 ) -> BackendResult {
3366 match result_struct {
3367 Some(struct_name) => {
3368 let mut has_point_size = false;
3369 let result_ty = context.function.result.as_ref().unwrap().ty;
3370 match context.module.types[result_ty].inner {
3371 crate::TypeInner::Struct { ref members, .. } => {
3372 let tmp = "_tmp";
3373 write!(self.out, "{level}const auto {tmp} = ")?;
3374 self.put_expression(expr_handle, context, true)?;
3375 writeln!(self.out, ";")?;
3376 write!(self.out, "{level}return {struct_name} {{")?;
3377
3378 let mut is_first = true;
3379
3380 for (index, member) in members.iter().enumerate() {
3381 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3382 member.binding
3383 {
3384 has_point_size = true;
3385 if !context.pipeline_options.allow_and_force_point_size {
3386 continue;
3387 }
3388 }
3389
3390 let comma = if is_first { "" } else { "," };
3391 is_first = false;
3392 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3393 if let crate::TypeInner::Array {
3397 size: crate::ArraySize::Constant(size),
3398 ..
3399 } = context.module.types[member.ty].inner
3400 {
3401 write!(self.out, "{comma} {{")?;
3402 for j in 0..size.get() {
3403 if j != 0 {
3404 write!(self.out, ",")?;
3405 }
3406 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3407 }
3408 write!(self.out, "}}")?;
3409 } else {
3410 write!(self.out, "{comma} {tmp}.{name}")?;
3411 }
3412 }
3413 }
3414 _ => {
3415 write!(self.out, "{level}return {struct_name} {{ ")?;
3416 self.put_expression(expr_handle, context, true)?;
3417 }
3418 }
3419
3420 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3421 let stage = context.module.entry_points[ep_index as usize].stage;
3422 if context.pipeline_options.allow_and_force_point_size
3423 && stage == crate::ShaderStage::Vertex
3424 && !has_point_size
3425 {
3426 write!(self.out, ", 1.0")?;
3428 }
3429 }
3430 write!(self.out, " }}")?;
3431 }
3432 None => {
3433 write!(self.out, "{level}return ")?;
3434 self.put_expression(expr_handle, context, true)?;
3435 }
3436 }
3437 writeln!(self.out, ";")?;
3438 Ok(())
3439 }
3440
3441 fn update_expressions_to_bake(
3446 &mut self,
3447 func: &crate::Function,
3448 info: &valid::FunctionInfo,
3449 context: &ExpressionContext,
3450 ) {
3451 use crate::Expression;
3452 self.need_bake_expressions.clear();
3453
3454 for (expr_handle, expr) in func.expressions.iter() {
3455 let expr_info = &info[expr_handle];
3458 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3459 if min_ref_count <= expr_info.ref_count {
3460 self.need_bake_expressions.insert(expr_handle);
3461 } else {
3462 match expr_info.ty {
3463 TypeResolution::Handle(h)
3465 if Some(h) == context.module.special_types.ray_desc =>
3466 {
3467 self.need_bake_expressions.insert(expr_handle);
3468 }
3469 _ => {}
3470 }
3471 }
3472
3473 if let Expression::Math {
3474 fun,
3475 arg,
3476 arg1,
3477 arg2,
3478 ..
3479 } = *expr
3480 {
3481 match fun {
3482 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3492 self.need_bake_expressions.insert(arg);
3493 self.need_bake_expressions.insert(arg1.unwrap());
3494 }
3495 crate::MathFunction::FirstLeadingBit => {
3496 self.need_bake_expressions.insert(arg);
3497 }
3498 crate::MathFunction::Pack4xI8
3499 | crate::MathFunction::Pack4xU8
3500 | crate::MathFunction::Pack4xI8Clamp
3501 | crate::MathFunction::Pack4xU8Clamp
3502 | crate::MathFunction::Unpack4xI8
3503 | crate::MathFunction::Unpack4xU8 => {
3504 if context.lang_version < (2, 1) {
3507 self.need_bake_expressions.insert(arg);
3508 }
3509 }
3510 crate::MathFunction::ExtractBits => {
3511 self.need_bake_expressions.insert(arg1.unwrap());
3513 }
3514 crate::MathFunction::InsertBits => {
3515 self.need_bake_expressions.insert(arg2.unwrap());
3517 }
3518 crate::MathFunction::Sign => {
3519 let inner = context.resolve_type(expr_handle);
3524 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3525 self.need_bake_expressions.insert(arg);
3526 }
3527 }
3528 _ => {}
3529 }
3530 }
3531 }
3532 }
3533
3534 fn start_baking_expression(
3535 &mut self,
3536 handle: Handle<crate::Expression>,
3537 context: &ExpressionContext,
3538 name: &str,
3539 ) -> BackendResult {
3540 match context.info[handle].ty {
3541 TypeResolution::Handle(ty_handle) => {
3542 let ty_name = TypeContext {
3543 handle: ty_handle,
3544 gctx: context.module.to_ctx(),
3545 names: &self.names,
3546 access: crate::StorageAccess::empty(),
3547 first_time: false,
3548 };
3549 write!(self.out, "{ty_name}")?;
3550 }
3551 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3552 put_numeric_type(&mut self.out, scalar, &[])?;
3553 }
3554 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3555 put_numeric_type(&mut self.out, scalar, &[size])?;
3556 }
3557 TypeResolution::Value(crate::TypeInner::Matrix {
3558 columns,
3559 rows,
3560 scalar,
3561 }) => {
3562 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3563 }
3564 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3565 columns,
3566 rows,
3567 scalar,
3568 role: _,
3569 }) => {
3570 write!(
3571 self.out,
3572 "{}::simdgroup_{}{}x{}",
3573 NAMESPACE,
3574 scalar.to_msl_name(),
3575 columns as u32,
3576 rows as u32,
3577 )?;
3578 }
3579 TypeResolution::Value(ref other) => {
3580 log::warn!("Type {other:?} isn't a known local");
3581 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3582 }
3583 }
3584
3585 write!(self.out, " {name} = ")?;
3587
3588 Ok(())
3589 }
3590
3591 fn put_cache_restricted_level(
3604 &mut self,
3605 load: Handle<crate::Expression>,
3606 image: Handle<crate::Expression>,
3607 mip_level: Option<Handle<crate::Expression>>,
3608 indent: back::Level,
3609 context: &StatementContext,
3610 ) -> BackendResult {
3611 let level_of_detail = match mip_level {
3614 Some(level) => level,
3615 None => return Ok(()),
3616 };
3617
3618 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3619 || !context.expression.image_needs_lod(image)
3620 {
3621 return Ok(());
3622 }
3623
3624 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3625 self.put_restricted_scalar_image_index(
3626 image,
3627 level_of_detail,
3628 "get_num_mip_levels",
3629 &context.expression,
3630 )?;
3631 writeln!(self.out, ";")?;
3632
3633 Ok(())
3634 }
3635
3636 fn put_casting_to_packed_chars(
3642 &mut self,
3643 fun: crate::MathFunction,
3644 arg0: Handle<crate::Expression>,
3645 arg1: Handle<crate::Expression>,
3646 indent: back::Level,
3647 context: &StatementContext<'_>,
3648 ) -> Result<(), Error> {
3649 let packed_type = match fun {
3650 crate::MathFunction::Dot4I8Packed => "packed_char4",
3651 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3652 _ => unreachable!(),
3653 };
3654
3655 for arg in [arg0, arg1] {
3656 write!(
3657 self.out,
3658 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3659 Reinterpreted::new(packed_type, arg)
3660 )?;
3661 self.put_expression(arg, &context.expression, true)?;
3662 writeln!(self.out, ");")?;
3663 }
3664
3665 Ok(())
3666 }
3667
3668 fn put_block(
3669 &mut self,
3670 level: back::Level,
3671 statements: &[crate::Statement],
3672 context: &StatementContext,
3673 ) -> BackendResult {
3674 #[cfg(test)]
3676 self.put_block_stack_pointers
3677 .insert(ptr::from_ref(&level).cast());
3678
3679 for statement in statements {
3680 log::trace!("statement[{}] {:?}", level.0, statement);
3681 match *statement {
3682 crate::Statement::Emit(ref range) => {
3683 for handle in range.clone() {
3684 use crate::MathFunction as Mf;
3685
3686 match context.expression.function.expressions[handle] {
3687 crate::Expression::ImageLoad {
3690 image,
3691 level: mip_level,
3692 ..
3693 } => {
3694 self.put_cache_restricted_level(
3695 handle, image, mip_level, level, context,
3696 )?;
3697 }
3698
3699 crate::Expression::Math {
3708 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3709 arg,
3710 arg1,
3711 ..
3712 } if context.expression.lang_version >= (2, 1) => {
3713 self.put_casting_to_packed_chars(
3714 fun,
3715 arg,
3716 arg1.unwrap(),
3717 level,
3718 context,
3719 )?;
3720 }
3721
3722 _ => (),
3723 }
3724
3725 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3726 let expr_name = if ptr_class.is_some() {
3727 None } else if let Some(name) =
3729 context.expression.function.named_expressions.get(&handle)
3730 {
3731 Some(self.namer.call(name))
3741 } else {
3742 let bake = if context.expression.guarded_indices.contains(handle) {
3746 true
3747 } else {
3748 self.need_bake_expressions.contains(&handle)
3749 };
3750
3751 if bake {
3752 Some(Baked(handle).to_string())
3753 } else {
3754 None
3755 }
3756 };
3757
3758 if let Some(name) = expr_name {
3759 write!(self.out, "{level}")?;
3760 self.start_baking_expression(handle, &context.expression, &name)?;
3761 self.put_expression(handle, &context.expression, true)?;
3762 self.named_expressions.insert(handle, name);
3763 writeln!(self.out, ";")?;
3764 }
3765 }
3766 }
3767 crate::Statement::Block(ref block) => {
3768 if !block.is_empty() {
3769 writeln!(self.out, "{level}{{")?;
3770 self.put_block(level.next(), block, context)?;
3771 writeln!(self.out, "{level}}}")?;
3772 }
3773 }
3774 crate::Statement::If {
3775 condition,
3776 ref accept,
3777 ref reject,
3778 } => {
3779 write!(self.out, "{level}if (")?;
3780 self.put_expression(condition, &context.expression, true)?;
3781 writeln!(self.out, ") {{")?;
3782 self.put_block(level.next(), accept, context)?;
3783 if !reject.is_empty() {
3784 writeln!(self.out, "{level}}} else {{")?;
3785 self.put_block(level.next(), reject, context)?;
3786 }
3787 writeln!(self.out, "{level}}}")?;
3788 }
3789 crate::Statement::Switch {
3790 selector,
3791 ref cases,
3792 } => {
3793 write!(self.out, "{level}switch(")?;
3794 self.put_expression(selector, &context.expression, true)?;
3795 writeln!(self.out, ") {{")?;
3796 let lcase = level.next();
3797 for case in cases.iter() {
3798 match case.value {
3799 crate::SwitchValue::I32(value) => {
3800 write!(self.out, "{lcase}case {value}:")?;
3801 }
3802 crate::SwitchValue::U32(value) => {
3803 write!(self.out, "{lcase}case {value}u:")?;
3804 }
3805 crate::SwitchValue::Default => {
3806 write!(self.out, "{lcase}default:")?;
3807 }
3808 }
3809
3810 let write_block_braces = !(case.fall_through && case.body.is_empty());
3811 if write_block_braces {
3812 writeln!(self.out, " {{")?;
3813 } else {
3814 writeln!(self.out)?;
3815 }
3816
3817 self.put_block(lcase.next(), &case.body, context)?;
3818 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3819 {
3820 writeln!(self.out, "{}break;", lcase.next())?;
3821 }
3822
3823 if write_block_braces {
3824 writeln!(self.out, "{lcase}}}")?;
3825 }
3826 }
3827 writeln!(self.out, "{level}}}")?;
3828 }
3829 crate::Statement::Loop {
3830 ref body,
3831 ref continuing,
3832 break_if,
3833 } => {
3834 let force_loop_bound_statements =
3835 self.gen_force_bounded_loop_statements(level, context);
3836 let gate_name = (!continuing.is_empty() || break_if.is_some())
3837 .then(|| self.namer.call("loop_init"));
3838
3839 if let Some((ref decl, _)) = force_loop_bound_statements {
3840 writeln!(self.out, "{decl}")?;
3841 }
3842 if let Some(ref gate_name) = gate_name {
3843 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3844 }
3845
3846 writeln!(self.out, "{level}while(true) {{",)?;
3847 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3848 writeln!(self.out, "{break_and_inc}")?;
3849 }
3850 if let Some(ref gate_name) = gate_name {
3851 let lif = level.next();
3852 let lcontinuing = lif.next();
3853 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3854 self.put_block(lcontinuing, continuing, context)?;
3855 if let Some(condition) = break_if {
3856 write!(self.out, "{lcontinuing}if (")?;
3857 self.put_expression(condition, &context.expression, true)?;
3858 writeln!(self.out, ") {{")?;
3859 writeln!(self.out, "{}break;", lcontinuing.next())?;
3860 writeln!(self.out, "{lcontinuing}}}")?;
3861 }
3862 writeln!(self.out, "{lif}}}")?;
3863 writeln!(self.out, "{lif}{gate_name} = false;")?;
3864 }
3865 self.put_block(level.next(), body, context)?;
3866
3867 writeln!(self.out, "{level}}}")?;
3868 }
3869 crate::Statement::Break => {
3870 writeln!(self.out, "{level}break;")?;
3871 }
3872 crate::Statement::Continue => {
3873 writeln!(self.out, "{level}continue;")?;
3874 }
3875 crate::Statement::Return {
3876 value: Some(expr_handle),
3877 } => {
3878 self.put_return_value(
3879 level,
3880 expr_handle,
3881 context.result_struct,
3882 &context.expression,
3883 )?;
3884 }
3885 crate::Statement::Return { value: None } => {
3886 writeln!(self.out, "{level}return;")?;
3887 }
3888 crate::Statement::Kill => {
3889 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3890 }
3891 crate::Statement::ControlBarrier(flags)
3892 | crate::Statement::MemoryBarrier(flags) => {
3893 self.write_barrier(flags, level)?;
3894 }
3895 crate::Statement::Store { pointer, value } => {
3896 self.put_store(pointer, value, level, context)?
3897 }
3898 crate::Statement::ImageStore {
3899 image,
3900 coordinate,
3901 array_index,
3902 value,
3903 } => {
3904 let address = TexelAddress {
3905 coordinate,
3906 array_index,
3907 sample: None,
3908 level: None,
3909 };
3910 self.put_image_store(level, image, &address, value, context)?
3911 }
3912 crate::Statement::Call {
3913 function,
3914 ref arguments,
3915 result,
3916 } => {
3917 write!(self.out, "{level}")?;
3918 if let Some(expr) = result {
3919 let name = Baked(expr).to_string();
3920 self.start_baking_expression(expr, &context.expression, &name)?;
3921 self.named_expressions.insert(expr, name);
3922 }
3923 let fun_name = &self.names[&NameKey::Function(function)];
3924 write!(self.out, "{fun_name}(")?;
3925 for (i, &handle) in arguments.iter().enumerate() {
3927 if i != 0 {
3928 write!(self.out, ", ")?;
3929 }
3930 self.put_expression(handle, &context.expression, true)?;
3931 }
3932 let mut separate = !arguments.is_empty();
3934 let fun_info = &context.expression.mod_info[function];
3935 let mut needs_buffer_sizes = false;
3936 for (handle, var) in context.expression.module.global_variables.iter() {
3937 if fun_info[handle].is_empty() {
3938 continue;
3939 }
3940 if var.space.needs_pass_through() {
3941 let name = &self.names[&NameKey::GlobalVariable(handle)];
3942 if separate {
3943 write!(self.out, ", ")?;
3944 } else {
3945 separate = true;
3946 }
3947 write!(self.out, "{name}")?;
3948 }
3949 needs_buffer_sizes |=
3950 needs_array_length(var.ty, &context.expression.module.types);
3951 }
3952 if needs_buffer_sizes {
3953 if separate {
3954 write!(self.out, ", ")?;
3955 }
3956 write!(self.out, "_buffer_sizes")?;
3957 }
3958
3959 writeln!(self.out, ");")?;
3961 }
3962 crate::Statement::Atomic {
3963 pointer,
3964 ref fun,
3965 value,
3966 result,
3967 } => {
3968 let context = &context.expression;
3969
3970 write!(self.out, "{level}")?;
3975 let fun_key = if let Some(result) = result {
3976 let res_name = Baked(result).to_string();
3977 self.start_baking_expression(result, context, &res_name)?;
3978 self.named_expressions.insert(result, res_name);
3979 fun.to_msl()
3980 } else if context.resolve_type(value).scalar_width() == Some(8) {
3981 fun.to_msl_64_bit()?
3982 } else {
3983 fun.to_msl()
3984 };
3985
3986 let policy = context.choose_bounds_check_policy(pointer);
3990 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3991 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3992
3993 if checked {
3995 write!(self.out, " ? ")?;
3996 }
3997
3998 match *fun {
4000 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4001 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4002 self.put_access_chain(pointer, policy, context)?;
4003 write!(self.out, ", ")?;
4004 self.put_expression(cmp, context, true)?;
4005 write!(self.out, ", ")?;
4006 self.put_expression(value, context, true)?;
4007 write!(self.out, ")")?;
4008 }
4009 _ => {
4010 write!(
4011 self.out,
4012 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4013 )?;
4014 self.put_access_chain(pointer, policy, context)?;
4015 write!(self.out, ", ")?;
4016 self.put_expression(value, context, true)?;
4017 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4018 }
4019 }
4020
4021 if checked {
4023 write!(self.out, " : DefaultConstructible()")?;
4024 }
4025
4026 writeln!(self.out, ";")?;
4028 }
4029 crate::Statement::ImageAtomic {
4030 image,
4031 coordinate,
4032 array_index,
4033 fun,
4034 value,
4035 } => {
4036 let address = TexelAddress {
4037 coordinate,
4038 array_index,
4039 sample: None,
4040 level: None,
4041 };
4042 self.put_image_atomic(level, image, &address, fun, value, context)?
4043 }
4044 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4045 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4046
4047 write!(self.out, "{level}")?;
4048 let name = self.namer.call("");
4049 self.start_baking_expression(result, &context.expression, &name)?;
4050 self.put_load(pointer, &context.expression, true)?;
4051 self.named_expressions.insert(result, name);
4052
4053 writeln!(self.out, ";")?;
4054 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4055 }
4056 crate::Statement::RayQuery { query, ref fun } => {
4057 if context.expression.lang_version < (2, 4) {
4058 return Err(Error::UnsupportedRayTracing);
4059 }
4060
4061 match *fun {
4062 crate::RayQueryFunction::Initialize {
4063 acceleration_structure,
4064 descriptor,
4065 } => {
4066 write!(self.out, "{level}")?;
4068 self.put_expression(query, &context.expression, true)?;
4069 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4070 {
4071 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4072 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4073 write!(self.out, "{level}")?;
4074 self.put_expression(query, &context.expression, true)?;
4075 write!(
4076 self.out,
4077 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4078 )?;
4079 self.put_expression(descriptor, &context.expression, true)?;
4080 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4081 self.put_expression(descriptor, &context.expression, true)?;
4082 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4083 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4084 }
4085 {
4086 let f_opaque = back::RayFlag::OPAQUE.bits();
4087 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4088 write!(self.out, "{level}")?;
4089 self.put_expression(query, &context.expression, true)?;
4090 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4091 self.put_expression(descriptor, &context.expression, true)?;
4092 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4093 self.put_expression(descriptor, &context.expression, true)?;
4094 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4095 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4096 }
4097 {
4098 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4099 write!(self.out, "{level}")?;
4100 self.put_expression(query, &context.expression, true)?;
4101 write!(
4102 self.out,
4103 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4104 )?;
4105 self.put_expression(descriptor, &context.expression, true)?;
4106 writeln!(self.out, ".flags & {flag}) != 0);")?;
4107 }
4108
4109 write!(self.out, "{level}")?;
4110 self.put_expression(query, &context.expression, true)?;
4111 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4112 self.put_expression(query, &context.expression, true)?;
4113 write!(
4114 self.out,
4115 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4116 )?;
4117 self.put_expression(descriptor, &context.expression, true)?;
4118 write!(self.out, ".origin, ")?;
4119 self.put_expression(descriptor, &context.expression, true)?;
4120 write!(self.out, ".dir, ")?;
4121 self.put_expression(descriptor, &context.expression, true)?;
4122 write!(self.out, ".tmin, ")?;
4123 self.put_expression(descriptor, &context.expression, true)?;
4124 write!(self.out, ".tmax), ")?;
4125 self.put_expression(acceleration_structure, &context.expression, true)?;
4126 write!(self.out, ", ")?;
4127 self.put_expression(descriptor, &context.expression, true)?;
4128 write!(self.out, ".cull_mask);")?;
4129
4130 write!(self.out, "{level}")?;
4131 self.put_expression(query, &context.expression, true)?;
4132 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4133 }
4134 crate::RayQueryFunction::Proceed { result } => {
4135 write!(self.out, "{level}")?;
4136 let name = Baked(result).to_string();
4137 self.start_baking_expression(result, &context.expression, &name)?;
4138 self.named_expressions.insert(result, name);
4139 self.put_expression(query, &context.expression, true)?;
4140 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4141 if RAY_QUERY_MODERN_SUPPORT {
4142 write!(self.out, "{level}")?;
4143 self.put_expression(query, &context.expression, true)?;
4144 writeln!(self.out, ".?.next();")?;
4145 }
4146 }
4147 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4148 if RAY_QUERY_MODERN_SUPPORT {
4149 write!(self.out, "{level}")?;
4150 self.put_expression(query, &context.expression, true)?;
4151 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4152 self.put_expression(hit_t, &context.expression, true)?;
4153 writeln!(self.out, ");")?;
4154 } else {
4155 log::warn!("Ray Query GenerateIntersection is not yet supported");
4156 }
4157 }
4158 crate::RayQueryFunction::ConfirmIntersection => {
4159 if RAY_QUERY_MODERN_SUPPORT {
4160 write!(self.out, "{level}")?;
4161 self.put_expression(query, &context.expression, true)?;
4162 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4163 } else {
4164 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4165 }
4166 }
4167 crate::RayQueryFunction::Terminate => {
4168 if RAY_QUERY_MODERN_SUPPORT {
4169 write!(self.out, "{level}")?;
4170 self.put_expression(query, &context.expression, true)?;
4171 writeln!(self.out, ".?.abort();")?;
4172 }
4173 write!(self.out, "{level}")?;
4174 self.put_expression(query, &context.expression, true)?;
4175 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4176 }
4177 }
4178 }
4179 crate::Statement::SubgroupBallot { result, predicate } => {
4180 write!(self.out, "{level}")?;
4181 let name = self.namer.call("");
4182 self.start_baking_expression(result, &context.expression, &name)?;
4183 self.named_expressions.insert(result, name);
4184 write!(
4185 self.out,
4186 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4187 )?;
4188 if let Some(predicate) = predicate {
4189 self.put_expression(predicate, &context.expression, true)?;
4190 } else {
4191 write!(self.out, "true")?;
4192 }
4193 writeln!(self.out, "), 0, 0, 0);")?;
4194 }
4195 crate::Statement::SubgroupCollectiveOperation {
4196 op,
4197 collective_op,
4198 argument,
4199 result,
4200 } => {
4201 write!(self.out, "{level}")?;
4202 let name = self.namer.call("");
4203 self.start_baking_expression(result, &context.expression, &name)?;
4204 self.named_expressions.insert(result, name);
4205 match (collective_op, op) {
4206 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4207 write!(self.out, "{NAMESPACE}::simd_all(")?
4208 }
4209 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4210 write!(self.out, "{NAMESPACE}::simd_any(")?
4211 }
4212 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4213 write!(self.out, "{NAMESPACE}::simd_sum(")?
4214 }
4215 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4216 write!(self.out, "{NAMESPACE}::simd_product(")?
4217 }
4218 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4219 write!(self.out, "{NAMESPACE}::simd_max(")?
4220 }
4221 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4222 write!(self.out, "{NAMESPACE}::simd_min(")?
4223 }
4224 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4225 write!(self.out, "{NAMESPACE}::simd_and(")?
4226 }
4227 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4228 write!(self.out, "{NAMESPACE}::simd_or(")?
4229 }
4230 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4231 write!(self.out, "{NAMESPACE}::simd_xor(")?
4232 }
4233 (
4234 crate::CollectiveOperation::ExclusiveScan,
4235 crate::SubgroupOperation::Add,
4236 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4237 (
4238 crate::CollectiveOperation::ExclusiveScan,
4239 crate::SubgroupOperation::Mul,
4240 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4241 (
4242 crate::CollectiveOperation::InclusiveScan,
4243 crate::SubgroupOperation::Add,
4244 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4245 (
4246 crate::CollectiveOperation::InclusiveScan,
4247 crate::SubgroupOperation::Mul,
4248 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4249 _ => unimplemented!(),
4250 }
4251 self.put_expression(argument, &context.expression, true)?;
4252 writeln!(self.out, ");")?;
4253 }
4254 crate::Statement::SubgroupGather {
4255 mode,
4256 argument,
4257 result,
4258 } => {
4259 write!(self.out, "{level}")?;
4260 let name = self.namer.call("");
4261 self.start_baking_expression(result, &context.expression, &name)?;
4262 self.named_expressions.insert(result, name);
4263 match mode {
4264 crate::GatherMode::BroadcastFirst => {
4265 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4266 }
4267 crate::GatherMode::Broadcast(_) => {
4268 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4269 }
4270 crate::GatherMode::Shuffle(_) => {
4271 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4272 }
4273 crate::GatherMode::ShuffleDown(_) => {
4274 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4275 }
4276 crate::GatherMode::ShuffleUp(_) => {
4277 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4278 }
4279 crate::GatherMode::ShuffleXor(_) => {
4280 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4281 }
4282 crate::GatherMode::QuadBroadcast(_) => {
4283 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4284 }
4285 crate::GatherMode::QuadSwap(_) => {
4286 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4287 }
4288 }
4289 self.put_expression(argument, &context.expression, true)?;
4290 match mode {
4291 crate::GatherMode::BroadcastFirst => {}
4292 crate::GatherMode::Broadcast(index)
4293 | crate::GatherMode::Shuffle(index)
4294 | crate::GatherMode::ShuffleDown(index)
4295 | crate::GatherMode::ShuffleUp(index)
4296 | crate::GatherMode::ShuffleXor(index)
4297 | crate::GatherMode::QuadBroadcast(index) => {
4298 write!(self.out, ", ")?;
4299 self.put_expression(index, &context.expression, true)?;
4300 }
4301 crate::GatherMode::QuadSwap(direction) => {
4302 write!(self.out, ", ")?;
4303 match direction {
4304 crate::Direction::X => {
4305 write!(self.out, "1u")?;
4306 }
4307 crate::Direction::Y => {
4308 write!(self.out, "2u")?;
4309 }
4310 crate::Direction::Diagonal => {
4311 write!(self.out, "3u")?;
4312 }
4313 }
4314 }
4315 }
4316 writeln!(self.out, ");")?;
4317 }
4318 crate::Statement::CooperativeStore { target, ref data } => {
4319 write!(self.out, "{level}simdgroup_store(")?;
4320 self.put_expression(target, &context.expression, true)?;
4321 write!(self.out, ", &")?;
4322 self.put_access_chain(
4323 data.pointer,
4324 context.expression.policies.index,
4325 &context.expression,
4326 )?;
4327 write!(self.out, ", ")?;
4328 self.put_expression(data.stride, &context.expression, true)?;
4329 if data.row_major {
4330 let matrix_origin = "0";
4331 let transpose = true;
4332 write!(self.out, ", {matrix_origin}, {transpose}")?;
4333 }
4334 writeln!(self.out, ");")?;
4335 }
4336 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4337 }
4338 }
4339
4340 for statement in statements {
4343 if let crate::Statement::Emit(ref range) = *statement {
4344 for handle in range.clone() {
4345 self.named_expressions.shift_remove(&handle);
4346 }
4347 }
4348 }
4349 Ok(())
4350 }
4351
4352 fn put_store(
4353 &mut self,
4354 pointer: Handle<crate::Expression>,
4355 value: Handle<crate::Expression>,
4356 level: back::Level,
4357 context: &StatementContext,
4358 ) -> BackendResult {
4359 let policy = context.expression.choose_bounds_check_policy(pointer);
4360 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4361 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4362 {
4363 writeln!(self.out, ") {{")?;
4364 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4365 writeln!(self.out, "{level}}}")?;
4366 } else {
4367 self.put_unchecked_store(pointer, value, policy, level, context)?;
4368 }
4369
4370 Ok(())
4371 }
4372
4373 fn put_unchecked_store(
4374 &mut self,
4375 pointer: Handle<crate::Expression>,
4376 value: Handle<crate::Expression>,
4377 policy: index::BoundsCheckPolicy,
4378 level: back::Level,
4379 context: &StatementContext,
4380 ) -> BackendResult {
4381 let is_atomic_pointer = context
4382 .expression
4383 .resolve_type(pointer)
4384 .is_atomic_pointer(&context.expression.module.types);
4385
4386 if is_atomic_pointer {
4387 write!(
4388 self.out,
4389 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4390 )?;
4391 self.put_access_chain(pointer, policy, &context.expression)?;
4392 write!(self.out, ", ")?;
4393 self.put_expression(value, &context.expression, true)?;
4394 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4395 } else {
4396 write!(self.out, "{level}")?;
4397 self.put_access_chain(pointer, policy, &context.expression)?;
4398 write!(self.out, " = ")?;
4399 self.put_expression(value, &context.expression, true)?;
4400 writeln!(self.out, ";")?;
4401 }
4402
4403 Ok(())
4404 }
4405
4406 pub fn write(
4407 &mut self,
4408 module: &crate::Module,
4409 info: &valid::ModuleInfo,
4410 options: &Options,
4411 pipeline_options: &PipelineOptions,
4412 ) -> Result<TranslationInfo, Error> {
4413 self.names.clear();
4414 self.namer.reset(
4415 module,
4416 &super::keywords::RESERVED_SET,
4417 proc::KeywordSet::empty(),
4418 proc::CaseInsensitiveKeywordSet::empty(),
4419 &[CLAMPED_LOD_LOAD_PREFIX],
4420 &mut self.names,
4421 );
4422 self.wrapped_functions.clear();
4423 self.struct_member_pads.clear();
4424
4425 writeln!(
4426 self.out,
4427 "// language: metal{}.{}",
4428 options.lang_version.0, options.lang_version.1
4429 )?;
4430 writeln!(self.out, "#include <metal_stdlib>")?;
4431 writeln!(self.out, "#include <simd/simd.h>")?;
4432 writeln!(self.out)?;
4433 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4435
4436 let mut uses_ray_query = false;
4437 for (_, ty) in module.types.iter() {
4438 match ty.inner {
4439 crate::TypeInner::AccelerationStructure { .. } => {
4440 if options.lang_version < (2, 4) {
4441 return Err(Error::UnsupportedRayTracing);
4442 }
4443 }
4444 crate::TypeInner::RayQuery { .. } => {
4445 if options.lang_version < (2, 4) {
4446 return Err(Error::UnsupportedRayTracing);
4447 }
4448 uses_ray_query = true;
4449 }
4450 _ => (),
4451 }
4452 }
4453
4454 if module.special_types.ray_desc.is_some()
4455 || module.special_types.ray_intersection.is_some()
4456 {
4457 if options.lang_version < (2, 4) {
4458 return Err(Error::UnsupportedRayTracing);
4459 }
4460 }
4461
4462 if uses_ray_query {
4463 self.put_ray_query_type()?;
4464 }
4465
4466 if options
4467 .bounds_check_policies
4468 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4469 {
4470 self.put_default_constructible()?;
4471 }
4472 writeln!(self.out)?;
4473
4474 {
4475 let globals: Vec<Handle<crate::GlobalVariable>> = module
4478 .global_variables
4479 .iter()
4480 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4481 .map(|(handle, _)| handle)
4482 .collect();
4483
4484 let mut buffer_indices = vec![];
4485 for vbm in &pipeline_options.vertex_buffer_mappings {
4486 buffer_indices.push(vbm.id);
4487 }
4488
4489 if !globals.is_empty() || !buffer_indices.is_empty() {
4490 writeln!(self.out, "struct _mslBufferSizes {{")?;
4491
4492 for global in globals {
4493 writeln!(
4494 self.out,
4495 "{}uint {};",
4496 back::INDENT,
4497 ArraySizeMember(global)
4498 )?;
4499 }
4500
4501 for idx in buffer_indices {
4502 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4503 }
4504
4505 writeln!(self.out, "}};")?;
4506 writeln!(self.out)?;
4507 }
4508 };
4509
4510 self.write_type_defs(module)?;
4511 self.write_global_constants(module, info)?;
4512 self.write_functions(module, info, options, pipeline_options)
4513 }
4514
4515 fn put_default_constructible(&mut self) -> BackendResult {
4528 let tab = back::INDENT;
4529 writeln!(self.out, "struct DefaultConstructible {{")?;
4530 writeln!(self.out, "{tab}template<typename T>")?;
4531 writeln!(self.out, "{tab}operator T() && {{")?;
4532 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4533 writeln!(self.out, "{tab}}}")?;
4534 writeln!(self.out, "}};")?;
4535 Ok(())
4536 }
4537
4538 fn put_ray_query_type(&mut self) -> BackendResult {
4539 let tab = back::INDENT;
4540 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4541 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4542 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4543 writeln!(
4544 self.out,
4545 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4546 )?;
4547 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4548 writeln!(self.out, "}};")?;
4549 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4550 let v_triangle = back::RayIntersectionType::Triangle as u32;
4551 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4552 writeln!(
4553 self.out,
4554 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4555 )?;
4556 writeln!(
4557 self.out,
4558 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4559 )?;
4560 writeln!(self.out, "}}")?;
4561 Ok(())
4562 }
4563
4564 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4565 let mut generated_argument_buffer_wrapper = false;
4566 let mut generated_external_texture_wrapper = false;
4567 for (handle, ty) in module.types.iter() {
4568 match ty.inner {
4569 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4570 writeln!(self.out, "template <typename T>")?;
4571 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4572 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4573 writeln!(self.out, "}};")?;
4574 generated_argument_buffer_wrapper = true;
4575 }
4576 crate::TypeInner::Image {
4577 class: crate::ImageClass::External,
4578 ..
4579 } if !generated_external_texture_wrapper => {
4580 let params_ty_name = &self.names
4581 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4582 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4583 writeln!(
4584 self.out,
4585 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4586 back::INDENT
4587 )?;
4588 writeln!(
4589 self.out,
4590 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4591 back::INDENT
4592 )?;
4593 writeln!(
4594 self.out,
4595 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4596 back::INDENT
4597 )?;
4598 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4599 writeln!(self.out, "}};")?;
4600 generated_external_texture_wrapper = true;
4601 }
4602 _ => {}
4603 }
4604
4605 if !ty.needs_alias() {
4606 continue;
4607 }
4608 let name = &self.names[&NameKey::Type(handle)];
4609 match ty.inner {
4610 crate::TypeInner::Array {
4624 base,
4625 size,
4626 stride: _,
4627 } => {
4628 let base_name = TypeContext {
4629 handle: base,
4630 gctx: module.to_ctx(),
4631 names: &self.names,
4632 access: crate::StorageAccess::empty(),
4633 first_time: false,
4634 };
4635
4636 match size.resolve(module.to_ctx())? {
4637 proc::IndexableLength::Known(size) => {
4638 writeln!(self.out, "struct {name} {{")?;
4639 writeln!(
4640 self.out,
4641 "{}{} {}[{}];",
4642 back::INDENT,
4643 base_name,
4644 WRAPPED_ARRAY_FIELD,
4645 size
4646 )?;
4647 writeln!(self.out, "}};")?;
4648 }
4649 proc::IndexableLength::Dynamic => {
4650 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4651 }
4652 }
4653 }
4654 crate::TypeInner::Struct {
4655 ref members, span, ..
4656 } => {
4657 writeln!(self.out, "struct {name} {{")?;
4658 let mut last_offset = 0;
4659 for (index, member) in members.iter().enumerate() {
4660 if member.offset > last_offset {
4661 self.struct_member_pads.insert((handle, index as u32));
4662 let pad = member.offset - last_offset;
4663 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4664 }
4665 let ty_inner = &module.types[member.ty].inner;
4666 last_offset = member.offset + ty_inner.size(module.to_ctx());
4667
4668 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4669
4670 match should_pack_struct_member(members, span, index, module) {
4672 Some(scalar) => {
4673 writeln!(
4674 self.out,
4675 "{}{}::packed_{}3 {};",
4676 back::INDENT,
4677 NAMESPACE,
4678 scalar.to_msl_name(),
4679 member_name
4680 )?;
4681 }
4682 None => {
4683 let base_name = TypeContext {
4684 handle: member.ty,
4685 gctx: module.to_ctx(),
4686 names: &self.names,
4687 access: crate::StorageAccess::empty(),
4688 first_time: false,
4689 };
4690 writeln!(
4691 self.out,
4692 "{}{} {};",
4693 back::INDENT,
4694 base_name,
4695 member_name
4696 )?;
4697
4698 if let crate::TypeInner::Vector {
4700 size: crate::VectorSize::Tri,
4701 scalar,
4702 } = *ty_inner
4703 {
4704 last_offset += scalar.width as u32;
4705 }
4706 }
4707 }
4708 }
4709 if last_offset < span {
4710 let pad = span - last_offset;
4711 writeln!(
4712 self.out,
4713 "{}char _pad{}[{}];",
4714 back::INDENT,
4715 members.len(),
4716 pad
4717 )?;
4718 }
4719 writeln!(self.out, "}};")?;
4720 }
4721 _ => {
4722 let ty_name = TypeContext {
4723 handle,
4724 gctx: module.to_ctx(),
4725 names: &self.names,
4726 access: crate::StorageAccess::empty(),
4727 first_time: true,
4728 };
4729 writeln!(self.out, "typedef {ty_name} {name};")?;
4730 }
4731 }
4732 }
4733
4734 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4736 match type_key {
4737 &crate::PredeclaredType::ModfResult { size, scalar }
4738 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4739 let arg_type_name_owner;
4740 let arg_type_name = if let Some(size) = size {
4741 arg_type_name_owner = format!(
4742 "{NAMESPACE}::{}{}",
4743 if scalar.width == 8 { "double" } else { "float" },
4744 size as u8
4745 );
4746 &arg_type_name_owner
4747 } else if scalar.width == 8 {
4748 "double"
4749 } else {
4750 "float"
4751 };
4752
4753 let other_type_name_owner;
4754 let (defined_func_name, called_func_name, other_type_name) =
4755 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4756 (MODF_FUNCTION, "modf", arg_type_name)
4757 } else {
4758 let other_type_name = if let Some(size) = size {
4759 other_type_name_owner = format!("int{}", size as u8);
4760 &other_type_name_owner
4761 } else {
4762 "int"
4763 };
4764 (FREXP_FUNCTION, "frexp", other_type_name)
4765 };
4766
4767 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4768
4769 writeln!(self.out)?;
4770 writeln!(
4771 self.out,
4772 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4773 {other_type_name} other;
4774 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4775 return {struct_name}{{ fract, other }};
4776}}"
4777 )?;
4778 }
4779 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4780 let arg_type_name = scalar.to_msl_name();
4781 let called_func_name = "atomic_compare_exchange_weak_explicit";
4782 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4783 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4784
4785 writeln!(self.out)?;
4786
4787 for address_space_name in ["device", "threadgroup"] {
4788 writeln!(
4789 self.out,
4790 "\
4791template <typename A>
4792{struct_name} {defined_func_name}(
4793 {address_space_name} A *atomic_ptr,
4794 {arg_type_name} cmp,
4795 {arg_type_name} v
4796) {{
4797 bool swapped = {NAMESPACE}::{called_func_name}(
4798 atomic_ptr, &cmp, v,
4799 metal::memory_order_relaxed, metal::memory_order_relaxed
4800 );
4801 return {struct_name}{{cmp, swapped}};
4802}}"
4803 )?;
4804 }
4805 }
4806 }
4807 }
4808
4809 Ok(())
4810 }
4811
4812 fn write_global_constants(
4814 &mut self,
4815 module: &crate::Module,
4816 mod_info: &valid::ModuleInfo,
4817 ) -> BackendResult {
4818 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4819
4820 for (handle, constant) in constants {
4821 let ty_name = TypeContext {
4822 handle: constant.ty,
4823 gctx: module.to_ctx(),
4824 names: &self.names,
4825 access: crate::StorageAccess::empty(),
4826 first_time: false,
4827 };
4828 let name = &self.names[&NameKey::Constant(handle)];
4829 write!(self.out, "constant {ty_name} {name} = ")?;
4830 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4831 writeln!(self.out, ";")?;
4832 }
4833
4834 Ok(())
4835 }
4836
4837 fn put_inline_sampler_properties(
4838 &mut self,
4839 level: back::Level,
4840 sampler: &sm::InlineSampler,
4841 ) -> BackendResult {
4842 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4843 writeln!(
4844 self.out,
4845 "{}{}::{}_address::{},",
4846 level,
4847 NAMESPACE,
4848 letter,
4849 address.as_str(),
4850 )?;
4851 }
4852 writeln!(
4853 self.out,
4854 "{}{}::mag_filter::{},",
4855 level,
4856 NAMESPACE,
4857 sampler.mag_filter.as_str(),
4858 )?;
4859 writeln!(
4860 self.out,
4861 "{}{}::min_filter::{},",
4862 level,
4863 NAMESPACE,
4864 sampler.min_filter.as_str(),
4865 )?;
4866 if let Some(filter) = sampler.mip_filter {
4867 writeln!(
4868 self.out,
4869 "{}{}::mip_filter::{},",
4870 level,
4871 NAMESPACE,
4872 filter.as_str(),
4873 )?;
4874 }
4875 if sampler.border_color != sm::BorderColor::TransparentBlack {
4877 writeln!(
4878 self.out,
4879 "{}{}::border_color::{},",
4880 level,
4881 NAMESPACE,
4882 sampler.border_color.as_str(),
4883 )?;
4884 }
4885 if false {
4889 if let Some(ref lod) = sampler.lod_clamp {
4890 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4891 }
4892 if let Some(aniso) = sampler.max_anisotropy {
4893 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4894 }
4895 }
4896 if sampler.compare_func != sm::CompareFunc::Never {
4897 writeln!(
4898 self.out,
4899 "{}{}::compare_func::{},",
4900 level,
4901 NAMESPACE,
4902 sampler.compare_func.as_str(),
4903 )?;
4904 }
4905 writeln!(
4906 self.out,
4907 "{}{}::coord::{}",
4908 level,
4909 NAMESPACE,
4910 sampler.coord.as_str()
4911 )?;
4912 Ok(())
4913 }
4914
4915 fn write_unpacking_function(
4916 &mut self,
4917 format: back::msl::VertexFormat,
4918 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4919 use crate::{Scalar, VectorSize};
4920 use back::msl::VertexFormat::*;
4921 match format {
4922 Uint8 => {
4923 let name = self.namer.call("unpackUint8");
4924 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4925 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4926 writeln!(self.out, "}}")?;
4927 Ok((name, 1, None, Scalar::U32))
4928 }
4929 Uint8x2 => {
4930 let name = self.namer.call("unpackUint8x2");
4931 writeln!(
4932 self.out,
4933 "metal::uint2 {name}(metal::uchar b0, \
4934 metal::uchar b1) {{"
4935 )?;
4936 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4937 writeln!(self.out, "}}")?;
4938 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4939 }
4940 Uint8x4 => {
4941 let name = self.namer.call("unpackUint8x4");
4942 writeln!(
4943 self.out,
4944 "metal::uint4 {name}(metal::uchar b0, \
4945 metal::uchar b1, \
4946 metal::uchar b2, \
4947 metal::uchar b3) {{"
4948 )?;
4949 writeln!(
4950 self.out,
4951 "{}return metal::uint4(b0, b1, b2, b3);",
4952 back::INDENT
4953 )?;
4954 writeln!(self.out, "}}")?;
4955 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4956 }
4957 Sint8 => {
4958 let name = self.namer.call("unpackSint8");
4959 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4960 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4961 writeln!(self.out, "}}")?;
4962 Ok((name, 1, None, Scalar::I32))
4963 }
4964 Sint8x2 => {
4965 let name = self.namer.call("unpackSint8x2");
4966 writeln!(
4967 self.out,
4968 "metal::int2 {name}(metal::uchar b0, \
4969 metal::uchar b1) {{"
4970 )?;
4971 writeln!(
4972 self.out,
4973 "{}return metal::int2(as_type<char>(b0), \
4974 as_type<char>(b1));",
4975 back::INDENT
4976 )?;
4977 writeln!(self.out, "}}")?;
4978 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4979 }
4980 Sint8x4 => {
4981 let name = self.namer.call("unpackSint8x4");
4982 writeln!(
4983 self.out,
4984 "metal::int4 {name}(metal::uchar b0, \
4985 metal::uchar b1, \
4986 metal::uchar b2, \
4987 metal::uchar b3) {{"
4988 )?;
4989 writeln!(
4990 self.out,
4991 "{}return metal::int4(as_type<char>(b0), \
4992 as_type<char>(b1), \
4993 as_type<char>(b2), \
4994 as_type<char>(b3));",
4995 back::INDENT
4996 )?;
4997 writeln!(self.out, "}}")?;
4998 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4999 }
5000 Unorm8 => {
5001 let name = self.namer.call("unpackUnorm8");
5002 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5003 writeln!(
5004 self.out,
5005 "{}return float(float(b0) / 255.0f);",
5006 back::INDENT
5007 )?;
5008 writeln!(self.out, "}}")?;
5009 Ok((name, 1, None, Scalar::F32))
5010 }
5011 Unorm8x2 => {
5012 let name = self.namer.call("unpackUnorm8x2");
5013 writeln!(
5014 self.out,
5015 "metal::float2 {name}(metal::uchar b0, \
5016 metal::uchar b1) {{"
5017 )?;
5018 writeln!(
5019 self.out,
5020 "{}return metal::float2(float(b0) / 255.0f, \
5021 float(b1) / 255.0f);",
5022 back::INDENT
5023 )?;
5024 writeln!(self.out, "}}")?;
5025 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5026 }
5027 Unorm8x4 => {
5028 let name = self.namer.call("unpackUnorm8x4");
5029 writeln!(
5030 self.out,
5031 "metal::float4 {name}(metal::uchar b0, \
5032 metal::uchar b1, \
5033 metal::uchar b2, \
5034 metal::uchar b3) {{"
5035 )?;
5036 writeln!(
5037 self.out,
5038 "{}return metal::float4(float(b0) / 255.0f, \
5039 float(b1) / 255.0f, \
5040 float(b2) / 255.0f, \
5041 float(b3) / 255.0f);",
5042 back::INDENT
5043 )?;
5044 writeln!(self.out, "}}")?;
5045 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5046 }
5047 Snorm8 => {
5048 let name = self.namer.call("unpackSnorm8");
5049 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5050 writeln!(
5051 self.out,
5052 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5053 back::INDENT
5054 )?;
5055 writeln!(self.out, "}}")?;
5056 Ok((name, 1, None, Scalar::F32))
5057 }
5058 Snorm8x2 => {
5059 let name = self.namer.call("unpackSnorm8x2");
5060 writeln!(
5061 self.out,
5062 "metal::float2 {name}(metal::uchar b0, \
5063 metal::uchar b1) {{"
5064 )?;
5065 writeln!(
5066 self.out,
5067 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5068 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5069 back::INDENT
5070 )?;
5071 writeln!(self.out, "}}")?;
5072 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5073 }
5074 Snorm8x4 => {
5075 let name = self.namer.call("unpackSnorm8x4");
5076 writeln!(
5077 self.out,
5078 "metal::float4 {name}(metal::uchar b0, \
5079 metal::uchar b1, \
5080 metal::uchar b2, \
5081 metal::uchar b3) {{"
5082 )?;
5083 writeln!(
5084 self.out,
5085 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5086 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5087 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5088 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5089 back::INDENT
5090 )?;
5091 writeln!(self.out, "}}")?;
5092 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5093 }
5094 Uint16 => {
5095 let name = self.namer.call("unpackUint16");
5096 writeln!(
5097 self.out,
5098 "metal::uint {name}(metal::uint b0, \
5099 metal::uint b1) {{"
5100 )?;
5101 writeln!(
5102 self.out,
5103 "{}return metal::uint(b1 << 8 | b0);",
5104 back::INDENT
5105 )?;
5106 writeln!(self.out, "}}")?;
5107 Ok((name, 2, None, Scalar::U32))
5108 }
5109 Uint16x2 => {
5110 let name = self.namer.call("unpackUint16x2");
5111 writeln!(
5112 self.out,
5113 "metal::uint2 {name}(metal::uint b0, \
5114 metal::uint b1, \
5115 metal::uint b2, \
5116 metal::uint b3) {{"
5117 )?;
5118 writeln!(
5119 self.out,
5120 "{}return metal::uint2(b1 << 8 | b0, \
5121 b3 << 8 | b2);",
5122 back::INDENT
5123 )?;
5124 writeln!(self.out, "}}")?;
5125 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5126 }
5127 Uint16x4 => {
5128 let name = self.namer.call("unpackUint16x4");
5129 writeln!(
5130 self.out,
5131 "metal::uint4 {name}(metal::uint b0, \
5132 metal::uint b1, \
5133 metal::uint b2, \
5134 metal::uint b3, \
5135 metal::uint b4, \
5136 metal::uint b5, \
5137 metal::uint b6, \
5138 metal::uint b7) {{"
5139 )?;
5140 writeln!(
5141 self.out,
5142 "{}return metal::uint4(b1 << 8 | b0, \
5143 b3 << 8 | b2, \
5144 b5 << 8 | b4, \
5145 b7 << 8 | b6);",
5146 back::INDENT
5147 )?;
5148 writeln!(self.out, "}}")?;
5149 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5150 }
5151 Sint16 => {
5152 let name = self.namer.call("unpackSint16");
5153 writeln!(
5154 self.out,
5155 "int {name}(metal::ushort b0, \
5156 metal::ushort b1) {{"
5157 )?;
5158 writeln!(
5159 self.out,
5160 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5161 back::INDENT
5162 )?;
5163 writeln!(self.out, "}}")?;
5164 Ok((name, 2, None, Scalar::I32))
5165 }
5166 Sint16x2 => {
5167 let name = self.namer.call("unpackSint16x2");
5168 writeln!(
5169 self.out,
5170 "metal::int2 {name}(metal::ushort b0, \
5171 metal::ushort b1, \
5172 metal::ushort b2, \
5173 metal::ushort b3) {{"
5174 )?;
5175 writeln!(
5176 self.out,
5177 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5178 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5179 back::INDENT
5180 )?;
5181 writeln!(self.out, "}}")?;
5182 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5183 }
5184 Sint16x4 => {
5185 let name = self.namer.call("unpackSint16x4");
5186 writeln!(
5187 self.out,
5188 "metal::int4 {name}(metal::ushort b0, \
5189 metal::ushort b1, \
5190 metal::ushort b2, \
5191 metal::ushort b3, \
5192 metal::ushort b4, \
5193 metal::ushort b5, \
5194 metal::ushort b6, \
5195 metal::ushort b7) {{"
5196 )?;
5197 writeln!(
5198 self.out,
5199 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5200 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5201 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5202 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5203 back::INDENT
5204 )?;
5205 writeln!(self.out, "}}")?;
5206 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5207 }
5208 Unorm16 => {
5209 let name = self.namer.call("unpackUnorm16");
5210 writeln!(
5211 self.out,
5212 "float {name}(metal::ushort b0, \
5213 metal::ushort b1) {{"
5214 )?;
5215 writeln!(
5216 self.out,
5217 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5218 back::INDENT
5219 )?;
5220 writeln!(self.out, "}}")?;
5221 Ok((name, 2, None, Scalar::F32))
5222 }
5223 Unorm16x2 => {
5224 let name = self.namer.call("unpackUnorm16x2");
5225 writeln!(
5226 self.out,
5227 "metal::float2 {name}(metal::ushort b0, \
5228 metal::ushort b1, \
5229 metal::ushort b2, \
5230 metal::ushort b3) {{"
5231 )?;
5232 writeln!(
5233 self.out,
5234 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5235 float(b3 << 8 | b2) / 65535.0f);",
5236 back::INDENT
5237 )?;
5238 writeln!(self.out, "}}")?;
5239 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5240 }
5241 Unorm16x4 => {
5242 let name = self.namer.call("unpackUnorm16x4");
5243 writeln!(
5244 self.out,
5245 "metal::float4 {name}(metal::ushort b0, \
5246 metal::ushort b1, \
5247 metal::ushort b2, \
5248 metal::ushort b3, \
5249 metal::ushort b4, \
5250 metal::ushort b5, \
5251 metal::ushort b6, \
5252 metal::ushort b7) {{"
5253 )?;
5254 writeln!(
5255 self.out,
5256 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5257 float(b3 << 8 | b2) / 65535.0f, \
5258 float(b5 << 8 | b4) / 65535.0f, \
5259 float(b7 << 8 | b6) / 65535.0f);",
5260 back::INDENT
5261 )?;
5262 writeln!(self.out, "}}")?;
5263 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5264 }
5265 Snorm16 => {
5266 let name = self.namer.call("unpackSnorm16");
5267 writeln!(
5268 self.out,
5269 "float {name}(metal::ushort b0, \
5270 metal::ushort b1) {{"
5271 )?;
5272 writeln!(
5273 self.out,
5274 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5275 back::INDENT
5276 )?;
5277 writeln!(self.out, "}}")?;
5278 Ok((name, 2, None, Scalar::F32))
5279 }
5280 Snorm16x2 => {
5281 let name = self.namer.call("unpackSnorm16x2");
5282 writeln!(
5283 self.out,
5284 "metal::float2 {name}(uint b0, \
5285 uint b1, \
5286 uint b2, \
5287 uint b3) {{"
5288 )?;
5289 writeln!(
5290 self.out,
5291 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5292 back::INDENT
5293 )?;
5294 writeln!(self.out, "}}")?;
5295 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5296 }
5297 Snorm16x4 => {
5298 let name = self.namer.call("unpackSnorm16x4");
5299 writeln!(
5300 self.out,
5301 "metal::float4 {name}(uint b0, \
5302 uint b1, \
5303 uint b2, \
5304 uint b3, \
5305 uint b4, \
5306 uint b5, \
5307 uint b6, \
5308 uint b7) {{"
5309 )?;
5310 writeln!(
5311 self.out,
5312 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5313 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5314 back::INDENT
5315 )?;
5316 writeln!(self.out, "}}")?;
5317 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5318 }
5319 Float16 => {
5320 let name = self.namer.call("unpackFloat16");
5321 writeln!(
5322 self.out,
5323 "float {name}(metal::ushort b0, \
5324 metal::ushort b1) {{"
5325 )?;
5326 writeln!(
5327 self.out,
5328 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5329 back::INDENT
5330 )?;
5331 writeln!(self.out, "}}")?;
5332 Ok((name, 2, None, Scalar::F32))
5333 }
5334 Float16x2 => {
5335 let name = self.namer.call("unpackFloat16x2");
5336 writeln!(
5337 self.out,
5338 "metal::float2 {name}(metal::ushort b0, \
5339 metal::ushort b1, \
5340 metal::ushort b2, \
5341 metal::ushort b3) {{"
5342 )?;
5343 writeln!(
5344 self.out,
5345 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5346 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5347 back::INDENT
5348 )?;
5349 writeln!(self.out, "}}")?;
5350 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5351 }
5352 Float16x4 => {
5353 let name = self.namer.call("unpackFloat16x4");
5354 writeln!(
5355 self.out,
5356 "metal::float4 {name}(metal::ushort b0, \
5357 metal::ushort b1, \
5358 metal::ushort b2, \
5359 metal::ushort b3, \
5360 metal::ushort b4, \
5361 metal::ushort b5, \
5362 metal::ushort b6, \
5363 metal::ushort b7) {{"
5364 )?;
5365 writeln!(
5366 self.out,
5367 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5368 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5369 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5370 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5371 back::INDENT
5372 )?;
5373 writeln!(self.out, "}}")?;
5374 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5375 }
5376 Float32 => {
5377 let name = self.namer.call("unpackFloat32");
5378 writeln!(
5379 self.out,
5380 "float {name}(uint b0, \
5381 uint b1, \
5382 uint b2, \
5383 uint b3) {{"
5384 )?;
5385 writeln!(
5386 self.out,
5387 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5388 back::INDENT
5389 )?;
5390 writeln!(self.out, "}}")?;
5391 Ok((name, 4, None, Scalar::F32))
5392 }
5393 Float32x2 => {
5394 let name = self.namer.call("unpackFloat32x2");
5395 writeln!(
5396 self.out,
5397 "metal::float2 {name}(uint b0, \
5398 uint b1, \
5399 uint b2, \
5400 uint b3, \
5401 uint b4, \
5402 uint b5, \
5403 uint b6, \
5404 uint b7) {{"
5405 )?;
5406 writeln!(
5407 self.out,
5408 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5409 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5410 back::INDENT
5411 )?;
5412 writeln!(self.out, "}}")?;
5413 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5414 }
5415 Float32x3 => {
5416 let name = self.namer.call("unpackFloat32x3");
5417 writeln!(
5418 self.out,
5419 "metal::float3 {name}(uint b0, \
5420 uint b1, \
5421 uint b2, \
5422 uint b3, \
5423 uint b4, \
5424 uint b5, \
5425 uint b6, \
5426 uint b7, \
5427 uint b8, \
5428 uint b9, \
5429 uint b10, \
5430 uint b11) {{"
5431 )?;
5432 writeln!(
5433 self.out,
5434 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5435 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5436 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5437 back::INDENT
5438 )?;
5439 writeln!(self.out, "}}")?;
5440 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5441 }
5442 Float32x4 => {
5443 let name = self.namer.call("unpackFloat32x4");
5444 writeln!(
5445 self.out,
5446 "metal::float4 {name}(uint b0, \
5447 uint b1, \
5448 uint b2, \
5449 uint b3, \
5450 uint b4, \
5451 uint b5, \
5452 uint b6, \
5453 uint b7, \
5454 uint b8, \
5455 uint b9, \
5456 uint b10, \
5457 uint b11, \
5458 uint b12, \
5459 uint b13, \
5460 uint b14, \
5461 uint b15) {{"
5462 )?;
5463 writeln!(
5464 self.out,
5465 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5466 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5467 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5468 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5469 back::INDENT
5470 )?;
5471 writeln!(self.out, "}}")?;
5472 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5473 }
5474 Uint32 => {
5475 let name = self.namer.call("unpackUint32");
5476 writeln!(
5477 self.out,
5478 "uint {name}(uint b0, \
5479 uint b1, \
5480 uint b2, \
5481 uint b3) {{"
5482 )?;
5483 writeln!(
5484 self.out,
5485 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5486 back::INDENT
5487 )?;
5488 writeln!(self.out, "}}")?;
5489 Ok((name, 4, None, Scalar::U32))
5490 }
5491 Uint32x2 => {
5492 let name = self.namer.call("unpackUint32x2");
5493 writeln!(
5494 self.out,
5495 "uint2 {name}(uint b0, \
5496 uint b1, \
5497 uint b2, \
5498 uint b3, \
5499 uint b4, \
5500 uint b5, \
5501 uint b6, \
5502 uint b7) {{"
5503 )?;
5504 writeln!(
5505 self.out,
5506 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5507 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5508 back::INDENT
5509 )?;
5510 writeln!(self.out, "}}")?;
5511 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5512 }
5513 Uint32x3 => {
5514 let name = self.namer.call("unpackUint32x3");
5515 writeln!(
5516 self.out,
5517 "uint3 {name}(uint b0, \
5518 uint b1, \
5519 uint b2, \
5520 uint b3, \
5521 uint b4, \
5522 uint b5, \
5523 uint b6, \
5524 uint b7, \
5525 uint b8, \
5526 uint b9, \
5527 uint b10, \
5528 uint b11) {{"
5529 )?;
5530 writeln!(
5531 self.out,
5532 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5533 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5534 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5535 back::INDENT
5536 )?;
5537 writeln!(self.out, "}}")?;
5538 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5539 }
5540 Uint32x4 => {
5541 let name = self.namer.call("unpackUint32x4");
5542 writeln!(
5543 self.out,
5544 "{NAMESPACE}::uint4 {name}(uint b0, \
5545 uint b1, \
5546 uint b2, \
5547 uint b3, \
5548 uint b4, \
5549 uint b5, \
5550 uint b6, \
5551 uint b7, \
5552 uint b8, \
5553 uint b9, \
5554 uint b10, \
5555 uint b11, \
5556 uint b12, \
5557 uint b13, \
5558 uint b14, \
5559 uint b15) {{"
5560 )?;
5561 writeln!(
5562 self.out,
5563 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5564 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5565 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5566 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5567 back::INDENT
5568 )?;
5569 writeln!(self.out, "}}")?;
5570 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5571 }
5572 Sint32 => {
5573 let name = self.namer.call("unpackSint32");
5574 writeln!(
5575 self.out,
5576 "int {name}(uint b0, \
5577 uint b1, \
5578 uint b2, \
5579 uint b3) {{"
5580 )?;
5581 writeln!(
5582 self.out,
5583 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5584 back::INDENT
5585 )?;
5586 writeln!(self.out, "}}")?;
5587 Ok((name, 4, None, Scalar::I32))
5588 }
5589 Sint32x2 => {
5590 let name = self.namer.call("unpackSint32x2");
5591 writeln!(
5592 self.out,
5593 "metal::int2 {name}(uint b0, \
5594 uint b1, \
5595 uint b2, \
5596 uint b3, \
5597 uint b4, \
5598 uint b5, \
5599 uint b6, \
5600 uint b7) {{"
5601 )?;
5602 writeln!(
5603 self.out,
5604 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5605 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5606 back::INDENT
5607 )?;
5608 writeln!(self.out, "}}")?;
5609 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5610 }
5611 Sint32x3 => {
5612 let name = self.namer.call("unpackSint32x3");
5613 writeln!(
5614 self.out,
5615 "metal::int3 {name}(uint b0, \
5616 uint b1, \
5617 uint b2, \
5618 uint b3, \
5619 uint b4, \
5620 uint b5, \
5621 uint b6, \
5622 uint b7, \
5623 uint b8, \
5624 uint b9, \
5625 uint b10, \
5626 uint b11) {{"
5627 )?;
5628 writeln!(
5629 self.out,
5630 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5631 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5632 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5633 back::INDENT
5634 )?;
5635 writeln!(self.out, "}}")?;
5636 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5637 }
5638 Sint32x4 => {
5639 let name = self.namer.call("unpackSint32x4");
5640 writeln!(
5641 self.out,
5642 "metal::int4 {name}(uint b0, \
5643 uint b1, \
5644 uint b2, \
5645 uint b3, \
5646 uint b4, \
5647 uint b5, \
5648 uint b6, \
5649 uint b7, \
5650 uint b8, \
5651 uint b9, \
5652 uint b10, \
5653 uint b11, \
5654 uint b12, \
5655 uint b13, \
5656 uint b14, \
5657 uint b15) {{"
5658 )?;
5659 writeln!(
5660 self.out,
5661 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5662 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5663 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5664 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5665 back::INDENT
5666 )?;
5667 writeln!(self.out, "}}")?;
5668 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5669 }
5670 Unorm10_10_10_2 => {
5671 let name = self.namer.call("unpackUnorm10_10_10_2");
5672 writeln!(
5673 self.out,
5674 "metal::float4 {name}(uint b0, \
5675 uint b1, \
5676 uint b2, \
5677 uint b3) {{"
5678 )?;
5679 writeln!(
5680 self.out,
5681 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5693 back::INDENT
5694 )?;
5695 writeln!(self.out, "}}")?;
5696 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5697 }
5698 Unorm8x4Bgra => {
5699 let name = self.namer.call("unpackUnorm8x4Bgra");
5700 writeln!(
5701 self.out,
5702 "metal::float4 {name}(metal::uchar b0, \
5703 metal::uchar b1, \
5704 metal::uchar b2, \
5705 metal::uchar b3) {{"
5706 )?;
5707 writeln!(
5708 self.out,
5709 "{}return metal::float4(float(b2) / 255.0f, \
5710 float(b1) / 255.0f, \
5711 float(b0) / 255.0f, \
5712 float(b3) / 255.0f);",
5713 back::INDENT
5714 )?;
5715 writeln!(self.out, "}}")?;
5716 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5717 }
5718 }
5719 }
5720
5721 fn write_wrapped_unary_op(
5722 &mut self,
5723 module: &crate::Module,
5724 func_ctx: &back::FunctionCtx,
5725 op: crate::UnaryOperator,
5726 operand: Handle<crate::Expression>,
5727 ) -> BackendResult {
5728 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5729 match op {
5730 crate::UnaryOperator::Negate
5737 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5738 {
5739 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5740 return Ok(());
5741 };
5742 let wrapped = WrappedFunction::UnaryOp {
5743 op,
5744 ty: (vector_size, scalar),
5745 };
5746 if !self.wrapped_functions.insert(wrapped) {
5747 return Ok(());
5748 }
5749
5750 let unsigned_scalar = crate::Scalar {
5751 kind: crate::ScalarKind::Uint,
5752 ..scalar
5753 };
5754 let mut type_name = String::new();
5755 let mut unsigned_type_name = String::new();
5756 match vector_size {
5757 None => {
5758 put_numeric_type(&mut type_name, scalar, &[])?;
5759 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5760 }
5761 Some(size) => {
5762 put_numeric_type(&mut type_name, scalar, &[size])?;
5763 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5764 }
5765 };
5766
5767 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5768 let level = back::Level(1);
5769 writeln!(
5770 self.out,
5771 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5772 )?;
5773 writeln!(self.out, "}}")?;
5774 writeln!(self.out)?;
5775 }
5776 _ => {}
5777 }
5778 Ok(())
5779 }
5780
5781 fn write_wrapped_binary_op(
5782 &mut self,
5783 module: &crate::Module,
5784 func_ctx: &back::FunctionCtx,
5785 expr: Handle<crate::Expression>,
5786 op: crate::BinaryOperator,
5787 left: Handle<crate::Expression>,
5788 right: Handle<crate::Expression>,
5789 ) -> BackendResult {
5790 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5791 let left_ty = func_ctx.resolve_type(left, &module.types);
5792 let right_ty = func_ctx.resolve_type(right, &module.types);
5793 match (op, expr_ty.scalar_kind()) {
5794 (
5801 crate::BinaryOperator::Divide,
5802 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5803 ) => {
5804 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5805 return Ok(());
5806 };
5807 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5808 return Ok(());
5809 };
5810 let wrapped = WrappedFunction::BinaryOp {
5811 op,
5812 left_ty: left_wrapped_ty,
5813 right_ty: right_wrapped_ty,
5814 };
5815 if !self.wrapped_functions.insert(wrapped) {
5816 return Ok(());
5817 }
5818
5819 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5820 return Ok(());
5821 };
5822 let mut type_name = String::new();
5823 match vector_size {
5824 None => put_numeric_type(&mut type_name, scalar, &[])?,
5825 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5826 };
5827 writeln!(
5828 self.out,
5829 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5830 )?;
5831 let level = back::Level(1);
5832 match scalar.kind {
5833 crate::ScalarKind::Sint => {
5834 let min_val = match scalar.width {
5835 4 => crate::Literal::I32(i32::MIN),
5836 8 => crate::Literal::I64(i64::MIN),
5837 _ => {
5838 return Err(Error::GenericValidation(format!(
5839 "Unexpected width for scalar {scalar:?}"
5840 )));
5841 }
5842 };
5843 write!(
5844 self.out,
5845 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5846 )?;
5847 self.put_literal(min_val)?;
5848 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5849 }
5850 crate::ScalarKind::Uint => writeln!(
5851 self.out,
5852 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5853 )?,
5854 _ => unreachable!(),
5855 }
5856 writeln!(self.out, "}}")?;
5857 writeln!(self.out)?;
5858 }
5859 (
5872 crate::BinaryOperator::Modulo,
5873 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5874 ) => {
5875 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5876 return Ok(());
5877 };
5878 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5879 else {
5880 return Ok(());
5881 };
5882 let wrapped = WrappedFunction::BinaryOp {
5883 op,
5884 left_ty: left_wrapped_ty,
5885 right_ty: (right_vector_size, right_scalar),
5886 };
5887 if !self.wrapped_functions.insert(wrapped) {
5888 return Ok(());
5889 }
5890
5891 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5892 return Ok(());
5893 };
5894 let mut type_name = String::new();
5895 match vector_size {
5896 None => put_numeric_type(&mut type_name, scalar, &[])?,
5897 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5898 };
5899 let mut rhs_type_name = String::new();
5900 match right_vector_size {
5901 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5902 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5903 };
5904
5905 writeln!(
5906 self.out,
5907 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5908 )?;
5909 let level = back::Level(1);
5910 match scalar.kind {
5911 crate::ScalarKind::Sint => {
5912 let min_val = match scalar.width {
5913 4 => crate::Literal::I32(i32::MIN),
5914 8 => crate::Literal::I64(i64::MIN),
5915 _ => {
5916 return Err(Error::GenericValidation(format!(
5917 "Unexpected width for scalar {scalar:?}"
5918 )));
5919 }
5920 };
5921 write!(
5922 self.out,
5923 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5924 )?;
5925 self.put_literal(min_val)?;
5926 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5927 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5928 }
5929 crate::ScalarKind::Uint => writeln!(
5930 self.out,
5931 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5932 )?,
5933 _ => unreachable!(),
5934 }
5935 writeln!(self.out, "}}")?;
5936 writeln!(self.out)?;
5937 }
5938 _ => {}
5939 }
5940 Ok(())
5941 }
5942
5943 fn get_dot_wrapper_function_helper_name(
5949 &self,
5950 scalar: crate::Scalar,
5951 size: crate::VectorSize,
5952 ) -> String {
5953 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5955
5956 let type_name = scalar.to_msl_name();
5957 let size_suffix = common::vector_size_str(size);
5958 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5959 }
5960
5961 #[allow(clippy::too_many_arguments)]
5962 fn write_wrapped_math_function(
5963 &mut self,
5964 module: &crate::Module,
5965 func_ctx: &back::FunctionCtx,
5966 fun: crate::MathFunction,
5967 arg: Handle<crate::Expression>,
5968 _arg1: Option<Handle<crate::Expression>>,
5969 _arg2: Option<Handle<crate::Expression>>,
5970 _arg3: Option<Handle<crate::Expression>>,
5971 ) -> BackendResult {
5972 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5973 match fun {
5974 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5982 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5983 return Ok(());
5984 };
5985 let wrapped = WrappedFunction::Math {
5986 fun,
5987 arg_ty: (vector_size, scalar),
5988 };
5989 if !self.wrapped_functions.insert(wrapped) {
5990 return Ok(());
5991 }
5992
5993 let unsigned_scalar = crate::Scalar {
5994 kind: crate::ScalarKind::Uint,
5995 ..scalar
5996 };
5997 let mut type_name = String::new();
5998 let mut unsigned_type_name = String::new();
5999 match vector_size {
6000 None => {
6001 put_numeric_type(&mut type_name, scalar, &[])?;
6002 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
6003 }
6004 Some(size) => {
6005 put_numeric_type(&mut type_name, scalar, &[size])?;
6006 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
6007 }
6008 };
6009
6010 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
6011 let level = back::Level(1);
6012 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
6013 writeln!(self.out, "}}")?;
6014 writeln!(self.out)?;
6015 }
6016
6017 crate::MathFunction::Dot => match *arg_ty {
6018 crate::TypeInner::Vector { size, scalar }
6019 if matches!(
6020 scalar.kind,
6021 crate::ScalarKind::Sint | crate::ScalarKind::Uint
6022 ) =>
6023 {
6024 let wrapped = WrappedFunction::Math {
6026 fun,
6027 arg_ty: (Some(size), scalar),
6028 };
6029 if !self.wrapped_functions.insert(wrapped) {
6030 return Ok(());
6031 }
6032
6033 let mut vec_ty = String::new();
6034 put_numeric_type(&mut vec_ty, scalar, &[size])?;
6035 let mut ret_ty = String::new();
6036 put_numeric_type(&mut ret_ty, scalar, &[])?;
6037
6038 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6039
6040 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6042 let level = back::Level(1);
6043 write!(self.out, "{level}return ")?;
6044 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6045 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6046 Ok(())
6047 })?;
6048 writeln!(self.out, ";")?;
6049 writeln!(self.out, "}}")?;
6050 writeln!(self.out)?;
6051 }
6052 _ => {}
6053 },
6054
6055 _ => {}
6056 }
6057 Ok(())
6058 }
6059
6060 fn write_wrapped_cast(
6061 &mut self,
6062 module: &crate::Module,
6063 func_ctx: &back::FunctionCtx,
6064 expr: Handle<crate::Expression>,
6065 kind: crate::ScalarKind,
6066 convert: Option<crate::Bytes>,
6067 ) -> BackendResult {
6068 let src_ty = func_ctx.resolve_type(expr, &module.types);
6079 let Some(width) = convert else {
6080 return Ok(());
6081 };
6082 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6083 return Ok(());
6084 };
6085 let dst_scalar = crate::Scalar { kind, width };
6086 if src_scalar.kind != crate::ScalarKind::Float
6087 || (dst_scalar.kind != crate::ScalarKind::Sint
6088 && dst_scalar.kind != crate::ScalarKind::Uint)
6089 {
6090 return Ok(());
6091 }
6092 let wrapped = WrappedFunction::Cast {
6093 src_scalar,
6094 vector_size,
6095 dst_scalar,
6096 };
6097 if !self.wrapped_functions.insert(wrapped) {
6098 return Ok(());
6099 }
6100 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6101
6102 let mut src_type_name = String::new();
6103 match vector_size {
6104 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6105 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6106 };
6107 let mut dst_type_name = String::new();
6108 match vector_size {
6109 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6110 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6111 };
6112 let fun_name = match dst_scalar {
6113 crate::Scalar::I32 => F2I32_FUNCTION,
6114 crate::Scalar::U32 => F2U32_FUNCTION,
6115 crate::Scalar::I64 => F2I64_FUNCTION,
6116 crate::Scalar::U64 => F2U64_FUNCTION,
6117 _ => unreachable!(),
6118 };
6119
6120 writeln!(
6121 self.out,
6122 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6123 )?;
6124 let level = back::Level(1);
6125 write!(
6126 self.out,
6127 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6128 )?;
6129 self.put_literal(min)?;
6130 write!(self.out, ", ")?;
6131 self.put_literal(max)?;
6132 writeln!(self.out, "));")?;
6133 writeln!(self.out, "}}")?;
6134 writeln!(self.out)?;
6135 Ok(())
6136 }
6137
6138 fn write_convert_yuv_to_rgb_and_return(
6146 &mut self,
6147 level: back::Level,
6148 y: &str,
6149 uv: &str,
6150 params: &str,
6151 ) -> BackendResult {
6152 let l1 = level;
6153 let l2 = l1.next();
6154
6155 writeln!(
6157 self.out,
6158 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6159 )?;
6160
6161 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6164 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6165 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6166 writeln!(
6167 self.out,
6168 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6169 )?;
6170
6171 writeln!(
6174 self.out,
6175 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6176 )?;
6177
6178 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6181 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6182 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6183 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6184
6185 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6186 Ok(())
6187 }
6188
6189 #[allow(clippy::too_many_arguments)]
6190 fn write_wrapped_image_load(
6191 &mut self,
6192 module: &crate::Module,
6193 func_ctx: &back::FunctionCtx,
6194 image: Handle<crate::Expression>,
6195 _coordinate: Handle<crate::Expression>,
6196 _array_index: Option<Handle<crate::Expression>>,
6197 _sample: Option<Handle<crate::Expression>>,
6198 _level: Option<Handle<crate::Expression>>,
6199 ) -> BackendResult {
6200 let class = match *func_ctx.resolve_type(image, &module.types) {
6202 crate::TypeInner::Image { class, .. } => class,
6203 _ => unreachable!(),
6204 };
6205 if class != crate::ImageClass::External {
6206 return Ok(());
6207 }
6208 let wrapped = WrappedFunction::ImageLoad { class };
6209 if !self.wrapped_functions.insert(wrapped) {
6210 return Ok(());
6211 }
6212
6213 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6214 let l1 = back::Level(1);
6215 let l2 = l1.next();
6216 let l3 = l2.next();
6217 writeln!(
6218 self.out,
6219 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6220 )?;
6221 writeln!(
6225 self.out,
6226 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6227 )?;
6228 writeln!(
6229 self.out,
6230 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6231 )?;
6232
6233 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6235 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6236 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6238 writeln!(self.out, "{l1}}} else {{")?;
6239
6240 writeln!(
6242 self.out,
6243 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6244 )?;
6245 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6246
6247 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6249
6250 writeln!(self.out, "{l2}float2 uv;")?;
6251 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6252 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6254 writeln!(self.out, "{l2}}} else {{")?;
6255 writeln!(
6257 self.out,
6258 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6259 )?;
6260 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6261 writeln!(
6262 self.out,
6263 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6264 )?;
6265 writeln!(self.out, "{l2}}}")?;
6266
6267 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6268
6269 writeln!(self.out, "{l1}}}")?;
6270 writeln!(self.out, "}}")?;
6271 writeln!(self.out)?;
6272 Ok(())
6273 }
6274
6275 #[allow(clippy::too_many_arguments)]
6276 fn write_wrapped_image_sample(
6277 &mut self,
6278 module: &crate::Module,
6279 func_ctx: &back::FunctionCtx,
6280 image: Handle<crate::Expression>,
6281 _sampler: Handle<crate::Expression>,
6282 _gather: Option<crate::SwizzleComponent>,
6283 _coordinate: Handle<crate::Expression>,
6284 _array_index: Option<Handle<crate::Expression>>,
6285 _offset: Option<Handle<crate::Expression>>,
6286 _level: crate::SampleLevel,
6287 _depth_ref: Option<Handle<crate::Expression>>,
6288 clamp_to_edge: bool,
6289 ) -> BackendResult {
6290 if !clamp_to_edge {
6293 return Ok(());
6294 }
6295 let class = match *func_ctx.resolve_type(image, &module.types) {
6296 crate::TypeInner::Image { class, .. } => class,
6297 _ => unreachable!(),
6298 };
6299 let wrapped = WrappedFunction::ImageSample {
6300 class,
6301 clamp_to_edge: true,
6302 };
6303 if !self.wrapped_functions.insert(wrapped) {
6304 return Ok(());
6305 }
6306 match class {
6307 crate::ImageClass::External => {
6308 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6309 let l1 = back::Level(1);
6310 let l2 = l1.next();
6311 let l3 = l2.next();
6312 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6313 writeln!(
6314 self.out,
6315 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6316 )?;
6317
6318 writeln!(
6326 self.out,
6327 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6328 )?;
6329 writeln!(
6330 self.out,
6331 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6332 )?;
6333 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6334 writeln!(
6335 self.out,
6336 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6337 )?;
6338 writeln!(
6339 self.out,
6340 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6341 )?;
6342 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6343 writeln!(
6345 self.out,
6346 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6347 )?;
6348 writeln!(self.out, "{l1}}} else {{")?;
6349 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6350 writeln!(
6351 self.out,
6352 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6353 )?;
6354 writeln!(
6355 self.out,
6356 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6357 )?;
6358
6359 writeln!(
6361 self.out,
6362 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6363 )?;
6364 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6365 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6366 writeln!(
6368 self.out,
6369 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6370 )?;
6371 writeln!(self.out, "{l2}}} else {{")?;
6372 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6374 writeln!(
6375 self.out,
6376 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6377 )?;
6378 writeln!(
6379 self.out,
6380 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6381 )?;
6382 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6383 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6384 writeln!(self.out, "{l2}}}")?;
6385
6386 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6387
6388 writeln!(self.out, "{l1}}}")?;
6389 writeln!(self.out, "}}")?;
6390 writeln!(self.out)?;
6391 }
6392 _ => {
6393 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6394 let l1 = back::Level(1);
6395 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6396 writeln!(
6397 self.out,
6398 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6399 )?;
6400 writeln!(self.out, "}}")?;
6401 writeln!(self.out)?;
6402 }
6403 }
6404 Ok(())
6405 }
6406
6407 fn write_wrapped_image_query(
6408 &mut self,
6409 module: &crate::Module,
6410 func_ctx: &back::FunctionCtx,
6411 image: Handle<crate::Expression>,
6412 query: crate::ImageQuery,
6413 ) -> BackendResult {
6414 if !matches!(query, crate::ImageQuery::Size { .. }) {
6416 return Ok(());
6417 }
6418 let class = match *func_ctx.resolve_type(image, &module.types) {
6419 crate::TypeInner::Image { class, .. } => class,
6420 _ => unreachable!(),
6421 };
6422 if class != crate::ImageClass::External {
6423 return Ok(());
6424 }
6425 let wrapped = WrappedFunction::ImageQuerySize { class };
6426 if !self.wrapped_functions.insert(wrapped) {
6427 return Ok(());
6428 }
6429 writeln!(
6430 self.out,
6431 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6432 )?;
6433 let l1 = back::Level(1);
6434 let l2 = l1.next();
6435 writeln!(
6436 self.out,
6437 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6438 )?;
6439 writeln!(self.out, "{l2}return tex.params.size;")?;
6440 writeln!(self.out, "{l1}}} else {{")?;
6441 writeln!(
6443 self.out,
6444 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6445 )?;
6446 writeln!(self.out, "{l1}}}")?;
6447 writeln!(self.out, "}}")?;
6448 writeln!(self.out)?;
6449 Ok(())
6450 }
6451
6452 fn write_wrapped_cooperative_load(
6453 &mut self,
6454 module: &crate::Module,
6455 func_ctx: &back::FunctionCtx,
6456 columns: crate::CooperativeSize,
6457 rows: crate::CooperativeSize,
6458 pointer: Handle<crate::Expression>,
6459 ) -> BackendResult {
6460 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6461 let space = ptr_ty.pointer_space().unwrap();
6462 let space_name = space.to_msl_name().unwrap_or_default();
6463 let scalar = ptr_ty
6464 .pointer_base_type()
6465 .unwrap()
6466 .inner_with(&module.types)
6467 .scalar()
6468 .unwrap();
6469 let wrapped = WrappedFunction::CooperativeLoad {
6470 space_name,
6471 columns,
6472 rows,
6473 scalar,
6474 };
6475 if !self.wrapped_functions.insert(wrapped) {
6476 return Ok(());
6477 }
6478 let scalar_name = scalar.to_msl_name();
6479 writeln!(
6480 self.out,
6481 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6482 columns as u32, rows as u32,
6483 )?;
6484 let l1 = back::Level(1);
6485 writeln!(
6486 self.out,
6487 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6488 columns as u32, rows as u32
6489 )?;
6490 let matrix_origin = "0";
6491 writeln!(
6492 self.out,
6493 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6494 )?;
6495 writeln!(self.out, "{l1}return m;")?;
6496 writeln!(self.out, "}}")?;
6497 writeln!(self.out)?;
6498 Ok(())
6499 }
6500
6501 fn write_wrapped_cooperative_multiply_add(
6502 &mut self,
6503 module: &crate::Module,
6504 func_ctx: &back::FunctionCtx,
6505 space: crate::AddressSpace,
6506 a: Handle<crate::Expression>,
6507 b: Handle<crate::Expression>,
6508 ) -> BackendResult {
6509 let space_name = space.to_msl_name().unwrap_or_default();
6510 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6511 crate::TypeInner::CooperativeMatrix {
6512 columns,
6513 rows,
6514 scalar,
6515 ..
6516 } => (columns, rows, scalar),
6517 _ => unreachable!(),
6518 };
6519 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6520 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6521 _ => unreachable!(),
6522 };
6523 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6524 space_name,
6525 columns: b_c,
6526 rows: a_r,
6527 intermediate: a_c,
6528 scalar,
6529 };
6530 if !self.wrapped_functions.insert(wrapped) {
6531 return Ok(());
6532 }
6533 let scalar_name = scalar.to_msl_name();
6534 writeln!(
6535 self.out,
6536 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6537 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6538 )?;
6539 let l1 = back::Level(1);
6540 writeln!(
6541 self.out,
6542 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6543 b_c as u32, a_r as u32
6544 )?;
6545 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6546 writeln!(self.out, "{l1}return d;")?;
6547 writeln!(self.out, "}}")?;
6548 writeln!(self.out)?;
6549 Ok(())
6550 }
6551
6552 pub(super) fn write_wrapped_functions(
6553 &mut self,
6554 module: &crate::Module,
6555 func_ctx: &back::FunctionCtx,
6556 ) -> BackendResult {
6557 for (expr_handle, expr) in func_ctx.expressions.iter() {
6558 match *expr {
6559 crate::Expression::Unary { op, expr: operand } => {
6560 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6561 }
6562 crate::Expression::Binary { op, left, right } => {
6563 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6564 }
6565 crate::Expression::Math {
6566 fun,
6567 arg,
6568 arg1,
6569 arg2,
6570 arg3,
6571 } => {
6572 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6573 }
6574 crate::Expression::As {
6575 expr,
6576 kind,
6577 convert,
6578 } => {
6579 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6580 }
6581 crate::Expression::ImageLoad {
6582 image,
6583 coordinate,
6584 array_index,
6585 sample,
6586 level,
6587 } => {
6588 self.write_wrapped_image_load(
6589 module,
6590 func_ctx,
6591 image,
6592 coordinate,
6593 array_index,
6594 sample,
6595 level,
6596 )?;
6597 }
6598 crate::Expression::ImageSample {
6599 image,
6600 sampler,
6601 gather,
6602 coordinate,
6603 array_index,
6604 offset,
6605 level,
6606 depth_ref,
6607 clamp_to_edge,
6608 } => {
6609 self.write_wrapped_image_sample(
6610 module,
6611 func_ctx,
6612 image,
6613 sampler,
6614 gather,
6615 coordinate,
6616 array_index,
6617 offset,
6618 level,
6619 depth_ref,
6620 clamp_to_edge,
6621 )?;
6622 }
6623 crate::Expression::ImageQuery { image, query } => {
6624 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6625 }
6626 crate::Expression::CooperativeLoad {
6627 columns,
6628 rows,
6629 role: _,
6630 ref data,
6631 } => {
6632 self.write_wrapped_cooperative_load(
6633 module,
6634 func_ctx,
6635 columns,
6636 rows,
6637 data.pointer,
6638 )?;
6639 }
6640 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6641 let space = crate::AddressSpace::Private;
6642 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6643 }
6644 _ => {}
6645 }
6646 }
6647
6648 Ok(())
6649 }
6650
6651 fn write_functions(
6653 &mut self,
6654 module: &crate::Module,
6655 mod_info: &valid::ModuleInfo,
6656 options: &Options,
6657 pipeline_options: &PipelineOptions,
6658 ) -> Result<TranslationInfo, Error> {
6659 use back::msl::VertexFormat;
6660
6661 struct AttributeMappingResolved {
6664 ty_name: String,
6665 dimension: Option<crate::VectorSize>,
6666 scalar: crate::Scalar,
6667 name: String,
6668 }
6669 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6670
6671 struct VertexBufferMappingResolved<'a> {
6672 id: u32,
6673 stride: u32,
6674 step_mode: back::msl::VertexBufferStepMode,
6675 ty_name: String,
6676 param_name: String,
6677 elem_name: String,
6678 attributes: &'a Vec<back::msl::AttributeMapping>,
6679 }
6680 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6681
6682 struct UnpackingFunction {
6684 name: String,
6685 byte_count: u32,
6686 dimension: Option<crate::VectorSize>,
6687 scalar: crate::Scalar,
6688 }
6689 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6690
6691 let mut needs_vertex_id = false;
6697 let v_id = self.namer.call("v_id");
6698
6699 let mut needs_instance_id = false;
6700 let i_id = self.namer.call("i_id");
6701 if pipeline_options.vertex_pulling_transform {
6702 for vbm in &pipeline_options.vertex_buffer_mappings {
6703 let buffer_id = vbm.id;
6704 let buffer_stride = vbm.stride;
6705
6706 assert!(
6707 buffer_stride > 0,
6708 "Vertex pulling requires a non-zero buffer stride."
6709 );
6710
6711 match vbm.step_mode {
6712 back::msl::VertexBufferStepMode::Constant => {}
6713 back::msl::VertexBufferStepMode::ByVertex => {
6714 needs_vertex_id = true;
6715 }
6716 back::msl::VertexBufferStepMode::ByInstance => {
6717 needs_instance_id = true;
6718 }
6719 }
6720
6721 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6722 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6723 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6724
6725 vbm_resolved.push(VertexBufferMappingResolved {
6726 id: buffer_id,
6727 stride: buffer_stride,
6728 step_mode: vbm.step_mode,
6729 ty_name: buffer_ty,
6730 param_name: buffer_param,
6731 elem_name: buffer_elem,
6732 attributes: &vbm.attributes,
6733 });
6734
6735 for attribute in &vbm.attributes {
6737 if unpacking_functions.contains_key(&attribute.format) {
6738 continue;
6739 }
6740 let (name, byte_count, dimension, scalar) =
6741 match self.write_unpacking_function(attribute.format) {
6742 Ok((name, byte_count, dimension, scalar)) => {
6743 (name, byte_count, dimension, scalar)
6744 }
6745 _ => {
6746 continue;
6747 }
6748 };
6749 unpacking_functions.insert(
6750 attribute.format,
6751 UnpackingFunction {
6752 name,
6753 byte_count,
6754 dimension,
6755 scalar,
6756 },
6757 );
6758 }
6759 }
6760 }
6761
6762 let mut pass_through_globals = Vec::new();
6763 for (fun_handle, fun) in module.functions.iter() {
6764 log::trace!(
6765 "function {:?}, handle {:?}",
6766 fun.name.as_deref().unwrap_or("(anonymous)"),
6767 fun_handle
6768 );
6769
6770 let ctx = back::FunctionCtx {
6771 ty: back::FunctionType::Function(fun_handle),
6772 info: &mod_info[fun_handle],
6773 expressions: &fun.expressions,
6774 named_expressions: &fun.named_expressions,
6775 };
6776
6777 writeln!(self.out)?;
6778 self.write_wrapped_functions(module, &ctx)?;
6779
6780 let fun_info = &mod_info[fun_handle];
6781 pass_through_globals.clear();
6782 let mut needs_buffer_sizes = false;
6783 for (handle, var) in module.global_variables.iter() {
6784 if !fun_info[handle].is_empty() {
6785 if var.space.needs_pass_through() {
6786 pass_through_globals.push(handle);
6787 }
6788 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6789 }
6790 }
6791
6792 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6793 match fun.result {
6794 Some(ref result) => {
6795 let ty_name = TypeContext {
6796 handle: result.ty,
6797 gctx: module.to_ctx(),
6798 names: &self.names,
6799 access: crate::StorageAccess::empty(),
6800 first_time: false,
6801 };
6802 write!(self.out, "{ty_name}")?;
6803 }
6804 None => {
6805 write!(self.out, "void")?;
6806 }
6807 }
6808 writeln!(self.out, " {fun_name}(")?;
6809
6810 for (index, arg) in fun.arguments.iter().enumerate() {
6811 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6812 let param_type_name = TypeContext {
6813 handle: arg.ty,
6814 gctx: module.to_ctx(),
6815 names: &self.names,
6816 access: crate::StorageAccess::empty(),
6817 first_time: false,
6818 };
6819 let separator = separate(
6820 !pass_through_globals.is_empty()
6821 || index + 1 != fun.arguments.len()
6822 || needs_buffer_sizes,
6823 );
6824 writeln!(
6825 self.out,
6826 "{}{} {}{}",
6827 back::INDENT,
6828 param_type_name,
6829 name,
6830 separator
6831 )?;
6832 }
6833 for (index, &handle) in pass_through_globals.iter().enumerate() {
6834 let tyvar = TypedGlobalVariable {
6835 module,
6836 names: &self.names,
6837 handle,
6838 usage: fun_info[handle],
6839 reference: true,
6840 };
6841 let separator =
6842 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6843 write!(self.out, "{}", back::INDENT)?;
6844 tyvar.try_fmt(&mut self.out)?;
6845 writeln!(self.out, "{separator}")?;
6846 }
6847
6848 if needs_buffer_sizes {
6849 writeln!(
6850 self.out,
6851 "{}constant _mslBufferSizes& _buffer_sizes",
6852 back::INDENT
6853 )?;
6854 }
6855
6856 writeln!(self.out, ") {{")?;
6857
6858 let guarded_indices =
6859 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6860
6861 let context = StatementContext {
6862 expression: ExpressionContext {
6863 function: fun,
6864 origin: FunctionOrigin::Handle(fun_handle),
6865 info: fun_info,
6866 lang_version: options.lang_version,
6867 policies: options.bounds_check_policies,
6868 guarded_indices,
6869 module,
6870 mod_info,
6871 pipeline_options,
6872 force_loop_bounding: options.force_loop_bounding,
6873 },
6874 result_struct: None,
6875 };
6876
6877 self.put_locals(&context.expression)?;
6878 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6879 self.put_block(back::Level(1), &fun.body, &context)?;
6880 writeln!(self.out, "}}")?;
6881 self.named_expressions.clear();
6882 }
6883
6884 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6885 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6886
6887 let mut info = TranslationInfo {
6888 entry_point_names: Vec::with_capacity(ep_range.len()),
6889 };
6890
6891 for ep_index in ep_range {
6892 let ep = &module.entry_points[ep_index];
6893 let fun = &ep.function;
6894 let fun_info = mod_info.get_entry_point(ep_index);
6895 let mut ep_error = None;
6896
6897 let mut v_existing_id = None;
6901 let mut i_existing_id = None;
6902
6903 log::trace!(
6904 "entry point {:?}, index {:?}",
6905 fun.name.as_deref().unwrap_or("(anonymous)"),
6906 ep_index
6907 );
6908
6909 let ctx = back::FunctionCtx {
6910 ty: back::FunctionType::EntryPoint(ep_index as u16),
6911 info: fun_info,
6912 expressions: &fun.expressions,
6913 named_expressions: &fun.named_expressions,
6914 };
6915
6916 self.write_wrapped_functions(module, &ctx)?;
6917
6918 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6919 crate::ShaderStage::Vertex => (
6920 "vertex",
6921 LocationMode::VertexInput,
6922 LocationMode::VertexOutput,
6923 true,
6924 ),
6925 crate::ShaderStage::Fragment => (
6926 "fragment",
6927 LocationMode::FragmentInput,
6928 LocationMode::FragmentOutput,
6929 false,
6930 ),
6931 crate::ShaderStage::Compute => (
6932 "kernel",
6933 LocationMode::Uniform,
6934 LocationMode::Uniform,
6935 false,
6936 ),
6937 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6938 crate::ShaderStage::RayGeneration
6939 | crate::ShaderStage::AnyHit
6940 | crate::ShaderStage::ClosestHit
6941 | crate::ShaderStage::Miss => unimplemented!(),
6942 };
6943
6944 let do_vertex_pulling = can_vertex_pull
6946 && pipeline_options.vertex_pulling_transform
6947 && !pipeline_options.vertex_buffer_mappings.is_empty();
6948
6949 let needs_buffer_sizes = do_vertex_pulling
6951 || module
6952 .global_variables
6953 .iter()
6954 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6955 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6956
6957 if !options.fake_missing_bindings {
6960 for (var_handle, var) in module.global_variables.iter() {
6961 if fun_info[var_handle].is_empty() {
6962 continue;
6963 }
6964 match var.space {
6965 crate::AddressSpace::Uniform
6966 | crate::AddressSpace::Storage { .. }
6967 | crate::AddressSpace::Handle => {
6968 let br = match var.binding {
6969 Some(ref br) => br,
6970 None => {
6971 let var_name = var.name.clone().unwrap_or_default();
6972 ep_error =
6973 Some(super::EntryPointError::MissingBinding(var_name));
6974 break;
6975 }
6976 };
6977 let target = options.get_resource_binding_target(ep, br);
6978 let good = match target {
6979 Some(target) => {
6980 match module.types[var.ty].inner {
6984 crate::TypeInner::Image {
6985 class: crate::ImageClass::External,
6986 ..
6987 } => target.external_texture.is_some(),
6988 crate::TypeInner::Image { .. } => target.texture.is_some(),
6989 crate::TypeInner::Sampler { .. } => {
6990 target.sampler.is_some()
6991 }
6992 _ => target.buffer.is_some(),
6993 }
6994 }
6995 None => false,
6996 };
6997 if !good {
6998 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6999 break;
7000 }
7001 }
7002 crate::AddressSpace::Immediate => {
7003 if let Err(e) = options.resolve_immediates(ep) {
7004 ep_error = Some(e);
7005 break;
7006 }
7007 }
7008 crate::AddressSpace::TaskPayload => {
7009 unimplemented!()
7010 }
7011 crate::AddressSpace::Function
7012 | crate::AddressSpace::Private
7013 | crate::AddressSpace::WorkGroup => {}
7014 crate::AddressSpace::RayPayload
7015 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
7016 }
7017 }
7018 if needs_buffer_sizes {
7019 if let Err(err) = options.resolve_sizes_buffer(ep) {
7020 ep_error = Some(err);
7021 }
7022 }
7023 }
7024
7025 if let Some(err) = ep_error {
7026 info.entry_point_names.push(Err(err));
7027 continue;
7028 }
7029 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
7030 info.entry_point_names.push(Ok(fun_name.clone()));
7031
7032 writeln!(self.out)?;
7033
7034 let mut flattened_member_names = FastHashMap::default();
7040 let mut varyings_namer = proc::Namer::default();
7042
7043 let mut flattened_arguments = Vec::new();
7048 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7049 match module.types[arg.ty].inner {
7050 crate::TypeInner::Struct { ref members, .. } => {
7051 for (member_index, member) in members.iter().enumerate() {
7052 let member_index = member_index as u32;
7053 flattened_arguments.push((
7054 NameKey::StructMember(arg.ty, member_index),
7055 member.ty,
7056 member.binding.as_ref(),
7057 ));
7058 let name_key = NameKey::StructMember(arg.ty, member_index);
7059 let name = match member.binding {
7060 Some(crate::Binding::Location { .. }) => {
7061 if do_vertex_pulling {
7062 self.namer.call(&self.names[&name_key])
7063 } else {
7064 varyings_namer.call(&self.names[&name_key])
7065 }
7066 }
7067 _ => self.namer.call(&self.names[&name_key]),
7068 };
7069 flattened_member_names.insert(name_key, name);
7070 }
7071 }
7072 _ => flattened_arguments.push((
7073 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7074 arg.ty,
7075 arg.binding.as_ref(),
7076 )),
7077 }
7078 }
7079
7080 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7085 let varyings_member_name = self.namer.call("varyings");
7086 let mut has_varyings = false;
7087 if !flattened_arguments.is_empty() {
7088 if !do_vertex_pulling {
7089 writeln!(self.out, "struct {stage_in_name} {{")?;
7090 }
7091 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7092 let Some(binding) = binding else {
7093 continue;
7094 };
7095 let name = match *name_key {
7096 NameKey::StructMember(..) => &flattened_member_names[name_key],
7097 _ => &self.names[name_key],
7098 };
7099 let ty_name = TypeContext {
7100 handle: ty,
7101 gctx: module.to_ctx(),
7102 names: &self.names,
7103 access: crate::StorageAccess::empty(),
7104 first_time: false,
7105 };
7106 let resolved = options.resolve_local_binding(binding, in_mode)?;
7107 let location = match *binding {
7108 crate::Binding::Location { location, .. } => Some(location),
7109 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7110 crate::Binding::BuiltIn(_) => continue,
7111 };
7112 if do_vertex_pulling {
7113 let Some(location) = location else {
7114 continue;
7115 };
7116 am_resolved.insert(
7118 location,
7119 AttributeMappingResolved {
7120 ty_name: ty_name.to_string(),
7121 dimension: ty_name.vector_size(),
7122 scalar: ty_name.scalar().unwrap(),
7123 name: name.to_string(),
7124 },
7125 );
7126 } else {
7127 has_varyings = true;
7128 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7129 resolved.try_fmt(&mut self.out)?;
7130 writeln!(self.out, ";")?;
7131 }
7132 }
7133 if !do_vertex_pulling {
7134 writeln!(self.out, "}};")?;
7135 }
7136 }
7137
7138 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7141 let result_member_name = self.namer.call("member");
7142 let result_type_name = match fun.result {
7143 Some(ref result) => {
7144 let mut result_members = Vec::new();
7145 if let crate::TypeInner::Struct { ref members, .. } =
7146 module.types[result.ty].inner
7147 {
7148 for (member_index, member) in members.iter().enumerate() {
7149 result_members.push((
7150 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7151 member.ty,
7152 member.binding.as_ref(),
7153 ));
7154 }
7155 } else {
7156 result_members.push((
7157 &result_member_name,
7158 result.ty,
7159 result.binding.as_ref(),
7160 ));
7161 }
7162
7163 writeln!(self.out, "struct {stage_out_name} {{")?;
7164 let mut has_point_size = false;
7165 for (name, ty, binding) in result_members {
7166 let ty_name = TypeContext {
7167 handle: ty,
7168 gctx: module.to_ctx(),
7169 names: &self.names,
7170 access: crate::StorageAccess::empty(),
7171 first_time: true,
7172 };
7173 let binding = binding.ok_or_else(|| {
7174 Error::GenericValidation("Expected binding, got None".into())
7175 })?;
7176
7177 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7178 has_point_size = true;
7179 if !pipeline_options.allow_and_force_point_size {
7180 continue;
7181 }
7182 }
7183
7184 let array_len = match module.types[ty].inner {
7185 crate::TypeInner::Array {
7186 size: crate::ArraySize::Constant(size),
7187 ..
7188 } => Some(size),
7189 _ => None,
7190 };
7191 let resolved = options.resolve_local_binding(binding, out_mode)?;
7192 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7193 resolved.try_fmt(&mut self.out)?;
7194 if let Some(array_len) = array_len {
7195 write!(self.out, " [{array_len}]")?;
7196 }
7197 writeln!(self.out, ";")?;
7198 }
7199
7200 if pipeline_options.allow_and_force_point_size
7201 && ep.stage == crate::ShaderStage::Vertex
7202 && !has_point_size
7203 {
7204 writeln!(
7206 self.out,
7207 "{}float _point_size [[point_size]];",
7208 back::INDENT
7209 )?;
7210 }
7211 writeln!(self.out, "}};")?;
7212 &stage_out_name
7213 }
7214 None => "void",
7215 };
7216
7217 if do_vertex_pulling {
7220 for vbm in &vbm_resolved {
7221 let buffer_stride = vbm.stride;
7222 let buffer_ty = &vbm.ty_name;
7223
7224 writeln!(
7228 self.out,
7229 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7230 )?;
7231 }
7232 }
7233
7234 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
7236
7237 let mut is_first_argument = true;
7238 let mut separator = || {
7239 if is_first_argument {
7240 is_first_argument = false;
7241 ' '
7242 } else {
7243 ','
7244 }
7245 };
7246
7247 if has_varyings {
7250 writeln!(
7251 self.out,
7252 "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
7253 separator()
7254 )?;
7255 }
7256
7257 let mut local_invocation_id = None;
7258
7259 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7262 let binding = match binding {
7263 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7264 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7265 _ => continue,
7266 };
7267 let name = match *name_key {
7268 NameKey::StructMember(..) => &flattened_member_names[name_key],
7269 _ => &self.names[name_key],
7270 };
7271
7272 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
7273 local_invocation_id = Some(name_key);
7274 }
7275
7276 let ty_name = TypeContext {
7277 handle: ty,
7278 gctx: module.to_ctx(),
7279 names: &self.names,
7280 access: crate::StorageAccess::empty(),
7281 first_time: false,
7282 };
7283
7284 match *binding {
7285 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7286 v_existing_id = Some(name.clone());
7287 }
7288 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7289 i_existing_id = Some(name.clone());
7290 }
7291 _ => {}
7292 };
7293
7294 let resolved = options.resolve_local_binding(binding, in_mode)?;
7295 write!(self.out, "{} {ty_name} {name}", separator())?;
7296 resolved.try_fmt(&mut self.out)?;
7297 writeln!(self.out)?;
7298 }
7299
7300 let need_workgroup_variables_initialization =
7301 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7302
7303 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7304 writeln!(
7305 self.out,
7306 "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7307 )?;
7308 }
7309
7310 for (handle, var) in module.global_variables.iter() {
7315 let usage = fun_info[handle];
7316 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7317 continue;
7318 }
7319
7320 if options.lang_version < (1, 2) {
7321 match var.space {
7322 crate::AddressSpace::Storage { access }
7332 if access.contains(crate::StorageAccess::STORE)
7333 && ep.stage == crate::ShaderStage::Fragment =>
7334 {
7335 return Err(Error::UnsupportedWriteableStorageBuffer)
7336 }
7337 crate::AddressSpace::Handle => {
7338 match module.types[var.ty].inner {
7339 crate::TypeInner::Image {
7340 class: crate::ImageClass::Storage { access, .. },
7341 ..
7342 } => {
7343 if access.contains(crate::StorageAccess::STORE)
7353 && (ep.stage == crate::ShaderStage::Vertex
7354 || ep.stage == crate::ShaderStage::Fragment)
7355 {
7356 return Err(Error::UnsupportedWriteableStorageTexture(
7357 ep.stage,
7358 ));
7359 }
7360
7361 if access.contains(
7362 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7363 ) {
7364 return Err(Error::UnsupportedRWStorageTexture);
7365 }
7366 }
7367 _ => {}
7368 }
7369 }
7370 _ => {}
7371 }
7372 }
7373
7374 match var.space {
7376 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7377 crate::TypeInner::BindingArray { base, .. } => {
7378 match module.types[base].inner {
7379 crate::TypeInner::Sampler { .. } => {
7380 if options.lang_version < (2, 0) {
7381 return Err(Error::UnsupportedArrayOf(
7382 "samplers".to_string(),
7383 ));
7384 }
7385 }
7386 crate::TypeInner::Image { class, .. } => match class {
7387 crate::ImageClass::Sampled { .. }
7388 | crate::ImageClass::Depth { .. }
7389 | crate::ImageClass::Storage {
7390 access: crate::StorageAccess::LOAD,
7391 ..
7392 } => {
7393 if options.lang_version < (2, 0) {
7398 return Err(Error::UnsupportedArrayOf(
7399 "textures".to_string(),
7400 ));
7401 }
7402 }
7403 crate::ImageClass::Storage {
7404 access: crate::StorageAccess::STORE,
7405 ..
7406 } => {
7407 if options.lang_version < (2, 0) {
7412 return Err(Error::UnsupportedArrayOf(
7413 "write-only textures".to_string(),
7414 ));
7415 }
7416 }
7417 crate::ImageClass::Storage { .. } => {
7418 if options.lang_version < (3, 0) {
7419 return Err(Error::UnsupportedArrayOf(
7420 "read-write textures".to_string(),
7421 ));
7422 }
7423 }
7424 crate::ImageClass::External => {
7425 return Err(Error::UnsupportedArrayOf(
7426 "external textures".to_string(),
7427 ));
7428 }
7429 },
7430 _ => {
7431 return Err(Error::UnsupportedArrayOfType(base));
7432 }
7433 }
7434 }
7435 _ => {}
7436 },
7437 _ => {}
7438 }
7439
7440 let resolved = match var.space {
7442 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7443 crate::AddressSpace::WorkGroup => None,
7444 _ => options
7445 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7446 .ok(),
7447 };
7448 if let Some(ref resolved) = resolved {
7449 if resolved.as_inline_sampler(options).is_some() {
7451 continue;
7452 }
7453 }
7454
7455 match module.types[var.ty].inner {
7456 crate::TypeInner::Image {
7457 class: crate::ImageClass::External,
7458 ..
7459 } => {
7460 let target = match resolved {
7464 Some(back::msl::ResolvedBinding::Resource(target)) => {
7465 target.external_texture
7466 }
7467 _ => None,
7468 };
7469
7470 for i in 0..3 {
7471 write!(self.out, "{} ", separator())?;
7472
7473 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7474 handle,
7475 ExternalTextureNameKey::Plane(i),
7476 )];
7477 write!(
7478 self.out,
7479 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7480 )?;
7481 if let Some(ref target) = target {
7482 write!(self.out, " [[texture({})]]", target.planes[i])?;
7483 }
7484 writeln!(self.out)?;
7485 }
7486 let params_ty_name = &self.names
7487 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7488 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7489 handle,
7490 ExternalTextureNameKey::Params,
7491 )];
7492 write!(self.out, "{} ", separator())?;
7493 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7494 if let Some(ref target) = target {
7495 write!(self.out, " [[buffer({})]]", target.params)?;
7496 }
7497 }
7498 _ => {
7499 let tyvar = TypedGlobalVariable {
7500 module,
7501 names: &self.names,
7502 handle,
7503 usage,
7504 reference: true,
7505 };
7506 write!(self.out, "{} ", separator())?;
7507 tyvar.try_fmt(&mut self.out)?;
7508 if let Some(resolved) = resolved {
7509 resolved.try_fmt(&mut self.out)?;
7510 }
7511 if let Some(value) = var.init {
7512 write!(self.out, " = ")?;
7513 self.put_const_expression(
7514 value,
7515 module,
7516 mod_info,
7517 &module.global_expressions,
7518 )?;
7519 }
7520 }
7521 }
7522 writeln!(self.out)?;
7523 }
7524
7525 if do_vertex_pulling {
7526 if needs_vertex_id && v_existing_id.is_none() {
7527 writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7529 }
7530
7531 if needs_instance_id && i_existing_id.is_none() {
7532 writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7533 }
7534
7535 for vbm in &vbm_resolved {
7538 let id = &vbm.id;
7539 let ty_name = &vbm.ty_name;
7540 let param_name = &vbm.param_name;
7541 writeln!(
7542 self.out,
7543 "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7544 separator()
7545 )?;
7546 }
7547 }
7548
7549 if needs_buffer_sizes {
7552 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7554 write!(
7555 self.out,
7556 "{} constant _mslBufferSizes& _buffer_sizes",
7557 separator()
7558 )?;
7559 resolved.try_fmt(&mut self.out)?;
7560 writeln!(self.out)?;
7561 }
7562
7563 writeln!(self.out, ") {{")?;
7565
7566 if do_vertex_pulling {
7568 for vbm in &vbm_resolved {
7571 for attribute in vbm.attributes {
7572 let location = attribute.shader_location;
7573 let am_option = am_resolved.get(&location);
7574 if am_option.is_none() {
7575 continue;
7578 }
7579 let am = am_option.unwrap();
7580 let attribute_ty_name = &am.ty_name;
7581 let attribute_name = &am.name;
7582
7583 writeln!(
7584 self.out,
7585 "{}{attribute_ty_name} {attribute_name} = {{}};",
7586 back::Level(1)
7587 )?;
7588 }
7589
7590 write!(self.out, "{}if (", back::Level(1))?;
7593
7594 let idx = &vbm.id;
7595 let stride = &vbm.stride;
7596 let index_name = match vbm.step_mode {
7597 back::msl::VertexBufferStepMode::Constant => "0",
7598 back::msl::VertexBufferStepMode::ByVertex => {
7599 if let Some(ref name) = v_existing_id {
7600 name
7601 } else {
7602 &v_id
7603 }
7604 }
7605 back::msl::VertexBufferStepMode::ByInstance => {
7606 if let Some(ref name) = i_existing_id {
7607 name
7608 } else {
7609 &i_id
7610 }
7611 }
7612 };
7613 write!(
7614 self.out,
7615 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7616 )?;
7617
7618 writeln!(self.out, ") {{")?;
7619
7620 let ty_name = &vbm.ty_name;
7622 let elem_name = &vbm.elem_name;
7623 let param_name = &vbm.param_name;
7624
7625 writeln!(
7626 self.out,
7627 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7628 back::Level(2),
7629 )?;
7630
7631 for attribute in vbm.attributes {
7634 let location = attribute.shader_location;
7635 let Some(am) = am_resolved.get(&location) else {
7636 continue;
7640 };
7641 let attribute_name = &am.name;
7642 let attribute_ty_name = &am.ty_name;
7643
7644 let offset = attribute.offset;
7645 let func = unpacking_functions
7646 .get(&attribute.format)
7647 .expect("Should have generated this unpacking function earlier.");
7648 let func_name = &func.name;
7649
7650 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7656
7657 let needs_conversion = am.scalar != func.scalar;
7660
7661 if needs_padding_or_truncation != Ordering::Equal {
7662 writeln!(
7665 self.out,
7666 "{}// {attribute_ty_name} <- {:?}",
7667 back::Level(2),
7668 attribute.format
7669 )?;
7670 }
7671
7672 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7673
7674 if needs_padding_or_truncation == Ordering::Greater {
7675 write!(self.out, "{attribute_ty_name}(")?;
7677 }
7678
7679 if needs_conversion {
7681 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7682 write!(self.out, "(")?;
7683 }
7684 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7685 for i in (offset + 1)..(offset + func.byte_count) {
7686 write!(self.out, ", {elem_name}.data[{i}]")?;
7687 }
7688 write!(self.out, ")")?;
7689 if needs_conversion {
7690 write!(self.out, ")")?;
7691 }
7692
7693 match needs_padding_or_truncation {
7694 Ordering::Greater => {
7695 let ty_is_int = scalar_is_int(am.scalar);
7697 let zero_value = if ty_is_int { "0" } else { "0.0" };
7698 let one_value = if ty_is_int { "1" } else { "1.0" };
7699 for i in func.dimension.map_or(1, u8::from)
7700 ..am.dimension.map_or(1, u8::from)
7701 {
7702 write!(
7703 self.out,
7704 ", {}",
7705 if i == 3 { one_value } else { zero_value }
7706 )?;
7707 }
7708 }
7709 Ordering::Less => {
7710 write!(
7712 self.out,
7713 ".{}",
7714 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7715 )?;
7716 }
7717 Ordering::Equal => {}
7718 }
7719
7720 if needs_padding_or_truncation == Ordering::Greater {
7721 write!(self.out, ")")?;
7722 }
7723
7724 writeln!(self.out, ";")?;
7725 }
7726
7727 writeln!(self.out, "{}}}", back::Level(1))?;
7729 }
7730 }
7731
7732 if need_workgroup_variables_initialization {
7733 self.write_workgroup_variables_initialization(
7734 module,
7735 mod_info,
7736 fun_info,
7737 local_invocation_id,
7738 )?;
7739 }
7740
7741 for (handle, var) in module.global_variables.iter() {
7744 let usage = fun_info[handle];
7745 if usage.is_empty() {
7746 continue;
7747 }
7748 if var.space == crate::AddressSpace::Private {
7749 let tyvar = TypedGlobalVariable {
7750 module,
7751 names: &self.names,
7752 handle,
7753 usage,
7754
7755 reference: false,
7756 };
7757 write!(self.out, "{}", back::INDENT)?;
7758 tyvar.try_fmt(&mut self.out)?;
7759 match var.init {
7760 Some(value) => {
7761 write!(self.out, " = ")?;
7762 self.put_const_expression(
7763 value,
7764 module,
7765 mod_info,
7766 &module.global_expressions,
7767 )?;
7768 writeln!(self.out, ";")?;
7769 }
7770 None => {
7771 writeln!(self.out, " = {{}};")?;
7772 }
7773 };
7774 } else if let Some(ref binding) = var.binding {
7775 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7776 if let Some(sampler) = resolved.as_inline_sampler(options) {
7777 let name = &self.names[&NameKey::GlobalVariable(handle)];
7779 writeln!(
7780 self.out,
7781 "{}constexpr {}::sampler {}(",
7782 back::INDENT,
7783 NAMESPACE,
7784 name
7785 )?;
7786 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7787 writeln!(self.out, "{});", back::INDENT)?;
7788 } else if let crate::TypeInner::Image {
7789 class: crate::ImageClass::External,
7790 ..
7791 } = module.types[var.ty].inner
7792 {
7793 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7796 let l1 = back::Level(1);
7797 let l2 = l1.next();
7798 writeln!(
7799 self.out,
7800 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7801 )?;
7802 for i in 0..3 {
7803 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7804 handle,
7805 ExternalTextureNameKey::Plane(i),
7806 )];
7807 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7808 }
7809 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7810 handle,
7811 ExternalTextureNameKey::Params,
7812 )];
7813 writeln!(self.out, "{l2}.params = {params_name},")?;
7814 writeln!(self.out, "{l1}}};")?;
7815 }
7816 }
7817 }
7818
7819 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7830 let arg_name =
7831 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7832 match module.types[arg.ty].inner {
7833 crate::TypeInner::Struct { ref members, .. } => {
7834 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7835 write!(
7836 self.out,
7837 "{}const {} {} = {{ ",
7838 back::INDENT,
7839 struct_name,
7840 arg_name
7841 )?;
7842 for (member_index, member) in members.iter().enumerate() {
7843 let key = NameKey::StructMember(arg.ty, member_index as u32);
7844 let name = &flattened_member_names[&key];
7845 if member_index != 0 {
7846 write!(self.out, ", ")?;
7847 }
7848 if self
7850 .struct_member_pads
7851 .contains(&(arg.ty, member_index as u32))
7852 {
7853 write!(self.out, "{{}}, ")?;
7854 }
7855 if let Some(crate::Binding::Location { .. }) = member.binding {
7856 if has_varyings {
7857 write!(self.out, "{varyings_member_name}.")?;
7858 }
7859 }
7860 write!(self.out, "{name}")?;
7861 }
7862 writeln!(self.out, " }};")?;
7863 }
7864 _ => match arg.binding {
7865 Some(crate::Binding::Location { .. })
7866 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
7867 if has_varyings {
7868 writeln!(
7869 self.out,
7870 "{}const auto {} = {}.{};",
7871 back::INDENT,
7872 arg_name,
7873 varyings_member_name,
7874 arg_name
7875 )?;
7876 }
7877 }
7878 _ => {}
7879 },
7880 }
7881 }
7882
7883 let guarded_indices =
7884 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7885
7886 let context = StatementContext {
7887 expression: ExpressionContext {
7888 function: fun,
7889 origin: FunctionOrigin::EntryPoint(ep_index as _),
7890 info: fun_info,
7891 lang_version: options.lang_version,
7892 policies: options.bounds_check_policies,
7893 guarded_indices,
7894 module,
7895 mod_info,
7896 pipeline_options,
7897 force_loop_bounding: options.force_loop_bounding,
7898 },
7899 result_struct: Some(&stage_out_name),
7900 };
7901
7902 self.put_locals(&context.expression)?;
7905 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7906 self.put_block(back::Level(1), &fun.body, &context)?;
7907 writeln!(self.out, "}}")?;
7908 if ep_index + 1 != module.entry_points.len() {
7909 writeln!(self.out)?;
7910 }
7911 self.named_expressions.clear();
7912 }
7913
7914 Ok(info)
7915 }
7916
7917 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7918 if flags.is_empty() {
7921 writeln!(
7922 self.out,
7923 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7924 )?;
7925 }
7926 if flags.contains(crate::Barrier::STORAGE) {
7927 writeln!(
7928 self.out,
7929 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7930 )?;
7931 }
7932 if flags.contains(crate::Barrier::WORK_GROUP) {
7933 writeln!(
7934 self.out,
7935 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7936 )?;
7937 }
7938 if flags.contains(crate::Barrier::SUB_GROUP) {
7939 writeln!(
7940 self.out,
7941 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7942 )?;
7943 }
7944 if flags.contains(crate::Barrier::TEXTURE) {
7945 writeln!(
7946 self.out,
7947 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7948 )?;
7949 }
7950 Ok(())
7951 }
7952}
7953
7954mod workgroup_mem_init {
7957 use crate::EntryPoint;
7958
7959 use super::*;
7960
7961 enum Access {
7962 GlobalVariable(Handle<crate::GlobalVariable>),
7963 StructMember(Handle<crate::Type>, u32),
7964 Array(usize),
7965 }
7966
7967 impl Access {
7968 fn write<W: Write>(
7969 &self,
7970 writer: &mut W,
7971 names: &FastHashMap<NameKey, String>,
7972 ) -> Result<(), core::fmt::Error> {
7973 match *self {
7974 Access::GlobalVariable(handle) => {
7975 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7976 }
7977 Access::StructMember(handle, index) => {
7978 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7979 }
7980 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7981 }
7982 }
7983 }
7984
7985 struct AccessStack {
7986 stack: Vec<Access>,
7987 array_depth: usize,
7988 }
7989
7990 impl AccessStack {
7991 const fn new() -> Self {
7992 Self {
7993 stack: Vec::new(),
7994 array_depth: 0,
7995 }
7996 }
7997
7998 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7999 let array_depth = self.array_depth;
8000 self.stack.push(Access::Array(array_depth));
8001 self.array_depth += 1;
8002 let res = cb(self, array_depth);
8003 self.stack.pop();
8004 self.array_depth -= 1;
8005 res
8006 }
8007
8008 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8009 self.stack.push(new);
8010 let res = cb(self);
8011 self.stack.pop();
8012 res
8013 }
8014
8015 fn write<W: Write>(
8016 &self,
8017 writer: &mut W,
8018 names: &FastHashMap<NameKey, String>,
8019 ) -> Result<(), core::fmt::Error> {
8020 for next in self.stack.iter() {
8021 next.write(writer, names)?;
8022 }
8023 Ok(())
8024 }
8025 }
8026
8027 impl<W: Write> Writer<W> {
8028 pub(super) fn need_workgroup_variables_initialization(
8029 &mut self,
8030 options: &Options,
8031 ep: &EntryPoint,
8032 module: &crate::Module,
8033 fun_info: &valid::FunctionInfo,
8034 ) -> bool {
8035 options.zero_initialize_workgroup_memory
8036 && ep.stage.compute_like()
8037 && module.global_variables.iter().any(|(handle, var)| {
8038 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
8039 })
8040 }
8041
8042 pub(super) fn write_workgroup_variables_initialization(
8043 &mut self,
8044 module: &crate::Module,
8045 module_info: &valid::ModuleInfo,
8046 fun_info: &valid::FunctionInfo,
8047 local_invocation_id: Option<&NameKey>,
8048 ) -> BackendResult {
8049 let level = back::Level(1);
8050
8051 writeln!(
8052 self.out,
8053 "{}if ({}::all({} == {}::uint3(0u))) {{",
8054 level,
8055 NAMESPACE,
8056 local_invocation_id
8057 .map(|name_key| self.names[name_key].as_str())
8058 .unwrap_or("__local_invocation_id"),
8059 NAMESPACE,
8060 )?;
8061
8062 let mut access_stack = AccessStack::new();
8063
8064 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8065 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
8066 });
8067
8068 for (handle, var) in vars {
8069 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8070 self.write_workgroup_variable_initialization(
8071 module,
8072 module_info,
8073 var.ty,
8074 access_stack,
8075 level.next(),
8076 )
8077 })?;
8078 }
8079
8080 writeln!(self.out, "{level}}}")?;
8081 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8082 }
8083
8084 fn write_workgroup_variable_initialization(
8085 &mut self,
8086 module: &crate::Module,
8087 module_info: &valid::ModuleInfo,
8088 ty: Handle<crate::Type>,
8089 access_stack: &mut AccessStack,
8090 level: back::Level,
8091 ) -> BackendResult {
8092 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8093 write!(self.out, "{level}")?;
8094 access_stack.write(&mut self.out, &self.names)?;
8095 writeln!(self.out, " = {{}};")?;
8096 } else {
8097 match module.types[ty].inner {
8098 crate::TypeInner::Atomic { .. } => {
8099 write!(
8100 self.out,
8101 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8102 )?;
8103 access_stack.write(&mut self.out, &self.names)?;
8104 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8105 }
8106 crate::TypeInner::Array { base, size, .. } => {
8107 let count = match size.resolve(module.to_ctx())? {
8108 proc::IndexableLength::Known(count) => count,
8109 proc::IndexableLength::Dynamic => unreachable!(),
8110 };
8111
8112 access_stack.enter_array(|access_stack, array_depth| {
8113 writeln!(
8114 self.out,
8115 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8116 )?;
8117 self.write_workgroup_variable_initialization(
8118 module,
8119 module_info,
8120 base,
8121 access_stack,
8122 level.next(),
8123 )?;
8124 writeln!(self.out, "{level}}}")?;
8125 BackendResult::Ok(())
8126 })?;
8127 }
8128 crate::TypeInner::Struct { ref members, .. } => {
8129 for (index, member) in members.iter().enumerate() {
8130 access_stack.enter(
8131 Access::StructMember(ty, index as u32),
8132 |access_stack| {
8133 self.write_workgroup_variable_initialization(
8134 module,
8135 module_info,
8136 member.ty,
8137 access_stack,
8138 level,
8139 )
8140 },
8141 )?;
8142 }
8143 }
8144 _ => unreachable!(),
8145 }
8146 }
8147
8148 Ok(())
8149 }
8150 }
8151}
8152
8153impl crate::AtomicFunction {
8154 const fn to_msl(self) -> &'static str {
8155 match self {
8156 Self::Add => "fetch_add",
8157 Self::Subtract => "fetch_sub",
8158 Self::And => "fetch_and",
8159 Self::InclusiveOr => "fetch_or",
8160 Self::ExclusiveOr => "fetch_xor",
8161 Self::Min => "fetch_min",
8162 Self::Max => "fetch_max",
8163 Self::Exchange { compare: None } => "exchange",
8164 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8165 }
8166 }
8167
8168 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8169 Ok(match self {
8170 Self::Min => "min",
8171 Self::Max => "max",
8172 _ => Err(Error::FeatureNotImplemented(
8173 "64-bit atomic operation other than min/max".to_string(),
8174 ))?,
8175 })
8176 }
8177}