1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self, concrete_int_scalars,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68 "nagaTextureSampleBaseClampToEdge";
69pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
83pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
84
85fn put_numeric_type(
94 out: &mut impl Write,
95 scalar: crate::Scalar,
96 sizes: &[crate::VectorSize],
97) -> Result<(), FmtError> {
98 match (scalar, sizes) {
99 (scalar, &[]) => {
100 write!(out, "{}", scalar.to_msl_name())
101 }
102 (scalar, &[rows]) => {
103 write!(
104 out,
105 "{}::{}{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(rows)
109 )
110 }
111 (scalar, &[rows, columns]) => {
112 write!(
113 out,
114 "{}::{}{}x{}",
115 NAMESPACE,
116 scalar.to_msl_name(),
117 common::vector_size_str(columns),
118 common::vector_size_str(rows)
119 )
120 }
121 (_, _) => Ok(()), }
123}
124
125const fn scalar_is_int(scalar: crate::Scalar) -> bool {
126 use crate::ScalarKind::*;
127 match scalar.kind {
128 Sint | Uint | AbstractInt | Bool => true,
129 Float | AbstractFloat => false,
130 }
131}
132
133const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
135
136const REINTERPRET_PREFIX: &str = "reinterpreted_";
138
139struct ClampedLod(Handle<crate::Expression>);
145
146impl Display for ClampedLod {
147 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
148 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
149 }
150}
151
152struct ArraySizeMember(Handle<crate::GlobalVariable>);
167
168impl Display for ArraySizeMember {
169 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170 self.0.write_prefixed(f, "size")
171 }
172}
173
174#[derive(Clone, Copy)]
179struct Reinterpreted<'a> {
180 target_type: &'a str,
181 orig: Handle<crate::Expression>,
182}
183
184impl<'a> Reinterpreted<'a> {
185 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
186 Self { target_type, orig }
187 }
188}
189
190impl Display for Reinterpreted<'_> {
191 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
192 f.write_str(REINTERPRET_PREFIX)?;
193 f.write_str(self.target_type)?;
194 self.orig.write_prefixed(f, "_e")
195 }
196}
197
198struct TypeContext<'a> {
199 handle: Handle<crate::Type>,
200 gctx: proc::GlobalCtx<'a>,
201 names: &'a FastHashMap<NameKey, String>,
202 access: crate::StorageAccess,
203 first_time: bool,
204}
205
206impl TypeContext<'_> {
207 fn scalar(&self) -> Option<crate::Scalar> {
208 let ty = &self.gctx.types[self.handle];
209 ty.inner.scalar()
210 }
211
212 fn vertex_input_dimension(&self) -> u32 {
213 let ty = &self.gctx.types[self.handle];
214 match ty.inner {
215 crate::TypeInner::Scalar(_) => 1,
216 crate::TypeInner::Vector { size, .. } => size as u32,
217 _ => unreachable!(),
218 }
219 }
220}
221
222impl Display for TypeContext<'_> {
223 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224 let ty = &self.gctx.types[self.handle];
225 if ty.needs_alias() && !self.first_time {
226 let name = &self.names[&NameKey::Type(self.handle)];
227 return write!(out, "{name}");
228 }
229
230 match ty.inner {
231 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232 crate::TypeInner::Atomic(scalar) => {
233 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234 }
235 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236 crate::TypeInner::Matrix {
237 columns,
238 rows,
239 scalar,
240 } => put_numeric_type(out, scalar, &[rows, columns]),
241 crate::TypeInner::CooperativeMatrix {
243 columns,
244 rows,
245 scalar,
246 role: _,
247 } => {
248 write!(
249 out,
250 "{NAMESPACE}::simdgroup_{}{}x{}",
251 scalar.to_msl_name(),
252 columns as u32,
253 rows as u32,
254 )
255 }
256 crate::TypeInner::Pointer { base, space } => {
257 let sub = Self {
258 handle: base,
259 first_time: false,
260 ..*self
261 };
262 let space_name = match space.to_msl_name() {
263 Some(name) => name,
264 None => return Ok(()),
265 };
266 write!(out, "{space_name} {sub}&")
267 }
268 crate::TypeInner::ValuePointer {
269 size,
270 scalar,
271 space,
272 } => {
273 match space.to_msl_name() {
274 Some(name) => write!(out, "{name} ")?,
275 None => return Ok(()),
276 };
277 match size {
278 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279 None => put_numeric_type(out, scalar, &[])?,
280 };
281
282 write!(out, "&")
283 }
284 crate::TypeInner::Array { base, .. } => {
285 let sub = Self {
286 handle: base,
287 first_time: false,
288 ..*self
289 };
290 write!(out, "{sub}")
293 }
294 crate::TypeInner::Struct { .. } => unreachable!(),
295 crate::TypeInner::Image {
296 dim,
297 arrayed,
298 class,
299 } => {
300 let dim_str = match dim {
301 crate::ImageDimension::D1 => "1d",
302 crate::ImageDimension::D2 => "2d",
303 crate::ImageDimension::D3 => "3d",
304 crate::ImageDimension::Cube => "cube",
305 };
306 let (texture_str, msaa_str, scalar, access) = match class {
307 crate::ImageClass::Sampled { kind, multi } => {
308 let (msaa_str, access) = if multi {
309 ("_ms", "read")
310 } else {
311 ("", "sample")
312 };
313 let scalar = crate::Scalar { kind, width: 4 };
314 ("texture", msaa_str, scalar, access)
315 }
316 crate::ImageClass::Depth { multi } => {
317 let (msaa_str, access) = if multi {
318 ("_ms", "read")
319 } else {
320 ("", "sample")
321 };
322 let scalar = crate::Scalar {
323 kind: crate::ScalarKind::Float,
324 width: 4,
325 };
326 ("depth", msaa_str, scalar, access)
327 }
328 crate::ImageClass::Storage { format, .. } => {
329 let access = if self
330 .access
331 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332 {
333 "read_write"
334 } else if self.access.contains(crate::StorageAccess::STORE) {
335 "write"
336 } else if self.access.contains(crate::StorageAccess::LOAD) {
337 "read"
338 } else {
339 log::warn!(
340 "Storage access for {:?} (name '{}'): {:?}",
341 self.handle,
342 ty.name.as_deref().unwrap_or_default(),
343 self.access
344 );
345 unreachable!("module is not valid");
346 };
347 ("texture", "", format.into(), access)
348 }
349 crate::ImageClass::External => {
350 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351 }
352 };
353 let base_name = scalar.to_msl_name();
354 let array_str = if arrayed { "_array" } else { "" };
355 write!(
356 out,
357 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358 )
359 }
360 crate::TypeInner::Sampler { comparison: _ } => {
361 write!(out, "{NAMESPACE}::sampler")
362 }
363 crate::TypeInner::AccelerationStructure { vertex_return } => {
364 if vertex_return {
365 unimplemented!("metal does not support vertex ray hit return")
366 }
367 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368 }
369 crate::TypeInner::RayQuery { vertex_return } => {
370 if vertex_return {
371 unimplemented!("metal does not support vertex ray hit return")
372 }
373 write!(out, "{RAY_QUERY_TYPE}")
374 }
375 crate::TypeInner::BindingArray { base, .. } => {
376 let base_tyname = Self {
377 handle: base,
378 first_time: false,
379 ..*self
380 };
381
382 write!(
383 out,
384 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385 )
386 }
387 }
388 }
389}
390
391struct TypedGlobalVariable<'a> {
392 module: &'a crate::Module,
393 names: &'a FastHashMap<NameKey, String>,
394 handle: Handle<crate::GlobalVariable>,
395 usage: valid::GlobalUse,
396 reference: bool,
397}
398
399impl TypedGlobalVariable<'_> {
400 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
401 let var = &self.module.global_variables[self.handle];
402 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
403
404 let storage_access = match var.space {
405 crate::AddressSpace::Storage { access } => access,
406 _ => match self.module.types[var.ty].inner {
407 crate::TypeInner::Image {
408 class: crate::ImageClass::Storage { access, .. },
409 ..
410 } => access,
411 crate::TypeInner::BindingArray { base, .. } => {
412 match self.module.types[base].inner {
413 crate::TypeInner::Image {
414 class: crate::ImageClass::Storage { access, .. },
415 ..
416 } => access,
417 _ => crate::StorageAccess::default(),
418 }
419 }
420 _ => crate::StorageAccess::default(),
421 },
422 };
423 let ty_name = TypeContext {
424 handle: var.ty,
425 gctx: self.module.to_ctx(),
426 names: self.names,
427 access: storage_access,
428 first_time: false,
429 };
430
431 let (space, access, reference) = match var.space.to_msl_name() {
432 Some(space) if self.reference => {
433 let access = if var.space.needs_access_qualifier()
434 && !self.usage.intersects(valid::GlobalUse::WRITE)
435 {
436 "const"
437 } else {
438 ""
439 };
440 (space, access, "&")
441 }
442 _ => ("", "", ""),
443 };
444
445 Ok(write!(
446 out,
447 "{}{}{}{}{}{} {}",
448 space,
449 if space.is_empty() { "" } else { " " },
450 ty_name,
451 if access.is_empty() { "" } else { " " },
452 access,
453 reference,
454 name,
455 )?)
456 }
457}
458
459#[derive(Eq, PartialEq, Hash)]
460enum WrappedFunction {
461 UnaryOp {
462 op: crate::UnaryOperator,
463 ty: (Option<crate::VectorSize>, crate::Scalar),
464 },
465 BinaryOp {
466 op: crate::BinaryOperator,
467 left_ty: (Option<crate::VectorSize>, crate::Scalar),
468 right_ty: (Option<crate::VectorSize>, crate::Scalar),
469 },
470 Math {
471 fun: crate::MathFunction,
472 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
473 },
474 Cast {
475 src_scalar: crate::Scalar,
476 vector_size: Option<crate::VectorSize>,
477 dst_scalar: crate::Scalar,
478 },
479 ImageLoad {
480 class: crate::ImageClass,
481 },
482 ImageSample {
483 class: crate::ImageClass,
484 clamp_to_edge: bool,
485 },
486 ImageQuerySize {
487 class: crate::ImageClass,
488 },
489 CooperativeLoad {
490 space_name: &'static str,
491 columns: crate::CooperativeSize,
492 rows: crate::CooperativeSize,
493 scalar: crate::Scalar,
494 },
495 CooperativeMultiplyAdd {
496 space_name: &'static str,
497 columns: crate::CooperativeSize,
498 rows: crate::CooperativeSize,
499 intermediate: crate::CooperativeSize,
500 scalar: crate::Scalar,
501 },
502}
503
504pub struct Writer<W> {
505 out: W,
506 names: FastHashMap<NameKey, String>,
507 named_expressions: crate::NamedExpressions,
508 need_bake_expressions: back::NeedBakeExpressions,
510 namer: proc::Namer,
511 wrapped_functions: FastHashSet<WrappedFunction>,
512 #[cfg(test)]
513 put_expression_stack_pointers: FastHashSet<*const ()>,
514 #[cfg(test)]
515 put_block_stack_pointers: FastHashSet<*const ()>,
516 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
519}
520
521impl crate::Scalar {
522 pub(super) fn to_msl_name(self) -> &'static str {
523 use crate::ScalarKind as Sk;
524 match self {
525 Self {
526 kind: Sk::Float,
527 width: 4,
528 } => "float",
529 Self {
530 kind: Sk::Float,
531 width: 2,
532 } => "half",
533 Self {
534 kind: Sk::Sint,
535 width: 4,
536 } => "int",
537 Self {
538 kind: Sk::Uint,
539 width: 4,
540 } => "uint",
541 Self {
542 kind: Sk::Sint,
543 width: 8,
544 } => "long",
545 Self {
546 kind: Sk::Uint,
547 width: 8,
548 } => "ulong",
549 Self {
550 kind: Sk::Bool,
551 width: _,
552 } => "bool",
553 Self {
554 kind: Sk::AbstractInt | Sk::AbstractFloat,
555 width: _,
556 } => unreachable!("Found Abstract scalar kind"),
557 _ => unreachable!("Unsupported scalar kind: {:?}", self),
558 }
559 }
560}
561
562const fn separate(need_separator: bool) -> &'static str {
563 if need_separator {
564 ","
565 } else {
566 ""
567 }
568}
569
570fn should_pack_struct_member(
571 members: &[crate::StructMember],
572 span: u32,
573 index: usize,
574 module: &crate::Module,
575) -> Option<crate::Scalar> {
576 let member = &members[index];
577
578 let ty_inner = &module.types[member.ty].inner;
579 let last_offset = member.offset + ty_inner.size(module.to_ctx());
580 let next_offset = match members.get(index + 1) {
581 Some(next) => next.offset,
582 None => span,
583 };
584 let is_tight = next_offset == last_offset;
585
586 match *ty_inner {
587 crate::TypeInner::Vector {
588 size: crate::VectorSize::Tri,
589 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
590 } if is_tight => Some(scalar),
591 _ => None,
592 }
593}
594
595fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
596 match arena[ty].inner {
597 crate::TypeInner::Struct { ref members, .. } => {
598 if let Some(member) = members.last() {
599 if let crate::TypeInner::Array {
600 size: crate::ArraySize::Dynamic,
601 ..
602 } = arena[member.ty].inner
603 {
604 return true;
605 }
606 }
607 false
608 }
609 crate::TypeInner::Array {
610 size: crate::ArraySize::Dynamic,
611 ..
612 } => true,
613 _ => false,
614 }
615}
616
617impl crate::AddressSpace {
618 const fn needs_pass_through(&self) -> bool {
622 match *self {
623 Self::Uniform
624 | Self::Storage { .. }
625 | Self::Private
626 | Self::WorkGroup
627 | Self::Immediate
628 | Self::Handle
629 | Self::TaskPayload => true,
630 Self::Function => false,
631 }
632 }
633
634 const fn needs_access_qualifier(&self) -> bool {
636 match *self {
637 Self::Storage { .. } => true,
642 Self::TaskPayload => unimplemented!(),
643 Self::Private | Self::WorkGroup => false,
645 Self::Uniform | Self::Immediate => false,
647 Self::Handle | Self::Function => false,
649 }
650 }
651
652 const fn to_msl_name(self) -> Option<&'static str> {
653 match self {
654 Self::Handle => None,
655 Self::Uniform | Self::Immediate => Some("constant"),
656 Self::Storage { .. } => Some("device"),
657 Self::Private | Self::Function => Some("thread"),
658 Self::WorkGroup => Some("threadgroup"),
659 Self::TaskPayload => Some("object_data"),
660 }
661 }
662}
663
664impl crate::Type {
665 const fn needs_alias(&self) -> bool {
667 use crate::TypeInner as Ti;
668
669 match self.inner {
670 Ti::Scalar(_)
672 | Ti::Vector { .. }
673 | Ti::Matrix { .. }
674 | Ti::CooperativeMatrix { .. }
675 | Ti::Atomic(_)
676 | Ti::Pointer { .. }
677 | Ti::ValuePointer { .. } => self.name.is_some(),
678 Ti::Struct { .. } | Ti::Array { .. } => true,
680 Ti::Image { .. }
682 | Ti::Sampler { .. }
683 | Ti::AccelerationStructure { .. }
684 | Ti::RayQuery { .. }
685 | Ti::BindingArray { .. } => false,
686 }
687 }
688}
689
690#[derive(Clone, Copy)]
691enum FunctionOrigin {
692 Handle(Handle<crate::Function>),
693 EntryPoint(proc::EntryPointIndex),
694}
695
696trait NameKeyExt {
697 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
698 match origin {
699 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
700 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
701 }
702 }
703
704 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
709 match origin {
710 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
711 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
712 }
713 }
714}
715
716impl NameKeyExt for NameKey {}
717
718#[derive(Clone, Copy)]
728enum LevelOfDetail {
729 Direct(Handle<crate::Expression>),
730 Restricted(Handle<crate::Expression>),
731}
732
733struct TexelAddress {
743 coordinate: Handle<crate::Expression>,
744 array_index: Option<Handle<crate::Expression>>,
745 sample: Option<Handle<crate::Expression>>,
746 level: Option<LevelOfDetail>,
747}
748
749struct ExpressionContext<'a> {
750 function: &'a crate::Function,
751 origin: FunctionOrigin,
752 info: &'a valid::FunctionInfo,
753 module: &'a crate::Module,
754 mod_info: &'a valid::ModuleInfo,
755 pipeline_options: &'a PipelineOptions,
756 lang_version: (u8, u8),
757 policies: index::BoundsCheckPolicies,
758
759 guarded_indices: HandleSet<crate::Expression>,
763 force_loop_bounding: bool,
765}
766
767impl<'a> ExpressionContext<'a> {
768 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
769 self.info[handle].ty.inner_with(&self.module.types)
770 }
771
772 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
779 let image_ty = self.resolve_type(image);
780 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
781 class.is_mipmapped() && dim != crate::ImageDimension::D1
782 } else {
783 false
784 }
785 }
786
787 fn choose_bounds_check_policy(
788 &self,
789 pointer: Handle<crate::Expression>,
790 ) -> index::BoundsCheckPolicy {
791 self.policies
792 .choose_policy(pointer, &self.module.types, self.info)
793 }
794
795 fn access_needs_check(
797 &self,
798 base: Handle<crate::Expression>,
799 index: index::GuardedIndex,
800 ) -> Option<index::IndexableLength> {
801 index::access_needs_check(
802 base,
803 index,
804 self.module,
805 &self.function.expressions,
806 self.info,
807 )
808 }
809
810 fn bounds_check_iter(
812 &self,
813 chain: Handle<crate::Expression>,
814 ) -> impl Iterator<Item = BoundsCheck> + '_ {
815 index::bounds_check_iter(chain, self.module, self.function, self.info)
816 }
817
818 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
820 index::oob_local_types(self.module, self.function, self.info, self.policies)
821 }
822
823 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
824 match self.function.expressions[expr_handle] {
825 crate::Expression::AccessIndex { base, index } => {
826 let ty = match *self.resolve_type(base) {
827 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
828 ref ty => ty,
829 };
830 match *ty {
831 crate::TypeInner::Struct {
832 ref members, span, ..
833 } => should_pack_struct_member(members, span, index as usize, self.module),
834 _ => None,
835 }
836 }
837 _ => None,
838 }
839 }
840}
841
842struct StatementContext<'a> {
843 expression: ExpressionContext<'a>,
844 result_struct: Option<&'a str>,
845}
846
847impl<W: Write> Writer<W> {
848 pub fn new(out: W) -> Self {
850 Writer {
851 out,
852 names: FastHashMap::default(),
853 named_expressions: Default::default(),
854 need_bake_expressions: Default::default(),
855 namer: proc::Namer::default(),
856 wrapped_functions: FastHashSet::default(),
857 #[cfg(test)]
858 put_expression_stack_pointers: Default::default(),
859 #[cfg(test)]
860 put_block_stack_pointers: Default::default(),
861 struct_member_pads: FastHashSet::default(),
862 }
863 }
864
865 #[allow(clippy::missing_const_for_fn)]
868 pub fn finish(self) -> W {
869 self.out
870 }
871
872 fn gen_force_bounded_loop_statements(
979 &mut self,
980 level: back::Level,
981 context: &StatementContext,
982 ) -> Option<(String, String)> {
983 if !context.expression.force_loop_bounding {
984 return None;
985 }
986
987 let loop_bound_name = self.namer.call("loop_bound");
988 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
991 let level = level.next();
992 let break_and_inc = format!(
993 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
994{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
995 );
996
997 Some((decl, break_and_inc))
998 }
999
1000 fn put_call_parameters(
1001 &mut self,
1002 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1003 context: &ExpressionContext,
1004 ) -> BackendResult {
1005 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1006 writer.put_expression(expr, context, true)
1007 })
1008 }
1009
1010 fn put_call_parameters_impl<C, E>(
1011 &mut self,
1012 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1013 ctx: &C,
1014 put_expression: E,
1015 ) -> BackendResult
1016 where
1017 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1018 {
1019 write!(self.out, "(")?;
1020 for (i, handle) in parameters.enumerate() {
1021 if i != 0 {
1022 write!(self.out, ", ")?;
1023 }
1024 put_expression(self, ctx, handle)?;
1025 }
1026 write!(self.out, ")")?;
1027 Ok(())
1028 }
1029
1030 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1036 let oob_local_types = context.oob_local_types();
1037 for &ty in oob_local_types.iter() {
1038 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1039 self.names.insert(name_key, self.namer.call("oob"));
1040 }
1041
1042 for (name_key, ty, init) in context
1043 .function
1044 .local_variables
1045 .iter()
1046 .map(|(local_handle, local)| {
1047 let name_key = NameKey::local(context.origin, local_handle);
1048 (name_key, local.ty, local.init)
1049 })
1050 .chain(oob_local_types.iter().map(|&ty| {
1051 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1052 (name_key, ty, None)
1053 }))
1054 {
1055 let ty_name = TypeContext {
1056 handle: ty,
1057 gctx: context.module.to_ctx(),
1058 names: &self.names,
1059 access: crate::StorageAccess::empty(),
1060 first_time: false,
1061 };
1062 write!(
1063 self.out,
1064 "{}{} {}",
1065 back::INDENT,
1066 ty_name,
1067 self.names[&name_key]
1068 )?;
1069 match init {
1070 Some(value) => {
1071 write!(self.out, " = ")?;
1072 self.put_expression(value, context, true)?;
1073 }
1074 None => {
1075 write!(self.out, " = {{}}")?;
1076 }
1077 };
1078 writeln!(self.out, ";")?;
1079 }
1080 Ok(())
1081 }
1082
1083 fn put_level_of_detail(
1084 &mut self,
1085 level: LevelOfDetail,
1086 context: &ExpressionContext,
1087 ) -> BackendResult {
1088 match level {
1089 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1090 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1091 }
1092 Ok(())
1093 }
1094
1095 fn put_image_query(
1096 &mut self,
1097 image: Handle<crate::Expression>,
1098 query: &str,
1099 level: Option<LevelOfDetail>,
1100 context: &ExpressionContext,
1101 ) -> BackendResult {
1102 self.put_expression(image, context, false)?;
1103 write!(self.out, ".get_{query}(")?;
1104 if let Some(level) = level {
1105 self.put_level_of_detail(level, context)?;
1106 }
1107 write!(self.out, ")")?;
1108 Ok(())
1109 }
1110
1111 fn put_image_size_query(
1112 &mut self,
1113 image: Handle<crate::Expression>,
1114 level: Option<LevelOfDetail>,
1115 kind: crate::ScalarKind,
1116 context: &ExpressionContext,
1117 ) -> BackendResult {
1118 if let crate::TypeInner::Image {
1119 class: crate::ImageClass::External,
1120 ..
1121 } = *context.resolve_type(image)
1122 {
1123 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1124 self.put_expression(image, context, true)?;
1125 write!(self.out, ")")?;
1126 return Ok(());
1127 }
1128
1129 let dim = match *context.resolve_type(image) {
1132 crate::TypeInner::Image { dim, .. } => dim,
1133 ref other => unreachable!("Unexpected type {:?}", other),
1134 };
1135 let scalar = crate::Scalar { kind, width: 4 };
1136 let coordinate_type = scalar.to_msl_name();
1137 match dim {
1138 crate::ImageDimension::D1 => {
1139 if kind == crate::ScalarKind::Uint {
1143 self.put_image_query(image, "width", None, context)?;
1145 } else {
1146 write!(self.out, "int(")?;
1148 self.put_image_query(image, "width", None, context)?;
1149 write!(self.out, ")")?;
1150 }
1151 }
1152 crate::ImageDimension::D2 => {
1153 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1154 self.put_image_query(image, "width", level, context)?;
1155 write!(self.out, ", ")?;
1156 self.put_image_query(image, "height", level, context)?;
1157 write!(self.out, ")")?;
1158 }
1159 crate::ImageDimension::D3 => {
1160 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1161 self.put_image_query(image, "width", level, context)?;
1162 write!(self.out, ", ")?;
1163 self.put_image_query(image, "height", level, context)?;
1164 write!(self.out, ", ")?;
1165 self.put_image_query(image, "depth", level, context)?;
1166 write!(self.out, ")")?;
1167 }
1168 crate::ImageDimension::Cube => {
1169 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1170 self.put_image_query(image, "width", level, context)?;
1171 write!(self.out, ")")?;
1172 }
1173 }
1174 Ok(())
1175 }
1176
1177 fn put_cast_to_uint_scalar_or_vector(
1178 &mut self,
1179 expr: Handle<crate::Expression>,
1180 context: &ExpressionContext,
1181 ) -> BackendResult {
1182 match *context.resolve_type(expr) {
1184 crate::TypeInner::Scalar(_) => {
1185 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1186 }
1187 crate::TypeInner::Vector { size, .. } => {
1188 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1189 }
1190 _ => {
1191 return Err(Error::GenericValidation(
1192 "Invalid type for image coordinate".into(),
1193 ))
1194 }
1195 };
1196
1197 write!(self.out, "(")?;
1198 self.put_expression(expr, context, true)?;
1199 write!(self.out, ")")?;
1200 Ok(())
1201 }
1202
1203 fn put_image_sample_level(
1204 &mut self,
1205 image: Handle<crate::Expression>,
1206 level: crate::SampleLevel,
1207 context: &ExpressionContext,
1208 ) -> BackendResult {
1209 let has_levels = context.image_needs_lod(image);
1210 match level {
1211 crate::SampleLevel::Auto => {}
1212 crate::SampleLevel::Zero => {
1213 }
1215 _ if !has_levels => {
1216 log::warn!("1D image can't be sampled with level {level:?}");
1217 }
1218 crate::SampleLevel::Exact(h) => {
1219 write!(self.out, ", {NAMESPACE}::level(")?;
1220 self.put_expression(h, context, true)?;
1221 write!(self.out, ")")?;
1222 }
1223 crate::SampleLevel::Bias(h) => {
1224 write!(self.out, ", {NAMESPACE}::bias(")?;
1225 self.put_expression(h, context, true)?;
1226 write!(self.out, ")")?;
1227 }
1228 crate::SampleLevel::Gradient { x, y } => {
1229 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1230 self.put_expression(x, context, true)?;
1231 write!(self.out, ", ")?;
1232 self.put_expression(y, context, true)?;
1233 write!(self.out, ")")?;
1234 }
1235 }
1236 Ok(())
1237 }
1238
1239 fn put_image_coordinate_limits(
1240 &mut self,
1241 image: Handle<crate::Expression>,
1242 level: Option<LevelOfDetail>,
1243 context: &ExpressionContext,
1244 ) -> BackendResult {
1245 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1246 write!(self.out, " - 1")?;
1247 Ok(())
1248 }
1249
1250 fn put_restricted_scalar_image_index(
1268 &mut self,
1269 image: Handle<crate::Expression>,
1270 index: Handle<crate::Expression>,
1271 limit_method: &str,
1272 context: &ExpressionContext,
1273 ) -> BackendResult {
1274 write!(self.out, "{NAMESPACE}::min(uint(")?;
1275 self.put_expression(index, context, true)?;
1276 write!(self.out, "), ")?;
1277 self.put_expression(image, context, false)?;
1278 write!(self.out, ".{limit_method}() - 1)")?;
1279 Ok(())
1280 }
1281
1282 fn put_restricted_texel_address(
1283 &mut self,
1284 image: Handle<crate::Expression>,
1285 address: &TexelAddress,
1286 context: &ExpressionContext,
1287 ) -> BackendResult {
1288 write!(self.out, "{NAMESPACE}::min(")?;
1290 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1291 write!(self.out, ", ")?;
1292 self.put_image_coordinate_limits(image, address.level, context)?;
1293 write!(self.out, ")")?;
1294
1295 if let Some(array_index) = address.array_index {
1297 write!(self.out, ", ")?;
1298 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1299 }
1300
1301 if let Some(sample) = address.sample {
1303 write!(self.out, ", ")?;
1304 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1305 }
1306
1307 if let Some(level) = address.level {
1310 write!(self.out, ", ")?;
1311 self.put_level_of_detail(level, context)?;
1312 }
1313
1314 Ok(())
1315 }
1316
1317 fn put_image_access_bounds_check(
1319 &mut self,
1320 image: Handle<crate::Expression>,
1321 address: &TexelAddress,
1322 context: &ExpressionContext,
1323 ) -> BackendResult {
1324 let mut conjunction = "";
1325
1326 let level = if let Some(level) = address.level {
1329 write!(self.out, "uint(")?;
1330 self.put_level_of_detail(level, context)?;
1331 write!(self.out, ") < ")?;
1332 self.put_expression(image, context, true)?;
1333 write!(self.out, ".get_num_mip_levels()")?;
1334 conjunction = " && ";
1335 Some(level)
1336 } else {
1337 None
1338 };
1339
1340 if let Some(sample) = address.sample {
1342 write!(self.out, "uint(")?;
1343 self.put_expression(sample, context, true)?;
1344 write!(self.out, ") < ")?;
1345 self.put_expression(image, context, true)?;
1346 write!(self.out, ".get_num_samples()")?;
1347 conjunction = " && ";
1348 }
1349
1350 if let Some(array_index) = address.array_index {
1352 write!(self.out, "{conjunction}uint(")?;
1353 self.put_expression(array_index, context, true)?;
1354 write!(self.out, ") < ")?;
1355 self.put_expression(image, context, true)?;
1356 write!(self.out, ".get_array_size()")?;
1357 conjunction = " && ";
1358 }
1359
1360 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1362 crate::TypeInner::Vector { .. } => true,
1363 _ => false,
1364 };
1365 write!(self.out, "{conjunction}")?;
1366 if coord_is_vector {
1367 write!(self.out, "{NAMESPACE}::all(")?;
1368 }
1369 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1370 write!(self.out, " < ")?;
1371 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1372 if coord_is_vector {
1373 write!(self.out, ")")?;
1374 }
1375
1376 Ok(())
1377 }
1378
1379 fn put_image_load(
1380 &mut self,
1381 load: Handle<crate::Expression>,
1382 image: Handle<crate::Expression>,
1383 mut address: TexelAddress,
1384 context: &ExpressionContext,
1385 ) -> BackendResult {
1386 if let crate::TypeInner::Image {
1387 class: crate::ImageClass::External,
1388 ..
1389 } = *context.resolve_type(image)
1390 {
1391 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1392 self.put_expression(image, context, true)?;
1393 write!(self.out, ", ")?;
1394 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1395 write!(self.out, ")")?;
1396 return Ok(());
1397 }
1398
1399 match context.policies.image_load {
1400 proc::BoundsCheckPolicy::Restrict => {
1401 if address.level.is_some() {
1404 address.level = if context.image_needs_lod(image) {
1405 Some(LevelOfDetail::Restricted(load))
1406 } else {
1407 None
1408 }
1409 }
1410
1411 self.put_expression(image, context, false)?;
1412 write!(self.out, ".read(")?;
1413 self.put_restricted_texel_address(image, &address, context)?;
1414 write!(self.out, ")")?;
1415 }
1416 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1417 write!(self.out, "(")?;
1418 self.put_image_access_bounds_check(image, &address, context)?;
1419 write!(self.out, " ? ")?;
1420 self.put_unchecked_image_load(image, &address, context)?;
1421 write!(self.out, ": DefaultConstructible())")?;
1422 }
1423 proc::BoundsCheckPolicy::Unchecked => {
1424 self.put_unchecked_image_load(image, &address, context)?;
1425 }
1426 }
1427
1428 Ok(())
1429 }
1430
1431 fn put_unchecked_image_load(
1432 &mut self,
1433 image: Handle<crate::Expression>,
1434 address: &TexelAddress,
1435 context: &ExpressionContext,
1436 ) -> BackendResult {
1437 self.put_expression(image, context, false)?;
1438 write!(self.out, ".read(")?;
1439 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1441 if let Some(expr) = address.array_index {
1442 write!(self.out, ", ")?;
1443 self.put_expression(expr, context, true)?;
1444 }
1445 if let Some(sample) = address.sample {
1446 write!(self.out, ", ")?;
1447 self.put_expression(sample, context, true)?;
1448 }
1449 if let Some(level) = address.level {
1450 if context.image_needs_lod(image) {
1451 write!(self.out, ", ")?;
1452 self.put_level_of_detail(level, context)?;
1453 }
1454 }
1455 write!(self.out, ")")?;
1456
1457 Ok(())
1458 }
1459
1460 fn put_image_atomic(
1461 &mut self,
1462 level: back::Level,
1463 image: Handle<crate::Expression>,
1464 address: &TexelAddress,
1465 fun: crate::AtomicFunction,
1466 value: Handle<crate::Expression>,
1467 context: &StatementContext,
1468 ) -> BackendResult {
1469 write!(self.out, "{level}")?;
1470 self.put_expression(image, &context.expression, false)?;
1471 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1472 fun.to_msl_64_bit()?
1473 } else {
1474 fun.to_msl()
1475 };
1476 write!(self.out, ".atomic_{op}(")?;
1477 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1479 write!(self.out, ", ")?;
1480 self.put_expression(value, &context.expression, true)?;
1481 writeln!(self.out, ");")?;
1482
1483 Ok(())
1484 }
1485
1486 fn put_image_store(
1487 &mut self,
1488 level: back::Level,
1489 image: Handle<crate::Expression>,
1490 address: &TexelAddress,
1491 value: Handle<crate::Expression>,
1492 context: &StatementContext,
1493 ) -> BackendResult {
1494 write!(self.out, "{level}")?;
1495 self.put_expression(image, &context.expression, false)?;
1496 write!(self.out, ".write(")?;
1497 self.put_expression(value, &context.expression, true)?;
1498 write!(self.out, ", ")?;
1499 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1501 if let Some(expr) = address.array_index {
1502 write!(self.out, ", ")?;
1503 self.put_expression(expr, &context.expression, true)?;
1504 }
1505 writeln!(self.out, ");")?;
1506
1507 Ok(())
1508 }
1509
1510 fn put_dynamic_array_max_index(
1521 &mut self,
1522 handle: Handle<crate::GlobalVariable>,
1523 context: &ExpressionContext,
1524 ) -> BackendResult {
1525 let global = &context.module.global_variables[handle];
1526 let (offset, array_ty) = match context.module.types[global.ty].inner {
1527 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1528 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1529 None => return Err(Error::GenericValidation("Struct has no members".into())),
1530 },
1531 crate::TypeInner::Array {
1532 size: crate::ArraySize::Dynamic,
1533 ..
1534 } => (0, global.ty),
1535 ref ty => {
1536 return Err(Error::GenericValidation(format!(
1537 "Expected type with dynamic array, got {ty:?}"
1538 )))
1539 }
1540 };
1541
1542 let (size, stride) = match context.module.types[array_ty].inner {
1543 crate::TypeInner::Array { base, stride, .. } => (
1544 context.module.types[base]
1545 .inner
1546 .size(context.module.to_ctx()),
1547 stride,
1548 ),
1549 ref ty => {
1550 return Err(Error::GenericValidation(format!(
1551 "Expected array type, got {ty:?}"
1552 )))
1553 }
1554 };
1555
1556 write!(
1569 self.out,
1570 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1571 member = ArraySizeMember(handle),
1572 offset = offset,
1573 size = size,
1574 stride = stride,
1575 )?;
1576 Ok(())
1577 }
1578
1579 fn put_dot_product<T: Copy>(
1584 &mut self,
1585 arg: T,
1586 arg1: T,
1587 size: usize,
1588 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1589 ) -> BackendResult {
1590 write!(self.out, "(")?;
1593
1594 for index in 0..size {
1596 write!(self.out, " + ")?;
1599 extractor(self, arg, index)?;
1600 write!(self.out, " * ")?;
1601 extractor(self, arg1, index)?;
1602 }
1603
1604 write!(self.out, ")")?;
1605 Ok(())
1606 }
1607
1608 fn put_pack4x8(
1610 &mut self,
1611 arg: Handle<crate::Expression>,
1612 context: &ExpressionContext<'_>,
1613 was_signed: bool,
1614 clamp_bounds: Option<(&str, &str)>,
1615 ) -> Result<(), Error> {
1616 let write_arg = |this: &mut Self| -> BackendResult {
1617 if let Some((min, max)) = clamp_bounds {
1618 write!(this.out, "{NAMESPACE}::clamp(")?;
1620 this.put_expression(arg, context, true)?;
1621 write!(this.out, ", {min}, {max})")?;
1622 } else {
1623 this.put_expression(arg, context, true)?;
1624 }
1625 Ok(())
1626 };
1627
1628 if context.lang_version >= (2, 1) {
1629 let packed_type = if was_signed {
1630 "packed_char4"
1631 } else {
1632 "packed_uchar4"
1633 };
1634 write!(self.out, "as_type<uint>({packed_type}(")?;
1636 write_arg(self)?;
1637 write!(self.out, "))")?;
1638 } else {
1639 if was_signed {
1641 write!(self.out, "uint(")?;
1642 }
1643 write!(self.out, "(")?;
1644 write_arg(self)?;
1645 write!(self.out, "[0] & 0xFF) | ((")?;
1646 write_arg(self)?;
1647 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1648 write_arg(self)?;
1649 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1650 write_arg(self)?;
1651 write!(self.out, "[3] & 0xFF) << 24)")?;
1652 if was_signed {
1653 write!(self.out, ")")?;
1654 }
1655 }
1656
1657 Ok(())
1658 }
1659
1660 fn put_isign(
1663 &mut self,
1664 arg: Handle<crate::Expression>,
1665 context: &ExpressionContext,
1666 ) -> BackendResult {
1667 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1668 let scalar = context
1669 .resolve_type(arg)
1670 .scalar()
1671 .expect("put_isign should only be called for args which have an integer scalar type")
1672 .to_msl_name();
1673 match context.resolve_type(arg) {
1674 &crate::TypeInner::Vector { size, .. } => {
1675 let size = common::vector_size_str(size);
1676 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1677 }
1678 _ => {
1679 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1680 }
1681 }
1682 write!(self.out, ", (")?;
1683 self.put_expression(arg, context, true)?;
1684 write!(self.out, " > 0)), {scalar}(0), (")?;
1685 self.put_expression(arg, context, true)?;
1686 write!(self.out, " == 0))")?;
1687 Ok(())
1688 }
1689
1690 fn put_const_expression(
1691 &mut self,
1692 expr_handle: Handle<crate::Expression>,
1693 module: &crate::Module,
1694 mod_info: &valid::ModuleInfo,
1695 arena: &crate::Arena<crate::Expression>,
1696 ) -> BackendResult {
1697 self.put_possibly_const_expression(
1698 expr_handle,
1699 arena,
1700 module,
1701 mod_info,
1702 &(module, mod_info),
1703 |&(_, mod_info), expr| &mod_info[expr],
1704 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1705 )
1706 }
1707
1708 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1709 match literal {
1710 crate::Literal::F64(_) => {
1711 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1712 }
1713 crate::Literal::F16(value) => {
1714 if value.is_infinite() {
1715 let sign = if value.is_sign_negative() { "-" } else { "" };
1716 write!(self.out, "{sign}INFINITY")?;
1717 } else if value.is_nan() {
1718 write!(self.out, "NAN")?;
1719 } else {
1720 let suffix = if value.fract() == f16::from_f32(0.0) {
1721 ".0h"
1722 } else {
1723 "h"
1724 };
1725 write!(self.out, "{value}{suffix}")?;
1726 }
1727 }
1728 crate::Literal::F32(value) => {
1729 if value.is_infinite() {
1730 let sign = if value.is_sign_negative() { "-" } else { "" };
1731 write!(self.out, "{sign}INFINITY")?;
1732 } else if value.is_nan() {
1733 write!(self.out, "NAN")?;
1734 } else {
1735 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1736 write!(self.out, "{value}{suffix}")?;
1737 }
1738 }
1739 crate::Literal::U32(value) => {
1740 write!(self.out, "{value}u")?;
1741 }
1742 crate::Literal::I32(value) => {
1743 if value == i32::MIN {
1748 write!(self.out, "({} - 1)", value + 1)?;
1749 } else {
1750 write!(self.out, "{value}")?;
1751 }
1752 }
1753 crate::Literal::U64(value) => {
1754 write!(self.out, "{value}uL")?;
1755 }
1756 crate::Literal::I64(value) => {
1757 if value == i64::MIN {
1764 write!(self.out, "({}L - 1L)", value + 1)?;
1765 } else {
1766 write!(self.out, "{value}L")?;
1767 }
1768 }
1769 crate::Literal::Bool(value) => {
1770 write!(self.out, "{value}")?;
1771 }
1772 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1773 return Err(Error::GenericValidation(
1774 "Unsupported abstract literal".into(),
1775 ));
1776 }
1777 }
1778 Ok(())
1779 }
1780
1781 #[allow(clippy::too_many_arguments)]
1782 fn put_possibly_const_expression<C, I, E>(
1783 &mut self,
1784 expr_handle: Handle<crate::Expression>,
1785 expressions: &crate::Arena<crate::Expression>,
1786 module: &crate::Module,
1787 mod_info: &valid::ModuleInfo,
1788 ctx: &C,
1789 get_expr_ty: I,
1790 put_expression: E,
1791 ) -> BackendResult
1792 where
1793 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1794 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1795 {
1796 match expressions[expr_handle] {
1797 crate::Expression::Literal(literal) => {
1798 self.put_literal(literal)?;
1799 }
1800 crate::Expression::Constant(handle) => {
1801 let constant = &module.constants[handle];
1802 if constant.name.is_some() {
1803 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1804 } else {
1805 self.put_const_expression(
1806 constant.init,
1807 module,
1808 mod_info,
1809 &module.global_expressions,
1810 )?;
1811 }
1812 }
1813 crate::Expression::ZeroValue(ty) => {
1814 let ty_name = TypeContext {
1815 handle: ty,
1816 gctx: module.to_ctx(),
1817 names: &self.names,
1818 access: crate::StorageAccess::empty(),
1819 first_time: false,
1820 };
1821 write!(self.out, "{ty_name} {{}}")?;
1822 }
1823 crate::Expression::Compose { ty, ref components } => {
1824 let ty_name = TypeContext {
1825 handle: ty,
1826 gctx: module.to_ctx(),
1827 names: &self.names,
1828 access: crate::StorageAccess::empty(),
1829 first_time: false,
1830 };
1831 write!(self.out, "{ty_name}")?;
1832 match module.types[ty].inner {
1833 crate::TypeInner::Scalar(_)
1834 | crate::TypeInner::Vector { .. }
1835 | crate::TypeInner::Matrix { .. } => {
1836 self.put_call_parameters_impl(
1837 components.iter().copied(),
1838 ctx,
1839 put_expression,
1840 )?;
1841 }
1842 crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1843 write!(self.out, " {{")?;
1844 for (index, &component) in components.iter().enumerate() {
1845 if index != 0 {
1846 write!(self.out, ", ")?;
1847 }
1848 if self.struct_member_pads.contains(&(ty, index as u32)) {
1850 write!(self.out, "{{}}, ")?;
1851 }
1852 put_expression(self, ctx, component)?;
1853 }
1854 write!(self.out, "}}")?;
1855 }
1856 _ => return Err(Error::UnsupportedCompose(ty)),
1857 }
1858 }
1859 crate::Expression::Splat { size, value } => {
1860 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1861 crate::TypeInner::Scalar(scalar) => scalar,
1862 ref ty => {
1863 return Err(Error::GenericValidation(format!(
1864 "Expected splat value type must be a scalar, got {ty:?}",
1865 )))
1866 }
1867 };
1868 put_numeric_type(&mut self.out, scalar, &[size])?;
1869 write!(self.out, "(")?;
1870 put_expression(self, ctx, value)?;
1871 write!(self.out, ")")?;
1872 }
1873 _ => {
1874 return Err(Error::Override);
1875 }
1876 }
1877
1878 Ok(())
1879 }
1880
1881 fn put_expression(
1893 &mut self,
1894 expr_handle: Handle<crate::Expression>,
1895 context: &ExpressionContext,
1896 is_scoped: bool,
1897 ) -> BackendResult {
1898 #[cfg(test)]
1900 self.put_expression_stack_pointers
1901 .insert(ptr::from_ref(&expr_handle).cast());
1902
1903 if let Some(name) = self.named_expressions.get(&expr_handle) {
1904 write!(self.out, "{name}")?;
1905 return Ok(());
1906 }
1907
1908 let expression = &context.function.expressions[expr_handle];
1909 match *expression {
1910 crate::Expression::Literal(_)
1911 | crate::Expression::Constant(_)
1912 | crate::Expression::ZeroValue(_)
1913 | crate::Expression::Compose { .. }
1914 | crate::Expression::Splat { .. } => {
1915 self.put_possibly_const_expression(
1916 expr_handle,
1917 &context.function.expressions,
1918 context.module,
1919 context.mod_info,
1920 context,
1921 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1922 |writer, context, expr| writer.put_expression(expr, context, true),
1923 )?;
1924 }
1925 crate::Expression::Override(_) => return Err(Error::Override),
1926 crate::Expression::Access { base, .. }
1927 | crate::Expression::AccessIndex { base, .. } => {
1928 let policy = context.choose_bounds_check_policy(base);
1934 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1935 && self.put_bounds_checks(
1936 expr_handle,
1937 context,
1938 back::Level(0),
1939 if is_scoped { "" } else { "(" },
1940 )?
1941 {
1942 write!(self.out, " ? ")?;
1943 self.put_access_chain(expr_handle, policy, context)?;
1944 write!(self.out, " : ")?;
1945
1946 if context.resolve_type(base).pointer_space().is_some() {
1947 let result_ty = context.info[expr_handle]
1951 .ty
1952 .inner_with(&context.module.types)
1953 .pointer_base_type();
1954 let result_ty_handle = match result_ty {
1955 Some(TypeResolution::Handle(handle)) => handle,
1956 Some(TypeResolution::Value(_)) => {
1957 unreachable!(
1965 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1966 );
1967 }
1968 None => {
1969 unreachable!(
1970 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1971 )
1972 }
1973 };
1974 let name_key =
1975 NameKey::oob_local_for_type(context.origin, result_ty_handle);
1976 self.out.write_str(&self.names[&name_key])?;
1977 } else {
1978 write!(self.out, "DefaultConstructible()")?;
1979 }
1980
1981 if !is_scoped {
1982 write!(self.out, ")")?;
1983 }
1984 } else {
1985 self.put_access_chain(expr_handle, policy, context)?;
1986 }
1987 }
1988 crate::Expression::Swizzle {
1989 size,
1990 vector,
1991 pattern,
1992 } => {
1993 self.put_wrapped_expression_for_packed_vec3_access(
1994 vector,
1995 context,
1996 false,
1997 &Self::put_expression,
1998 )?;
1999 write!(self.out, ".")?;
2000 for &sc in pattern[..size as usize].iter() {
2001 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2002 }
2003 }
2004 crate::Expression::FunctionArgument(index) => {
2005 let name_key = match context.origin {
2006 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2007 FunctionOrigin::EntryPoint(ep_index) => {
2008 NameKey::EntryPointArgument(ep_index, index)
2009 }
2010 };
2011 let name = &self.names[&name_key];
2012 write!(self.out, "{name}")?;
2013 }
2014 crate::Expression::GlobalVariable(handle) => {
2015 let name = &self.names[&NameKey::GlobalVariable(handle)];
2016 write!(self.out, "{name}")?;
2017 }
2018 crate::Expression::LocalVariable(handle) => {
2019 let name_key = NameKey::local(context.origin, handle);
2020 let name = &self.names[&name_key];
2021 write!(self.out, "{name}")?;
2022 }
2023 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2024 crate::Expression::ImageSample {
2025 coordinate,
2026 image,
2027 sampler,
2028 clamp_to_edge: true,
2029 gather: None,
2030 array_index: None,
2031 offset: None,
2032 level: crate::SampleLevel::Zero,
2033 depth_ref: None,
2034 } => {
2035 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2036 self.put_expression(image, context, true)?;
2037 write!(self.out, ", ")?;
2038 self.put_expression(sampler, context, true)?;
2039 write!(self.out, ", ")?;
2040 self.put_expression(coordinate, context, true)?;
2041 write!(self.out, ")")?;
2042 }
2043 crate::Expression::ImageSample {
2044 image,
2045 sampler,
2046 gather,
2047 coordinate,
2048 array_index,
2049 offset,
2050 level,
2051 depth_ref,
2052 clamp_to_edge,
2053 } => {
2054 if clamp_to_edge {
2055 return Err(Error::GenericValidation(
2056 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2057 ));
2058 }
2059
2060 let main_op = match gather {
2061 Some(_) => "gather",
2062 None => "sample",
2063 };
2064 let comparison_op = match depth_ref {
2065 Some(_) => "_compare",
2066 None => "",
2067 };
2068 self.put_expression(image, context, false)?;
2069 write!(self.out, ".{main_op}{comparison_op}(")?;
2070 self.put_expression(sampler, context, true)?;
2071 write!(self.out, ", ")?;
2072 self.put_expression(coordinate, context, true)?;
2073 if let Some(expr) = array_index {
2074 write!(self.out, ", ")?;
2075 self.put_expression(expr, context, true)?;
2076 }
2077 if let Some(dref) = depth_ref {
2078 write!(self.out, ", ")?;
2079 self.put_expression(dref, context, true)?;
2080 }
2081
2082 self.put_image_sample_level(image, level, context)?;
2083
2084 if let Some(offset) = offset {
2085 write!(self.out, ", ")?;
2086 self.put_expression(offset, context, true)?;
2087 }
2088
2089 match gather {
2090 None | Some(crate::SwizzleComponent::X) => {}
2091 Some(component) => {
2092 let is_cube_map = match *context.resolve_type(image) {
2093 crate::TypeInner::Image {
2094 dim: crate::ImageDimension::Cube,
2095 ..
2096 } => true,
2097 _ => false,
2098 };
2099 if offset.is_none() && !is_cube_map {
2102 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2103 }
2104 let letter = back::COMPONENTS[component as usize];
2105 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2106 }
2107 }
2108 write!(self.out, ")")?;
2109 }
2110 crate::Expression::ImageLoad {
2111 image,
2112 coordinate,
2113 array_index,
2114 sample,
2115 level,
2116 } => {
2117 let address = TexelAddress {
2118 coordinate,
2119 array_index,
2120 sample,
2121 level: level.map(LevelOfDetail::Direct),
2122 };
2123 self.put_image_load(expr_handle, image, address, context)?;
2124 }
2125 crate::Expression::ImageQuery { image, query } => match query {
2128 crate::ImageQuery::Size { level } => {
2129 self.put_image_size_query(
2130 image,
2131 level.map(LevelOfDetail::Direct),
2132 crate::ScalarKind::Uint,
2133 context,
2134 )?;
2135 }
2136 crate::ImageQuery::NumLevels => {
2137 self.put_expression(image, context, false)?;
2138 write!(self.out, ".get_num_mip_levels()")?;
2139 }
2140 crate::ImageQuery::NumLayers => {
2141 self.put_expression(image, context, false)?;
2142 write!(self.out, ".get_array_size()")?;
2143 }
2144 crate::ImageQuery::NumSamples => {
2145 self.put_expression(image, context, false)?;
2146 write!(self.out, ".get_num_samples()")?;
2147 }
2148 },
2149 crate::Expression::Unary { op, expr } => {
2150 let op_str = match op {
2151 crate::UnaryOperator::Negate => {
2152 match context.resolve_type(expr).scalar_kind() {
2153 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2154 _ => "-",
2155 }
2156 }
2157 crate::UnaryOperator::LogicalNot => "!",
2158 crate::UnaryOperator::BitwiseNot => "~",
2159 };
2160 write!(self.out, "{op_str}(")?;
2161 self.put_expression(expr, context, false)?;
2162 write!(self.out, ")")?;
2163 }
2164 crate::Expression::Binary { op, left, right } => {
2165 let kind = context
2166 .resolve_type(left)
2167 .scalar_kind()
2168 .ok_or(Error::UnsupportedBinaryOp(op))?;
2169
2170 if op == crate::BinaryOperator::Divide
2171 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2172 {
2173 write!(self.out, "{DIV_FUNCTION}(")?;
2174 self.put_expression(left, context, true)?;
2175 write!(self.out, ", ")?;
2176 self.put_expression(right, context, true)?;
2177 write!(self.out, ")")?;
2178 } else if op == crate::BinaryOperator::Modulo
2179 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2180 {
2181 write!(self.out, "{MOD_FUNCTION}(")?;
2182 self.put_expression(left, context, true)?;
2183 write!(self.out, ", ")?;
2184 self.put_expression(right, context, true)?;
2185 write!(self.out, ")")?;
2186 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2187 write!(self.out, "{NAMESPACE}::fmod(")?;
2192 self.put_expression(left, context, true)?;
2193 write!(self.out, ", ")?;
2194 self.put_expression(right, context, true)?;
2195 write!(self.out, ")")?;
2196 } else if (op == crate::BinaryOperator::Add
2197 || op == crate::BinaryOperator::Subtract
2198 || op == crate::BinaryOperator::Multiply)
2199 && kind == crate::ScalarKind::Sint
2200 {
2201 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2202 crate::TypeInner::Scalar(scalar) => {
2203 Ok(crate::TypeInner::Scalar(crate::Scalar {
2204 kind: crate::ScalarKind::Uint,
2205 ..scalar
2206 }))
2207 }
2208 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2209 size,
2210 scalar: crate::Scalar {
2211 kind: crate::ScalarKind::Uint,
2212 ..scalar
2213 },
2214 }),
2215 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2216 };
2217
2218 self.put_bitcasted_expression(
2223 context.resolve_type(expr_handle),
2224 context,
2225 &|writer, context, is_scoped| {
2226 writer.put_binop(
2227 op,
2228 left,
2229 right,
2230 context,
2231 is_scoped,
2232 &|writer, expr, context, _is_scoped| {
2233 writer.put_bitcasted_expression(
2234 &to_unsigned(context.resolve_type(expr))?,
2235 context,
2236 &|writer, context, is_scoped| {
2237 writer.put_expression(expr, context, is_scoped)
2238 },
2239 )
2240 },
2241 )
2242 },
2243 )?;
2244 } else {
2245 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2246 }
2247 }
2248 crate::Expression::Select {
2249 condition,
2250 accept,
2251 reject,
2252 } => match *context.resolve_type(condition) {
2253 crate::TypeInner::Scalar(crate::Scalar {
2254 kind: crate::ScalarKind::Bool,
2255 ..
2256 }) => {
2257 if !is_scoped {
2258 write!(self.out, "(")?;
2259 }
2260 self.put_expression(condition, context, false)?;
2261 write!(self.out, " ? ")?;
2262 self.put_expression(accept, context, false)?;
2263 write!(self.out, " : ")?;
2264 self.put_expression(reject, context, false)?;
2265 if !is_scoped {
2266 write!(self.out, ")")?;
2267 }
2268 }
2269 crate::TypeInner::Vector {
2270 scalar:
2271 crate::Scalar {
2272 kind: crate::ScalarKind::Bool,
2273 ..
2274 },
2275 ..
2276 } => {
2277 write!(self.out, "{NAMESPACE}::select(")?;
2278 self.put_expression(reject, context, true)?;
2279 write!(self.out, ", ")?;
2280 self.put_expression(accept, context, true)?;
2281 write!(self.out, ", ")?;
2282 self.put_expression(condition, context, true)?;
2283 write!(self.out, ")")?;
2284 }
2285 ref ty => {
2286 return Err(Error::GenericValidation(format!(
2287 "Expected select condition to be a non-bool type, got {ty:?}",
2288 )))
2289 }
2290 },
2291 crate::Expression::Derivative { axis, expr, .. } => {
2292 use crate::DerivativeAxis as Axis;
2293 let op = match axis {
2294 Axis::X => "dfdx",
2295 Axis::Y => "dfdy",
2296 Axis::Width => "fwidth",
2297 };
2298 write!(self.out, "{NAMESPACE}::{op}")?;
2299 self.put_call_parameters(iter::once(expr), context)?;
2300 }
2301 crate::Expression::Relational { fun, argument } => {
2302 let op = match fun {
2303 crate::RelationalFunction::Any => "any",
2304 crate::RelationalFunction::All => "all",
2305 crate::RelationalFunction::IsNan => "isnan",
2306 crate::RelationalFunction::IsInf => "isinf",
2307 };
2308 write!(self.out, "{NAMESPACE}::{op}")?;
2309 self.put_call_parameters(iter::once(argument), context)?;
2310 }
2311 crate::Expression::Math {
2312 fun,
2313 arg,
2314 arg1,
2315 arg2,
2316 arg3,
2317 } => {
2318 use crate::MathFunction as Mf;
2319
2320 let arg_type = context.resolve_type(arg);
2321 let scalar_argument = match arg_type {
2322 &crate::TypeInner::Scalar(_) => true,
2323 _ => false,
2324 };
2325
2326 let fun_name = match fun {
2327 Mf::Abs => "abs",
2329 Mf::Min => "min",
2330 Mf::Max => "max",
2331 Mf::Clamp => "clamp",
2332 Mf::Saturate => "saturate",
2333 Mf::Cos => "cos",
2335 Mf::Cosh => "cosh",
2336 Mf::Sin => "sin",
2337 Mf::Sinh => "sinh",
2338 Mf::Tan => "tan",
2339 Mf::Tanh => "tanh",
2340 Mf::Acos => "acos",
2341 Mf::Asin => "asin",
2342 Mf::Atan => "atan",
2343 Mf::Atan2 => "atan2",
2344 Mf::Asinh => "asinh",
2345 Mf::Acosh => "acosh",
2346 Mf::Atanh => "atanh",
2347 Mf::Radians => "",
2348 Mf::Degrees => "",
2349 Mf::Ceil => "ceil",
2351 Mf::Floor => "floor",
2352 Mf::Round => "rint",
2353 Mf::Fract => "fract",
2354 Mf::Trunc => "trunc",
2355 Mf::Modf => MODF_FUNCTION,
2356 Mf::Frexp => FREXP_FUNCTION,
2357 Mf::Ldexp => "ldexp",
2358 Mf::Exp => "exp",
2360 Mf::Exp2 => "exp2",
2361 Mf::Log => "log",
2362 Mf::Log2 => "log2",
2363 Mf::Pow => "pow",
2364 Mf::Dot => match *context.resolve_type(arg) {
2366 crate::TypeInner::Vector {
2367 scalar:
2368 crate::Scalar {
2369 kind: crate::ScalarKind::Float,
2371 ..
2372 },
2373 ..
2374 } => "dot",
2375 crate::TypeInner::Vector {
2376 size,
2377 scalar:
2378 scalar @ crate::Scalar {
2379 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2380 ..
2381 },
2382 } => {
2383 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2385 write!(self.out, "{fun_name}(")?;
2386 self.put_expression(arg, context, true)?;
2387 write!(self.out, ", ")?;
2388 self.put_expression(arg1.unwrap(), context, true)?;
2389 write!(self.out, ")")?;
2390 return Ok(());
2391 }
2392 _ => unreachable!(
2393 "Correct TypeInner for dot product should be already validated"
2394 ),
2395 },
2396 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2397 if context.lang_version >= (2, 1) {
2398 let packed_type = match fun {
2402 Mf::Dot4I8Packed => "packed_char4",
2403 Mf::Dot4U8Packed => "packed_uchar4",
2404 _ => unreachable!(),
2405 };
2406
2407 return self.put_dot_product(
2408 Reinterpreted::new(packed_type, arg),
2409 Reinterpreted::new(packed_type, arg1.unwrap()),
2410 4,
2411 |writer, arg, index| {
2412 write!(writer.out, "{arg}[{index}]")?;
2415 Ok(())
2416 },
2417 );
2418 } else {
2419 let conversion = match fun {
2423 Mf::Dot4I8Packed => "int",
2424 Mf::Dot4U8Packed => "",
2425 _ => unreachable!(),
2426 };
2427
2428 return self.put_dot_product(
2429 arg,
2430 arg1.unwrap(),
2431 4,
2432 |writer, arg, index| {
2433 write!(writer.out, "({conversion}(")?;
2434 writer.put_expression(arg, context, true)?;
2435 if index == 3 {
2436 write!(writer.out, ") >> 24)")?;
2437 } else {
2438 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2439 }
2440 Ok(())
2441 },
2442 );
2443 }
2444 }
2445 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2446 Mf::Cross => "cross",
2447 Mf::Distance => "distance",
2448 Mf::Length if scalar_argument => "abs",
2449 Mf::Length => "length",
2450 Mf::Normalize => "normalize",
2451 Mf::FaceForward => "faceforward",
2452 Mf::Reflect => "reflect",
2453 Mf::Refract => "refract",
2454 Mf::Sign => match arg_type.scalar_kind() {
2456 Some(crate::ScalarKind::Sint) => {
2457 return self.put_isign(arg, context);
2458 }
2459 _ => "sign",
2460 },
2461 Mf::Fma => "fma",
2462 Mf::Mix => "mix",
2463 Mf::Step => "step",
2464 Mf::SmoothStep => "smoothstep",
2465 Mf::Sqrt => "sqrt",
2466 Mf::InverseSqrt => "rsqrt",
2467 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2468 Mf::Transpose => "transpose",
2469 Mf::Determinant => "determinant",
2470 Mf::QuantizeToF16 => "",
2471 Mf::CountTrailingZeros => "ctz",
2473 Mf::CountLeadingZeros => "clz",
2474 Mf::CountOneBits => "popcount",
2475 Mf::ReverseBits => "reverse_bits",
2476 Mf::ExtractBits => "",
2477 Mf::InsertBits => "",
2478 Mf::FirstTrailingBit => "",
2479 Mf::FirstLeadingBit => "",
2480 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2482 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2483 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2484 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2485 Mf::Pack2x16float => "",
2486 Mf::Pack4xI8 => "",
2487 Mf::Pack4xU8 => "",
2488 Mf::Pack4xI8Clamp => "",
2489 Mf::Pack4xU8Clamp => "",
2490 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2492 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2493 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2494 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2495 Mf::Unpack2x16float => "",
2496 Mf::Unpack4xI8 => "",
2497 Mf::Unpack4xU8 => "",
2498 };
2499
2500 match fun {
2501 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2502 if context.lang_version < (1, 2) {
2511 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2512 }
2513 }
2514 _ => {}
2515 }
2516
2517 match fun {
2518 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2519 write!(self.out, "{ABS_FUNCTION}(")?;
2520 self.put_expression(arg, context, true)?;
2521 write!(self.out, ")")?;
2522 }
2523 Mf::Distance if scalar_argument => {
2524 write!(self.out, "{NAMESPACE}::abs(")?;
2525 self.put_expression(arg, context, false)?;
2526 write!(self.out, " - ")?;
2527 self.put_expression(arg1.unwrap(), context, false)?;
2528 write!(self.out, ")")?;
2529 }
2530 Mf::FirstTrailingBit => {
2531 let scalar = context.resolve_type(arg).scalar().unwrap();
2532 let constant = scalar.width * 8 + 1;
2533
2534 write!(self.out, "((({NAMESPACE}::ctz(")?;
2535 self.put_expression(arg, context, true)?;
2536 write!(self.out, ") + 1) % {constant}) - 1)")?;
2537 }
2538 Mf::FirstLeadingBit => {
2539 let inner = context.resolve_type(arg);
2540 let scalar = inner.scalar().unwrap();
2541 let constant = scalar.width * 8 - 1;
2542
2543 write!(
2544 self.out,
2545 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2546 )?;
2547
2548 if scalar.kind == crate::ScalarKind::Sint {
2549 write!(self.out, "{NAMESPACE}::select(")?;
2550 self.put_expression(arg, context, true)?;
2551 write!(self.out, ", ~")?;
2552 self.put_expression(arg, context, true)?;
2553 write!(self.out, ", ")?;
2554 self.put_expression(arg, context, true)?;
2555 write!(self.out, " < 0)")?;
2556 } else {
2557 self.put_expression(arg, context, true)?;
2558 }
2559
2560 write!(self.out, "), ")?;
2561
2562 match *inner {
2564 crate::TypeInner::Vector { size, scalar } => {
2565 let size = common::vector_size_str(size);
2566 let name = scalar.to_msl_name();
2567 write!(self.out, "{name}{size}")?;
2568 }
2569 crate::TypeInner::Scalar(scalar) => {
2570 let name = scalar.to_msl_name();
2571 write!(self.out, "{name}")?;
2572 }
2573 _ => (),
2574 }
2575
2576 write!(self.out, "(-1), ")?;
2577 self.put_expression(arg, context, true)?;
2578 write!(self.out, " == 0 || ")?;
2579 self.put_expression(arg, context, true)?;
2580 write!(self.out, " == -1)")?;
2581 }
2582 Mf::Unpack2x16float => {
2583 write!(self.out, "float2(as_type<half2>(")?;
2584 self.put_expression(arg, context, false)?;
2585 write!(self.out, "))")?;
2586 }
2587 Mf::Pack2x16float => {
2588 write!(self.out, "as_type<uint>(half2(")?;
2589 self.put_expression(arg, context, false)?;
2590 write!(self.out, "))")?;
2591 }
2592 Mf::ExtractBits => {
2593 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2610
2611 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2612 self.put_expression(arg, context, true)?;
2613 write!(self.out, ", {NAMESPACE}::min(")?;
2614 self.put_expression(arg1.unwrap(), context, true)?;
2615 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2616 self.put_expression(arg2.unwrap(), context, true)?;
2617 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2618 self.put_expression(arg1.unwrap(), context, true)?;
2619 write!(self.out, ", {scalar_bits}u)))")?;
2620 }
2621 Mf::InsertBits => {
2622 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2627
2628 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2629 self.put_expression(arg, context, true)?;
2630 write!(self.out, ", ")?;
2631 self.put_expression(arg1.unwrap(), context, true)?;
2632 write!(self.out, ", {NAMESPACE}::min(")?;
2633 self.put_expression(arg2.unwrap(), context, true)?;
2634 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2635 self.put_expression(arg3.unwrap(), context, true)?;
2636 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2637 self.put_expression(arg2.unwrap(), context, true)?;
2638 write!(self.out, ", {scalar_bits}u)))")?;
2639 }
2640 Mf::Radians => {
2641 write!(self.out, "((")?;
2642 self.put_expression(arg, context, false)?;
2643 write!(self.out, ") * 0.017453292519943295474)")?;
2644 }
2645 Mf::Degrees => {
2646 write!(self.out, "((")?;
2647 self.put_expression(arg, context, false)?;
2648 write!(self.out, ") * 57.295779513082322865)")?;
2649 }
2650 Mf::Modf | Mf::Frexp => {
2651 write!(self.out, "{fun_name}")?;
2652 self.put_call_parameters(iter::once(arg), context)?;
2653 }
2654 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2655 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2656 Mf::Pack4xI8Clamp => {
2657 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2658 }
2659 Mf::Pack4xU8Clamp => {
2660 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2661 }
2662 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2663 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2664 "u"
2665 } else {
2666 ""
2667 };
2668
2669 if context.lang_version >= (2, 1) {
2670 write!(
2672 self.out,
2673 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2674 )?;
2675 self.put_expression(arg, context, true)?;
2676 write!(self.out, "))")?;
2677 } else {
2678 write!(self.out, "({sign_prefix}int4(")?;
2680 self.put_expression(arg, context, true)?;
2681 write!(self.out, ", ")?;
2682 self.put_expression(arg, context, true)?;
2683 write!(self.out, " >> 8, ")?;
2684 self.put_expression(arg, context, true)?;
2685 write!(self.out, " >> 16, ")?;
2686 self.put_expression(arg, context, true)?;
2687 write!(self.out, " >> 24) << 24 >> 24)")?;
2688 }
2689 }
2690 Mf::QuantizeToF16 => {
2691 match *context.resolve_type(arg) {
2692 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2693 crate::TypeInner::Vector { size, .. } => write!(
2694 self.out,
2695 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2696 size = common::vector_size_str(size),
2697 )?,
2698 _ => unreachable!(
2699 "Correct TypeInner for QuantizeToF16 should be already validated"
2700 ),
2701 };
2702
2703 self.put_expression(arg, context, true)?;
2704 write!(self.out, "))")?;
2705 }
2706 _ => {
2707 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2708 self.put_call_parameters(
2709 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2710 context,
2711 )?;
2712 }
2713 }
2714 }
2715 crate::Expression::As {
2716 expr,
2717 kind,
2718 convert,
2719 } => match *context.resolve_type(expr) {
2720 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2721 if src.kind == crate::ScalarKind::Float
2722 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2723 && convert.is_some()
2724 {
2725 let fun_name = match (kind, convert) {
2729 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2730 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2731 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2732 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2733 _ => unreachable!(),
2734 };
2735 write!(self.out, "{fun_name}(")?;
2736 self.put_expression(expr, context, true)?;
2737 write!(self.out, ")")?;
2738 } else {
2739 let target_scalar = crate::Scalar {
2740 kind,
2741 width: convert.unwrap_or(src.width),
2742 };
2743 let op = match convert {
2744 Some(_) => "static_cast",
2745 None => "as_type",
2746 };
2747 write!(self.out, "{op}<")?;
2748 match *context.resolve_type(expr) {
2749 crate::TypeInner::Vector { size, .. } => {
2750 put_numeric_type(&mut self.out, target_scalar, &[size])?
2751 }
2752 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2753 };
2754 write!(self.out, ">(")?;
2755 self.put_expression(expr, context, true)?;
2756 write!(self.out, ")")?;
2757 }
2758 }
2759 crate::TypeInner::Matrix {
2760 columns,
2761 rows,
2762 scalar,
2763 } => {
2764 let target_scalar = crate::Scalar {
2765 kind,
2766 width: convert.unwrap_or(scalar.width),
2767 };
2768 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2769 write!(self.out, "(")?;
2770 self.put_expression(expr, context, true)?;
2771 write!(self.out, ")")?;
2772 }
2773 ref ty => {
2774 return Err(Error::GenericValidation(format!(
2775 "Unsupported type for As: {ty:?}"
2776 )))
2777 }
2778 },
2779 crate::Expression::CallResult(_)
2781 | crate::Expression::AtomicResult { .. }
2782 | crate::Expression::WorkGroupUniformLoadResult { .. }
2783 | crate::Expression::SubgroupBallotResult
2784 | crate::Expression::SubgroupOperationResult { .. }
2785 | crate::Expression::RayQueryProceedResult => {
2786 unreachable!()
2787 }
2788 crate::Expression::ArrayLength(expr) => {
2789 let global = match context.function.expressions[expr] {
2791 crate::Expression::AccessIndex { base, .. } => {
2792 match context.function.expressions[base] {
2793 crate::Expression::GlobalVariable(handle) => handle,
2794 ref ex => {
2795 return Err(Error::GenericValidation(format!(
2796 "Expected global variable in AccessIndex, got {ex:?}"
2797 )))
2798 }
2799 }
2800 }
2801 crate::Expression::GlobalVariable(handle) => handle,
2802 ref ex => {
2803 return Err(Error::GenericValidation(format!(
2804 "Unexpected expression in ArrayLength, got {ex:?}"
2805 )))
2806 }
2807 };
2808
2809 if !is_scoped {
2810 write!(self.out, "(")?;
2811 }
2812 write!(self.out, "1 + ")?;
2813 self.put_dynamic_array_max_index(global, context)?;
2814 if !is_scoped {
2815 write!(self.out, ")")?;
2816 }
2817 }
2818 crate::Expression::RayQueryVertexPositions { .. } => {
2819 unimplemented!()
2820 }
2821 crate::Expression::RayQueryGetIntersection {
2822 query,
2823 committed: _,
2824 } => {
2825 if context.lang_version < (2, 4) {
2826 return Err(Error::UnsupportedRayTracing);
2827 }
2828
2829 let ty = context.module.special_types.ray_intersection.unwrap();
2830 let type_name = &self.names[&NameKey::Type(ty)];
2831 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2832 self.put_expression(query, context, true)?;
2833 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2834 let fields = [
2835 "distance",
2836 "user_instance_id", "instance_id",
2838 "", "geometry_id",
2840 "primitive_id",
2841 "triangle_barycentric_coord",
2842 "triangle_front_facing",
2843 "", "object_to_world_transform", "world_to_object_transform", ];
2847 for field in fields {
2848 write!(self.out, ", ")?;
2849 if field.is_empty() {
2850 write!(self.out, "{{}}")?;
2851 } else {
2852 self.put_expression(query, context, true)?;
2853 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2854 }
2855 }
2856 write!(self.out, "}}")?;
2857 }
2858 crate::Expression::CooperativeLoad { ref data, .. } => {
2859 if context.lang_version < (2, 3) {
2860 return Err(Error::UnsupportedCooperativeMatrix);
2861 }
2862 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2863 write!(self.out, "&")?;
2864 self.put_access_chain(data.pointer, context.policies.index, context)?;
2865 write!(self.out, ", ")?;
2866 self.put_expression(data.stride, context, true)?;
2867 write!(self.out, ", {})", data.row_major)?;
2868 }
2869 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2870 if context.lang_version < (2, 3) {
2871 return Err(Error::UnsupportedCooperativeMatrix);
2872 }
2873 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2874 self.put_expression(a, context, true)?;
2875 write!(self.out, ", ")?;
2876 self.put_expression(b, context, true)?;
2877 write!(self.out, ", ")?;
2878 self.put_expression(c, context, true)?;
2879 write!(self.out, ")")?;
2880 }
2881 }
2882 Ok(())
2883 }
2884
2885 fn put_binop<F>(
2888 &mut self,
2889 op: crate::BinaryOperator,
2890 left: Handle<crate::Expression>,
2891 right: Handle<crate::Expression>,
2892 context: &ExpressionContext,
2893 is_scoped: bool,
2894 put_expression: &F,
2895 ) -> BackendResult
2896 where
2897 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2898 {
2899 let op_str = back::binary_operation_str(op);
2900
2901 if !is_scoped {
2902 write!(self.out, "(")?;
2903 }
2904
2905 if op == crate::BinaryOperator::Multiply
2908 && matches!(
2909 context.resolve_type(right),
2910 &crate::TypeInner::Matrix { .. }
2911 )
2912 {
2913 self.put_wrapped_expression_for_packed_vec3_access(
2914 left,
2915 context,
2916 false,
2917 put_expression,
2918 )?;
2919 } else {
2920 put_expression(self, left, context, false)?;
2921 }
2922
2923 write!(self.out, " {op_str} ")?;
2924
2925 if op == crate::BinaryOperator::Multiply
2927 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2928 {
2929 self.put_wrapped_expression_for_packed_vec3_access(
2930 right,
2931 context,
2932 false,
2933 put_expression,
2934 )?;
2935 } else {
2936 put_expression(self, right, context, false)?;
2937 }
2938
2939 if !is_scoped {
2940 write!(self.out, ")")?;
2941 }
2942
2943 Ok(())
2944 }
2945
2946 fn put_wrapped_expression_for_packed_vec3_access<F>(
2948 &mut self,
2949 expr_handle: Handle<crate::Expression>,
2950 context: &ExpressionContext,
2951 is_scoped: bool,
2952 put_expression: &F,
2953 ) -> BackendResult
2954 where
2955 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2956 {
2957 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2958 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2959 put_expression(self, expr_handle, context, is_scoped)?;
2960 write!(self.out, ")")?;
2961 } else {
2962 put_expression(self, expr_handle, context, is_scoped)?;
2963 }
2964 Ok(())
2965 }
2966
2967 fn put_bitcasted_expression<F>(
2970 &mut self,
2971 cast_to: &crate::TypeInner,
2972 context: &ExpressionContext,
2973 put_expression: &F,
2974 ) -> BackendResult
2975 where
2976 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2977 {
2978 write!(self.out, "as_type<")?;
2979 match *cast_to {
2980 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2981 crate::TypeInner::Vector { size, scalar } => {
2982 put_numeric_type(&mut self.out, scalar, &[size])?
2983 }
2984 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2985 };
2986 write!(self.out, ">(")?;
2987 put_expression(self, context, true)?;
2988 write!(self.out, ")")?;
2989
2990 Ok(())
2991 }
2992
2993 fn put_index(
2995 &mut self,
2996 index: index::GuardedIndex,
2997 context: &ExpressionContext,
2998 is_scoped: bool,
2999 ) -> BackendResult {
3000 match index {
3001 index::GuardedIndex::Expression(expr) => {
3002 self.put_expression(expr, context, is_scoped)?
3003 }
3004 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3005 }
3006 Ok(())
3007 }
3008
3009 #[allow(unused_variables)]
3039 fn put_bounds_checks(
3040 &mut self,
3041 chain: Handle<crate::Expression>,
3042 context: &ExpressionContext,
3043 level: back::Level,
3044 prefix: &'static str,
3045 ) -> Result<bool, Error> {
3046 let mut check_written = false;
3047
3048 for item in context.bounds_check_iter(chain) {
3050 let BoundsCheck {
3051 base,
3052 index,
3053 length,
3054 } = item;
3055
3056 if check_written {
3057 write!(self.out, " && ")?;
3058 } else {
3059 write!(self.out, "{level}{prefix}")?;
3060 check_written = true;
3061 }
3062
3063 write!(self.out, "uint(")?;
3067 self.put_index(index, context, true)?;
3068 self.out.write_str(") < ")?;
3069 match length {
3070 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3071 index::IndexableLength::Dynamic => {
3072 let global = context.function.originating_global(base).ok_or_else(|| {
3073 Error::GenericValidation("Could not find originating global".into())
3074 })?;
3075 write!(self.out, "1 + ")?;
3076 self.put_dynamic_array_max_index(global, context)?
3077 }
3078 }
3079 }
3080
3081 Ok(check_written)
3082 }
3083
3084 fn put_access_chain(
3104 &mut self,
3105 chain: Handle<crate::Expression>,
3106 policy: index::BoundsCheckPolicy,
3107 context: &ExpressionContext,
3108 ) -> BackendResult {
3109 match context.function.expressions[chain] {
3110 crate::Expression::Access { base, index } => {
3111 let mut base_ty = context.resolve_type(base);
3112
3113 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3115 base_ty = &context.module.types[base].inner;
3116 }
3117
3118 self.put_subscripted_access_chain(
3119 base,
3120 base_ty,
3121 index::GuardedIndex::Expression(index),
3122 policy,
3123 context,
3124 )?;
3125 }
3126 crate::Expression::AccessIndex { base, index } => {
3127 let base_resolution = &context.info[base].ty;
3128 let mut base_ty = base_resolution.inner_with(&context.module.types);
3129 let mut base_ty_handle = base_resolution.handle();
3130
3131 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3133 base_ty = &context.module.types[base].inner;
3134 base_ty_handle = Some(base);
3135 }
3136
3137 match *base_ty {
3141 crate::TypeInner::Struct { .. } => {
3142 let base_ty = base_ty_handle.unwrap();
3143 self.put_access_chain(base, policy, context)?;
3144 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3145 write!(self.out, ".{name}")?;
3146 }
3147 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3148 self.put_access_chain(base, policy, context)?;
3149 if context.get_packed_vec_kind(base).is_some() {
3152 write!(self.out, "[{index}]")?;
3153 } else {
3154 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3155 }
3156 }
3157 _ => {
3158 self.put_subscripted_access_chain(
3159 base,
3160 base_ty,
3161 index::GuardedIndex::Known(index),
3162 policy,
3163 context,
3164 )?;
3165 }
3166 }
3167 }
3168 _ => self.put_expression(chain, context, false)?,
3169 }
3170
3171 Ok(())
3172 }
3173
3174 fn put_subscripted_access_chain(
3191 &mut self,
3192 base: Handle<crate::Expression>,
3193 base_ty: &crate::TypeInner,
3194 index: index::GuardedIndex,
3195 policy: index::BoundsCheckPolicy,
3196 context: &ExpressionContext,
3197 ) -> BackendResult {
3198 let accessing_wrapped_array = match *base_ty {
3199 crate::TypeInner::Array {
3200 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3201 ..
3202 } => true,
3203 _ => false,
3204 };
3205 let accessing_wrapped_binding_array =
3206 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3207
3208 self.put_access_chain(base, policy, context)?;
3209 if accessing_wrapped_array {
3210 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3211 }
3212 write!(self.out, "[")?;
3213
3214 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3216 context.access_needs_check(base, index)
3217 } else {
3218 None
3219 };
3220 if let Some(limit) = restriction_needed {
3221 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3222 self.put_index(index, context, true)?;
3223 write!(self.out, "), ")?;
3224 match limit {
3225 index::IndexableLength::Known(limit) => {
3226 write!(self.out, "{}u", limit - 1)?;
3227 }
3228 index::IndexableLength::Dynamic => {
3229 let global = context.function.originating_global(base).ok_or_else(|| {
3230 Error::GenericValidation("Could not find originating global".into())
3231 })?;
3232 self.put_dynamic_array_max_index(global, context)?;
3233 }
3234 }
3235 write!(self.out, ")")?;
3236 } else {
3237 self.put_index(index, context, true)?;
3238 }
3239
3240 write!(self.out, "]")?;
3241
3242 if accessing_wrapped_binding_array {
3243 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3244 }
3245
3246 Ok(())
3247 }
3248
3249 fn put_load(
3250 &mut self,
3251 pointer: Handle<crate::Expression>,
3252 context: &ExpressionContext,
3253 is_scoped: bool,
3254 ) -> BackendResult {
3255 let policy = context.choose_bounds_check_policy(pointer);
3258 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3259 && self.put_bounds_checks(
3260 pointer,
3261 context,
3262 back::Level(0),
3263 if is_scoped { "" } else { "(" },
3264 )?
3265 {
3266 write!(self.out, " ? ")?;
3267 self.put_unchecked_load(pointer, policy, context)?;
3268 write!(self.out, " : DefaultConstructible()")?;
3269
3270 if !is_scoped {
3271 write!(self.out, ")")?;
3272 }
3273 } else {
3274 self.put_unchecked_load(pointer, policy, context)?;
3275 }
3276
3277 Ok(())
3278 }
3279
3280 fn put_unchecked_load(
3281 &mut self,
3282 pointer: Handle<crate::Expression>,
3283 policy: index::BoundsCheckPolicy,
3284 context: &ExpressionContext,
3285 ) -> BackendResult {
3286 let is_atomic_pointer = context
3287 .resolve_type(pointer)
3288 .is_atomic_pointer(&context.module.types);
3289
3290 if is_atomic_pointer {
3291 write!(
3292 self.out,
3293 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3294 )?;
3295 self.put_access_chain(pointer, policy, context)?;
3296 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3297 } else {
3298 self.put_access_chain(pointer, policy, context)?;
3302 }
3303
3304 Ok(())
3305 }
3306
3307 fn put_return_value(
3308 &mut self,
3309 level: back::Level,
3310 expr_handle: Handle<crate::Expression>,
3311 result_struct: Option<&str>,
3312 context: &ExpressionContext,
3313 ) -> BackendResult {
3314 match result_struct {
3315 Some(struct_name) => {
3316 let mut has_point_size = false;
3317 let result_ty = context.function.result.as_ref().unwrap().ty;
3318 match context.module.types[result_ty].inner {
3319 crate::TypeInner::Struct { ref members, .. } => {
3320 let tmp = "_tmp";
3321 write!(self.out, "{level}const auto {tmp} = ")?;
3322 self.put_expression(expr_handle, context, true)?;
3323 writeln!(self.out, ";")?;
3324 write!(self.out, "{level}return {struct_name} {{")?;
3325
3326 let mut is_first = true;
3327
3328 for (index, member) in members.iter().enumerate() {
3329 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3330 member.binding
3331 {
3332 has_point_size = true;
3333 if !context.pipeline_options.allow_and_force_point_size {
3334 continue;
3335 }
3336 }
3337
3338 let comma = if is_first { "" } else { "," };
3339 is_first = false;
3340 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3341 if let crate::TypeInner::Array {
3345 size: crate::ArraySize::Constant(size),
3346 ..
3347 } = context.module.types[member.ty].inner
3348 {
3349 write!(self.out, "{comma} {{")?;
3350 for j in 0..size.get() {
3351 if j != 0 {
3352 write!(self.out, ",")?;
3353 }
3354 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3355 }
3356 write!(self.out, "}}")?;
3357 } else {
3358 write!(self.out, "{comma} {tmp}.{name}")?;
3359 }
3360 }
3361 }
3362 _ => {
3363 write!(self.out, "{level}return {struct_name} {{ ")?;
3364 self.put_expression(expr_handle, context, true)?;
3365 }
3366 }
3367
3368 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3369 let stage = context.module.entry_points[ep_index as usize].stage;
3370 if context.pipeline_options.allow_and_force_point_size
3371 && stage == crate::ShaderStage::Vertex
3372 && !has_point_size
3373 {
3374 write!(self.out, ", 1.0")?;
3376 }
3377 }
3378 write!(self.out, " }}")?;
3379 }
3380 None => {
3381 write!(self.out, "{level}return ")?;
3382 self.put_expression(expr_handle, context, true)?;
3383 }
3384 }
3385 writeln!(self.out, ";")?;
3386 Ok(())
3387 }
3388
3389 fn update_expressions_to_bake(
3394 &mut self,
3395 func: &crate::Function,
3396 info: &valid::FunctionInfo,
3397 context: &ExpressionContext,
3398 ) {
3399 use crate::Expression;
3400 self.need_bake_expressions.clear();
3401
3402 for (expr_handle, expr) in func.expressions.iter() {
3403 let expr_info = &info[expr_handle];
3406 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3407 if min_ref_count <= expr_info.ref_count {
3408 self.need_bake_expressions.insert(expr_handle);
3409 } else {
3410 match expr_info.ty {
3411 TypeResolution::Handle(h)
3413 if Some(h) == context.module.special_types.ray_desc =>
3414 {
3415 self.need_bake_expressions.insert(expr_handle);
3416 }
3417 _ => {}
3418 }
3419 }
3420
3421 if let Expression::Math {
3422 fun,
3423 arg,
3424 arg1,
3425 arg2,
3426 ..
3427 } = *expr
3428 {
3429 match fun {
3430 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3440 self.need_bake_expressions.insert(arg);
3441 self.need_bake_expressions.insert(arg1.unwrap());
3442 }
3443 crate::MathFunction::FirstLeadingBit => {
3444 self.need_bake_expressions.insert(arg);
3445 }
3446 crate::MathFunction::Pack4xI8
3447 | crate::MathFunction::Pack4xU8
3448 | crate::MathFunction::Pack4xI8Clamp
3449 | crate::MathFunction::Pack4xU8Clamp
3450 | crate::MathFunction::Unpack4xI8
3451 | crate::MathFunction::Unpack4xU8 => {
3452 if context.lang_version < (2, 1) {
3455 self.need_bake_expressions.insert(arg);
3456 }
3457 }
3458 crate::MathFunction::ExtractBits => {
3459 self.need_bake_expressions.insert(arg1.unwrap());
3461 }
3462 crate::MathFunction::InsertBits => {
3463 self.need_bake_expressions.insert(arg2.unwrap());
3465 }
3466 crate::MathFunction::Sign => {
3467 let inner = context.resolve_type(expr_handle);
3472 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3473 self.need_bake_expressions.insert(arg);
3474 }
3475 }
3476 _ => {}
3477 }
3478 }
3479 }
3480 }
3481
3482 fn start_baking_expression(
3483 &mut self,
3484 handle: Handle<crate::Expression>,
3485 context: &ExpressionContext,
3486 name: &str,
3487 ) -> BackendResult {
3488 match context.info[handle].ty {
3489 TypeResolution::Handle(ty_handle) => {
3490 let ty_name = TypeContext {
3491 handle: ty_handle,
3492 gctx: context.module.to_ctx(),
3493 names: &self.names,
3494 access: crate::StorageAccess::empty(),
3495 first_time: false,
3496 };
3497 write!(self.out, "{ty_name}")?;
3498 }
3499 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3500 put_numeric_type(&mut self.out, scalar, &[])?;
3501 }
3502 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3503 put_numeric_type(&mut self.out, scalar, &[size])?;
3504 }
3505 TypeResolution::Value(crate::TypeInner::Matrix {
3506 columns,
3507 rows,
3508 scalar,
3509 }) => {
3510 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3511 }
3512 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3513 columns,
3514 rows,
3515 scalar,
3516 role: _,
3517 }) => {
3518 write!(
3519 self.out,
3520 "{}::simdgroup_{}{}x{}",
3521 NAMESPACE,
3522 scalar.to_msl_name(),
3523 columns as u32,
3524 rows as u32,
3525 )?;
3526 }
3527 TypeResolution::Value(ref other) => {
3528 log::warn!("Type {other:?} isn't a known local");
3529 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3530 }
3531 }
3532
3533 write!(self.out, " {name} = ")?;
3535
3536 Ok(())
3537 }
3538
3539 fn put_cache_restricted_level(
3552 &mut self,
3553 load: Handle<crate::Expression>,
3554 image: Handle<crate::Expression>,
3555 mip_level: Option<Handle<crate::Expression>>,
3556 indent: back::Level,
3557 context: &StatementContext,
3558 ) -> BackendResult {
3559 let level_of_detail = match mip_level {
3562 Some(level) => level,
3563 None => return Ok(()),
3564 };
3565
3566 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3567 || !context.expression.image_needs_lod(image)
3568 {
3569 return Ok(());
3570 }
3571
3572 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3573 self.put_restricted_scalar_image_index(
3574 image,
3575 level_of_detail,
3576 "get_num_mip_levels",
3577 &context.expression,
3578 )?;
3579 writeln!(self.out, ";")?;
3580
3581 Ok(())
3582 }
3583
3584 fn put_casting_to_packed_chars(
3590 &mut self,
3591 fun: crate::MathFunction,
3592 arg0: Handle<crate::Expression>,
3593 arg1: Handle<crate::Expression>,
3594 indent: back::Level,
3595 context: &StatementContext<'_>,
3596 ) -> Result<(), Error> {
3597 let packed_type = match fun {
3598 crate::MathFunction::Dot4I8Packed => "packed_char4",
3599 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3600 _ => unreachable!(),
3601 };
3602
3603 for arg in [arg0, arg1] {
3604 write!(
3605 self.out,
3606 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3607 Reinterpreted::new(packed_type, arg)
3608 )?;
3609 self.put_expression(arg, &context.expression, true)?;
3610 writeln!(self.out, ");")?;
3611 }
3612
3613 Ok(())
3614 }
3615
3616 fn put_block(
3617 &mut self,
3618 level: back::Level,
3619 statements: &[crate::Statement],
3620 context: &StatementContext,
3621 ) -> BackendResult {
3622 #[cfg(test)]
3624 self.put_block_stack_pointers
3625 .insert(ptr::from_ref(&level).cast());
3626
3627 for statement in statements {
3628 log::trace!("statement[{}] {:?}", level.0, statement);
3629 match *statement {
3630 crate::Statement::Emit(ref range) => {
3631 for handle in range.clone() {
3632 use crate::MathFunction as Mf;
3633
3634 match context.expression.function.expressions[handle] {
3635 crate::Expression::ImageLoad {
3638 image,
3639 level: mip_level,
3640 ..
3641 } => {
3642 self.put_cache_restricted_level(
3643 handle, image, mip_level, level, context,
3644 )?;
3645 }
3646
3647 crate::Expression::Math {
3656 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3657 arg,
3658 arg1,
3659 ..
3660 } if context.expression.lang_version >= (2, 1) => {
3661 self.put_casting_to_packed_chars(
3662 fun,
3663 arg,
3664 arg1.unwrap(),
3665 level,
3666 context,
3667 )?;
3668 }
3669
3670 _ => (),
3671 }
3672
3673 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3674 let expr_name = if ptr_class.is_some() {
3675 None } else if let Some(name) =
3677 context.expression.function.named_expressions.get(&handle)
3678 {
3679 Some(self.namer.call(name))
3689 } else {
3690 let bake = if context.expression.guarded_indices.contains(handle) {
3694 true
3695 } else {
3696 self.need_bake_expressions.contains(&handle)
3697 };
3698
3699 if bake {
3700 Some(Baked(handle).to_string())
3701 } else {
3702 None
3703 }
3704 };
3705
3706 if let Some(name) = expr_name {
3707 write!(self.out, "{level}")?;
3708 self.start_baking_expression(handle, &context.expression, &name)?;
3709 self.put_expression(handle, &context.expression, true)?;
3710 self.named_expressions.insert(handle, name);
3711 writeln!(self.out, ";")?;
3712 }
3713 }
3714 }
3715 crate::Statement::Block(ref block) => {
3716 if !block.is_empty() {
3717 writeln!(self.out, "{level}{{")?;
3718 self.put_block(level.next(), block, context)?;
3719 writeln!(self.out, "{level}}}")?;
3720 }
3721 }
3722 crate::Statement::If {
3723 condition,
3724 ref accept,
3725 ref reject,
3726 } => {
3727 write!(self.out, "{level}if (")?;
3728 self.put_expression(condition, &context.expression, true)?;
3729 writeln!(self.out, ") {{")?;
3730 self.put_block(level.next(), accept, context)?;
3731 if !reject.is_empty() {
3732 writeln!(self.out, "{level}}} else {{")?;
3733 self.put_block(level.next(), reject, context)?;
3734 }
3735 writeln!(self.out, "{level}}}")?;
3736 }
3737 crate::Statement::Switch {
3738 selector,
3739 ref cases,
3740 } => {
3741 write!(self.out, "{level}switch(")?;
3742 self.put_expression(selector, &context.expression, true)?;
3743 writeln!(self.out, ") {{")?;
3744 let lcase = level.next();
3745 for case in cases.iter() {
3746 match case.value {
3747 crate::SwitchValue::I32(value) => {
3748 write!(self.out, "{lcase}case {value}:")?;
3749 }
3750 crate::SwitchValue::U32(value) => {
3751 write!(self.out, "{lcase}case {value}u:")?;
3752 }
3753 crate::SwitchValue::Default => {
3754 write!(self.out, "{lcase}default:")?;
3755 }
3756 }
3757
3758 let write_block_braces = !(case.fall_through && case.body.is_empty());
3759 if write_block_braces {
3760 writeln!(self.out, " {{")?;
3761 } else {
3762 writeln!(self.out)?;
3763 }
3764
3765 self.put_block(lcase.next(), &case.body, context)?;
3766 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3767 {
3768 writeln!(self.out, "{}break;", lcase.next())?;
3769 }
3770
3771 if write_block_braces {
3772 writeln!(self.out, "{lcase}}}")?;
3773 }
3774 }
3775 writeln!(self.out, "{level}}}")?;
3776 }
3777 crate::Statement::Loop {
3778 ref body,
3779 ref continuing,
3780 break_if,
3781 } => {
3782 let force_loop_bound_statements =
3783 self.gen_force_bounded_loop_statements(level, context);
3784 let gate_name = (!continuing.is_empty() || break_if.is_some())
3785 .then(|| self.namer.call("loop_init"));
3786
3787 if let Some((ref decl, _)) = force_loop_bound_statements {
3788 writeln!(self.out, "{decl}")?;
3789 }
3790 if let Some(ref gate_name) = gate_name {
3791 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3792 }
3793
3794 writeln!(self.out, "{level}while(true) {{",)?;
3795 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3796 writeln!(self.out, "{break_and_inc}")?;
3797 }
3798 if let Some(ref gate_name) = gate_name {
3799 let lif = level.next();
3800 let lcontinuing = lif.next();
3801 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3802 self.put_block(lcontinuing, continuing, context)?;
3803 if let Some(condition) = break_if {
3804 write!(self.out, "{lcontinuing}if (")?;
3805 self.put_expression(condition, &context.expression, true)?;
3806 writeln!(self.out, ") {{")?;
3807 writeln!(self.out, "{}break;", lcontinuing.next())?;
3808 writeln!(self.out, "{lcontinuing}}}")?;
3809 }
3810 writeln!(self.out, "{lif}}}")?;
3811 writeln!(self.out, "{lif}{gate_name} = false;")?;
3812 }
3813 self.put_block(level.next(), body, context)?;
3814
3815 writeln!(self.out, "{level}}}")?;
3816 }
3817 crate::Statement::Break => {
3818 writeln!(self.out, "{level}break;")?;
3819 }
3820 crate::Statement::Continue => {
3821 writeln!(self.out, "{level}continue;")?;
3822 }
3823 crate::Statement::Return {
3824 value: Some(expr_handle),
3825 } => {
3826 self.put_return_value(
3827 level,
3828 expr_handle,
3829 context.result_struct,
3830 &context.expression,
3831 )?;
3832 }
3833 crate::Statement::Return { value: None } => {
3834 writeln!(self.out, "{level}return;")?;
3835 }
3836 crate::Statement::Kill => {
3837 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3838 }
3839 crate::Statement::ControlBarrier(flags)
3840 | crate::Statement::MemoryBarrier(flags) => {
3841 self.write_barrier(flags, level)?;
3842 }
3843 crate::Statement::Store { pointer, value } => {
3844 self.put_store(pointer, value, level, context)?
3845 }
3846 crate::Statement::ImageStore {
3847 image,
3848 coordinate,
3849 array_index,
3850 value,
3851 } => {
3852 let address = TexelAddress {
3853 coordinate,
3854 array_index,
3855 sample: None,
3856 level: None,
3857 };
3858 self.put_image_store(level, image, &address, value, context)?
3859 }
3860 crate::Statement::Call {
3861 function,
3862 ref arguments,
3863 result,
3864 } => {
3865 write!(self.out, "{level}")?;
3866 if let Some(expr) = result {
3867 let name = Baked(expr).to_string();
3868 self.start_baking_expression(expr, &context.expression, &name)?;
3869 self.named_expressions.insert(expr, name);
3870 }
3871 let fun_name = &self.names[&NameKey::Function(function)];
3872 write!(self.out, "{fun_name}(")?;
3873 for (i, &handle) in arguments.iter().enumerate() {
3875 if i != 0 {
3876 write!(self.out, ", ")?;
3877 }
3878 self.put_expression(handle, &context.expression, true)?;
3879 }
3880 let mut separate = !arguments.is_empty();
3882 let fun_info = &context.expression.mod_info[function];
3883 let mut needs_buffer_sizes = false;
3884 for (handle, var) in context.expression.module.global_variables.iter() {
3885 if fun_info[handle].is_empty() {
3886 continue;
3887 }
3888 if var.space.needs_pass_through() {
3889 let name = &self.names[&NameKey::GlobalVariable(handle)];
3890 if separate {
3891 write!(self.out, ", ")?;
3892 } else {
3893 separate = true;
3894 }
3895 write!(self.out, "{name}")?;
3896 }
3897 needs_buffer_sizes |=
3898 needs_array_length(var.ty, &context.expression.module.types);
3899 }
3900 if needs_buffer_sizes {
3901 if separate {
3902 write!(self.out, ", ")?;
3903 }
3904 write!(self.out, "_buffer_sizes")?;
3905 }
3906
3907 writeln!(self.out, ");")?;
3909 }
3910 crate::Statement::Atomic {
3911 pointer,
3912 ref fun,
3913 value,
3914 result,
3915 } => {
3916 let context = &context.expression;
3917
3918 write!(self.out, "{level}")?;
3923 let fun_key = if let Some(result) = result {
3924 let res_name = Baked(result).to_string();
3925 self.start_baking_expression(result, context, &res_name)?;
3926 self.named_expressions.insert(result, res_name);
3927 fun.to_msl()
3928 } else if context.resolve_type(value).scalar_width() == Some(8) {
3929 fun.to_msl_64_bit()?
3930 } else {
3931 fun.to_msl()
3932 };
3933
3934 let policy = context.choose_bounds_check_policy(pointer);
3938 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3939 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3940
3941 if checked {
3943 write!(self.out, " ? ")?;
3944 }
3945
3946 match *fun {
3948 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3949 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3950 self.put_access_chain(pointer, policy, context)?;
3951 write!(self.out, ", ")?;
3952 self.put_expression(cmp, context, true)?;
3953 write!(self.out, ", ")?;
3954 self.put_expression(value, context, true)?;
3955 write!(self.out, ")")?;
3956 }
3957 _ => {
3958 write!(
3959 self.out,
3960 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3961 )?;
3962 self.put_access_chain(pointer, policy, context)?;
3963 write!(self.out, ", ")?;
3964 self.put_expression(value, context, true)?;
3965 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3966 }
3967 }
3968
3969 if checked {
3971 write!(self.out, " : DefaultConstructible()")?;
3972 }
3973
3974 writeln!(self.out, ";")?;
3976 }
3977 crate::Statement::ImageAtomic {
3978 image,
3979 coordinate,
3980 array_index,
3981 fun,
3982 value,
3983 } => {
3984 let address = TexelAddress {
3985 coordinate,
3986 array_index,
3987 sample: None,
3988 level: None,
3989 };
3990 self.put_image_atomic(level, image, &address, fun, value, context)?
3991 }
3992 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3993 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3994
3995 write!(self.out, "{level}")?;
3996 let name = self.namer.call("");
3997 self.start_baking_expression(result, &context.expression, &name)?;
3998 self.put_load(pointer, &context.expression, true)?;
3999 self.named_expressions.insert(result, name);
4000
4001 writeln!(self.out, ";")?;
4002 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4003 }
4004 crate::Statement::RayQuery { query, ref fun } => {
4005 if context.expression.lang_version < (2, 4) {
4006 return Err(Error::UnsupportedRayTracing);
4007 }
4008
4009 match *fun {
4010 crate::RayQueryFunction::Initialize {
4011 acceleration_structure,
4012 descriptor,
4013 } => {
4014 write!(self.out, "{level}")?;
4016 self.put_expression(query, &context.expression, true)?;
4017 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4018 {
4019 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4020 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4021 write!(self.out, "{level}")?;
4022 self.put_expression(query, &context.expression, true)?;
4023 write!(
4024 self.out,
4025 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4026 )?;
4027 self.put_expression(descriptor, &context.expression, true)?;
4028 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4029 self.put_expression(descriptor, &context.expression, true)?;
4030 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4031 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4032 }
4033 {
4034 let f_opaque = back::RayFlag::OPAQUE.bits();
4035 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4036 write!(self.out, "{level}")?;
4037 self.put_expression(query, &context.expression, true)?;
4038 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4039 self.put_expression(descriptor, &context.expression, true)?;
4040 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4041 self.put_expression(descriptor, &context.expression, true)?;
4042 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4043 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4044 }
4045 {
4046 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4047 write!(self.out, "{level}")?;
4048 self.put_expression(query, &context.expression, true)?;
4049 write!(
4050 self.out,
4051 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4052 )?;
4053 self.put_expression(descriptor, &context.expression, true)?;
4054 writeln!(self.out, ".flags & {flag}) != 0);")?;
4055 }
4056
4057 write!(self.out, "{level}")?;
4058 self.put_expression(query, &context.expression, true)?;
4059 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4060 self.put_expression(query, &context.expression, true)?;
4061 write!(
4062 self.out,
4063 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4064 )?;
4065 self.put_expression(descriptor, &context.expression, true)?;
4066 write!(self.out, ".origin, ")?;
4067 self.put_expression(descriptor, &context.expression, true)?;
4068 write!(self.out, ".dir, ")?;
4069 self.put_expression(descriptor, &context.expression, true)?;
4070 write!(self.out, ".tmin, ")?;
4071 self.put_expression(descriptor, &context.expression, true)?;
4072 write!(self.out, ".tmax), ")?;
4073 self.put_expression(acceleration_structure, &context.expression, true)?;
4074 write!(self.out, ", ")?;
4075 self.put_expression(descriptor, &context.expression, true)?;
4076 write!(self.out, ".cull_mask);")?;
4077
4078 write!(self.out, "{level}")?;
4079 self.put_expression(query, &context.expression, true)?;
4080 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4081 }
4082 crate::RayQueryFunction::Proceed { result } => {
4083 write!(self.out, "{level}")?;
4084 let name = Baked(result).to_string();
4085 self.start_baking_expression(result, &context.expression, &name)?;
4086 self.named_expressions.insert(result, name);
4087 self.put_expression(query, &context.expression, true)?;
4088 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4089 if RAY_QUERY_MODERN_SUPPORT {
4090 write!(self.out, "{level}")?;
4091 self.put_expression(query, &context.expression, true)?;
4092 writeln!(self.out, ".?.next();")?;
4093 }
4094 }
4095 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4096 if RAY_QUERY_MODERN_SUPPORT {
4097 write!(self.out, "{level}")?;
4098 self.put_expression(query, &context.expression, true)?;
4099 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4100 self.put_expression(hit_t, &context.expression, true)?;
4101 writeln!(self.out, ");")?;
4102 } else {
4103 log::warn!("Ray Query GenerateIntersection is not yet supported");
4104 }
4105 }
4106 crate::RayQueryFunction::ConfirmIntersection => {
4107 if RAY_QUERY_MODERN_SUPPORT {
4108 write!(self.out, "{level}")?;
4109 self.put_expression(query, &context.expression, true)?;
4110 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4111 } else {
4112 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4113 }
4114 }
4115 crate::RayQueryFunction::Terminate => {
4116 if RAY_QUERY_MODERN_SUPPORT {
4117 write!(self.out, "{level}")?;
4118 self.put_expression(query, &context.expression, true)?;
4119 writeln!(self.out, ".?.abort();")?;
4120 }
4121 write!(self.out, "{level}")?;
4122 self.put_expression(query, &context.expression, true)?;
4123 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4124 }
4125 }
4126 }
4127 crate::Statement::SubgroupBallot { result, predicate } => {
4128 write!(self.out, "{level}")?;
4129 let name = self.namer.call("");
4130 self.start_baking_expression(result, &context.expression, &name)?;
4131 self.named_expressions.insert(result, name);
4132 write!(
4133 self.out,
4134 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4135 )?;
4136 if let Some(predicate) = predicate {
4137 self.put_expression(predicate, &context.expression, true)?;
4138 } else {
4139 write!(self.out, "true")?;
4140 }
4141 writeln!(self.out, "), 0, 0, 0);")?;
4142 }
4143 crate::Statement::SubgroupCollectiveOperation {
4144 op,
4145 collective_op,
4146 argument,
4147 result,
4148 } => {
4149 write!(self.out, "{level}")?;
4150 let name = self.namer.call("");
4151 self.start_baking_expression(result, &context.expression, &name)?;
4152 self.named_expressions.insert(result, name);
4153 match (collective_op, op) {
4154 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4155 write!(self.out, "{NAMESPACE}::simd_all(")?
4156 }
4157 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4158 write!(self.out, "{NAMESPACE}::simd_any(")?
4159 }
4160 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4161 write!(self.out, "{NAMESPACE}::simd_sum(")?
4162 }
4163 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4164 write!(self.out, "{NAMESPACE}::simd_product(")?
4165 }
4166 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4167 write!(self.out, "{NAMESPACE}::simd_max(")?
4168 }
4169 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4170 write!(self.out, "{NAMESPACE}::simd_min(")?
4171 }
4172 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4173 write!(self.out, "{NAMESPACE}::simd_and(")?
4174 }
4175 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4176 write!(self.out, "{NAMESPACE}::simd_or(")?
4177 }
4178 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4179 write!(self.out, "{NAMESPACE}::simd_xor(")?
4180 }
4181 (
4182 crate::CollectiveOperation::ExclusiveScan,
4183 crate::SubgroupOperation::Add,
4184 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4185 (
4186 crate::CollectiveOperation::ExclusiveScan,
4187 crate::SubgroupOperation::Mul,
4188 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4189 (
4190 crate::CollectiveOperation::InclusiveScan,
4191 crate::SubgroupOperation::Add,
4192 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4193 (
4194 crate::CollectiveOperation::InclusiveScan,
4195 crate::SubgroupOperation::Mul,
4196 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4197 _ => unimplemented!(),
4198 }
4199 self.put_expression(argument, &context.expression, true)?;
4200 writeln!(self.out, ");")?;
4201 }
4202 crate::Statement::SubgroupGather {
4203 mode,
4204 argument,
4205 result,
4206 } => {
4207 write!(self.out, "{level}")?;
4208 let name = self.namer.call("");
4209 self.start_baking_expression(result, &context.expression, &name)?;
4210 self.named_expressions.insert(result, name);
4211 match mode {
4212 crate::GatherMode::BroadcastFirst => {
4213 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4214 }
4215 crate::GatherMode::Broadcast(_) => {
4216 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4217 }
4218 crate::GatherMode::Shuffle(_) => {
4219 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4220 }
4221 crate::GatherMode::ShuffleDown(_) => {
4222 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4223 }
4224 crate::GatherMode::ShuffleUp(_) => {
4225 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4226 }
4227 crate::GatherMode::ShuffleXor(_) => {
4228 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4229 }
4230 crate::GatherMode::QuadBroadcast(_) => {
4231 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4232 }
4233 crate::GatherMode::QuadSwap(_) => {
4234 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4235 }
4236 }
4237 self.put_expression(argument, &context.expression, true)?;
4238 match mode {
4239 crate::GatherMode::BroadcastFirst => {}
4240 crate::GatherMode::Broadcast(index)
4241 | crate::GatherMode::Shuffle(index)
4242 | crate::GatherMode::ShuffleDown(index)
4243 | crate::GatherMode::ShuffleUp(index)
4244 | crate::GatherMode::ShuffleXor(index)
4245 | crate::GatherMode::QuadBroadcast(index) => {
4246 write!(self.out, ", ")?;
4247 self.put_expression(index, &context.expression, true)?;
4248 }
4249 crate::GatherMode::QuadSwap(direction) => {
4250 write!(self.out, ", ")?;
4251 match direction {
4252 crate::Direction::X => {
4253 write!(self.out, "1u")?;
4254 }
4255 crate::Direction::Y => {
4256 write!(self.out, "2u")?;
4257 }
4258 crate::Direction::Diagonal => {
4259 write!(self.out, "3u")?;
4260 }
4261 }
4262 }
4263 }
4264 writeln!(self.out, ");")?;
4265 }
4266 crate::Statement::CooperativeStore { target, ref data } => {
4267 write!(self.out, "{level}simdgroup_store(")?;
4268 self.put_expression(target, &context.expression, true)?;
4269 write!(self.out, ", &")?;
4270 self.put_access_chain(
4271 data.pointer,
4272 context.expression.policies.index,
4273 &context.expression,
4274 )?;
4275 write!(self.out, ", ")?;
4276 self.put_expression(data.stride, &context.expression, true)?;
4277 if data.row_major {
4278 let matrix_origin = "0";
4279 let transpose = true;
4280 write!(self.out, ", {matrix_origin}, {transpose}")?;
4281 }
4282 writeln!(self.out, ");")?;
4283 }
4284 }
4285 }
4286
4287 for statement in statements {
4290 if let crate::Statement::Emit(ref range) = *statement {
4291 for handle in range.clone() {
4292 self.named_expressions.shift_remove(&handle);
4293 }
4294 }
4295 }
4296 Ok(())
4297 }
4298
4299 fn put_store(
4300 &mut self,
4301 pointer: Handle<crate::Expression>,
4302 value: Handle<crate::Expression>,
4303 level: back::Level,
4304 context: &StatementContext,
4305 ) -> BackendResult {
4306 let policy = context.expression.choose_bounds_check_policy(pointer);
4307 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4308 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4309 {
4310 writeln!(self.out, ") {{")?;
4311 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4312 writeln!(self.out, "{level}}}")?;
4313 } else {
4314 self.put_unchecked_store(pointer, value, policy, level, context)?;
4315 }
4316
4317 Ok(())
4318 }
4319
4320 fn put_unchecked_store(
4321 &mut self,
4322 pointer: Handle<crate::Expression>,
4323 value: Handle<crate::Expression>,
4324 policy: index::BoundsCheckPolicy,
4325 level: back::Level,
4326 context: &StatementContext,
4327 ) -> BackendResult {
4328 let is_atomic_pointer = context
4329 .expression
4330 .resolve_type(pointer)
4331 .is_atomic_pointer(&context.expression.module.types);
4332
4333 if is_atomic_pointer {
4334 write!(
4335 self.out,
4336 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4337 )?;
4338 self.put_access_chain(pointer, policy, &context.expression)?;
4339 write!(self.out, ", ")?;
4340 self.put_expression(value, &context.expression, true)?;
4341 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4342 } else {
4343 write!(self.out, "{level}")?;
4344 self.put_access_chain(pointer, policy, &context.expression)?;
4345 write!(self.out, " = ")?;
4346 self.put_expression(value, &context.expression, true)?;
4347 writeln!(self.out, ";")?;
4348 }
4349
4350 Ok(())
4351 }
4352
4353 pub fn write(
4354 &mut self,
4355 module: &crate::Module,
4356 info: &valid::ModuleInfo,
4357 options: &Options,
4358 pipeline_options: &PipelineOptions,
4359 ) -> Result<TranslationInfo, Error> {
4360 self.names.clear();
4361 self.namer.reset(
4362 module,
4363 &super::keywords::RESERVED_SET,
4364 proc::CaseInsensitiveKeywordSet::empty(),
4365 &[CLAMPED_LOD_LOAD_PREFIX],
4366 &mut self.names,
4367 );
4368 self.wrapped_functions.clear();
4369 self.struct_member_pads.clear();
4370
4371 writeln!(
4372 self.out,
4373 "// language: metal{}.{}",
4374 options.lang_version.0, options.lang_version.1
4375 )?;
4376 writeln!(self.out, "#include <metal_stdlib>")?;
4377 writeln!(self.out, "#include <simd/simd.h>")?;
4378 writeln!(self.out)?;
4379 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4381
4382 let mut uses_ray_query = false;
4383 for (_, ty) in module.types.iter() {
4384 match ty.inner {
4385 crate::TypeInner::AccelerationStructure { .. } => {
4386 if options.lang_version < (2, 4) {
4387 return Err(Error::UnsupportedRayTracing);
4388 }
4389 }
4390 crate::TypeInner::RayQuery { .. } => {
4391 if options.lang_version < (2, 4) {
4392 return Err(Error::UnsupportedRayTracing);
4393 }
4394 uses_ray_query = true;
4395 }
4396 _ => (),
4397 }
4398 }
4399
4400 if module.special_types.ray_desc.is_some()
4401 || module.special_types.ray_intersection.is_some()
4402 {
4403 if options.lang_version < (2, 4) {
4404 return Err(Error::UnsupportedRayTracing);
4405 }
4406 }
4407
4408 if uses_ray_query {
4409 self.put_ray_query_type()?;
4410 }
4411
4412 if options
4413 .bounds_check_policies
4414 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4415 {
4416 self.put_default_constructible()?;
4417 }
4418 writeln!(self.out)?;
4419
4420 {
4421 let globals: Vec<Handle<crate::GlobalVariable>> = module
4424 .global_variables
4425 .iter()
4426 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4427 .map(|(handle, _)| handle)
4428 .collect();
4429
4430 let mut buffer_indices = vec![];
4431 for vbm in &pipeline_options.vertex_buffer_mappings {
4432 buffer_indices.push(vbm.id);
4433 }
4434
4435 if !globals.is_empty() || !buffer_indices.is_empty() {
4436 writeln!(self.out, "struct _mslBufferSizes {{")?;
4437
4438 for global in globals {
4439 writeln!(
4440 self.out,
4441 "{}uint {};",
4442 back::INDENT,
4443 ArraySizeMember(global)
4444 )?;
4445 }
4446
4447 for idx in buffer_indices {
4448 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4449 }
4450
4451 writeln!(self.out, "}};")?;
4452 writeln!(self.out)?;
4453 }
4454 };
4455
4456 self.write_type_defs(module)?;
4457 self.write_global_constants(module, info)?;
4458 self.write_functions(module, info, options, pipeline_options)
4459 }
4460
4461 fn put_default_constructible(&mut self) -> BackendResult {
4474 let tab = back::INDENT;
4475 writeln!(self.out, "struct DefaultConstructible {{")?;
4476 writeln!(self.out, "{tab}template<typename T>")?;
4477 writeln!(self.out, "{tab}operator T() && {{")?;
4478 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4479 writeln!(self.out, "{tab}}}")?;
4480 writeln!(self.out, "}};")?;
4481 Ok(())
4482 }
4483
4484 fn put_ray_query_type(&mut self) -> BackendResult {
4485 let tab = back::INDENT;
4486 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4487 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4488 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4489 writeln!(
4490 self.out,
4491 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4492 )?;
4493 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4494 writeln!(self.out, "}};")?;
4495 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4496 let v_triangle = back::RayIntersectionType::Triangle as u32;
4497 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4498 writeln!(
4499 self.out,
4500 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4501 )?;
4502 writeln!(
4503 self.out,
4504 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4505 )?;
4506 writeln!(self.out, "}}")?;
4507 Ok(())
4508 }
4509
4510 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4511 let mut generated_argument_buffer_wrapper = false;
4512 let mut generated_external_texture_wrapper = false;
4513 for (handle, ty) in module.types.iter() {
4514 match ty.inner {
4515 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4516 writeln!(self.out, "template <typename T>")?;
4517 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4518 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4519 writeln!(self.out, "}};")?;
4520 generated_argument_buffer_wrapper = true;
4521 }
4522 crate::TypeInner::Image {
4523 class: crate::ImageClass::External,
4524 ..
4525 } if !generated_external_texture_wrapper => {
4526 let params_ty_name = &self.names
4527 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4528 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4529 writeln!(
4530 self.out,
4531 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4532 back::INDENT
4533 )?;
4534 writeln!(
4535 self.out,
4536 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4537 back::INDENT
4538 )?;
4539 writeln!(
4540 self.out,
4541 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4542 back::INDENT
4543 )?;
4544 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4545 writeln!(self.out, "}};")?;
4546 generated_external_texture_wrapper = true;
4547 }
4548 _ => {}
4549 }
4550
4551 if !ty.needs_alias() {
4552 continue;
4553 }
4554 let name = &self.names[&NameKey::Type(handle)];
4555 match ty.inner {
4556 crate::TypeInner::Array {
4570 base,
4571 size,
4572 stride: _,
4573 } => {
4574 let base_name = TypeContext {
4575 handle: base,
4576 gctx: module.to_ctx(),
4577 names: &self.names,
4578 access: crate::StorageAccess::empty(),
4579 first_time: false,
4580 };
4581
4582 match size.resolve(module.to_ctx())? {
4583 proc::IndexableLength::Known(size) => {
4584 writeln!(self.out, "struct {name} {{")?;
4585 writeln!(
4586 self.out,
4587 "{}{} {}[{}];",
4588 back::INDENT,
4589 base_name,
4590 WRAPPED_ARRAY_FIELD,
4591 size
4592 )?;
4593 writeln!(self.out, "}};")?;
4594 }
4595 proc::IndexableLength::Dynamic => {
4596 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4597 }
4598 }
4599 }
4600 crate::TypeInner::Struct {
4601 ref members, span, ..
4602 } => {
4603 writeln!(self.out, "struct {name} {{")?;
4604 let mut last_offset = 0;
4605 for (index, member) in members.iter().enumerate() {
4606 if member.offset > last_offset {
4607 self.struct_member_pads.insert((handle, index as u32));
4608 let pad = member.offset - last_offset;
4609 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4610 }
4611 let ty_inner = &module.types[member.ty].inner;
4612 last_offset = member.offset + ty_inner.size(module.to_ctx());
4613
4614 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4615
4616 match should_pack_struct_member(members, span, index, module) {
4618 Some(scalar) => {
4619 writeln!(
4620 self.out,
4621 "{}{}::packed_{}3 {};",
4622 back::INDENT,
4623 NAMESPACE,
4624 scalar.to_msl_name(),
4625 member_name
4626 )?;
4627 }
4628 None => {
4629 let base_name = TypeContext {
4630 handle: member.ty,
4631 gctx: module.to_ctx(),
4632 names: &self.names,
4633 access: crate::StorageAccess::empty(),
4634 first_time: false,
4635 };
4636 writeln!(
4637 self.out,
4638 "{}{} {};",
4639 back::INDENT,
4640 base_name,
4641 member_name
4642 )?;
4643
4644 if let crate::TypeInner::Vector {
4646 size: crate::VectorSize::Tri,
4647 scalar,
4648 } = *ty_inner
4649 {
4650 last_offset += scalar.width as u32;
4651 }
4652 }
4653 }
4654 }
4655 if last_offset < span {
4656 let pad = span - last_offset;
4657 writeln!(
4658 self.out,
4659 "{}char _pad{}[{}];",
4660 back::INDENT,
4661 members.len(),
4662 pad
4663 )?;
4664 }
4665 writeln!(self.out, "}};")?;
4666 }
4667 _ => {
4668 let ty_name = TypeContext {
4669 handle,
4670 gctx: module.to_ctx(),
4671 names: &self.names,
4672 access: crate::StorageAccess::empty(),
4673 first_time: true,
4674 };
4675 writeln!(self.out, "typedef {ty_name} {name};")?;
4676 }
4677 }
4678 }
4679
4680 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4682 match type_key {
4683 &crate::PredeclaredType::ModfResult { size, scalar }
4684 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4685 let arg_type_name_owner;
4686 let arg_type_name = if let Some(size) = size {
4687 arg_type_name_owner = format!(
4688 "{NAMESPACE}::{}{}",
4689 if scalar.width == 8 { "double" } else { "float" },
4690 size as u8
4691 );
4692 &arg_type_name_owner
4693 } else if scalar.width == 8 {
4694 "double"
4695 } else {
4696 "float"
4697 };
4698
4699 let other_type_name_owner;
4700 let (defined_func_name, called_func_name, other_type_name) =
4701 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4702 (MODF_FUNCTION, "modf", arg_type_name)
4703 } else {
4704 let other_type_name = if let Some(size) = size {
4705 other_type_name_owner = format!("int{}", size as u8);
4706 &other_type_name_owner
4707 } else {
4708 "int"
4709 };
4710 (FREXP_FUNCTION, "frexp", other_type_name)
4711 };
4712
4713 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4714
4715 writeln!(self.out)?;
4716 writeln!(
4717 self.out,
4718 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4719 {other_type_name} other;
4720 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4721 return {struct_name}{{ fract, other }};
4722}}"
4723 )?;
4724 }
4725 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4726 let arg_type_name = scalar.to_msl_name();
4727 let called_func_name = "atomic_compare_exchange_weak_explicit";
4728 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4729 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4730
4731 writeln!(self.out)?;
4732
4733 for address_space_name in ["device", "threadgroup"] {
4734 writeln!(
4735 self.out,
4736 "\
4737template <typename A>
4738{struct_name} {defined_func_name}(
4739 {address_space_name} A *atomic_ptr,
4740 {arg_type_name} cmp,
4741 {arg_type_name} v
4742) {{
4743 bool swapped = {NAMESPACE}::{called_func_name}(
4744 atomic_ptr, &cmp, v,
4745 metal::memory_order_relaxed, metal::memory_order_relaxed
4746 );
4747 return {struct_name}{{cmp, swapped}};
4748}}"
4749 )?;
4750 }
4751 }
4752 }
4753 }
4754
4755 Ok(())
4756 }
4757
4758 fn write_global_constants(
4760 &mut self,
4761 module: &crate::Module,
4762 mod_info: &valid::ModuleInfo,
4763 ) -> BackendResult {
4764 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4765
4766 for (handle, constant) in constants {
4767 let ty_name = TypeContext {
4768 handle: constant.ty,
4769 gctx: module.to_ctx(),
4770 names: &self.names,
4771 access: crate::StorageAccess::empty(),
4772 first_time: false,
4773 };
4774 let name = &self.names[&NameKey::Constant(handle)];
4775 write!(self.out, "constant {ty_name} {name} = ")?;
4776 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4777 writeln!(self.out, ";")?;
4778 }
4779
4780 Ok(())
4781 }
4782
4783 fn put_inline_sampler_properties(
4784 &mut self,
4785 level: back::Level,
4786 sampler: &sm::InlineSampler,
4787 ) -> BackendResult {
4788 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4789 writeln!(
4790 self.out,
4791 "{}{}::{}_address::{},",
4792 level,
4793 NAMESPACE,
4794 letter,
4795 address.as_str(),
4796 )?;
4797 }
4798 writeln!(
4799 self.out,
4800 "{}{}::mag_filter::{},",
4801 level,
4802 NAMESPACE,
4803 sampler.mag_filter.as_str(),
4804 )?;
4805 writeln!(
4806 self.out,
4807 "{}{}::min_filter::{},",
4808 level,
4809 NAMESPACE,
4810 sampler.min_filter.as_str(),
4811 )?;
4812 if let Some(filter) = sampler.mip_filter {
4813 writeln!(
4814 self.out,
4815 "{}{}::mip_filter::{},",
4816 level,
4817 NAMESPACE,
4818 filter.as_str(),
4819 )?;
4820 }
4821 if sampler.border_color != sm::BorderColor::TransparentBlack {
4823 writeln!(
4824 self.out,
4825 "{}{}::border_color::{},",
4826 level,
4827 NAMESPACE,
4828 sampler.border_color.as_str(),
4829 )?;
4830 }
4831 if false {
4835 if let Some(ref lod) = sampler.lod_clamp {
4836 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4837 }
4838 if let Some(aniso) = sampler.max_anisotropy {
4839 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4840 }
4841 }
4842 if sampler.compare_func != sm::CompareFunc::Never {
4843 writeln!(
4844 self.out,
4845 "{}{}::compare_func::{},",
4846 level,
4847 NAMESPACE,
4848 sampler.compare_func.as_str(),
4849 )?;
4850 }
4851 writeln!(
4852 self.out,
4853 "{}{}::coord::{}",
4854 level,
4855 NAMESPACE,
4856 sampler.coord.as_str()
4857 )?;
4858 Ok(())
4859 }
4860
4861 fn write_unpacking_function(
4862 &mut self,
4863 format: back::msl::VertexFormat,
4864 ) -> Result<(String, u32, u32), Error> {
4865 use back::msl::VertexFormat::*;
4866 match format {
4867 Uint8 => {
4868 let name = self.namer.call("unpackUint8");
4869 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4870 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4871 writeln!(self.out, "}}")?;
4872 Ok((name, 1, 1))
4873 }
4874 Uint8x2 => {
4875 let name = self.namer.call("unpackUint8x2");
4876 writeln!(
4877 self.out,
4878 "metal::uint2 {name}(metal::uchar b0, \
4879 metal::uchar b1) {{"
4880 )?;
4881 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4882 writeln!(self.out, "}}")?;
4883 Ok((name, 2, 2))
4884 }
4885 Uint8x4 => {
4886 let name = self.namer.call("unpackUint8x4");
4887 writeln!(
4888 self.out,
4889 "metal::uint4 {name}(metal::uchar b0, \
4890 metal::uchar b1, \
4891 metal::uchar b2, \
4892 metal::uchar b3) {{"
4893 )?;
4894 writeln!(
4895 self.out,
4896 "{}return metal::uint4(b0, b1, b2, b3);",
4897 back::INDENT
4898 )?;
4899 writeln!(self.out, "}}")?;
4900 Ok((name, 4, 4))
4901 }
4902 Sint8 => {
4903 let name = self.namer.call("unpackSint8");
4904 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4905 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4906 writeln!(self.out, "}}")?;
4907 Ok((name, 1, 1))
4908 }
4909 Sint8x2 => {
4910 let name = self.namer.call("unpackSint8x2");
4911 writeln!(
4912 self.out,
4913 "metal::int2 {name}(metal::uchar b0, \
4914 metal::uchar b1) {{"
4915 )?;
4916 writeln!(
4917 self.out,
4918 "{}return metal::int2(as_type<char>(b0), \
4919 as_type<char>(b1));",
4920 back::INDENT
4921 )?;
4922 writeln!(self.out, "}}")?;
4923 Ok((name, 2, 2))
4924 }
4925 Sint8x4 => {
4926 let name = self.namer.call("unpackSint8x4");
4927 writeln!(
4928 self.out,
4929 "metal::int4 {name}(metal::uchar b0, \
4930 metal::uchar b1, \
4931 metal::uchar b2, \
4932 metal::uchar b3) {{"
4933 )?;
4934 writeln!(
4935 self.out,
4936 "{}return metal::int4(as_type<char>(b0), \
4937 as_type<char>(b1), \
4938 as_type<char>(b2), \
4939 as_type<char>(b3));",
4940 back::INDENT
4941 )?;
4942 writeln!(self.out, "}}")?;
4943 Ok((name, 4, 4))
4944 }
4945 Unorm8 => {
4946 let name = self.namer.call("unpackUnorm8");
4947 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4948 writeln!(
4949 self.out,
4950 "{}return float(float(b0) / 255.0f);",
4951 back::INDENT
4952 )?;
4953 writeln!(self.out, "}}")?;
4954 Ok((name, 1, 1))
4955 }
4956 Unorm8x2 => {
4957 let name = self.namer.call("unpackUnorm8x2");
4958 writeln!(
4959 self.out,
4960 "metal::float2 {name}(metal::uchar b0, \
4961 metal::uchar b1) {{"
4962 )?;
4963 writeln!(
4964 self.out,
4965 "{}return metal::float2(float(b0) / 255.0f, \
4966 float(b1) / 255.0f);",
4967 back::INDENT
4968 )?;
4969 writeln!(self.out, "}}")?;
4970 Ok((name, 2, 2))
4971 }
4972 Unorm8x4 => {
4973 let name = self.namer.call("unpackUnorm8x4");
4974 writeln!(
4975 self.out,
4976 "metal::float4 {name}(metal::uchar b0, \
4977 metal::uchar b1, \
4978 metal::uchar b2, \
4979 metal::uchar b3) {{"
4980 )?;
4981 writeln!(
4982 self.out,
4983 "{}return metal::float4(float(b0) / 255.0f, \
4984 float(b1) / 255.0f, \
4985 float(b2) / 255.0f, \
4986 float(b3) / 255.0f);",
4987 back::INDENT
4988 )?;
4989 writeln!(self.out, "}}")?;
4990 Ok((name, 4, 4))
4991 }
4992 Snorm8 => {
4993 let name = self.namer.call("unpackSnorm8");
4994 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4995 writeln!(
4996 self.out,
4997 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4998 back::INDENT
4999 )?;
5000 writeln!(self.out, "}}")?;
5001 Ok((name, 1, 1))
5002 }
5003 Snorm8x2 => {
5004 let name = self.namer.call("unpackSnorm8x2");
5005 writeln!(
5006 self.out,
5007 "metal::float2 {name}(metal::uchar b0, \
5008 metal::uchar b1) {{"
5009 )?;
5010 writeln!(
5011 self.out,
5012 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5013 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5014 back::INDENT
5015 )?;
5016 writeln!(self.out, "}}")?;
5017 Ok((name, 2, 2))
5018 }
5019 Snorm8x4 => {
5020 let name = self.namer.call("unpackSnorm8x4");
5021 writeln!(
5022 self.out,
5023 "metal::float4 {name}(metal::uchar b0, \
5024 metal::uchar b1, \
5025 metal::uchar b2, \
5026 metal::uchar b3) {{"
5027 )?;
5028 writeln!(
5029 self.out,
5030 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5031 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5032 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5033 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5034 back::INDENT
5035 )?;
5036 writeln!(self.out, "}}")?;
5037 Ok((name, 4, 4))
5038 }
5039 Uint16 => {
5040 let name = self.namer.call("unpackUint16");
5041 writeln!(
5042 self.out,
5043 "metal::uint {name}(metal::uint b0, \
5044 metal::uint b1) {{"
5045 )?;
5046 writeln!(
5047 self.out,
5048 "{}return metal::uint(b1 << 8 | b0);",
5049 back::INDENT
5050 )?;
5051 writeln!(self.out, "}}")?;
5052 Ok((name, 2, 1))
5053 }
5054 Uint16x2 => {
5055 let name = self.namer.call("unpackUint16x2");
5056 writeln!(
5057 self.out,
5058 "metal::uint2 {name}(metal::uint b0, \
5059 metal::uint b1, \
5060 metal::uint b2, \
5061 metal::uint b3) {{"
5062 )?;
5063 writeln!(
5064 self.out,
5065 "{}return metal::uint2(b1 << 8 | b0, \
5066 b3 << 8 | b2);",
5067 back::INDENT
5068 )?;
5069 writeln!(self.out, "}}")?;
5070 Ok((name, 4, 2))
5071 }
5072 Uint16x4 => {
5073 let name = self.namer.call("unpackUint16x4");
5074 writeln!(
5075 self.out,
5076 "metal::uint4 {name}(metal::uint b0, \
5077 metal::uint b1, \
5078 metal::uint b2, \
5079 metal::uint b3, \
5080 metal::uint b4, \
5081 metal::uint b5, \
5082 metal::uint b6, \
5083 metal::uint b7) {{"
5084 )?;
5085 writeln!(
5086 self.out,
5087 "{}return metal::uint4(b1 << 8 | b0, \
5088 b3 << 8 | b2, \
5089 b5 << 8 | b4, \
5090 b7 << 8 | b6);",
5091 back::INDENT
5092 )?;
5093 writeln!(self.out, "}}")?;
5094 Ok((name, 8, 4))
5095 }
5096 Sint16 => {
5097 let name = self.namer.call("unpackSint16");
5098 writeln!(
5099 self.out,
5100 "int {name}(metal::ushort b0, \
5101 metal::ushort b1) {{"
5102 )?;
5103 writeln!(
5104 self.out,
5105 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5106 back::INDENT
5107 )?;
5108 writeln!(self.out, "}}")?;
5109 Ok((name, 2, 1))
5110 }
5111 Sint16x2 => {
5112 let name = self.namer.call("unpackSint16x2");
5113 writeln!(
5114 self.out,
5115 "metal::int2 {name}(metal::ushort b0, \
5116 metal::ushort b1, \
5117 metal::ushort b2, \
5118 metal::ushort b3) {{"
5119 )?;
5120 writeln!(
5121 self.out,
5122 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5123 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5124 back::INDENT
5125 )?;
5126 writeln!(self.out, "}}")?;
5127 Ok((name, 4, 2))
5128 }
5129 Sint16x4 => {
5130 let name = self.namer.call("unpackSint16x4");
5131 writeln!(
5132 self.out,
5133 "metal::int4 {name}(metal::ushort b0, \
5134 metal::ushort b1, \
5135 metal::ushort b2, \
5136 metal::ushort b3, \
5137 metal::ushort b4, \
5138 metal::ushort b5, \
5139 metal::ushort b6, \
5140 metal::ushort b7) {{"
5141 )?;
5142 writeln!(
5143 self.out,
5144 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5145 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5146 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5147 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5148 back::INDENT
5149 )?;
5150 writeln!(self.out, "}}")?;
5151 Ok((name, 8, 4))
5152 }
5153 Unorm16 => {
5154 let name = self.namer.call("unpackUnorm16");
5155 writeln!(
5156 self.out,
5157 "float {name}(metal::ushort b0, \
5158 metal::ushort b1) {{"
5159 )?;
5160 writeln!(
5161 self.out,
5162 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5163 back::INDENT
5164 )?;
5165 writeln!(self.out, "}}")?;
5166 Ok((name, 2, 1))
5167 }
5168 Unorm16x2 => {
5169 let name = self.namer.call("unpackUnorm16x2");
5170 writeln!(
5171 self.out,
5172 "metal::float2 {name}(metal::ushort b0, \
5173 metal::ushort b1, \
5174 metal::ushort b2, \
5175 metal::ushort b3) {{"
5176 )?;
5177 writeln!(
5178 self.out,
5179 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5180 float(b3 << 8 | b2) / 65535.0f);",
5181 back::INDENT
5182 )?;
5183 writeln!(self.out, "}}")?;
5184 Ok((name, 4, 2))
5185 }
5186 Unorm16x4 => {
5187 let name = self.namer.call("unpackUnorm16x4");
5188 writeln!(
5189 self.out,
5190 "metal::float4 {name}(metal::ushort b0, \
5191 metal::ushort b1, \
5192 metal::ushort b2, \
5193 metal::ushort b3, \
5194 metal::ushort b4, \
5195 metal::ushort b5, \
5196 metal::ushort b6, \
5197 metal::ushort b7) {{"
5198 )?;
5199 writeln!(
5200 self.out,
5201 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5202 float(b3 << 8 | b2) / 65535.0f, \
5203 float(b5 << 8 | b4) / 65535.0f, \
5204 float(b7 << 8 | b6) / 65535.0f);",
5205 back::INDENT
5206 )?;
5207 writeln!(self.out, "}}")?;
5208 Ok((name, 8, 4))
5209 }
5210 Snorm16 => {
5211 let name = self.namer.call("unpackSnorm16");
5212 writeln!(
5213 self.out,
5214 "float {name}(metal::ushort b0, \
5215 metal::ushort b1) {{"
5216 )?;
5217 writeln!(
5218 self.out,
5219 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5220 back::INDENT
5221 )?;
5222 writeln!(self.out, "}}")?;
5223 Ok((name, 2, 1))
5224 }
5225 Snorm16x2 => {
5226 let name = self.namer.call("unpackSnorm16x2");
5227 writeln!(
5228 self.out,
5229 "metal::float2 {name}(uint b0, \
5230 uint b1, \
5231 uint b2, \
5232 uint b3) {{"
5233 )?;
5234 writeln!(
5235 self.out,
5236 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5237 back::INDENT
5238 )?;
5239 writeln!(self.out, "}}")?;
5240 Ok((name, 4, 2))
5241 }
5242 Snorm16x4 => {
5243 let name = self.namer.call("unpackSnorm16x4");
5244 writeln!(
5245 self.out,
5246 "metal::float4 {name}(uint b0, \
5247 uint b1, \
5248 uint b2, \
5249 uint b3, \
5250 uint b4, \
5251 uint b5, \
5252 uint b6, \
5253 uint b7) {{"
5254 )?;
5255 writeln!(
5256 self.out,
5257 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5258 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5259 back::INDENT
5260 )?;
5261 writeln!(self.out, "}}")?;
5262 Ok((name, 8, 4))
5263 }
5264 Float16 => {
5265 let name = self.namer.call("unpackFloat16");
5266 writeln!(
5267 self.out,
5268 "float {name}(metal::ushort b0, \
5269 metal::ushort b1) {{"
5270 )?;
5271 writeln!(
5272 self.out,
5273 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5274 back::INDENT
5275 )?;
5276 writeln!(self.out, "}}")?;
5277 Ok((name, 2, 1))
5278 }
5279 Float16x2 => {
5280 let name = self.namer.call("unpackFloat16x2");
5281 writeln!(
5282 self.out,
5283 "metal::float2 {name}(metal::ushort b0, \
5284 metal::ushort b1, \
5285 metal::ushort b2, \
5286 metal::ushort b3) {{"
5287 )?;
5288 writeln!(
5289 self.out,
5290 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5291 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5292 back::INDENT
5293 )?;
5294 writeln!(self.out, "}}")?;
5295 Ok((name, 4, 2))
5296 }
5297 Float16x4 => {
5298 let name = self.namer.call("unpackFloat16x4");
5299 writeln!(
5300 self.out,
5301 "metal::float4 {name}(metal::ushort b0, \
5302 metal::ushort b1, \
5303 metal::ushort b2, \
5304 metal::ushort b3, \
5305 metal::ushort b4, \
5306 metal::ushort b5, \
5307 metal::ushort b6, \
5308 metal::ushort b7) {{"
5309 )?;
5310 writeln!(
5311 self.out,
5312 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5313 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5314 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5315 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5316 back::INDENT
5317 )?;
5318 writeln!(self.out, "}}")?;
5319 Ok((name, 8, 4))
5320 }
5321 Float32 => {
5322 let name = self.namer.call("unpackFloat32");
5323 writeln!(
5324 self.out,
5325 "float {name}(uint b0, \
5326 uint b1, \
5327 uint b2, \
5328 uint b3) {{"
5329 )?;
5330 writeln!(
5331 self.out,
5332 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5333 back::INDENT
5334 )?;
5335 writeln!(self.out, "}}")?;
5336 Ok((name, 4, 1))
5337 }
5338 Float32x2 => {
5339 let name = self.namer.call("unpackFloat32x2");
5340 writeln!(
5341 self.out,
5342 "metal::float2 {name}(uint b0, \
5343 uint b1, \
5344 uint b2, \
5345 uint b3, \
5346 uint b4, \
5347 uint b5, \
5348 uint b6, \
5349 uint b7) {{"
5350 )?;
5351 writeln!(
5352 self.out,
5353 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5354 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5355 back::INDENT
5356 )?;
5357 writeln!(self.out, "}}")?;
5358 Ok((name, 8, 2))
5359 }
5360 Float32x3 => {
5361 let name = self.namer.call("unpackFloat32x3");
5362 writeln!(
5363 self.out,
5364 "metal::float3 {name}(uint b0, \
5365 uint b1, \
5366 uint b2, \
5367 uint b3, \
5368 uint b4, \
5369 uint b5, \
5370 uint b6, \
5371 uint b7, \
5372 uint b8, \
5373 uint b9, \
5374 uint b10, \
5375 uint b11) {{"
5376 )?;
5377 writeln!(
5378 self.out,
5379 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5380 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5381 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5382 back::INDENT
5383 )?;
5384 writeln!(self.out, "}}")?;
5385 Ok((name, 12, 3))
5386 }
5387 Float32x4 => {
5388 let name = self.namer.call("unpackFloat32x4");
5389 writeln!(
5390 self.out,
5391 "metal::float4 {name}(uint b0, \
5392 uint b1, \
5393 uint b2, \
5394 uint b3, \
5395 uint b4, \
5396 uint b5, \
5397 uint b6, \
5398 uint b7, \
5399 uint b8, \
5400 uint b9, \
5401 uint b10, \
5402 uint b11, \
5403 uint b12, \
5404 uint b13, \
5405 uint b14, \
5406 uint b15) {{"
5407 )?;
5408 writeln!(
5409 self.out,
5410 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5411 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5412 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5413 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5414 back::INDENT
5415 )?;
5416 writeln!(self.out, "}}")?;
5417 Ok((name, 16, 4))
5418 }
5419 Uint32 => {
5420 let name = self.namer.call("unpackUint32");
5421 writeln!(
5422 self.out,
5423 "uint {name}(uint b0, \
5424 uint b1, \
5425 uint b2, \
5426 uint b3) {{"
5427 )?;
5428 writeln!(
5429 self.out,
5430 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5431 back::INDENT
5432 )?;
5433 writeln!(self.out, "}}")?;
5434 Ok((name, 4, 1))
5435 }
5436 Uint32x2 => {
5437 let name = self.namer.call("unpackUint32x2");
5438 writeln!(
5439 self.out,
5440 "uint2 {name}(uint b0, \
5441 uint b1, \
5442 uint b2, \
5443 uint b3, \
5444 uint b4, \
5445 uint b5, \
5446 uint b6, \
5447 uint b7) {{"
5448 )?;
5449 writeln!(
5450 self.out,
5451 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5452 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5453 back::INDENT
5454 )?;
5455 writeln!(self.out, "}}")?;
5456 Ok((name, 8, 2))
5457 }
5458 Uint32x3 => {
5459 let name = self.namer.call("unpackUint32x3");
5460 writeln!(
5461 self.out,
5462 "uint3 {name}(uint b0, \
5463 uint b1, \
5464 uint b2, \
5465 uint b3, \
5466 uint b4, \
5467 uint b5, \
5468 uint b6, \
5469 uint b7, \
5470 uint b8, \
5471 uint b9, \
5472 uint b10, \
5473 uint b11) {{"
5474 )?;
5475 writeln!(
5476 self.out,
5477 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5478 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5479 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5480 back::INDENT
5481 )?;
5482 writeln!(self.out, "}}")?;
5483 Ok((name, 12, 3))
5484 }
5485 Uint32x4 => {
5486 let name = self.namer.call("unpackUint32x4");
5487 writeln!(
5488 self.out,
5489 "{NAMESPACE}::uint4 {name}(uint b0, \
5490 uint b1, \
5491 uint b2, \
5492 uint b3, \
5493 uint b4, \
5494 uint b5, \
5495 uint b6, \
5496 uint b7, \
5497 uint b8, \
5498 uint b9, \
5499 uint b10, \
5500 uint b11, \
5501 uint b12, \
5502 uint b13, \
5503 uint b14, \
5504 uint b15) {{"
5505 )?;
5506 writeln!(
5507 self.out,
5508 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5509 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5510 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5511 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5512 back::INDENT
5513 )?;
5514 writeln!(self.out, "}}")?;
5515 Ok((name, 16, 4))
5516 }
5517 Sint32 => {
5518 let name = self.namer.call("unpackSint32");
5519 writeln!(
5520 self.out,
5521 "int {name}(uint b0, \
5522 uint b1, \
5523 uint b2, \
5524 uint b3) {{"
5525 )?;
5526 writeln!(
5527 self.out,
5528 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5529 back::INDENT
5530 )?;
5531 writeln!(self.out, "}}")?;
5532 Ok((name, 4, 1))
5533 }
5534 Sint32x2 => {
5535 let name = self.namer.call("unpackSint32x2");
5536 writeln!(
5537 self.out,
5538 "metal::int2 {name}(uint b0, \
5539 uint b1, \
5540 uint b2, \
5541 uint b3, \
5542 uint b4, \
5543 uint b5, \
5544 uint b6, \
5545 uint b7) {{"
5546 )?;
5547 writeln!(
5548 self.out,
5549 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5550 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5551 back::INDENT
5552 )?;
5553 writeln!(self.out, "}}")?;
5554 Ok((name, 8, 2))
5555 }
5556 Sint32x3 => {
5557 let name = self.namer.call("unpackSint32x3");
5558 writeln!(
5559 self.out,
5560 "metal::int3 {name}(uint b0, \
5561 uint b1, \
5562 uint b2, \
5563 uint b3, \
5564 uint b4, \
5565 uint b5, \
5566 uint b6, \
5567 uint b7, \
5568 uint b8, \
5569 uint b9, \
5570 uint b10, \
5571 uint b11) {{"
5572 )?;
5573 writeln!(
5574 self.out,
5575 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5576 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5577 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5578 back::INDENT
5579 )?;
5580 writeln!(self.out, "}}")?;
5581 Ok((name, 12, 3))
5582 }
5583 Sint32x4 => {
5584 let name = self.namer.call("unpackSint32x4");
5585 writeln!(
5586 self.out,
5587 "metal::int4 {name}(uint b0, \
5588 uint b1, \
5589 uint b2, \
5590 uint b3, \
5591 uint b4, \
5592 uint b5, \
5593 uint b6, \
5594 uint b7, \
5595 uint b8, \
5596 uint b9, \
5597 uint b10, \
5598 uint b11, \
5599 uint b12, \
5600 uint b13, \
5601 uint b14, \
5602 uint b15) {{"
5603 )?;
5604 writeln!(
5605 self.out,
5606 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5607 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5608 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5609 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5610 back::INDENT
5611 )?;
5612 writeln!(self.out, "}}")?;
5613 Ok((name, 16, 4))
5614 }
5615 Unorm10_10_10_2 => {
5616 let name = self.namer.call("unpackUnorm10_10_10_2");
5617 writeln!(
5618 self.out,
5619 "metal::float4 {name}(uint b0, \
5620 uint b1, \
5621 uint b2, \
5622 uint b3) {{"
5623 )?;
5624 writeln!(
5625 self.out,
5626 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5638 back::INDENT
5639 )?;
5640 writeln!(self.out, "}}")?;
5641 Ok((name, 4, 4))
5642 }
5643 Unorm8x4Bgra => {
5644 let name = self.namer.call("unpackUnorm8x4Bgra");
5645 writeln!(
5646 self.out,
5647 "metal::float4 {name}(metal::uchar b0, \
5648 metal::uchar b1, \
5649 metal::uchar b2, \
5650 metal::uchar b3) {{"
5651 )?;
5652 writeln!(
5653 self.out,
5654 "{}return metal::float4(float(b2) / 255.0f, \
5655 float(b1) / 255.0f, \
5656 float(b0) / 255.0f, \
5657 float(b3) / 255.0f);",
5658 back::INDENT
5659 )?;
5660 writeln!(self.out, "}}")?;
5661 Ok((name, 4, 4))
5662 }
5663 }
5664 }
5665
5666 fn write_wrapped_unary_op(
5667 &mut self,
5668 module: &crate::Module,
5669 func_ctx: &back::FunctionCtx,
5670 op: crate::UnaryOperator,
5671 operand: Handle<crate::Expression>,
5672 ) -> BackendResult {
5673 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5674 match op {
5675 crate::UnaryOperator::Negate
5682 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5683 {
5684 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5685 return Ok(());
5686 };
5687 let wrapped = WrappedFunction::UnaryOp {
5688 op,
5689 ty: (vector_size, scalar),
5690 };
5691 if !self.wrapped_functions.insert(wrapped) {
5692 return Ok(());
5693 }
5694
5695 let unsigned_scalar = crate::Scalar {
5696 kind: crate::ScalarKind::Uint,
5697 ..scalar
5698 };
5699 let mut type_name = String::new();
5700 let mut unsigned_type_name = String::new();
5701 match vector_size {
5702 None => {
5703 put_numeric_type(&mut type_name, scalar, &[])?;
5704 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5705 }
5706 Some(size) => {
5707 put_numeric_type(&mut type_name, scalar, &[size])?;
5708 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5709 }
5710 };
5711
5712 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5713 let level = back::Level(1);
5714 writeln!(
5715 self.out,
5716 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5717 )?;
5718 writeln!(self.out, "}}")?;
5719 writeln!(self.out)?;
5720 }
5721 _ => {}
5722 }
5723 Ok(())
5724 }
5725
5726 fn write_wrapped_binary_op(
5727 &mut self,
5728 module: &crate::Module,
5729 func_ctx: &back::FunctionCtx,
5730 expr: Handle<crate::Expression>,
5731 op: crate::BinaryOperator,
5732 left: Handle<crate::Expression>,
5733 right: Handle<crate::Expression>,
5734 ) -> BackendResult {
5735 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5736 let left_ty = func_ctx.resolve_type(left, &module.types);
5737 let right_ty = func_ctx.resolve_type(right, &module.types);
5738 match (op, expr_ty.scalar_kind()) {
5739 (
5746 crate::BinaryOperator::Divide,
5747 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5748 ) => {
5749 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5750 return Ok(());
5751 };
5752 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5753 return Ok(());
5754 };
5755 let wrapped = WrappedFunction::BinaryOp {
5756 op,
5757 left_ty: left_wrapped_ty,
5758 right_ty: right_wrapped_ty,
5759 };
5760 if !self.wrapped_functions.insert(wrapped) {
5761 return Ok(());
5762 }
5763
5764 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5765 return Ok(());
5766 };
5767 let mut type_name = String::new();
5768 match vector_size {
5769 None => put_numeric_type(&mut type_name, scalar, &[])?,
5770 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5771 };
5772 writeln!(
5773 self.out,
5774 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5775 )?;
5776 let level = back::Level(1);
5777 match scalar.kind {
5778 crate::ScalarKind::Sint => {
5779 let min_val = match scalar.width {
5780 4 => crate::Literal::I32(i32::MIN),
5781 8 => crate::Literal::I64(i64::MIN),
5782 _ => {
5783 return Err(Error::GenericValidation(format!(
5784 "Unexpected width for scalar {scalar:?}"
5785 )));
5786 }
5787 };
5788 write!(
5789 self.out,
5790 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5791 )?;
5792 self.put_literal(min_val)?;
5793 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5794 }
5795 crate::ScalarKind::Uint => writeln!(
5796 self.out,
5797 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5798 )?,
5799 _ => unreachable!(),
5800 }
5801 writeln!(self.out, "}}")?;
5802 writeln!(self.out)?;
5803 }
5804 (
5817 crate::BinaryOperator::Modulo,
5818 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5819 ) => {
5820 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5821 return Ok(());
5822 };
5823 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5824 else {
5825 return Ok(());
5826 };
5827 let wrapped = WrappedFunction::BinaryOp {
5828 op,
5829 left_ty: left_wrapped_ty,
5830 right_ty: (right_vector_size, right_scalar),
5831 };
5832 if !self.wrapped_functions.insert(wrapped) {
5833 return Ok(());
5834 }
5835
5836 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5837 return Ok(());
5838 };
5839 let mut type_name = String::new();
5840 match vector_size {
5841 None => put_numeric_type(&mut type_name, scalar, &[])?,
5842 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5843 };
5844 let mut rhs_type_name = String::new();
5845 match right_vector_size {
5846 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5847 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5848 };
5849
5850 writeln!(
5851 self.out,
5852 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5853 )?;
5854 let level = back::Level(1);
5855 match scalar.kind {
5856 crate::ScalarKind::Sint => {
5857 let min_val = match scalar.width {
5858 4 => crate::Literal::I32(i32::MIN),
5859 8 => crate::Literal::I64(i64::MIN),
5860 _ => {
5861 return Err(Error::GenericValidation(format!(
5862 "Unexpected width for scalar {scalar:?}"
5863 )));
5864 }
5865 };
5866 write!(
5867 self.out,
5868 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5869 )?;
5870 self.put_literal(min_val)?;
5871 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5872 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5873 }
5874 crate::ScalarKind::Uint => writeln!(
5875 self.out,
5876 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5877 )?,
5878 _ => unreachable!(),
5879 }
5880 writeln!(self.out, "}}")?;
5881 writeln!(self.out)?;
5882 }
5883 _ => {}
5884 }
5885 Ok(())
5886 }
5887
5888 fn get_dot_wrapper_function_helper_name(
5894 &self,
5895 scalar: crate::Scalar,
5896 size: crate::VectorSize,
5897 ) -> String {
5898 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5900
5901 let type_name = scalar.to_msl_name();
5902 let size_suffix = common::vector_size_str(size);
5903 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5904 }
5905
5906 #[allow(clippy::too_many_arguments)]
5907 fn write_wrapped_math_function(
5908 &mut self,
5909 module: &crate::Module,
5910 func_ctx: &back::FunctionCtx,
5911 fun: crate::MathFunction,
5912 arg: Handle<crate::Expression>,
5913 _arg1: Option<Handle<crate::Expression>>,
5914 _arg2: Option<Handle<crate::Expression>>,
5915 _arg3: Option<Handle<crate::Expression>>,
5916 ) -> BackendResult {
5917 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5918 match fun {
5919 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5927 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5928 return Ok(());
5929 };
5930 let wrapped = WrappedFunction::Math {
5931 fun,
5932 arg_ty: (vector_size, scalar),
5933 };
5934 if !self.wrapped_functions.insert(wrapped) {
5935 return Ok(());
5936 }
5937
5938 let unsigned_scalar = crate::Scalar {
5939 kind: crate::ScalarKind::Uint,
5940 ..scalar
5941 };
5942 let mut type_name = String::new();
5943 let mut unsigned_type_name = String::new();
5944 match vector_size {
5945 None => {
5946 put_numeric_type(&mut type_name, scalar, &[])?;
5947 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5948 }
5949 Some(size) => {
5950 put_numeric_type(&mut type_name, scalar, &[size])?;
5951 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5952 }
5953 };
5954
5955 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5956 let level = back::Level(1);
5957 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5958 writeln!(self.out, "}}")?;
5959 writeln!(self.out)?;
5960 }
5961
5962 crate::MathFunction::Dot => match *arg_ty {
5963 crate::TypeInner::Vector { size, scalar }
5964 if matches!(
5965 scalar.kind,
5966 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5967 ) =>
5968 {
5969 let wrapped = WrappedFunction::Math {
5971 fun,
5972 arg_ty: (Some(size), scalar),
5973 };
5974 if !self.wrapped_functions.insert(wrapped) {
5975 return Ok(());
5976 }
5977
5978 let mut vec_ty = String::new();
5979 put_numeric_type(&mut vec_ty, scalar, &[size])?;
5980 let mut ret_ty = String::new();
5981 put_numeric_type(&mut ret_ty, scalar, &[])?;
5982
5983 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5984
5985 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5987 let level = back::Level(1);
5988 write!(self.out, "{level}return ")?;
5989 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5990 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5991 Ok(())
5992 })?;
5993 writeln!(self.out, ";")?;
5994 writeln!(self.out, "}}")?;
5995 writeln!(self.out)?;
5996 }
5997 _ => {}
5998 },
5999
6000 _ => {}
6001 }
6002 Ok(())
6003 }
6004
6005 fn write_wrapped_cast(
6006 &mut self,
6007 module: &crate::Module,
6008 func_ctx: &back::FunctionCtx,
6009 expr: Handle<crate::Expression>,
6010 kind: crate::ScalarKind,
6011 convert: Option<crate::Bytes>,
6012 ) -> BackendResult {
6013 let src_ty = func_ctx.resolve_type(expr, &module.types);
6024 let Some(width) = convert else {
6025 return Ok(());
6026 };
6027 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6028 return Ok(());
6029 };
6030 let dst_scalar = crate::Scalar { kind, width };
6031 if src_scalar.kind != crate::ScalarKind::Float
6032 || (dst_scalar.kind != crate::ScalarKind::Sint
6033 && dst_scalar.kind != crate::ScalarKind::Uint)
6034 {
6035 return Ok(());
6036 }
6037 let wrapped = WrappedFunction::Cast {
6038 src_scalar,
6039 vector_size,
6040 dst_scalar,
6041 };
6042 if !self.wrapped_functions.insert(wrapped) {
6043 return Ok(());
6044 }
6045 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6046
6047 let mut src_type_name = String::new();
6048 match vector_size {
6049 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6050 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6051 };
6052 let mut dst_type_name = String::new();
6053 match vector_size {
6054 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6055 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6056 };
6057 let fun_name = match dst_scalar {
6058 crate::Scalar::I32 => F2I32_FUNCTION,
6059 crate::Scalar::U32 => F2U32_FUNCTION,
6060 crate::Scalar::I64 => F2I64_FUNCTION,
6061 crate::Scalar::U64 => F2U64_FUNCTION,
6062 _ => unreachable!(),
6063 };
6064
6065 writeln!(
6066 self.out,
6067 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6068 )?;
6069 let level = back::Level(1);
6070 write!(
6071 self.out,
6072 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6073 )?;
6074 self.put_literal(min)?;
6075 write!(self.out, ", ")?;
6076 self.put_literal(max)?;
6077 writeln!(self.out, "));")?;
6078 writeln!(self.out, "}}")?;
6079 writeln!(self.out)?;
6080 Ok(())
6081 }
6082
6083 fn write_convert_yuv_to_rgb_and_return(
6091 &mut self,
6092 level: back::Level,
6093 y: &str,
6094 uv: &str,
6095 params: &str,
6096 ) -> BackendResult {
6097 let l1 = level;
6098 let l2 = l1.next();
6099
6100 writeln!(
6102 self.out,
6103 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6104 )?;
6105
6106 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6109 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6110 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6111 writeln!(
6112 self.out,
6113 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6114 )?;
6115
6116 writeln!(
6119 self.out,
6120 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6121 )?;
6122
6123 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6126 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6127 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6128 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6129
6130 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6131 Ok(())
6132 }
6133
6134 #[allow(clippy::too_many_arguments)]
6135 fn write_wrapped_image_load(
6136 &mut self,
6137 module: &crate::Module,
6138 func_ctx: &back::FunctionCtx,
6139 image: Handle<crate::Expression>,
6140 _coordinate: Handle<crate::Expression>,
6141 _array_index: Option<Handle<crate::Expression>>,
6142 _sample: Option<Handle<crate::Expression>>,
6143 _level: Option<Handle<crate::Expression>>,
6144 ) -> BackendResult {
6145 let class = match *func_ctx.resolve_type(image, &module.types) {
6147 crate::TypeInner::Image { class, .. } => class,
6148 _ => unreachable!(),
6149 };
6150 if class != crate::ImageClass::External {
6151 return Ok(());
6152 }
6153 let wrapped = WrappedFunction::ImageLoad { class };
6154 if !self.wrapped_functions.insert(wrapped) {
6155 return Ok(());
6156 }
6157
6158 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6159 let l1 = back::Level(1);
6160 let l2 = l1.next();
6161 let l3 = l2.next();
6162 writeln!(
6163 self.out,
6164 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6165 )?;
6166 writeln!(
6170 self.out,
6171 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6172 )?;
6173 writeln!(
6174 self.out,
6175 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6176 )?;
6177
6178 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6180 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6181 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6183 writeln!(self.out, "{l1}}} else {{")?;
6184
6185 writeln!(
6187 self.out,
6188 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6189 )?;
6190 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6191
6192 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6194
6195 writeln!(self.out, "{l2}float2 uv;")?;
6196 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6197 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6199 writeln!(self.out, "{l2}}} else {{")?;
6200 writeln!(
6202 self.out,
6203 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6204 )?;
6205 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6206 writeln!(
6207 self.out,
6208 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6209 )?;
6210 writeln!(self.out, "{l2}}}")?;
6211
6212 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6213
6214 writeln!(self.out, "{l1}}}")?;
6215 writeln!(self.out, "}}")?;
6216 writeln!(self.out)?;
6217 Ok(())
6218 }
6219
6220 #[allow(clippy::too_many_arguments)]
6221 fn write_wrapped_image_sample(
6222 &mut self,
6223 module: &crate::Module,
6224 func_ctx: &back::FunctionCtx,
6225 image: Handle<crate::Expression>,
6226 _sampler: Handle<crate::Expression>,
6227 _gather: Option<crate::SwizzleComponent>,
6228 _coordinate: Handle<crate::Expression>,
6229 _array_index: Option<Handle<crate::Expression>>,
6230 _offset: Option<Handle<crate::Expression>>,
6231 _level: crate::SampleLevel,
6232 _depth_ref: Option<Handle<crate::Expression>>,
6233 clamp_to_edge: bool,
6234 ) -> BackendResult {
6235 if !clamp_to_edge {
6238 return Ok(());
6239 }
6240 let class = match *func_ctx.resolve_type(image, &module.types) {
6241 crate::TypeInner::Image { class, .. } => class,
6242 _ => unreachable!(),
6243 };
6244 let wrapped = WrappedFunction::ImageSample {
6245 class,
6246 clamp_to_edge: true,
6247 };
6248 if !self.wrapped_functions.insert(wrapped) {
6249 return Ok(());
6250 }
6251 match class {
6252 crate::ImageClass::External => {
6253 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6254 let l1 = back::Level(1);
6255 let l2 = l1.next();
6256 let l3 = l2.next();
6257 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6258 writeln!(
6259 self.out,
6260 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6261 )?;
6262
6263 writeln!(
6271 self.out,
6272 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6273 )?;
6274 writeln!(
6275 self.out,
6276 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6277 )?;
6278 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6279 writeln!(
6280 self.out,
6281 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6282 )?;
6283 writeln!(
6284 self.out,
6285 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6286 )?;
6287 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6288 writeln!(
6290 self.out,
6291 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6292 )?;
6293 writeln!(self.out, "{l1}}} else {{")?;
6294 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6295 writeln!(
6296 self.out,
6297 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6298 )?;
6299 writeln!(
6300 self.out,
6301 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6302 )?;
6303
6304 writeln!(
6306 self.out,
6307 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6308 )?;
6309 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6310 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6311 writeln!(
6313 self.out,
6314 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6315 )?;
6316 writeln!(self.out, "{l2}}} else {{")?;
6317 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6319 writeln!(
6320 self.out,
6321 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6322 )?;
6323 writeln!(
6324 self.out,
6325 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6326 )?;
6327 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6328 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6329 writeln!(self.out, "{l2}}}")?;
6330
6331 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6332
6333 writeln!(self.out, "{l1}}}")?;
6334 writeln!(self.out, "}}")?;
6335 writeln!(self.out)?;
6336 }
6337 _ => {
6338 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6339 let l1 = back::Level(1);
6340 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6341 writeln!(
6342 self.out,
6343 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6344 )?;
6345 writeln!(self.out, "}}")?;
6346 writeln!(self.out)?;
6347 }
6348 }
6349 Ok(())
6350 }
6351
6352 fn write_wrapped_image_query(
6353 &mut self,
6354 module: &crate::Module,
6355 func_ctx: &back::FunctionCtx,
6356 image: Handle<crate::Expression>,
6357 query: crate::ImageQuery,
6358 ) -> BackendResult {
6359 if !matches!(query, crate::ImageQuery::Size { .. }) {
6361 return Ok(());
6362 }
6363 let class = match *func_ctx.resolve_type(image, &module.types) {
6364 crate::TypeInner::Image { class, .. } => class,
6365 _ => unreachable!(),
6366 };
6367 if class != crate::ImageClass::External {
6368 return Ok(());
6369 }
6370 let wrapped = WrappedFunction::ImageQuerySize { class };
6371 if !self.wrapped_functions.insert(wrapped) {
6372 return Ok(());
6373 }
6374 writeln!(
6375 self.out,
6376 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6377 )?;
6378 let l1 = back::Level(1);
6379 let l2 = l1.next();
6380 writeln!(
6381 self.out,
6382 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6383 )?;
6384 writeln!(self.out, "{l2}return tex.params.size;")?;
6385 writeln!(self.out, "{l1}}} else {{")?;
6386 writeln!(
6388 self.out,
6389 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6390 )?;
6391 writeln!(self.out, "{l1}}}")?;
6392 writeln!(self.out, "}}")?;
6393 writeln!(self.out)?;
6394 Ok(())
6395 }
6396
6397 fn write_wrapped_cooperative_load(
6398 &mut self,
6399 module: &crate::Module,
6400 func_ctx: &back::FunctionCtx,
6401 columns: crate::CooperativeSize,
6402 rows: crate::CooperativeSize,
6403 pointer: Handle<crate::Expression>,
6404 ) -> BackendResult {
6405 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6406 let space = ptr_ty.pointer_space().unwrap();
6407 let space_name = space.to_msl_name().unwrap_or_default();
6408 let scalar = ptr_ty
6409 .pointer_base_type()
6410 .unwrap()
6411 .inner_with(&module.types)
6412 .scalar()
6413 .unwrap();
6414 let wrapped = WrappedFunction::CooperativeLoad {
6415 space_name,
6416 columns,
6417 rows,
6418 scalar,
6419 };
6420 if !self.wrapped_functions.insert(wrapped) {
6421 return Ok(());
6422 }
6423 let scalar_name = scalar.to_msl_name();
6424 writeln!(
6425 self.out,
6426 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6427 columns as u32, rows as u32,
6428 )?;
6429 let l1 = back::Level(1);
6430 writeln!(
6431 self.out,
6432 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6433 columns as u32, rows as u32
6434 )?;
6435 let matrix_origin = "0";
6436 writeln!(
6437 self.out,
6438 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6439 )?;
6440 writeln!(self.out, "{l1}return m;")?;
6441 writeln!(self.out, "}}")?;
6442 writeln!(self.out)?;
6443 Ok(())
6444 }
6445
6446 fn write_wrapped_cooperative_multiply_add(
6447 &mut self,
6448 module: &crate::Module,
6449 func_ctx: &back::FunctionCtx,
6450 space: crate::AddressSpace,
6451 a: Handle<crate::Expression>,
6452 b: Handle<crate::Expression>,
6453 ) -> BackendResult {
6454 let space_name = space.to_msl_name().unwrap_or_default();
6455 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6456 crate::TypeInner::CooperativeMatrix {
6457 columns,
6458 rows,
6459 scalar,
6460 ..
6461 } => (columns, rows, scalar),
6462 _ => unreachable!(),
6463 };
6464 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6465 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6466 _ => unreachable!(),
6467 };
6468 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6469 space_name,
6470 columns: b_c,
6471 rows: a_r,
6472 intermediate: a_c,
6473 scalar,
6474 };
6475 if !self.wrapped_functions.insert(wrapped) {
6476 return Ok(());
6477 }
6478 let scalar_name = scalar.to_msl_name();
6479 writeln!(
6480 self.out,
6481 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6482 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6483 )?;
6484 let l1 = back::Level(1);
6485 writeln!(
6486 self.out,
6487 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6488 b_c as u32, a_r as u32
6489 )?;
6490 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6491 writeln!(self.out, "{l1}return d;")?;
6492 writeln!(self.out, "}}")?;
6493 writeln!(self.out)?;
6494 Ok(())
6495 }
6496
6497 pub(super) fn write_wrapped_functions(
6498 &mut self,
6499 module: &crate::Module,
6500 func_ctx: &back::FunctionCtx,
6501 ) -> BackendResult {
6502 for (expr_handle, expr) in func_ctx.expressions.iter() {
6503 match *expr {
6504 crate::Expression::Unary { op, expr: operand } => {
6505 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6506 }
6507 crate::Expression::Binary { op, left, right } => {
6508 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6509 }
6510 crate::Expression::Math {
6511 fun,
6512 arg,
6513 arg1,
6514 arg2,
6515 arg3,
6516 } => {
6517 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6518 }
6519 crate::Expression::As {
6520 expr,
6521 kind,
6522 convert,
6523 } => {
6524 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6525 }
6526 crate::Expression::ImageLoad {
6527 image,
6528 coordinate,
6529 array_index,
6530 sample,
6531 level,
6532 } => {
6533 self.write_wrapped_image_load(
6534 module,
6535 func_ctx,
6536 image,
6537 coordinate,
6538 array_index,
6539 sample,
6540 level,
6541 )?;
6542 }
6543 crate::Expression::ImageSample {
6544 image,
6545 sampler,
6546 gather,
6547 coordinate,
6548 array_index,
6549 offset,
6550 level,
6551 depth_ref,
6552 clamp_to_edge,
6553 } => {
6554 self.write_wrapped_image_sample(
6555 module,
6556 func_ctx,
6557 image,
6558 sampler,
6559 gather,
6560 coordinate,
6561 array_index,
6562 offset,
6563 level,
6564 depth_ref,
6565 clamp_to_edge,
6566 )?;
6567 }
6568 crate::Expression::ImageQuery { image, query } => {
6569 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6570 }
6571 crate::Expression::CooperativeLoad {
6572 columns,
6573 rows,
6574 role: _,
6575 ref data,
6576 } => {
6577 self.write_wrapped_cooperative_load(
6578 module,
6579 func_ctx,
6580 columns,
6581 rows,
6582 data.pointer,
6583 )?;
6584 }
6585 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6586 let space = crate::AddressSpace::Private;
6587 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6588 }
6589 _ => {}
6590 }
6591 }
6592
6593 Ok(())
6594 }
6595
6596 fn write_functions(
6598 &mut self,
6599 module: &crate::Module,
6600 mod_info: &valid::ModuleInfo,
6601 options: &Options,
6602 pipeline_options: &PipelineOptions,
6603 ) -> Result<TranslationInfo, Error> {
6604 use back::msl::VertexFormat;
6605
6606 struct AttributeMappingResolved {
6609 ty_name: String,
6610 dimension: u32,
6611 ty_is_int: bool,
6612 name: String,
6613 }
6614 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6615
6616 struct VertexBufferMappingResolved<'a> {
6617 id: u32,
6618 stride: u32,
6619 step_mode: back::msl::VertexBufferStepMode,
6620 ty_name: String,
6621 param_name: String,
6622 elem_name: String,
6623 attributes: &'a Vec<back::msl::AttributeMapping>,
6624 }
6625 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6626
6627 struct UnpackingFunction {
6629 name: String,
6630 byte_count: u32,
6631 dimension: u32,
6632 }
6633 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6634
6635 let mut needs_vertex_id = false;
6641 let v_id = self.namer.call("v_id");
6642
6643 let mut needs_instance_id = false;
6644 let i_id = self.namer.call("i_id");
6645 if pipeline_options.vertex_pulling_transform {
6646 for vbm in &pipeline_options.vertex_buffer_mappings {
6647 let buffer_id = vbm.id;
6648 let buffer_stride = vbm.stride;
6649
6650 assert!(
6651 buffer_stride > 0,
6652 "Vertex pulling requires a non-zero buffer stride."
6653 );
6654
6655 match vbm.step_mode {
6656 back::msl::VertexBufferStepMode::Constant => {}
6657 back::msl::VertexBufferStepMode::ByVertex => {
6658 needs_vertex_id = true;
6659 }
6660 back::msl::VertexBufferStepMode::ByInstance => {
6661 needs_instance_id = true;
6662 }
6663 }
6664
6665 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6666 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6667 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6668
6669 vbm_resolved.push(VertexBufferMappingResolved {
6670 id: buffer_id,
6671 stride: buffer_stride,
6672 step_mode: vbm.step_mode,
6673 ty_name: buffer_ty,
6674 param_name: buffer_param,
6675 elem_name: buffer_elem,
6676 attributes: &vbm.attributes,
6677 });
6678
6679 for attribute in &vbm.attributes {
6681 if unpacking_functions.contains_key(&attribute.format) {
6682 continue;
6683 }
6684 let (name, byte_count, dimension) =
6685 match self.write_unpacking_function(attribute.format) {
6686 Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6687 _ => {
6688 continue;
6689 }
6690 };
6691 unpacking_functions.insert(
6692 attribute.format,
6693 UnpackingFunction {
6694 name,
6695 byte_count,
6696 dimension,
6697 },
6698 );
6699 }
6700 }
6701 }
6702
6703 let mut pass_through_globals = Vec::new();
6704 for (fun_handle, fun) in module.functions.iter() {
6705 log::trace!(
6706 "function {:?}, handle {:?}",
6707 fun.name.as_deref().unwrap_or("(anonymous)"),
6708 fun_handle
6709 );
6710
6711 let ctx = back::FunctionCtx {
6712 ty: back::FunctionType::Function(fun_handle),
6713 info: &mod_info[fun_handle],
6714 expressions: &fun.expressions,
6715 named_expressions: &fun.named_expressions,
6716 };
6717
6718 writeln!(self.out)?;
6719 self.write_wrapped_functions(module, &ctx)?;
6720
6721 let fun_info = &mod_info[fun_handle];
6722 pass_through_globals.clear();
6723 let mut needs_buffer_sizes = false;
6724 for (handle, var) in module.global_variables.iter() {
6725 if !fun_info[handle].is_empty() {
6726 if var.space.needs_pass_through() {
6727 pass_through_globals.push(handle);
6728 }
6729 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6730 }
6731 }
6732
6733 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6734 match fun.result {
6735 Some(ref result) => {
6736 let ty_name = TypeContext {
6737 handle: result.ty,
6738 gctx: module.to_ctx(),
6739 names: &self.names,
6740 access: crate::StorageAccess::empty(),
6741 first_time: false,
6742 };
6743 write!(self.out, "{ty_name}")?;
6744 }
6745 None => {
6746 write!(self.out, "void")?;
6747 }
6748 }
6749 writeln!(self.out, " {fun_name}(")?;
6750
6751 for (index, arg) in fun.arguments.iter().enumerate() {
6752 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6753 let param_type_name = TypeContext {
6754 handle: arg.ty,
6755 gctx: module.to_ctx(),
6756 names: &self.names,
6757 access: crate::StorageAccess::empty(),
6758 first_time: false,
6759 };
6760 let separator = separate(
6761 !pass_through_globals.is_empty()
6762 || index + 1 != fun.arguments.len()
6763 || needs_buffer_sizes,
6764 );
6765 writeln!(
6766 self.out,
6767 "{}{} {}{}",
6768 back::INDENT,
6769 param_type_name,
6770 name,
6771 separator
6772 )?;
6773 }
6774 for (index, &handle) in pass_through_globals.iter().enumerate() {
6775 let tyvar = TypedGlobalVariable {
6776 module,
6777 names: &self.names,
6778 handle,
6779 usage: fun_info[handle],
6780 reference: true,
6781 };
6782 let separator =
6783 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6784 write!(self.out, "{}", back::INDENT)?;
6785 tyvar.try_fmt(&mut self.out)?;
6786 writeln!(self.out, "{separator}")?;
6787 }
6788
6789 if needs_buffer_sizes {
6790 writeln!(
6791 self.out,
6792 "{}constant _mslBufferSizes& _buffer_sizes",
6793 back::INDENT
6794 )?;
6795 }
6796
6797 writeln!(self.out, ") {{")?;
6798
6799 let guarded_indices =
6800 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6801
6802 let context = StatementContext {
6803 expression: ExpressionContext {
6804 function: fun,
6805 origin: FunctionOrigin::Handle(fun_handle),
6806 info: fun_info,
6807 lang_version: options.lang_version,
6808 policies: options.bounds_check_policies,
6809 guarded_indices,
6810 module,
6811 mod_info,
6812 pipeline_options,
6813 force_loop_bounding: options.force_loop_bounding,
6814 },
6815 result_struct: None,
6816 };
6817
6818 self.put_locals(&context.expression)?;
6819 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6820 self.put_block(back::Level(1), &fun.body, &context)?;
6821 writeln!(self.out, "}}")?;
6822 self.named_expressions.clear();
6823 }
6824
6825 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6826 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6827
6828 let mut info = TranslationInfo {
6829 entry_point_names: Vec::with_capacity(ep_range.len()),
6830 };
6831
6832 for ep_index in ep_range {
6833 let ep = &module.entry_points[ep_index];
6834 let fun = &ep.function;
6835 let fun_info = mod_info.get_entry_point(ep_index);
6836 let mut ep_error = None;
6837
6838 let mut v_existing_id = None;
6842 let mut i_existing_id = None;
6843
6844 log::trace!(
6845 "entry point {:?}, index {:?}",
6846 fun.name.as_deref().unwrap_or("(anonymous)"),
6847 ep_index
6848 );
6849
6850 let ctx = back::FunctionCtx {
6851 ty: back::FunctionType::EntryPoint(ep_index as u16),
6852 info: fun_info,
6853 expressions: &fun.expressions,
6854 named_expressions: &fun.named_expressions,
6855 };
6856
6857 self.write_wrapped_functions(module, &ctx)?;
6858
6859 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6860 crate::ShaderStage::Vertex => (
6861 "vertex",
6862 LocationMode::VertexInput,
6863 LocationMode::VertexOutput,
6864 true,
6865 ),
6866 crate::ShaderStage::Fragment => (
6867 "fragment",
6868 LocationMode::FragmentInput,
6869 LocationMode::FragmentOutput,
6870 false,
6871 ),
6872 crate::ShaderStage::Compute => (
6873 "kernel",
6874 LocationMode::Uniform,
6875 LocationMode::Uniform,
6876 false,
6877 ),
6878 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6879 };
6880
6881 let do_vertex_pulling = can_vertex_pull
6883 && pipeline_options.vertex_pulling_transform
6884 && !pipeline_options.vertex_buffer_mappings.is_empty();
6885
6886 let needs_buffer_sizes = do_vertex_pulling
6888 || module
6889 .global_variables
6890 .iter()
6891 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6892 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6893
6894 if !options.fake_missing_bindings {
6897 for (var_handle, var) in module.global_variables.iter() {
6898 if fun_info[var_handle].is_empty() {
6899 continue;
6900 }
6901 match var.space {
6902 crate::AddressSpace::Uniform
6903 | crate::AddressSpace::Storage { .. }
6904 | crate::AddressSpace::Handle => {
6905 let br = match var.binding {
6906 Some(ref br) => br,
6907 None => {
6908 let var_name = var.name.clone().unwrap_or_default();
6909 ep_error =
6910 Some(super::EntryPointError::MissingBinding(var_name));
6911 break;
6912 }
6913 };
6914 let target = options.get_resource_binding_target(ep, br);
6915 let good = match target {
6916 Some(target) => {
6917 match module.types[var.ty].inner {
6921 crate::TypeInner::Image {
6922 class: crate::ImageClass::External,
6923 ..
6924 } => target.external_texture.is_some(),
6925 crate::TypeInner::Image { .. } => target.texture.is_some(),
6926 crate::TypeInner::Sampler { .. } => {
6927 target.sampler.is_some()
6928 }
6929 _ => target.buffer.is_some(),
6930 }
6931 }
6932 None => false,
6933 };
6934 if !good {
6935 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6936 break;
6937 }
6938 }
6939 crate::AddressSpace::Immediate => {
6940 if let Err(e) = options.resolve_immediates(ep) {
6941 ep_error = Some(e);
6942 break;
6943 }
6944 }
6945 crate::AddressSpace::TaskPayload => {
6946 unimplemented!()
6947 }
6948 crate::AddressSpace::Function
6949 | crate::AddressSpace::Private
6950 | crate::AddressSpace::WorkGroup => {}
6951 }
6952 }
6953 if needs_buffer_sizes {
6954 if let Err(err) = options.resolve_sizes_buffer(ep) {
6955 ep_error = Some(err);
6956 }
6957 }
6958 }
6959
6960 if let Some(err) = ep_error {
6961 info.entry_point_names.push(Err(err));
6962 continue;
6963 }
6964 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6965 info.entry_point_names.push(Ok(fun_name.clone()));
6966
6967 writeln!(self.out)?;
6968
6969 let mut flattened_member_names = FastHashMap::default();
6975 let mut varyings_namer = proc::Namer::default();
6977
6978 let mut flattened_arguments = Vec::new();
6983 for (arg_index, arg) in fun.arguments.iter().enumerate() {
6984 match module.types[arg.ty].inner {
6985 crate::TypeInner::Struct { ref members, .. } => {
6986 for (member_index, member) in members.iter().enumerate() {
6987 let member_index = member_index as u32;
6988 flattened_arguments.push((
6989 NameKey::StructMember(arg.ty, member_index),
6990 member.ty,
6991 member.binding.as_ref(),
6992 ));
6993 let name_key = NameKey::StructMember(arg.ty, member_index);
6994 let name = match member.binding {
6995 Some(crate::Binding::Location { .. }) => {
6996 if do_vertex_pulling {
6997 self.namer.call(&self.names[&name_key])
6998 } else {
6999 varyings_namer.call(&self.names[&name_key])
7000 }
7001 }
7002 _ => self.namer.call(&self.names[&name_key]),
7003 };
7004 flattened_member_names.insert(name_key, name);
7005 }
7006 }
7007 _ => flattened_arguments.push((
7008 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7009 arg.ty,
7010 arg.binding.as_ref(),
7011 )),
7012 }
7013 }
7014
7015 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7020 let varyings_member_name = self.namer.call("varyings");
7021 let mut has_varyings = false;
7022 if !flattened_arguments.is_empty() {
7023 if !do_vertex_pulling {
7024 writeln!(self.out, "struct {stage_in_name} {{")?;
7025 }
7026 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7027 let Some(binding) = binding else {
7028 continue;
7029 };
7030 let name = match *name_key {
7031 NameKey::StructMember(..) => &flattened_member_names[name_key],
7032 _ => &self.names[name_key],
7033 };
7034 let ty_name = TypeContext {
7035 handle: ty,
7036 gctx: module.to_ctx(),
7037 names: &self.names,
7038 access: crate::StorageAccess::empty(),
7039 first_time: false,
7040 };
7041 let resolved = options.resolve_local_binding(binding, in_mode)?;
7042 let location = match *binding {
7043 crate::Binding::Location { location, .. } => Some(location),
7044 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7045 crate::Binding::BuiltIn(_) => continue,
7046 };
7047 if do_vertex_pulling {
7048 let Some(location) = location else {
7049 continue;
7050 };
7051 am_resolved.insert(
7053 location,
7054 AttributeMappingResolved {
7055 ty_name: ty_name.to_string(),
7056 dimension: ty_name.vertex_input_dimension(),
7057 ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
7058 name: name.to_string(),
7059 },
7060 );
7061 } else {
7062 has_varyings = true;
7063 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7064 resolved.try_fmt(&mut self.out)?;
7065 writeln!(self.out, ";")?;
7066 }
7067 }
7068 if !do_vertex_pulling {
7069 writeln!(self.out, "}};")?;
7070 }
7071 }
7072
7073 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7076 let result_member_name = self.namer.call("member");
7077 let result_type_name = match fun.result {
7078 Some(ref result) => {
7079 let mut result_members = Vec::new();
7080 if let crate::TypeInner::Struct { ref members, .. } =
7081 module.types[result.ty].inner
7082 {
7083 for (member_index, member) in members.iter().enumerate() {
7084 result_members.push((
7085 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7086 member.ty,
7087 member.binding.as_ref(),
7088 ));
7089 }
7090 } else {
7091 result_members.push((
7092 &result_member_name,
7093 result.ty,
7094 result.binding.as_ref(),
7095 ));
7096 }
7097
7098 writeln!(self.out, "struct {stage_out_name} {{")?;
7099 let mut has_point_size = false;
7100 for (name, ty, binding) in result_members {
7101 let ty_name = TypeContext {
7102 handle: ty,
7103 gctx: module.to_ctx(),
7104 names: &self.names,
7105 access: crate::StorageAccess::empty(),
7106 first_time: true,
7107 };
7108 let binding = binding.ok_or_else(|| {
7109 Error::GenericValidation("Expected binding, got None".into())
7110 })?;
7111
7112 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7113 has_point_size = true;
7114 if !pipeline_options.allow_and_force_point_size {
7115 continue;
7116 }
7117 }
7118
7119 let array_len = match module.types[ty].inner {
7120 crate::TypeInner::Array {
7121 size: crate::ArraySize::Constant(size),
7122 ..
7123 } => Some(size),
7124 _ => None,
7125 };
7126 let resolved = options.resolve_local_binding(binding, out_mode)?;
7127 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7128 if let Some(array_len) = array_len {
7129 write!(self.out, " [{array_len}]")?;
7130 }
7131 resolved.try_fmt(&mut self.out)?;
7132 writeln!(self.out, ";")?;
7133 }
7134
7135 if pipeline_options.allow_and_force_point_size
7136 && ep.stage == crate::ShaderStage::Vertex
7137 && !has_point_size
7138 {
7139 writeln!(
7141 self.out,
7142 "{}float _point_size [[point_size]];",
7143 back::INDENT
7144 )?;
7145 }
7146 writeln!(self.out, "}};")?;
7147 &stage_out_name
7148 }
7149 None => "void",
7150 };
7151
7152 if do_vertex_pulling {
7155 for vbm in &vbm_resolved {
7156 let buffer_stride = vbm.stride;
7157 let buffer_ty = &vbm.ty_name;
7158
7159 writeln!(
7163 self.out,
7164 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7165 )?;
7166 }
7167 }
7168
7169 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
7171
7172 let mut is_first_argument = true;
7173 let mut separator = || {
7174 if is_first_argument {
7175 is_first_argument = false;
7176 ' '
7177 } else {
7178 ','
7179 }
7180 };
7181
7182 if has_varyings {
7185 writeln!(
7186 self.out,
7187 "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
7188 separator()
7189 )?;
7190 }
7191
7192 let mut local_invocation_id = None;
7193
7194 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7197 let binding = match binding {
7198 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7199 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7200 _ => continue,
7201 };
7202 let name = match *name_key {
7203 NameKey::StructMember(..) => &flattened_member_names[name_key],
7204 _ => &self.names[name_key],
7205 };
7206
7207 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
7208 local_invocation_id = Some(name_key);
7209 }
7210
7211 let ty_name = TypeContext {
7212 handle: ty,
7213 gctx: module.to_ctx(),
7214 names: &self.names,
7215 access: crate::StorageAccess::empty(),
7216 first_time: false,
7217 };
7218
7219 match *binding {
7220 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7221 v_existing_id = Some(name.clone());
7222 }
7223 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7224 i_existing_id = Some(name.clone());
7225 }
7226 _ => {}
7227 };
7228
7229 let resolved = options.resolve_local_binding(binding, in_mode)?;
7230 write!(self.out, "{} {ty_name} {name}", separator())?;
7231 resolved.try_fmt(&mut self.out)?;
7232 writeln!(self.out)?;
7233 }
7234
7235 let need_workgroup_variables_initialization =
7236 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7237
7238 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7239 writeln!(
7240 self.out,
7241 "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7242 )?;
7243 }
7244
7245 for (handle, var) in module.global_variables.iter() {
7250 let usage = fun_info[handle];
7251 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7252 continue;
7253 }
7254
7255 if options.lang_version < (1, 2) {
7256 match var.space {
7257 crate::AddressSpace::Storage { access }
7267 if access.contains(crate::StorageAccess::STORE)
7268 && ep.stage == crate::ShaderStage::Fragment =>
7269 {
7270 return Err(Error::UnsupportedWriteableStorageBuffer)
7271 }
7272 crate::AddressSpace::Handle => {
7273 match module.types[var.ty].inner {
7274 crate::TypeInner::Image {
7275 class: crate::ImageClass::Storage { access, .. },
7276 ..
7277 } => {
7278 if access.contains(crate::StorageAccess::STORE)
7288 && (ep.stage == crate::ShaderStage::Vertex
7289 || ep.stage == crate::ShaderStage::Fragment)
7290 {
7291 return Err(Error::UnsupportedWriteableStorageTexture(
7292 ep.stage,
7293 ));
7294 }
7295
7296 if access.contains(
7297 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7298 ) {
7299 return Err(Error::UnsupportedRWStorageTexture);
7300 }
7301 }
7302 _ => {}
7303 }
7304 }
7305 _ => {}
7306 }
7307 }
7308
7309 match var.space {
7311 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7312 crate::TypeInner::BindingArray { base, .. } => {
7313 match module.types[base].inner {
7314 crate::TypeInner::Sampler { .. } => {
7315 if options.lang_version < (2, 0) {
7316 return Err(Error::UnsupportedArrayOf(
7317 "samplers".to_string(),
7318 ));
7319 }
7320 }
7321 crate::TypeInner::Image { class, .. } => match class {
7322 crate::ImageClass::Sampled { .. }
7323 | crate::ImageClass::Depth { .. }
7324 | crate::ImageClass::Storage {
7325 access: crate::StorageAccess::LOAD,
7326 ..
7327 } => {
7328 if options.lang_version < (2, 0) {
7333 return Err(Error::UnsupportedArrayOf(
7334 "textures".to_string(),
7335 ));
7336 }
7337 }
7338 crate::ImageClass::Storage {
7339 access: crate::StorageAccess::STORE,
7340 ..
7341 } => {
7342 if options.lang_version < (2, 0) {
7347 return Err(Error::UnsupportedArrayOf(
7348 "write-only textures".to_string(),
7349 ));
7350 }
7351 }
7352 crate::ImageClass::Storage { .. } => {
7353 if options.lang_version < (3, 0) {
7354 return Err(Error::UnsupportedArrayOf(
7355 "read-write textures".to_string(),
7356 ));
7357 }
7358 }
7359 crate::ImageClass::External => {
7360 return Err(Error::UnsupportedArrayOf(
7361 "external textures".to_string(),
7362 ));
7363 }
7364 },
7365 _ => {
7366 return Err(Error::UnsupportedArrayOfType(base));
7367 }
7368 }
7369 }
7370 _ => {}
7371 },
7372 _ => {}
7373 }
7374
7375 let resolved = match var.space {
7377 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7378 crate::AddressSpace::WorkGroup => None,
7379 _ => options
7380 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7381 .ok(),
7382 };
7383 if let Some(ref resolved) = resolved {
7384 if resolved.as_inline_sampler(options).is_some() {
7386 continue;
7387 }
7388 }
7389
7390 match module.types[var.ty].inner {
7391 crate::TypeInner::Image {
7392 class: crate::ImageClass::External,
7393 ..
7394 } => {
7395 let target = match resolved {
7399 Some(back::msl::ResolvedBinding::Resource(target)) => {
7400 target.external_texture
7401 }
7402 _ => None,
7403 };
7404
7405 for i in 0..3 {
7406 write!(self.out, "{} ", separator())?;
7407
7408 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7409 handle,
7410 ExternalTextureNameKey::Plane(i),
7411 )];
7412 write!(
7413 self.out,
7414 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7415 )?;
7416 if let Some(ref target) = target {
7417 write!(self.out, " [[texture({})]]", target.planes[i])?;
7418 }
7419 writeln!(self.out)?;
7420 }
7421 let params_ty_name = &self.names
7422 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7423 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7424 handle,
7425 ExternalTextureNameKey::Params,
7426 )];
7427 write!(self.out, "{} ", separator())?;
7428 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7429 if let Some(ref target) = target {
7430 write!(self.out, " [[buffer({})]]", target.params)?;
7431 }
7432 }
7433 _ => {
7434 let tyvar = TypedGlobalVariable {
7435 module,
7436 names: &self.names,
7437 handle,
7438 usage,
7439 reference: true,
7440 };
7441 write!(self.out, "{} ", separator())?;
7442 tyvar.try_fmt(&mut self.out)?;
7443 if let Some(resolved) = resolved {
7444 resolved.try_fmt(&mut self.out)?;
7445 }
7446 if let Some(value) = var.init {
7447 write!(self.out, " = ")?;
7448 self.put_const_expression(
7449 value,
7450 module,
7451 mod_info,
7452 &module.global_expressions,
7453 )?;
7454 }
7455 }
7456 }
7457 writeln!(self.out)?;
7458 }
7459
7460 if do_vertex_pulling {
7461 if needs_vertex_id && v_existing_id.is_none() {
7462 writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7464 }
7465
7466 if needs_instance_id && i_existing_id.is_none() {
7467 writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7468 }
7469
7470 for vbm in &vbm_resolved {
7473 let id = &vbm.id;
7474 let ty_name = &vbm.ty_name;
7475 let param_name = &vbm.param_name;
7476 writeln!(
7477 self.out,
7478 "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7479 separator()
7480 )?;
7481 }
7482 }
7483
7484 if needs_buffer_sizes {
7487 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7489 write!(
7490 self.out,
7491 "{} constant _mslBufferSizes& _buffer_sizes",
7492 separator()
7493 )?;
7494 resolved.try_fmt(&mut self.out)?;
7495 writeln!(self.out)?;
7496 }
7497
7498 writeln!(self.out, ") {{")?;
7500
7501 if do_vertex_pulling {
7503 for vbm in &vbm_resolved {
7506 for attribute in vbm.attributes {
7507 let location = attribute.shader_location;
7508 let am_option = am_resolved.get(&location);
7509 if am_option.is_none() {
7510 continue;
7513 }
7514 let am = am_option.unwrap();
7515 let attribute_ty_name = &am.ty_name;
7516 let attribute_name = &am.name;
7517
7518 writeln!(
7519 self.out,
7520 "{}{attribute_ty_name} {attribute_name} = {{}};",
7521 back::Level(1)
7522 )?;
7523 }
7524
7525 write!(self.out, "{}if (", back::Level(1))?;
7528
7529 let idx = &vbm.id;
7530 let stride = &vbm.stride;
7531 let index_name = match vbm.step_mode {
7532 back::msl::VertexBufferStepMode::Constant => "0",
7533 back::msl::VertexBufferStepMode::ByVertex => {
7534 if let Some(ref name) = v_existing_id {
7535 name
7536 } else {
7537 &v_id
7538 }
7539 }
7540 back::msl::VertexBufferStepMode::ByInstance => {
7541 if let Some(ref name) = i_existing_id {
7542 name
7543 } else {
7544 &i_id
7545 }
7546 }
7547 };
7548 write!(
7549 self.out,
7550 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7551 )?;
7552
7553 writeln!(self.out, ") {{")?;
7554
7555 let ty_name = &vbm.ty_name;
7557 let elem_name = &vbm.elem_name;
7558 let param_name = &vbm.param_name;
7559
7560 writeln!(
7561 self.out,
7562 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7563 back::Level(2),
7564 )?;
7565
7566 for attribute in vbm.attributes {
7569 let location = attribute.shader_location;
7570 let Some(am) = am_resolved.get(&location) else {
7571 continue;
7575 };
7576 let attribute_name = &am.name;
7577 let attribute_ty_name = &am.ty_name;
7578
7579 let offset = attribute.offset;
7580 let func = unpacking_functions
7581 .get(&attribute.format)
7582 .expect("Should have generated this unpacking function earlier.");
7583 let func_name = &func.name;
7584
7585 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7592
7593 if needs_padding_or_truncation != Ordering::Equal {
7594 writeln!(
7597 self.out,
7598 "{}// {attribute_ty_name} <- {:?}",
7599 back::Level(2),
7600 attribute.format
7601 )?;
7602 }
7603
7604 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7605
7606 if needs_padding_or_truncation == Ordering::Greater {
7607 write!(self.out, "{attribute_ty_name}(")?;
7609 }
7610
7611 write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7613 for i in (offset + 1)..(offset + func.byte_count) {
7614 write!(self.out, ", {elem_name}.data[{i}]")?;
7615 }
7616 write!(self.out, ")")?;
7617
7618 match needs_padding_or_truncation {
7619 Ordering::Greater => {
7620 let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7622 let one_value = if am.ty_is_int { "1" } else { "1.0" };
7623 for i in func.dimension..am.dimension {
7624 write!(
7625 self.out,
7626 ", {}",
7627 if i == 3 { one_value } else { zero_value }
7628 )?;
7629 }
7630 write!(self.out, ")")?;
7631 }
7632 Ordering::Less => {
7633 write!(
7635 self.out,
7636 ".{}",
7637 &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7638 )?;
7639 }
7640 Ordering::Equal => {}
7641 }
7642
7643 writeln!(self.out, ";")?;
7644 }
7645
7646 writeln!(self.out, "{}}}", back::Level(1))?;
7648 }
7649 }
7650
7651 if need_workgroup_variables_initialization {
7652 self.write_workgroup_variables_initialization(
7653 module,
7654 mod_info,
7655 fun_info,
7656 local_invocation_id,
7657 )?;
7658 }
7659
7660 for (handle, var) in module.global_variables.iter() {
7663 let usage = fun_info[handle];
7664 if usage.is_empty() {
7665 continue;
7666 }
7667 if var.space == crate::AddressSpace::Private {
7668 let tyvar = TypedGlobalVariable {
7669 module,
7670 names: &self.names,
7671 handle,
7672 usage,
7673
7674 reference: false,
7675 };
7676 write!(self.out, "{}", back::INDENT)?;
7677 tyvar.try_fmt(&mut self.out)?;
7678 match var.init {
7679 Some(value) => {
7680 write!(self.out, " = ")?;
7681 self.put_const_expression(
7682 value,
7683 module,
7684 mod_info,
7685 &module.global_expressions,
7686 )?;
7687 writeln!(self.out, ";")?;
7688 }
7689 None => {
7690 writeln!(self.out, " = {{}};")?;
7691 }
7692 };
7693 } else if let Some(ref binding) = var.binding {
7694 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7695 if let Some(sampler) = resolved.as_inline_sampler(options) {
7696 let name = &self.names[&NameKey::GlobalVariable(handle)];
7698 writeln!(
7699 self.out,
7700 "{}constexpr {}::sampler {}(",
7701 back::INDENT,
7702 NAMESPACE,
7703 name
7704 )?;
7705 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7706 writeln!(self.out, "{});", back::INDENT)?;
7707 } else if let crate::TypeInner::Image {
7708 class: crate::ImageClass::External,
7709 ..
7710 } = module.types[var.ty].inner
7711 {
7712 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7715 let l1 = back::Level(1);
7716 let l2 = l1.next();
7717 writeln!(
7718 self.out,
7719 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7720 )?;
7721 for i in 0..3 {
7722 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7723 handle,
7724 ExternalTextureNameKey::Plane(i),
7725 )];
7726 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7727 }
7728 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7729 handle,
7730 ExternalTextureNameKey::Params,
7731 )];
7732 writeln!(self.out, "{l2}.params = {params_name},")?;
7733 writeln!(self.out, "{l1}}};")?;
7734 }
7735 }
7736 }
7737
7738 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7749 let arg_name =
7750 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7751 match module.types[arg.ty].inner {
7752 crate::TypeInner::Struct { ref members, .. } => {
7753 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7754 write!(
7755 self.out,
7756 "{}const {} {} = {{ ",
7757 back::INDENT,
7758 struct_name,
7759 arg_name
7760 )?;
7761 for (member_index, member) in members.iter().enumerate() {
7762 let key = NameKey::StructMember(arg.ty, member_index as u32);
7763 let name = &flattened_member_names[&key];
7764 if member_index != 0 {
7765 write!(self.out, ", ")?;
7766 }
7767 if self
7769 .struct_member_pads
7770 .contains(&(arg.ty, member_index as u32))
7771 {
7772 write!(self.out, "{{}}, ")?;
7773 }
7774 if let Some(crate::Binding::Location { .. }) = member.binding {
7775 if has_varyings {
7776 write!(self.out, "{varyings_member_name}.")?;
7777 }
7778 }
7779 write!(self.out, "{name}")?;
7780 }
7781 writeln!(self.out, " }};")?;
7782 }
7783 _ => match arg.binding {
7784 Some(crate::Binding::Location { .. })
7785 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
7786 if has_varyings {
7787 writeln!(
7788 self.out,
7789 "{}const auto {} = {}.{};",
7790 back::INDENT,
7791 arg_name,
7792 varyings_member_name,
7793 arg_name
7794 )?;
7795 }
7796 }
7797 _ => {}
7798 },
7799 }
7800 }
7801
7802 let guarded_indices =
7803 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7804
7805 let context = StatementContext {
7806 expression: ExpressionContext {
7807 function: fun,
7808 origin: FunctionOrigin::EntryPoint(ep_index as _),
7809 info: fun_info,
7810 lang_version: options.lang_version,
7811 policies: options.bounds_check_policies,
7812 guarded_indices,
7813 module,
7814 mod_info,
7815 pipeline_options,
7816 force_loop_bounding: options.force_loop_bounding,
7817 },
7818 result_struct: Some(&stage_out_name),
7819 };
7820
7821 self.put_locals(&context.expression)?;
7824 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7825 self.put_block(back::Level(1), &fun.body, &context)?;
7826 writeln!(self.out, "}}")?;
7827 if ep_index + 1 != module.entry_points.len() {
7828 writeln!(self.out)?;
7829 }
7830 self.named_expressions.clear();
7831 }
7832
7833 Ok(info)
7834 }
7835
7836 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7837 if flags.is_empty() {
7840 writeln!(
7841 self.out,
7842 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7843 )?;
7844 }
7845 if flags.contains(crate::Barrier::STORAGE) {
7846 writeln!(
7847 self.out,
7848 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7849 )?;
7850 }
7851 if flags.contains(crate::Barrier::WORK_GROUP) {
7852 writeln!(
7853 self.out,
7854 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7855 )?;
7856 }
7857 if flags.contains(crate::Barrier::SUB_GROUP) {
7858 writeln!(
7859 self.out,
7860 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7861 )?;
7862 }
7863 if flags.contains(crate::Barrier::TEXTURE) {
7864 writeln!(
7865 self.out,
7866 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7867 )?;
7868 }
7869 Ok(())
7870 }
7871}
7872
7873mod workgroup_mem_init {
7876 use crate::EntryPoint;
7877
7878 use super::*;
7879
7880 enum Access {
7881 GlobalVariable(Handle<crate::GlobalVariable>),
7882 StructMember(Handle<crate::Type>, u32),
7883 Array(usize),
7884 }
7885
7886 impl Access {
7887 fn write<W: Write>(
7888 &self,
7889 writer: &mut W,
7890 names: &FastHashMap<NameKey, String>,
7891 ) -> Result<(), core::fmt::Error> {
7892 match *self {
7893 Access::GlobalVariable(handle) => {
7894 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7895 }
7896 Access::StructMember(handle, index) => {
7897 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7898 }
7899 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7900 }
7901 }
7902 }
7903
7904 struct AccessStack {
7905 stack: Vec<Access>,
7906 array_depth: usize,
7907 }
7908
7909 impl AccessStack {
7910 const fn new() -> Self {
7911 Self {
7912 stack: Vec::new(),
7913 array_depth: 0,
7914 }
7915 }
7916
7917 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7918 let array_depth = self.array_depth;
7919 self.stack.push(Access::Array(array_depth));
7920 self.array_depth += 1;
7921 let res = cb(self, array_depth);
7922 self.stack.pop();
7923 self.array_depth -= 1;
7924 res
7925 }
7926
7927 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7928 self.stack.push(new);
7929 let res = cb(self);
7930 self.stack.pop();
7931 res
7932 }
7933
7934 fn write<W: Write>(
7935 &self,
7936 writer: &mut W,
7937 names: &FastHashMap<NameKey, String>,
7938 ) -> Result<(), core::fmt::Error> {
7939 for next in self.stack.iter() {
7940 next.write(writer, names)?;
7941 }
7942 Ok(())
7943 }
7944 }
7945
7946 impl<W: Write> Writer<W> {
7947 pub(super) fn need_workgroup_variables_initialization(
7948 &mut self,
7949 options: &Options,
7950 ep: &EntryPoint,
7951 module: &crate::Module,
7952 fun_info: &valid::FunctionInfo,
7953 ) -> bool {
7954 options.zero_initialize_workgroup_memory
7955 && ep.stage.compute_like()
7956 && module.global_variables.iter().any(|(handle, var)| {
7957 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7958 })
7959 }
7960
7961 pub(super) fn write_workgroup_variables_initialization(
7962 &mut self,
7963 module: &crate::Module,
7964 module_info: &valid::ModuleInfo,
7965 fun_info: &valid::FunctionInfo,
7966 local_invocation_id: Option<&NameKey>,
7967 ) -> BackendResult {
7968 let level = back::Level(1);
7969
7970 writeln!(
7971 self.out,
7972 "{}if ({}::all({} == {}::uint3(0u))) {{",
7973 level,
7974 NAMESPACE,
7975 local_invocation_id
7976 .map(|name_key| self.names[name_key].as_str())
7977 .unwrap_or("__local_invocation_id"),
7978 NAMESPACE,
7979 )?;
7980
7981 let mut access_stack = AccessStack::new();
7982
7983 let vars = module.global_variables.iter().filter(|&(handle, var)| {
7984 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7985 });
7986
7987 for (handle, var) in vars {
7988 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7989 self.write_workgroup_variable_initialization(
7990 module,
7991 module_info,
7992 var.ty,
7993 access_stack,
7994 level.next(),
7995 )
7996 })?;
7997 }
7998
7999 writeln!(self.out, "{level}}}")?;
8000 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8001 }
8002
8003 fn write_workgroup_variable_initialization(
8004 &mut self,
8005 module: &crate::Module,
8006 module_info: &valid::ModuleInfo,
8007 ty: Handle<crate::Type>,
8008 access_stack: &mut AccessStack,
8009 level: back::Level,
8010 ) -> BackendResult {
8011 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8012 write!(self.out, "{level}")?;
8013 access_stack.write(&mut self.out, &self.names)?;
8014 writeln!(self.out, " = {{}};")?;
8015 } else {
8016 match module.types[ty].inner {
8017 crate::TypeInner::Atomic { .. } => {
8018 write!(
8019 self.out,
8020 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8021 )?;
8022 access_stack.write(&mut self.out, &self.names)?;
8023 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8024 }
8025 crate::TypeInner::Array { base, size, .. } => {
8026 let count = match size.resolve(module.to_ctx())? {
8027 proc::IndexableLength::Known(count) => count,
8028 proc::IndexableLength::Dynamic => unreachable!(),
8029 };
8030
8031 access_stack.enter_array(|access_stack, array_depth| {
8032 writeln!(
8033 self.out,
8034 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8035 )?;
8036 self.write_workgroup_variable_initialization(
8037 module,
8038 module_info,
8039 base,
8040 access_stack,
8041 level.next(),
8042 )?;
8043 writeln!(self.out, "{level}}}")?;
8044 BackendResult::Ok(())
8045 })?;
8046 }
8047 crate::TypeInner::Struct { ref members, .. } => {
8048 for (index, member) in members.iter().enumerate() {
8049 access_stack.enter(
8050 Access::StructMember(ty, index as u32),
8051 |access_stack| {
8052 self.write_workgroup_variable_initialization(
8053 module,
8054 module_info,
8055 member.ty,
8056 access_stack,
8057 level,
8058 )
8059 },
8060 )?;
8061 }
8062 }
8063 _ => unreachable!(),
8064 }
8065 }
8066
8067 Ok(())
8068 }
8069 }
8070}
8071
8072impl crate::AtomicFunction {
8073 const fn to_msl(self) -> &'static str {
8074 match self {
8075 Self::Add => "fetch_add",
8076 Self::Subtract => "fetch_sub",
8077 Self::And => "fetch_and",
8078 Self::InclusiveOr => "fetch_or",
8079 Self::ExclusiveOr => "fetch_xor",
8080 Self::Min => "fetch_min",
8081 Self::Max => "fetch_max",
8082 Self::Exchange { compare: None } => "exchange",
8083 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8084 }
8085 }
8086
8087 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8088 Ok(match self {
8089 Self::Min => "min",
8090 Self::Max => "max",
8091 _ => Err(Error::FeatureNotImplemented(
8092 "64-bit atomic operation other than min/max".to_string(),
8093 ))?,
8094 })
8095 }
8096}