1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self, concrete_int_scalars,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68 "nagaTextureSampleBaseClampToEdge";
69pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
83pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
84
85fn put_numeric_type(
94 out: &mut impl Write,
95 scalar: crate::Scalar,
96 sizes: &[crate::VectorSize],
97) -> Result<(), FmtError> {
98 match (scalar, sizes) {
99 (scalar, &[]) => {
100 write!(out, "{}", scalar.to_msl_name())
101 }
102 (scalar, &[rows]) => {
103 write!(
104 out,
105 "{}::{}{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(rows)
109 )
110 }
111 (scalar, &[rows, columns]) => {
112 write!(
113 out,
114 "{}::{}{}x{}",
115 NAMESPACE,
116 scalar.to_msl_name(),
117 common::vector_size_str(columns),
118 common::vector_size_str(rows)
119 )
120 }
121 (_, _) => Ok(()), }
123}
124
125const fn scalar_is_int(scalar: crate::Scalar) -> bool {
126 use crate::ScalarKind::*;
127 match scalar.kind {
128 Sint | Uint | AbstractInt | Bool => true,
129 Float | AbstractFloat => false,
130 }
131}
132
133const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
135
136const REINTERPRET_PREFIX: &str = "reinterpreted_";
138
139struct ClampedLod(Handle<crate::Expression>);
145
146impl Display for ClampedLod {
147 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
148 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
149 }
150}
151
152struct ArraySizeMember(Handle<crate::GlobalVariable>);
167
168impl Display for ArraySizeMember {
169 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170 self.0.write_prefixed(f, "size")
171 }
172}
173
174#[derive(Clone, Copy)]
179struct Reinterpreted<'a> {
180 target_type: &'a str,
181 orig: Handle<crate::Expression>,
182}
183
184impl<'a> Reinterpreted<'a> {
185 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
186 Self { target_type, orig }
187 }
188}
189
190impl Display for Reinterpreted<'_> {
191 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
192 f.write_str(REINTERPRET_PREFIX)?;
193 f.write_str(self.target_type)?;
194 self.orig.write_prefixed(f, "_e")
195 }
196}
197
198struct TypeContext<'a> {
199 handle: Handle<crate::Type>,
200 gctx: proc::GlobalCtx<'a>,
201 names: &'a FastHashMap<NameKey, String>,
202 access: crate::StorageAccess,
203 first_time: bool,
204}
205
206impl TypeContext<'_> {
207 fn scalar(&self) -> Option<crate::Scalar> {
208 let ty = &self.gctx.types[self.handle];
209 ty.inner.scalar()
210 }
211
212 fn vertex_input_dimension(&self) -> u32 {
213 let ty = &self.gctx.types[self.handle];
214 match ty.inner {
215 crate::TypeInner::Scalar(_) => 1,
216 crate::TypeInner::Vector { size, .. } => size as u32,
217 _ => unreachable!(),
218 }
219 }
220}
221
222impl Display for TypeContext<'_> {
223 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224 let ty = &self.gctx.types[self.handle];
225 if ty.needs_alias() && !self.first_time {
226 let name = &self.names[&NameKey::Type(self.handle)];
227 return write!(out, "{name}");
228 }
229
230 match ty.inner {
231 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232 crate::TypeInner::Atomic(scalar) => {
233 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234 }
235 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236 crate::TypeInner::Matrix {
237 columns,
238 rows,
239 scalar,
240 } => put_numeric_type(out, scalar, &[rows, columns]),
241 crate::TypeInner::CooperativeMatrix {
243 columns,
244 rows,
245 scalar,
246 role: _,
247 } => {
248 write!(
249 out,
250 "{NAMESPACE}::simdgroup_{}{}x{}",
251 scalar.to_msl_name(),
252 columns as u32,
253 rows as u32,
254 )
255 }
256 crate::TypeInner::Pointer { base, space } => {
257 let sub = Self {
258 handle: base,
259 first_time: false,
260 ..*self
261 };
262 let space_name = match space.to_msl_name() {
263 Some(name) => name,
264 None => return Ok(()),
265 };
266 write!(out, "{space_name} {sub}&")
267 }
268 crate::TypeInner::ValuePointer {
269 size,
270 scalar,
271 space,
272 } => {
273 match space.to_msl_name() {
274 Some(name) => write!(out, "{name} ")?,
275 None => return Ok(()),
276 };
277 match size {
278 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279 None => put_numeric_type(out, scalar, &[])?,
280 };
281
282 write!(out, "&")
283 }
284 crate::TypeInner::Array { base, .. } => {
285 let sub = Self {
286 handle: base,
287 first_time: false,
288 ..*self
289 };
290 write!(out, "{sub}")
293 }
294 crate::TypeInner::Struct { .. } => unreachable!(),
295 crate::TypeInner::Image {
296 dim,
297 arrayed,
298 class,
299 } => {
300 let dim_str = match dim {
301 crate::ImageDimension::D1 => "1d",
302 crate::ImageDimension::D2 => "2d",
303 crate::ImageDimension::D3 => "3d",
304 crate::ImageDimension::Cube => "cube",
305 };
306 let (texture_str, msaa_str, scalar, access) = match class {
307 crate::ImageClass::Sampled { kind, multi } => {
308 let (msaa_str, access) = if multi {
309 ("_ms", "read")
310 } else {
311 ("", "sample")
312 };
313 let scalar = crate::Scalar { kind, width: 4 };
314 ("texture", msaa_str, scalar, access)
315 }
316 crate::ImageClass::Depth { multi } => {
317 let (msaa_str, access) = if multi {
318 ("_ms", "read")
319 } else {
320 ("", "sample")
321 };
322 let scalar = crate::Scalar {
323 kind: crate::ScalarKind::Float,
324 width: 4,
325 };
326 ("depth", msaa_str, scalar, access)
327 }
328 crate::ImageClass::Storage { format, .. } => {
329 let access = if self
330 .access
331 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332 {
333 "read_write"
334 } else if self.access.contains(crate::StorageAccess::STORE) {
335 "write"
336 } else if self.access.contains(crate::StorageAccess::LOAD) {
337 "read"
338 } else {
339 log::warn!(
340 "Storage access for {:?} (name '{}'): {:?}",
341 self.handle,
342 ty.name.as_deref().unwrap_or_default(),
343 self.access
344 );
345 unreachable!("module is not valid");
346 };
347 ("texture", "", format.into(), access)
348 }
349 crate::ImageClass::External => {
350 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351 }
352 };
353 let base_name = scalar.to_msl_name();
354 let array_str = if arrayed { "_array" } else { "" };
355 write!(
356 out,
357 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358 )
359 }
360 crate::TypeInner::Sampler { comparison: _ } => {
361 write!(out, "{NAMESPACE}::sampler")
362 }
363 crate::TypeInner::AccelerationStructure { vertex_return } => {
364 if vertex_return {
365 unimplemented!("metal does not support vertex ray hit return")
366 }
367 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368 }
369 crate::TypeInner::RayQuery { vertex_return } => {
370 if vertex_return {
371 unimplemented!("metal does not support vertex ray hit return")
372 }
373 write!(out, "{RAY_QUERY_TYPE}")
374 }
375 crate::TypeInner::BindingArray { base, .. } => {
376 let base_tyname = Self {
377 handle: base,
378 first_time: false,
379 ..*self
380 };
381
382 write!(
383 out,
384 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385 )
386 }
387 }
388 }
389}
390
391struct TypedGlobalVariable<'a> {
392 module: &'a crate::Module,
393 names: &'a FastHashMap<NameKey, String>,
394 handle: Handle<crate::GlobalVariable>,
395 usage: valid::GlobalUse,
396 reference: bool,
397}
398
399impl TypedGlobalVariable<'_> {
400 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
401 let var = &self.module.global_variables[self.handle];
402 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
403
404 let storage_access = match var.space {
405 crate::AddressSpace::Storage { access } => access,
406 _ => match self.module.types[var.ty].inner {
407 crate::TypeInner::Image {
408 class: crate::ImageClass::Storage { access, .. },
409 ..
410 } => access,
411 crate::TypeInner::BindingArray { base, .. } => {
412 match self.module.types[base].inner {
413 crate::TypeInner::Image {
414 class: crate::ImageClass::Storage { access, .. },
415 ..
416 } => access,
417 _ => crate::StorageAccess::default(),
418 }
419 }
420 _ => crate::StorageAccess::default(),
421 },
422 };
423 let ty_name = TypeContext {
424 handle: var.ty,
425 gctx: self.module.to_ctx(),
426 names: self.names,
427 access: storage_access,
428 first_time: false,
429 };
430
431 let (space, access, reference) = match var.space.to_msl_name() {
432 Some(space) if self.reference => {
433 let access = if var.space.needs_access_qualifier()
434 && !self.usage.intersects(valid::GlobalUse::WRITE)
435 {
436 "const"
437 } else {
438 ""
439 };
440 (space, access, "&")
441 }
442 _ => ("", "", ""),
443 };
444
445 Ok(write!(
446 out,
447 "{}{}{}{}{}{} {}",
448 space,
449 if space.is_empty() { "" } else { " " },
450 ty_name,
451 if access.is_empty() { "" } else { " " },
452 access,
453 reference,
454 name,
455 )?)
456 }
457}
458
459#[derive(Eq, PartialEq, Hash)]
460enum WrappedFunction {
461 UnaryOp {
462 op: crate::UnaryOperator,
463 ty: (Option<crate::VectorSize>, crate::Scalar),
464 },
465 BinaryOp {
466 op: crate::BinaryOperator,
467 left_ty: (Option<crate::VectorSize>, crate::Scalar),
468 right_ty: (Option<crate::VectorSize>, crate::Scalar),
469 },
470 Math {
471 fun: crate::MathFunction,
472 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
473 },
474 Cast {
475 src_scalar: crate::Scalar,
476 vector_size: Option<crate::VectorSize>,
477 dst_scalar: crate::Scalar,
478 },
479 ImageLoad {
480 class: crate::ImageClass,
481 },
482 ImageSample {
483 class: crate::ImageClass,
484 clamp_to_edge: bool,
485 },
486 ImageQuerySize {
487 class: crate::ImageClass,
488 },
489 CooperativeLoad {
490 space_name: &'static str,
491 columns: crate::CooperativeSize,
492 rows: crate::CooperativeSize,
493 scalar: crate::Scalar,
494 },
495 CooperativeMultiplyAdd {
496 space_name: &'static str,
497 columns: crate::CooperativeSize,
498 rows: crate::CooperativeSize,
499 intermediate: crate::CooperativeSize,
500 scalar: crate::Scalar,
501 },
502}
503
504pub struct Writer<W> {
505 out: W,
506 names: FastHashMap<NameKey, String>,
507 named_expressions: crate::NamedExpressions,
508 need_bake_expressions: back::NeedBakeExpressions,
510 namer: proc::Namer,
511 wrapped_functions: FastHashSet<WrappedFunction>,
512 #[cfg(test)]
513 put_expression_stack_pointers: FastHashSet<*const ()>,
514 #[cfg(test)]
515 put_block_stack_pointers: FastHashSet<*const ()>,
516 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
519}
520
521impl crate::Scalar {
522 pub(super) fn to_msl_name(self) -> &'static str {
523 use crate::ScalarKind as Sk;
524 match self {
525 Self {
526 kind: Sk::Float,
527 width: 4,
528 } => "float",
529 Self {
530 kind: Sk::Float,
531 width: 2,
532 } => "half",
533 Self {
534 kind: Sk::Sint,
535 width: 4,
536 } => "int",
537 Self {
538 kind: Sk::Uint,
539 width: 4,
540 } => "uint",
541 Self {
542 kind: Sk::Sint,
543 width: 8,
544 } => "long",
545 Self {
546 kind: Sk::Uint,
547 width: 8,
548 } => "ulong",
549 Self {
550 kind: Sk::Bool,
551 width: _,
552 } => "bool",
553 Self {
554 kind: Sk::AbstractInt | Sk::AbstractFloat,
555 width: _,
556 } => unreachable!("Found Abstract scalar kind"),
557 _ => unreachable!("Unsupported scalar kind: {:?}", self),
558 }
559 }
560}
561
562const fn separate(need_separator: bool) -> &'static str {
563 if need_separator {
564 ","
565 } else {
566 ""
567 }
568}
569
570fn should_pack_struct_member(
571 members: &[crate::StructMember],
572 span: u32,
573 index: usize,
574 module: &crate::Module,
575) -> Option<crate::Scalar> {
576 let member = &members[index];
577
578 let ty_inner = &module.types[member.ty].inner;
579 let last_offset = member.offset + ty_inner.size(module.to_ctx());
580 let next_offset = match members.get(index + 1) {
581 Some(next) => next.offset,
582 None => span,
583 };
584 let is_tight = next_offset == last_offset;
585
586 match *ty_inner {
587 crate::TypeInner::Vector {
588 size: crate::VectorSize::Tri,
589 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
590 } if is_tight => Some(scalar),
591 _ => None,
592 }
593}
594
595fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
596 match arena[ty].inner {
597 crate::TypeInner::Struct { ref members, .. } => {
598 if let Some(member) = members.last() {
599 if let crate::TypeInner::Array {
600 size: crate::ArraySize::Dynamic,
601 ..
602 } = arena[member.ty].inner
603 {
604 return true;
605 }
606 }
607 false
608 }
609 crate::TypeInner::Array {
610 size: crate::ArraySize::Dynamic,
611 ..
612 } => true,
613 _ => false,
614 }
615}
616
617impl crate::AddressSpace {
618 const fn needs_pass_through(&self) -> bool {
622 match *self {
623 Self::Uniform
624 | Self::Storage { .. }
625 | Self::Private
626 | Self::WorkGroup
627 | Self::Immediate
628 | Self::Handle
629 | Self::TaskPayload => true,
630 Self::Function => false,
631 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
632 }
633 }
634
635 const fn needs_access_qualifier(&self) -> bool {
637 match *self {
638 Self::Storage { .. } => true,
643 Self::TaskPayload | Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
644 Self::Private | Self::WorkGroup => false,
646 Self::Uniform | Self::Immediate => false,
648 Self::Handle | Self::Function => false,
650 }
651 }
652
653 const fn to_msl_name(self) -> Option<&'static str> {
654 match self {
655 Self::Handle => None,
656 Self::Uniform | Self::Immediate => Some("constant"),
657 Self::Storage { .. } => Some("device"),
658 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
662 Self::WorkGroup => Some("threadgroup"),
663 Self::TaskPayload => Some("object_data"),
664 Self::IncomingRayPayload => Some("ray_data"),
665 }
666 }
667}
668
669impl crate::Type {
670 const fn needs_alias(&self) -> bool {
672 use crate::TypeInner as Ti;
673
674 match self.inner {
675 Ti::Scalar(_)
677 | Ti::Vector { .. }
678 | Ti::Matrix { .. }
679 | Ti::CooperativeMatrix { .. }
680 | Ti::Atomic(_)
681 | Ti::Pointer { .. }
682 | Ti::ValuePointer { .. } => self.name.is_some(),
683 Ti::Struct { .. } | Ti::Array { .. } => true,
685 Ti::Image { .. }
687 | Ti::Sampler { .. }
688 | Ti::AccelerationStructure { .. }
689 | Ti::RayQuery { .. }
690 | Ti::BindingArray { .. } => false,
691 }
692 }
693}
694
695#[derive(Clone, Copy)]
696enum FunctionOrigin {
697 Handle(Handle<crate::Function>),
698 EntryPoint(proc::EntryPointIndex),
699}
700
701trait NameKeyExt {
702 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
703 match origin {
704 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
705 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
706 }
707 }
708
709 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
714 match origin {
715 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
716 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
717 }
718 }
719}
720
721impl NameKeyExt for NameKey {}
722
723#[derive(Clone, Copy)]
733enum LevelOfDetail {
734 Direct(Handle<crate::Expression>),
735 Restricted(Handle<crate::Expression>),
736}
737
738struct TexelAddress {
748 coordinate: Handle<crate::Expression>,
749 array_index: Option<Handle<crate::Expression>>,
750 sample: Option<Handle<crate::Expression>>,
751 level: Option<LevelOfDetail>,
752}
753
754struct ExpressionContext<'a> {
755 function: &'a crate::Function,
756 origin: FunctionOrigin,
757 info: &'a valid::FunctionInfo,
758 module: &'a crate::Module,
759 mod_info: &'a valid::ModuleInfo,
760 pipeline_options: &'a PipelineOptions,
761 lang_version: (u8, u8),
762 policies: index::BoundsCheckPolicies,
763
764 guarded_indices: HandleSet<crate::Expression>,
768 force_loop_bounding: bool,
770}
771
772impl<'a> ExpressionContext<'a> {
773 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
774 self.info[handle].ty.inner_with(&self.module.types)
775 }
776
777 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
784 let image_ty = self.resolve_type(image);
785 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
786 class.is_mipmapped() && dim != crate::ImageDimension::D1
787 } else {
788 false
789 }
790 }
791
792 fn choose_bounds_check_policy(
793 &self,
794 pointer: Handle<crate::Expression>,
795 ) -> index::BoundsCheckPolicy {
796 self.policies
797 .choose_policy(pointer, &self.module.types, self.info)
798 }
799
800 fn access_needs_check(
802 &self,
803 base: Handle<crate::Expression>,
804 index: index::GuardedIndex,
805 ) -> Option<index::IndexableLength> {
806 index::access_needs_check(
807 base,
808 index,
809 self.module,
810 &self.function.expressions,
811 self.info,
812 )
813 }
814
815 fn bounds_check_iter(
817 &self,
818 chain: Handle<crate::Expression>,
819 ) -> impl Iterator<Item = BoundsCheck> + '_ {
820 index::bounds_check_iter(chain, self.module, self.function, self.info)
821 }
822
823 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
825 index::oob_local_types(self.module, self.function, self.info, self.policies)
826 }
827
828 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
829 match self.function.expressions[expr_handle] {
830 crate::Expression::AccessIndex { base, index } => {
831 let ty = match *self.resolve_type(base) {
832 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
833 ref ty => ty,
834 };
835 match *ty {
836 crate::TypeInner::Struct {
837 ref members, span, ..
838 } => should_pack_struct_member(members, span, index as usize, self.module),
839 _ => None,
840 }
841 }
842 _ => None,
843 }
844 }
845}
846
847struct StatementContext<'a> {
848 expression: ExpressionContext<'a>,
849 result_struct: Option<&'a str>,
850}
851
852impl<W: Write> Writer<W> {
853 pub fn new(out: W) -> Self {
855 Writer {
856 out,
857 names: FastHashMap::default(),
858 named_expressions: Default::default(),
859 need_bake_expressions: Default::default(),
860 namer: proc::Namer::default(),
861 wrapped_functions: FastHashSet::default(),
862 #[cfg(test)]
863 put_expression_stack_pointers: Default::default(),
864 #[cfg(test)]
865 put_block_stack_pointers: Default::default(),
866 struct_member_pads: FastHashSet::default(),
867 }
868 }
869
870 pub fn finish(self) -> W {
873 self.out
874 }
875
876 fn gen_force_bounded_loop_statements(
983 &mut self,
984 level: back::Level,
985 context: &StatementContext,
986 ) -> Option<(String, String)> {
987 if !context.expression.force_loop_bounding {
988 return None;
989 }
990
991 let loop_bound_name = self.namer.call("loop_bound");
992 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
995 let level = level.next();
996 let break_and_inc = format!(
997 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
998{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
999 );
1000
1001 Some((decl, break_and_inc))
1002 }
1003
1004 fn put_call_parameters(
1005 &mut self,
1006 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1007 context: &ExpressionContext,
1008 ) -> BackendResult {
1009 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1010 writer.put_expression(expr, context, true)
1011 })
1012 }
1013
1014 fn put_call_parameters_impl<C, E>(
1015 &mut self,
1016 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1017 ctx: &C,
1018 put_expression: E,
1019 ) -> BackendResult
1020 where
1021 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1022 {
1023 write!(self.out, "(")?;
1024 for (i, handle) in parameters.enumerate() {
1025 if i != 0 {
1026 write!(self.out, ", ")?;
1027 }
1028 put_expression(self, ctx, handle)?;
1029 }
1030 write!(self.out, ")")?;
1031 Ok(())
1032 }
1033
1034 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1040 let oob_local_types = context.oob_local_types();
1041 for &ty in oob_local_types.iter() {
1042 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1043 self.names.insert(name_key, self.namer.call("oob"));
1044 }
1045
1046 for (name_key, ty, init) in context
1047 .function
1048 .local_variables
1049 .iter()
1050 .map(|(local_handle, local)| {
1051 let name_key = NameKey::local(context.origin, local_handle);
1052 (name_key, local.ty, local.init)
1053 })
1054 .chain(oob_local_types.iter().map(|&ty| {
1055 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1056 (name_key, ty, None)
1057 }))
1058 {
1059 let ty_name = TypeContext {
1060 handle: ty,
1061 gctx: context.module.to_ctx(),
1062 names: &self.names,
1063 access: crate::StorageAccess::empty(),
1064 first_time: false,
1065 };
1066 write!(
1067 self.out,
1068 "{}{} {}",
1069 back::INDENT,
1070 ty_name,
1071 self.names[&name_key]
1072 )?;
1073 match init {
1074 Some(value) => {
1075 write!(self.out, " = ")?;
1076 self.put_expression(value, context, true)?;
1077 }
1078 None => {
1079 write!(self.out, " = {{}}")?;
1080 }
1081 };
1082 writeln!(self.out, ";")?;
1083 }
1084 Ok(())
1085 }
1086
1087 fn put_level_of_detail(
1088 &mut self,
1089 level: LevelOfDetail,
1090 context: &ExpressionContext,
1091 ) -> BackendResult {
1092 match level {
1093 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1094 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1095 }
1096 Ok(())
1097 }
1098
1099 fn put_image_query(
1100 &mut self,
1101 image: Handle<crate::Expression>,
1102 query: &str,
1103 level: Option<LevelOfDetail>,
1104 context: &ExpressionContext,
1105 ) -> BackendResult {
1106 self.put_expression(image, context, false)?;
1107 write!(self.out, ".get_{query}(")?;
1108 if let Some(level) = level {
1109 self.put_level_of_detail(level, context)?;
1110 }
1111 write!(self.out, ")")?;
1112 Ok(())
1113 }
1114
1115 fn put_image_size_query(
1116 &mut self,
1117 image: Handle<crate::Expression>,
1118 level: Option<LevelOfDetail>,
1119 kind: crate::ScalarKind,
1120 context: &ExpressionContext,
1121 ) -> BackendResult {
1122 if let crate::TypeInner::Image {
1123 class: crate::ImageClass::External,
1124 ..
1125 } = *context.resolve_type(image)
1126 {
1127 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1128 self.put_expression(image, context, true)?;
1129 write!(self.out, ")")?;
1130 return Ok(());
1131 }
1132
1133 let dim = match *context.resolve_type(image) {
1136 crate::TypeInner::Image { dim, .. } => dim,
1137 ref other => unreachable!("Unexpected type {:?}", other),
1138 };
1139 let scalar = crate::Scalar { kind, width: 4 };
1140 let coordinate_type = scalar.to_msl_name();
1141 match dim {
1142 crate::ImageDimension::D1 => {
1143 if kind == crate::ScalarKind::Uint {
1147 self.put_image_query(image, "width", None, context)?;
1149 } else {
1150 write!(self.out, "int(")?;
1152 self.put_image_query(image, "width", None, context)?;
1153 write!(self.out, ")")?;
1154 }
1155 }
1156 crate::ImageDimension::D2 => {
1157 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1158 self.put_image_query(image, "width", level, context)?;
1159 write!(self.out, ", ")?;
1160 self.put_image_query(image, "height", level, context)?;
1161 write!(self.out, ")")?;
1162 }
1163 crate::ImageDimension::D3 => {
1164 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1165 self.put_image_query(image, "width", level, context)?;
1166 write!(self.out, ", ")?;
1167 self.put_image_query(image, "height", level, context)?;
1168 write!(self.out, ", ")?;
1169 self.put_image_query(image, "depth", level, context)?;
1170 write!(self.out, ")")?;
1171 }
1172 crate::ImageDimension::Cube => {
1173 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1174 self.put_image_query(image, "width", level, context)?;
1175 write!(self.out, ")")?;
1176 }
1177 }
1178 Ok(())
1179 }
1180
1181 fn put_cast_to_uint_scalar_or_vector(
1182 &mut self,
1183 expr: Handle<crate::Expression>,
1184 context: &ExpressionContext,
1185 ) -> BackendResult {
1186 match *context.resolve_type(expr) {
1188 crate::TypeInner::Scalar(_) => {
1189 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1190 }
1191 crate::TypeInner::Vector { size, .. } => {
1192 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1193 }
1194 _ => {
1195 return Err(Error::GenericValidation(
1196 "Invalid type for image coordinate".into(),
1197 ))
1198 }
1199 };
1200
1201 write!(self.out, "(")?;
1202 self.put_expression(expr, context, true)?;
1203 write!(self.out, ")")?;
1204 Ok(())
1205 }
1206
1207 fn put_image_sample_level(
1208 &mut self,
1209 image: Handle<crate::Expression>,
1210 level: crate::SampleLevel,
1211 context: &ExpressionContext,
1212 ) -> BackendResult {
1213 let has_levels = context.image_needs_lod(image);
1214 match level {
1215 crate::SampleLevel::Auto => {}
1216 crate::SampleLevel::Zero => {
1217 }
1219 _ if !has_levels => {
1220 log::warn!("1D image can't be sampled with level {level:?}");
1221 }
1222 crate::SampleLevel::Exact(h) => {
1223 write!(self.out, ", {NAMESPACE}::level(")?;
1224 self.put_expression(h, context, true)?;
1225 write!(self.out, ")")?;
1226 }
1227 crate::SampleLevel::Bias(h) => {
1228 write!(self.out, ", {NAMESPACE}::bias(")?;
1229 self.put_expression(h, context, true)?;
1230 write!(self.out, ")")?;
1231 }
1232 crate::SampleLevel::Gradient { x, y } => {
1233 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1234 self.put_expression(x, context, true)?;
1235 write!(self.out, ", ")?;
1236 self.put_expression(y, context, true)?;
1237 write!(self.out, ")")?;
1238 }
1239 }
1240 Ok(())
1241 }
1242
1243 fn put_image_coordinate_limits(
1244 &mut self,
1245 image: Handle<crate::Expression>,
1246 level: Option<LevelOfDetail>,
1247 context: &ExpressionContext,
1248 ) -> BackendResult {
1249 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1250 write!(self.out, " - 1")?;
1251 Ok(())
1252 }
1253
1254 fn put_restricted_scalar_image_index(
1272 &mut self,
1273 image: Handle<crate::Expression>,
1274 index: Handle<crate::Expression>,
1275 limit_method: &str,
1276 context: &ExpressionContext,
1277 ) -> BackendResult {
1278 write!(self.out, "{NAMESPACE}::min(uint(")?;
1279 self.put_expression(index, context, true)?;
1280 write!(self.out, "), ")?;
1281 self.put_expression(image, context, false)?;
1282 write!(self.out, ".{limit_method}() - 1)")?;
1283 Ok(())
1284 }
1285
1286 fn put_restricted_texel_address(
1287 &mut self,
1288 image: Handle<crate::Expression>,
1289 address: &TexelAddress,
1290 context: &ExpressionContext,
1291 ) -> BackendResult {
1292 write!(self.out, "{NAMESPACE}::min(")?;
1294 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1295 write!(self.out, ", ")?;
1296 self.put_image_coordinate_limits(image, address.level, context)?;
1297 write!(self.out, ")")?;
1298
1299 if let Some(array_index) = address.array_index {
1301 write!(self.out, ", ")?;
1302 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1303 }
1304
1305 if let Some(sample) = address.sample {
1307 write!(self.out, ", ")?;
1308 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1309 }
1310
1311 if let Some(level) = address.level {
1314 write!(self.out, ", ")?;
1315 self.put_level_of_detail(level, context)?;
1316 }
1317
1318 Ok(())
1319 }
1320
1321 fn put_image_access_bounds_check(
1323 &mut self,
1324 image: Handle<crate::Expression>,
1325 address: &TexelAddress,
1326 context: &ExpressionContext,
1327 ) -> BackendResult {
1328 let mut conjunction = "";
1329
1330 let level = if let Some(level) = address.level {
1333 write!(self.out, "uint(")?;
1334 self.put_level_of_detail(level, context)?;
1335 write!(self.out, ") < ")?;
1336 self.put_expression(image, context, true)?;
1337 write!(self.out, ".get_num_mip_levels()")?;
1338 conjunction = " && ";
1339 Some(level)
1340 } else {
1341 None
1342 };
1343
1344 if let Some(sample) = address.sample {
1346 write!(self.out, "uint(")?;
1347 self.put_expression(sample, context, true)?;
1348 write!(self.out, ") < ")?;
1349 self.put_expression(image, context, true)?;
1350 write!(self.out, ".get_num_samples()")?;
1351 conjunction = " && ";
1352 }
1353
1354 if let Some(array_index) = address.array_index {
1356 write!(self.out, "{conjunction}uint(")?;
1357 self.put_expression(array_index, context, true)?;
1358 write!(self.out, ") < ")?;
1359 self.put_expression(image, context, true)?;
1360 write!(self.out, ".get_array_size()")?;
1361 conjunction = " && ";
1362 }
1363
1364 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1366 crate::TypeInner::Vector { .. } => true,
1367 _ => false,
1368 };
1369 write!(self.out, "{conjunction}")?;
1370 if coord_is_vector {
1371 write!(self.out, "{NAMESPACE}::all(")?;
1372 }
1373 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1374 write!(self.out, " < ")?;
1375 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1376 if coord_is_vector {
1377 write!(self.out, ")")?;
1378 }
1379
1380 Ok(())
1381 }
1382
1383 fn put_image_load(
1384 &mut self,
1385 load: Handle<crate::Expression>,
1386 image: Handle<crate::Expression>,
1387 mut address: TexelAddress,
1388 context: &ExpressionContext,
1389 ) -> BackendResult {
1390 if let crate::TypeInner::Image {
1391 class: crate::ImageClass::External,
1392 ..
1393 } = *context.resolve_type(image)
1394 {
1395 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1396 self.put_expression(image, context, true)?;
1397 write!(self.out, ", ")?;
1398 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1399 write!(self.out, ")")?;
1400 return Ok(());
1401 }
1402
1403 match context.policies.image_load {
1404 proc::BoundsCheckPolicy::Restrict => {
1405 if address.level.is_some() {
1408 address.level = if context.image_needs_lod(image) {
1409 Some(LevelOfDetail::Restricted(load))
1410 } else {
1411 None
1412 }
1413 }
1414
1415 self.put_expression(image, context, false)?;
1416 write!(self.out, ".read(")?;
1417 self.put_restricted_texel_address(image, &address, context)?;
1418 write!(self.out, ")")?;
1419 }
1420 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1421 write!(self.out, "(")?;
1422 self.put_image_access_bounds_check(image, &address, context)?;
1423 write!(self.out, " ? ")?;
1424 self.put_unchecked_image_load(image, &address, context)?;
1425 write!(self.out, ": DefaultConstructible())")?;
1426 }
1427 proc::BoundsCheckPolicy::Unchecked => {
1428 self.put_unchecked_image_load(image, &address, context)?;
1429 }
1430 }
1431
1432 Ok(())
1433 }
1434
1435 fn put_unchecked_image_load(
1436 &mut self,
1437 image: Handle<crate::Expression>,
1438 address: &TexelAddress,
1439 context: &ExpressionContext,
1440 ) -> BackendResult {
1441 self.put_expression(image, context, false)?;
1442 write!(self.out, ".read(")?;
1443 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1445 if let Some(expr) = address.array_index {
1446 write!(self.out, ", ")?;
1447 self.put_expression(expr, context, true)?;
1448 }
1449 if let Some(sample) = address.sample {
1450 write!(self.out, ", ")?;
1451 self.put_expression(sample, context, true)?;
1452 }
1453 if let Some(level) = address.level {
1454 if context.image_needs_lod(image) {
1455 write!(self.out, ", ")?;
1456 self.put_level_of_detail(level, context)?;
1457 }
1458 }
1459 write!(self.out, ")")?;
1460
1461 Ok(())
1462 }
1463
1464 fn put_image_atomic(
1465 &mut self,
1466 level: back::Level,
1467 image: Handle<crate::Expression>,
1468 address: &TexelAddress,
1469 fun: crate::AtomicFunction,
1470 value: Handle<crate::Expression>,
1471 context: &StatementContext,
1472 ) -> BackendResult {
1473 write!(self.out, "{level}")?;
1474 self.put_expression(image, &context.expression, false)?;
1475 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1476 fun.to_msl_64_bit()?
1477 } else {
1478 fun.to_msl()
1479 };
1480 write!(self.out, ".atomic_{op}(")?;
1481 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1483 write!(self.out, ", ")?;
1484 self.put_expression(value, &context.expression, true)?;
1485 writeln!(self.out, ");")?;
1486
1487 Ok(())
1488 }
1489
1490 fn put_image_store(
1491 &mut self,
1492 level: back::Level,
1493 image: Handle<crate::Expression>,
1494 address: &TexelAddress,
1495 value: Handle<crate::Expression>,
1496 context: &StatementContext,
1497 ) -> BackendResult {
1498 write!(self.out, "{level}")?;
1499 self.put_expression(image, &context.expression, false)?;
1500 write!(self.out, ".write(")?;
1501 self.put_expression(value, &context.expression, true)?;
1502 write!(self.out, ", ")?;
1503 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1505 if let Some(expr) = address.array_index {
1506 write!(self.out, ", ")?;
1507 self.put_expression(expr, &context.expression, true)?;
1508 }
1509 writeln!(self.out, ");")?;
1510
1511 Ok(())
1512 }
1513
1514 fn put_dynamic_array_max_index(
1525 &mut self,
1526 handle: Handle<crate::GlobalVariable>,
1527 context: &ExpressionContext,
1528 ) -> BackendResult {
1529 let global = &context.module.global_variables[handle];
1530 let (offset, array_ty) = match context.module.types[global.ty].inner {
1531 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1532 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1533 None => return Err(Error::GenericValidation("Struct has no members".into())),
1534 },
1535 crate::TypeInner::Array {
1536 size: crate::ArraySize::Dynamic,
1537 ..
1538 } => (0, global.ty),
1539 ref ty => {
1540 return Err(Error::GenericValidation(format!(
1541 "Expected type with dynamic array, got {ty:?}"
1542 )))
1543 }
1544 };
1545
1546 let (size, stride) = match context.module.types[array_ty].inner {
1547 crate::TypeInner::Array { base, stride, .. } => (
1548 context.module.types[base]
1549 .inner
1550 .size(context.module.to_ctx()),
1551 stride,
1552 ),
1553 ref ty => {
1554 return Err(Error::GenericValidation(format!(
1555 "Expected array type, got {ty:?}"
1556 )))
1557 }
1558 };
1559
1560 write!(
1573 self.out,
1574 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1575 member = ArraySizeMember(handle),
1576 offset = offset,
1577 size = size,
1578 stride = stride,
1579 )?;
1580 Ok(())
1581 }
1582
1583 fn put_dot_product<T: Copy>(
1588 &mut self,
1589 arg: T,
1590 arg1: T,
1591 size: usize,
1592 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1593 ) -> BackendResult {
1594 write!(self.out, "(")?;
1597
1598 for index in 0..size {
1600 write!(self.out, " + ")?;
1603 extractor(self, arg, index)?;
1604 write!(self.out, " * ")?;
1605 extractor(self, arg1, index)?;
1606 }
1607
1608 write!(self.out, ")")?;
1609 Ok(())
1610 }
1611
1612 fn put_pack4x8(
1614 &mut self,
1615 arg: Handle<crate::Expression>,
1616 context: &ExpressionContext<'_>,
1617 was_signed: bool,
1618 clamp_bounds: Option<(&str, &str)>,
1619 ) -> Result<(), Error> {
1620 let write_arg = |this: &mut Self| -> BackendResult {
1621 if let Some((min, max)) = clamp_bounds {
1622 write!(this.out, "{NAMESPACE}::clamp(")?;
1624 this.put_expression(arg, context, true)?;
1625 write!(this.out, ", {min}, {max})")?;
1626 } else {
1627 this.put_expression(arg, context, true)?;
1628 }
1629 Ok(())
1630 };
1631
1632 if context.lang_version >= (2, 1) {
1633 let packed_type = if was_signed {
1634 "packed_char4"
1635 } else {
1636 "packed_uchar4"
1637 };
1638 write!(self.out, "as_type<uint>({packed_type}(")?;
1640 write_arg(self)?;
1641 write!(self.out, "))")?;
1642 } else {
1643 if was_signed {
1645 write!(self.out, "uint(")?;
1646 }
1647 write!(self.out, "(")?;
1648 write_arg(self)?;
1649 write!(self.out, "[0] & 0xFF) | ((")?;
1650 write_arg(self)?;
1651 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1652 write_arg(self)?;
1653 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1654 write_arg(self)?;
1655 write!(self.out, "[3] & 0xFF) << 24)")?;
1656 if was_signed {
1657 write!(self.out, ")")?;
1658 }
1659 }
1660
1661 Ok(())
1662 }
1663
1664 fn put_isign(
1667 &mut self,
1668 arg: Handle<crate::Expression>,
1669 context: &ExpressionContext,
1670 ) -> BackendResult {
1671 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1672 let scalar = context
1673 .resolve_type(arg)
1674 .scalar()
1675 .expect("put_isign should only be called for args which have an integer scalar type")
1676 .to_msl_name();
1677 match context.resolve_type(arg) {
1678 &crate::TypeInner::Vector { size, .. } => {
1679 let size = common::vector_size_str(size);
1680 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1681 }
1682 _ => {
1683 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1684 }
1685 }
1686 write!(self.out, ", (")?;
1687 self.put_expression(arg, context, true)?;
1688 write!(self.out, " > 0)), {scalar}(0), (")?;
1689 self.put_expression(arg, context, true)?;
1690 write!(self.out, " == 0))")?;
1691 Ok(())
1692 }
1693
1694 fn put_const_expression(
1695 &mut self,
1696 expr_handle: Handle<crate::Expression>,
1697 module: &crate::Module,
1698 mod_info: &valid::ModuleInfo,
1699 arena: &crate::Arena<crate::Expression>,
1700 ) -> BackendResult {
1701 self.put_possibly_const_expression(
1702 expr_handle,
1703 arena,
1704 module,
1705 mod_info,
1706 &(module, mod_info),
1707 |&(_, mod_info), expr| &mod_info[expr],
1708 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1709 )
1710 }
1711
1712 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1713 match literal {
1714 crate::Literal::F64(_) => {
1715 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1716 }
1717 crate::Literal::F16(value) => {
1718 if value.is_infinite() {
1719 let sign = if value.is_sign_negative() { "-" } else { "" };
1720 write!(self.out, "{sign}INFINITY")?;
1721 } else if value.is_nan() {
1722 write!(self.out, "NAN")?;
1723 } else {
1724 let suffix = if value.fract() == f16::from_f32(0.0) {
1725 ".0h"
1726 } else {
1727 "h"
1728 };
1729 write!(self.out, "{value}{suffix}")?;
1730 }
1731 }
1732 crate::Literal::F32(value) => {
1733 if value.is_infinite() {
1734 let sign = if value.is_sign_negative() { "-" } else { "" };
1735 write!(self.out, "{sign}INFINITY")?;
1736 } else if value.is_nan() {
1737 write!(self.out, "NAN")?;
1738 } else {
1739 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1740 write!(self.out, "{value}{suffix}")?;
1741 }
1742 }
1743 crate::Literal::U32(value) => {
1744 write!(self.out, "{value}u")?;
1745 }
1746 crate::Literal::I32(value) => {
1747 if value == i32::MIN {
1752 write!(self.out, "({} - 1)", value + 1)?;
1753 } else {
1754 write!(self.out, "{value}")?;
1755 }
1756 }
1757 crate::Literal::U64(value) => {
1758 write!(self.out, "{value}uL")?;
1759 }
1760 crate::Literal::I64(value) => {
1761 if value == i64::MIN {
1768 write!(self.out, "({}L - 1L)", value + 1)?;
1769 } else {
1770 write!(self.out, "{value}L")?;
1771 }
1772 }
1773 crate::Literal::Bool(value) => {
1774 write!(self.out, "{value}")?;
1775 }
1776 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1777 return Err(Error::GenericValidation(
1778 "Unsupported abstract literal".into(),
1779 ));
1780 }
1781 }
1782 Ok(())
1783 }
1784
1785 #[allow(clippy::too_many_arguments)]
1786 fn put_possibly_const_expression<C, I, E>(
1787 &mut self,
1788 expr_handle: Handle<crate::Expression>,
1789 expressions: &crate::Arena<crate::Expression>,
1790 module: &crate::Module,
1791 mod_info: &valid::ModuleInfo,
1792 ctx: &C,
1793 get_expr_ty: I,
1794 put_expression: E,
1795 ) -> BackendResult
1796 where
1797 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1798 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1799 {
1800 match expressions[expr_handle] {
1801 crate::Expression::Literal(literal) => {
1802 self.put_literal(literal)?;
1803 }
1804 crate::Expression::Constant(handle) => {
1805 let constant = &module.constants[handle];
1806 if constant.name.is_some() {
1807 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1808 } else {
1809 self.put_const_expression(
1810 constant.init,
1811 module,
1812 mod_info,
1813 &module.global_expressions,
1814 )?;
1815 }
1816 }
1817 crate::Expression::ZeroValue(ty) => {
1818 let ty_name = TypeContext {
1819 handle: ty,
1820 gctx: module.to_ctx(),
1821 names: &self.names,
1822 access: crate::StorageAccess::empty(),
1823 first_time: false,
1824 };
1825 write!(self.out, "{ty_name} {{}}")?;
1826 }
1827 crate::Expression::Compose { ty, ref components } => {
1828 let ty_name = TypeContext {
1829 handle: ty,
1830 gctx: module.to_ctx(),
1831 names: &self.names,
1832 access: crate::StorageAccess::empty(),
1833 first_time: false,
1834 };
1835 write!(self.out, "{ty_name}")?;
1836 match module.types[ty].inner {
1837 crate::TypeInner::Scalar(_)
1838 | crate::TypeInner::Vector { .. }
1839 | crate::TypeInner::Matrix { .. } => {
1840 self.put_call_parameters_impl(
1841 components.iter().copied(),
1842 ctx,
1843 put_expression,
1844 )?;
1845 }
1846 crate::TypeInner::Array { .. } => {
1847 write!(self.out, " {{{{")?;
1850 for (index, &component) in components.iter().enumerate() {
1851 if index != 0 {
1852 write!(self.out, ", ")?;
1853 }
1854 put_expression(self, ctx, component)?;
1855 }
1856 write!(self.out, "}}}}")?;
1857 }
1858 crate::TypeInner::Struct { .. } => {
1859 write!(self.out, " {{")?;
1860 for (index, &component) in components.iter().enumerate() {
1861 if index != 0 {
1862 write!(self.out, ", ")?;
1863 }
1864 if self.struct_member_pads.contains(&(ty, index as u32)) {
1866 write!(self.out, "{{}}, ")?;
1867 }
1868 put_expression(self, ctx, component)?;
1869 }
1870 write!(self.out, "}}")?;
1871 }
1872 _ => return Err(Error::UnsupportedCompose(ty)),
1873 }
1874 }
1875 crate::Expression::Splat { size, value } => {
1876 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1877 crate::TypeInner::Scalar(scalar) => scalar,
1878 ref ty => {
1879 return Err(Error::GenericValidation(format!(
1880 "Expected splat value type must be a scalar, got {ty:?}",
1881 )))
1882 }
1883 };
1884 put_numeric_type(&mut self.out, scalar, &[size])?;
1885 write!(self.out, "(")?;
1886 put_expression(self, ctx, value)?;
1887 write!(self.out, ")")?;
1888 }
1889 _ => {
1890 return Err(Error::Override);
1891 }
1892 }
1893
1894 Ok(())
1895 }
1896
1897 fn put_expression(
1909 &mut self,
1910 expr_handle: Handle<crate::Expression>,
1911 context: &ExpressionContext,
1912 is_scoped: bool,
1913 ) -> BackendResult {
1914 #[cfg(test)]
1916 self.put_expression_stack_pointers
1917 .insert(ptr::from_ref(&expr_handle).cast());
1918
1919 if let Some(name) = self.named_expressions.get(&expr_handle) {
1920 write!(self.out, "{name}")?;
1921 return Ok(());
1922 }
1923
1924 let expression = &context.function.expressions[expr_handle];
1925 match *expression {
1926 crate::Expression::Literal(_)
1927 | crate::Expression::Constant(_)
1928 | crate::Expression::ZeroValue(_)
1929 | crate::Expression::Compose { .. }
1930 | crate::Expression::Splat { .. } => {
1931 self.put_possibly_const_expression(
1932 expr_handle,
1933 &context.function.expressions,
1934 context.module,
1935 context.mod_info,
1936 context,
1937 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1938 |writer, context, expr| writer.put_expression(expr, context, true),
1939 )?;
1940 }
1941 crate::Expression::Override(_) => return Err(Error::Override),
1942 crate::Expression::Access { base, .. }
1943 | crate::Expression::AccessIndex { base, .. } => {
1944 let policy = context.choose_bounds_check_policy(base);
1950 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1951 && self.put_bounds_checks(
1952 expr_handle,
1953 context,
1954 back::Level(0),
1955 if is_scoped { "" } else { "(" },
1956 )?
1957 {
1958 write!(self.out, " ? ")?;
1959 self.put_access_chain(expr_handle, policy, context)?;
1960 write!(self.out, " : ")?;
1961
1962 if context.resolve_type(base).pointer_space().is_some() {
1963 let result_ty = context.info[expr_handle]
1967 .ty
1968 .inner_with(&context.module.types)
1969 .pointer_base_type();
1970 let result_ty_handle = match result_ty {
1971 Some(TypeResolution::Handle(handle)) => handle,
1972 Some(TypeResolution::Value(_)) => {
1973 unreachable!(
1981 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1982 );
1983 }
1984 None => {
1985 unreachable!(
1986 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1987 )
1988 }
1989 };
1990 let name_key =
1991 NameKey::oob_local_for_type(context.origin, result_ty_handle);
1992 self.out.write_str(&self.names[&name_key])?;
1993 } else {
1994 write!(self.out, "DefaultConstructible()")?;
1995 }
1996
1997 if !is_scoped {
1998 write!(self.out, ")")?;
1999 }
2000 } else {
2001 self.put_access_chain(expr_handle, policy, context)?;
2002 }
2003 }
2004 crate::Expression::Swizzle {
2005 size,
2006 vector,
2007 pattern,
2008 } => {
2009 self.put_wrapped_expression_for_packed_vec3_access(
2010 vector,
2011 context,
2012 false,
2013 &Self::put_expression,
2014 )?;
2015 write!(self.out, ".")?;
2016 for &sc in pattern[..size as usize].iter() {
2017 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2018 }
2019 }
2020 crate::Expression::FunctionArgument(index) => {
2021 let name_key = match context.origin {
2022 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2023 FunctionOrigin::EntryPoint(ep_index) => {
2024 NameKey::EntryPointArgument(ep_index, index)
2025 }
2026 };
2027 let name = &self.names[&name_key];
2028 write!(self.out, "{name}")?;
2029 }
2030 crate::Expression::GlobalVariable(handle) => {
2031 let name = &self.names[&NameKey::GlobalVariable(handle)];
2032 write!(self.out, "{name}")?;
2033 }
2034 crate::Expression::LocalVariable(handle) => {
2035 let name_key = NameKey::local(context.origin, handle);
2036 let name = &self.names[&name_key];
2037 write!(self.out, "{name}")?;
2038 }
2039 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2040 crate::Expression::ImageSample {
2041 coordinate,
2042 image,
2043 sampler,
2044 clamp_to_edge: true,
2045 gather: None,
2046 array_index: None,
2047 offset: None,
2048 level: crate::SampleLevel::Zero,
2049 depth_ref: None,
2050 } => {
2051 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2052 self.put_expression(image, context, true)?;
2053 write!(self.out, ", ")?;
2054 self.put_expression(sampler, context, true)?;
2055 write!(self.out, ", ")?;
2056 self.put_expression(coordinate, context, true)?;
2057 write!(self.out, ")")?;
2058 }
2059 crate::Expression::ImageSample {
2060 image,
2061 sampler,
2062 gather,
2063 coordinate,
2064 array_index,
2065 offset,
2066 level,
2067 depth_ref,
2068 clamp_to_edge,
2069 } => {
2070 if clamp_to_edge {
2071 return Err(Error::GenericValidation(
2072 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2073 ));
2074 }
2075
2076 let main_op = match gather {
2077 Some(_) => "gather",
2078 None => "sample",
2079 };
2080 let comparison_op = match depth_ref {
2081 Some(_) => "_compare",
2082 None => "",
2083 };
2084 self.put_expression(image, context, false)?;
2085 write!(self.out, ".{main_op}{comparison_op}(")?;
2086 self.put_expression(sampler, context, true)?;
2087 write!(self.out, ", ")?;
2088 self.put_expression(coordinate, context, true)?;
2089 if let Some(expr) = array_index {
2090 write!(self.out, ", ")?;
2091 self.put_expression(expr, context, true)?;
2092 }
2093 if let Some(dref) = depth_ref {
2094 write!(self.out, ", ")?;
2095 self.put_expression(dref, context, true)?;
2096 }
2097
2098 self.put_image_sample_level(image, level, context)?;
2099
2100 if let Some(offset) = offset {
2101 write!(self.out, ", ")?;
2102 self.put_expression(offset, context, true)?;
2103 }
2104
2105 match gather {
2106 None | Some(crate::SwizzleComponent::X) => {}
2107 Some(component) => {
2108 let is_cube_map = match *context.resolve_type(image) {
2109 crate::TypeInner::Image {
2110 dim: crate::ImageDimension::Cube,
2111 ..
2112 } => true,
2113 _ => false,
2114 };
2115 if offset.is_none() && !is_cube_map {
2118 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2119 }
2120 let letter = back::COMPONENTS[component as usize];
2121 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2122 }
2123 }
2124 write!(self.out, ")")?;
2125 }
2126 crate::Expression::ImageLoad {
2127 image,
2128 coordinate,
2129 array_index,
2130 sample,
2131 level,
2132 } => {
2133 let address = TexelAddress {
2134 coordinate,
2135 array_index,
2136 sample,
2137 level: level.map(LevelOfDetail::Direct),
2138 };
2139 self.put_image_load(expr_handle, image, address, context)?;
2140 }
2141 crate::Expression::ImageQuery { image, query } => match query {
2144 crate::ImageQuery::Size { level } => {
2145 self.put_image_size_query(
2146 image,
2147 level.map(LevelOfDetail::Direct),
2148 crate::ScalarKind::Uint,
2149 context,
2150 )?;
2151 }
2152 crate::ImageQuery::NumLevels => {
2153 self.put_expression(image, context, false)?;
2154 write!(self.out, ".get_num_mip_levels()")?;
2155 }
2156 crate::ImageQuery::NumLayers => {
2157 self.put_expression(image, context, false)?;
2158 write!(self.out, ".get_array_size()")?;
2159 }
2160 crate::ImageQuery::NumSamples => {
2161 self.put_expression(image, context, false)?;
2162 write!(self.out, ".get_num_samples()")?;
2163 }
2164 },
2165 crate::Expression::Unary { op, expr } => {
2166 let op_str = match op {
2167 crate::UnaryOperator::Negate => {
2168 match context.resolve_type(expr).scalar_kind() {
2169 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2170 _ => "-",
2171 }
2172 }
2173 crate::UnaryOperator::LogicalNot => "!",
2174 crate::UnaryOperator::BitwiseNot => "~",
2175 };
2176 write!(self.out, "{op_str}(")?;
2177 self.put_expression(expr, context, false)?;
2178 write!(self.out, ")")?;
2179 }
2180 crate::Expression::Binary { op, left, right } => {
2181 let kind = context
2182 .resolve_type(left)
2183 .scalar_kind()
2184 .ok_or(Error::UnsupportedBinaryOp(op))?;
2185
2186 if op == crate::BinaryOperator::Divide
2187 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2188 {
2189 write!(self.out, "{DIV_FUNCTION}(")?;
2190 self.put_expression(left, context, true)?;
2191 write!(self.out, ", ")?;
2192 self.put_expression(right, context, true)?;
2193 write!(self.out, ")")?;
2194 } else if op == crate::BinaryOperator::Modulo
2195 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2196 {
2197 write!(self.out, "{MOD_FUNCTION}(")?;
2198 self.put_expression(left, context, true)?;
2199 write!(self.out, ", ")?;
2200 self.put_expression(right, context, true)?;
2201 write!(self.out, ")")?;
2202 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2203 write!(self.out, "{NAMESPACE}::fmod(")?;
2208 self.put_expression(left, context, true)?;
2209 write!(self.out, ", ")?;
2210 self.put_expression(right, context, true)?;
2211 write!(self.out, ")")?;
2212 } else if (op == crate::BinaryOperator::Add
2213 || op == crate::BinaryOperator::Subtract
2214 || op == crate::BinaryOperator::Multiply)
2215 && kind == crate::ScalarKind::Sint
2216 {
2217 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2218 crate::TypeInner::Scalar(scalar) => {
2219 Ok(crate::TypeInner::Scalar(crate::Scalar {
2220 kind: crate::ScalarKind::Uint,
2221 ..scalar
2222 }))
2223 }
2224 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2225 size,
2226 scalar: crate::Scalar {
2227 kind: crate::ScalarKind::Uint,
2228 ..scalar
2229 },
2230 }),
2231 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2232 };
2233
2234 self.put_bitcasted_expression(
2239 context.resolve_type(expr_handle),
2240 context,
2241 &|writer, context, is_scoped| {
2242 writer.put_binop(
2243 op,
2244 left,
2245 right,
2246 context,
2247 is_scoped,
2248 &|writer, expr, context, _is_scoped| {
2249 writer.put_bitcasted_expression(
2250 &to_unsigned(context.resolve_type(expr))?,
2251 context,
2252 &|writer, context, is_scoped| {
2253 writer.put_expression(expr, context, is_scoped)
2254 },
2255 )
2256 },
2257 )
2258 },
2259 )?;
2260 } else {
2261 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2262 }
2263 }
2264 crate::Expression::Select {
2265 condition,
2266 accept,
2267 reject,
2268 } => match *context.resolve_type(condition) {
2269 crate::TypeInner::Scalar(crate::Scalar {
2270 kind: crate::ScalarKind::Bool,
2271 ..
2272 }) => {
2273 if !is_scoped {
2274 write!(self.out, "(")?;
2275 }
2276 self.put_expression(condition, context, false)?;
2277 write!(self.out, " ? ")?;
2278 self.put_expression(accept, context, false)?;
2279 write!(self.out, " : ")?;
2280 self.put_expression(reject, context, false)?;
2281 if !is_scoped {
2282 write!(self.out, ")")?;
2283 }
2284 }
2285 crate::TypeInner::Vector {
2286 scalar:
2287 crate::Scalar {
2288 kind: crate::ScalarKind::Bool,
2289 ..
2290 },
2291 ..
2292 } => {
2293 write!(self.out, "{NAMESPACE}::select(")?;
2294 self.put_expression(reject, context, true)?;
2295 write!(self.out, ", ")?;
2296 self.put_expression(accept, context, true)?;
2297 write!(self.out, ", ")?;
2298 self.put_expression(condition, context, true)?;
2299 write!(self.out, ")")?;
2300 }
2301 ref ty => {
2302 return Err(Error::GenericValidation(format!(
2303 "Expected select condition to be a non-bool type, got {ty:?}",
2304 )))
2305 }
2306 },
2307 crate::Expression::Derivative { axis, expr, .. } => {
2308 use crate::DerivativeAxis as Axis;
2309 let op = match axis {
2310 Axis::X => "dfdx",
2311 Axis::Y => "dfdy",
2312 Axis::Width => "fwidth",
2313 };
2314 write!(self.out, "{NAMESPACE}::{op}")?;
2315 self.put_call_parameters(iter::once(expr), context)?;
2316 }
2317 crate::Expression::Relational { fun, argument } => {
2318 let op = match fun {
2319 crate::RelationalFunction::Any => "any",
2320 crate::RelationalFunction::All => "all",
2321 crate::RelationalFunction::IsNan => "isnan",
2322 crate::RelationalFunction::IsInf => "isinf",
2323 };
2324 write!(self.out, "{NAMESPACE}::{op}")?;
2325 self.put_call_parameters(iter::once(argument), context)?;
2326 }
2327 crate::Expression::Math {
2328 fun,
2329 arg,
2330 arg1,
2331 arg2,
2332 arg3,
2333 } => {
2334 use crate::MathFunction as Mf;
2335
2336 let arg_type = context.resolve_type(arg);
2337 let scalar_argument = match arg_type {
2338 &crate::TypeInner::Scalar(_) => true,
2339 _ => false,
2340 };
2341
2342 let fun_name = match fun {
2343 Mf::Abs => "abs",
2345 Mf::Min => "min",
2346 Mf::Max => "max",
2347 Mf::Clamp => "clamp",
2348 Mf::Saturate => "saturate",
2349 Mf::Cos => "cos",
2351 Mf::Cosh => "cosh",
2352 Mf::Sin => "sin",
2353 Mf::Sinh => "sinh",
2354 Mf::Tan => "tan",
2355 Mf::Tanh => "tanh",
2356 Mf::Acos => "acos",
2357 Mf::Asin => "asin",
2358 Mf::Atan => "atan",
2359 Mf::Atan2 => "atan2",
2360 Mf::Asinh => "asinh",
2361 Mf::Acosh => "acosh",
2362 Mf::Atanh => "atanh",
2363 Mf::Radians => "",
2364 Mf::Degrees => "",
2365 Mf::Ceil => "ceil",
2367 Mf::Floor => "floor",
2368 Mf::Round => "rint",
2369 Mf::Fract => "fract",
2370 Mf::Trunc => "trunc",
2371 Mf::Modf => MODF_FUNCTION,
2372 Mf::Frexp => FREXP_FUNCTION,
2373 Mf::Ldexp => "ldexp",
2374 Mf::Exp => "exp",
2376 Mf::Exp2 => "exp2",
2377 Mf::Log => "log",
2378 Mf::Log2 => "log2",
2379 Mf::Pow => "pow",
2380 Mf::Dot => match *context.resolve_type(arg) {
2382 crate::TypeInner::Vector {
2383 scalar:
2384 crate::Scalar {
2385 kind: crate::ScalarKind::Float,
2387 ..
2388 },
2389 ..
2390 } => "dot",
2391 crate::TypeInner::Vector {
2392 size,
2393 scalar:
2394 scalar @ crate::Scalar {
2395 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2396 ..
2397 },
2398 } => {
2399 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2401 write!(self.out, "{fun_name}(")?;
2402 self.put_expression(arg, context, true)?;
2403 write!(self.out, ", ")?;
2404 self.put_expression(arg1.unwrap(), context, true)?;
2405 write!(self.out, ")")?;
2406 return Ok(());
2407 }
2408 _ => unreachable!(
2409 "Correct TypeInner for dot product should be already validated"
2410 ),
2411 },
2412 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2413 if context.lang_version >= (2, 1) {
2414 let packed_type = match fun {
2418 Mf::Dot4I8Packed => "packed_char4",
2419 Mf::Dot4U8Packed => "packed_uchar4",
2420 _ => unreachable!(),
2421 };
2422
2423 return self.put_dot_product(
2424 Reinterpreted::new(packed_type, arg),
2425 Reinterpreted::new(packed_type, arg1.unwrap()),
2426 4,
2427 |writer, arg, index| {
2428 write!(writer.out, "{arg}[{index}]")?;
2431 Ok(())
2432 },
2433 );
2434 } else {
2435 let conversion = match fun {
2439 Mf::Dot4I8Packed => "int",
2440 Mf::Dot4U8Packed => "",
2441 _ => unreachable!(),
2442 };
2443
2444 return self.put_dot_product(
2445 arg,
2446 arg1.unwrap(),
2447 4,
2448 |writer, arg, index| {
2449 write!(writer.out, "({conversion}(")?;
2450 writer.put_expression(arg, context, true)?;
2451 if index == 3 {
2452 write!(writer.out, ") >> 24)")?;
2453 } else {
2454 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2455 }
2456 Ok(())
2457 },
2458 );
2459 }
2460 }
2461 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2462 Mf::Cross => "cross",
2463 Mf::Distance => "distance",
2464 Mf::Length if scalar_argument => "abs",
2465 Mf::Length => "length",
2466 Mf::Normalize => "normalize",
2467 Mf::FaceForward => "faceforward",
2468 Mf::Reflect => "reflect",
2469 Mf::Refract => "refract",
2470 Mf::Sign => match arg_type.scalar_kind() {
2472 Some(crate::ScalarKind::Sint) => {
2473 return self.put_isign(arg, context);
2474 }
2475 _ => "sign",
2476 },
2477 Mf::Fma => "fma",
2478 Mf::Mix => "mix",
2479 Mf::Step => "step",
2480 Mf::SmoothStep => "smoothstep",
2481 Mf::Sqrt => "sqrt",
2482 Mf::InverseSqrt => "rsqrt",
2483 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2484 Mf::Transpose => "transpose",
2485 Mf::Determinant => "determinant",
2486 Mf::QuantizeToF16 => "",
2487 Mf::CountTrailingZeros => "ctz",
2489 Mf::CountLeadingZeros => "clz",
2490 Mf::CountOneBits => "popcount",
2491 Mf::ReverseBits => "reverse_bits",
2492 Mf::ExtractBits => "",
2493 Mf::InsertBits => "",
2494 Mf::FirstTrailingBit => "",
2495 Mf::FirstLeadingBit => "",
2496 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2498 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2499 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2500 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2501 Mf::Pack2x16float => "",
2502 Mf::Pack4xI8 => "",
2503 Mf::Pack4xU8 => "",
2504 Mf::Pack4xI8Clamp => "",
2505 Mf::Pack4xU8Clamp => "",
2506 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2508 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2509 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2510 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2511 Mf::Unpack2x16float => "",
2512 Mf::Unpack4xI8 => "",
2513 Mf::Unpack4xU8 => "",
2514 };
2515
2516 match fun {
2517 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2518 if context.lang_version < (1, 2) {
2527 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2528 }
2529 }
2530 _ => {}
2531 }
2532
2533 match fun {
2534 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2535 write!(self.out, "{ABS_FUNCTION}(")?;
2536 self.put_expression(arg, context, true)?;
2537 write!(self.out, ")")?;
2538 }
2539 Mf::Distance if scalar_argument => {
2540 write!(self.out, "{NAMESPACE}::abs(")?;
2541 self.put_expression(arg, context, false)?;
2542 write!(self.out, " - ")?;
2543 self.put_expression(arg1.unwrap(), context, false)?;
2544 write!(self.out, ")")?;
2545 }
2546 Mf::FirstTrailingBit => {
2547 let scalar = context.resolve_type(arg).scalar().unwrap();
2548 let constant = scalar.width * 8 + 1;
2549
2550 write!(self.out, "((({NAMESPACE}::ctz(")?;
2551 self.put_expression(arg, context, true)?;
2552 write!(self.out, ") + 1) % {constant}) - 1)")?;
2553 }
2554 Mf::FirstLeadingBit => {
2555 let inner = context.resolve_type(arg);
2556 let scalar = inner.scalar().unwrap();
2557 let constant = scalar.width * 8 - 1;
2558
2559 write!(
2560 self.out,
2561 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2562 )?;
2563
2564 if scalar.kind == crate::ScalarKind::Sint {
2565 write!(self.out, "{NAMESPACE}::select(")?;
2566 self.put_expression(arg, context, true)?;
2567 write!(self.out, ", ~")?;
2568 self.put_expression(arg, context, true)?;
2569 write!(self.out, ", ")?;
2570 self.put_expression(arg, context, true)?;
2571 write!(self.out, " < 0)")?;
2572 } else {
2573 self.put_expression(arg, context, true)?;
2574 }
2575
2576 write!(self.out, "), ")?;
2577
2578 match *inner {
2580 crate::TypeInner::Vector { size, scalar } => {
2581 let size = common::vector_size_str(size);
2582 let name = scalar.to_msl_name();
2583 write!(self.out, "{name}{size}")?;
2584 }
2585 crate::TypeInner::Scalar(scalar) => {
2586 let name = scalar.to_msl_name();
2587 write!(self.out, "{name}")?;
2588 }
2589 _ => (),
2590 }
2591
2592 write!(self.out, "(-1), ")?;
2593 self.put_expression(arg, context, true)?;
2594 write!(self.out, " == 0 || ")?;
2595 self.put_expression(arg, context, true)?;
2596 write!(self.out, " == -1)")?;
2597 }
2598 Mf::Unpack2x16float => {
2599 write!(self.out, "float2(as_type<half2>(")?;
2600 self.put_expression(arg, context, false)?;
2601 write!(self.out, "))")?;
2602 }
2603 Mf::Pack2x16float => {
2604 write!(self.out, "as_type<uint>(half2(")?;
2605 self.put_expression(arg, context, false)?;
2606 write!(self.out, "))")?;
2607 }
2608 Mf::ExtractBits => {
2609 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2626
2627 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2628 self.put_expression(arg, context, true)?;
2629 write!(self.out, ", {NAMESPACE}::min(")?;
2630 self.put_expression(arg1.unwrap(), context, true)?;
2631 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2632 self.put_expression(arg2.unwrap(), context, true)?;
2633 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2634 self.put_expression(arg1.unwrap(), context, true)?;
2635 write!(self.out, ", {scalar_bits}u)))")?;
2636 }
2637 Mf::InsertBits => {
2638 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2643
2644 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2645 self.put_expression(arg, context, true)?;
2646 write!(self.out, ", ")?;
2647 self.put_expression(arg1.unwrap(), context, true)?;
2648 write!(self.out, ", {NAMESPACE}::min(")?;
2649 self.put_expression(arg2.unwrap(), context, true)?;
2650 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2651 self.put_expression(arg3.unwrap(), context, true)?;
2652 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2653 self.put_expression(arg2.unwrap(), context, true)?;
2654 write!(self.out, ", {scalar_bits}u)))")?;
2655 }
2656 Mf::Radians => {
2657 write!(self.out, "((")?;
2658 self.put_expression(arg, context, false)?;
2659 write!(self.out, ") * 0.017453292519943295474)")?;
2660 }
2661 Mf::Degrees => {
2662 write!(self.out, "((")?;
2663 self.put_expression(arg, context, false)?;
2664 write!(self.out, ") * 57.295779513082322865)")?;
2665 }
2666 Mf::Modf | Mf::Frexp => {
2667 write!(self.out, "{fun_name}")?;
2668 self.put_call_parameters(iter::once(arg), context)?;
2669 }
2670 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2671 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2672 Mf::Pack4xI8Clamp => {
2673 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2674 }
2675 Mf::Pack4xU8Clamp => {
2676 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2677 }
2678 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2679 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2680 "u"
2681 } else {
2682 ""
2683 };
2684
2685 if context.lang_version >= (2, 1) {
2686 write!(
2688 self.out,
2689 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2690 )?;
2691 self.put_expression(arg, context, true)?;
2692 write!(self.out, "))")?;
2693 } else {
2694 write!(self.out, "({sign_prefix}int4(")?;
2696 self.put_expression(arg, context, true)?;
2697 write!(self.out, ", ")?;
2698 self.put_expression(arg, context, true)?;
2699 write!(self.out, " >> 8, ")?;
2700 self.put_expression(arg, context, true)?;
2701 write!(self.out, " >> 16, ")?;
2702 self.put_expression(arg, context, true)?;
2703 write!(self.out, " >> 24) << 24 >> 24)")?;
2704 }
2705 }
2706 Mf::QuantizeToF16 => {
2707 match *context.resolve_type(arg) {
2708 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2709 crate::TypeInner::Vector { size, .. } => write!(
2710 self.out,
2711 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2712 size = common::vector_size_str(size),
2713 )?,
2714 _ => unreachable!(
2715 "Correct TypeInner for QuantizeToF16 should be already validated"
2716 ),
2717 };
2718
2719 self.put_expression(arg, context, true)?;
2720 write!(self.out, "))")?;
2721 }
2722 _ => {
2723 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2724 self.put_call_parameters(
2725 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2726 context,
2727 )?;
2728 }
2729 }
2730 }
2731 crate::Expression::As {
2732 expr,
2733 kind,
2734 convert,
2735 } => match *context.resolve_type(expr) {
2736 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2737 if src.kind == crate::ScalarKind::Float
2738 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2739 && convert.is_some()
2740 {
2741 let fun_name = match (kind, convert) {
2745 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2746 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2747 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2748 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2749 _ => unreachable!(),
2750 };
2751 write!(self.out, "{fun_name}(")?;
2752 self.put_expression(expr, context, true)?;
2753 write!(self.out, ")")?;
2754 } else {
2755 let target_scalar = crate::Scalar {
2756 kind,
2757 width: convert.unwrap_or(src.width),
2758 };
2759 let op = match convert {
2760 Some(_) => "static_cast",
2761 None => "as_type",
2762 };
2763 write!(self.out, "{op}<")?;
2764 match *context.resolve_type(expr) {
2765 crate::TypeInner::Vector { size, .. } => {
2766 put_numeric_type(&mut self.out, target_scalar, &[size])?
2767 }
2768 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2769 };
2770 write!(self.out, ">(")?;
2771 self.put_expression(expr, context, true)?;
2772 write!(self.out, ")")?;
2773 }
2774 }
2775 crate::TypeInner::Matrix {
2776 columns,
2777 rows,
2778 scalar,
2779 } => {
2780 let target_scalar = crate::Scalar {
2781 kind,
2782 width: convert.unwrap_or(scalar.width),
2783 };
2784 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2785 write!(self.out, "(")?;
2786 self.put_expression(expr, context, true)?;
2787 write!(self.out, ")")?;
2788 }
2789 ref ty => {
2790 return Err(Error::GenericValidation(format!(
2791 "Unsupported type for As: {ty:?}"
2792 )))
2793 }
2794 },
2795 crate::Expression::CallResult(_)
2797 | crate::Expression::AtomicResult { .. }
2798 | crate::Expression::WorkGroupUniformLoadResult { .. }
2799 | crate::Expression::SubgroupBallotResult
2800 | crate::Expression::SubgroupOperationResult { .. }
2801 | crate::Expression::RayQueryProceedResult => {
2802 unreachable!()
2803 }
2804 crate::Expression::ArrayLength(expr) => {
2805 let global = match context.function.expressions[expr] {
2807 crate::Expression::AccessIndex { base, .. } => {
2808 match context.function.expressions[base] {
2809 crate::Expression::GlobalVariable(handle) => handle,
2810 ref ex => {
2811 return Err(Error::GenericValidation(format!(
2812 "Expected global variable in AccessIndex, got {ex:?}"
2813 )))
2814 }
2815 }
2816 }
2817 crate::Expression::GlobalVariable(handle) => handle,
2818 ref ex => {
2819 return Err(Error::GenericValidation(format!(
2820 "Unexpected expression in ArrayLength, got {ex:?}"
2821 )))
2822 }
2823 };
2824
2825 if !is_scoped {
2826 write!(self.out, "(")?;
2827 }
2828 write!(self.out, "1 + ")?;
2829 self.put_dynamic_array_max_index(global, context)?;
2830 if !is_scoped {
2831 write!(self.out, ")")?;
2832 }
2833 }
2834 crate::Expression::RayQueryVertexPositions { .. } => {
2835 unimplemented!()
2836 }
2837 crate::Expression::RayQueryGetIntersection {
2838 query,
2839 committed: _,
2840 } => {
2841 if context.lang_version < (2, 4) {
2842 return Err(Error::UnsupportedRayTracing);
2843 }
2844
2845 let ty = context.module.special_types.ray_intersection.unwrap();
2846 let type_name = &self.names[&NameKey::Type(ty)];
2847 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2848 self.put_expression(query, context, true)?;
2849 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2850 let fields = [
2851 "distance",
2852 "user_instance_id", "instance_id",
2854 "", "geometry_id",
2856 "primitive_id",
2857 "triangle_barycentric_coord",
2858 "triangle_front_facing",
2859 "", "object_to_world_transform", "world_to_object_transform", ];
2863 for field in fields {
2864 write!(self.out, ", ")?;
2865 if field.is_empty() {
2866 write!(self.out, "{{}}")?;
2867 } else {
2868 self.put_expression(query, context, true)?;
2869 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2870 }
2871 }
2872 write!(self.out, "}}")?;
2873 }
2874 crate::Expression::CooperativeLoad { ref data, .. } => {
2875 if context.lang_version < (2, 3) {
2876 return Err(Error::UnsupportedCooperativeMatrix);
2877 }
2878 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2879 write!(self.out, "&")?;
2880 self.put_access_chain(data.pointer, context.policies.index, context)?;
2881 write!(self.out, ", ")?;
2882 self.put_expression(data.stride, context, true)?;
2883 write!(self.out, ", {})", data.row_major)?;
2884 }
2885 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2886 if context.lang_version < (2, 3) {
2887 return Err(Error::UnsupportedCooperativeMatrix);
2888 }
2889 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2890 self.put_expression(a, context, true)?;
2891 write!(self.out, ", ")?;
2892 self.put_expression(b, context, true)?;
2893 write!(self.out, ", ")?;
2894 self.put_expression(c, context, true)?;
2895 write!(self.out, ")")?;
2896 }
2897 }
2898 Ok(())
2899 }
2900
2901 fn put_binop<F>(
2904 &mut self,
2905 op: crate::BinaryOperator,
2906 left: Handle<crate::Expression>,
2907 right: Handle<crate::Expression>,
2908 context: &ExpressionContext,
2909 is_scoped: bool,
2910 put_expression: &F,
2911 ) -> BackendResult
2912 where
2913 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2914 {
2915 let op_str = back::binary_operation_str(op);
2916
2917 if !is_scoped {
2918 write!(self.out, "(")?;
2919 }
2920
2921 if op == crate::BinaryOperator::Multiply
2924 && matches!(
2925 context.resolve_type(right),
2926 &crate::TypeInner::Matrix { .. }
2927 )
2928 {
2929 self.put_wrapped_expression_for_packed_vec3_access(
2930 left,
2931 context,
2932 false,
2933 put_expression,
2934 )?;
2935 } else {
2936 put_expression(self, left, context, false)?;
2937 }
2938
2939 write!(self.out, " {op_str} ")?;
2940
2941 if op == crate::BinaryOperator::Multiply
2943 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2944 {
2945 self.put_wrapped_expression_for_packed_vec3_access(
2946 right,
2947 context,
2948 false,
2949 put_expression,
2950 )?;
2951 } else {
2952 put_expression(self, right, context, false)?;
2953 }
2954
2955 if !is_scoped {
2956 write!(self.out, ")")?;
2957 }
2958
2959 Ok(())
2960 }
2961
2962 fn put_wrapped_expression_for_packed_vec3_access<F>(
2964 &mut self,
2965 expr_handle: Handle<crate::Expression>,
2966 context: &ExpressionContext,
2967 is_scoped: bool,
2968 put_expression: &F,
2969 ) -> BackendResult
2970 where
2971 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2972 {
2973 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2974 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2975 put_expression(self, expr_handle, context, is_scoped)?;
2976 write!(self.out, ")")?;
2977 } else {
2978 put_expression(self, expr_handle, context, is_scoped)?;
2979 }
2980 Ok(())
2981 }
2982
2983 fn put_bitcasted_expression<F>(
2986 &mut self,
2987 cast_to: &crate::TypeInner,
2988 context: &ExpressionContext,
2989 put_expression: &F,
2990 ) -> BackendResult
2991 where
2992 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2993 {
2994 write!(self.out, "as_type<")?;
2995 match *cast_to {
2996 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2997 crate::TypeInner::Vector { size, scalar } => {
2998 put_numeric_type(&mut self.out, scalar, &[size])?
2999 }
3000 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3001 };
3002 write!(self.out, ">(")?;
3003 put_expression(self, context, true)?;
3004 write!(self.out, ")")?;
3005
3006 Ok(())
3007 }
3008
3009 fn put_index(
3011 &mut self,
3012 index: index::GuardedIndex,
3013 context: &ExpressionContext,
3014 is_scoped: bool,
3015 ) -> BackendResult {
3016 match index {
3017 index::GuardedIndex::Expression(expr) => {
3018 self.put_expression(expr, context, is_scoped)?
3019 }
3020 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3021 }
3022 Ok(())
3023 }
3024
3025 fn put_bounds_checks(
3055 &mut self,
3056 chain: Handle<crate::Expression>,
3057 context: &ExpressionContext,
3058 level: back::Level,
3059 prefix: &'static str,
3060 ) -> Result<bool, Error> {
3061 let mut check_written = false;
3062
3063 for item in context.bounds_check_iter(chain) {
3065 let BoundsCheck {
3066 base,
3067 index,
3068 length,
3069 } = item;
3070
3071 if check_written {
3072 write!(self.out, " && ")?;
3073 } else {
3074 write!(self.out, "{level}{prefix}")?;
3075 check_written = true;
3076 }
3077
3078 write!(self.out, "uint(")?;
3082 self.put_index(index, context, true)?;
3083 self.out.write_str(") < ")?;
3084 match length {
3085 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3086 index::IndexableLength::Dynamic => {
3087 let global = context.function.originating_global(base).ok_or_else(|| {
3088 Error::GenericValidation("Could not find originating global".into())
3089 })?;
3090 write!(self.out, "1 + ")?;
3091 self.put_dynamic_array_max_index(global, context)?
3092 }
3093 }
3094 }
3095
3096 Ok(check_written)
3097 }
3098
3099 fn put_access_chain(
3119 &mut self,
3120 chain: Handle<crate::Expression>,
3121 policy: index::BoundsCheckPolicy,
3122 context: &ExpressionContext,
3123 ) -> BackendResult {
3124 match context.function.expressions[chain] {
3125 crate::Expression::Access { base, index } => {
3126 let mut base_ty = context.resolve_type(base);
3127
3128 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3130 base_ty = &context.module.types[base].inner;
3131 }
3132
3133 self.put_subscripted_access_chain(
3134 base,
3135 base_ty,
3136 index::GuardedIndex::Expression(index),
3137 policy,
3138 context,
3139 )?;
3140 }
3141 crate::Expression::AccessIndex { base, index } => {
3142 let base_resolution = &context.info[base].ty;
3143 let mut base_ty = base_resolution.inner_with(&context.module.types);
3144 let mut base_ty_handle = base_resolution.handle();
3145
3146 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3148 base_ty = &context.module.types[base].inner;
3149 base_ty_handle = Some(base);
3150 }
3151
3152 match *base_ty {
3156 crate::TypeInner::Struct { .. } => {
3157 let base_ty = base_ty_handle.unwrap();
3158 self.put_access_chain(base, policy, context)?;
3159 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3160 write!(self.out, ".{name}")?;
3161 }
3162 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3163 self.put_access_chain(base, policy, context)?;
3164 if context.get_packed_vec_kind(base).is_some() {
3167 write!(self.out, "[{index}]")?;
3168 } else {
3169 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3170 }
3171 }
3172 _ => {
3173 self.put_subscripted_access_chain(
3174 base,
3175 base_ty,
3176 index::GuardedIndex::Known(index),
3177 policy,
3178 context,
3179 )?;
3180 }
3181 }
3182 }
3183 _ => self.put_expression(chain, context, false)?,
3184 }
3185
3186 Ok(())
3187 }
3188
3189 fn put_subscripted_access_chain(
3206 &mut self,
3207 base: Handle<crate::Expression>,
3208 base_ty: &crate::TypeInner,
3209 index: index::GuardedIndex,
3210 policy: index::BoundsCheckPolicy,
3211 context: &ExpressionContext,
3212 ) -> BackendResult {
3213 let accessing_wrapped_array = match *base_ty {
3214 crate::TypeInner::Array {
3215 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3216 ..
3217 } => true,
3218 _ => false,
3219 };
3220 let accessing_wrapped_binding_array =
3221 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3222
3223 self.put_access_chain(base, policy, context)?;
3224 if accessing_wrapped_array {
3225 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3226 }
3227 write!(self.out, "[")?;
3228
3229 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3231 context.access_needs_check(base, index)
3232 } else {
3233 None
3234 };
3235 if let Some(limit) = restriction_needed {
3236 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3237 self.put_index(index, context, true)?;
3238 write!(self.out, "), ")?;
3239 match limit {
3240 index::IndexableLength::Known(limit) => {
3241 write!(self.out, "{}u", limit - 1)?;
3242 }
3243 index::IndexableLength::Dynamic => {
3244 let global = context.function.originating_global(base).ok_or_else(|| {
3245 Error::GenericValidation("Could not find originating global".into())
3246 })?;
3247 self.put_dynamic_array_max_index(global, context)?;
3248 }
3249 }
3250 write!(self.out, ")")?;
3251 } else {
3252 self.put_index(index, context, true)?;
3253 }
3254
3255 write!(self.out, "]")?;
3256
3257 if accessing_wrapped_binding_array {
3258 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3259 }
3260
3261 Ok(())
3262 }
3263
3264 fn put_load(
3265 &mut self,
3266 pointer: Handle<crate::Expression>,
3267 context: &ExpressionContext,
3268 is_scoped: bool,
3269 ) -> BackendResult {
3270 let policy = context.choose_bounds_check_policy(pointer);
3273 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3274 && self.put_bounds_checks(
3275 pointer,
3276 context,
3277 back::Level(0),
3278 if is_scoped { "" } else { "(" },
3279 )?
3280 {
3281 write!(self.out, " ? ")?;
3282 self.put_unchecked_load(pointer, policy, context)?;
3283 write!(self.out, " : DefaultConstructible()")?;
3284
3285 if !is_scoped {
3286 write!(self.out, ")")?;
3287 }
3288 } else {
3289 self.put_unchecked_load(pointer, policy, context)?;
3290 }
3291
3292 Ok(())
3293 }
3294
3295 fn put_unchecked_load(
3296 &mut self,
3297 pointer: Handle<crate::Expression>,
3298 policy: index::BoundsCheckPolicy,
3299 context: &ExpressionContext,
3300 ) -> BackendResult {
3301 let is_atomic_pointer = context
3302 .resolve_type(pointer)
3303 .is_atomic_pointer(&context.module.types);
3304
3305 if is_atomic_pointer {
3306 write!(
3307 self.out,
3308 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3309 )?;
3310 self.put_access_chain(pointer, policy, context)?;
3311 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3312 } else {
3313 self.put_access_chain(pointer, policy, context)?;
3317 }
3318
3319 Ok(())
3320 }
3321
3322 fn put_return_value(
3323 &mut self,
3324 level: back::Level,
3325 expr_handle: Handle<crate::Expression>,
3326 result_struct: Option<&str>,
3327 context: &ExpressionContext,
3328 ) -> BackendResult {
3329 match result_struct {
3330 Some(struct_name) => {
3331 let mut has_point_size = false;
3332 let result_ty = context.function.result.as_ref().unwrap().ty;
3333 match context.module.types[result_ty].inner {
3334 crate::TypeInner::Struct { ref members, .. } => {
3335 let tmp = "_tmp";
3336 write!(self.out, "{level}const auto {tmp} = ")?;
3337 self.put_expression(expr_handle, context, true)?;
3338 writeln!(self.out, ";")?;
3339 write!(self.out, "{level}return {struct_name} {{")?;
3340
3341 let mut is_first = true;
3342
3343 for (index, member) in members.iter().enumerate() {
3344 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3345 member.binding
3346 {
3347 has_point_size = true;
3348 if !context.pipeline_options.allow_and_force_point_size {
3349 continue;
3350 }
3351 }
3352
3353 let comma = if is_first { "" } else { "," };
3354 is_first = false;
3355 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3356 if let crate::TypeInner::Array {
3360 size: crate::ArraySize::Constant(size),
3361 ..
3362 } = context.module.types[member.ty].inner
3363 {
3364 write!(self.out, "{comma} {{")?;
3365 for j in 0..size.get() {
3366 if j != 0 {
3367 write!(self.out, ",")?;
3368 }
3369 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3370 }
3371 write!(self.out, "}}")?;
3372 } else {
3373 write!(self.out, "{comma} {tmp}.{name}")?;
3374 }
3375 }
3376 }
3377 _ => {
3378 write!(self.out, "{level}return {struct_name} {{ ")?;
3379 self.put_expression(expr_handle, context, true)?;
3380 }
3381 }
3382
3383 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3384 let stage = context.module.entry_points[ep_index as usize].stage;
3385 if context.pipeline_options.allow_and_force_point_size
3386 && stage == crate::ShaderStage::Vertex
3387 && !has_point_size
3388 {
3389 write!(self.out, ", 1.0")?;
3391 }
3392 }
3393 write!(self.out, " }}")?;
3394 }
3395 None => {
3396 write!(self.out, "{level}return ")?;
3397 self.put_expression(expr_handle, context, true)?;
3398 }
3399 }
3400 writeln!(self.out, ";")?;
3401 Ok(())
3402 }
3403
3404 fn update_expressions_to_bake(
3409 &mut self,
3410 func: &crate::Function,
3411 info: &valid::FunctionInfo,
3412 context: &ExpressionContext,
3413 ) {
3414 use crate::Expression;
3415 self.need_bake_expressions.clear();
3416
3417 for (expr_handle, expr) in func.expressions.iter() {
3418 let expr_info = &info[expr_handle];
3421 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3422 if min_ref_count <= expr_info.ref_count {
3423 self.need_bake_expressions.insert(expr_handle);
3424 } else {
3425 match expr_info.ty {
3426 TypeResolution::Handle(h)
3428 if Some(h) == context.module.special_types.ray_desc =>
3429 {
3430 self.need_bake_expressions.insert(expr_handle);
3431 }
3432 _ => {}
3433 }
3434 }
3435
3436 if let Expression::Math {
3437 fun,
3438 arg,
3439 arg1,
3440 arg2,
3441 ..
3442 } = *expr
3443 {
3444 match fun {
3445 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3455 self.need_bake_expressions.insert(arg);
3456 self.need_bake_expressions.insert(arg1.unwrap());
3457 }
3458 crate::MathFunction::FirstLeadingBit => {
3459 self.need_bake_expressions.insert(arg);
3460 }
3461 crate::MathFunction::Pack4xI8
3462 | crate::MathFunction::Pack4xU8
3463 | crate::MathFunction::Pack4xI8Clamp
3464 | crate::MathFunction::Pack4xU8Clamp
3465 | crate::MathFunction::Unpack4xI8
3466 | crate::MathFunction::Unpack4xU8 => {
3467 if context.lang_version < (2, 1) {
3470 self.need_bake_expressions.insert(arg);
3471 }
3472 }
3473 crate::MathFunction::ExtractBits => {
3474 self.need_bake_expressions.insert(arg1.unwrap());
3476 }
3477 crate::MathFunction::InsertBits => {
3478 self.need_bake_expressions.insert(arg2.unwrap());
3480 }
3481 crate::MathFunction::Sign => {
3482 let inner = context.resolve_type(expr_handle);
3487 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3488 self.need_bake_expressions.insert(arg);
3489 }
3490 }
3491 _ => {}
3492 }
3493 }
3494 }
3495 }
3496
3497 fn start_baking_expression(
3498 &mut self,
3499 handle: Handle<crate::Expression>,
3500 context: &ExpressionContext,
3501 name: &str,
3502 ) -> BackendResult {
3503 match context.info[handle].ty {
3504 TypeResolution::Handle(ty_handle) => {
3505 let ty_name = TypeContext {
3506 handle: ty_handle,
3507 gctx: context.module.to_ctx(),
3508 names: &self.names,
3509 access: crate::StorageAccess::empty(),
3510 first_time: false,
3511 };
3512 write!(self.out, "{ty_name}")?;
3513 }
3514 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3515 put_numeric_type(&mut self.out, scalar, &[])?;
3516 }
3517 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3518 put_numeric_type(&mut self.out, scalar, &[size])?;
3519 }
3520 TypeResolution::Value(crate::TypeInner::Matrix {
3521 columns,
3522 rows,
3523 scalar,
3524 }) => {
3525 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3526 }
3527 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3528 columns,
3529 rows,
3530 scalar,
3531 role: _,
3532 }) => {
3533 write!(
3534 self.out,
3535 "{}::simdgroup_{}{}x{}",
3536 NAMESPACE,
3537 scalar.to_msl_name(),
3538 columns as u32,
3539 rows as u32,
3540 )?;
3541 }
3542 TypeResolution::Value(ref other) => {
3543 log::warn!("Type {other:?} isn't a known local");
3544 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3545 }
3546 }
3547
3548 write!(self.out, " {name} = ")?;
3550
3551 Ok(())
3552 }
3553
3554 fn put_cache_restricted_level(
3567 &mut self,
3568 load: Handle<crate::Expression>,
3569 image: Handle<crate::Expression>,
3570 mip_level: Option<Handle<crate::Expression>>,
3571 indent: back::Level,
3572 context: &StatementContext,
3573 ) -> BackendResult {
3574 let level_of_detail = match mip_level {
3577 Some(level) => level,
3578 None => return Ok(()),
3579 };
3580
3581 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3582 || !context.expression.image_needs_lod(image)
3583 {
3584 return Ok(());
3585 }
3586
3587 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3588 self.put_restricted_scalar_image_index(
3589 image,
3590 level_of_detail,
3591 "get_num_mip_levels",
3592 &context.expression,
3593 )?;
3594 writeln!(self.out, ";")?;
3595
3596 Ok(())
3597 }
3598
3599 fn put_casting_to_packed_chars(
3605 &mut self,
3606 fun: crate::MathFunction,
3607 arg0: Handle<crate::Expression>,
3608 arg1: Handle<crate::Expression>,
3609 indent: back::Level,
3610 context: &StatementContext<'_>,
3611 ) -> Result<(), Error> {
3612 let packed_type = match fun {
3613 crate::MathFunction::Dot4I8Packed => "packed_char4",
3614 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3615 _ => unreachable!(),
3616 };
3617
3618 for arg in [arg0, arg1] {
3619 write!(
3620 self.out,
3621 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3622 Reinterpreted::new(packed_type, arg)
3623 )?;
3624 self.put_expression(arg, &context.expression, true)?;
3625 writeln!(self.out, ");")?;
3626 }
3627
3628 Ok(())
3629 }
3630
3631 fn put_block(
3632 &mut self,
3633 level: back::Level,
3634 statements: &[crate::Statement],
3635 context: &StatementContext,
3636 ) -> BackendResult {
3637 #[cfg(test)]
3639 self.put_block_stack_pointers
3640 .insert(ptr::from_ref(&level).cast());
3641
3642 for statement in statements {
3643 log::trace!("statement[{}] {:?}", level.0, statement);
3644 match *statement {
3645 crate::Statement::Emit(ref range) => {
3646 for handle in range.clone() {
3647 use crate::MathFunction as Mf;
3648
3649 match context.expression.function.expressions[handle] {
3650 crate::Expression::ImageLoad {
3653 image,
3654 level: mip_level,
3655 ..
3656 } => {
3657 self.put_cache_restricted_level(
3658 handle, image, mip_level, level, context,
3659 )?;
3660 }
3661
3662 crate::Expression::Math {
3671 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3672 arg,
3673 arg1,
3674 ..
3675 } if context.expression.lang_version >= (2, 1) => {
3676 self.put_casting_to_packed_chars(
3677 fun,
3678 arg,
3679 arg1.unwrap(),
3680 level,
3681 context,
3682 )?;
3683 }
3684
3685 _ => (),
3686 }
3687
3688 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3689 let expr_name = if ptr_class.is_some() {
3690 None } else if let Some(name) =
3692 context.expression.function.named_expressions.get(&handle)
3693 {
3694 Some(self.namer.call(name))
3704 } else {
3705 let bake = if context.expression.guarded_indices.contains(handle) {
3709 true
3710 } else {
3711 self.need_bake_expressions.contains(&handle)
3712 };
3713
3714 if bake {
3715 Some(Baked(handle).to_string())
3716 } else {
3717 None
3718 }
3719 };
3720
3721 if let Some(name) = expr_name {
3722 write!(self.out, "{level}")?;
3723 self.start_baking_expression(handle, &context.expression, &name)?;
3724 self.put_expression(handle, &context.expression, true)?;
3725 self.named_expressions.insert(handle, name);
3726 writeln!(self.out, ";")?;
3727 }
3728 }
3729 }
3730 crate::Statement::Block(ref block) => {
3731 if !block.is_empty() {
3732 writeln!(self.out, "{level}{{")?;
3733 self.put_block(level.next(), block, context)?;
3734 writeln!(self.out, "{level}}}")?;
3735 }
3736 }
3737 crate::Statement::If {
3738 condition,
3739 ref accept,
3740 ref reject,
3741 } => {
3742 write!(self.out, "{level}if (")?;
3743 self.put_expression(condition, &context.expression, true)?;
3744 writeln!(self.out, ") {{")?;
3745 self.put_block(level.next(), accept, context)?;
3746 if !reject.is_empty() {
3747 writeln!(self.out, "{level}}} else {{")?;
3748 self.put_block(level.next(), reject, context)?;
3749 }
3750 writeln!(self.out, "{level}}}")?;
3751 }
3752 crate::Statement::Switch {
3753 selector,
3754 ref cases,
3755 } => {
3756 write!(self.out, "{level}switch(")?;
3757 self.put_expression(selector, &context.expression, true)?;
3758 writeln!(self.out, ") {{")?;
3759 let lcase = level.next();
3760 for case in cases.iter() {
3761 match case.value {
3762 crate::SwitchValue::I32(value) => {
3763 write!(self.out, "{lcase}case {value}:")?;
3764 }
3765 crate::SwitchValue::U32(value) => {
3766 write!(self.out, "{lcase}case {value}u:")?;
3767 }
3768 crate::SwitchValue::Default => {
3769 write!(self.out, "{lcase}default:")?;
3770 }
3771 }
3772
3773 let write_block_braces = !(case.fall_through && case.body.is_empty());
3774 if write_block_braces {
3775 writeln!(self.out, " {{")?;
3776 } else {
3777 writeln!(self.out)?;
3778 }
3779
3780 self.put_block(lcase.next(), &case.body, context)?;
3781 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3782 {
3783 writeln!(self.out, "{}break;", lcase.next())?;
3784 }
3785
3786 if write_block_braces {
3787 writeln!(self.out, "{lcase}}}")?;
3788 }
3789 }
3790 writeln!(self.out, "{level}}}")?;
3791 }
3792 crate::Statement::Loop {
3793 ref body,
3794 ref continuing,
3795 break_if,
3796 } => {
3797 let force_loop_bound_statements =
3798 self.gen_force_bounded_loop_statements(level, context);
3799 let gate_name = (!continuing.is_empty() || break_if.is_some())
3800 .then(|| self.namer.call("loop_init"));
3801
3802 if let Some((ref decl, _)) = force_loop_bound_statements {
3803 writeln!(self.out, "{decl}")?;
3804 }
3805 if let Some(ref gate_name) = gate_name {
3806 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3807 }
3808
3809 writeln!(self.out, "{level}while(true) {{",)?;
3810 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3811 writeln!(self.out, "{break_and_inc}")?;
3812 }
3813 if let Some(ref gate_name) = gate_name {
3814 let lif = level.next();
3815 let lcontinuing = lif.next();
3816 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3817 self.put_block(lcontinuing, continuing, context)?;
3818 if let Some(condition) = break_if {
3819 write!(self.out, "{lcontinuing}if (")?;
3820 self.put_expression(condition, &context.expression, true)?;
3821 writeln!(self.out, ") {{")?;
3822 writeln!(self.out, "{}break;", lcontinuing.next())?;
3823 writeln!(self.out, "{lcontinuing}}}")?;
3824 }
3825 writeln!(self.out, "{lif}}}")?;
3826 writeln!(self.out, "{lif}{gate_name} = false;")?;
3827 }
3828 self.put_block(level.next(), body, context)?;
3829
3830 writeln!(self.out, "{level}}}")?;
3831 }
3832 crate::Statement::Break => {
3833 writeln!(self.out, "{level}break;")?;
3834 }
3835 crate::Statement::Continue => {
3836 writeln!(self.out, "{level}continue;")?;
3837 }
3838 crate::Statement::Return {
3839 value: Some(expr_handle),
3840 } => {
3841 self.put_return_value(
3842 level,
3843 expr_handle,
3844 context.result_struct,
3845 &context.expression,
3846 )?;
3847 }
3848 crate::Statement::Return { value: None } => {
3849 writeln!(self.out, "{level}return;")?;
3850 }
3851 crate::Statement::Kill => {
3852 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3853 }
3854 crate::Statement::ControlBarrier(flags)
3855 | crate::Statement::MemoryBarrier(flags) => {
3856 self.write_barrier(flags, level)?;
3857 }
3858 crate::Statement::Store { pointer, value } => {
3859 self.put_store(pointer, value, level, context)?
3860 }
3861 crate::Statement::ImageStore {
3862 image,
3863 coordinate,
3864 array_index,
3865 value,
3866 } => {
3867 let address = TexelAddress {
3868 coordinate,
3869 array_index,
3870 sample: None,
3871 level: None,
3872 };
3873 self.put_image_store(level, image, &address, value, context)?
3874 }
3875 crate::Statement::Call {
3876 function,
3877 ref arguments,
3878 result,
3879 } => {
3880 write!(self.out, "{level}")?;
3881 if let Some(expr) = result {
3882 let name = Baked(expr).to_string();
3883 self.start_baking_expression(expr, &context.expression, &name)?;
3884 self.named_expressions.insert(expr, name);
3885 }
3886 let fun_name = &self.names[&NameKey::Function(function)];
3887 write!(self.out, "{fun_name}(")?;
3888 for (i, &handle) in arguments.iter().enumerate() {
3890 if i != 0 {
3891 write!(self.out, ", ")?;
3892 }
3893 self.put_expression(handle, &context.expression, true)?;
3894 }
3895 let mut separate = !arguments.is_empty();
3897 let fun_info = &context.expression.mod_info[function];
3898 let mut needs_buffer_sizes = false;
3899 for (handle, var) in context.expression.module.global_variables.iter() {
3900 if fun_info[handle].is_empty() {
3901 continue;
3902 }
3903 if var.space.needs_pass_through() {
3904 let name = &self.names[&NameKey::GlobalVariable(handle)];
3905 if separate {
3906 write!(self.out, ", ")?;
3907 } else {
3908 separate = true;
3909 }
3910 write!(self.out, "{name}")?;
3911 }
3912 needs_buffer_sizes |=
3913 needs_array_length(var.ty, &context.expression.module.types);
3914 }
3915 if needs_buffer_sizes {
3916 if separate {
3917 write!(self.out, ", ")?;
3918 }
3919 write!(self.out, "_buffer_sizes")?;
3920 }
3921
3922 writeln!(self.out, ");")?;
3924 }
3925 crate::Statement::Atomic {
3926 pointer,
3927 ref fun,
3928 value,
3929 result,
3930 } => {
3931 let context = &context.expression;
3932
3933 write!(self.out, "{level}")?;
3938 let fun_key = if let Some(result) = result {
3939 let res_name = Baked(result).to_string();
3940 self.start_baking_expression(result, context, &res_name)?;
3941 self.named_expressions.insert(result, res_name);
3942 fun.to_msl()
3943 } else if context.resolve_type(value).scalar_width() == Some(8) {
3944 fun.to_msl_64_bit()?
3945 } else {
3946 fun.to_msl()
3947 };
3948
3949 let policy = context.choose_bounds_check_policy(pointer);
3953 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3954 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3955
3956 if checked {
3958 write!(self.out, " ? ")?;
3959 }
3960
3961 match *fun {
3963 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3964 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3965 self.put_access_chain(pointer, policy, context)?;
3966 write!(self.out, ", ")?;
3967 self.put_expression(cmp, context, true)?;
3968 write!(self.out, ", ")?;
3969 self.put_expression(value, context, true)?;
3970 write!(self.out, ")")?;
3971 }
3972 _ => {
3973 write!(
3974 self.out,
3975 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3976 )?;
3977 self.put_access_chain(pointer, policy, context)?;
3978 write!(self.out, ", ")?;
3979 self.put_expression(value, context, true)?;
3980 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3981 }
3982 }
3983
3984 if checked {
3986 write!(self.out, " : DefaultConstructible()")?;
3987 }
3988
3989 writeln!(self.out, ";")?;
3991 }
3992 crate::Statement::ImageAtomic {
3993 image,
3994 coordinate,
3995 array_index,
3996 fun,
3997 value,
3998 } => {
3999 let address = TexelAddress {
4000 coordinate,
4001 array_index,
4002 sample: None,
4003 level: None,
4004 };
4005 self.put_image_atomic(level, image, &address, fun, value, context)?
4006 }
4007 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4008 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4009
4010 write!(self.out, "{level}")?;
4011 let name = self.namer.call("");
4012 self.start_baking_expression(result, &context.expression, &name)?;
4013 self.put_load(pointer, &context.expression, true)?;
4014 self.named_expressions.insert(result, name);
4015
4016 writeln!(self.out, ";")?;
4017 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4018 }
4019 crate::Statement::RayQuery { query, ref fun } => {
4020 if context.expression.lang_version < (2, 4) {
4021 return Err(Error::UnsupportedRayTracing);
4022 }
4023
4024 match *fun {
4025 crate::RayQueryFunction::Initialize {
4026 acceleration_structure,
4027 descriptor,
4028 } => {
4029 write!(self.out, "{level}")?;
4031 self.put_expression(query, &context.expression, true)?;
4032 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4033 {
4034 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4035 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4036 write!(self.out, "{level}")?;
4037 self.put_expression(query, &context.expression, true)?;
4038 write!(
4039 self.out,
4040 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4041 )?;
4042 self.put_expression(descriptor, &context.expression, true)?;
4043 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4044 self.put_expression(descriptor, &context.expression, true)?;
4045 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4046 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4047 }
4048 {
4049 let f_opaque = back::RayFlag::OPAQUE.bits();
4050 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4051 write!(self.out, "{level}")?;
4052 self.put_expression(query, &context.expression, true)?;
4053 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4054 self.put_expression(descriptor, &context.expression, true)?;
4055 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4056 self.put_expression(descriptor, &context.expression, true)?;
4057 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4058 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4059 }
4060 {
4061 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4062 write!(self.out, "{level}")?;
4063 self.put_expression(query, &context.expression, true)?;
4064 write!(
4065 self.out,
4066 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4067 )?;
4068 self.put_expression(descriptor, &context.expression, true)?;
4069 writeln!(self.out, ".flags & {flag}) != 0);")?;
4070 }
4071
4072 write!(self.out, "{level}")?;
4073 self.put_expression(query, &context.expression, true)?;
4074 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4075 self.put_expression(query, &context.expression, true)?;
4076 write!(
4077 self.out,
4078 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4079 )?;
4080 self.put_expression(descriptor, &context.expression, true)?;
4081 write!(self.out, ".origin, ")?;
4082 self.put_expression(descriptor, &context.expression, true)?;
4083 write!(self.out, ".dir, ")?;
4084 self.put_expression(descriptor, &context.expression, true)?;
4085 write!(self.out, ".tmin, ")?;
4086 self.put_expression(descriptor, &context.expression, true)?;
4087 write!(self.out, ".tmax), ")?;
4088 self.put_expression(acceleration_structure, &context.expression, true)?;
4089 write!(self.out, ", ")?;
4090 self.put_expression(descriptor, &context.expression, true)?;
4091 write!(self.out, ".cull_mask);")?;
4092
4093 write!(self.out, "{level}")?;
4094 self.put_expression(query, &context.expression, true)?;
4095 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4096 }
4097 crate::RayQueryFunction::Proceed { result } => {
4098 write!(self.out, "{level}")?;
4099 let name = Baked(result).to_string();
4100 self.start_baking_expression(result, &context.expression, &name)?;
4101 self.named_expressions.insert(result, name);
4102 self.put_expression(query, &context.expression, true)?;
4103 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4104 if RAY_QUERY_MODERN_SUPPORT {
4105 write!(self.out, "{level}")?;
4106 self.put_expression(query, &context.expression, true)?;
4107 writeln!(self.out, ".?.next();")?;
4108 }
4109 }
4110 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4111 if RAY_QUERY_MODERN_SUPPORT {
4112 write!(self.out, "{level}")?;
4113 self.put_expression(query, &context.expression, true)?;
4114 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4115 self.put_expression(hit_t, &context.expression, true)?;
4116 writeln!(self.out, ");")?;
4117 } else {
4118 log::warn!("Ray Query GenerateIntersection is not yet supported");
4119 }
4120 }
4121 crate::RayQueryFunction::ConfirmIntersection => {
4122 if RAY_QUERY_MODERN_SUPPORT {
4123 write!(self.out, "{level}")?;
4124 self.put_expression(query, &context.expression, true)?;
4125 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4126 } else {
4127 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4128 }
4129 }
4130 crate::RayQueryFunction::Terminate => {
4131 if RAY_QUERY_MODERN_SUPPORT {
4132 write!(self.out, "{level}")?;
4133 self.put_expression(query, &context.expression, true)?;
4134 writeln!(self.out, ".?.abort();")?;
4135 }
4136 write!(self.out, "{level}")?;
4137 self.put_expression(query, &context.expression, true)?;
4138 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4139 }
4140 }
4141 }
4142 crate::Statement::SubgroupBallot { result, predicate } => {
4143 write!(self.out, "{level}")?;
4144 let name = self.namer.call("");
4145 self.start_baking_expression(result, &context.expression, &name)?;
4146 self.named_expressions.insert(result, name);
4147 write!(
4148 self.out,
4149 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4150 )?;
4151 if let Some(predicate) = predicate {
4152 self.put_expression(predicate, &context.expression, true)?;
4153 } else {
4154 write!(self.out, "true")?;
4155 }
4156 writeln!(self.out, "), 0, 0, 0);")?;
4157 }
4158 crate::Statement::SubgroupCollectiveOperation {
4159 op,
4160 collective_op,
4161 argument,
4162 result,
4163 } => {
4164 write!(self.out, "{level}")?;
4165 let name = self.namer.call("");
4166 self.start_baking_expression(result, &context.expression, &name)?;
4167 self.named_expressions.insert(result, name);
4168 match (collective_op, op) {
4169 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4170 write!(self.out, "{NAMESPACE}::simd_all(")?
4171 }
4172 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4173 write!(self.out, "{NAMESPACE}::simd_any(")?
4174 }
4175 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4176 write!(self.out, "{NAMESPACE}::simd_sum(")?
4177 }
4178 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4179 write!(self.out, "{NAMESPACE}::simd_product(")?
4180 }
4181 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4182 write!(self.out, "{NAMESPACE}::simd_max(")?
4183 }
4184 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4185 write!(self.out, "{NAMESPACE}::simd_min(")?
4186 }
4187 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4188 write!(self.out, "{NAMESPACE}::simd_and(")?
4189 }
4190 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4191 write!(self.out, "{NAMESPACE}::simd_or(")?
4192 }
4193 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4194 write!(self.out, "{NAMESPACE}::simd_xor(")?
4195 }
4196 (
4197 crate::CollectiveOperation::ExclusiveScan,
4198 crate::SubgroupOperation::Add,
4199 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4200 (
4201 crate::CollectiveOperation::ExclusiveScan,
4202 crate::SubgroupOperation::Mul,
4203 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4204 (
4205 crate::CollectiveOperation::InclusiveScan,
4206 crate::SubgroupOperation::Add,
4207 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4208 (
4209 crate::CollectiveOperation::InclusiveScan,
4210 crate::SubgroupOperation::Mul,
4211 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4212 _ => unimplemented!(),
4213 }
4214 self.put_expression(argument, &context.expression, true)?;
4215 writeln!(self.out, ");")?;
4216 }
4217 crate::Statement::SubgroupGather {
4218 mode,
4219 argument,
4220 result,
4221 } => {
4222 write!(self.out, "{level}")?;
4223 let name = self.namer.call("");
4224 self.start_baking_expression(result, &context.expression, &name)?;
4225 self.named_expressions.insert(result, name);
4226 match mode {
4227 crate::GatherMode::BroadcastFirst => {
4228 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4229 }
4230 crate::GatherMode::Broadcast(_) => {
4231 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4232 }
4233 crate::GatherMode::Shuffle(_) => {
4234 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4235 }
4236 crate::GatherMode::ShuffleDown(_) => {
4237 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4238 }
4239 crate::GatherMode::ShuffleUp(_) => {
4240 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4241 }
4242 crate::GatherMode::ShuffleXor(_) => {
4243 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4244 }
4245 crate::GatherMode::QuadBroadcast(_) => {
4246 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4247 }
4248 crate::GatherMode::QuadSwap(_) => {
4249 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4250 }
4251 }
4252 self.put_expression(argument, &context.expression, true)?;
4253 match mode {
4254 crate::GatherMode::BroadcastFirst => {}
4255 crate::GatherMode::Broadcast(index)
4256 | crate::GatherMode::Shuffle(index)
4257 | crate::GatherMode::ShuffleDown(index)
4258 | crate::GatherMode::ShuffleUp(index)
4259 | crate::GatherMode::ShuffleXor(index)
4260 | crate::GatherMode::QuadBroadcast(index) => {
4261 write!(self.out, ", ")?;
4262 self.put_expression(index, &context.expression, true)?;
4263 }
4264 crate::GatherMode::QuadSwap(direction) => {
4265 write!(self.out, ", ")?;
4266 match direction {
4267 crate::Direction::X => {
4268 write!(self.out, "1u")?;
4269 }
4270 crate::Direction::Y => {
4271 write!(self.out, "2u")?;
4272 }
4273 crate::Direction::Diagonal => {
4274 write!(self.out, "3u")?;
4275 }
4276 }
4277 }
4278 }
4279 writeln!(self.out, ");")?;
4280 }
4281 crate::Statement::CooperativeStore { target, ref data } => {
4282 write!(self.out, "{level}simdgroup_store(")?;
4283 self.put_expression(target, &context.expression, true)?;
4284 write!(self.out, ", &")?;
4285 self.put_access_chain(
4286 data.pointer,
4287 context.expression.policies.index,
4288 &context.expression,
4289 )?;
4290 write!(self.out, ", ")?;
4291 self.put_expression(data.stride, &context.expression, true)?;
4292 if data.row_major {
4293 let matrix_origin = "0";
4294 let transpose = true;
4295 write!(self.out, ", {matrix_origin}, {transpose}")?;
4296 }
4297 writeln!(self.out, ");")?;
4298 }
4299 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4300 }
4301 }
4302
4303 for statement in statements {
4306 if let crate::Statement::Emit(ref range) = *statement {
4307 for handle in range.clone() {
4308 self.named_expressions.shift_remove(&handle);
4309 }
4310 }
4311 }
4312 Ok(())
4313 }
4314
4315 fn put_store(
4316 &mut self,
4317 pointer: Handle<crate::Expression>,
4318 value: Handle<crate::Expression>,
4319 level: back::Level,
4320 context: &StatementContext,
4321 ) -> BackendResult {
4322 let policy = context.expression.choose_bounds_check_policy(pointer);
4323 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4324 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4325 {
4326 writeln!(self.out, ") {{")?;
4327 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4328 writeln!(self.out, "{level}}}")?;
4329 } else {
4330 self.put_unchecked_store(pointer, value, policy, level, context)?;
4331 }
4332
4333 Ok(())
4334 }
4335
4336 fn put_unchecked_store(
4337 &mut self,
4338 pointer: Handle<crate::Expression>,
4339 value: Handle<crate::Expression>,
4340 policy: index::BoundsCheckPolicy,
4341 level: back::Level,
4342 context: &StatementContext,
4343 ) -> BackendResult {
4344 let is_atomic_pointer = context
4345 .expression
4346 .resolve_type(pointer)
4347 .is_atomic_pointer(&context.expression.module.types);
4348
4349 if is_atomic_pointer {
4350 write!(
4351 self.out,
4352 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4353 )?;
4354 self.put_access_chain(pointer, policy, &context.expression)?;
4355 write!(self.out, ", ")?;
4356 self.put_expression(value, &context.expression, true)?;
4357 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4358 } else {
4359 write!(self.out, "{level}")?;
4360 self.put_access_chain(pointer, policy, &context.expression)?;
4361 write!(self.out, " = ")?;
4362 self.put_expression(value, &context.expression, true)?;
4363 writeln!(self.out, ";")?;
4364 }
4365
4366 Ok(())
4367 }
4368
4369 pub fn write(
4370 &mut self,
4371 module: &crate::Module,
4372 info: &valid::ModuleInfo,
4373 options: &Options,
4374 pipeline_options: &PipelineOptions,
4375 ) -> Result<TranslationInfo, Error> {
4376 self.names.clear();
4377 self.namer.reset(
4378 module,
4379 &super::keywords::RESERVED_SET,
4380 proc::KeywordSet::empty(),
4381 proc::CaseInsensitiveKeywordSet::empty(),
4382 &[CLAMPED_LOD_LOAD_PREFIX],
4383 &mut self.names,
4384 );
4385 self.wrapped_functions.clear();
4386 self.struct_member_pads.clear();
4387
4388 writeln!(
4389 self.out,
4390 "// language: metal{}.{}",
4391 options.lang_version.0, options.lang_version.1
4392 )?;
4393 writeln!(self.out, "#include <metal_stdlib>")?;
4394 writeln!(self.out, "#include <simd/simd.h>")?;
4395 writeln!(self.out)?;
4396 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4398
4399 let mut uses_ray_query = false;
4400 for (_, ty) in module.types.iter() {
4401 match ty.inner {
4402 crate::TypeInner::AccelerationStructure { .. } => {
4403 if options.lang_version < (2, 4) {
4404 return Err(Error::UnsupportedRayTracing);
4405 }
4406 }
4407 crate::TypeInner::RayQuery { .. } => {
4408 if options.lang_version < (2, 4) {
4409 return Err(Error::UnsupportedRayTracing);
4410 }
4411 uses_ray_query = true;
4412 }
4413 _ => (),
4414 }
4415 }
4416
4417 if module.special_types.ray_desc.is_some()
4418 || module.special_types.ray_intersection.is_some()
4419 {
4420 if options.lang_version < (2, 4) {
4421 return Err(Error::UnsupportedRayTracing);
4422 }
4423 }
4424
4425 if uses_ray_query {
4426 self.put_ray_query_type()?;
4427 }
4428
4429 if options
4430 .bounds_check_policies
4431 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4432 {
4433 self.put_default_constructible()?;
4434 }
4435 writeln!(self.out)?;
4436
4437 {
4438 let globals: Vec<Handle<crate::GlobalVariable>> = module
4441 .global_variables
4442 .iter()
4443 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4444 .map(|(handle, _)| handle)
4445 .collect();
4446
4447 let mut buffer_indices = vec![];
4448 for vbm in &pipeline_options.vertex_buffer_mappings {
4449 buffer_indices.push(vbm.id);
4450 }
4451
4452 if !globals.is_empty() || !buffer_indices.is_empty() {
4453 writeln!(self.out, "struct _mslBufferSizes {{")?;
4454
4455 for global in globals {
4456 writeln!(
4457 self.out,
4458 "{}uint {};",
4459 back::INDENT,
4460 ArraySizeMember(global)
4461 )?;
4462 }
4463
4464 for idx in buffer_indices {
4465 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4466 }
4467
4468 writeln!(self.out, "}};")?;
4469 writeln!(self.out)?;
4470 }
4471 };
4472
4473 self.write_type_defs(module)?;
4474 self.write_global_constants(module, info)?;
4475 self.write_functions(module, info, options, pipeline_options)
4476 }
4477
4478 fn put_default_constructible(&mut self) -> BackendResult {
4491 let tab = back::INDENT;
4492 writeln!(self.out, "struct DefaultConstructible {{")?;
4493 writeln!(self.out, "{tab}template<typename T>")?;
4494 writeln!(self.out, "{tab}operator T() && {{")?;
4495 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4496 writeln!(self.out, "{tab}}}")?;
4497 writeln!(self.out, "}};")?;
4498 Ok(())
4499 }
4500
4501 fn put_ray_query_type(&mut self) -> BackendResult {
4502 let tab = back::INDENT;
4503 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4504 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4505 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4506 writeln!(
4507 self.out,
4508 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4509 )?;
4510 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4511 writeln!(self.out, "}};")?;
4512 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4513 let v_triangle = back::RayIntersectionType::Triangle as u32;
4514 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4515 writeln!(
4516 self.out,
4517 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4518 )?;
4519 writeln!(
4520 self.out,
4521 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4522 )?;
4523 writeln!(self.out, "}}")?;
4524 Ok(())
4525 }
4526
4527 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4528 let mut generated_argument_buffer_wrapper = false;
4529 let mut generated_external_texture_wrapper = false;
4530 for (handle, ty) in module.types.iter() {
4531 match ty.inner {
4532 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4533 writeln!(self.out, "template <typename T>")?;
4534 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4535 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4536 writeln!(self.out, "}};")?;
4537 generated_argument_buffer_wrapper = true;
4538 }
4539 crate::TypeInner::Image {
4540 class: crate::ImageClass::External,
4541 ..
4542 } if !generated_external_texture_wrapper => {
4543 let params_ty_name = &self.names
4544 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4545 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4546 writeln!(
4547 self.out,
4548 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4549 back::INDENT
4550 )?;
4551 writeln!(
4552 self.out,
4553 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4554 back::INDENT
4555 )?;
4556 writeln!(
4557 self.out,
4558 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4559 back::INDENT
4560 )?;
4561 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4562 writeln!(self.out, "}};")?;
4563 generated_external_texture_wrapper = true;
4564 }
4565 _ => {}
4566 }
4567
4568 if !ty.needs_alias() {
4569 continue;
4570 }
4571 let name = &self.names[&NameKey::Type(handle)];
4572 match ty.inner {
4573 crate::TypeInner::Array {
4587 base,
4588 size,
4589 stride: _,
4590 } => {
4591 let base_name = TypeContext {
4592 handle: base,
4593 gctx: module.to_ctx(),
4594 names: &self.names,
4595 access: crate::StorageAccess::empty(),
4596 first_time: false,
4597 };
4598
4599 match size.resolve(module.to_ctx())? {
4600 proc::IndexableLength::Known(size) => {
4601 writeln!(self.out, "struct {name} {{")?;
4602 writeln!(
4603 self.out,
4604 "{}{} {}[{}];",
4605 back::INDENT,
4606 base_name,
4607 WRAPPED_ARRAY_FIELD,
4608 size
4609 )?;
4610 writeln!(self.out, "}};")?;
4611 }
4612 proc::IndexableLength::Dynamic => {
4613 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4614 }
4615 }
4616 }
4617 crate::TypeInner::Struct {
4618 ref members, span, ..
4619 } => {
4620 writeln!(self.out, "struct {name} {{")?;
4621 let mut last_offset = 0;
4622 for (index, member) in members.iter().enumerate() {
4623 if member.offset > last_offset {
4624 self.struct_member_pads.insert((handle, index as u32));
4625 let pad = member.offset - last_offset;
4626 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4627 }
4628 let ty_inner = &module.types[member.ty].inner;
4629 last_offset = member.offset + ty_inner.size(module.to_ctx());
4630
4631 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4632
4633 match should_pack_struct_member(members, span, index, module) {
4635 Some(scalar) => {
4636 writeln!(
4637 self.out,
4638 "{}{}::packed_{}3 {};",
4639 back::INDENT,
4640 NAMESPACE,
4641 scalar.to_msl_name(),
4642 member_name
4643 )?;
4644 }
4645 None => {
4646 let base_name = TypeContext {
4647 handle: member.ty,
4648 gctx: module.to_ctx(),
4649 names: &self.names,
4650 access: crate::StorageAccess::empty(),
4651 first_time: false,
4652 };
4653 writeln!(
4654 self.out,
4655 "{}{} {};",
4656 back::INDENT,
4657 base_name,
4658 member_name
4659 )?;
4660
4661 if let crate::TypeInner::Vector {
4663 size: crate::VectorSize::Tri,
4664 scalar,
4665 } = *ty_inner
4666 {
4667 last_offset += scalar.width as u32;
4668 }
4669 }
4670 }
4671 }
4672 if last_offset < span {
4673 let pad = span - last_offset;
4674 writeln!(
4675 self.out,
4676 "{}char _pad{}[{}];",
4677 back::INDENT,
4678 members.len(),
4679 pad
4680 )?;
4681 }
4682 writeln!(self.out, "}};")?;
4683 }
4684 _ => {
4685 let ty_name = TypeContext {
4686 handle,
4687 gctx: module.to_ctx(),
4688 names: &self.names,
4689 access: crate::StorageAccess::empty(),
4690 first_time: true,
4691 };
4692 writeln!(self.out, "typedef {ty_name} {name};")?;
4693 }
4694 }
4695 }
4696
4697 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4699 match type_key {
4700 &crate::PredeclaredType::ModfResult { size, scalar }
4701 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4702 let arg_type_name_owner;
4703 let arg_type_name = if let Some(size) = size {
4704 arg_type_name_owner = format!(
4705 "{NAMESPACE}::{}{}",
4706 if scalar.width == 8 { "double" } else { "float" },
4707 size as u8
4708 );
4709 &arg_type_name_owner
4710 } else if scalar.width == 8 {
4711 "double"
4712 } else {
4713 "float"
4714 };
4715
4716 let other_type_name_owner;
4717 let (defined_func_name, called_func_name, other_type_name) =
4718 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4719 (MODF_FUNCTION, "modf", arg_type_name)
4720 } else {
4721 let other_type_name = if let Some(size) = size {
4722 other_type_name_owner = format!("int{}", size as u8);
4723 &other_type_name_owner
4724 } else {
4725 "int"
4726 };
4727 (FREXP_FUNCTION, "frexp", other_type_name)
4728 };
4729
4730 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4731
4732 writeln!(self.out)?;
4733 writeln!(
4734 self.out,
4735 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4736 {other_type_name} other;
4737 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4738 return {struct_name}{{ fract, other }};
4739}}"
4740 )?;
4741 }
4742 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4743 let arg_type_name = scalar.to_msl_name();
4744 let called_func_name = "atomic_compare_exchange_weak_explicit";
4745 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4746 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4747
4748 writeln!(self.out)?;
4749
4750 for address_space_name in ["device", "threadgroup"] {
4751 writeln!(
4752 self.out,
4753 "\
4754template <typename A>
4755{struct_name} {defined_func_name}(
4756 {address_space_name} A *atomic_ptr,
4757 {arg_type_name} cmp,
4758 {arg_type_name} v
4759) {{
4760 bool swapped = {NAMESPACE}::{called_func_name}(
4761 atomic_ptr, &cmp, v,
4762 metal::memory_order_relaxed, metal::memory_order_relaxed
4763 );
4764 return {struct_name}{{cmp, swapped}};
4765}}"
4766 )?;
4767 }
4768 }
4769 }
4770 }
4771
4772 Ok(())
4773 }
4774
4775 fn write_global_constants(
4777 &mut self,
4778 module: &crate::Module,
4779 mod_info: &valid::ModuleInfo,
4780 ) -> BackendResult {
4781 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4782
4783 for (handle, constant) in constants {
4784 let ty_name = TypeContext {
4785 handle: constant.ty,
4786 gctx: module.to_ctx(),
4787 names: &self.names,
4788 access: crate::StorageAccess::empty(),
4789 first_time: false,
4790 };
4791 let name = &self.names[&NameKey::Constant(handle)];
4792 write!(self.out, "constant {ty_name} {name} = ")?;
4793 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4794 writeln!(self.out, ";")?;
4795 }
4796
4797 Ok(())
4798 }
4799
4800 fn put_inline_sampler_properties(
4801 &mut self,
4802 level: back::Level,
4803 sampler: &sm::InlineSampler,
4804 ) -> BackendResult {
4805 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4806 writeln!(
4807 self.out,
4808 "{}{}::{}_address::{},",
4809 level,
4810 NAMESPACE,
4811 letter,
4812 address.as_str(),
4813 )?;
4814 }
4815 writeln!(
4816 self.out,
4817 "{}{}::mag_filter::{},",
4818 level,
4819 NAMESPACE,
4820 sampler.mag_filter.as_str(),
4821 )?;
4822 writeln!(
4823 self.out,
4824 "{}{}::min_filter::{},",
4825 level,
4826 NAMESPACE,
4827 sampler.min_filter.as_str(),
4828 )?;
4829 if let Some(filter) = sampler.mip_filter {
4830 writeln!(
4831 self.out,
4832 "{}{}::mip_filter::{},",
4833 level,
4834 NAMESPACE,
4835 filter.as_str(),
4836 )?;
4837 }
4838 if sampler.border_color != sm::BorderColor::TransparentBlack {
4840 writeln!(
4841 self.out,
4842 "{}{}::border_color::{},",
4843 level,
4844 NAMESPACE,
4845 sampler.border_color.as_str(),
4846 )?;
4847 }
4848 if false {
4852 if let Some(ref lod) = sampler.lod_clamp {
4853 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4854 }
4855 if let Some(aniso) = sampler.max_anisotropy {
4856 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4857 }
4858 }
4859 if sampler.compare_func != sm::CompareFunc::Never {
4860 writeln!(
4861 self.out,
4862 "{}{}::compare_func::{},",
4863 level,
4864 NAMESPACE,
4865 sampler.compare_func.as_str(),
4866 )?;
4867 }
4868 writeln!(
4869 self.out,
4870 "{}{}::coord::{}",
4871 level,
4872 NAMESPACE,
4873 sampler.coord.as_str()
4874 )?;
4875 Ok(())
4876 }
4877
4878 fn write_unpacking_function(
4879 &mut self,
4880 format: back::msl::VertexFormat,
4881 ) -> Result<(String, u32, u32), Error> {
4882 use back::msl::VertexFormat::*;
4883 match format {
4884 Uint8 => {
4885 let name = self.namer.call("unpackUint8");
4886 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4887 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4888 writeln!(self.out, "}}")?;
4889 Ok((name, 1, 1))
4890 }
4891 Uint8x2 => {
4892 let name = self.namer.call("unpackUint8x2");
4893 writeln!(
4894 self.out,
4895 "metal::uint2 {name}(metal::uchar b0, \
4896 metal::uchar b1) {{"
4897 )?;
4898 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4899 writeln!(self.out, "}}")?;
4900 Ok((name, 2, 2))
4901 }
4902 Uint8x4 => {
4903 let name = self.namer.call("unpackUint8x4");
4904 writeln!(
4905 self.out,
4906 "metal::uint4 {name}(metal::uchar b0, \
4907 metal::uchar b1, \
4908 metal::uchar b2, \
4909 metal::uchar b3) {{"
4910 )?;
4911 writeln!(
4912 self.out,
4913 "{}return metal::uint4(b0, b1, b2, b3);",
4914 back::INDENT
4915 )?;
4916 writeln!(self.out, "}}")?;
4917 Ok((name, 4, 4))
4918 }
4919 Sint8 => {
4920 let name = self.namer.call("unpackSint8");
4921 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4922 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4923 writeln!(self.out, "}}")?;
4924 Ok((name, 1, 1))
4925 }
4926 Sint8x2 => {
4927 let name = self.namer.call("unpackSint8x2");
4928 writeln!(
4929 self.out,
4930 "metal::int2 {name}(metal::uchar b0, \
4931 metal::uchar b1) {{"
4932 )?;
4933 writeln!(
4934 self.out,
4935 "{}return metal::int2(as_type<char>(b0), \
4936 as_type<char>(b1));",
4937 back::INDENT
4938 )?;
4939 writeln!(self.out, "}}")?;
4940 Ok((name, 2, 2))
4941 }
4942 Sint8x4 => {
4943 let name = self.namer.call("unpackSint8x4");
4944 writeln!(
4945 self.out,
4946 "metal::int4 {name}(metal::uchar b0, \
4947 metal::uchar b1, \
4948 metal::uchar b2, \
4949 metal::uchar b3) {{"
4950 )?;
4951 writeln!(
4952 self.out,
4953 "{}return metal::int4(as_type<char>(b0), \
4954 as_type<char>(b1), \
4955 as_type<char>(b2), \
4956 as_type<char>(b3));",
4957 back::INDENT
4958 )?;
4959 writeln!(self.out, "}}")?;
4960 Ok((name, 4, 4))
4961 }
4962 Unorm8 => {
4963 let name = self.namer.call("unpackUnorm8");
4964 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4965 writeln!(
4966 self.out,
4967 "{}return float(float(b0) / 255.0f);",
4968 back::INDENT
4969 )?;
4970 writeln!(self.out, "}}")?;
4971 Ok((name, 1, 1))
4972 }
4973 Unorm8x2 => {
4974 let name = self.namer.call("unpackUnorm8x2");
4975 writeln!(
4976 self.out,
4977 "metal::float2 {name}(metal::uchar b0, \
4978 metal::uchar b1) {{"
4979 )?;
4980 writeln!(
4981 self.out,
4982 "{}return metal::float2(float(b0) / 255.0f, \
4983 float(b1) / 255.0f);",
4984 back::INDENT
4985 )?;
4986 writeln!(self.out, "}}")?;
4987 Ok((name, 2, 2))
4988 }
4989 Unorm8x4 => {
4990 let name = self.namer.call("unpackUnorm8x4");
4991 writeln!(
4992 self.out,
4993 "metal::float4 {name}(metal::uchar b0, \
4994 metal::uchar b1, \
4995 metal::uchar b2, \
4996 metal::uchar b3) {{"
4997 )?;
4998 writeln!(
4999 self.out,
5000 "{}return metal::float4(float(b0) / 255.0f, \
5001 float(b1) / 255.0f, \
5002 float(b2) / 255.0f, \
5003 float(b3) / 255.0f);",
5004 back::INDENT
5005 )?;
5006 writeln!(self.out, "}}")?;
5007 Ok((name, 4, 4))
5008 }
5009 Snorm8 => {
5010 let name = self.namer.call("unpackSnorm8");
5011 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5012 writeln!(
5013 self.out,
5014 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5015 back::INDENT
5016 )?;
5017 writeln!(self.out, "}}")?;
5018 Ok((name, 1, 1))
5019 }
5020 Snorm8x2 => {
5021 let name = self.namer.call("unpackSnorm8x2");
5022 writeln!(
5023 self.out,
5024 "metal::float2 {name}(metal::uchar b0, \
5025 metal::uchar b1) {{"
5026 )?;
5027 writeln!(
5028 self.out,
5029 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5030 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5031 back::INDENT
5032 )?;
5033 writeln!(self.out, "}}")?;
5034 Ok((name, 2, 2))
5035 }
5036 Snorm8x4 => {
5037 let name = self.namer.call("unpackSnorm8x4");
5038 writeln!(
5039 self.out,
5040 "metal::float4 {name}(metal::uchar b0, \
5041 metal::uchar b1, \
5042 metal::uchar b2, \
5043 metal::uchar b3) {{"
5044 )?;
5045 writeln!(
5046 self.out,
5047 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5048 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5049 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5050 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5051 back::INDENT
5052 )?;
5053 writeln!(self.out, "}}")?;
5054 Ok((name, 4, 4))
5055 }
5056 Uint16 => {
5057 let name = self.namer.call("unpackUint16");
5058 writeln!(
5059 self.out,
5060 "metal::uint {name}(metal::uint b0, \
5061 metal::uint b1) {{"
5062 )?;
5063 writeln!(
5064 self.out,
5065 "{}return metal::uint(b1 << 8 | b0);",
5066 back::INDENT
5067 )?;
5068 writeln!(self.out, "}}")?;
5069 Ok((name, 2, 1))
5070 }
5071 Uint16x2 => {
5072 let name = self.namer.call("unpackUint16x2");
5073 writeln!(
5074 self.out,
5075 "metal::uint2 {name}(metal::uint b0, \
5076 metal::uint b1, \
5077 metal::uint b2, \
5078 metal::uint b3) {{"
5079 )?;
5080 writeln!(
5081 self.out,
5082 "{}return metal::uint2(b1 << 8 | b0, \
5083 b3 << 8 | b2);",
5084 back::INDENT
5085 )?;
5086 writeln!(self.out, "}}")?;
5087 Ok((name, 4, 2))
5088 }
5089 Uint16x4 => {
5090 let name = self.namer.call("unpackUint16x4");
5091 writeln!(
5092 self.out,
5093 "metal::uint4 {name}(metal::uint b0, \
5094 metal::uint b1, \
5095 metal::uint b2, \
5096 metal::uint b3, \
5097 metal::uint b4, \
5098 metal::uint b5, \
5099 metal::uint b6, \
5100 metal::uint b7) {{"
5101 )?;
5102 writeln!(
5103 self.out,
5104 "{}return metal::uint4(b1 << 8 | b0, \
5105 b3 << 8 | b2, \
5106 b5 << 8 | b4, \
5107 b7 << 8 | b6);",
5108 back::INDENT
5109 )?;
5110 writeln!(self.out, "}}")?;
5111 Ok((name, 8, 4))
5112 }
5113 Sint16 => {
5114 let name = self.namer.call("unpackSint16");
5115 writeln!(
5116 self.out,
5117 "int {name}(metal::ushort b0, \
5118 metal::ushort b1) {{"
5119 )?;
5120 writeln!(
5121 self.out,
5122 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5123 back::INDENT
5124 )?;
5125 writeln!(self.out, "}}")?;
5126 Ok((name, 2, 1))
5127 }
5128 Sint16x2 => {
5129 let name = self.namer.call("unpackSint16x2");
5130 writeln!(
5131 self.out,
5132 "metal::int2 {name}(metal::ushort b0, \
5133 metal::ushort b1, \
5134 metal::ushort b2, \
5135 metal::ushort b3) {{"
5136 )?;
5137 writeln!(
5138 self.out,
5139 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5140 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5141 back::INDENT
5142 )?;
5143 writeln!(self.out, "}}")?;
5144 Ok((name, 4, 2))
5145 }
5146 Sint16x4 => {
5147 let name = self.namer.call("unpackSint16x4");
5148 writeln!(
5149 self.out,
5150 "metal::int4 {name}(metal::ushort b0, \
5151 metal::ushort b1, \
5152 metal::ushort b2, \
5153 metal::ushort b3, \
5154 metal::ushort b4, \
5155 metal::ushort b5, \
5156 metal::ushort b6, \
5157 metal::ushort b7) {{"
5158 )?;
5159 writeln!(
5160 self.out,
5161 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5162 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5163 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5164 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5165 back::INDENT
5166 )?;
5167 writeln!(self.out, "}}")?;
5168 Ok((name, 8, 4))
5169 }
5170 Unorm16 => {
5171 let name = self.namer.call("unpackUnorm16");
5172 writeln!(
5173 self.out,
5174 "float {name}(metal::ushort b0, \
5175 metal::ushort b1) {{"
5176 )?;
5177 writeln!(
5178 self.out,
5179 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5180 back::INDENT
5181 )?;
5182 writeln!(self.out, "}}")?;
5183 Ok((name, 2, 1))
5184 }
5185 Unorm16x2 => {
5186 let name = self.namer.call("unpackUnorm16x2");
5187 writeln!(
5188 self.out,
5189 "metal::float2 {name}(metal::ushort b0, \
5190 metal::ushort b1, \
5191 metal::ushort b2, \
5192 metal::ushort b3) {{"
5193 )?;
5194 writeln!(
5195 self.out,
5196 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5197 float(b3 << 8 | b2) / 65535.0f);",
5198 back::INDENT
5199 )?;
5200 writeln!(self.out, "}}")?;
5201 Ok((name, 4, 2))
5202 }
5203 Unorm16x4 => {
5204 let name = self.namer.call("unpackUnorm16x4");
5205 writeln!(
5206 self.out,
5207 "metal::float4 {name}(metal::ushort b0, \
5208 metal::ushort b1, \
5209 metal::ushort b2, \
5210 metal::ushort b3, \
5211 metal::ushort b4, \
5212 metal::ushort b5, \
5213 metal::ushort b6, \
5214 metal::ushort b7) {{"
5215 )?;
5216 writeln!(
5217 self.out,
5218 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5219 float(b3 << 8 | b2) / 65535.0f, \
5220 float(b5 << 8 | b4) / 65535.0f, \
5221 float(b7 << 8 | b6) / 65535.0f);",
5222 back::INDENT
5223 )?;
5224 writeln!(self.out, "}}")?;
5225 Ok((name, 8, 4))
5226 }
5227 Snorm16 => {
5228 let name = self.namer.call("unpackSnorm16");
5229 writeln!(
5230 self.out,
5231 "float {name}(metal::ushort b0, \
5232 metal::ushort b1) {{"
5233 )?;
5234 writeln!(
5235 self.out,
5236 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5237 back::INDENT
5238 )?;
5239 writeln!(self.out, "}}")?;
5240 Ok((name, 2, 1))
5241 }
5242 Snorm16x2 => {
5243 let name = self.namer.call("unpackSnorm16x2");
5244 writeln!(
5245 self.out,
5246 "metal::float2 {name}(uint b0, \
5247 uint b1, \
5248 uint b2, \
5249 uint b3) {{"
5250 )?;
5251 writeln!(
5252 self.out,
5253 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5254 back::INDENT
5255 )?;
5256 writeln!(self.out, "}}")?;
5257 Ok((name, 4, 2))
5258 }
5259 Snorm16x4 => {
5260 let name = self.namer.call("unpackSnorm16x4");
5261 writeln!(
5262 self.out,
5263 "metal::float4 {name}(uint b0, \
5264 uint b1, \
5265 uint b2, \
5266 uint b3, \
5267 uint b4, \
5268 uint b5, \
5269 uint b6, \
5270 uint b7) {{"
5271 )?;
5272 writeln!(
5273 self.out,
5274 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5275 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5276 back::INDENT
5277 )?;
5278 writeln!(self.out, "}}")?;
5279 Ok((name, 8, 4))
5280 }
5281 Float16 => {
5282 let name = self.namer.call("unpackFloat16");
5283 writeln!(
5284 self.out,
5285 "float {name}(metal::ushort b0, \
5286 metal::ushort b1) {{"
5287 )?;
5288 writeln!(
5289 self.out,
5290 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5291 back::INDENT
5292 )?;
5293 writeln!(self.out, "}}")?;
5294 Ok((name, 2, 1))
5295 }
5296 Float16x2 => {
5297 let name = self.namer.call("unpackFloat16x2");
5298 writeln!(
5299 self.out,
5300 "metal::float2 {name}(metal::ushort b0, \
5301 metal::ushort b1, \
5302 metal::ushort b2, \
5303 metal::ushort b3) {{"
5304 )?;
5305 writeln!(
5306 self.out,
5307 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5308 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5309 back::INDENT
5310 )?;
5311 writeln!(self.out, "}}")?;
5312 Ok((name, 4, 2))
5313 }
5314 Float16x4 => {
5315 let name = self.namer.call("unpackFloat16x4");
5316 writeln!(
5317 self.out,
5318 "metal::float4 {name}(metal::ushort b0, \
5319 metal::ushort b1, \
5320 metal::ushort b2, \
5321 metal::ushort b3, \
5322 metal::ushort b4, \
5323 metal::ushort b5, \
5324 metal::ushort b6, \
5325 metal::ushort b7) {{"
5326 )?;
5327 writeln!(
5328 self.out,
5329 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5330 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5331 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5332 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5333 back::INDENT
5334 )?;
5335 writeln!(self.out, "}}")?;
5336 Ok((name, 8, 4))
5337 }
5338 Float32 => {
5339 let name = self.namer.call("unpackFloat32");
5340 writeln!(
5341 self.out,
5342 "float {name}(uint b0, \
5343 uint b1, \
5344 uint b2, \
5345 uint b3) {{"
5346 )?;
5347 writeln!(
5348 self.out,
5349 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5350 back::INDENT
5351 )?;
5352 writeln!(self.out, "}}")?;
5353 Ok((name, 4, 1))
5354 }
5355 Float32x2 => {
5356 let name = self.namer.call("unpackFloat32x2");
5357 writeln!(
5358 self.out,
5359 "metal::float2 {name}(uint b0, \
5360 uint b1, \
5361 uint b2, \
5362 uint b3, \
5363 uint b4, \
5364 uint b5, \
5365 uint b6, \
5366 uint b7) {{"
5367 )?;
5368 writeln!(
5369 self.out,
5370 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5371 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5372 back::INDENT
5373 )?;
5374 writeln!(self.out, "}}")?;
5375 Ok((name, 8, 2))
5376 }
5377 Float32x3 => {
5378 let name = self.namer.call("unpackFloat32x3");
5379 writeln!(
5380 self.out,
5381 "metal::float3 {name}(uint b0, \
5382 uint b1, \
5383 uint b2, \
5384 uint b3, \
5385 uint b4, \
5386 uint b5, \
5387 uint b6, \
5388 uint b7, \
5389 uint b8, \
5390 uint b9, \
5391 uint b10, \
5392 uint b11) {{"
5393 )?;
5394 writeln!(
5395 self.out,
5396 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5397 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5398 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5399 back::INDENT
5400 )?;
5401 writeln!(self.out, "}}")?;
5402 Ok((name, 12, 3))
5403 }
5404 Float32x4 => {
5405 let name = self.namer.call("unpackFloat32x4");
5406 writeln!(
5407 self.out,
5408 "metal::float4 {name}(uint b0, \
5409 uint b1, \
5410 uint b2, \
5411 uint b3, \
5412 uint b4, \
5413 uint b5, \
5414 uint b6, \
5415 uint b7, \
5416 uint b8, \
5417 uint b9, \
5418 uint b10, \
5419 uint b11, \
5420 uint b12, \
5421 uint b13, \
5422 uint b14, \
5423 uint b15) {{"
5424 )?;
5425 writeln!(
5426 self.out,
5427 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5428 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5429 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5430 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5431 back::INDENT
5432 )?;
5433 writeln!(self.out, "}}")?;
5434 Ok((name, 16, 4))
5435 }
5436 Uint32 => {
5437 let name = self.namer.call("unpackUint32");
5438 writeln!(
5439 self.out,
5440 "uint {name}(uint b0, \
5441 uint b1, \
5442 uint b2, \
5443 uint b3) {{"
5444 )?;
5445 writeln!(
5446 self.out,
5447 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5448 back::INDENT
5449 )?;
5450 writeln!(self.out, "}}")?;
5451 Ok((name, 4, 1))
5452 }
5453 Uint32x2 => {
5454 let name = self.namer.call("unpackUint32x2");
5455 writeln!(
5456 self.out,
5457 "uint2 {name}(uint b0, \
5458 uint b1, \
5459 uint b2, \
5460 uint b3, \
5461 uint b4, \
5462 uint b5, \
5463 uint b6, \
5464 uint b7) {{"
5465 )?;
5466 writeln!(
5467 self.out,
5468 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5469 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5470 back::INDENT
5471 )?;
5472 writeln!(self.out, "}}")?;
5473 Ok((name, 8, 2))
5474 }
5475 Uint32x3 => {
5476 let name = self.namer.call("unpackUint32x3");
5477 writeln!(
5478 self.out,
5479 "uint3 {name}(uint b0, \
5480 uint b1, \
5481 uint b2, \
5482 uint b3, \
5483 uint b4, \
5484 uint b5, \
5485 uint b6, \
5486 uint b7, \
5487 uint b8, \
5488 uint b9, \
5489 uint b10, \
5490 uint b11) {{"
5491 )?;
5492 writeln!(
5493 self.out,
5494 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5495 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5496 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5497 back::INDENT
5498 )?;
5499 writeln!(self.out, "}}")?;
5500 Ok((name, 12, 3))
5501 }
5502 Uint32x4 => {
5503 let name = self.namer.call("unpackUint32x4");
5504 writeln!(
5505 self.out,
5506 "{NAMESPACE}::uint4 {name}(uint b0, \
5507 uint b1, \
5508 uint b2, \
5509 uint b3, \
5510 uint b4, \
5511 uint b5, \
5512 uint b6, \
5513 uint b7, \
5514 uint b8, \
5515 uint b9, \
5516 uint b10, \
5517 uint b11, \
5518 uint b12, \
5519 uint b13, \
5520 uint b14, \
5521 uint b15) {{"
5522 )?;
5523 writeln!(
5524 self.out,
5525 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5526 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5527 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5528 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5529 back::INDENT
5530 )?;
5531 writeln!(self.out, "}}")?;
5532 Ok((name, 16, 4))
5533 }
5534 Sint32 => {
5535 let name = self.namer.call("unpackSint32");
5536 writeln!(
5537 self.out,
5538 "int {name}(uint b0, \
5539 uint b1, \
5540 uint b2, \
5541 uint b3) {{"
5542 )?;
5543 writeln!(
5544 self.out,
5545 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5546 back::INDENT
5547 )?;
5548 writeln!(self.out, "}}")?;
5549 Ok((name, 4, 1))
5550 }
5551 Sint32x2 => {
5552 let name = self.namer.call("unpackSint32x2");
5553 writeln!(
5554 self.out,
5555 "metal::int2 {name}(uint b0, \
5556 uint b1, \
5557 uint b2, \
5558 uint b3, \
5559 uint b4, \
5560 uint b5, \
5561 uint b6, \
5562 uint b7) {{"
5563 )?;
5564 writeln!(
5565 self.out,
5566 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5567 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5568 back::INDENT
5569 )?;
5570 writeln!(self.out, "}}")?;
5571 Ok((name, 8, 2))
5572 }
5573 Sint32x3 => {
5574 let name = self.namer.call("unpackSint32x3");
5575 writeln!(
5576 self.out,
5577 "metal::int3 {name}(uint b0, \
5578 uint b1, \
5579 uint b2, \
5580 uint b3, \
5581 uint b4, \
5582 uint b5, \
5583 uint b6, \
5584 uint b7, \
5585 uint b8, \
5586 uint b9, \
5587 uint b10, \
5588 uint b11) {{"
5589 )?;
5590 writeln!(
5591 self.out,
5592 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5593 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5594 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5595 back::INDENT
5596 )?;
5597 writeln!(self.out, "}}")?;
5598 Ok((name, 12, 3))
5599 }
5600 Sint32x4 => {
5601 let name = self.namer.call("unpackSint32x4");
5602 writeln!(
5603 self.out,
5604 "metal::int4 {name}(uint b0, \
5605 uint b1, \
5606 uint b2, \
5607 uint b3, \
5608 uint b4, \
5609 uint b5, \
5610 uint b6, \
5611 uint b7, \
5612 uint b8, \
5613 uint b9, \
5614 uint b10, \
5615 uint b11, \
5616 uint b12, \
5617 uint b13, \
5618 uint b14, \
5619 uint b15) {{"
5620 )?;
5621 writeln!(
5622 self.out,
5623 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5624 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5625 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5626 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5627 back::INDENT
5628 )?;
5629 writeln!(self.out, "}}")?;
5630 Ok((name, 16, 4))
5631 }
5632 Unorm10_10_10_2 => {
5633 let name = self.namer.call("unpackUnorm10_10_10_2");
5634 writeln!(
5635 self.out,
5636 "metal::float4 {name}(uint b0, \
5637 uint b1, \
5638 uint b2, \
5639 uint b3) {{"
5640 )?;
5641 writeln!(
5642 self.out,
5643 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5655 back::INDENT
5656 )?;
5657 writeln!(self.out, "}}")?;
5658 Ok((name, 4, 4))
5659 }
5660 Unorm8x4Bgra => {
5661 let name = self.namer.call("unpackUnorm8x4Bgra");
5662 writeln!(
5663 self.out,
5664 "metal::float4 {name}(metal::uchar b0, \
5665 metal::uchar b1, \
5666 metal::uchar b2, \
5667 metal::uchar b3) {{"
5668 )?;
5669 writeln!(
5670 self.out,
5671 "{}return metal::float4(float(b2) / 255.0f, \
5672 float(b1) / 255.0f, \
5673 float(b0) / 255.0f, \
5674 float(b3) / 255.0f);",
5675 back::INDENT
5676 )?;
5677 writeln!(self.out, "}}")?;
5678 Ok((name, 4, 4))
5679 }
5680 }
5681 }
5682
5683 fn write_wrapped_unary_op(
5684 &mut self,
5685 module: &crate::Module,
5686 func_ctx: &back::FunctionCtx,
5687 op: crate::UnaryOperator,
5688 operand: Handle<crate::Expression>,
5689 ) -> BackendResult {
5690 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5691 match op {
5692 crate::UnaryOperator::Negate
5699 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5700 {
5701 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5702 return Ok(());
5703 };
5704 let wrapped = WrappedFunction::UnaryOp {
5705 op,
5706 ty: (vector_size, scalar),
5707 };
5708 if !self.wrapped_functions.insert(wrapped) {
5709 return Ok(());
5710 }
5711
5712 let unsigned_scalar = crate::Scalar {
5713 kind: crate::ScalarKind::Uint,
5714 ..scalar
5715 };
5716 let mut type_name = String::new();
5717 let mut unsigned_type_name = String::new();
5718 match vector_size {
5719 None => {
5720 put_numeric_type(&mut type_name, scalar, &[])?;
5721 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5722 }
5723 Some(size) => {
5724 put_numeric_type(&mut type_name, scalar, &[size])?;
5725 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5726 }
5727 };
5728
5729 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5730 let level = back::Level(1);
5731 writeln!(
5732 self.out,
5733 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5734 )?;
5735 writeln!(self.out, "}}")?;
5736 writeln!(self.out)?;
5737 }
5738 _ => {}
5739 }
5740 Ok(())
5741 }
5742
5743 fn write_wrapped_binary_op(
5744 &mut self,
5745 module: &crate::Module,
5746 func_ctx: &back::FunctionCtx,
5747 expr: Handle<crate::Expression>,
5748 op: crate::BinaryOperator,
5749 left: Handle<crate::Expression>,
5750 right: Handle<crate::Expression>,
5751 ) -> BackendResult {
5752 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5753 let left_ty = func_ctx.resolve_type(left, &module.types);
5754 let right_ty = func_ctx.resolve_type(right, &module.types);
5755 match (op, expr_ty.scalar_kind()) {
5756 (
5763 crate::BinaryOperator::Divide,
5764 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5765 ) => {
5766 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5767 return Ok(());
5768 };
5769 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5770 return Ok(());
5771 };
5772 let wrapped = WrappedFunction::BinaryOp {
5773 op,
5774 left_ty: left_wrapped_ty,
5775 right_ty: right_wrapped_ty,
5776 };
5777 if !self.wrapped_functions.insert(wrapped) {
5778 return Ok(());
5779 }
5780
5781 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5782 return Ok(());
5783 };
5784 let mut type_name = String::new();
5785 match vector_size {
5786 None => put_numeric_type(&mut type_name, scalar, &[])?,
5787 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5788 };
5789 writeln!(
5790 self.out,
5791 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5792 )?;
5793 let level = back::Level(1);
5794 match scalar.kind {
5795 crate::ScalarKind::Sint => {
5796 let min_val = match scalar.width {
5797 4 => crate::Literal::I32(i32::MIN),
5798 8 => crate::Literal::I64(i64::MIN),
5799 _ => {
5800 return Err(Error::GenericValidation(format!(
5801 "Unexpected width for scalar {scalar:?}"
5802 )));
5803 }
5804 };
5805 write!(
5806 self.out,
5807 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5808 )?;
5809 self.put_literal(min_val)?;
5810 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5811 }
5812 crate::ScalarKind::Uint => writeln!(
5813 self.out,
5814 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5815 )?,
5816 _ => unreachable!(),
5817 }
5818 writeln!(self.out, "}}")?;
5819 writeln!(self.out)?;
5820 }
5821 (
5834 crate::BinaryOperator::Modulo,
5835 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5836 ) => {
5837 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5838 return Ok(());
5839 };
5840 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5841 else {
5842 return Ok(());
5843 };
5844 let wrapped = WrappedFunction::BinaryOp {
5845 op,
5846 left_ty: left_wrapped_ty,
5847 right_ty: (right_vector_size, right_scalar),
5848 };
5849 if !self.wrapped_functions.insert(wrapped) {
5850 return Ok(());
5851 }
5852
5853 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5854 return Ok(());
5855 };
5856 let mut type_name = String::new();
5857 match vector_size {
5858 None => put_numeric_type(&mut type_name, scalar, &[])?,
5859 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5860 };
5861 let mut rhs_type_name = String::new();
5862 match right_vector_size {
5863 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5864 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5865 };
5866
5867 writeln!(
5868 self.out,
5869 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5870 )?;
5871 let level = back::Level(1);
5872 match scalar.kind {
5873 crate::ScalarKind::Sint => {
5874 let min_val = match scalar.width {
5875 4 => crate::Literal::I32(i32::MIN),
5876 8 => crate::Literal::I64(i64::MIN),
5877 _ => {
5878 return Err(Error::GenericValidation(format!(
5879 "Unexpected width for scalar {scalar:?}"
5880 )));
5881 }
5882 };
5883 write!(
5884 self.out,
5885 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5886 )?;
5887 self.put_literal(min_val)?;
5888 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5889 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5890 }
5891 crate::ScalarKind::Uint => writeln!(
5892 self.out,
5893 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5894 )?,
5895 _ => unreachable!(),
5896 }
5897 writeln!(self.out, "}}")?;
5898 writeln!(self.out)?;
5899 }
5900 _ => {}
5901 }
5902 Ok(())
5903 }
5904
5905 fn get_dot_wrapper_function_helper_name(
5911 &self,
5912 scalar: crate::Scalar,
5913 size: crate::VectorSize,
5914 ) -> String {
5915 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5917
5918 let type_name = scalar.to_msl_name();
5919 let size_suffix = common::vector_size_str(size);
5920 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5921 }
5922
5923 #[allow(clippy::too_many_arguments)]
5924 fn write_wrapped_math_function(
5925 &mut self,
5926 module: &crate::Module,
5927 func_ctx: &back::FunctionCtx,
5928 fun: crate::MathFunction,
5929 arg: Handle<crate::Expression>,
5930 _arg1: Option<Handle<crate::Expression>>,
5931 _arg2: Option<Handle<crate::Expression>>,
5932 _arg3: Option<Handle<crate::Expression>>,
5933 ) -> BackendResult {
5934 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5935 match fun {
5936 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5944 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5945 return Ok(());
5946 };
5947 let wrapped = WrappedFunction::Math {
5948 fun,
5949 arg_ty: (vector_size, scalar),
5950 };
5951 if !self.wrapped_functions.insert(wrapped) {
5952 return Ok(());
5953 }
5954
5955 let unsigned_scalar = crate::Scalar {
5956 kind: crate::ScalarKind::Uint,
5957 ..scalar
5958 };
5959 let mut type_name = String::new();
5960 let mut unsigned_type_name = String::new();
5961 match vector_size {
5962 None => {
5963 put_numeric_type(&mut type_name, scalar, &[])?;
5964 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5965 }
5966 Some(size) => {
5967 put_numeric_type(&mut type_name, scalar, &[size])?;
5968 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5969 }
5970 };
5971
5972 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5973 let level = back::Level(1);
5974 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5975 writeln!(self.out, "}}")?;
5976 writeln!(self.out)?;
5977 }
5978
5979 crate::MathFunction::Dot => match *arg_ty {
5980 crate::TypeInner::Vector { size, scalar }
5981 if matches!(
5982 scalar.kind,
5983 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5984 ) =>
5985 {
5986 let wrapped = WrappedFunction::Math {
5988 fun,
5989 arg_ty: (Some(size), scalar),
5990 };
5991 if !self.wrapped_functions.insert(wrapped) {
5992 return Ok(());
5993 }
5994
5995 let mut vec_ty = String::new();
5996 put_numeric_type(&mut vec_ty, scalar, &[size])?;
5997 let mut ret_ty = String::new();
5998 put_numeric_type(&mut ret_ty, scalar, &[])?;
5999
6000 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6001
6002 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6004 let level = back::Level(1);
6005 write!(self.out, "{level}return ")?;
6006 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6007 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6008 Ok(())
6009 })?;
6010 writeln!(self.out, ";")?;
6011 writeln!(self.out, "}}")?;
6012 writeln!(self.out)?;
6013 }
6014 _ => {}
6015 },
6016
6017 _ => {}
6018 }
6019 Ok(())
6020 }
6021
6022 fn write_wrapped_cast(
6023 &mut self,
6024 module: &crate::Module,
6025 func_ctx: &back::FunctionCtx,
6026 expr: Handle<crate::Expression>,
6027 kind: crate::ScalarKind,
6028 convert: Option<crate::Bytes>,
6029 ) -> BackendResult {
6030 let src_ty = func_ctx.resolve_type(expr, &module.types);
6041 let Some(width) = convert else {
6042 return Ok(());
6043 };
6044 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6045 return Ok(());
6046 };
6047 let dst_scalar = crate::Scalar { kind, width };
6048 if src_scalar.kind != crate::ScalarKind::Float
6049 || (dst_scalar.kind != crate::ScalarKind::Sint
6050 && dst_scalar.kind != crate::ScalarKind::Uint)
6051 {
6052 return Ok(());
6053 }
6054 let wrapped = WrappedFunction::Cast {
6055 src_scalar,
6056 vector_size,
6057 dst_scalar,
6058 };
6059 if !self.wrapped_functions.insert(wrapped) {
6060 return Ok(());
6061 }
6062 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6063
6064 let mut src_type_name = String::new();
6065 match vector_size {
6066 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6067 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6068 };
6069 let mut dst_type_name = String::new();
6070 match vector_size {
6071 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6072 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6073 };
6074 let fun_name = match dst_scalar {
6075 crate::Scalar::I32 => F2I32_FUNCTION,
6076 crate::Scalar::U32 => F2U32_FUNCTION,
6077 crate::Scalar::I64 => F2I64_FUNCTION,
6078 crate::Scalar::U64 => F2U64_FUNCTION,
6079 _ => unreachable!(),
6080 };
6081
6082 writeln!(
6083 self.out,
6084 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6085 )?;
6086 let level = back::Level(1);
6087 write!(
6088 self.out,
6089 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6090 )?;
6091 self.put_literal(min)?;
6092 write!(self.out, ", ")?;
6093 self.put_literal(max)?;
6094 writeln!(self.out, "));")?;
6095 writeln!(self.out, "}}")?;
6096 writeln!(self.out)?;
6097 Ok(())
6098 }
6099
6100 fn write_convert_yuv_to_rgb_and_return(
6108 &mut self,
6109 level: back::Level,
6110 y: &str,
6111 uv: &str,
6112 params: &str,
6113 ) -> BackendResult {
6114 let l1 = level;
6115 let l2 = l1.next();
6116
6117 writeln!(
6119 self.out,
6120 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6121 )?;
6122
6123 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6126 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6127 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6128 writeln!(
6129 self.out,
6130 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6131 )?;
6132
6133 writeln!(
6136 self.out,
6137 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6138 )?;
6139
6140 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6143 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6144 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6145 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6146
6147 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6148 Ok(())
6149 }
6150
6151 #[allow(clippy::too_many_arguments)]
6152 fn write_wrapped_image_load(
6153 &mut self,
6154 module: &crate::Module,
6155 func_ctx: &back::FunctionCtx,
6156 image: Handle<crate::Expression>,
6157 _coordinate: Handle<crate::Expression>,
6158 _array_index: Option<Handle<crate::Expression>>,
6159 _sample: Option<Handle<crate::Expression>>,
6160 _level: Option<Handle<crate::Expression>>,
6161 ) -> BackendResult {
6162 let class = match *func_ctx.resolve_type(image, &module.types) {
6164 crate::TypeInner::Image { class, .. } => class,
6165 _ => unreachable!(),
6166 };
6167 if class != crate::ImageClass::External {
6168 return Ok(());
6169 }
6170 let wrapped = WrappedFunction::ImageLoad { class };
6171 if !self.wrapped_functions.insert(wrapped) {
6172 return Ok(());
6173 }
6174
6175 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6176 let l1 = back::Level(1);
6177 let l2 = l1.next();
6178 let l3 = l2.next();
6179 writeln!(
6180 self.out,
6181 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6182 )?;
6183 writeln!(
6187 self.out,
6188 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6189 )?;
6190 writeln!(
6191 self.out,
6192 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6193 )?;
6194
6195 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6197 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6198 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6200 writeln!(self.out, "{l1}}} else {{")?;
6201
6202 writeln!(
6204 self.out,
6205 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6206 )?;
6207 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6208
6209 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6211
6212 writeln!(self.out, "{l2}float2 uv;")?;
6213 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6214 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6216 writeln!(self.out, "{l2}}} else {{")?;
6217 writeln!(
6219 self.out,
6220 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6221 )?;
6222 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6223 writeln!(
6224 self.out,
6225 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6226 )?;
6227 writeln!(self.out, "{l2}}}")?;
6228
6229 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6230
6231 writeln!(self.out, "{l1}}}")?;
6232 writeln!(self.out, "}}")?;
6233 writeln!(self.out)?;
6234 Ok(())
6235 }
6236
6237 #[allow(clippy::too_many_arguments)]
6238 fn write_wrapped_image_sample(
6239 &mut self,
6240 module: &crate::Module,
6241 func_ctx: &back::FunctionCtx,
6242 image: Handle<crate::Expression>,
6243 _sampler: Handle<crate::Expression>,
6244 _gather: Option<crate::SwizzleComponent>,
6245 _coordinate: Handle<crate::Expression>,
6246 _array_index: Option<Handle<crate::Expression>>,
6247 _offset: Option<Handle<crate::Expression>>,
6248 _level: crate::SampleLevel,
6249 _depth_ref: Option<Handle<crate::Expression>>,
6250 clamp_to_edge: bool,
6251 ) -> BackendResult {
6252 if !clamp_to_edge {
6255 return Ok(());
6256 }
6257 let class = match *func_ctx.resolve_type(image, &module.types) {
6258 crate::TypeInner::Image { class, .. } => class,
6259 _ => unreachable!(),
6260 };
6261 let wrapped = WrappedFunction::ImageSample {
6262 class,
6263 clamp_to_edge: true,
6264 };
6265 if !self.wrapped_functions.insert(wrapped) {
6266 return Ok(());
6267 }
6268 match class {
6269 crate::ImageClass::External => {
6270 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6271 let l1 = back::Level(1);
6272 let l2 = l1.next();
6273 let l3 = l2.next();
6274 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6275 writeln!(
6276 self.out,
6277 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6278 )?;
6279
6280 writeln!(
6288 self.out,
6289 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6290 )?;
6291 writeln!(
6292 self.out,
6293 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6294 )?;
6295 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6296 writeln!(
6297 self.out,
6298 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6299 )?;
6300 writeln!(
6301 self.out,
6302 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6303 )?;
6304 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6305 writeln!(
6307 self.out,
6308 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6309 )?;
6310 writeln!(self.out, "{l1}}} else {{")?;
6311 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6312 writeln!(
6313 self.out,
6314 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6315 )?;
6316 writeln!(
6317 self.out,
6318 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6319 )?;
6320
6321 writeln!(
6323 self.out,
6324 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6325 )?;
6326 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6327 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6328 writeln!(
6330 self.out,
6331 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6332 )?;
6333 writeln!(self.out, "{l2}}} else {{")?;
6334 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6336 writeln!(
6337 self.out,
6338 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6339 )?;
6340 writeln!(
6341 self.out,
6342 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6343 )?;
6344 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6345 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6346 writeln!(self.out, "{l2}}}")?;
6347
6348 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6349
6350 writeln!(self.out, "{l1}}}")?;
6351 writeln!(self.out, "}}")?;
6352 writeln!(self.out)?;
6353 }
6354 _ => {
6355 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6356 let l1 = back::Level(1);
6357 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6358 writeln!(
6359 self.out,
6360 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6361 )?;
6362 writeln!(self.out, "}}")?;
6363 writeln!(self.out)?;
6364 }
6365 }
6366 Ok(())
6367 }
6368
6369 fn write_wrapped_image_query(
6370 &mut self,
6371 module: &crate::Module,
6372 func_ctx: &back::FunctionCtx,
6373 image: Handle<crate::Expression>,
6374 query: crate::ImageQuery,
6375 ) -> BackendResult {
6376 if !matches!(query, crate::ImageQuery::Size { .. }) {
6378 return Ok(());
6379 }
6380 let class = match *func_ctx.resolve_type(image, &module.types) {
6381 crate::TypeInner::Image { class, .. } => class,
6382 _ => unreachable!(),
6383 };
6384 if class != crate::ImageClass::External {
6385 return Ok(());
6386 }
6387 let wrapped = WrappedFunction::ImageQuerySize { class };
6388 if !self.wrapped_functions.insert(wrapped) {
6389 return Ok(());
6390 }
6391 writeln!(
6392 self.out,
6393 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6394 )?;
6395 let l1 = back::Level(1);
6396 let l2 = l1.next();
6397 writeln!(
6398 self.out,
6399 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6400 )?;
6401 writeln!(self.out, "{l2}return tex.params.size;")?;
6402 writeln!(self.out, "{l1}}} else {{")?;
6403 writeln!(
6405 self.out,
6406 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6407 )?;
6408 writeln!(self.out, "{l1}}}")?;
6409 writeln!(self.out, "}}")?;
6410 writeln!(self.out)?;
6411 Ok(())
6412 }
6413
6414 fn write_wrapped_cooperative_load(
6415 &mut self,
6416 module: &crate::Module,
6417 func_ctx: &back::FunctionCtx,
6418 columns: crate::CooperativeSize,
6419 rows: crate::CooperativeSize,
6420 pointer: Handle<crate::Expression>,
6421 ) -> BackendResult {
6422 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6423 let space = ptr_ty.pointer_space().unwrap();
6424 let space_name = space.to_msl_name().unwrap_or_default();
6425 let scalar = ptr_ty
6426 .pointer_base_type()
6427 .unwrap()
6428 .inner_with(&module.types)
6429 .scalar()
6430 .unwrap();
6431 let wrapped = WrappedFunction::CooperativeLoad {
6432 space_name,
6433 columns,
6434 rows,
6435 scalar,
6436 };
6437 if !self.wrapped_functions.insert(wrapped) {
6438 return Ok(());
6439 }
6440 let scalar_name = scalar.to_msl_name();
6441 writeln!(
6442 self.out,
6443 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6444 columns as u32, rows as u32,
6445 )?;
6446 let l1 = back::Level(1);
6447 writeln!(
6448 self.out,
6449 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6450 columns as u32, rows as u32
6451 )?;
6452 let matrix_origin = "0";
6453 writeln!(
6454 self.out,
6455 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6456 )?;
6457 writeln!(self.out, "{l1}return m;")?;
6458 writeln!(self.out, "}}")?;
6459 writeln!(self.out)?;
6460 Ok(())
6461 }
6462
6463 fn write_wrapped_cooperative_multiply_add(
6464 &mut self,
6465 module: &crate::Module,
6466 func_ctx: &back::FunctionCtx,
6467 space: crate::AddressSpace,
6468 a: Handle<crate::Expression>,
6469 b: Handle<crate::Expression>,
6470 ) -> BackendResult {
6471 let space_name = space.to_msl_name().unwrap_or_default();
6472 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6473 crate::TypeInner::CooperativeMatrix {
6474 columns,
6475 rows,
6476 scalar,
6477 ..
6478 } => (columns, rows, scalar),
6479 _ => unreachable!(),
6480 };
6481 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6482 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6483 _ => unreachable!(),
6484 };
6485 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6486 space_name,
6487 columns: b_c,
6488 rows: a_r,
6489 intermediate: a_c,
6490 scalar,
6491 };
6492 if !self.wrapped_functions.insert(wrapped) {
6493 return Ok(());
6494 }
6495 let scalar_name = scalar.to_msl_name();
6496 writeln!(
6497 self.out,
6498 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6499 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6500 )?;
6501 let l1 = back::Level(1);
6502 writeln!(
6503 self.out,
6504 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6505 b_c as u32, a_r as u32
6506 )?;
6507 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6508 writeln!(self.out, "{l1}return d;")?;
6509 writeln!(self.out, "}}")?;
6510 writeln!(self.out)?;
6511 Ok(())
6512 }
6513
6514 pub(super) fn write_wrapped_functions(
6515 &mut self,
6516 module: &crate::Module,
6517 func_ctx: &back::FunctionCtx,
6518 ) -> BackendResult {
6519 for (expr_handle, expr) in func_ctx.expressions.iter() {
6520 match *expr {
6521 crate::Expression::Unary { op, expr: operand } => {
6522 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6523 }
6524 crate::Expression::Binary { op, left, right } => {
6525 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6526 }
6527 crate::Expression::Math {
6528 fun,
6529 arg,
6530 arg1,
6531 arg2,
6532 arg3,
6533 } => {
6534 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6535 }
6536 crate::Expression::As {
6537 expr,
6538 kind,
6539 convert,
6540 } => {
6541 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6542 }
6543 crate::Expression::ImageLoad {
6544 image,
6545 coordinate,
6546 array_index,
6547 sample,
6548 level,
6549 } => {
6550 self.write_wrapped_image_load(
6551 module,
6552 func_ctx,
6553 image,
6554 coordinate,
6555 array_index,
6556 sample,
6557 level,
6558 )?;
6559 }
6560 crate::Expression::ImageSample {
6561 image,
6562 sampler,
6563 gather,
6564 coordinate,
6565 array_index,
6566 offset,
6567 level,
6568 depth_ref,
6569 clamp_to_edge,
6570 } => {
6571 self.write_wrapped_image_sample(
6572 module,
6573 func_ctx,
6574 image,
6575 sampler,
6576 gather,
6577 coordinate,
6578 array_index,
6579 offset,
6580 level,
6581 depth_ref,
6582 clamp_to_edge,
6583 )?;
6584 }
6585 crate::Expression::ImageQuery { image, query } => {
6586 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6587 }
6588 crate::Expression::CooperativeLoad {
6589 columns,
6590 rows,
6591 role: _,
6592 ref data,
6593 } => {
6594 self.write_wrapped_cooperative_load(
6595 module,
6596 func_ctx,
6597 columns,
6598 rows,
6599 data.pointer,
6600 )?;
6601 }
6602 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6603 let space = crate::AddressSpace::Private;
6604 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6605 }
6606 _ => {}
6607 }
6608 }
6609
6610 Ok(())
6611 }
6612
6613 fn write_functions(
6615 &mut self,
6616 module: &crate::Module,
6617 mod_info: &valid::ModuleInfo,
6618 options: &Options,
6619 pipeline_options: &PipelineOptions,
6620 ) -> Result<TranslationInfo, Error> {
6621 use back::msl::VertexFormat;
6622
6623 struct AttributeMappingResolved {
6626 ty_name: String,
6627 dimension: u32,
6628 ty_is_int: bool,
6629 name: String,
6630 }
6631 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6632
6633 struct VertexBufferMappingResolved<'a> {
6634 id: u32,
6635 stride: u32,
6636 step_mode: back::msl::VertexBufferStepMode,
6637 ty_name: String,
6638 param_name: String,
6639 elem_name: String,
6640 attributes: &'a Vec<back::msl::AttributeMapping>,
6641 }
6642 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6643
6644 struct UnpackingFunction {
6646 name: String,
6647 byte_count: u32,
6648 dimension: u32,
6649 }
6650 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6651
6652 let mut needs_vertex_id = false;
6658 let v_id = self.namer.call("v_id");
6659
6660 let mut needs_instance_id = false;
6661 let i_id = self.namer.call("i_id");
6662 if pipeline_options.vertex_pulling_transform {
6663 for vbm in &pipeline_options.vertex_buffer_mappings {
6664 let buffer_id = vbm.id;
6665 let buffer_stride = vbm.stride;
6666
6667 assert!(
6668 buffer_stride > 0,
6669 "Vertex pulling requires a non-zero buffer stride."
6670 );
6671
6672 match vbm.step_mode {
6673 back::msl::VertexBufferStepMode::Constant => {}
6674 back::msl::VertexBufferStepMode::ByVertex => {
6675 needs_vertex_id = true;
6676 }
6677 back::msl::VertexBufferStepMode::ByInstance => {
6678 needs_instance_id = true;
6679 }
6680 }
6681
6682 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6683 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6684 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6685
6686 vbm_resolved.push(VertexBufferMappingResolved {
6687 id: buffer_id,
6688 stride: buffer_stride,
6689 step_mode: vbm.step_mode,
6690 ty_name: buffer_ty,
6691 param_name: buffer_param,
6692 elem_name: buffer_elem,
6693 attributes: &vbm.attributes,
6694 });
6695
6696 for attribute in &vbm.attributes {
6698 if unpacking_functions.contains_key(&attribute.format) {
6699 continue;
6700 }
6701 let (name, byte_count, dimension) =
6702 match self.write_unpacking_function(attribute.format) {
6703 Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6704 _ => {
6705 continue;
6706 }
6707 };
6708 unpacking_functions.insert(
6709 attribute.format,
6710 UnpackingFunction {
6711 name,
6712 byte_count,
6713 dimension,
6714 },
6715 );
6716 }
6717 }
6718 }
6719
6720 let mut pass_through_globals = Vec::new();
6721 for (fun_handle, fun) in module.functions.iter() {
6722 log::trace!(
6723 "function {:?}, handle {:?}",
6724 fun.name.as_deref().unwrap_or("(anonymous)"),
6725 fun_handle
6726 );
6727
6728 let ctx = back::FunctionCtx {
6729 ty: back::FunctionType::Function(fun_handle),
6730 info: &mod_info[fun_handle],
6731 expressions: &fun.expressions,
6732 named_expressions: &fun.named_expressions,
6733 };
6734
6735 writeln!(self.out)?;
6736 self.write_wrapped_functions(module, &ctx)?;
6737
6738 let fun_info = &mod_info[fun_handle];
6739 pass_through_globals.clear();
6740 let mut needs_buffer_sizes = false;
6741 for (handle, var) in module.global_variables.iter() {
6742 if !fun_info[handle].is_empty() {
6743 if var.space.needs_pass_through() {
6744 pass_through_globals.push(handle);
6745 }
6746 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6747 }
6748 }
6749
6750 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6751 match fun.result {
6752 Some(ref result) => {
6753 let ty_name = TypeContext {
6754 handle: result.ty,
6755 gctx: module.to_ctx(),
6756 names: &self.names,
6757 access: crate::StorageAccess::empty(),
6758 first_time: false,
6759 };
6760 write!(self.out, "{ty_name}")?;
6761 }
6762 None => {
6763 write!(self.out, "void")?;
6764 }
6765 }
6766 writeln!(self.out, " {fun_name}(")?;
6767
6768 for (index, arg) in fun.arguments.iter().enumerate() {
6769 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6770 let param_type_name = TypeContext {
6771 handle: arg.ty,
6772 gctx: module.to_ctx(),
6773 names: &self.names,
6774 access: crate::StorageAccess::empty(),
6775 first_time: false,
6776 };
6777 let separator = separate(
6778 !pass_through_globals.is_empty()
6779 || index + 1 != fun.arguments.len()
6780 || needs_buffer_sizes,
6781 );
6782 writeln!(
6783 self.out,
6784 "{}{} {}{}",
6785 back::INDENT,
6786 param_type_name,
6787 name,
6788 separator
6789 )?;
6790 }
6791 for (index, &handle) in pass_through_globals.iter().enumerate() {
6792 let tyvar = TypedGlobalVariable {
6793 module,
6794 names: &self.names,
6795 handle,
6796 usage: fun_info[handle],
6797 reference: true,
6798 };
6799 let separator =
6800 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6801 write!(self.out, "{}", back::INDENT)?;
6802 tyvar.try_fmt(&mut self.out)?;
6803 writeln!(self.out, "{separator}")?;
6804 }
6805
6806 if needs_buffer_sizes {
6807 writeln!(
6808 self.out,
6809 "{}constant _mslBufferSizes& _buffer_sizes",
6810 back::INDENT
6811 )?;
6812 }
6813
6814 writeln!(self.out, ") {{")?;
6815
6816 let guarded_indices =
6817 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6818
6819 let context = StatementContext {
6820 expression: ExpressionContext {
6821 function: fun,
6822 origin: FunctionOrigin::Handle(fun_handle),
6823 info: fun_info,
6824 lang_version: options.lang_version,
6825 policies: options.bounds_check_policies,
6826 guarded_indices,
6827 module,
6828 mod_info,
6829 pipeline_options,
6830 force_loop_bounding: options.force_loop_bounding,
6831 },
6832 result_struct: None,
6833 };
6834
6835 self.put_locals(&context.expression)?;
6836 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6837 self.put_block(back::Level(1), &fun.body, &context)?;
6838 writeln!(self.out, "}}")?;
6839 self.named_expressions.clear();
6840 }
6841
6842 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6843 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6844
6845 let mut info = TranslationInfo {
6846 entry_point_names: Vec::with_capacity(ep_range.len()),
6847 };
6848
6849 for ep_index in ep_range {
6850 let ep = &module.entry_points[ep_index];
6851 let fun = &ep.function;
6852 let fun_info = mod_info.get_entry_point(ep_index);
6853 let mut ep_error = None;
6854
6855 let mut v_existing_id = None;
6859 let mut i_existing_id = None;
6860
6861 log::trace!(
6862 "entry point {:?}, index {:?}",
6863 fun.name.as_deref().unwrap_or("(anonymous)"),
6864 ep_index
6865 );
6866
6867 let ctx = back::FunctionCtx {
6868 ty: back::FunctionType::EntryPoint(ep_index as u16),
6869 info: fun_info,
6870 expressions: &fun.expressions,
6871 named_expressions: &fun.named_expressions,
6872 };
6873
6874 self.write_wrapped_functions(module, &ctx)?;
6875
6876 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6877 crate::ShaderStage::Vertex => (
6878 "vertex",
6879 LocationMode::VertexInput,
6880 LocationMode::VertexOutput,
6881 true,
6882 ),
6883 crate::ShaderStage::Fragment => (
6884 "fragment",
6885 LocationMode::FragmentInput,
6886 LocationMode::FragmentOutput,
6887 false,
6888 ),
6889 crate::ShaderStage::Compute => (
6890 "kernel",
6891 LocationMode::Uniform,
6892 LocationMode::Uniform,
6893 false,
6894 ),
6895 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6896 crate::ShaderStage::RayGeneration
6897 | crate::ShaderStage::AnyHit
6898 | crate::ShaderStage::ClosestHit
6899 | crate::ShaderStage::Miss => unimplemented!(),
6900 };
6901
6902 let do_vertex_pulling = can_vertex_pull
6904 && pipeline_options.vertex_pulling_transform
6905 && !pipeline_options.vertex_buffer_mappings.is_empty();
6906
6907 let needs_buffer_sizes = do_vertex_pulling
6909 || module
6910 .global_variables
6911 .iter()
6912 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6913 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6914
6915 if !options.fake_missing_bindings {
6918 for (var_handle, var) in module.global_variables.iter() {
6919 if fun_info[var_handle].is_empty() {
6920 continue;
6921 }
6922 match var.space {
6923 crate::AddressSpace::Uniform
6924 | crate::AddressSpace::Storage { .. }
6925 | crate::AddressSpace::Handle => {
6926 let br = match var.binding {
6927 Some(ref br) => br,
6928 None => {
6929 let var_name = var.name.clone().unwrap_or_default();
6930 ep_error =
6931 Some(super::EntryPointError::MissingBinding(var_name));
6932 break;
6933 }
6934 };
6935 let target = options.get_resource_binding_target(ep, br);
6936 let good = match target {
6937 Some(target) => {
6938 match module.types[var.ty].inner {
6942 crate::TypeInner::Image {
6943 class: crate::ImageClass::External,
6944 ..
6945 } => target.external_texture.is_some(),
6946 crate::TypeInner::Image { .. } => target.texture.is_some(),
6947 crate::TypeInner::Sampler { .. } => {
6948 target.sampler.is_some()
6949 }
6950 _ => target.buffer.is_some(),
6951 }
6952 }
6953 None => false,
6954 };
6955 if !good {
6956 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6957 break;
6958 }
6959 }
6960 crate::AddressSpace::Immediate => {
6961 if let Err(e) = options.resolve_immediates(ep) {
6962 ep_error = Some(e);
6963 break;
6964 }
6965 }
6966 crate::AddressSpace::TaskPayload => {
6967 unimplemented!()
6968 }
6969 crate::AddressSpace::Function
6970 | crate::AddressSpace::Private
6971 | crate::AddressSpace::WorkGroup => {}
6972 crate::AddressSpace::RayPayload
6973 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6974 }
6975 }
6976 if needs_buffer_sizes {
6977 if let Err(err) = options.resolve_sizes_buffer(ep) {
6978 ep_error = Some(err);
6979 }
6980 }
6981 }
6982
6983 if let Some(err) = ep_error {
6984 info.entry_point_names.push(Err(err));
6985 continue;
6986 }
6987 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6988 info.entry_point_names.push(Ok(fun_name.clone()));
6989
6990 writeln!(self.out)?;
6991
6992 let mut flattened_member_names = FastHashMap::default();
6998 let mut varyings_namer = proc::Namer::default();
7000
7001 let mut flattened_arguments = Vec::new();
7006 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7007 match module.types[arg.ty].inner {
7008 crate::TypeInner::Struct { ref members, .. } => {
7009 for (member_index, member) in members.iter().enumerate() {
7010 let member_index = member_index as u32;
7011 flattened_arguments.push((
7012 NameKey::StructMember(arg.ty, member_index),
7013 member.ty,
7014 member.binding.as_ref(),
7015 ));
7016 let name_key = NameKey::StructMember(arg.ty, member_index);
7017 let name = match member.binding {
7018 Some(crate::Binding::Location { .. }) => {
7019 if do_vertex_pulling {
7020 self.namer.call(&self.names[&name_key])
7021 } else {
7022 varyings_namer.call(&self.names[&name_key])
7023 }
7024 }
7025 _ => self.namer.call(&self.names[&name_key]),
7026 };
7027 flattened_member_names.insert(name_key, name);
7028 }
7029 }
7030 _ => flattened_arguments.push((
7031 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7032 arg.ty,
7033 arg.binding.as_ref(),
7034 )),
7035 }
7036 }
7037
7038 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7043 let varyings_member_name = self.namer.call("varyings");
7044 let mut has_varyings = false;
7045 if !flattened_arguments.is_empty() {
7046 if !do_vertex_pulling {
7047 writeln!(self.out, "struct {stage_in_name} {{")?;
7048 }
7049 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7050 let Some(binding) = binding else {
7051 continue;
7052 };
7053 let name = match *name_key {
7054 NameKey::StructMember(..) => &flattened_member_names[name_key],
7055 _ => &self.names[name_key],
7056 };
7057 let ty_name = TypeContext {
7058 handle: ty,
7059 gctx: module.to_ctx(),
7060 names: &self.names,
7061 access: crate::StorageAccess::empty(),
7062 first_time: false,
7063 };
7064 let resolved = options.resolve_local_binding(binding, in_mode)?;
7065 let location = match *binding {
7066 crate::Binding::Location { location, .. } => Some(location),
7067 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7068 crate::Binding::BuiltIn(_) => continue,
7069 };
7070 if do_vertex_pulling {
7071 let Some(location) = location else {
7072 continue;
7073 };
7074 am_resolved.insert(
7076 location,
7077 AttributeMappingResolved {
7078 ty_name: ty_name.to_string(),
7079 dimension: ty_name.vertex_input_dimension(),
7080 ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
7081 name: name.to_string(),
7082 },
7083 );
7084 } else {
7085 has_varyings = true;
7086 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7087 resolved.try_fmt(&mut self.out)?;
7088 writeln!(self.out, ";")?;
7089 }
7090 }
7091 if !do_vertex_pulling {
7092 writeln!(self.out, "}};")?;
7093 }
7094 }
7095
7096 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7099 let result_member_name = self.namer.call("member");
7100 let result_type_name = match fun.result {
7101 Some(ref result) => {
7102 let mut result_members = Vec::new();
7103 if let crate::TypeInner::Struct { ref members, .. } =
7104 module.types[result.ty].inner
7105 {
7106 for (member_index, member) in members.iter().enumerate() {
7107 result_members.push((
7108 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7109 member.ty,
7110 member.binding.as_ref(),
7111 ));
7112 }
7113 } else {
7114 result_members.push((
7115 &result_member_name,
7116 result.ty,
7117 result.binding.as_ref(),
7118 ));
7119 }
7120
7121 writeln!(self.out, "struct {stage_out_name} {{")?;
7122 let mut has_point_size = false;
7123 for (name, ty, binding) in result_members {
7124 let ty_name = TypeContext {
7125 handle: ty,
7126 gctx: module.to_ctx(),
7127 names: &self.names,
7128 access: crate::StorageAccess::empty(),
7129 first_time: true,
7130 };
7131 let binding = binding.ok_or_else(|| {
7132 Error::GenericValidation("Expected binding, got None".into())
7133 })?;
7134
7135 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7136 has_point_size = true;
7137 if !pipeline_options.allow_and_force_point_size {
7138 continue;
7139 }
7140 }
7141
7142 let array_len = match module.types[ty].inner {
7143 crate::TypeInner::Array {
7144 size: crate::ArraySize::Constant(size),
7145 ..
7146 } => Some(size),
7147 _ => None,
7148 };
7149 let resolved = options.resolve_local_binding(binding, out_mode)?;
7150 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7151 if let Some(array_len) = array_len {
7152 write!(self.out, " [{array_len}]")?;
7153 }
7154 resolved.try_fmt(&mut self.out)?;
7155 writeln!(self.out, ";")?;
7156 }
7157
7158 if pipeline_options.allow_and_force_point_size
7159 && ep.stage == crate::ShaderStage::Vertex
7160 && !has_point_size
7161 {
7162 writeln!(
7164 self.out,
7165 "{}float _point_size [[point_size]];",
7166 back::INDENT
7167 )?;
7168 }
7169 writeln!(self.out, "}};")?;
7170 &stage_out_name
7171 }
7172 None => "void",
7173 };
7174
7175 if do_vertex_pulling {
7178 for vbm in &vbm_resolved {
7179 let buffer_stride = vbm.stride;
7180 let buffer_ty = &vbm.ty_name;
7181
7182 writeln!(
7186 self.out,
7187 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7188 )?;
7189 }
7190 }
7191
7192 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
7194
7195 let mut is_first_argument = true;
7196 let mut separator = || {
7197 if is_first_argument {
7198 is_first_argument = false;
7199 ' '
7200 } else {
7201 ','
7202 }
7203 };
7204
7205 if has_varyings {
7208 writeln!(
7209 self.out,
7210 "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
7211 separator()
7212 )?;
7213 }
7214
7215 let mut local_invocation_id = None;
7216
7217 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7220 let binding = match binding {
7221 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7222 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7223 _ => continue,
7224 };
7225 let name = match *name_key {
7226 NameKey::StructMember(..) => &flattened_member_names[name_key],
7227 _ => &self.names[name_key],
7228 };
7229
7230 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
7231 local_invocation_id = Some(name_key);
7232 }
7233
7234 let ty_name = TypeContext {
7235 handle: ty,
7236 gctx: module.to_ctx(),
7237 names: &self.names,
7238 access: crate::StorageAccess::empty(),
7239 first_time: false,
7240 };
7241
7242 match *binding {
7243 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7244 v_existing_id = Some(name.clone());
7245 }
7246 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7247 i_existing_id = Some(name.clone());
7248 }
7249 _ => {}
7250 };
7251
7252 let resolved = options.resolve_local_binding(binding, in_mode)?;
7253 write!(self.out, "{} {ty_name} {name}", separator())?;
7254 resolved.try_fmt(&mut self.out)?;
7255 writeln!(self.out)?;
7256 }
7257
7258 let need_workgroup_variables_initialization =
7259 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7260
7261 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7262 writeln!(
7263 self.out,
7264 "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7265 )?;
7266 }
7267
7268 for (handle, var) in module.global_variables.iter() {
7273 let usage = fun_info[handle];
7274 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7275 continue;
7276 }
7277
7278 if options.lang_version < (1, 2) {
7279 match var.space {
7280 crate::AddressSpace::Storage { access }
7290 if access.contains(crate::StorageAccess::STORE)
7291 && ep.stage == crate::ShaderStage::Fragment =>
7292 {
7293 return Err(Error::UnsupportedWriteableStorageBuffer)
7294 }
7295 crate::AddressSpace::Handle => {
7296 match module.types[var.ty].inner {
7297 crate::TypeInner::Image {
7298 class: crate::ImageClass::Storage { access, .. },
7299 ..
7300 } => {
7301 if access.contains(crate::StorageAccess::STORE)
7311 && (ep.stage == crate::ShaderStage::Vertex
7312 || ep.stage == crate::ShaderStage::Fragment)
7313 {
7314 return Err(Error::UnsupportedWriteableStorageTexture(
7315 ep.stage,
7316 ));
7317 }
7318
7319 if access.contains(
7320 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7321 ) {
7322 return Err(Error::UnsupportedRWStorageTexture);
7323 }
7324 }
7325 _ => {}
7326 }
7327 }
7328 _ => {}
7329 }
7330 }
7331
7332 match var.space {
7334 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7335 crate::TypeInner::BindingArray { base, .. } => {
7336 match module.types[base].inner {
7337 crate::TypeInner::Sampler { .. } => {
7338 if options.lang_version < (2, 0) {
7339 return Err(Error::UnsupportedArrayOf(
7340 "samplers".to_string(),
7341 ));
7342 }
7343 }
7344 crate::TypeInner::Image { class, .. } => match class {
7345 crate::ImageClass::Sampled { .. }
7346 | crate::ImageClass::Depth { .. }
7347 | crate::ImageClass::Storage {
7348 access: crate::StorageAccess::LOAD,
7349 ..
7350 } => {
7351 if options.lang_version < (2, 0) {
7356 return Err(Error::UnsupportedArrayOf(
7357 "textures".to_string(),
7358 ));
7359 }
7360 }
7361 crate::ImageClass::Storage {
7362 access: crate::StorageAccess::STORE,
7363 ..
7364 } => {
7365 if options.lang_version < (2, 0) {
7370 return Err(Error::UnsupportedArrayOf(
7371 "write-only textures".to_string(),
7372 ));
7373 }
7374 }
7375 crate::ImageClass::Storage { .. } => {
7376 if options.lang_version < (3, 0) {
7377 return Err(Error::UnsupportedArrayOf(
7378 "read-write textures".to_string(),
7379 ));
7380 }
7381 }
7382 crate::ImageClass::External => {
7383 return Err(Error::UnsupportedArrayOf(
7384 "external textures".to_string(),
7385 ));
7386 }
7387 },
7388 _ => {
7389 return Err(Error::UnsupportedArrayOfType(base));
7390 }
7391 }
7392 }
7393 _ => {}
7394 },
7395 _ => {}
7396 }
7397
7398 let resolved = match var.space {
7400 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7401 crate::AddressSpace::WorkGroup => None,
7402 _ => options
7403 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7404 .ok(),
7405 };
7406 if let Some(ref resolved) = resolved {
7407 if resolved.as_inline_sampler(options).is_some() {
7409 continue;
7410 }
7411 }
7412
7413 match module.types[var.ty].inner {
7414 crate::TypeInner::Image {
7415 class: crate::ImageClass::External,
7416 ..
7417 } => {
7418 let target = match resolved {
7422 Some(back::msl::ResolvedBinding::Resource(target)) => {
7423 target.external_texture
7424 }
7425 _ => None,
7426 };
7427
7428 for i in 0..3 {
7429 write!(self.out, "{} ", separator())?;
7430
7431 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7432 handle,
7433 ExternalTextureNameKey::Plane(i),
7434 )];
7435 write!(
7436 self.out,
7437 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7438 )?;
7439 if let Some(ref target) = target {
7440 write!(self.out, " [[texture({})]]", target.planes[i])?;
7441 }
7442 writeln!(self.out)?;
7443 }
7444 let params_ty_name = &self.names
7445 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7446 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7447 handle,
7448 ExternalTextureNameKey::Params,
7449 )];
7450 write!(self.out, "{} ", separator())?;
7451 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7452 if let Some(ref target) = target {
7453 write!(self.out, " [[buffer({})]]", target.params)?;
7454 }
7455 }
7456 _ => {
7457 let tyvar = TypedGlobalVariable {
7458 module,
7459 names: &self.names,
7460 handle,
7461 usage,
7462 reference: true,
7463 };
7464 write!(self.out, "{} ", separator())?;
7465 tyvar.try_fmt(&mut self.out)?;
7466 if let Some(resolved) = resolved {
7467 resolved.try_fmt(&mut self.out)?;
7468 }
7469 if let Some(value) = var.init {
7470 write!(self.out, " = ")?;
7471 self.put_const_expression(
7472 value,
7473 module,
7474 mod_info,
7475 &module.global_expressions,
7476 )?;
7477 }
7478 }
7479 }
7480 writeln!(self.out)?;
7481 }
7482
7483 if do_vertex_pulling {
7484 if needs_vertex_id && v_existing_id.is_none() {
7485 writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7487 }
7488
7489 if needs_instance_id && i_existing_id.is_none() {
7490 writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7491 }
7492
7493 for vbm in &vbm_resolved {
7496 let id = &vbm.id;
7497 let ty_name = &vbm.ty_name;
7498 let param_name = &vbm.param_name;
7499 writeln!(
7500 self.out,
7501 "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7502 separator()
7503 )?;
7504 }
7505 }
7506
7507 if needs_buffer_sizes {
7510 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7512 write!(
7513 self.out,
7514 "{} constant _mslBufferSizes& _buffer_sizes",
7515 separator()
7516 )?;
7517 resolved.try_fmt(&mut self.out)?;
7518 writeln!(self.out)?;
7519 }
7520
7521 writeln!(self.out, ") {{")?;
7523
7524 if do_vertex_pulling {
7526 for vbm in &vbm_resolved {
7529 for attribute in vbm.attributes {
7530 let location = attribute.shader_location;
7531 let am_option = am_resolved.get(&location);
7532 if am_option.is_none() {
7533 continue;
7536 }
7537 let am = am_option.unwrap();
7538 let attribute_ty_name = &am.ty_name;
7539 let attribute_name = &am.name;
7540
7541 writeln!(
7542 self.out,
7543 "{}{attribute_ty_name} {attribute_name} = {{}};",
7544 back::Level(1)
7545 )?;
7546 }
7547
7548 write!(self.out, "{}if (", back::Level(1))?;
7551
7552 let idx = &vbm.id;
7553 let stride = &vbm.stride;
7554 let index_name = match vbm.step_mode {
7555 back::msl::VertexBufferStepMode::Constant => "0",
7556 back::msl::VertexBufferStepMode::ByVertex => {
7557 if let Some(ref name) = v_existing_id {
7558 name
7559 } else {
7560 &v_id
7561 }
7562 }
7563 back::msl::VertexBufferStepMode::ByInstance => {
7564 if let Some(ref name) = i_existing_id {
7565 name
7566 } else {
7567 &i_id
7568 }
7569 }
7570 };
7571 write!(
7572 self.out,
7573 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7574 )?;
7575
7576 writeln!(self.out, ") {{")?;
7577
7578 let ty_name = &vbm.ty_name;
7580 let elem_name = &vbm.elem_name;
7581 let param_name = &vbm.param_name;
7582
7583 writeln!(
7584 self.out,
7585 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7586 back::Level(2),
7587 )?;
7588
7589 for attribute in vbm.attributes {
7592 let location = attribute.shader_location;
7593 let Some(am) = am_resolved.get(&location) else {
7594 continue;
7598 };
7599 let attribute_name = &am.name;
7600 let attribute_ty_name = &am.ty_name;
7601
7602 let offset = attribute.offset;
7603 let func = unpacking_functions
7604 .get(&attribute.format)
7605 .expect("Should have generated this unpacking function earlier.");
7606 let func_name = &func.name;
7607
7608 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7615
7616 if needs_padding_or_truncation != Ordering::Equal {
7617 writeln!(
7620 self.out,
7621 "{}// {attribute_ty_name} <- {:?}",
7622 back::Level(2),
7623 attribute.format
7624 )?;
7625 }
7626
7627 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7628
7629 if needs_padding_or_truncation == Ordering::Greater {
7630 write!(self.out, "{attribute_ty_name}(")?;
7632 }
7633
7634 write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7636 for i in (offset + 1)..(offset + func.byte_count) {
7637 write!(self.out, ", {elem_name}.data[{i}]")?;
7638 }
7639 write!(self.out, ")")?;
7640
7641 match needs_padding_or_truncation {
7642 Ordering::Greater => {
7643 let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7645 let one_value = if am.ty_is_int { "1" } else { "1.0" };
7646 for i in func.dimension..am.dimension {
7647 write!(
7648 self.out,
7649 ", {}",
7650 if i == 3 { one_value } else { zero_value }
7651 )?;
7652 }
7653 write!(self.out, ")")?;
7654 }
7655 Ordering::Less => {
7656 write!(
7658 self.out,
7659 ".{}",
7660 &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7661 )?;
7662 }
7663 Ordering::Equal => {}
7664 }
7665
7666 writeln!(self.out, ";")?;
7667 }
7668
7669 writeln!(self.out, "{}}}", back::Level(1))?;
7671 }
7672 }
7673
7674 if need_workgroup_variables_initialization {
7675 self.write_workgroup_variables_initialization(
7676 module,
7677 mod_info,
7678 fun_info,
7679 local_invocation_id,
7680 )?;
7681 }
7682
7683 for (handle, var) in module.global_variables.iter() {
7686 let usage = fun_info[handle];
7687 if usage.is_empty() {
7688 continue;
7689 }
7690 if var.space == crate::AddressSpace::Private {
7691 let tyvar = TypedGlobalVariable {
7692 module,
7693 names: &self.names,
7694 handle,
7695 usage,
7696
7697 reference: false,
7698 };
7699 write!(self.out, "{}", back::INDENT)?;
7700 tyvar.try_fmt(&mut self.out)?;
7701 match var.init {
7702 Some(value) => {
7703 write!(self.out, " = ")?;
7704 self.put_const_expression(
7705 value,
7706 module,
7707 mod_info,
7708 &module.global_expressions,
7709 )?;
7710 writeln!(self.out, ";")?;
7711 }
7712 None => {
7713 writeln!(self.out, " = {{}};")?;
7714 }
7715 };
7716 } else if let Some(ref binding) = var.binding {
7717 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7718 if let Some(sampler) = resolved.as_inline_sampler(options) {
7719 let name = &self.names[&NameKey::GlobalVariable(handle)];
7721 writeln!(
7722 self.out,
7723 "{}constexpr {}::sampler {}(",
7724 back::INDENT,
7725 NAMESPACE,
7726 name
7727 )?;
7728 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7729 writeln!(self.out, "{});", back::INDENT)?;
7730 } else if let crate::TypeInner::Image {
7731 class: crate::ImageClass::External,
7732 ..
7733 } = module.types[var.ty].inner
7734 {
7735 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7738 let l1 = back::Level(1);
7739 let l2 = l1.next();
7740 writeln!(
7741 self.out,
7742 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7743 )?;
7744 for i in 0..3 {
7745 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7746 handle,
7747 ExternalTextureNameKey::Plane(i),
7748 )];
7749 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7750 }
7751 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7752 handle,
7753 ExternalTextureNameKey::Params,
7754 )];
7755 writeln!(self.out, "{l2}.params = {params_name},")?;
7756 writeln!(self.out, "{l1}}};")?;
7757 }
7758 }
7759 }
7760
7761 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7772 let arg_name =
7773 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7774 match module.types[arg.ty].inner {
7775 crate::TypeInner::Struct { ref members, .. } => {
7776 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7777 write!(
7778 self.out,
7779 "{}const {} {} = {{ ",
7780 back::INDENT,
7781 struct_name,
7782 arg_name
7783 )?;
7784 for (member_index, member) in members.iter().enumerate() {
7785 let key = NameKey::StructMember(arg.ty, member_index as u32);
7786 let name = &flattened_member_names[&key];
7787 if member_index != 0 {
7788 write!(self.out, ", ")?;
7789 }
7790 if self
7792 .struct_member_pads
7793 .contains(&(arg.ty, member_index as u32))
7794 {
7795 write!(self.out, "{{}}, ")?;
7796 }
7797 if let Some(crate::Binding::Location { .. }) = member.binding {
7798 if has_varyings {
7799 write!(self.out, "{varyings_member_name}.")?;
7800 }
7801 }
7802 write!(self.out, "{name}")?;
7803 }
7804 writeln!(self.out, " }};")?;
7805 }
7806 _ => match arg.binding {
7807 Some(crate::Binding::Location { .. })
7808 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
7809 if has_varyings {
7810 writeln!(
7811 self.out,
7812 "{}const auto {} = {}.{};",
7813 back::INDENT,
7814 arg_name,
7815 varyings_member_name,
7816 arg_name
7817 )?;
7818 }
7819 }
7820 _ => {}
7821 },
7822 }
7823 }
7824
7825 let guarded_indices =
7826 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7827
7828 let context = StatementContext {
7829 expression: ExpressionContext {
7830 function: fun,
7831 origin: FunctionOrigin::EntryPoint(ep_index as _),
7832 info: fun_info,
7833 lang_version: options.lang_version,
7834 policies: options.bounds_check_policies,
7835 guarded_indices,
7836 module,
7837 mod_info,
7838 pipeline_options,
7839 force_loop_bounding: options.force_loop_bounding,
7840 },
7841 result_struct: Some(&stage_out_name),
7842 };
7843
7844 self.put_locals(&context.expression)?;
7847 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7848 self.put_block(back::Level(1), &fun.body, &context)?;
7849 writeln!(self.out, "}}")?;
7850 if ep_index + 1 != module.entry_points.len() {
7851 writeln!(self.out)?;
7852 }
7853 self.named_expressions.clear();
7854 }
7855
7856 Ok(info)
7857 }
7858
7859 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7860 if flags.is_empty() {
7863 writeln!(
7864 self.out,
7865 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7866 )?;
7867 }
7868 if flags.contains(crate::Barrier::STORAGE) {
7869 writeln!(
7870 self.out,
7871 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7872 )?;
7873 }
7874 if flags.contains(crate::Barrier::WORK_GROUP) {
7875 writeln!(
7876 self.out,
7877 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7878 )?;
7879 }
7880 if flags.contains(crate::Barrier::SUB_GROUP) {
7881 writeln!(
7882 self.out,
7883 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7884 )?;
7885 }
7886 if flags.contains(crate::Barrier::TEXTURE) {
7887 writeln!(
7888 self.out,
7889 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7890 )?;
7891 }
7892 Ok(())
7893 }
7894}
7895
7896mod workgroup_mem_init {
7899 use crate::EntryPoint;
7900
7901 use super::*;
7902
7903 enum Access {
7904 GlobalVariable(Handle<crate::GlobalVariable>),
7905 StructMember(Handle<crate::Type>, u32),
7906 Array(usize),
7907 }
7908
7909 impl Access {
7910 fn write<W: Write>(
7911 &self,
7912 writer: &mut W,
7913 names: &FastHashMap<NameKey, String>,
7914 ) -> Result<(), core::fmt::Error> {
7915 match *self {
7916 Access::GlobalVariable(handle) => {
7917 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7918 }
7919 Access::StructMember(handle, index) => {
7920 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7921 }
7922 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7923 }
7924 }
7925 }
7926
7927 struct AccessStack {
7928 stack: Vec<Access>,
7929 array_depth: usize,
7930 }
7931
7932 impl AccessStack {
7933 const fn new() -> Self {
7934 Self {
7935 stack: Vec::new(),
7936 array_depth: 0,
7937 }
7938 }
7939
7940 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7941 let array_depth = self.array_depth;
7942 self.stack.push(Access::Array(array_depth));
7943 self.array_depth += 1;
7944 let res = cb(self, array_depth);
7945 self.stack.pop();
7946 self.array_depth -= 1;
7947 res
7948 }
7949
7950 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7951 self.stack.push(new);
7952 let res = cb(self);
7953 self.stack.pop();
7954 res
7955 }
7956
7957 fn write<W: Write>(
7958 &self,
7959 writer: &mut W,
7960 names: &FastHashMap<NameKey, String>,
7961 ) -> Result<(), core::fmt::Error> {
7962 for next in self.stack.iter() {
7963 next.write(writer, names)?;
7964 }
7965 Ok(())
7966 }
7967 }
7968
7969 impl<W: Write> Writer<W> {
7970 pub(super) fn need_workgroup_variables_initialization(
7971 &mut self,
7972 options: &Options,
7973 ep: &EntryPoint,
7974 module: &crate::Module,
7975 fun_info: &valid::FunctionInfo,
7976 ) -> bool {
7977 options.zero_initialize_workgroup_memory
7978 && ep.stage.compute_like()
7979 && module.global_variables.iter().any(|(handle, var)| {
7980 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7981 })
7982 }
7983
7984 pub(super) fn write_workgroup_variables_initialization(
7985 &mut self,
7986 module: &crate::Module,
7987 module_info: &valid::ModuleInfo,
7988 fun_info: &valid::FunctionInfo,
7989 local_invocation_id: Option<&NameKey>,
7990 ) -> BackendResult {
7991 let level = back::Level(1);
7992
7993 writeln!(
7994 self.out,
7995 "{}if ({}::all({} == {}::uint3(0u))) {{",
7996 level,
7997 NAMESPACE,
7998 local_invocation_id
7999 .map(|name_key| self.names[name_key].as_str())
8000 .unwrap_or("__local_invocation_id"),
8001 NAMESPACE,
8002 )?;
8003
8004 let mut access_stack = AccessStack::new();
8005
8006 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8007 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
8008 });
8009
8010 for (handle, var) in vars {
8011 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8012 self.write_workgroup_variable_initialization(
8013 module,
8014 module_info,
8015 var.ty,
8016 access_stack,
8017 level.next(),
8018 )
8019 })?;
8020 }
8021
8022 writeln!(self.out, "{level}}}")?;
8023 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8024 }
8025
8026 fn write_workgroup_variable_initialization(
8027 &mut self,
8028 module: &crate::Module,
8029 module_info: &valid::ModuleInfo,
8030 ty: Handle<crate::Type>,
8031 access_stack: &mut AccessStack,
8032 level: back::Level,
8033 ) -> BackendResult {
8034 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8035 write!(self.out, "{level}")?;
8036 access_stack.write(&mut self.out, &self.names)?;
8037 writeln!(self.out, " = {{}};")?;
8038 } else {
8039 match module.types[ty].inner {
8040 crate::TypeInner::Atomic { .. } => {
8041 write!(
8042 self.out,
8043 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8044 )?;
8045 access_stack.write(&mut self.out, &self.names)?;
8046 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8047 }
8048 crate::TypeInner::Array { base, size, .. } => {
8049 let count = match size.resolve(module.to_ctx())? {
8050 proc::IndexableLength::Known(count) => count,
8051 proc::IndexableLength::Dynamic => unreachable!(),
8052 };
8053
8054 access_stack.enter_array(|access_stack, array_depth| {
8055 writeln!(
8056 self.out,
8057 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8058 )?;
8059 self.write_workgroup_variable_initialization(
8060 module,
8061 module_info,
8062 base,
8063 access_stack,
8064 level.next(),
8065 )?;
8066 writeln!(self.out, "{level}}}")?;
8067 BackendResult::Ok(())
8068 })?;
8069 }
8070 crate::TypeInner::Struct { ref members, .. } => {
8071 for (index, member) in members.iter().enumerate() {
8072 access_stack.enter(
8073 Access::StructMember(ty, index as u32),
8074 |access_stack| {
8075 self.write_workgroup_variable_initialization(
8076 module,
8077 module_info,
8078 member.ty,
8079 access_stack,
8080 level,
8081 )
8082 },
8083 )?;
8084 }
8085 }
8086 _ => unreachable!(),
8087 }
8088 }
8089
8090 Ok(())
8091 }
8092 }
8093}
8094
8095impl crate::AtomicFunction {
8096 const fn to_msl(self) -> &'static str {
8097 match self {
8098 Self::Add => "fetch_add",
8099 Self::Subtract => "fetch_sub",
8100 Self::And => "fetch_and",
8101 Self::InclusiveOr => "fetch_or",
8102 Self::ExclusiveOr => "fetch_xor",
8103 Self::Min => "fetch_min",
8104 Self::Max => "fetch_max",
8105 Self::Exchange { compare: None } => "exchange",
8106 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8107 }
8108 }
8109
8110 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8111 Ok(match self {
8112 Self::Min => "min",
8113 Self::Max => "max",
8114 _ => Err(Error::FeatureNotImplemented(
8115 "64-bit atomic operation other than min/max".to_string(),
8116 ))?,
8117 })
8118 }
8119}