1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17 sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo, NAMESPACE,
18 WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21 arena::{Handle, HandleSet},
22 back::{
23 self, get_entry_points,
24 msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25 Baked,
26 },
27 common,
28 proc::{
29 self, concrete_int_scalars,
30 index::{self, BoundsCheck},
31 ExternalTextureNameKey, NameKey, TypeResolution,
32 },
33 valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39const ATOMIC_REFERENCE: &str = "&";
43
44const RT_NAMESPACE: &str = "metal::raytracing";
45const RAY_QUERY_TYPE: &str = "_RayQuery";
46const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
47const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
48const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
50const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
51
52pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
53pub(crate) const MODF_FUNCTION: &str = "naga_modf";
54pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
55pub(crate) const ABS_FUNCTION: &str = "naga_abs";
56pub(crate) const DIV_FUNCTION: &str = "naga_div";
57pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67 "nagaTextureSampleBaseClampToEdge";
68pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
82pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
83
84fn put_numeric_type(
93 out: &mut impl Write,
94 scalar: crate::Scalar,
95 sizes: &[crate::VectorSize],
96) -> Result<(), FmtError> {
97 match (scalar, sizes) {
98 (scalar, &[]) => {
99 write!(out, "{}", scalar.to_msl_name())
100 }
101 (scalar, &[rows]) => {
102 write!(
103 out,
104 "{}::{}{}",
105 NAMESPACE,
106 scalar.to_msl_name(),
107 common::vector_size_str(rows)
108 )
109 }
110 (scalar, &[rows, columns]) => {
111 write!(
112 out,
113 "{}::{}{}x{}",
114 NAMESPACE,
115 scalar.to_msl_name(),
116 common::vector_size_str(columns),
117 common::vector_size_str(rows)
118 )
119 }
120 (_, _) => Ok(()), }
122}
123
124const fn scalar_is_int(scalar: crate::Scalar) -> bool {
125 use crate::ScalarKind::*;
126 match scalar.kind {
127 Sint | Uint | AbstractInt | Bool => true,
128 Float | AbstractFloat => false,
129 }
130}
131
132const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
134
135const REINTERPRET_PREFIX: &str = "reinterpreted_";
137
138struct ClampedLod(Handle<crate::Expression>);
144
145impl Display for ClampedLod {
146 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
147 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
148 }
149}
150
151struct ArraySizeMember(Handle<crate::GlobalVariable>);
166
167impl Display for ArraySizeMember {
168 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
169 self.0.write_prefixed(f, "size")
170 }
171}
172
173#[derive(Clone, Copy)]
178struct Reinterpreted<'a> {
179 target_type: &'a str,
180 orig: Handle<crate::Expression>,
181}
182
183impl<'a> Reinterpreted<'a> {
184 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
185 Self { target_type, orig }
186 }
187}
188
189impl Display for Reinterpreted<'_> {
190 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
191 f.write_str(REINTERPRET_PREFIX)?;
192 f.write_str(self.target_type)?;
193 self.orig.write_prefixed(f, "_e")
194 }
195}
196
197pub(super) struct TypeContext<'a> {
198 pub handle: Handle<crate::Type>,
199 pub gctx: proc::GlobalCtx<'a>,
200 pub names: &'a FastHashMap<NameKey, String>,
201 pub access: crate::StorageAccess,
202 pub first_time: bool,
203}
204
205impl TypeContext<'_> {
206 fn scalar(&self) -> Option<crate::Scalar> {
207 let ty = &self.gctx.types[self.handle];
208 ty.inner.scalar()
209 }
210
211 fn vector_size(&self) -> Option<crate::VectorSize> {
212 let ty = &self.gctx.types[self.handle];
213 match ty.inner {
214 crate::TypeInner::Vector { size, .. } => Some(size),
215 _ => None,
216 }
217 }
218
219 fn unwrap_array(self) -> Self {
220 match self.gctx.types[self.handle].inner {
221 crate::TypeInner::Array { base, .. } => Self {
222 handle: base,
223 ..self
224 },
225 _ => self,
226 }
227 }
228}
229
230impl Display for TypeContext<'_> {
231 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
232 let ty = &self.gctx.types[self.handle];
233 if ty.needs_alias() && !self.first_time {
234 let name = &self.names[&NameKey::Type(self.handle)];
235 return write!(out, "{name}");
236 }
237
238 match ty.inner {
239 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
240 crate::TypeInner::Atomic(scalar) => {
241 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
242 }
243 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
244 crate::TypeInner::Matrix {
245 columns,
246 rows,
247 scalar,
248 } => put_numeric_type(out, scalar, &[rows, columns]),
249 crate::TypeInner::CooperativeMatrix {
251 columns,
252 rows,
253 scalar,
254 role: _,
255 } => {
256 write!(
257 out,
258 "{NAMESPACE}::simdgroup_{}{}x{}",
259 scalar.to_msl_name(),
260 columns as u32,
261 rows as u32,
262 )
263 }
264 crate::TypeInner::Pointer { base, space } => {
265 let sub = Self {
266 handle: base,
267 first_time: false,
268 ..*self
269 };
270 let space_name = match space.to_msl_name() {
271 Some(name) => name,
272 None => return Ok(()),
273 };
274 write!(out, "{space_name} {sub}&")
275 }
276 crate::TypeInner::ValuePointer {
277 size,
278 scalar,
279 space,
280 } => {
281 match space.to_msl_name() {
282 Some(name) => write!(out, "{name} ")?,
283 None => return Ok(()),
284 };
285 match size {
286 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
287 None => put_numeric_type(out, scalar, &[])?,
288 };
289
290 write!(out, "&")
291 }
292 crate::TypeInner::Array { base, .. } => {
293 let sub = Self {
294 handle: base,
295 first_time: false,
296 ..*self
297 };
298 write!(out, "{sub}")
301 }
302 crate::TypeInner::Struct { .. } => unreachable!(),
303 crate::TypeInner::Image {
304 dim,
305 arrayed,
306 class,
307 } => {
308 let dim_str = match dim {
309 crate::ImageDimension::D1 => "1d",
310 crate::ImageDimension::D2 => "2d",
311 crate::ImageDimension::D3 => "3d",
312 crate::ImageDimension::Cube => "cube",
313 };
314 let (texture_str, msaa_str, scalar, access) = match class {
315 crate::ImageClass::Sampled { kind, multi } => {
316 let (msaa_str, access) = if multi {
317 ("_ms", "read")
318 } else {
319 ("", "sample")
320 };
321 let scalar = crate::Scalar { kind, width: 4 };
322 ("texture", msaa_str, scalar, access)
323 }
324 crate::ImageClass::Depth { multi } => {
325 let (msaa_str, access) = if multi {
326 ("_ms", "read")
327 } else {
328 ("", "sample")
329 };
330 let scalar = crate::Scalar {
331 kind: crate::ScalarKind::Float,
332 width: 4,
333 };
334 ("depth", msaa_str, scalar, access)
335 }
336 crate::ImageClass::Storage { format, .. } => {
337 let access = if self
338 .access
339 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
340 {
341 "read_write"
342 } else if self.access.contains(crate::StorageAccess::STORE) {
343 "write"
344 } else if self.access.contains(crate::StorageAccess::LOAD) {
345 "read"
346 } else {
347 log::warn!(
348 "Storage access for {:?} (name '{}'): {:?}",
349 self.handle,
350 ty.name.as_deref().unwrap_or_default(),
351 self.access
352 );
353 unreachable!("module is not valid");
354 };
355 ("texture", "", format.into(), access)
356 }
357 crate::ImageClass::External => {
358 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
359 }
360 };
361 let base_name = scalar.to_msl_name();
362 let array_str = if arrayed { "_array" } else { "" };
363 write!(
364 out,
365 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
366 )
367 }
368 crate::TypeInner::Sampler { comparison: _ } => {
369 write!(out, "{NAMESPACE}::sampler")
370 }
371 crate::TypeInner::AccelerationStructure { vertex_return } => {
372 if vertex_return {
373 unimplemented!("metal does not support vertex ray hit return")
374 }
375 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
376 }
377 crate::TypeInner::RayQuery { vertex_return } => {
378 if vertex_return {
379 unimplemented!("metal does not support vertex ray hit return")
380 }
381 write!(out, "{RAY_QUERY_TYPE}")
382 }
383 crate::TypeInner::BindingArray { base, .. } => {
384 let base_tyname = Self {
385 handle: base,
386 first_time: false,
387 ..*self
388 };
389
390 write!(
391 out,
392 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
393 )
394 }
395 }
396 }
397}
398
399pub(super) struct TypedGlobalVariable<'a> {
400 pub module: &'a crate::Module,
401 pub names: &'a FastHashMap<NameKey, String>,
402 pub handle: Handle<crate::GlobalVariable>,
403 pub usage: valid::GlobalUse,
404 pub reference: bool,
405}
406
407struct TypedGlobalVariableParts {
408 ty_name: String,
409 var_name: String,
410}
411
412impl TypedGlobalVariable<'_> {
413 fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
414 let var = &self.module.global_variables[self.handle];
415 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
416
417 let storage_access = match var.space {
418 crate::AddressSpace::Storage { access } => access,
419 _ => match self.module.types[var.ty].inner {
420 crate::TypeInner::Image {
421 class: crate::ImageClass::Storage { access, .. },
422 ..
423 } => access,
424 crate::TypeInner::BindingArray { base, .. } => {
425 match self.module.types[base].inner {
426 crate::TypeInner::Image {
427 class: crate::ImageClass::Storage { access, .. },
428 ..
429 } => access,
430 _ => crate::StorageAccess::default(),
431 }
432 }
433 _ => crate::StorageAccess::default(),
434 },
435 };
436 let ty_name = TypeContext {
437 handle: var.ty,
438 gctx: self.module.to_ctx(),
439 names: self.names,
440 access: storage_access,
441 first_time: false,
442 };
443
444 let access = if var.space.needs_access_qualifier()
445 && !self.usage.intersects(valid::GlobalUse::WRITE)
446 {
447 "const"
448 } else {
449 ""
450 };
451 let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
452 (Some(space), crate::AddressSpace::WorkGroup) => {
453 ("", space, access, if self.reference { "&" } else { "" })
454 }
455 (Some(space), _) if self.reference => {
456 let coherent = if var
457 .memory_decorations
458 .contains(crate::MemoryDecorations::COHERENT)
459 {
460 "coherent "
461 } else {
462 ""
463 };
464 (coherent, space, access, "&")
465 }
466 _ => ("", "", "", ""),
467 };
468
469 let ty = format!(
470 "{coherent}{space}{}{ty_name}{}{access}{reference}",
471 if space.is_empty() { "" } else { " " },
472 if access.is_empty() { "" } else { " " },
473 );
474
475 Ok(TypedGlobalVariableParts {
476 ty_name: ty,
477 var_name: name.clone(),
478 })
479 }
480 pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
481 let parts = self.to_parts()?;
482
483 Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
484 }
485}
486
487#[derive(Eq, PartialEq, Hash)]
488pub(super) enum WrappedFunction {
489 UnaryOp {
490 op: crate::UnaryOperator,
491 ty: (Option<crate::VectorSize>, crate::Scalar),
492 },
493 BinaryOp {
494 op: crate::BinaryOperator,
495 left_ty: (Option<crate::VectorSize>, crate::Scalar),
496 right_ty: (Option<crate::VectorSize>, crate::Scalar),
497 },
498 Math {
499 fun: crate::MathFunction,
500 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
501 },
502 Cast {
503 src_scalar: crate::Scalar,
504 vector_size: Option<crate::VectorSize>,
505 dst_scalar: crate::Scalar,
506 },
507 ImageLoad {
508 class: crate::ImageClass,
509 },
510 ImageSample {
511 class: crate::ImageClass,
512 clamp_to_edge: bool,
513 },
514 ImageQuerySize {
515 class: crate::ImageClass,
516 },
517 CooperativeLoad {
518 space_name: &'static str,
519 columns: crate::CooperativeSize,
520 rows: crate::CooperativeSize,
521 scalar: crate::Scalar,
522 },
523 CooperativeMultiplyAdd {
524 space_name: &'static str,
525 columns: crate::CooperativeSize,
526 rows: crate::CooperativeSize,
527 intermediate: crate::CooperativeSize,
528 scalar: crate::Scalar,
529 },
530}
531
532pub struct Writer<W> {
533 pub(super) out: W,
534 pub(super) names: FastHashMap<NameKey, String>,
535 pub(super) named_expressions: crate::NamedExpressions,
536 pub(super) need_bake_expressions: back::NeedBakeExpressions,
538 pub(super) namer: proc::Namer,
539 pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
540 #[cfg(test)]
541 pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
542 #[cfg(test)]
543 pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
544 pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
547 pub(super) needs_object_memory_barriers: bool,
548}
549
550impl crate::Scalar {
551 pub(super) fn to_msl_name(self) -> &'static str {
552 use crate::ScalarKind as Sk;
553 match self {
554 Self {
555 kind: Sk::Float,
556 width: 4,
557 } => "float",
558 Self {
559 kind: Sk::Float,
560 width: 2,
561 } => "half",
562 Self {
563 kind: Sk::Sint,
564 width: 4,
565 } => "int",
566 Self {
567 kind: Sk::Uint,
568 width: 4,
569 } => "uint",
570 Self {
571 kind: Sk::Sint,
572 width: 8,
573 } => "long",
574 Self {
575 kind: Sk::Uint,
576 width: 8,
577 } => "ulong",
578 Self {
579 kind: Sk::Bool,
580 width: _,
581 } => "bool",
582 Self {
583 kind: Sk::AbstractInt | Sk::AbstractFloat,
584 width: _,
585 } => unreachable!("Found Abstract scalar kind"),
586 _ => unreachable!("Unsupported scalar kind: {:?}", self),
587 }
588 }
589}
590
591const fn separate(need_separator: bool) -> &'static str {
592 if need_separator {
593 ","
594 } else {
595 ""
596 }
597}
598
599fn should_pack_struct_member(
600 members: &[crate::StructMember],
601 span: u32,
602 index: usize,
603 module: &crate::Module,
604) -> Option<crate::Scalar> {
605 let member = &members[index];
606
607 let ty_inner = &module.types[member.ty].inner;
608 let last_offset = member.offset + ty_inner.size(module.to_ctx());
609 let next_offset = match members.get(index + 1) {
610 Some(next) => next.offset,
611 None => span,
612 };
613 let is_tight = next_offset == last_offset;
614
615 match *ty_inner {
616 crate::TypeInner::Vector {
617 size: crate::VectorSize::Tri,
618 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
619 } if is_tight => Some(scalar),
620 _ => None,
621 }
622}
623
624fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
625 match arena[ty].inner {
626 crate::TypeInner::Struct { ref members, .. } => {
627 if let Some(member) = members.last() {
628 if let crate::TypeInner::Array {
629 size: crate::ArraySize::Dynamic,
630 ..
631 } = arena[member.ty].inner
632 {
633 return true;
634 }
635 }
636 false
637 }
638 crate::TypeInner::Array {
639 size: crate::ArraySize::Dynamic,
640 ..
641 } => true,
642 _ => false,
643 }
644}
645
646impl crate::AddressSpace {
647 const fn needs_pass_through(&self) -> bool {
651 match *self {
652 Self::Uniform
653 | Self::Storage { .. }
654 | Self::Private
655 | Self::WorkGroup
656 | Self::Immediate
657 | Self::Handle
658 | Self::TaskPayload => true,
659 Self::Function => false,
660 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
661 }
662 }
663
664 const fn needs_access_qualifier(&self) -> bool {
666 match *self {
667 Self::Storage { .. } => true,
672 Self::TaskPayload => true,
673 Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
674 Self::Private | Self::WorkGroup => false,
676 Self::Uniform | Self::Immediate => false,
678 Self::Handle | Self::Function => false,
680 }
681 }
682
683 const fn to_msl_name(self) -> Option<&'static str> {
684 match self {
685 Self::Handle => None,
686 Self::Uniform | Self::Immediate => Some("constant"),
687 Self::Storage { .. } => Some("device"),
688 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
692 Self::WorkGroup => Some("threadgroup"),
693 Self::TaskPayload => Some("object_data"),
694 Self::IncomingRayPayload => Some("ray_data"),
695 }
696 }
697}
698
699impl crate::Type {
700 const fn needs_alias(&self) -> bool {
702 use crate::TypeInner as Ti;
703
704 match self.inner {
705 Ti::Scalar(_)
707 | Ti::Vector { .. }
708 | Ti::Matrix { .. }
709 | Ti::CooperativeMatrix { .. }
710 | Ti::Atomic(_)
711 | Ti::Pointer { .. }
712 | Ti::ValuePointer { .. } => self.name.is_some(),
713 Ti::Struct { .. } | Ti::Array { .. } => true,
715 Ti::Image { .. }
717 | Ti::Sampler { .. }
718 | Ti::AccelerationStructure { .. }
719 | Ti::RayQuery { .. }
720 | Ti::BindingArray { .. } => false,
721 }
722 }
723}
724
725#[derive(Clone, Copy)]
726enum FunctionOrigin {
727 Handle(Handle<crate::Function>),
728 EntryPoint(proc::EntryPointIndex),
729}
730
731trait NameKeyExt {
732 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
733 match origin {
734 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
735 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
736 }
737 }
738
739 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
744 match origin {
745 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
746 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
747 }
748 }
749}
750
751impl NameKeyExt for NameKey {}
752
753#[derive(Clone, Copy)]
763enum LevelOfDetail {
764 Direct(Handle<crate::Expression>),
765 Restricted(Handle<crate::Expression>),
766}
767
768struct TexelAddress {
778 coordinate: Handle<crate::Expression>,
779 array_index: Option<Handle<crate::Expression>>,
780 sample: Option<Handle<crate::Expression>>,
781 level: Option<LevelOfDetail>,
782}
783
784pub(super) struct ExpressionContext<'a> {
785 function: &'a crate::Function,
786 origin: FunctionOrigin,
787 info: &'a valid::FunctionInfo,
788 module: &'a crate::Module,
789 mod_info: &'a valid::ModuleInfo,
790 pipeline_options: &'a PipelineOptions,
791 lang_version: (u8, u8),
792 policies: index::BoundsCheckPolicies,
793
794 guarded_indices: HandleSet<crate::Expression>,
798 force_loop_bounding: bool,
800}
801
802impl<'a> ExpressionContext<'a> {
803 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
804 self.info[handle].ty.inner_with(&self.module.types)
805 }
806
807 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
814 let image_ty = self.resolve_type(image);
815 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
816 class.is_mipmapped() && dim != crate::ImageDimension::D1
817 } else {
818 false
819 }
820 }
821
822 fn choose_bounds_check_policy(
823 &self,
824 pointer: Handle<crate::Expression>,
825 ) -> index::BoundsCheckPolicy {
826 self.policies
827 .choose_policy(pointer, &self.module.types, self.info)
828 }
829
830 fn access_needs_check(
832 &self,
833 base: Handle<crate::Expression>,
834 index: index::GuardedIndex,
835 ) -> Option<index::IndexableLength> {
836 index::access_needs_check(
837 base,
838 index,
839 self.module,
840 &self.function.expressions,
841 self.info,
842 )
843 }
844
845 fn bounds_check_iter(
847 &self,
848 chain: Handle<crate::Expression>,
849 ) -> impl Iterator<Item = BoundsCheck> + '_ {
850 index::bounds_check_iter(chain, self.module, self.function, self.info)
851 }
852
853 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
855 index::oob_local_types(self.module, self.function, self.info, self.policies)
856 }
857
858 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
859 match self.function.expressions[expr_handle] {
860 crate::Expression::AccessIndex { base, index } => {
861 let ty = match *self.resolve_type(base) {
862 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
863 ref ty => ty,
864 };
865 match *ty {
866 crate::TypeInner::Struct {
867 ref members, span, ..
868 } => should_pack_struct_member(members, span, index as usize, self.module),
869 _ => None,
870 }
871 }
872 _ => None,
873 }
874 }
875}
876
877struct StatementContext<'a> {
878 expression: ExpressionContext<'a>,
879 result_struct: Option<&'a str>,
880}
881
882impl<W: Write> Writer<W> {
883 pub fn new(out: W) -> Self {
885 Writer {
886 out,
887 names: FastHashMap::default(),
888 named_expressions: Default::default(),
889 need_bake_expressions: Default::default(),
890 namer: proc::Namer::default(),
891 wrapped_functions: FastHashSet::default(),
892 #[cfg(test)]
893 put_expression_stack_pointers: Default::default(),
894 #[cfg(test)]
895 put_block_stack_pointers: Default::default(),
896 struct_member_pads: FastHashSet::default(),
897 needs_object_memory_barriers: false,
898 }
899 }
900
901 pub fn finish(self) -> W {
904 self.out
905 }
906
907 fn gen_force_bounded_loop_statements(
1014 &mut self,
1015 level: back::Level,
1016 context: &StatementContext,
1017 ) -> Option<(String, String)> {
1018 if !context.expression.force_loop_bounding {
1019 return None;
1020 }
1021
1022 let loop_bound_name = self.namer.call("loop_bound");
1023 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1026 let level = level.next();
1027 let break_and_inc = format!(
1028 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1029{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1030 );
1031
1032 Some((decl, break_and_inc))
1033 }
1034
1035 fn put_call_parameters(
1036 &mut self,
1037 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1038 context: &ExpressionContext,
1039 ) -> BackendResult {
1040 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1041 writer.put_expression(expr, context, true)
1042 })
1043 }
1044
1045 fn put_call_parameters_impl<C, E>(
1046 &mut self,
1047 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1048 ctx: &C,
1049 put_expression: E,
1050 ) -> BackendResult
1051 where
1052 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1053 {
1054 write!(self.out, "(")?;
1055 for (i, handle) in parameters.enumerate() {
1056 if i != 0 {
1057 write!(self.out, ", ")?;
1058 }
1059 put_expression(self, ctx, handle)?;
1060 }
1061 write!(self.out, ")")?;
1062 Ok(())
1063 }
1064
1065 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1071 let oob_local_types = context.oob_local_types();
1072 for &ty in oob_local_types.iter() {
1073 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1074 self.names.insert(name_key, self.namer.call("oob"));
1075 }
1076
1077 for (name_key, ty, init) in context
1078 .function
1079 .local_variables
1080 .iter()
1081 .map(|(local_handle, local)| {
1082 let name_key = NameKey::local(context.origin, local_handle);
1083 (name_key, local.ty, local.init)
1084 })
1085 .chain(oob_local_types.iter().map(|&ty| {
1086 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1087 (name_key, ty, None)
1088 }))
1089 {
1090 let ty_name = TypeContext {
1091 handle: ty,
1092 gctx: context.module.to_ctx(),
1093 names: &self.names,
1094 access: crate::StorageAccess::empty(),
1095 first_time: false,
1096 };
1097 write!(
1098 self.out,
1099 "{}{} {}",
1100 back::INDENT,
1101 ty_name,
1102 self.names[&name_key]
1103 )?;
1104 match init {
1105 Some(value) => {
1106 write!(self.out, " = ")?;
1107 self.put_expression(value, context, true)?;
1108 }
1109 None => {
1110 write!(self.out, " = {{}}")?;
1111 }
1112 };
1113 writeln!(self.out, ";")?;
1114 }
1115 Ok(())
1116 }
1117
1118 fn put_level_of_detail(
1119 &mut self,
1120 level: LevelOfDetail,
1121 context: &ExpressionContext,
1122 ) -> BackendResult {
1123 match level {
1124 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1125 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1126 }
1127 Ok(())
1128 }
1129
1130 fn put_image_query(
1131 &mut self,
1132 image: Handle<crate::Expression>,
1133 query: &str,
1134 level: Option<LevelOfDetail>,
1135 context: &ExpressionContext,
1136 ) -> BackendResult {
1137 self.put_expression(image, context, false)?;
1138 write!(self.out, ".get_{query}(")?;
1139 if let Some(level) = level {
1140 self.put_level_of_detail(level, context)?;
1141 }
1142 write!(self.out, ")")?;
1143 Ok(())
1144 }
1145
1146 fn put_image_size_query(
1147 &mut self,
1148 image: Handle<crate::Expression>,
1149 level: Option<LevelOfDetail>,
1150 kind: crate::ScalarKind,
1151 context: &ExpressionContext,
1152 ) -> BackendResult {
1153 if let crate::TypeInner::Image {
1154 class: crate::ImageClass::External,
1155 ..
1156 } = *context.resolve_type(image)
1157 {
1158 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1159 self.put_expression(image, context, true)?;
1160 write!(self.out, ")")?;
1161 return Ok(());
1162 }
1163
1164 let dim = match *context.resolve_type(image) {
1167 crate::TypeInner::Image { dim, .. } => dim,
1168 ref other => unreachable!("Unexpected type {:?}", other),
1169 };
1170 let scalar = crate::Scalar { kind, width: 4 };
1171 let coordinate_type = scalar.to_msl_name();
1172 match dim {
1173 crate::ImageDimension::D1 => {
1174 if kind == crate::ScalarKind::Uint {
1178 self.put_image_query(image, "width", None, context)?;
1180 } else {
1181 write!(self.out, "int(")?;
1183 self.put_image_query(image, "width", None, context)?;
1184 write!(self.out, ")")?;
1185 }
1186 }
1187 crate::ImageDimension::D2 => {
1188 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1189 self.put_image_query(image, "width", level, context)?;
1190 write!(self.out, ", ")?;
1191 self.put_image_query(image, "height", level, context)?;
1192 write!(self.out, ")")?;
1193 }
1194 crate::ImageDimension::D3 => {
1195 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1196 self.put_image_query(image, "width", level, context)?;
1197 write!(self.out, ", ")?;
1198 self.put_image_query(image, "height", level, context)?;
1199 write!(self.out, ", ")?;
1200 self.put_image_query(image, "depth", level, context)?;
1201 write!(self.out, ")")?;
1202 }
1203 crate::ImageDimension::Cube => {
1204 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1205 self.put_image_query(image, "width", level, context)?;
1206 write!(self.out, ")")?;
1207 }
1208 }
1209 Ok(())
1210 }
1211
1212 fn put_cast_to_uint_scalar_or_vector(
1213 &mut self,
1214 expr: Handle<crate::Expression>,
1215 context: &ExpressionContext,
1216 ) -> BackendResult {
1217 match *context.resolve_type(expr) {
1219 crate::TypeInner::Scalar(_) => {
1220 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1221 }
1222 crate::TypeInner::Vector { size, .. } => {
1223 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1224 }
1225 _ => {
1226 return Err(Error::GenericValidation(
1227 "Invalid type for image coordinate".into(),
1228 ))
1229 }
1230 };
1231
1232 write!(self.out, "(")?;
1233 self.put_expression(expr, context, true)?;
1234 write!(self.out, ")")?;
1235 Ok(())
1236 }
1237
1238 fn put_image_sample_level(
1239 &mut self,
1240 image: Handle<crate::Expression>,
1241 level: crate::SampleLevel,
1242 context: &ExpressionContext,
1243 ) -> BackendResult {
1244 let has_levels = context.image_needs_lod(image);
1245 match level {
1246 crate::SampleLevel::Auto => {}
1247 crate::SampleLevel::Zero => {
1248 }
1250 _ if !has_levels => {
1251 log::warn!("1D image can't be sampled with level {level:?}");
1252 }
1253 crate::SampleLevel::Exact(h) => {
1254 write!(self.out, ", {NAMESPACE}::level(")?;
1255 self.put_expression(h, context, true)?;
1256 write!(self.out, ")")?;
1257 }
1258 crate::SampleLevel::Bias(h) => {
1259 write!(self.out, ", {NAMESPACE}::bias(")?;
1260 self.put_expression(h, context, true)?;
1261 write!(self.out, ")")?;
1262 }
1263 crate::SampleLevel::Gradient { x, y } => {
1264 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1265 self.put_expression(x, context, true)?;
1266 write!(self.out, ", ")?;
1267 self.put_expression(y, context, true)?;
1268 write!(self.out, ")")?;
1269 }
1270 }
1271 Ok(())
1272 }
1273
1274 fn put_image_coordinate_limits(
1275 &mut self,
1276 image: Handle<crate::Expression>,
1277 level: Option<LevelOfDetail>,
1278 context: &ExpressionContext,
1279 ) -> BackendResult {
1280 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1281 write!(self.out, " - 1")?;
1282 Ok(())
1283 }
1284
1285 fn put_restricted_scalar_image_index(
1303 &mut self,
1304 image: Handle<crate::Expression>,
1305 index: Handle<crate::Expression>,
1306 limit_method: &str,
1307 context: &ExpressionContext,
1308 ) -> BackendResult {
1309 write!(self.out, "{NAMESPACE}::min(uint(")?;
1310 self.put_expression(index, context, true)?;
1311 write!(self.out, "), ")?;
1312 self.put_expression(image, context, false)?;
1313 write!(self.out, ".{limit_method}() - 1)")?;
1314 Ok(())
1315 }
1316
1317 fn put_restricted_texel_address(
1318 &mut self,
1319 image: Handle<crate::Expression>,
1320 address: &TexelAddress,
1321 context: &ExpressionContext,
1322 ) -> BackendResult {
1323 write!(self.out, "{NAMESPACE}::min(")?;
1325 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1326 write!(self.out, ", ")?;
1327 self.put_image_coordinate_limits(image, address.level, context)?;
1328 write!(self.out, ")")?;
1329
1330 if let Some(array_index) = address.array_index {
1332 write!(self.out, ", ")?;
1333 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1334 }
1335
1336 if let Some(sample) = address.sample {
1338 write!(self.out, ", ")?;
1339 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1340 }
1341
1342 if let Some(level) = address.level {
1345 write!(self.out, ", ")?;
1346 self.put_level_of_detail(level, context)?;
1347 }
1348
1349 Ok(())
1350 }
1351
1352 fn put_image_access_bounds_check(
1354 &mut self,
1355 image: Handle<crate::Expression>,
1356 address: &TexelAddress,
1357 context: &ExpressionContext,
1358 ) -> BackendResult {
1359 let mut conjunction = "";
1360
1361 let level = if let Some(level) = address.level {
1364 write!(self.out, "uint(")?;
1365 self.put_level_of_detail(level, context)?;
1366 write!(self.out, ") < ")?;
1367 self.put_expression(image, context, true)?;
1368 write!(self.out, ".get_num_mip_levels()")?;
1369 conjunction = " && ";
1370 Some(level)
1371 } else {
1372 None
1373 };
1374
1375 if let Some(sample) = address.sample {
1377 write!(self.out, "uint(")?;
1378 self.put_expression(sample, context, true)?;
1379 write!(self.out, ") < ")?;
1380 self.put_expression(image, context, true)?;
1381 write!(self.out, ".get_num_samples()")?;
1382 conjunction = " && ";
1383 }
1384
1385 if let Some(array_index) = address.array_index {
1387 write!(self.out, "{conjunction}uint(")?;
1388 self.put_expression(array_index, context, true)?;
1389 write!(self.out, ") < ")?;
1390 self.put_expression(image, context, true)?;
1391 write!(self.out, ".get_array_size()")?;
1392 conjunction = " && ";
1393 }
1394
1395 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1397 crate::TypeInner::Vector { .. } => true,
1398 _ => false,
1399 };
1400 write!(self.out, "{conjunction}")?;
1401 if coord_is_vector {
1402 write!(self.out, "{NAMESPACE}::all(")?;
1403 }
1404 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1405 write!(self.out, " < ")?;
1406 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1407 if coord_is_vector {
1408 write!(self.out, ")")?;
1409 }
1410
1411 Ok(())
1412 }
1413
1414 fn put_image_load(
1415 &mut self,
1416 load: Handle<crate::Expression>,
1417 image: Handle<crate::Expression>,
1418 mut address: TexelAddress,
1419 context: &ExpressionContext,
1420 ) -> BackendResult {
1421 if let crate::TypeInner::Image {
1422 class: crate::ImageClass::External,
1423 ..
1424 } = *context.resolve_type(image)
1425 {
1426 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1427 self.put_expression(image, context, true)?;
1428 write!(self.out, ", ")?;
1429 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1430 write!(self.out, ")")?;
1431 return Ok(());
1432 }
1433
1434 match context.policies.image_load {
1435 proc::BoundsCheckPolicy::Restrict => {
1436 if address.level.is_some() {
1439 address.level = if context.image_needs_lod(image) {
1440 Some(LevelOfDetail::Restricted(load))
1441 } else {
1442 None
1443 }
1444 }
1445
1446 self.put_expression(image, context, false)?;
1447 write!(self.out, ".read(")?;
1448 self.put_restricted_texel_address(image, &address, context)?;
1449 write!(self.out, ")")?;
1450 }
1451 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1452 write!(self.out, "(")?;
1453 self.put_image_access_bounds_check(image, &address, context)?;
1454 write!(self.out, " ? ")?;
1455 self.put_unchecked_image_load(image, &address, context)?;
1456 write!(self.out, ": DefaultConstructible())")?;
1457 }
1458 proc::BoundsCheckPolicy::Unchecked => {
1459 self.put_unchecked_image_load(image, &address, context)?;
1460 }
1461 }
1462
1463 Ok(())
1464 }
1465
1466 fn put_unchecked_image_load(
1467 &mut self,
1468 image: Handle<crate::Expression>,
1469 address: &TexelAddress,
1470 context: &ExpressionContext,
1471 ) -> BackendResult {
1472 self.put_expression(image, context, false)?;
1473 write!(self.out, ".read(")?;
1474 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1476 if let Some(expr) = address.array_index {
1477 write!(self.out, ", ")?;
1478 self.put_expression(expr, context, true)?;
1479 }
1480 if let Some(sample) = address.sample {
1481 write!(self.out, ", ")?;
1482 self.put_expression(sample, context, true)?;
1483 }
1484 if let Some(level) = address.level {
1485 if context.image_needs_lod(image) {
1486 write!(self.out, ", ")?;
1487 self.put_level_of_detail(level, context)?;
1488 }
1489 }
1490 write!(self.out, ")")?;
1491
1492 Ok(())
1493 }
1494
1495 fn put_image_atomic(
1496 &mut self,
1497 level: back::Level,
1498 image: Handle<crate::Expression>,
1499 address: &TexelAddress,
1500 fun: crate::AtomicFunction,
1501 value: Handle<crate::Expression>,
1502 context: &StatementContext,
1503 ) -> BackendResult {
1504 write!(self.out, "{level}")?;
1505 self.put_expression(image, &context.expression, false)?;
1506 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1507 fun.to_msl_64_bit()?
1508 } else {
1509 fun.to_msl()
1510 };
1511 write!(self.out, ".atomic_{op}(")?;
1512 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1514 write!(self.out, ", ")?;
1515 self.put_expression(value, &context.expression, true)?;
1516 writeln!(self.out, ");")?;
1517
1518 let value_ty = context.expression.resolve_type(value);
1524 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1525 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1526 (_, Some(8)) => "ulong4(0uL)",
1527 _ => "uint4(0u)",
1528 };
1529 let coord_ty = context.expression.resolve_type(address.coordinate);
1530 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1531 ""
1532 } else {
1533 ".x"
1534 };
1535 write!(self.out, "{level}if (")?;
1536 self.put_expression(address.coordinate, &context.expression, true)?;
1537 write!(self.out, "{x} == -99999) {{ ")?;
1538 self.put_expression(image, &context.expression, false)?;
1539 write!(self.out, ".write({zero_value}, ")?;
1540 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1541 if let Some(array_index) = address.array_index {
1542 write!(self.out, ", ")?;
1543 self.put_expression(array_index, &context.expression, true)?;
1544 }
1545 writeln!(self.out, "); }}")?;
1546
1547 Ok(())
1548 }
1549
1550 fn put_image_store(
1551 &mut self,
1552 level: back::Level,
1553 image: Handle<crate::Expression>,
1554 address: &TexelAddress,
1555 value: Handle<crate::Expression>,
1556 context: &StatementContext,
1557 ) -> BackendResult {
1558 write!(self.out, "{level}")?;
1559 self.put_expression(image, &context.expression, false)?;
1560 write!(self.out, ".write(")?;
1561 self.put_expression(value, &context.expression, true)?;
1562 write!(self.out, ", ")?;
1563 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1565 if let Some(expr) = address.array_index {
1566 write!(self.out, ", ")?;
1567 self.put_expression(expr, &context.expression, true)?;
1568 }
1569 writeln!(self.out, ");")?;
1570
1571 Ok(())
1572 }
1573
1574 fn put_dynamic_array_max_index(
1585 &mut self,
1586 handle: Handle<crate::GlobalVariable>,
1587 context: &ExpressionContext,
1588 ) -> BackendResult {
1589 let global = &context.module.global_variables[handle];
1590 let (offset, array_ty) = match context.module.types[global.ty].inner {
1591 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1592 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1593 None => return Err(Error::GenericValidation("Struct has no members".into())),
1594 },
1595 crate::TypeInner::Array {
1596 size: crate::ArraySize::Dynamic,
1597 ..
1598 } => (0, global.ty),
1599 ref ty => {
1600 return Err(Error::GenericValidation(format!(
1601 "Expected type with dynamic array, got {ty:?}"
1602 )))
1603 }
1604 };
1605
1606 let (size, stride) = match context.module.types[array_ty].inner {
1607 crate::TypeInner::Array { base, stride, .. } => (
1608 context.module.types[base]
1609 .inner
1610 .size(context.module.to_ctx()),
1611 stride,
1612 ),
1613 ref ty => {
1614 return Err(Error::GenericValidation(format!(
1615 "Expected array type, got {ty:?}"
1616 )))
1617 }
1618 };
1619
1620 write!(
1633 self.out,
1634 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1635 member = ArraySizeMember(handle),
1636 offset = offset,
1637 size = size,
1638 stride = stride,
1639 )?;
1640 Ok(())
1641 }
1642
1643 fn put_dot_product<T: Copy>(
1648 &mut self,
1649 arg: T,
1650 arg1: T,
1651 size: usize,
1652 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1653 ) -> BackendResult {
1654 write!(self.out, "(")?;
1657
1658 for index in 0..size {
1660 write!(self.out, " + ")?;
1663 extractor(self, arg, index)?;
1664 write!(self.out, " * ")?;
1665 extractor(self, arg1, index)?;
1666 }
1667
1668 write!(self.out, ")")?;
1669 Ok(())
1670 }
1671
1672 fn put_pack4x8(
1674 &mut self,
1675 arg: Handle<crate::Expression>,
1676 context: &ExpressionContext<'_>,
1677 was_signed: bool,
1678 clamp_bounds: Option<(&str, &str)>,
1679 ) -> Result<(), Error> {
1680 let write_arg = |this: &mut Self| -> BackendResult {
1681 if let Some((min, max)) = clamp_bounds {
1682 write!(this.out, "{NAMESPACE}::clamp(")?;
1684 this.put_expression(arg, context, true)?;
1685 write!(this.out, ", {min}, {max})")?;
1686 } else {
1687 this.put_expression(arg, context, true)?;
1688 }
1689 Ok(())
1690 };
1691
1692 if context.lang_version >= (2, 1) {
1693 let packed_type = if was_signed {
1694 "packed_char4"
1695 } else {
1696 "packed_uchar4"
1697 };
1698 write!(self.out, "as_type<uint>({packed_type}(")?;
1700 write_arg(self)?;
1701 write!(self.out, "))")?;
1702 } else {
1703 if was_signed {
1705 write!(self.out, "uint(")?;
1706 }
1707 write!(self.out, "(")?;
1708 write_arg(self)?;
1709 write!(self.out, "[0] & 0xFF) | ((")?;
1710 write_arg(self)?;
1711 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1712 write_arg(self)?;
1713 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1714 write_arg(self)?;
1715 write!(self.out, "[3] & 0xFF) << 24)")?;
1716 if was_signed {
1717 write!(self.out, ")")?;
1718 }
1719 }
1720
1721 Ok(())
1722 }
1723
1724 fn put_isign(
1727 &mut self,
1728 arg: Handle<crate::Expression>,
1729 context: &ExpressionContext,
1730 ) -> BackendResult {
1731 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1732 let scalar = context
1733 .resolve_type(arg)
1734 .scalar()
1735 .expect("put_isign should only be called for args which have an integer scalar type")
1736 .to_msl_name();
1737 match context.resolve_type(arg) {
1738 &crate::TypeInner::Vector { size, .. } => {
1739 let size = common::vector_size_str(size);
1740 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1741 }
1742 _ => {
1743 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1744 }
1745 }
1746 write!(self.out, ", (")?;
1747 self.put_expression(arg, context, true)?;
1748 write!(self.out, " > 0)), {scalar}(0), (")?;
1749 self.put_expression(arg, context, true)?;
1750 write!(self.out, " == 0))")?;
1751 Ok(())
1752 }
1753
1754 pub(super) fn put_const_expression(
1755 &mut self,
1756 expr_handle: Handle<crate::Expression>,
1757 module: &crate::Module,
1758 mod_info: &valid::ModuleInfo,
1759 arena: &crate::Arena<crate::Expression>,
1760 ) -> BackendResult {
1761 self.put_possibly_const_expression(
1762 expr_handle,
1763 arena,
1764 module,
1765 mod_info,
1766 &(module, mod_info),
1767 |&(_, mod_info), expr| &mod_info[expr],
1768 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1769 )
1770 }
1771
1772 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1773 match literal {
1774 crate::Literal::F64(_) => {
1775 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1776 }
1777 crate::Literal::F16(value) => {
1778 if value.is_infinite() {
1779 let sign = if value.is_sign_negative() { "-" } else { "" };
1780 write!(self.out, "{sign}INFINITY")?;
1781 } else if value.is_nan() {
1782 write!(self.out, "NAN")?;
1783 } else {
1784 let suffix = if value.fract() == f16::from_f32(0.0) {
1785 ".0h"
1786 } else {
1787 "h"
1788 };
1789 write!(self.out, "{value}{suffix}")?;
1790 }
1791 }
1792 crate::Literal::F32(value) => {
1793 if value.is_infinite() {
1794 let sign = if value.is_sign_negative() { "-" } else { "" };
1795 write!(self.out, "{sign}INFINITY")?;
1796 } else if value.is_nan() {
1797 write!(self.out, "NAN")?;
1798 } else {
1799 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1800 write!(self.out, "{value}{suffix}")?;
1801 }
1802 }
1803 crate::Literal::U32(value) => {
1804 write!(self.out, "{value}u")?;
1805 }
1806 crate::Literal::I32(value) => {
1807 if value == i32::MIN {
1812 write!(self.out, "({} - 1)", value + 1)?;
1813 } else {
1814 write!(self.out, "{value}")?;
1815 }
1816 }
1817 crate::Literal::U64(value) => {
1818 write!(self.out, "{value}uL")?;
1819 }
1820 crate::Literal::I64(value) => {
1821 if value == i64::MIN {
1828 write!(self.out, "({}L - 1L)", value + 1)?;
1829 } else {
1830 write!(self.out, "{value}L")?;
1831 }
1832 }
1833 crate::Literal::Bool(value) => {
1834 write!(self.out, "{value}")?;
1835 }
1836 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1837 return Err(Error::GenericValidation(
1838 "Unsupported abstract literal".into(),
1839 ));
1840 }
1841 }
1842 Ok(())
1843 }
1844
1845 #[allow(clippy::too_many_arguments)]
1846 fn put_possibly_const_expression<C, I, E>(
1847 &mut self,
1848 expr_handle: Handle<crate::Expression>,
1849 expressions: &crate::Arena<crate::Expression>,
1850 module: &crate::Module,
1851 mod_info: &valid::ModuleInfo,
1852 ctx: &C,
1853 get_expr_ty: I,
1854 put_expression: E,
1855 ) -> BackendResult
1856 where
1857 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1858 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1859 {
1860 match expressions[expr_handle] {
1861 crate::Expression::Literal(literal) => {
1862 self.put_literal(literal)?;
1863 }
1864 crate::Expression::Constant(handle) => {
1865 let constant = &module.constants[handle];
1866 if constant.name.is_some() {
1867 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1868 } else {
1869 self.put_const_expression(
1870 constant.init,
1871 module,
1872 mod_info,
1873 &module.global_expressions,
1874 )?;
1875 }
1876 }
1877 crate::Expression::ZeroValue(ty) => {
1878 let ty_name = TypeContext {
1879 handle: ty,
1880 gctx: module.to_ctx(),
1881 names: &self.names,
1882 access: crate::StorageAccess::empty(),
1883 first_time: false,
1884 };
1885 write!(self.out, "{ty_name} {{}}")?;
1886 }
1887 crate::Expression::Compose { ty, ref components } => {
1888 let ty_name = TypeContext {
1889 handle: ty,
1890 gctx: module.to_ctx(),
1891 names: &self.names,
1892 access: crate::StorageAccess::empty(),
1893 first_time: false,
1894 };
1895 write!(self.out, "{ty_name}")?;
1896 match module.types[ty].inner {
1897 crate::TypeInner::Scalar(_)
1898 | crate::TypeInner::Vector { .. }
1899 | crate::TypeInner::Matrix { .. } => {
1900 self.put_call_parameters_impl(
1901 components.iter().copied(),
1902 ctx,
1903 put_expression,
1904 )?;
1905 }
1906 crate::TypeInner::Array { .. } => {
1907 write!(self.out, " {{{{")?;
1910 for (index, &component) in components.iter().enumerate() {
1911 if index != 0 {
1912 write!(self.out, ", ")?;
1913 }
1914 put_expression(self, ctx, component)?;
1915 }
1916 write!(self.out, "}}}}")?;
1917 }
1918 crate::TypeInner::Struct { .. } => {
1919 write!(self.out, " {{")?;
1920 for (index, &component) in components.iter().enumerate() {
1921 if index != 0 {
1922 write!(self.out, ", ")?;
1923 }
1924 if self.struct_member_pads.contains(&(ty, index as u32)) {
1926 write!(self.out, "{{}}, ")?;
1927 }
1928 put_expression(self, ctx, component)?;
1929 }
1930 write!(self.out, "}}")?;
1931 }
1932 _ => return Err(Error::UnsupportedCompose(ty)),
1933 }
1934 }
1935 crate::Expression::Splat { size, value } => {
1936 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1937 crate::TypeInner::Scalar(scalar) => scalar,
1938 ref ty => {
1939 return Err(Error::GenericValidation(format!(
1940 "Expected splat value type must be a scalar, got {ty:?}",
1941 )))
1942 }
1943 };
1944 put_numeric_type(&mut self.out, scalar, &[size])?;
1945 write!(self.out, "(")?;
1946 put_expression(self, ctx, value)?;
1947 write!(self.out, ")")?;
1948 }
1949 _ => {
1950 return Err(Error::Override);
1951 }
1952 }
1953
1954 Ok(())
1955 }
1956
1957 pub(super) fn put_expression(
1969 &mut self,
1970 expr_handle: Handle<crate::Expression>,
1971 context: &ExpressionContext,
1972 is_scoped: bool,
1973 ) -> BackendResult {
1974 #[cfg(test)]
1976 self.put_expression_stack_pointers
1977 .insert(ptr::from_ref(&expr_handle).cast());
1978
1979 if let Some(name) = self.named_expressions.get(&expr_handle) {
1980 write!(self.out, "{name}")?;
1981 return Ok(());
1982 }
1983
1984 let expression = &context.function.expressions[expr_handle];
1985 match *expression {
1986 crate::Expression::Literal(_)
1987 | crate::Expression::Constant(_)
1988 | crate::Expression::ZeroValue(_)
1989 | crate::Expression::Compose { .. }
1990 | crate::Expression::Splat { .. } => {
1991 self.put_possibly_const_expression(
1992 expr_handle,
1993 &context.function.expressions,
1994 context.module,
1995 context.mod_info,
1996 context,
1997 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1998 |writer, context, expr| writer.put_expression(expr, context, true),
1999 )?;
2000 }
2001 crate::Expression::Override(_) => return Err(Error::Override),
2002 crate::Expression::Access { base, .. }
2003 | crate::Expression::AccessIndex { base, .. } => {
2004 let policy = context.choose_bounds_check_policy(base);
2010 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2011 && self.put_bounds_checks(
2012 expr_handle,
2013 context,
2014 back::Level(0),
2015 if is_scoped { "" } else { "(" },
2016 )?
2017 {
2018 write!(self.out, " ? ")?;
2019 self.put_access_chain(expr_handle, policy, context)?;
2020 write!(self.out, " : ")?;
2021
2022 if context.resolve_type(base).pointer_space().is_some() {
2023 let result_ty = context.info[expr_handle]
2027 .ty
2028 .inner_with(&context.module.types)
2029 .pointer_base_type();
2030 let result_ty_handle = match result_ty {
2031 Some(TypeResolution::Handle(handle)) => handle,
2032 Some(TypeResolution::Value(_)) => {
2033 unreachable!(
2041 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2042 );
2043 }
2044 None => {
2045 unreachable!(
2046 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2047 )
2048 }
2049 };
2050 let name_key =
2051 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2052 self.out.write_str(&self.names[&name_key])?;
2053 } else {
2054 write!(self.out, "DefaultConstructible()")?;
2055 }
2056
2057 if !is_scoped {
2058 write!(self.out, ")")?;
2059 }
2060 } else {
2061 self.put_access_chain(expr_handle, policy, context)?;
2062 }
2063 }
2064 crate::Expression::Swizzle {
2065 size,
2066 vector,
2067 pattern,
2068 } => {
2069 self.put_wrapped_expression_for_packed_vec3_access(
2070 vector,
2071 context,
2072 false,
2073 &Self::put_expression,
2074 )?;
2075 write!(self.out, ".")?;
2076 for &sc in pattern[..size as usize].iter() {
2077 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2078 }
2079 }
2080 crate::Expression::FunctionArgument(index) => {
2081 let name_key = match context.origin {
2082 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2083 FunctionOrigin::EntryPoint(ep_index) => {
2084 NameKey::EntryPointArgument(ep_index, index)
2085 }
2086 };
2087 let name = &self.names[&name_key];
2088 write!(self.out, "{name}")?;
2089 }
2090 crate::Expression::GlobalVariable(handle) => {
2091 let name = &self.names[&NameKey::GlobalVariable(handle)];
2092 write!(self.out, "{name}")?;
2093 }
2094 crate::Expression::LocalVariable(handle) => {
2095 let name_key = NameKey::local(context.origin, handle);
2096 let name = &self.names[&name_key];
2097 write!(self.out, "{name}")?;
2098 }
2099 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2100 crate::Expression::ImageSample {
2101 coordinate,
2102 image,
2103 sampler,
2104 clamp_to_edge: true,
2105 gather: None,
2106 array_index: None,
2107 offset: None,
2108 level: crate::SampleLevel::Zero,
2109 depth_ref: None,
2110 } => {
2111 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2112 self.put_expression(image, context, true)?;
2113 write!(self.out, ", ")?;
2114 self.put_expression(sampler, context, true)?;
2115 write!(self.out, ", ")?;
2116 self.put_expression(coordinate, context, true)?;
2117 write!(self.out, ")")?;
2118 }
2119 crate::Expression::ImageSample {
2120 image,
2121 sampler,
2122 gather,
2123 coordinate,
2124 array_index,
2125 offset,
2126 level,
2127 depth_ref,
2128 clamp_to_edge,
2129 } => {
2130 if clamp_to_edge {
2131 return Err(Error::GenericValidation(
2132 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2133 ));
2134 }
2135
2136 let main_op = match gather {
2137 Some(_) => "gather",
2138 None => "sample",
2139 };
2140 let comparison_op = match depth_ref {
2141 Some(_) => "_compare",
2142 None => "",
2143 };
2144 self.put_expression(image, context, false)?;
2145 write!(self.out, ".{main_op}{comparison_op}(")?;
2146 self.put_expression(sampler, context, true)?;
2147 write!(self.out, ", ")?;
2148 self.put_expression(coordinate, context, true)?;
2149 if let Some(expr) = array_index {
2150 write!(self.out, ", ")?;
2151 self.put_expression(expr, context, true)?;
2152 }
2153 if let Some(dref) = depth_ref {
2154 write!(self.out, ", ")?;
2155 self.put_expression(dref, context, true)?;
2156 }
2157
2158 self.put_image_sample_level(image, level, context)?;
2159
2160 if let Some(offset) = offset {
2161 write!(self.out, ", ")?;
2162 self.put_expression(offset, context, true)?;
2163 }
2164
2165 match gather {
2166 None | Some(crate::SwizzleComponent::X) => {}
2167 Some(component) => {
2168 let is_cube_map = match *context.resolve_type(image) {
2169 crate::TypeInner::Image {
2170 dim: crate::ImageDimension::Cube,
2171 ..
2172 } => true,
2173 _ => false,
2174 };
2175 if offset.is_none() && !is_cube_map {
2178 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2179 }
2180 let letter = back::COMPONENTS[component as usize];
2181 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2182 }
2183 }
2184 write!(self.out, ")")?;
2185 }
2186 crate::Expression::ImageLoad {
2187 image,
2188 coordinate,
2189 array_index,
2190 sample,
2191 level,
2192 } => {
2193 let address = TexelAddress {
2194 coordinate,
2195 array_index,
2196 sample,
2197 level: level.map(LevelOfDetail::Direct),
2198 };
2199 self.put_image_load(expr_handle, image, address, context)?;
2200 }
2201 crate::Expression::ImageQuery { image, query } => match query {
2204 crate::ImageQuery::Size { level } => {
2205 self.put_image_size_query(
2206 image,
2207 level.map(LevelOfDetail::Direct),
2208 crate::ScalarKind::Uint,
2209 context,
2210 )?;
2211 }
2212 crate::ImageQuery::NumLevels => {
2213 self.put_expression(image, context, false)?;
2214 write!(self.out, ".get_num_mip_levels()")?;
2215 }
2216 crate::ImageQuery::NumLayers => {
2217 self.put_expression(image, context, false)?;
2218 write!(self.out, ".get_array_size()")?;
2219 }
2220 crate::ImageQuery::NumSamples => {
2221 self.put_expression(image, context, false)?;
2222 write!(self.out, ".get_num_samples()")?;
2223 }
2224 },
2225 crate::Expression::Unary { op, expr } => {
2226 let op_str = match op {
2227 crate::UnaryOperator::Negate => {
2228 match context.resolve_type(expr).scalar_kind() {
2229 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2230 _ => "-",
2231 }
2232 }
2233 crate::UnaryOperator::LogicalNot => "!",
2234 crate::UnaryOperator::BitwiseNot => "~",
2235 };
2236 write!(self.out, "{op_str}(")?;
2237 self.put_expression(expr, context, false)?;
2238 write!(self.out, ")")?;
2239 }
2240 crate::Expression::Binary { op, left, right } => {
2241 let kind = context
2242 .resolve_type(left)
2243 .scalar_kind()
2244 .ok_or(Error::UnsupportedBinaryOp(op))?;
2245
2246 if op == crate::BinaryOperator::Divide
2247 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2248 {
2249 write!(self.out, "{DIV_FUNCTION}(")?;
2250 self.put_expression(left, context, true)?;
2251 write!(self.out, ", ")?;
2252 self.put_expression(right, context, true)?;
2253 write!(self.out, ")")?;
2254 } else if op == crate::BinaryOperator::Modulo
2255 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2256 {
2257 write!(self.out, "{MOD_FUNCTION}(")?;
2258 self.put_expression(left, context, true)?;
2259 write!(self.out, ", ")?;
2260 self.put_expression(right, context, true)?;
2261 write!(self.out, ")")?;
2262 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2263 write!(self.out, "{NAMESPACE}::fmod(")?;
2268 self.put_expression(left, context, true)?;
2269 write!(self.out, ", ")?;
2270 self.put_expression(right, context, true)?;
2271 write!(self.out, ")")?;
2272 } else if (op == crate::BinaryOperator::Add
2273 || op == crate::BinaryOperator::Subtract
2274 || op == crate::BinaryOperator::Multiply)
2275 && kind == crate::ScalarKind::Sint
2276 {
2277 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2278 crate::TypeInner::Scalar(scalar) => {
2279 Ok(crate::TypeInner::Scalar(crate::Scalar {
2280 kind: crate::ScalarKind::Uint,
2281 ..scalar
2282 }))
2283 }
2284 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2285 size,
2286 scalar: crate::Scalar {
2287 kind: crate::ScalarKind::Uint,
2288 ..scalar
2289 },
2290 }),
2291 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2292 };
2293
2294 self.put_bitcasted_expression(
2299 context.resolve_type(expr_handle),
2300 expr_handle,
2301 context,
2302 &|writer, context, is_scoped| {
2303 writer.put_binop(
2304 op,
2305 left,
2306 right,
2307 context,
2308 is_scoped,
2309 &|writer, expr, context, _is_scoped| {
2310 writer.put_bitcasted_expression(
2311 &to_unsigned(context.resolve_type(expr))?,
2312 expr,
2313 context,
2314 &|writer, context, is_scoped| {
2315 writer.put_expression(expr, context, is_scoped)
2316 },
2317 )
2318 },
2319 )
2320 },
2321 )?;
2322 } else {
2323 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2324 }
2325 }
2326 crate::Expression::Select {
2327 condition,
2328 accept,
2329 reject,
2330 } => match *context.resolve_type(condition) {
2331 crate::TypeInner::Scalar(crate::Scalar {
2332 kind: crate::ScalarKind::Bool,
2333 ..
2334 }) => {
2335 if !is_scoped {
2336 write!(self.out, "(")?;
2337 }
2338 self.put_expression(condition, context, false)?;
2339 write!(self.out, " ? ")?;
2340 self.put_expression(accept, context, false)?;
2341 write!(self.out, " : ")?;
2342 self.put_expression(reject, context, false)?;
2343 if !is_scoped {
2344 write!(self.out, ")")?;
2345 }
2346 }
2347 crate::TypeInner::Vector {
2348 scalar:
2349 crate::Scalar {
2350 kind: crate::ScalarKind::Bool,
2351 ..
2352 },
2353 ..
2354 } => {
2355 write!(self.out, "{NAMESPACE}::select(")?;
2356 self.put_expression(reject, context, true)?;
2357 write!(self.out, ", ")?;
2358 self.put_expression(accept, context, true)?;
2359 write!(self.out, ", ")?;
2360 self.put_expression(condition, context, true)?;
2361 write!(self.out, ")")?;
2362 }
2363 ref ty => {
2364 return Err(Error::GenericValidation(format!(
2365 "Expected select condition to be a non-bool type, got {ty:?}",
2366 )))
2367 }
2368 },
2369 crate::Expression::Derivative { axis, expr, .. } => {
2370 use crate::DerivativeAxis as Axis;
2371 let op = match axis {
2372 Axis::X => "dfdx",
2373 Axis::Y => "dfdy",
2374 Axis::Width => "fwidth",
2375 };
2376 write!(self.out, "{NAMESPACE}::{op}")?;
2377 self.put_call_parameters(iter::once(expr), context)?;
2378 }
2379 crate::Expression::Relational { fun, argument } => {
2380 let op = match fun {
2381 crate::RelationalFunction::Any => "any",
2382 crate::RelationalFunction::All => "all",
2383 crate::RelationalFunction::IsNan => "isnan",
2384 crate::RelationalFunction::IsInf => "isinf",
2385 };
2386 write!(self.out, "{NAMESPACE}::{op}")?;
2387 self.put_call_parameters(iter::once(argument), context)?;
2388 }
2389 crate::Expression::Math {
2390 fun,
2391 arg,
2392 arg1,
2393 arg2,
2394 arg3,
2395 } => {
2396 use crate::MathFunction as Mf;
2397
2398 let arg_type = context.resolve_type(arg);
2399 let scalar_argument = match arg_type {
2400 &crate::TypeInner::Scalar(_) => true,
2401 _ => false,
2402 };
2403
2404 let fun_name = match fun {
2405 Mf::Abs => "abs",
2407 Mf::Min => "min",
2408 Mf::Max => "max",
2409 Mf::Clamp => "clamp",
2410 Mf::Saturate => "saturate",
2411 Mf::Cos => "cos",
2413 Mf::Cosh => "cosh",
2414 Mf::Sin => "sin",
2415 Mf::Sinh => "sinh",
2416 Mf::Tan => "tan",
2417 Mf::Tanh => "tanh",
2418 Mf::Acos => "acos",
2419 Mf::Asin => "asin",
2420 Mf::Atan => "atan",
2421 Mf::Atan2 => "atan2",
2422 Mf::Asinh => "asinh",
2423 Mf::Acosh => "acosh",
2424 Mf::Atanh => "atanh",
2425 Mf::Radians => "",
2426 Mf::Degrees => "",
2427 Mf::Ceil => "ceil",
2429 Mf::Floor => "floor",
2430 Mf::Round => "rint",
2431 Mf::Fract => "fract",
2432 Mf::Trunc => "trunc",
2433 Mf::Modf => MODF_FUNCTION,
2434 Mf::Frexp => FREXP_FUNCTION,
2435 Mf::Ldexp => "ldexp",
2436 Mf::Exp => "exp",
2438 Mf::Exp2 => "exp2",
2439 Mf::Log => "log",
2440 Mf::Log2 => "log2",
2441 Mf::Pow => "pow",
2442 Mf::Dot => match *context.resolve_type(arg) {
2444 crate::TypeInner::Vector {
2445 scalar:
2446 crate::Scalar {
2447 kind: crate::ScalarKind::Float,
2449 ..
2450 },
2451 ..
2452 } => "dot",
2453 crate::TypeInner::Vector {
2454 size,
2455 scalar:
2456 scalar @ crate::Scalar {
2457 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2458 ..
2459 },
2460 } => {
2461 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2463 write!(self.out, "{fun_name}(")?;
2464 self.put_expression(arg, context, true)?;
2465 write!(self.out, ", ")?;
2466 self.put_expression(arg1.unwrap(), context, true)?;
2467 write!(self.out, ")")?;
2468 return Ok(());
2469 }
2470 _ => unreachable!(
2471 "Correct TypeInner for dot product should be already validated"
2472 ),
2473 },
2474 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2475 if context.lang_version >= (2, 1) {
2476 let packed_type = match fun {
2480 Mf::Dot4I8Packed => "packed_char4",
2481 Mf::Dot4U8Packed => "packed_uchar4",
2482 _ => unreachable!(),
2483 };
2484
2485 return self.put_dot_product(
2486 Reinterpreted::new(packed_type, arg),
2487 Reinterpreted::new(packed_type, arg1.unwrap()),
2488 4,
2489 |writer, arg, index| {
2490 write!(writer.out, "{arg}[{index}]")?;
2493 Ok(())
2494 },
2495 );
2496 } else {
2497 let conversion = match fun {
2501 Mf::Dot4I8Packed => "int",
2502 Mf::Dot4U8Packed => "",
2503 _ => unreachable!(),
2504 };
2505
2506 return self.put_dot_product(
2507 arg,
2508 arg1.unwrap(),
2509 4,
2510 |writer, arg, index| {
2511 write!(writer.out, "({conversion}(")?;
2512 writer.put_expression(arg, context, true)?;
2513 if index == 3 {
2514 write!(writer.out, ") >> 24)")?;
2515 } else {
2516 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2517 }
2518 Ok(())
2519 },
2520 );
2521 }
2522 }
2523 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2524 Mf::Cross => "cross",
2525 Mf::Distance => "distance",
2526 Mf::Length if scalar_argument => "abs",
2527 Mf::Length => "length",
2528 Mf::Normalize => "normalize",
2529 Mf::FaceForward => "faceforward",
2530 Mf::Reflect => "reflect",
2531 Mf::Refract => "refract",
2532 Mf::Sign => match arg_type.scalar_kind() {
2534 Some(crate::ScalarKind::Sint) => {
2535 return self.put_isign(arg, context);
2536 }
2537 _ => "sign",
2538 },
2539 Mf::Fma => "fma",
2540 Mf::Mix => "mix",
2541 Mf::Step => "step",
2542 Mf::SmoothStep => "smoothstep",
2543 Mf::Sqrt => "sqrt",
2544 Mf::InverseSqrt => "rsqrt",
2545 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2546 Mf::Transpose => "transpose",
2547 Mf::Determinant => "determinant",
2548 Mf::QuantizeToF16 => "",
2549 Mf::CountTrailingZeros => "ctz",
2551 Mf::CountLeadingZeros => "clz",
2552 Mf::CountOneBits => "popcount",
2553 Mf::ReverseBits => "reverse_bits",
2554 Mf::ExtractBits => "",
2555 Mf::InsertBits => "",
2556 Mf::FirstTrailingBit => "",
2557 Mf::FirstLeadingBit => "",
2558 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2560 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2561 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2562 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2563 Mf::Pack2x16float => "",
2564 Mf::Pack4xI8 => "",
2565 Mf::Pack4xU8 => "",
2566 Mf::Pack4xI8Clamp => "",
2567 Mf::Pack4xU8Clamp => "",
2568 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2570 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2571 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2572 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2573 Mf::Unpack2x16float => "",
2574 Mf::Unpack4xI8 => "",
2575 Mf::Unpack4xU8 => "",
2576 };
2577
2578 match fun {
2579 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2580 if context.lang_version < (1, 2) {
2589 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2590 }
2591 }
2592 _ => {}
2593 }
2594
2595 match fun {
2596 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2597 write!(self.out, "{ABS_FUNCTION}(")?;
2598 self.put_expression(arg, context, true)?;
2599 write!(self.out, ")")?;
2600 }
2601 Mf::Distance if scalar_argument => {
2602 write!(self.out, "{NAMESPACE}::abs(")?;
2603 self.put_expression(arg, context, false)?;
2604 write!(self.out, " - ")?;
2605 self.put_expression(arg1.unwrap(), context, false)?;
2606 write!(self.out, ")")?;
2607 }
2608 Mf::FirstTrailingBit => {
2609 let scalar = context.resolve_type(arg).scalar().unwrap();
2610 let constant = scalar.width * 8 + 1;
2611
2612 write!(self.out, "((({NAMESPACE}::ctz(")?;
2613 self.put_expression(arg, context, true)?;
2614 write!(self.out, ") + 1) % {constant}) - 1)")?;
2615 }
2616 Mf::FirstLeadingBit => {
2617 let inner = context.resolve_type(arg);
2618 let scalar = inner.scalar().unwrap();
2619 let constant = scalar.width * 8 - 1;
2620
2621 write!(
2622 self.out,
2623 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2624 )?;
2625
2626 if scalar.kind == crate::ScalarKind::Sint {
2627 write!(self.out, "{NAMESPACE}::select(")?;
2628 self.put_expression(arg, context, true)?;
2629 write!(self.out, ", ~")?;
2630 self.put_expression(arg, context, true)?;
2631 write!(self.out, ", ")?;
2632 self.put_expression(arg, context, true)?;
2633 write!(self.out, " < 0)")?;
2634 } else {
2635 self.put_expression(arg, context, true)?;
2636 }
2637
2638 write!(self.out, "), ")?;
2639
2640 match *inner {
2642 crate::TypeInner::Vector { size, scalar } => {
2643 let size = common::vector_size_str(size);
2644 let name = scalar.to_msl_name();
2645 write!(self.out, "{name}{size}")?;
2646 }
2647 crate::TypeInner::Scalar(scalar) => {
2648 let name = scalar.to_msl_name();
2649 write!(self.out, "{name}")?;
2650 }
2651 _ => (),
2652 }
2653
2654 write!(self.out, "(-1), ")?;
2655 self.put_expression(arg, context, true)?;
2656 write!(self.out, " == 0")?;
2657 if scalar.kind == crate::ScalarKind::Sint {
2658 write!(self.out, " || ")?;
2659 self.put_expression(arg, context, true)?;
2660 write!(self.out, " == -1")?;
2661 }
2662 write!(self.out, ")")?;
2663 }
2664 Mf::Unpack2x16float => {
2665 write!(self.out, "float2(as_type<half2>(")?;
2666 self.put_expression(arg, context, false)?;
2667 write!(self.out, "))")?;
2668 }
2669 Mf::Pack2x16float => {
2670 write!(self.out, "as_type<uint>(half2(")?;
2671 self.put_expression(arg, context, false)?;
2672 write!(self.out, "))")?;
2673 }
2674 Mf::ExtractBits => {
2675 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2692
2693 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2694 self.put_expression(arg, context, true)?;
2695 write!(self.out, ", {NAMESPACE}::min(")?;
2696 self.put_expression(arg1.unwrap(), context, true)?;
2697 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2698 self.put_expression(arg2.unwrap(), context, true)?;
2699 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2700 self.put_expression(arg1.unwrap(), context, true)?;
2701 write!(self.out, ", {scalar_bits}u)))")?;
2702 }
2703 Mf::InsertBits => {
2704 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2709
2710 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2711 self.put_expression(arg, context, true)?;
2712 write!(self.out, ", ")?;
2713 self.put_expression(arg1.unwrap(), context, true)?;
2714 write!(self.out, ", {NAMESPACE}::min(")?;
2715 self.put_expression(arg2.unwrap(), context, true)?;
2716 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2717 self.put_expression(arg3.unwrap(), context, true)?;
2718 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2719 self.put_expression(arg2.unwrap(), context, true)?;
2720 write!(self.out, ", {scalar_bits}u)))")?;
2721 }
2722 Mf::Radians => {
2723 write!(self.out, "((")?;
2724 self.put_expression(arg, context, false)?;
2725 write!(self.out, ") * 0.017453292519943295474)")?;
2726 }
2727 Mf::Degrees => {
2728 write!(self.out, "((")?;
2729 self.put_expression(arg, context, false)?;
2730 write!(self.out, ") * 57.295779513082322865)")?;
2731 }
2732 Mf::Modf | Mf::Frexp => {
2733 write!(self.out, "{fun_name}")?;
2734 self.put_call_parameters(iter::once(arg), context)?;
2735 }
2736 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2737 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2738 Mf::Pack4xI8Clamp => {
2739 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2740 }
2741 Mf::Pack4xU8Clamp => {
2742 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2743 }
2744 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2745 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2746 "u"
2747 } else {
2748 ""
2749 };
2750
2751 if context.lang_version >= (2, 1) {
2752 write!(
2754 self.out,
2755 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2756 )?;
2757 self.put_expression(arg, context, true)?;
2758 write!(self.out, "))")?;
2759 } else {
2760 write!(self.out, "({sign_prefix}int4(")?;
2762 self.put_expression(arg, context, true)?;
2763 write!(self.out, ", ")?;
2764 self.put_expression(arg, context, true)?;
2765 write!(self.out, " >> 8, ")?;
2766 self.put_expression(arg, context, true)?;
2767 write!(self.out, " >> 16, ")?;
2768 self.put_expression(arg, context, true)?;
2769 write!(self.out, " >> 24) << 24 >> 24)")?;
2770 }
2771 }
2772 Mf::QuantizeToF16 => {
2773 match *context.resolve_type(arg) {
2774 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2775 crate::TypeInner::Vector { size, .. } => write!(
2776 self.out,
2777 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2778 size = common::vector_size_str(size),
2779 )?,
2780 _ => unreachable!(
2781 "Correct TypeInner for QuantizeToF16 should be already validated"
2782 ),
2783 };
2784
2785 self.put_expression(arg, context, true)?;
2786 write!(self.out, "))")?;
2787 }
2788 _ => {
2789 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2790 self.put_call_parameters(
2791 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2792 context,
2793 )?;
2794 }
2795 }
2796 }
2797 crate::Expression::As {
2798 expr,
2799 kind,
2800 convert,
2801 } => match *context.resolve_type(expr) {
2802 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2803 if src.kind == crate::ScalarKind::Float
2804 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2805 && convert.is_some()
2806 {
2807 let fun_name = match (kind, convert) {
2811 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2812 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2813 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2814 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2815 _ => unreachable!(),
2816 };
2817 write!(self.out, "{fun_name}(")?;
2818 self.put_expression(expr, context, true)?;
2819 write!(self.out, ")")?;
2820 } else {
2821 let target_scalar = crate::Scalar {
2822 kind,
2823 width: convert.unwrap_or(src.width),
2824 };
2825 let op = match convert {
2826 Some(_) => "static_cast",
2827 None => "as_type",
2828 };
2829 write!(self.out, "{op}<")?;
2830 match *context.resolve_type(expr) {
2831 crate::TypeInner::Vector { size, .. } => {
2832 put_numeric_type(&mut self.out, target_scalar, &[size])?
2833 }
2834 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2835 };
2836 write!(self.out, ">(")?;
2837 self.put_expression(expr, context, true)?;
2838 write!(self.out, ")")?;
2839 }
2840 }
2841 crate::TypeInner::Matrix {
2842 columns,
2843 rows,
2844 scalar,
2845 } => {
2846 let target_scalar = crate::Scalar {
2847 kind,
2848 width: convert.unwrap_or(scalar.width),
2849 };
2850 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2851 write!(self.out, "(")?;
2852 self.put_expression(expr, context, true)?;
2853 write!(self.out, ")")?;
2854 }
2855 ref ty => {
2856 return Err(Error::GenericValidation(format!(
2857 "Unsupported type for As: {ty:?}"
2858 )))
2859 }
2860 },
2861 crate::Expression::CallResult(_)
2863 | crate::Expression::AtomicResult { .. }
2864 | crate::Expression::WorkGroupUniformLoadResult { .. }
2865 | crate::Expression::SubgroupBallotResult
2866 | crate::Expression::SubgroupOperationResult { .. }
2867 | crate::Expression::RayQueryProceedResult => {
2868 unreachable!()
2869 }
2870 crate::Expression::ArrayLength(expr) => {
2871 let global = match context.function.expressions[expr] {
2873 crate::Expression::AccessIndex { base, .. } => {
2874 match context.function.expressions[base] {
2875 crate::Expression::GlobalVariable(handle) => handle,
2876 ref ex => {
2877 return Err(Error::GenericValidation(format!(
2878 "Expected global variable in AccessIndex, got {ex:?}"
2879 )))
2880 }
2881 }
2882 }
2883 crate::Expression::GlobalVariable(handle) => handle,
2884 ref ex => {
2885 return Err(Error::GenericValidation(format!(
2886 "Unexpected expression in ArrayLength, got {ex:?}"
2887 )))
2888 }
2889 };
2890
2891 if !is_scoped {
2892 write!(self.out, "(")?;
2893 }
2894 write!(self.out, "1 + ")?;
2895 self.put_dynamic_array_max_index(global, context)?;
2896 if !is_scoped {
2897 write!(self.out, ")")?;
2898 }
2899 }
2900 crate::Expression::RayQueryVertexPositions { .. } => {
2901 unimplemented!()
2902 }
2903 crate::Expression::RayQueryGetIntersection {
2904 query,
2905 committed: _,
2906 } => {
2907 if context.lang_version < (2, 4) {
2908 return Err(Error::UnsupportedRayTracing);
2909 }
2910
2911 let ty = context.module.special_types.ray_intersection.unwrap();
2912 let type_name = &self.names[&NameKey::Type(ty)];
2913 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2914 self.put_expression(query, context, true)?;
2915 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2916 let fields = [
2917 "distance",
2918 "user_instance_id", "instance_id",
2920 "", "geometry_id",
2922 "primitive_id",
2923 "triangle_barycentric_coord",
2924 "triangle_front_facing",
2925 "", "object_to_world_transform", "world_to_object_transform", ];
2929 for field in fields {
2930 write!(self.out, ", ")?;
2931 if field.is_empty() {
2932 write!(self.out, "{{}}")?;
2933 } else {
2934 self.put_expression(query, context, true)?;
2935 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2936 }
2937 }
2938 write!(self.out, "}}")?;
2939 }
2940 crate::Expression::CooperativeLoad { ref data, .. } => {
2941 if context.lang_version < (2, 3) {
2942 return Err(Error::UnsupportedCooperativeMatrix);
2943 }
2944 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2945 write!(self.out, "&")?;
2946 self.put_access_chain(data.pointer, context.policies.index, context)?;
2947 write!(self.out, ", ")?;
2948 self.put_expression(data.stride, context, true)?;
2949 write!(self.out, ", {})", data.row_major)?;
2950 }
2951 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2952 if context.lang_version < (2, 3) {
2953 return Err(Error::UnsupportedCooperativeMatrix);
2954 }
2955 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2956 self.put_expression(a, context, true)?;
2957 write!(self.out, ", ")?;
2958 self.put_expression(b, context, true)?;
2959 write!(self.out, ", ")?;
2960 self.put_expression(c, context, true)?;
2961 write!(self.out, ")")?;
2962 }
2963 }
2964 Ok(())
2965 }
2966
2967 fn put_binop<F>(
2970 &mut self,
2971 op: crate::BinaryOperator,
2972 left: Handle<crate::Expression>,
2973 right: Handle<crate::Expression>,
2974 context: &ExpressionContext,
2975 is_scoped: bool,
2976 put_expression: &F,
2977 ) -> BackendResult
2978 where
2979 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2980 {
2981 let op_str = back::binary_operation_str(op);
2982
2983 if !is_scoped {
2984 write!(self.out, "(")?;
2985 }
2986
2987 if op == crate::BinaryOperator::Multiply
2990 && matches!(
2991 context.resolve_type(right),
2992 &crate::TypeInner::Matrix { .. }
2993 )
2994 {
2995 self.put_wrapped_expression_for_packed_vec3_access(
2996 left,
2997 context,
2998 false,
2999 put_expression,
3000 )?;
3001 } else {
3002 put_expression(self, left, context, false)?;
3003 }
3004
3005 write!(self.out, " {op_str} ")?;
3006
3007 if op == crate::BinaryOperator::Multiply
3009 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3010 {
3011 self.put_wrapped_expression_for_packed_vec3_access(
3012 right,
3013 context,
3014 false,
3015 put_expression,
3016 )?;
3017 } else {
3018 put_expression(self, right, context, false)?;
3019 }
3020
3021 if !is_scoped {
3022 write!(self.out, ")")?;
3023 }
3024
3025 Ok(())
3026 }
3027
3028 fn put_wrapped_expression_for_packed_vec3_access<F>(
3030 &mut self,
3031 expr_handle: Handle<crate::Expression>,
3032 context: &ExpressionContext,
3033 is_scoped: bool,
3034 put_expression: &F,
3035 ) -> BackendResult
3036 where
3037 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3038 {
3039 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3040 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3041 put_expression(self, expr_handle, context, is_scoped)?;
3042 write!(self.out, ")")?;
3043 } else {
3044 put_expression(self, expr_handle, context, is_scoped)?;
3045 }
3046 Ok(())
3047 }
3048
3049 fn put_bitcasted_expression<F>(
3052 &mut self,
3053 cast_to: &crate::TypeInner,
3054 inner_expr: Handle<crate::Expression>,
3055 context: &ExpressionContext,
3056 put_expression: &F,
3057 ) -> BackendResult
3058 where
3059 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3060 {
3061 write!(self.out, "as_type<")?;
3062 match *cast_to {
3063 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3064 crate::TypeInner::Vector { size, scalar } => {
3065 put_numeric_type(&mut self.out, scalar, &[size])?
3066 }
3067 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3068 };
3069 write!(self.out, ">(")?;
3070
3071 if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3073 put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3074 write!(self.out, "(")?;
3075 put_expression(self, context, true)?;
3076 write!(self.out, ")")?;
3077 } else {
3078 put_expression(self, context, true)?;
3079 }
3080
3081 write!(self.out, ")")?;
3082 Ok(())
3083 }
3084
3085 fn put_index(
3087 &mut self,
3088 index: index::GuardedIndex,
3089 context: &ExpressionContext,
3090 is_scoped: bool,
3091 ) -> BackendResult {
3092 match index {
3093 index::GuardedIndex::Expression(expr) => {
3094 self.put_expression(expr, context, is_scoped)?
3095 }
3096 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3097 }
3098 Ok(())
3099 }
3100
3101 fn put_bounds_checks(
3131 &mut self,
3132 chain: Handle<crate::Expression>,
3133 context: &ExpressionContext,
3134 level: back::Level,
3135 prefix: &'static str,
3136 ) -> Result<bool, Error> {
3137 let mut check_written = false;
3138
3139 for item in context.bounds_check_iter(chain) {
3141 let BoundsCheck {
3142 base,
3143 index,
3144 length,
3145 } = item;
3146
3147 if check_written {
3148 write!(self.out, " && ")?;
3149 } else {
3150 write!(self.out, "{level}{prefix}")?;
3151 check_written = true;
3152 }
3153
3154 write!(self.out, "uint(")?;
3158 self.put_index(index, context, true)?;
3159 self.out.write_str(") < ")?;
3160 match length {
3161 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3162 index::IndexableLength::Dynamic => {
3163 let global = context.function.originating_global(base).ok_or_else(|| {
3164 Error::GenericValidation("Could not find originating global".into())
3165 })?;
3166 write!(self.out, "1 + ")?;
3167 self.put_dynamic_array_max_index(global, context)?
3168 }
3169 }
3170 }
3171
3172 Ok(check_written)
3173 }
3174
3175 fn put_access_chain(
3195 &mut self,
3196 chain: Handle<crate::Expression>,
3197 policy: index::BoundsCheckPolicy,
3198 context: &ExpressionContext,
3199 ) -> BackendResult {
3200 match context.function.expressions[chain] {
3201 crate::Expression::Access { base, index } => {
3202 let mut base_ty = context.resolve_type(base);
3203
3204 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3206 base_ty = &context.module.types[base].inner;
3207 }
3208
3209 self.put_subscripted_access_chain(
3210 base,
3211 base_ty,
3212 index::GuardedIndex::Expression(index),
3213 policy,
3214 context,
3215 )?;
3216 }
3217 crate::Expression::AccessIndex { base, index } => {
3218 let base_resolution = &context.info[base].ty;
3219 let mut base_ty = base_resolution.inner_with(&context.module.types);
3220 let mut base_ty_handle = base_resolution.handle();
3221
3222 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3224 base_ty = &context.module.types[base].inner;
3225 base_ty_handle = Some(base);
3226 }
3227
3228 match *base_ty {
3232 crate::TypeInner::Struct { .. } => {
3233 let base_ty = base_ty_handle.unwrap();
3234 self.put_access_chain(base, policy, context)?;
3235 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3236 write!(self.out, ".{name}")?;
3237 }
3238 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3239 self.put_access_chain(base, policy, context)?;
3240 if context.get_packed_vec_kind(base).is_some() {
3243 write!(self.out, "[{index}]")?;
3244 } else {
3245 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3246 }
3247 }
3248 _ => {
3249 self.put_subscripted_access_chain(
3250 base,
3251 base_ty,
3252 index::GuardedIndex::Known(index),
3253 policy,
3254 context,
3255 )?;
3256 }
3257 }
3258 }
3259 _ => self.put_expression(chain, context, false)?,
3260 }
3261
3262 Ok(())
3263 }
3264
3265 fn put_subscripted_access_chain(
3282 &mut self,
3283 base: Handle<crate::Expression>,
3284 base_ty: &crate::TypeInner,
3285 index: index::GuardedIndex,
3286 policy: index::BoundsCheckPolicy,
3287 context: &ExpressionContext,
3288 ) -> BackendResult {
3289 let accessing_wrapped_array = match *base_ty {
3290 crate::TypeInner::Array {
3291 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3292 ..
3293 } => true,
3294 _ => false,
3295 };
3296 let accessing_wrapped_binding_array =
3297 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3298
3299 self.put_access_chain(base, policy, context)?;
3300 if accessing_wrapped_array {
3301 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3302 }
3303 write!(self.out, "[")?;
3304
3305 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3307 context.access_needs_check(base, index)
3308 } else {
3309 None
3310 };
3311 if let Some(limit) = restriction_needed {
3312 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3313 self.put_index(index, context, true)?;
3314 write!(self.out, "), ")?;
3315 match limit {
3316 index::IndexableLength::Known(limit) => {
3317 write!(self.out, "{}u", limit - 1)?;
3318 }
3319 index::IndexableLength::Dynamic => {
3320 let global = context.function.originating_global(base).ok_or_else(|| {
3321 Error::GenericValidation("Could not find originating global".into())
3322 })?;
3323 self.put_dynamic_array_max_index(global, context)?;
3324 }
3325 }
3326 write!(self.out, ")")?;
3327 } else {
3328 self.put_index(index, context, true)?;
3329 }
3330
3331 write!(self.out, "]")?;
3332
3333 if accessing_wrapped_binding_array {
3334 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3335 }
3336
3337 Ok(())
3338 }
3339
3340 fn put_load(
3341 &mut self,
3342 pointer: Handle<crate::Expression>,
3343 context: &ExpressionContext,
3344 is_scoped: bool,
3345 ) -> BackendResult {
3346 let policy = context.choose_bounds_check_policy(pointer);
3349 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3350 && self.put_bounds_checks(
3351 pointer,
3352 context,
3353 back::Level(0),
3354 if is_scoped { "" } else { "(" },
3355 )?
3356 {
3357 write!(self.out, " ? ")?;
3358 self.put_unchecked_load(pointer, policy, context)?;
3359 write!(self.out, " : DefaultConstructible()")?;
3360
3361 if !is_scoped {
3362 write!(self.out, ")")?;
3363 }
3364 } else {
3365 self.put_unchecked_load(pointer, policy, context)?;
3366 }
3367
3368 Ok(())
3369 }
3370
3371 fn put_unchecked_load(
3372 &mut self,
3373 pointer: Handle<crate::Expression>,
3374 policy: index::BoundsCheckPolicy,
3375 context: &ExpressionContext,
3376 ) -> BackendResult {
3377 let is_atomic_pointer = context
3378 .resolve_type(pointer)
3379 .is_atomic_pointer(&context.module.types);
3380
3381 if is_atomic_pointer {
3382 write!(
3383 self.out,
3384 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3385 )?;
3386 self.put_access_chain(pointer, policy, context)?;
3387 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3388 } else {
3389 self.put_access_chain(pointer, policy, context)?;
3393 }
3394
3395 Ok(())
3396 }
3397
3398 fn put_return_value(
3399 &mut self,
3400 level: back::Level,
3401 expr_handle: Handle<crate::Expression>,
3402 result_struct: Option<&str>,
3403 context: &ExpressionContext,
3404 ) -> BackendResult {
3405 match result_struct {
3406 Some(struct_name) => {
3407 let mut has_point_size = false;
3408 let result_ty = context.function.result.as_ref().unwrap().ty;
3409 match context.module.types[result_ty].inner {
3410 crate::TypeInner::Struct { ref members, .. } => {
3411 let tmp = "_tmp";
3412 write!(self.out, "{level}const auto {tmp} = ")?;
3413 self.put_expression(expr_handle, context, true)?;
3414 writeln!(self.out, ";")?;
3415 write!(self.out, "{level}return {struct_name} {{")?;
3416
3417 let mut is_first = true;
3418
3419 for (index, member) in members.iter().enumerate() {
3420 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3421 member.binding
3422 {
3423 has_point_size = true;
3424 if !context.pipeline_options.allow_and_force_point_size {
3425 continue;
3426 }
3427 }
3428
3429 let comma = if is_first { "" } else { "," };
3430 is_first = false;
3431 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3432 if let crate::TypeInner::Array {
3436 size: crate::ArraySize::Constant(size),
3437 ..
3438 } = context.module.types[member.ty].inner
3439 {
3440 write!(self.out, "{comma} {{")?;
3441 for j in 0..size.get() {
3442 if j != 0 {
3443 write!(self.out, ",")?;
3444 }
3445 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3446 }
3447 write!(self.out, "}}")?;
3448 } else {
3449 write!(self.out, "{comma} {tmp}.{name}")?;
3450 }
3451 }
3452 }
3453 _ => {
3454 write!(self.out, "{level}return {struct_name} {{ ")?;
3455 self.put_expression(expr_handle, context, true)?;
3456 }
3457 }
3458
3459 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3460 let stage = context.module.entry_points[ep_index as usize].stage;
3461 if context.pipeline_options.allow_and_force_point_size
3462 && stage == crate::ShaderStage::Vertex
3463 && !has_point_size
3464 {
3465 write!(self.out, ", 1.0")?;
3467 }
3468 }
3469 write!(self.out, " }}")?;
3470 }
3471 None => {
3472 write!(self.out, "{level}return ")?;
3473 self.put_expression(expr_handle, context, true)?;
3474 }
3475 }
3476 writeln!(self.out, ";")?;
3477 Ok(())
3478 }
3479
3480 fn update_expressions_to_bake(
3485 &mut self,
3486 func: &crate::Function,
3487 info: &valid::FunctionInfo,
3488 context: &ExpressionContext,
3489 ) {
3490 use crate::Expression;
3491 self.need_bake_expressions.clear();
3492
3493 for (expr_handle, expr) in func.expressions.iter() {
3494 let expr_info = &info[expr_handle];
3497 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3498 if min_ref_count <= expr_info.ref_count {
3499 self.need_bake_expressions.insert(expr_handle);
3500 } else {
3501 match expr_info.ty {
3502 TypeResolution::Handle(h)
3504 if Some(h) == context.module.special_types.ray_desc =>
3505 {
3506 self.need_bake_expressions.insert(expr_handle);
3507 }
3508 _ => {}
3509 }
3510 }
3511
3512 if let Expression::Math {
3513 fun,
3514 arg,
3515 arg1,
3516 arg2,
3517 ..
3518 } = *expr
3519 {
3520 match fun {
3521 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3531 self.need_bake_expressions.insert(arg);
3532 self.need_bake_expressions.insert(arg1.unwrap());
3533 }
3534 crate::MathFunction::FirstLeadingBit => {
3535 self.need_bake_expressions.insert(arg);
3536 }
3537 crate::MathFunction::Pack4xI8
3538 | crate::MathFunction::Pack4xU8
3539 | crate::MathFunction::Pack4xI8Clamp
3540 | crate::MathFunction::Pack4xU8Clamp
3541 | crate::MathFunction::Unpack4xI8
3542 | crate::MathFunction::Unpack4xU8 => {
3543 if context.lang_version < (2, 1) {
3546 self.need_bake_expressions.insert(arg);
3547 }
3548 }
3549 crate::MathFunction::ExtractBits => {
3550 self.need_bake_expressions.insert(arg1.unwrap());
3552 }
3553 crate::MathFunction::InsertBits => {
3554 self.need_bake_expressions.insert(arg2.unwrap());
3556 }
3557 crate::MathFunction::Sign => {
3558 let inner = context.resolve_type(expr_handle);
3563 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3564 self.need_bake_expressions.insert(arg);
3565 }
3566 }
3567 _ => {}
3568 }
3569 }
3570 }
3571 }
3572
3573 fn start_baking_expression(
3574 &mut self,
3575 handle: Handle<crate::Expression>,
3576 context: &ExpressionContext,
3577 name: &str,
3578 ) -> BackendResult {
3579 match context.info[handle].ty {
3580 TypeResolution::Handle(ty_handle) => {
3581 let ty_name = TypeContext {
3582 handle: ty_handle,
3583 gctx: context.module.to_ctx(),
3584 names: &self.names,
3585 access: crate::StorageAccess::empty(),
3586 first_time: false,
3587 };
3588 write!(self.out, "{ty_name}")?;
3589 }
3590 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3591 put_numeric_type(&mut self.out, scalar, &[])?;
3592 }
3593 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3594 put_numeric_type(&mut self.out, scalar, &[size])?;
3595 }
3596 TypeResolution::Value(crate::TypeInner::Matrix {
3597 columns,
3598 rows,
3599 scalar,
3600 }) => {
3601 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3602 }
3603 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3604 columns,
3605 rows,
3606 scalar,
3607 role: _,
3608 }) => {
3609 write!(
3610 self.out,
3611 "{}::simdgroup_{}{}x{}",
3612 NAMESPACE,
3613 scalar.to_msl_name(),
3614 columns as u32,
3615 rows as u32,
3616 )?;
3617 }
3618 TypeResolution::Value(ref other) => {
3619 log::warn!("Type {other:?} isn't a known local");
3620 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3621 }
3622 }
3623
3624 write!(self.out, " {name} = ")?;
3626
3627 Ok(())
3628 }
3629
3630 fn put_cache_restricted_level(
3643 &mut self,
3644 load: Handle<crate::Expression>,
3645 image: Handle<crate::Expression>,
3646 mip_level: Option<Handle<crate::Expression>>,
3647 indent: back::Level,
3648 context: &StatementContext,
3649 ) -> BackendResult {
3650 let level_of_detail = match mip_level {
3653 Some(level) => level,
3654 None => return Ok(()),
3655 };
3656
3657 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3658 || !context.expression.image_needs_lod(image)
3659 {
3660 return Ok(());
3661 }
3662
3663 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3664 self.put_restricted_scalar_image_index(
3665 image,
3666 level_of_detail,
3667 "get_num_mip_levels",
3668 &context.expression,
3669 )?;
3670 writeln!(self.out, ";")?;
3671
3672 Ok(())
3673 }
3674
3675 fn put_casting_to_packed_chars(
3681 &mut self,
3682 fun: crate::MathFunction,
3683 arg0: Handle<crate::Expression>,
3684 arg1: Handle<crate::Expression>,
3685 indent: back::Level,
3686 context: &StatementContext<'_>,
3687 ) -> Result<(), Error> {
3688 let packed_type = match fun {
3689 crate::MathFunction::Dot4I8Packed => "packed_char4",
3690 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3691 _ => unreachable!(),
3692 };
3693
3694 for arg in [arg0, arg1] {
3695 write!(
3696 self.out,
3697 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3698 Reinterpreted::new(packed_type, arg)
3699 )?;
3700 self.put_expression(arg, &context.expression, true)?;
3701 writeln!(self.out, ");")?;
3702 }
3703
3704 Ok(())
3705 }
3706
3707 fn put_block(
3708 &mut self,
3709 level: back::Level,
3710 statements: &[crate::Statement],
3711 context: &StatementContext,
3712 ) -> BackendResult {
3713 #[cfg(test)]
3715 self.put_block_stack_pointers
3716 .insert(ptr::from_ref(&level).cast());
3717
3718 for statement in statements {
3719 log::trace!("statement[{}] {:?}", level.0, statement);
3720 match *statement {
3721 crate::Statement::Emit(ref range) => {
3722 for handle in range.clone() {
3723 use crate::MathFunction as Mf;
3724
3725 match context.expression.function.expressions[handle] {
3726 crate::Expression::ImageLoad {
3729 image,
3730 level: mip_level,
3731 ..
3732 } => {
3733 self.put_cache_restricted_level(
3734 handle, image, mip_level, level, context,
3735 )?;
3736 }
3737
3738 crate::Expression::Math {
3747 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3748 arg,
3749 arg1,
3750 ..
3751 } if context.expression.lang_version >= (2, 1) => {
3752 self.put_casting_to_packed_chars(
3753 fun,
3754 arg,
3755 arg1.unwrap(),
3756 level,
3757 context,
3758 )?;
3759 }
3760
3761 _ => (),
3762 }
3763
3764 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3765 let expr_name = if ptr_class.is_some() {
3766 None } else if let Some(name) =
3768 context.expression.function.named_expressions.get(&handle)
3769 {
3770 Some(self.namer.call(name))
3780 } else {
3781 let bake = if context.expression.guarded_indices.contains(handle) {
3785 true
3786 } else {
3787 self.need_bake_expressions.contains(&handle)
3788 };
3789
3790 if bake {
3791 Some(Baked(handle).to_string())
3792 } else {
3793 None
3794 }
3795 };
3796
3797 if let Some(name) = expr_name {
3798 write!(self.out, "{level}")?;
3799 self.start_baking_expression(handle, &context.expression, &name)?;
3800 self.put_expression(handle, &context.expression, true)?;
3801 self.named_expressions.insert(handle, name);
3802 writeln!(self.out, ";")?;
3803 }
3804 }
3805 }
3806 crate::Statement::Block(ref block) => {
3807 if !block.is_empty() {
3808 writeln!(self.out, "{level}{{")?;
3809 self.put_block(level.next(), block, context)?;
3810 writeln!(self.out, "{level}}}")?;
3811 }
3812 }
3813 crate::Statement::If {
3814 condition,
3815 ref accept,
3816 ref reject,
3817 } => {
3818 write!(self.out, "{level}if (")?;
3819 self.put_expression(condition, &context.expression, true)?;
3820 writeln!(self.out, ") {{")?;
3821 self.put_block(level.next(), accept, context)?;
3822 if !reject.is_empty() {
3823 writeln!(self.out, "{level}}} else {{")?;
3824 self.put_block(level.next(), reject, context)?;
3825 }
3826 writeln!(self.out, "{level}}}")?;
3827 }
3828 crate::Statement::Switch {
3829 selector,
3830 ref cases,
3831 } => {
3832 write!(self.out, "{level}switch(")?;
3833 self.put_expression(selector, &context.expression, true)?;
3834 writeln!(self.out, ") {{")?;
3835 let lcase = level.next();
3836 for case in cases.iter() {
3837 match case.value {
3838 crate::SwitchValue::I32(value) => {
3839 write!(self.out, "{lcase}case {value}:")?;
3840 }
3841 crate::SwitchValue::U32(value) => {
3842 write!(self.out, "{lcase}case {value}u:")?;
3843 }
3844 crate::SwitchValue::Default => {
3845 write!(self.out, "{lcase}default:")?;
3846 }
3847 }
3848
3849 let write_block_braces = !(case.fall_through && case.body.is_empty());
3850 if write_block_braces {
3851 writeln!(self.out, " {{")?;
3852 } else {
3853 writeln!(self.out)?;
3854 }
3855
3856 self.put_block(lcase.next(), &case.body, context)?;
3857 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3858 {
3859 writeln!(self.out, "{}break;", lcase.next())?;
3860 }
3861
3862 if write_block_braces {
3863 writeln!(self.out, "{lcase}}}")?;
3864 }
3865 }
3866 writeln!(self.out, "{level}}}")?;
3867 }
3868 crate::Statement::Loop {
3869 ref body,
3870 ref continuing,
3871 break_if,
3872 } => {
3873 let force_loop_bound_statements =
3874 self.gen_force_bounded_loop_statements(level, context);
3875 let gate_name = (!continuing.is_empty() || break_if.is_some())
3876 .then(|| self.namer.call("loop_init"));
3877
3878 if let Some((ref decl, _)) = force_loop_bound_statements {
3879 writeln!(self.out, "{decl}")?;
3880 }
3881 if let Some(ref gate_name) = gate_name {
3882 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3883 }
3884
3885 writeln!(self.out, "{level}while(true) {{",)?;
3886 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3887 writeln!(self.out, "{break_and_inc}")?;
3888 }
3889 if let Some(ref gate_name) = gate_name {
3890 let lif = level.next();
3891 let lcontinuing = lif.next();
3892 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3893 self.put_block(lcontinuing, continuing, context)?;
3894 if let Some(condition) = break_if {
3895 write!(self.out, "{lcontinuing}if (")?;
3896 self.put_expression(condition, &context.expression, true)?;
3897 writeln!(self.out, ") {{")?;
3898 writeln!(self.out, "{}break;", lcontinuing.next())?;
3899 writeln!(self.out, "{lcontinuing}}}")?;
3900 }
3901 writeln!(self.out, "{lif}}}")?;
3902 writeln!(self.out, "{lif}{gate_name} = false;")?;
3903 }
3904 self.put_block(level.next(), body, context)?;
3905
3906 writeln!(self.out, "{level}}}")?;
3907 }
3908 crate::Statement::Break => {
3909 writeln!(self.out, "{level}break;")?;
3910 }
3911 crate::Statement::Continue => {
3912 writeln!(self.out, "{level}continue;")?;
3913 }
3914 crate::Statement::Return {
3915 value: Some(expr_handle),
3916 } => {
3917 self.put_return_value(
3918 level,
3919 expr_handle,
3920 context.result_struct,
3921 &context.expression,
3922 )?;
3923 }
3924 crate::Statement::Return { value: None } => {
3925 writeln!(self.out, "{level}return;")?;
3926 }
3927 crate::Statement::Kill => {
3928 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3929 }
3930 crate::Statement::ControlBarrier(flags)
3931 | crate::Statement::MemoryBarrier(flags) => {
3932 self.write_barrier(flags, level)?;
3933 }
3934 crate::Statement::Store { pointer, value } => {
3935 self.put_store(pointer, value, level, context)?
3936 }
3937 crate::Statement::ImageStore {
3938 image,
3939 coordinate,
3940 array_index,
3941 value,
3942 } => {
3943 let address = TexelAddress {
3944 coordinate,
3945 array_index,
3946 sample: None,
3947 level: None,
3948 };
3949 self.put_image_store(level, image, &address, value, context)?
3950 }
3951 crate::Statement::Call {
3952 function,
3953 ref arguments,
3954 result,
3955 } => {
3956 write!(self.out, "{level}")?;
3957 if let Some(expr) = result {
3958 let name = Baked(expr).to_string();
3959 self.start_baking_expression(expr, &context.expression, &name)?;
3960 self.named_expressions.insert(expr, name);
3961 }
3962 let fun_name = &self.names[&NameKey::Function(function)];
3963 write!(self.out, "{fun_name}(")?;
3964 for (i, &handle) in arguments.iter().enumerate() {
3966 if i != 0 {
3967 write!(self.out, ", ")?;
3968 }
3969 self.put_expression(handle, &context.expression, true)?;
3970 }
3971 let mut separate = !arguments.is_empty();
3973 let fun_info = &context.expression.mod_info[function];
3974 let mut needs_buffer_sizes = false;
3975 for (handle, var) in context.expression.module.global_variables.iter() {
3976 if fun_info[handle].is_empty() {
3977 continue;
3978 }
3979 if var.space.needs_pass_through() {
3980 let name = &self.names[&NameKey::GlobalVariable(handle)];
3981 if separate {
3982 write!(self.out, ", ")?;
3983 } else {
3984 separate = true;
3985 }
3986 write!(self.out, "{name}")?;
3987 }
3988 needs_buffer_sizes |=
3989 needs_array_length(var.ty, &context.expression.module.types);
3990 }
3991 if needs_buffer_sizes {
3992 if separate {
3993 write!(self.out, ", ")?;
3994 }
3995 write!(self.out, "_buffer_sizes")?;
3996 }
3997
3998 writeln!(self.out, ");")?;
4000 }
4001 crate::Statement::Atomic {
4002 pointer,
4003 ref fun,
4004 value,
4005 result,
4006 } => {
4007 let context = &context.expression;
4008
4009 write!(self.out, "{level}")?;
4014 let fun_key = if let Some(result) = result {
4015 let res_name = Baked(result).to_string();
4016 self.start_baking_expression(result, context, &res_name)?;
4017 self.named_expressions.insert(result, res_name);
4018 fun.to_msl()
4019 } else if context.resolve_type(value).scalar_width() == Some(8) {
4020 fun.to_msl_64_bit()?
4021 } else {
4022 fun.to_msl()
4023 };
4024
4025 let policy = context.choose_bounds_check_policy(pointer);
4029 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4030 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4031
4032 if checked {
4034 write!(self.out, " ? ")?;
4035 }
4036
4037 match *fun {
4039 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4040 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4041 self.put_access_chain(pointer, policy, context)?;
4042 write!(self.out, ", ")?;
4043 self.put_expression(cmp, context, true)?;
4044 write!(self.out, ", ")?;
4045 self.put_expression(value, context, true)?;
4046 write!(self.out, ")")?;
4047 }
4048 _ => {
4049 write!(
4050 self.out,
4051 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4052 )?;
4053 self.put_access_chain(pointer, policy, context)?;
4054 write!(self.out, ", ")?;
4055 self.put_expression(value, context, true)?;
4056 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4057 }
4058 }
4059
4060 if checked {
4062 write!(self.out, " : DefaultConstructible()")?;
4063 }
4064
4065 writeln!(self.out, ";")?;
4067 }
4068 crate::Statement::ImageAtomic {
4069 image,
4070 coordinate,
4071 array_index,
4072 fun,
4073 value,
4074 } => {
4075 let address = TexelAddress {
4076 coordinate,
4077 array_index,
4078 sample: None,
4079 level: None,
4080 };
4081 self.put_image_atomic(level, image, &address, fun, value, context)?
4082 }
4083 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4084 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4085
4086 write!(self.out, "{level}")?;
4087 let name = self.namer.call("");
4088 self.start_baking_expression(result, &context.expression, &name)?;
4089 self.put_load(pointer, &context.expression, true)?;
4090 self.named_expressions.insert(result, name);
4091
4092 writeln!(self.out, ";")?;
4093 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4094 }
4095 crate::Statement::RayQuery { query, ref fun } => {
4096 if context.expression.lang_version < (2, 4) {
4097 return Err(Error::UnsupportedRayTracing);
4098 }
4099
4100 match *fun {
4101 crate::RayQueryFunction::Initialize {
4102 acceleration_structure,
4103 descriptor,
4104 } => {
4105 write!(self.out, "{level}")?;
4107 self.put_expression(query, &context.expression, true)?;
4108 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
4109 {
4110 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
4111 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
4112 write!(self.out, "{level}")?;
4113 self.put_expression(query, &context.expression, true)?;
4114 write!(
4115 self.out,
4116 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
4117 )?;
4118 self.put_expression(descriptor, &context.expression, true)?;
4119 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
4120 self.put_expression(descriptor, &context.expression, true)?;
4121 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
4122 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
4123 }
4124 {
4125 let f_opaque = back::RayFlag::OPAQUE.bits();
4126 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
4127 write!(self.out, "{level}")?;
4128 self.put_expression(query, &context.expression, true)?;
4129 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
4130 self.put_expression(descriptor, &context.expression, true)?;
4131 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
4132 self.put_expression(descriptor, &context.expression, true)?;
4133 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
4134 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
4135 }
4136 {
4137 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
4138 write!(self.out, "{level}")?;
4139 self.put_expression(query, &context.expression, true)?;
4140 write!(
4141 self.out,
4142 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
4143 )?;
4144 self.put_expression(descriptor, &context.expression, true)?;
4145 writeln!(self.out, ".flags & {flag}) != 0);")?;
4146 }
4147
4148 write!(self.out, "{level}")?;
4149 self.put_expression(query, &context.expression, true)?;
4150 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
4151 self.put_expression(query, &context.expression, true)?;
4152 write!(
4153 self.out,
4154 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4155 )?;
4156 self.put_expression(descriptor, &context.expression, true)?;
4157 write!(self.out, ".origin, ")?;
4158 self.put_expression(descriptor, &context.expression, true)?;
4159 write!(self.out, ".dir, ")?;
4160 self.put_expression(descriptor, &context.expression, true)?;
4161 write!(self.out, ".tmin, ")?;
4162 self.put_expression(descriptor, &context.expression, true)?;
4163 write!(self.out, ".tmax), ")?;
4164 self.put_expression(acceleration_structure, &context.expression, true)?;
4165 write!(self.out, ", ")?;
4166 self.put_expression(descriptor, &context.expression, true)?;
4167 write!(self.out, ".cull_mask);")?;
4168
4169 write!(self.out, "{level}")?;
4170 self.put_expression(query, &context.expression, true)?;
4171 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4172 }
4173 crate::RayQueryFunction::Proceed { result } => {
4174 write!(self.out, "{level}")?;
4175 let name = Baked(result).to_string();
4176 self.start_baking_expression(result, &context.expression, &name)?;
4177 self.named_expressions.insert(result, name);
4178 self.put_expression(query, &context.expression, true)?;
4179 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4180 if RAY_QUERY_MODERN_SUPPORT {
4181 write!(self.out, "{level}")?;
4182 self.put_expression(query, &context.expression, true)?;
4183 writeln!(self.out, ".?.next();")?;
4184 }
4185 }
4186 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4187 if RAY_QUERY_MODERN_SUPPORT {
4188 write!(self.out, "{level}")?;
4189 self.put_expression(query, &context.expression, true)?;
4190 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4191 self.put_expression(hit_t, &context.expression, true)?;
4192 writeln!(self.out, ");")?;
4193 } else {
4194 log::warn!("Ray Query GenerateIntersection is not yet supported");
4195 }
4196 }
4197 crate::RayQueryFunction::ConfirmIntersection => {
4198 if RAY_QUERY_MODERN_SUPPORT {
4199 write!(self.out, "{level}")?;
4200 self.put_expression(query, &context.expression, true)?;
4201 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4202 } else {
4203 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4204 }
4205 }
4206 crate::RayQueryFunction::Terminate => {
4207 if RAY_QUERY_MODERN_SUPPORT {
4208 write!(self.out, "{level}")?;
4209 self.put_expression(query, &context.expression, true)?;
4210 writeln!(self.out, ".?.abort();")?;
4211 }
4212 write!(self.out, "{level}")?;
4213 self.put_expression(query, &context.expression, true)?;
4214 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4215 }
4216 }
4217 }
4218 crate::Statement::SubgroupBallot { result, predicate } => {
4219 write!(self.out, "{level}")?;
4220 let name = self.namer.call("");
4221 self.start_baking_expression(result, &context.expression, &name)?;
4222 self.named_expressions.insert(result, name);
4223 write!(
4224 self.out,
4225 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4226 )?;
4227 if let Some(predicate) = predicate {
4228 self.put_expression(predicate, &context.expression, true)?;
4229 } else {
4230 write!(self.out, "true")?;
4231 }
4232 writeln!(self.out, "), 0, 0, 0);")?;
4233 }
4234 crate::Statement::SubgroupCollectiveOperation {
4235 op,
4236 collective_op,
4237 argument,
4238 result,
4239 } => {
4240 write!(self.out, "{level}")?;
4241 let name = self.namer.call("");
4242 self.start_baking_expression(result, &context.expression, &name)?;
4243 self.named_expressions.insert(result, name);
4244 match (collective_op, op) {
4245 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4246 write!(self.out, "{NAMESPACE}::simd_all(")?
4247 }
4248 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4249 write!(self.out, "{NAMESPACE}::simd_any(")?
4250 }
4251 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4252 write!(self.out, "{NAMESPACE}::simd_sum(")?
4253 }
4254 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4255 write!(self.out, "{NAMESPACE}::simd_product(")?
4256 }
4257 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4258 write!(self.out, "{NAMESPACE}::simd_max(")?
4259 }
4260 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4261 write!(self.out, "{NAMESPACE}::simd_min(")?
4262 }
4263 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4264 write!(self.out, "{NAMESPACE}::simd_and(")?
4265 }
4266 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4267 write!(self.out, "{NAMESPACE}::simd_or(")?
4268 }
4269 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4270 write!(self.out, "{NAMESPACE}::simd_xor(")?
4271 }
4272 (
4273 crate::CollectiveOperation::ExclusiveScan,
4274 crate::SubgroupOperation::Add,
4275 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4276 (
4277 crate::CollectiveOperation::ExclusiveScan,
4278 crate::SubgroupOperation::Mul,
4279 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4280 (
4281 crate::CollectiveOperation::InclusiveScan,
4282 crate::SubgroupOperation::Add,
4283 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4284 (
4285 crate::CollectiveOperation::InclusiveScan,
4286 crate::SubgroupOperation::Mul,
4287 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4288 _ => unimplemented!(),
4289 }
4290 self.put_expression(argument, &context.expression, true)?;
4291 writeln!(self.out, ");")?;
4292 }
4293 crate::Statement::SubgroupGather {
4294 mode,
4295 argument,
4296 result,
4297 } => {
4298 write!(self.out, "{level}")?;
4299 let name = self.namer.call("");
4300 self.start_baking_expression(result, &context.expression, &name)?;
4301 self.named_expressions.insert(result, name);
4302 match mode {
4303 crate::GatherMode::BroadcastFirst => {
4304 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4305 }
4306 crate::GatherMode::Broadcast(_) => {
4307 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4308 }
4309 crate::GatherMode::Shuffle(_) => {
4310 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4311 }
4312 crate::GatherMode::ShuffleDown(_) => {
4313 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4314 }
4315 crate::GatherMode::ShuffleUp(_) => {
4316 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4317 }
4318 crate::GatherMode::ShuffleXor(_) => {
4319 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4320 }
4321 crate::GatherMode::QuadBroadcast(_) => {
4322 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4323 }
4324 crate::GatherMode::QuadSwap(_) => {
4325 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4326 }
4327 }
4328 self.put_expression(argument, &context.expression, true)?;
4329 match mode {
4330 crate::GatherMode::BroadcastFirst => {}
4331 crate::GatherMode::Broadcast(index)
4332 | crate::GatherMode::Shuffle(index)
4333 | crate::GatherMode::ShuffleDown(index)
4334 | crate::GatherMode::ShuffleUp(index)
4335 | crate::GatherMode::ShuffleXor(index)
4336 | crate::GatherMode::QuadBroadcast(index) => {
4337 write!(self.out, ", ")?;
4338 self.put_expression(index, &context.expression, true)?;
4339 }
4340 crate::GatherMode::QuadSwap(direction) => {
4341 write!(self.out, ", ")?;
4342 match direction {
4343 crate::Direction::X => {
4344 write!(self.out, "1u")?;
4345 }
4346 crate::Direction::Y => {
4347 write!(self.out, "2u")?;
4348 }
4349 crate::Direction::Diagonal => {
4350 write!(self.out, "3u")?;
4351 }
4352 }
4353 }
4354 }
4355 writeln!(self.out, ");")?;
4356 }
4357 crate::Statement::CooperativeStore { target, ref data } => {
4358 write!(self.out, "{level}simdgroup_store(")?;
4359 self.put_expression(target, &context.expression, true)?;
4360 write!(self.out, ", &")?;
4361 self.put_access_chain(
4362 data.pointer,
4363 context.expression.policies.index,
4364 &context.expression,
4365 )?;
4366 write!(self.out, ", ")?;
4367 self.put_expression(data.stride, &context.expression, true)?;
4368 if data.row_major {
4369 let matrix_origin = "0";
4370 let transpose = true;
4371 write!(self.out, ", {matrix_origin}, {transpose}")?;
4372 }
4373 writeln!(self.out, ");")?;
4374 }
4375 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4376 }
4377 }
4378
4379 for statement in statements {
4382 if let crate::Statement::Emit(ref range) = *statement {
4383 for handle in range.clone() {
4384 self.named_expressions.shift_remove(&handle);
4385 }
4386 }
4387 }
4388 Ok(())
4389 }
4390
4391 fn put_store(
4392 &mut self,
4393 pointer: Handle<crate::Expression>,
4394 value: Handle<crate::Expression>,
4395 level: back::Level,
4396 context: &StatementContext,
4397 ) -> BackendResult {
4398 let policy = context.expression.choose_bounds_check_policy(pointer);
4399 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4400 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4401 {
4402 writeln!(self.out, ") {{")?;
4403 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4404 writeln!(self.out, "{level}}}")?;
4405 } else {
4406 self.put_unchecked_store(pointer, value, policy, level, context)?;
4407 }
4408
4409 Ok(())
4410 }
4411
4412 fn put_unchecked_store(
4413 &mut self,
4414 pointer: Handle<crate::Expression>,
4415 value: Handle<crate::Expression>,
4416 policy: index::BoundsCheckPolicy,
4417 level: back::Level,
4418 context: &StatementContext,
4419 ) -> BackendResult {
4420 let is_atomic_pointer = context
4421 .expression
4422 .resolve_type(pointer)
4423 .is_atomic_pointer(&context.expression.module.types);
4424
4425 if is_atomic_pointer {
4426 write!(
4427 self.out,
4428 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4429 )?;
4430 self.put_access_chain(pointer, policy, &context.expression)?;
4431 write!(self.out, ", ")?;
4432 self.put_expression(value, &context.expression, true)?;
4433 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4434 } else {
4435 write!(self.out, "{level}")?;
4436 self.put_access_chain(pointer, policy, &context.expression)?;
4437 write!(self.out, " = ")?;
4438 self.put_expression(value, &context.expression, true)?;
4439 writeln!(self.out, ";")?;
4440 }
4441
4442 Ok(())
4443 }
4444
4445 pub fn write(
4446 &mut self,
4447 module: &crate::Module,
4448 info: &valid::ModuleInfo,
4449 options: &Options,
4450 pipeline_options: &PipelineOptions,
4451 ) -> Result<TranslationInfo, Error> {
4452 self.names.clear();
4453 self.namer.reset(
4454 module,
4455 &super::keywords::RESERVED_SET,
4456 proc::KeywordSet::empty(),
4457 proc::CaseInsensitiveKeywordSet::empty(),
4458 &[CLAMPED_LOD_LOAD_PREFIX],
4459 &mut self.names,
4460 );
4461 self.wrapped_functions.clear();
4462 self.struct_member_pads.clear();
4463
4464 writeln!(
4465 self.out,
4466 "// language: metal{}.{}",
4467 options.lang_version.0, options.lang_version.1
4468 )?;
4469 writeln!(self.out, "#include <metal_stdlib>")?;
4470 writeln!(self.out, "#include <simd/simd.h>")?;
4471 writeln!(self.out)?;
4472 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4474
4475 if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4476 return Err(Error::UnsupportedMeshShader);
4477 }
4478 self.needs_object_memory_barriers = module
4479 .entry_points
4480 .iter()
4481 .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4482
4483 let mut uses_ray_query = false;
4484 for (_, ty) in module.types.iter() {
4485 match ty.inner {
4486 crate::TypeInner::AccelerationStructure { .. } => {
4487 if options.lang_version < (2, 4) {
4488 return Err(Error::UnsupportedRayTracing);
4489 }
4490 }
4491 crate::TypeInner::RayQuery { .. } => {
4492 if options.lang_version < (2, 4) {
4493 return Err(Error::UnsupportedRayTracing);
4494 }
4495 uses_ray_query = true;
4496 }
4497 _ => (),
4498 }
4499 }
4500
4501 if module.special_types.ray_desc.is_some()
4502 || module.special_types.ray_intersection.is_some()
4503 {
4504 if options.lang_version < (2, 4) {
4505 return Err(Error::UnsupportedRayTracing);
4506 }
4507 }
4508
4509 if uses_ray_query {
4510 self.put_ray_query_type()?;
4511 }
4512
4513 if options
4514 .bounds_check_policies
4515 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4516 {
4517 self.put_default_constructible()?;
4518 }
4519 writeln!(self.out)?;
4520
4521 {
4522 let globals: Vec<Handle<crate::GlobalVariable>> = module
4525 .global_variables
4526 .iter()
4527 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4528 .map(|(handle, _)| handle)
4529 .collect();
4530
4531 let mut buffer_indices = vec![];
4532 for vbm in &pipeline_options.vertex_buffer_mappings {
4533 buffer_indices.push(vbm.id);
4534 }
4535
4536 if !globals.is_empty() || !buffer_indices.is_empty() {
4537 writeln!(self.out, "struct _mslBufferSizes {{")?;
4538
4539 for global in globals {
4540 writeln!(
4541 self.out,
4542 "{}uint {};",
4543 back::INDENT,
4544 ArraySizeMember(global)
4545 )?;
4546 }
4547
4548 for idx in buffer_indices {
4549 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4550 }
4551
4552 writeln!(self.out, "}};")?;
4553 writeln!(self.out)?;
4554 }
4555 };
4556
4557 self.write_type_defs(module)?;
4558 self.write_global_constants(module, info)?;
4559 self.write_functions(module, info, options, pipeline_options)
4560 }
4561
4562 fn put_default_constructible(&mut self) -> BackendResult {
4575 let tab = back::INDENT;
4576 writeln!(self.out, "struct DefaultConstructible {{")?;
4577 writeln!(self.out, "{tab}template<typename T>")?;
4578 writeln!(self.out, "{tab}operator T() && {{")?;
4579 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4580 writeln!(self.out, "{tab}}}")?;
4581 writeln!(self.out, "}};")?;
4582 Ok(())
4583 }
4584
4585 fn put_ray_query_type(&mut self) -> BackendResult {
4586 let tab = back::INDENT;
4587 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4588 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4589 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4590 writeln!(
4591 self.out,
4592 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4593 )?;
4594 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4595 writeln!(self.out, "}};")?;
4596 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4597 let v_triangle = back::RayIntersectionType::Triangle as u32;
4598 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4599 writeln!(
4600 self.out,
4601 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4602 )?;
4603 writeln!(
4604 self.out,
4605 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4606 )?;
4607 writeln!(self.out, "}}")?;
4608 Ok(())
4609 }
4610
4611 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4612 let mut generated_argument_buffer_wrapper = false;
4613 let mut generated_external_texture_wrapper = false;
4614 for (handle, ty) in module.types.iter() {
4615 match ty.inner {
4616 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4617 writeln!(self.out, "template <typename T>")?;
4618 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4619 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4620 writeln!(self.out, "}};")?;
4621 generated_argument_buffer_wrapper = true;
4622 }
4623 crate::TypeInner::Image {
4624 class: crate::ImageClass::External,
4625 ..
4626 } if !generated_external_texture_wrapper => {
4627 let params_ty_name = &self.names
4628 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4629 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4630 writeln!(
4631 self.out,
4632 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4633 back::INDENT
4634 )?;
4635 writeln!(
4636 self.out,
4637 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4638 back::INDENT
4639 )?;
4640 writeln!(
4641 self.out,
4642 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4643 back::INDENT
4644 )?;
4645 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4646 writeln!(self.out, "}};")?;
4647 generated_external_texture_wrapper = true;
4648 }
4649 _ => {}
4650 }
4651
4652 if !ty.needs_alias() {
4653 continue;
4654 }
4655 let name = &self.names[&NameKey::Type(handle)];
4656 match ty.inner {
4657 crate::TypeInner::Array {
4671 base,
4672 size,
4673 stride: _,
4674 } => {
4675 let base_name = TypeContext {
4676 handle: base,
4677 gctx: module.to_ctx(),
4678 names: &self.names,
4679 access: crate::StorageAccess::empty(),
4680 first_time: false,
4681 };
4682
4683 match size.resolve(module.to_ctx())? {
4684 proc::IndexableLength::Known(size) => {
4685 writeln!(self.out, "struct {name} {{")?;
4686 writeln!(
4687 self.out,
4688 "{}{} {}[{}];",
4689 back::INDENT,
4690 base_name,
4691 WRAPPED_ARRAY_FIELD,
4692 size
4693 )?;
4694 writeln!(self.out, "}};")?;
4695 }
4696 proc::IndexableLength::Dynamic => {
4697 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4698 }
4699 }
4700 }
4701 crate::TypeInner::Struct {
4702 ref members, span, ..
4703 } => {
4704 writeln!(self.out, "struct {name} {{")?;
4705 let mut last_offset = 0;
4706 for (index, member) in members.iter().enumerate() {
4707 if member.offset > last_offset {
4708 self.struct_member_pads.insert((handle, index as u32));
4709 let pad = member.offset - last_offset;
4710 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4711 }
4712 let ty_inner = &module.types[member.ty].inner;
4713 last_offset = member.offset + ty_inner.size(module.to_ctx());
4714
4715 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4716
4717 match should_pack_struct_member(members, span, index, module) {
4719 Some(scalar) => {
4720 writeln!(
4721 self.out,
4722 "{}{}::packed_{}3 {};",
4723 back::INDENT,
4724 NAMESPACE,
4725 scalar.to_msl_name(),
4726 member_name
4727 )?;
4728 }
4729 None => {
4730 let base_name = TypeContext {
4731 handle: member.ty,
4732 gctx: module.to_ctx(),
4733 names: &self.names,
4734 access: crate::StorageAccess::empty(),
4735 first_time: false,
4736 };
4737 writeln!(
4738 self.out,
4739 "{}{} {};",
4740 back::INDENT,
4741 base_name,
4742 member_name
4743 )?;
4744
4745 if let crate::TypeInner::Vector {
4747 size: crate::VectorSize::Tri,
4748 scalar,
4749 } = *ty_inner
4750 {
4751 last_offset += scalar.width as u32;
4752 }
4753 }
4754 }
4755 }
4756 if last_offset < span {
4757 let pad = span - last_offset;
4758 writeln!(
4759 self.out,
4760 "{}char _pad{}[{}];",
4761 back::INDENT,
4762 members.len(),
4763 pad
4764 )?;
4765 }
4766 writeln!(self.out, "}};")?;
4767 }
4768 _ => {
4769 let ty_name = TypeContext {
4770 handle,
4771 gctx: module.to_ctx(),
4772 names: &self.names,
4773 access: crate::StorageAccess::empty(),
4774 first_time: true,
4775 };
4776 writeln!(self.out, "typedef {ty_name} {name};")?;
4777 }
4778 }
4779 }
4780
4781 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4783 match type_key {
4784 &crate::PredeclaredType::ModfResult { size, scalar }
4785 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4786 let arg_type_name_owner;
4787 let arg_type_name = if let Some(size) = size {
4788 arg_type_name_owner = format!(
4789 "{NAMESPACE}::{}{}",
4790 if scalar.width == 8 { "double" } else { "float" },
4791 size as u8
4792 );
4793 &arg_type_name_owner
4794 } else if scalar.width == 8 {
4795 "double"
4796 } else {
4797 "float"
4798 };
4799
4800 let other_type_name_owner;
4801 let (defined_func_name, called_func_name, other_type_name) =
4802 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4803 (MODF_FUNCTION, "modf", arg_type_name)
4804 } else {
4805 let other_type_name = if let Some(size) = size {
4806 other_type_name_owner = format!("int{}", size as u8);
4807 &other_type_name_owner
4808 } else {
4809 "int"
4810 };
4811 (FREXP_FUNCTION, "frexp", other_type_name)
4812 };
4813
4814 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4815
4816 writeln!(self.out)?;
4817 writeln!(
4818 self.out,
4819 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4820 {other_type_name} other;
4821 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4822 return {struct_name}{{ fract, other }};
4823}}"
4824 )?;
4825 }
4826 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4827 let arg_type_name = scalar.to_msl_name();
4828 let called_func_name = "atomic_compare_exchange_weak_explicit";
4829 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4830 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4831
4832 writeln!(self.out)?;
4833
4834 for address_space_name in ["device", "threadgroup"] {
4835 writeln!(
4836 self.out,
4837 "\
4838template <typename A>
4839{struct_name} {defined_func_name}(
4840 {address_space_name} A *atomic_ptr,
4841 {arg_type_name} cmp,
4842 {arg_type_name} v
4843) {{
4844 bool swapped = {NAMESPACE}::{called_func_name}(
4845 atomic_ptr, &cmp, v,
4846 metal::memory_order_relaxed, metal::memory_order_relaxed
4847 );
4848 return {struct_name}{{cmp, swapped}};
4849}}"
4850 )?;
4851 }
4852 }
4853 }
4854 }
4855
4856 Ok(())
4857 }
4858
4859 fn write_global_constants(
4861 &mut self,
4862 module: &crate::Module,
4863 mod_info: &valid::ModuleInfo,
4864 ) -> BackendResult {
4865 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4866
4867 for (handle, constant) in constants {
4868 let ty_name = TypeContext {
4869 handle: constant.ty,
4870 gctx: module.to_ctx(),
4871 names: &self.names,
4872 access: crate::StorageAccess::empty(),
4873 first_time: false,
4874 };
4875 let name = &self.names[&NameKey::Constant(handle)];
4876 write!(self.out, "constant {ty_name} {name} = ")?;
4877 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4878 writeln!(self.out, ";")?;
4879 }
4880
4881 Ok(())
4882 }
4883
4884 fn put_inline_sampler_properties(
4885 &mut self,
4886 level: back::Level,
4887 sampler: &sm::InlineSampler,
4888 ) -> BackendResult {
4889 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4890 writeln!(
4891 self.out,
4892 "{}{}::{}_address::{},",
4893 level,
4894 NAMESPACE,
4895 letter,
4896 address.as_str(),
4897 )?;
4898 }
4899 writeln!(
4900 self.out,
4901 "{}{}::mag_filter::{},",
4902 level,
4903 NAMESPACE,
4904 sampler.mag_filter.as_str(),
4905 )?;
4906 writeln!(
4907 self.out,
4908 "{}{}::min_filter::{},",
4909 level,
4910 NAMESPACE,
4911 sampler.min_filter.as_str(),
4912 )?;
4913 if let Some(filter) = sampler.mip_filter {
4914 writeln!(
4915 self.out,
4916 "{}{}::mip_filter::{},",
4917 level,
4918 NAMESPACE,
4919 filter.as_str(),
4920 )?;
4921 }
4922 if sampler.border_color != sm::BorderColor::TransparentBlack {
4924 writeln!(
4925 self.out,
4926 "{}{}::border_color::{},",
4927 level,
4928 NAMESPACE,
4929 sampler.border_color.as_str(),
4930 )?;
4931 }
4932 if false {
4936 if let Some(ref lod) = sampler.lod_clamp {
4937 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4938 }
4939 if let Some(aniso) = sampler.max_anisotropy {
4940 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4941 }
4942 }
4943 if sampler.compare_func != sm::CompareFunc::Never {
4944 writeln!(
4945 self.out,
4946 "{}{}::compare_func::{},",
4947 level,
4948 NAMESPACE,
4949 sampler.compare_func.as_str(),
4950 )?;
4951 }
4952 writeln!(
4953 self.out,
4954 "{}{}::coord::{}",
4955 level,
4956 NAMESPACE,
4957 sampler.coord.as_str()
4958 )?;
4959 Ok(())
4960 }
4961
4962 fn write_unpacking_function(
4963 &mut self,
4964 format: back::msl::VertexFormat,
4965 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4966 use crate::{Scalar, VectorSize};
4967 use back::msl::VertexFormat::*;
4968 match format {
4969 Uint8 => {
4970 let name = self.namer.call("unpackUint8");
4971 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4972 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4973 writeln!(self.out, "}}")?;
4974 Ok((name, 1, None, Scalar::U32))
4975 }
4976 Uint8x2 => {
4977 let name = self.namer.call("unpackUint8x2");
4978 writeln!(
4979 self.out,
4980 "metal::uint2 {name}(metal::uchar b0, \
4981 metal::uchar b1) {{"
4982 )?;
4983 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4984 writeln!(self.out, "}}")?;
4985 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4986 }
4987 Uint8x4 => {
4988 let name = self.namer.call("unpackUint8x4");
4989 writeln!(
4990 self.out,
4991 "metal::uint4 {name}(metal::uchar b0, \
4992 metal::uchar b1, \
4993 metal::uchar b2, \
4994 metal::uchar b3) {{"
4995 )?;
4996 writeln!(
4997 self.out,
4998 "{}return metal::uint4(b0, b1, b2, b3);",
4999 back::INDENT
5000 )?;
5001 writeln!(self.out, "}}")?;
5002 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
5003 }
5004 Sint8 => {
5005 let name = self.namer.call("unpackSint8");
5006 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
5007 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
5008 writeln!(self.out, "}}")?;
5009 Ok((name, 1, None, Scalar::I32))
5010 }
5011 Sint8x2 => {
5012 let name = self.namer.call("unpackSint8x2");
5013 writeln!(
5014 self.out,
5015 "metal::int2 {name}(metal::uchar b0, \
5016 metal::uchar b1) {{"
5017 )?;
5018 writeln!(
5019 self.out,
5020 "{}return metal::int2(as_type<char>(b0), \
5021 as_type<char>(b1));",
5022 back::INDENT
5023 )?;
5024 writeln!(self.out, "}}")?;
5025 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
5026 }
5027 Sint8x4 => {
5028 let name = self.namer.call("unpackSint8x4");
5029 writeln!(
5030 self.out,
5031 "metal::int4 {name}(metal::uchar b0, \
5032 metal::uchar b1, \
5033 metal::uchar b2, \
5034 metal::uchar b3) {{"
5035 )?;
5036 writeln!(
5037 self.out,
5038 "{}return metal::int4(as_type<char>(b0), \
5039 as_type<char>(b1), \
5040 as_type<char>(b2), \
5041 as_type<char>(b3));",
5042 back::INDENT
5043 )?;
5044 writeln!(self.out, "}}")?;
5045 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
5046 }
5047 Unorm8 => {
5048 let name = self.namer.call("unpackUnorm8");
5049 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5050 writeln!(
5051 self.out,
5052 "{}return float(float(b0) / 255.0f);",
5053 back::INDENT
5054 )?;
5055 writeln!(self.out, "}}")?;
5056 Ok((name, 1, None, Scalar::F32))
5057 }
5058 Unorm8x2 => {
5059 let name = self.namer.call("unpackUnorm8x2");
5060 writeln!(
5061 self.out,
5062 "metal::float2 {name}(metal::uchar b0, \
5063 metal::uchar b1) {{"
5064 )?;
5065 writeln!(
5066 self.out,
5067 "{}return metal::float2(float(b0) / 255.0f, \
5068 float(b1) / 255.0f);",
5069 back::INDENT
5070 )?;
5071 writeln!(self.out, "}}")?;
5072 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5073 }
5074 Unorm8x4 => {
5075 let name = self.namer.call("unpackUnorm8x4");
5076 writeln!(
5077 self.out,
5078 "metal::float4 {name}(metal::uchar b0, \
5079 metal::uchar b1, \
5080 metal::uchar b2, \
5081 metal::uchar b3) {{"
5082 )?;
5083 writeln!(
5084 self.out,
5085 "{}return metal::float4(float(b0) / 255.0f, \
5086 float(b1) / 255.0f, \
5087 float(b2) / 255.0f, \
5088 float(b3) / 255.0f);",
5089 back::INDENT
5090 )?;
5091 writeln!(self.out, "}}")?;
5092 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5093 }
5094 Snorm8 => {
5095 let name = self.namer.call("unpackSnorm8");
5096 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
5097 writeln!(
5098 self.out,
5099 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
5100 back::INDENT
5101 )?;
5102 writeln!(self.out, "}}")?;
5103 Ok((name, 1, None, Scalar::F32))
5104 }
5105 Snorm8x2 => {
5106 let name = self.namer.call("unpackSnorm8x2");
5107 writeln!(
5108 self.out,
5109 "metal::float2 {name}(metal::uchar b0, \
5110 metal::uchar b1) {{"
5111 )?;
5112 writeln!(
5113 self.out,
5114 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5115 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
5116 back::INDENT
5117 )?;
5118 writeln!(self.out, "}}")?;
5119 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
5120 }
5121 Snorm8x4 => {
5122 let name = self.namer.call("unpackSnorm8x4");
5123 writeln!(
5124 self.out,
5125 "metal::float4 {name}(metal::uchar b0, \
5126 metal::uchar b1, \
5127 metal::uchar b2, \
5128 metal::uchar b3) {{"
5129 )?;
5130 writeln!(
5131 self.out,
5132 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5133 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5134 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5135 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5136 back::INDENT
5137 )?;
5138 writeln!(self.out, "}}")?;
5139 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5140 }
5141 Uint16 => {
5142 let name = self.namer.call("unpackUint16");
5143 writeln!(
5144 self.out,
5145 "metal::uint {name}(metal::uint b0, \
5146 metal::uint b1) {{"
5147 )?;
5148 writeln!(
5149 self.out,
5150 "{}return metal::uint(b1 << 8 | b0);",
5151 back::INDENT
5152 )?;
5153 writeln!(self.out, "}}")?;
5154 Ok((name, 2, None, Scalar::U32))
5155 }
5156 Uint16x2 => {
5157 let name = self.namer.call("unpackUint16x2");
5158 writeln!(
5159 self.out,
5160 "metal::uint2 {name}(metal::uint b0, \
5161 metal::uint b1, \
5162 metal::uint b2, \
5163 metal::uint b3) {{"
5164 )?;
5165 writeln!(
5166 self.out,
5167 "{}return metal::uint2(b1 << 8 | b0, \
5168 b3 << 8 | b2);",
5169 back::INDENT
5170 )?;
5171 writeln!(self.out, "}}")?;
5172 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5173 }
5174 Uint16x4 => {
5175 let name = self.namer.call("unpackUint16x4");
5176 writeln!(
5177 self.out,
5178 "metal::uint4 {name}(metal::uint b0, \
5179 metal::uint b1, \
5180 metal::uint b2, \
5181 metal::uint b3, \
5182 metal::uint b4, \
5183 metal::uint b5, \
5184 metal::uint b6, \
5185 metal::uint b7) {{"
5186 )?;
5187 writeln!(
5188 self.out,
5189 "{}return metal::uint4(b1 << 8 | b0, \
5190 b3 << 8 | b2, \
5191 b5 << 8 | b4, \
5192 b7 << 8 | b6);",
5193 back::INDENT
5194 )?;
5195 writeln!(self.out, "}}")?;
5196 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5197 }
5198 Sint16 => {
5199 let name = self.namer.call("unpackSint16");
5200 writeln!(
5201 self.out,
5202 "int {name}(metal::ushort b0, \
5203 metal::ushort b1) {{"
5204 )?;
5205 writeln!(
5206 self.out,
5207 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5208 back::INDENT
5209 )?;
5210 writeln!(self.out, "}}")?;
5211 Ok((name, 2, None, Scalar::I32))
5212 }
5213 Sint16x2 => {
5214 let name = self.namer.call("unpackSint16x2");
5215 writeln!(
5216 self.out,
5217 "metal::int2 {name}(metal::ushort b0, \
5218 metal::ushort b1, \
5219 metal::ushort b2, \
5220 metal::ushort b3) {{"
5221 )?;
5222 writeln!(
5223 self.out,
5224 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5225 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5226 back::INDENT
5227 )?;
5228 writeln!(self.out, "}}")?;
5229 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5230 }
5231 Sint16x4 => {
5232 let name = self.namer.call("unpackSint16x4");
5233 writeln!(
5234 self.out,
5235 "metal::int4 {name}(metal::ushort b0, \
5236 metal::ushort b1, \
5237 metal::ushort b2, \
5238 metal::ushort b3, \
5239 metal::ushort b4, \
5240 metal::ushort b5, \
5241 metal::ushort b6, \
5242 metal::ushort b7) {{"
5243 )?;
5244 writeln!(
5245 self.out,
5246 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5247 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5248 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5249 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5250 back::INDENT
5251 )?;
5252 writeln!(self.out, "}}")?;
5253 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5254 }
5255 Unorm16 => {
5256 let name = self.namer.call("unpackUnorm16");
5257 writeln!(
5258 self.out,
5259 "float {name}(metal::ushort b0, \
5260 metal::ushort b1) {{"
5261 )?;
5262 writeln!(
5263 self.out,
5264 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5265 back::INDENT
5266 )?;
5267 writeln!(self.out, "}}")?;
5268 Ok((name, 2, None, Scalar::F32))
5269 }
5270 Unorm16x2 => {
5271 let name = self.namer.call("unpackUnorm16x2");
5272 writeln!(
5273 self.out,
5274 "metal::float2 {name}(metal::ushort b0, \
5275 metal::ushort b1, \
5276 metal::ushort b2, \
5277 metal::ushort b3) {{"
5278 )?;
5279 writeln!(
5280 self.out,
5281 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5282 float(b3 << 8 | b2) / 65535.0f);",
5283 back::INDENT
5284 )?;
5285 writeln!(self.out, "}}")?;
5286 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5287 }
5288 Unorm16x4 => {
5289 let name = self.namer.call("unpackUnorm16x4");
5290 writeln!(
5291 self.out,
5292 "metal::float4 {name}(metal::ushort b0, \
5293 metal::ushort b1, \
5294 metal::ushort b2, \
5295 metal::ushort b3, \
5296 metal::ushort b4, \
5297 metal::ushort b5, \
5298 metal::ushort b6, \
5299 metal::ushort b7) {{"
5300 )?;
5301 writeln!(
5302 self.out,
5303 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5304 float(b3 << 8 | b2) / 65535.0f, \
5305 float(b5 << 8 | b4) / 65535.0f, \
5306 float(b7 << 8 | b6) / 65535.0f);",
5307 back::INDENT
5308 )?;
5309 writeln!(self.out, "}}")?;
5310 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5311 }
5312 Snorm16 => {
5313 let name = self.namer.call("unpackSnorm16");
5314 writeln!(
5315 self.out,
5316 "float {name}(metal::ushort b0, \
5317 metal::ushort b1) {{"
5318 )?;
5319 writeln!(
5320 self.out,
5321 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5322 back::INDENT
5323 )?;
5324 writeln!(self.out, "}}")?;
5325 Ok((name, 2, None, Scalar::F32))
5326 }
5327 Snorm16x2 => {
5328 let name = self.namer.call("unpackSnorm16x2");
5329 writeln!(
5330 self.out,
5331 "metal::float2 {name}(uint b0, \
5332 uint b1, \
5333 uint b2, \
5334 uint b3) {{"
5335 )?;
5336 writeln!(
5337 self.out,
5338 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5339 back::INDENT
5340 )?;
5341 writeln!(self.out, "}}")?;
5342 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5343 }
5344 Snorm16x4 => {
5345 let name = self.namer.call("unpackSnorm16x4");
5346 writeln!(
5347 self.out,
5348 "metal::float4 {name}(uint b0, \
5349 uint b1, \
5350 uint b2, \
5351 uint b3, \
5352 uint b4, \
5353 uint b5, \
5354 uint b6, \
5355 uint b7) {{"
5356 )?;
5357 writeln!(
5358 self.out,
5359 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5360 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5361 back::INDENT
5362 )?;
5363 writeln!(self.out, "}}")?;
5364 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5365 }
5366 Float16 => {
5367 let name = self.namer.call("unpackFloat16");
5368 writeln!(
5369 self.out,
5370 "float {name}(metal::ushort b0, \
5371 metal::ushort b1) {{"
5372 )?;
5373 writeln!(
5374 self.out,
5375 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5376 back::INDENT
5377 )?;
5378 writeln!(self.out, "}}")?;
5379 Ok((name, 2, None, Scalar::F32))
5380 }
5381 Float16x2 => {
5382 let name = self.namer.call("unpackFloat16x2");
5383 writeln!(
5384 self.out,
5385 "metal::float2 {name}(metal::ushort b0, \
5386 metal::ushort b1, \
5387 metal::ushort b2, \
5388 metal::ushort b3) {{"
5389 )?;
5390 writeln!(
5391 self.out,
5392 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5393 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5394 back::INDENT
5395 )?;
5396 writeln!(self.out, "}}")?;
5397 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5398 }
5399 Float16x4 => {
5400 let name = self.namer.call("unpackFloat16x4");
5401 writeln!(
5402 self.out,
5403 "metal::float4 {name}(metal::ushort b0, \
5404 metal::ushort b1, \
5405 metal::ushort b2, \
5406 metal::ushort b3, \
5407 metal::ushort b4, \
5408 metal::ushort b5, \
5409 metal::ushort b6, \
5410 metal::ushort b7) {{"
5411 )?;
5412 writeln!(
5413 self.out,
5414 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5415 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5416 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5417 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5418 back::INDENT
5419 )?;
5420 writeln!(self.out, "}}")?;
5421 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5422 }
5423 Float32 => {
5424 let name = self.namer.call("unpackFloat32");
5425 writeln!(
5426 self.out,
5427 "float {name}(uint b0, \
5428 uint b1, \
5429 uint b2, \
5430 uint b3) {{"
5431 )?;
5432 writeln!(
5433 self.out,
5434 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5435 back::INDENT
5436 )?;
5437 writeln!(self.out, "}}")?;
5438 Ok((name, 4, None, Scalar::F32))
5439 }
5440 Float32x2 => {
5441 let name = self.namer.call("unpackFloat32x2");
5442 writeln!(
5443 self.out,
5444 "metal::float2 {name}(uint b0, \
5445 uint b1, \
5446 uint b2, \
5447 uint b3, \
5448 uint b4, \
5449 uint b5, \
5450 uint b6, \
5451 uint b7) {{"
5452 )?;
5453 writeln!(
5454 self.out,
5455 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5456 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5457 back::INDENT
5458 )?;
5459 writeln!(self.out, "}}")?;
5460 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5461 }
5462 Float32x3 => {
5463 let name = self.namer.call("unpackFloat32x3");
5464 writeln!(
5465 self.out,
5466 "metal::float3 {name}(uint b0, \
5467 uint b1, \
5468 uint b2, \
5469 uint b3, \
5470 uint b4, \
5471 uint b5, \
5472 uint b6, \
5473 uint b7, \
5474 uint b8, \
5475 uint b9, \
5476 uint b10, \
5477 uint b11) {{"
5478 )?;
5479 writeln!(
5480 self.out,
5481 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5482 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5483 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5484 back::INDENT
5485 )?;
5486 writeln!(self.out, "}}")?;
5487 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5488 }
5489 Float32x4 => {
5490 let name = self.namer.call("unpackFloat32x4");
5491 writeln!(
5492 self.out,
5493 "metal::float4 {name}(uint b0, \
5494 uint b1, \
5495 uint b2, \
5496 uint b3, \
5497 uint b4, \
5498 uint b5, \
5499 uint b6, \
5500 uint b7, \
5501 uint b8, \
5502 uint b9, \
5503 uint b10, \
5504 uint b11, \
5505 uint b12, \
5506 uint b13, \
5507 uint b14, \
5508 uint b15) {{"
5509 )?;
5510 writeln!(
5511 self.out,
5512 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5513 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5514 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5515 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5516 back::INDENT
5517 )?;
5518 writeln!(self.out, "}}")?;
5519 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5520 }
5521 Uint32 => {
5522 let name = self.namer.call("unpackUint32");
5523 writeln!(
5524 self.out,
5525 "uint {name}(uint b0, \
5526 uint b1, \
5527 uint b2, \
5528 uint b3) {{"
5529 )?;
5530 writeln!(
5531 self.out,
5532 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5533 back::INDENT
5534 )?;
5535 writeln!(self.out, "}}")?;
5536 Ok((name, 4, None, Scalar::U32))
5537 }
5538 Uint32x2 => {
5539 let name = self.namer.call("unpackUint32x2");
5540 writeln!(
5541 self.out,
5542 "uint2 {name}(uint b0, \
5543 uint b1, \
5544 uint b2, \
5545 uint b3, \
5546 uint b4, \
5547 uint b5, \
5548 uint b6, \
5549 uint b7) {{"
5550 )?;
5551 writeln!(
5552 self.out,
5553 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5554 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5555 back::INDENT
5556 )?;
5557 writeln!(self.out, "}}")?;
5558 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5559 }
5560 Uint32x3 => {
5561 let name = self.namer.call("unpackUint32x3");
5562 writeln!(
5563 self.out,
5564 "uint3 {name}(uint b0, \
5565 uint b1, \
5566 uint b2, \
5567 uint b3, \
5568 uint b4, \
5569 uint b5, \
5570 uint b6, \
5571 uint b7, \
5572 uint b8, \
5573 uint b9, \
5574 uint b10, \
5575 uint b11) {{"
5576 )?;
5577 writeln!(
5578 self.out,
5579 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5580 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5581 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5582 back::INDENT
5583 )?;
5584 writeln!(self.out, "}}")?;
5585 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5586 }
5587 Uint32x4 => {
5588 let name = self.namer.call("unpackUint32x4");
5589 writeln!(
5590 self.out,
5591 "{NAMESPACE}::uint4 {name}(uint b0, \
5592 uint b1, \
5593 uint b2, \
5594 uint b3, \
5595 uint b4, \
5596 uint b5, \
5597 uint b6, \
5598 uint b7, \
5599 uint b8, \
5600 uint b9, \
5601 uint b10, \
5602 uint b11, \
5603 uint b12, \
5604 uint b13, \
5605 uint b14, \
5606 uint b15) {{"
5607 )?;
5608 writeln!(
5609 self.out,
5610 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5611 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5612 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5613 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5614 back::INDENT
5615 )?;
5616 writeln!(self.out, "}}")?;
5617 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5618 }
5619 Sint32 => {
5620 let name = self.namer.call("unpackSint32");
5621 writeln!(
5622 self.out,
5623 "int {name}(uint b0, \
5624 uint b1, \
5625 uint b2, \
5626 uint b3) {{"
5627 )?;
5628 writeln!(
5629 self.out,
5630 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5631 back::INDENT
5632 )?;
5633 writeln!(self.out, "}}")?;
5634 Ok((name, 4, None, Scalar::I32))
5635 }
5636 Sint32x2 => {
5637 let name = self.namer.call("unpackSint32x2");
5638 writeln!(
5639 self.out,
5640 "metal::int2 {name}(uint b0, \
5641 uint b1, \
5642 uint b2, \
5643 uint b3, \
5644 uint b4, \
5645 uint b5, \
5646 uint b6, \
5647 uint b7) {{"
5648 )?;
5649 writeln!(
5650 self.out,
5651 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5652 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5653 back::INDENT
5654 )?;
5655 writeln!(self.out, "}}")?;
5656 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5657 }
5658 Sint32x3 => {
5659 let name = self.namer.call("unpackSint32x3");
5660 writeln!(
5661 self.out,
5662 "metal::int3 {name}(uint b0, \
5663 uint b1, \
5664 uint b2, \
5665 uint b3, \
5666 uint b4, \
5667 uint b5, \
5668 uint b6, \
5669 uint b7, \
5670 uint b8, \
5671 uint b9, \
5672 uint b10, \
5673 uint b11) {{"
5674 )?;
5675 writeln!(
5676 self.out,
5677 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5678 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5679 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5680 back::INDENT
5681 )?;
5682 writeln!(self.out, "}}")?;
5683 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5684 }
5685 Sint32x4 => {
5686 let name = self.namer.call("unpackSint32x4");
5687 writeln!(
5688 self.out,
5689 "metal::int4 {name}(uint b0, \
5690 uint b1, \
5691 uint b2, \
5692 uint b3, \
5693 uint b4, \
5694 uint b5, \
5695 uint b6, \
5696 uint b7, \
5697 uint b8, \
5698 uint b9, \
5699 uint b10, \
5700 uint b11, \
5701 uint b12, \
5702 uint b13, \
5703 uint b14, \
5704 uint b15) {{"
5705 )?;
5706 writeln!(
5707 self.out,
5708 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5709 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5710 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5711 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5712 back::INDENT
5713 )?;
5714 writeln!(self.out, "}}")?;
5715 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5716 }
5717 Unorm10_10_10_2 => {
5718 let name = self.namer.call("unpackUnorm10_10_10_2");
5719 writeln!(
5720 self.out,
5721 "metal::float4 {name}(uint b0, \
5722 uint b1, \
5723 uint b2, \
5724 uint b3) {{"
5725 )?;
5726 writeln!(
5727 self.out,
5728 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5740 back::INDENT
5741 )?;
5742 writeln!(self.out, "}}")?;
5743 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5744 }
5745 Unorm8x4Bgra => {
5746 let name = self.namer.call("unpackUnorm8x4Bgra");
5747 writeln!(
5748 self.out,
5749 "metal::float4 {name}(metal::uchar b0, \
5750 metal::uchar b1, \
5751 metal::uchar b2, \
5752 metal::uchar b3) {{"
5753 )?;
5754 writeln!(
5755 self.out,
5756 "{}return metal::float4(float(b2) / 255.0f, \
5757 float(b1) / 255.0f, \
5758 float(b0) / 255.0f, \
5759 float(b3) / 255.0f);",
5760 back::INDENT
5761 )?;
5762 writeln!(self.out, "}}")?;
5763 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5764 }
5765 }
5766 }
5767
5768 fn write_wrapped_unary_op(
5769 &mut self,
5770 module: &crate::Module,
5771 func_ctx: &back::FunctionCtx,
5772 op: crate::UnaryOperator,
5773 operand: Handle<crate::Expression>,
5774 ) -> BackendResult {
5775 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5776 match op {
5777 crate::UnaryOperator::Negate
5784 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5785 {
5786 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5787 return Ok(());
5788 };
5789 let wrapped = WrappedFunction::UnaryOp {
5790 op,
5791 ty: (vector_size, scalar),
5792 };
5793 if !self.wrapped_functions.insert(wrapped) {
5794 return Ok(());
5795 }
5796
5797 let unsigned_scalar = crate::Scalar {
5798 kind: crate::ScalarKind::Uint,
5799 ..scalar
5800 };
5801 let mut type_name = String::new();
5802 let mut unsigned_type_name = String::new();
5803 match vector_size {
5804 None => {
5805 put_numeric_type(&mut type_name, scalar, &[])?;
5806 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5807 }
5808 Some(size) => {
5809 put_numeric_type(&mut type_name, scalar, &[size])?;
5810 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5811 }
5812 };
5813
5814 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5815 let level = back::Level(1);
5816 writeln!(
5817 self.out,
5818 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5819 )?;
5820 writeln!(self.out, "}}")?;
5821 writeln!(self.out)?;
5822 }
5823 _ => {}
5824 }
5825 Ok(())
5826 }
5827
5828 fn write_wrapped_binary_op(
5829 &mut self,
5830 module: &crate::Module,
5831 func_ctx: &back::FunctionCtx,
5832 expr: Handle<crate::Expression>,
5833 op: crate::BinaryOperator,
5834 left: Handle<crate::Expression>,
5835 right: Handle<crate::Expression>,
5836 ) -> BackendResult {
5837 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5838 let left_ty = func_ctx.resolve_type(left, &module.types);
5839 let right_ty = func_ctx.resolve_type(right, &module.types);
5840 match (op, expr_ty.scalar_kind()) {
5841 (
5848 crate::BinaryOperator::Divide,
5849 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5850 ) => {
5851 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5852 return Ok(());
5853 };
5854 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5855 return Ok(());
5856 };
5857 let wrapped = WrappedFunction::BinaryOp {
5858 op,
5859 left_ty: left_wrapped_ty,
5860 right_ty: right_wrapped_ty,
5861 };
5862 if !self.wrapped_functions.insert(wrapped) {
5863 return Ok(());
5864 }
5865
5866 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5867 return Ok(());
5868 };
5869 let mut type_name = String::new();
5870 match vector_size {
5871 None => put_numeric_type(&mut type_name, scalar, &[])?,
5872 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5873 };
5874 writeln!(
5875 self.out,
5876 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5877 )?;
5878 let level = back::Level(1);
5879 match scalar.kind {
5880 crate::ScalarKind::Sint => {
5881 let min_val = match scalar.width {
5882 4 => crate::Literal::I32(i32::MIN),
5883 8 => crate::Literal::I64(i64::MIN),
5884 _ => {
5885 return Err(Error::GenericValidation(format!(
5886 "Unexpected width for scalar {scalar:?}"
5887 )));
5888 }
5889 };
5890 write!(
5891 self.out,
5892 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5893 )?;
5894 self.put_literal(min_val)?;
5895 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5896 }
5897 crate::ScalarKind::Uint => writeln!(
5898 self.out,
5899 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5900 )?,
5901 _ => unreachable!(),
5902 }
5903 writeln!(self.out, "}}")?;
5904 writeln!(self.out)?;
5905 }
5906 (
5919 crate::BinaryOperator::Modulo,
5920 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5921 ) => {
5922 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5923 return Ok(());
5924 };
5925 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5926 else {
5927 return Ok(());
5928 };
5929 let wrapped = WrappedFunction::BinaryOp {
5930 op,
5931 left_ty: left_wrapped_ty,
5932 right_ty: (right_vector_size, right_scalar),
5933 };
5934 if !self.wrapped_functions.insert(wrapped) {
5935 return Ok(());
5936 }
5937
5938 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5939 return Ok(());
5940 };
5941 let mut type_name = String::new();
5942 match vector_size {
5943 None => put_numeric_type(&mut type_name, scalar, &[])?,
5944 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5945 };
5946 let mut rhs_type_name = String::new();
5947 match right_vector_size {
5948 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5949 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5950 };
5951
5952 writeln!(
5953 self.out,
5954 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5955 )?;
5956 let level = back::Level(1);
5957 match scalar.kind {
5958 crate::ScalarKind::Sint => {
5959 let min_val = match scalar.width {
5960 4 => crate::Literal::I32(i32::MIN),
5961 8 => crate::Literal::I64(i64::MIN),
5962 _ => {
5963 return Err(Error::GenericValidation(format!(
5964 "Unexpected width for scalar {scalar:?}"
5965 )));
5966 }
5967 };
5968 write!(
5969 self.out,
5970 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5971 )?;
5972 self.put_literal(min_val)?;
5973 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5974 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5975 }
5976 crate::ScalarKind::Uint => writeln!(
5977 self.out,
5978 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5979 )?,
5980 _ => unreachable!(),
5981 }
5982 writeln!(self.out, "}}")?;
5983 writeln!(self.out)?;
5984 }
5985 _ => {}
5986 }
5987 Ok(())
5988 }
5989
5990 fn get_dot_wrapper_function_helper_name(
5996 &self,
5997 scalar: crate::Scalar,
5998 size: crate::VectorSize,
5999 ) -> String {
6000 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
6002
6003 let type_name = scalar.to_msl_name();
6004 let size_suffix = common::vector_size_str(size);
6005 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
6006 }
6007
6008 #[allow(clippy::too_many_arguments)]
6009 fn write_wrapped_math_function(
6010 &mut self,
6011 module: &crate::Module,
6012 func_ctx: &back::FunctionCtx,
6013 fun: crate::MathFunction,
6014 arg: Handle<crate::Expression>,
6015 _arg1: Option<Handle<crate::Expression>>,
6016 _arg2: Option<Handle<crate::Expression>>,
6017 _arg3: Option<Handle<crate::Expression>>,
6018 ) -> BackendResult {
6019 let arg_ty = func_ctx.resolve_type(arg, &module.types);
6020 match fun {
6021 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
6029 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
6030 return Ok(());
6031 };
6032 let wrapped = WrappedFunction::Math {
6033 fun,
6034 arg_ty: (vector_size, scalar),
6035 };
6036 if !self.wrapped_functions.insert(wrapped) {
6037 return Ok(());
6038 }
6039
6040 let unsigned_scalar = crate::Scalar {
6041 kind: crate::ScalarKind::Uint,
6042 ..scalar
6043 };
6044 let mut type_name = String::new();
6045 let mut unsigned_type_name = String::new();
6046 match vector_size {
6047 None => {
6048 put_numeric_type(&mut type_name, scalar, &[])?;
6049 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
6050 }
6051 Some(size) => {
6052 put_numeric_type(&mut type_name, scalar, &[size])?;
6053 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
6054 }
6055 };
6056
6057 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
6058 let level = back::Level(1);
6059 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
6060 writeln!(self.out, "}}")?;
6061 writeln!(self.out)?;
6062 }
6063
6064 crate::MathFunction::Dot => match *arg_ty {
6065 crate::TypeInner::Vector { size, scalar }
6066 if matches!(
6067 scalar.kind,
6068 crate::ScalarKind::Sint | crate::ScalarKind::Uint
6069 ) =>
6070 {
6071 let wrapped = WrappedFunction::Math {
6073 fun,
6074 arg_ty: (Some(size), scalar),
6075 };
6076 if !self.wrapped_functions.insert(wrapped) {
6077 return Ok(());
6078 }
6079
6080 let mut vec_ty = String::new();
6081 put_numeric_type(&mut vec_ty, scalar, &[size])?;
6082 let mut ret_ty = String::new();
6083 put_numeric_type(&mut ret_ty, scalar, &[])?;
6084
6085 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6086
6087 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6089 let level = back::Level(1);
6090 write!(self.out, "{level}return ")?;
6091 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6092 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6093 Ok(())
6094 })?;
6095 writeln!(self.out, ";")?;
6096 writeln!(self.out, "}}")?;
6097 writeln!(self.out)?;
6098 }
6099 _ => {}
6100 },
6101
6102 _ => {}
6103 }
6104 Ok(())
6105 }
6106
6107 fn write_wrapped_cast(
6108 &mut self,
6109 module: &crate::Module,
6110 func_ctx: &back::FunctionCtx,
6111 expr: Handle<crate::Expression>,
6112 kind: crate::ScalarKind,
6113 convert: Option<crate::Bytes>,
6114 ) -> BackendResult {
6115 let src_ty = func_ctx.resolve_type(expr, &module.types);
6126 let Some(width) = convert else {
6127 return Ok(());
6128 };
6129 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6130 return Ok(());
6131 };
6132 let dst_scalar = crate::Scalar { kind, width };
6133 if src_scalar.kind != crate::ScalarKind::Float
6134 || (dst_scalar.kind != crate::ScalarKind::Sint
6135 && dst_scalar.kind != crate::ScalarKind::Uint)
6136 {
6137 return Ok(());
6138 }
6139 let wrapped = WrappedFunction::Cast {
6140 src_scalar,
6141 vector_size,
6142 dst_scalar,
6143 };
6144 if !self.wrapped_functions.insert(wrapped) {
6145 return Ok(());
6146 }
6147 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6148
6149 let mut src_type_name = String::new();
6150 match vector_size {
6151 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6152 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6153 };
6154 let mut dst_type_name = String::new();
6155 match vector_size {
6156 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6157 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6158 };
6159 let fun_name = match dst_scalar {
6160 crate::Scalar::I32 => F2I32_FUNCTION,
6161 crate::Scalar::U32 => F2U32_FUNCTION,
6162 crate::Scalar::I64 => F2I64_FUNCTION,
6163 crate::Scalar::U64 => F2U64_FUNCTION,
6164 _ => unreachable!(),
6165 };
6166
6167 writeln!(
6168 self.out,
6169 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6170 )?;
6171 let level = back::Level(1);
6172 write!(
6173 self.out,
6174 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6175 )?;
6176 self.put_literal(min)?;
6177 write!(self.out, ", ")?;
6178 self.put_literal(max)?;
6179 writeln!(self.out, "));")?;
6180 writeln!(self.out, "}}")?;
6181 writeln!(self.out)?;
6182 Ok(())
6183 }
6184
6185 fn write_convert_yuv_to_rgb_and_return(
6193 &mut self,
6194 level: back::Level,
6195 y: &str,
6196 uv: &str,
6197 params: &str,
6198 ) -> BackendResult {
6199 let l1 = level;
6200 let l2 = l1.next();
6201
6202 writeln!(
6204 self.out,
6205 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6206 )?;
6207
6208 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6211 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6212 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6213 writeln!(
6214 self.out,
6215 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6216 )?;
6217
6218 writeln!(
6221 self.out,
6222 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6223 )?;
6224
6225 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6228 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6229 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6230 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6231
6232 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6233 Ok(())
6234 }
6235
6236 #[allow(clippy::too_many_arguments)]
6237 fn write_wrapped_image_load(
6238 &mut self,
6239 module: &crate::Module,
6240 func_ctx: &back::FunctionCtx,
6241 image: Handle<crate::Expression>,
6242 _coordinate: Handle<crate::Expression>,
6243 _array_index: Option<Handle<crate::Expression>>,
6244 _sample: Option<Handle<crate::Expression>>,
6245 _level: Option<Handle<crate::Expression>>,
6246 ) -> BackendResult {
6247 let class = match *func_ctx.resolve_type(image, &module.types) {
6249 crate::TypeInner::Image { class, .. } => class,
6250 _ => unreachable!(),
6251 };
6252 if class != crate::ImageClass::External {
6253 return Ok(());
6254 }
6255 let wrapped = WrappedFunction::ImageLoad { class };
6256 if !self.wrapped_functions.insert(wrapped) {
6257 return Ok(());
6258 }
6259
6260 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6261 let l1 = back::Level(1);
6262 let l2 = l1.next();
6263 let l3 = l2.next();
6264 writeln!(
6265 self.out,
6266 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6267 )?;
6268 writeln!(
6272 self.out,
6273 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6274 )?;
6275 writeln!(
6276 self.out,
6277 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6278 )?;
6279
6280 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6282 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6283 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6285 writeln!(self.out, "{l1}}} else {{")?;
6286
6287 writeln!(
6289 self.out,
6290 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6291 )?;
6292 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6293
6294 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6296
6297 writeln!(self.out, "{l2}float2 uv;")?;
6298 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6299 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6301 writeln!(self.out, "{l2}}} else {{")?;
6302 writeln!(
6304 self.out,
6305 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6306 )?;
6307 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6308 writeln!(
6309 self.out,
6310 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6311 )?;
6312 writeln!(self.out, "{l2}}}")?;
6313
6314 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6315
6316 writeln!(self.out, "{l1}}}")?;
6317 writeln!(self.out, "}}")?;
6318 writeln!(self.out)?;
6319 Ok(())
6320 }
6321
6322 #[allow(clippy::too_many_arguments)]
6323 fn write_wrapped_image_sample(
6324 &mut self,
6325 module: &crate::Module,
6326 func_ctx: &back::FunctionCtx,
6327 image: Handle<crate::Expression>,
6328 _sampler: Handle<crate::Expression>,
6329 _gather: Option<crate::SwizzleComponent>,
6330 _coordinate: Handle<crate::Expression>,
6331 _array_index: Option<Handle<crate::Expression>>,
6332 _offset: Option<Handle<crate::Expression>>,
6333 _level: crate::SampleLevel,
6334 _depth_ref: Option<Handle<crate::Expression>>,
6335 clamp_to_edge: bool,
6336 ) -> BackendResult {
6337 if !clamp_to_edge {
6340 return Ok(());
6341 }
6342 let class = match *func_ctx.resolve_type(image, &module.types) {
6343 crate::TypeInner::Image { class, .. } => class,
6344 _ => unreachable!(),
6345 };
6346 let wrapped = WrappedFunction::ImageSample {
6347 class,
6348 clamp_to_edge: true,
6349 };
6350 if !self.wrapped_functions.insert(wrapped) {
6351 return Ok(());
6352 }
6353 match class {
6354 crate::ImageClass::External => {
6355 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6356 let l1 = back::Level(1);
6357 let l2 = l1.next();
6358 let l3 = l2.next();
6359 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6360 writeln!(
6361 self.out,
6362 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6363 )?;
6364
6365 writeln!(
6373 self.out,
6374 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6375 )?;
6376 writeln!(
6377 self.out,
6378 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6379 )?;
6380 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6381 writeln!(
6382 self.out,
6383 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6384 )?;
6385 writeln!(
6386 self.out,
6387 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6388 )?;
6389 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6390 writeln!(
6392 self.out,
6393 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6394 )?;
6395 writeln!(self.out, "{l1}}} else {{")?;
6396 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6397 writeln!(
6398 self.out,
6399 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6400 )?;
6401 writeln!(
6402 self.out,
6403 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6404 )?;
6405
6406 writeln!(
6408 self.out,
6409 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6410 )?;
6411 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6412 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6413 writeln!(
6415 self.out,
6416 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6417 )?;
6418 writeln!(self.out, "{l2}}} else {{")?;
6419 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6421 writeln!(
6422 self.out,
6423 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6424 )?;
6425 writeln!(
6426 self.out,
6427 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6428 )?;
6429 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6430 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6431 writeln!(self.out, "{l2}}}")?;
6432
6433 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6434
6435 writeln!(self.out, "{l1}}}")?;
6436 writeln!(self.out, "}}")?;
6437 writeln!(self.out)?;
6438 }
6439 _ => {
6440 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6441 let l1 = back::Level(1);
6442 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6443 writeln!(
6444 self.out,
6445 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6446 )?;
6447 writeln!(self.out, "}}")?;
6448 writeln!(self.out)?;
6449 }
6450 }
6451 Ok(())
6452 }
6453
6454 fn write_wrapped_image_query(
6455 &mut self,
6456 module: &crate::Module,
6457 func_ctx: &back::FunctionCtx,
6458 image: Handle<crate::Expression>,
6459 query: crate::ImageQuery,
6460 ) -> BackendResult {
6461 if !matches!(query, crate::ImageQuery::Size { .. }) {
6463 return Ok(());
6464 }
6465 let class = match *func_ctx.resolve_type(image, &module.types) {
6466 crate::TypeInner::Image { class, .. } => class,
6467 _ => unreachable!(),
6468 };
6469 if class != crate::ImageClass::External {
6470 return Ok(());
6471 }
6472 let wrapped = WrappedFunction::ImageQuerySize { class };
6473 if !self.wrapped_functions.insert(wrapped) {
6474 return Ok(());
6475 }
6476 writeln!(
6477 self.out,
6478 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6479 )?;
6480 let l1 = back::Level(1);
6481 let l2 = l1.next();
6482 writeln!(
6483 self.out,
6484 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6485 )?;
6486 writeln!(self.out, "{l2}return tex.params.size;")?;
6487 writeln!(self.out, "{l1}}} else {{")?;
6488 writeln!(
6490 self.out,
6491 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6492 )?;
6493 writeln!(self.out, "{l1}}}")?;
6494 writeln!(self.out, "}}")?;
6495 writeln!(self.out)?;
6496 Ok(())
6497 }
6498
6499 fn write_wrapped_cooperative_load(
6500 &mut self,
6501 module: &crate::Module,
6502 func_ctx: &back::FunctionCtx,
6503 columns: crate::CooperativeSize,
6504 rows: crate::CooperativeSize,
6505 pointer: Handle<crate::Expression>,
6506 ) -> BackendResult {
6507 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6508 let space = ptr_ty.pointer_space().unwrap();
6509 let space_name = space.to_msl_name().unwrap_or_default();
6510 let scalar = ptr_ty
6511 .pointer_base_type()
6512 .unwrap()
6513 .inner_with(&module.types)
6514 .scalar()
6515 .unwrap();
6516 let wrapped = WrappedFunction::CooperativeLoad {
6517 space_name,
6518 columns,
6519 rows,
6520 scalar,
6521 };
6522 if !self.wrapped_functions.insert(wrapped) {
6523 return Ok(());
6524 }
6525 let scalar_name = scalar.to_msl_name();
6526 writeln!(
6527 self.out,
6528 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6529 columns as u32, rows as u32,
6530 )?;
6531 let l1 = back::Level(1);
6532 writeln!(
6533 self.out,
6534 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6535 columns as u32, rows as u32
6536 )?;
6537 let matrix_origin = "0";
6538 writeln!(
6539 self.out,
6540 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6541 )?;
6542 writeln!(self.out, "{l1}return m;")?;
6543 writeln!(self.out, "}}")?;
6544 writeln!(self.out)?;
6545 Ok(())
6546 }
6547
6548 fn write_wrapped_cooperative_multiply_add(
6549 &mut self,
6550 module: &crate::Module,
6551 func_ctx: &back::FunctionCtx,
6552 space: crate::AddressSpace,
6553 a: Handle<crate::Expression>,
6554 b: Handle<crate::Expression>,
6555 ) -> BackendResult {
6556 let space_name = space.to_msl_name().unwrap_or_default();
6557 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6558 crate::TypeInner::CooperativeMatrix {
6559 columns,
6560 rows,
6561 scalar,
6562 ..
6563 } => (columns, rows, scalar),
6564 _ => unreachable!(),
6565 };
6566 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6567 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6568 _ => unreachable!(),
6569 };
6570 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6571 space_name,
6572 columns: b_c,
6573 rows: a_r,
6574 intermediate: a_c,
6575 scalar,
6576 };
6577 if !self.wrapped_functions.insert(wrapped) {
6578 return Ok(());
6579 }
6580 let scalar_name = scalar.to_msl_name();
6581 writeln!(
6582 self.out,
6583 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6584 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6585 )?;
6586 let l1 = back::Level(1);
6587 writeln!(
6588 self.out,
6589 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6590 b_c as u32, a_r as u32
6591 )?;
6592 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6593 writeln!(self.out, "{l1}return d;")?;
6594 writeln!(self.out, "}}")?;
6595 writeln!(self.out)?;
6596 Ok(())
6597 }
6598
6599 pub(super) fn write_wrapped_functions(
6600 &mut self,
6601 module: &crate::Module,
6602 func_ctx: &back::FunctionCtx,
6603 ) -> BackendResult {
6604 for (expr_handle, expr) in func_ctx.expressions.iter() {
6605 match *expr {
6606 crate::Expression::Unary { op, expr: operand } => {
6607 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6608 }
6609 crate::Expression::Binary { op, left, right } => {
6610 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6611 }
6612 crate::Expression::Math {
6613 fun,
6614 arg,
6615 arg1,
6616 arg2,
6617 arg3,
6618 } => {
6619 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6620 }
6621 crate::Expression::As {
6622 expr,
6623 kind,
6624 convert,
6625 } => {
6626 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6627 }
6628 crate::Expression::ImageLoad {
6629 image,
6630 coordinate,
6631 array_index,
6632 sample,
6633 level,
6634 } => {
6635 self.write_wrapped_image_load(
6636 module,
6637 func_ctx,
6638 image,
6639 coordinate,
6640 array_index,
6641 sample,
6642 level,
6643 )?;
6644 }
6645 crate::Expression::ImageSample {
6646 image,
6647 sampler,
6648 gather,
6649 coordinate,
6650 array_index,
6651 offset,
6652 level,
6653 depth_ref,
6654 clamp_to_edge,
6655 } => {
6656 self.write_wrapped_image_sample(
6657 module,
6658 func_ctx,
6659 image,
6660 sampler,
6661 gather,
6662 coordinate,
6663 array_index,
6664 offset,
6665 level,
6666 depth_ref,
6667 clamp_to_edge,
6668 )?;
6669 }
6670 crate::Expression::ImageQuery { image, query } => {
6671 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6672 }
6673 crate::Expression::CooperativeLoad {
6674 columns,
6675 rows,
6676 role: _,
6677 ref data,
6678 } => {
6679 self.write_wrapped_cooperative_load(
6680 module,
6681 func_ctx,
6682 columns,
6683 rows,
6684 data.pointer,
6685 )?;
6686 }
6687 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6688 let space = crate::AddressSpace::Private;
6689 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6690 }
6691 _ => {}
6692 }
6693 }
6694
6695 Ok(())
6696 }
6697
6698 fn write_functions(
6700 &mut self,
6701 module: &crate::Module,
6702 mod_info: &valid::ModuleInfo,
6703 options: &Options,
6704 pipeline_options: &PipelineOptions,
6705 ) -> Result<TranslationInfo, Error> {
6706 use back::msl::VertexFormat;
6707
6708 struct AttributeMappingResolved {
6711 ty_name: String,
6712 dimension: Option<crate::VectorSize>,
6713 scalar: crate::Scalar,
6714 name: String,
6715 }
6716 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6717
6718 struct VertexBufferMappingResolved<'a> {
6719 id: u32,
6720 stride: u32,
6721 step_mode: back::msl::VertexBufferStepMode,
6722 ty_name: String,
6723 param_name: String,
6724 elem_name: String,
6725 attributes: &'a Vec<back::msl::AttributeMapping>,
6726 }
6727 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6728
6729 struct UnpackingFunction {
6731 name: String,
6732 byte_count: u32,
6733 dimension: Option<crate::VectorSize>,
6734 scalar: crate::Scalar,
6735 }
6736 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6737
6738 let mut needs_vertex_id = false;
6744 let v_id = self.namer.call("v_id");
6745
6746 let mut needs_instance_id = false;
6747 let i_id = self.namer.call("i_id");
6748 if pipeline_options.vertex_pulling_transform {
6749 for vbm in &pipeline_options.vertex_buffer_mappings {
6750 let buffer_id = vbm.id;
6751 let buffer_stride = vbm.stride;
6752
6753 assert!(
6754 buffer_stride > 0,
6755 "Vertex pulling requires a non-zero buffer stride."
6756 );
6757
6758 match vbm.step_mode {
6759 back::msl::VertexBufferStepMode::Constant => {}
6760 back::msl::VertexBufferStepMode::ByVertex => {
6761 needs_vertex_id = true;
6762 }
6763 back::msl::VertexBufferStepMode::ByInstance => {
6764 needs_instance_id = true;
6765 }
6766 }
6767
6768 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6769 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6770 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6771
6772 vbm_resolved.push(VertexBufferMappingResolved {
6773 id: buffer_id,
6774 stride: buffer_stride,
6775 step_mode: vbm.step_mode,
6776 ty_name: buffer_ty,
6777 param_name: buffer_param,
6778 elem_name: buffer_elem,
6779 attributes: &vbm.attributes,
6780 });
6781
6782 for attribute in &vbm.attributes {
6784 if unpacking_functions.contains_key(&attribute.format) {
6785 continue;
6786 }
6787 let (name, byte_count, dimension, scalar) =
6788 match self.write_unpacking_function(attribute.format) {
6789 Ok((name, byte_count, dimension, scalar)) => {
6790 (name, byte_count, dimension, scalar)
6791 }
6792 _ => {
6793 continue;
6794 }
6795 };
6796 unpacking_functions.insert(
6797 attribute.format,
6798 UnpackingFunction {
6799 name,
6800 byte_count,
6801 dimension,
6802 scalar,
6803 },
6804 );
6805 }
6806 }
6807 }
6808
6809 let mut pass_through_globals = Vec::new();
6810 for (fun_handle, fun) in module.functions.iter() {
6811 log::trace!(
6812 "function {:?}, handle {:?}",
6813 fun.name.as_deref().unwrap_or("(anonymous)"),
6814 fun_handle
6815 );
6816
6817 let ctx = back::FunctionCtx {
6818 ty: back::FunctionType::Function(fun_handle),
6819 info: &mod_info[fun_handle],
6820 expressions: &fun.expressions,
6821 named_expressions: &fun.named_expressions,
6822 };
6823
6824 writeln!(self.out)?;
6825 self.write_wrapped_functions(module, &ctx)?;
6826
6827 let fun_info = &mod_info[fun_handle];
6828 pass_through_globals.clear();
6829 let mut needs_buffer_sizes = false;
6830 for (handle, var) in module.global_variables.iter() {
6831 if !fun_info[handle].is_empty() {
6832 if var.space.needs_pass_through() {
6833 pass_through_globals.push(handle);
6834 }
6835 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6836 }
6837 }
6838
6839 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6840 match fun.result {
6841 Some(ref result) => {
6842 let ty_name = TypeContext {
6843 handle: result.ty,
6844 gctx: module.to_ctx(),
6845 names: &self.names,
6846 access: crate::StorageAccess::empty(),
6847 first_time: false,
6848 };
6849 write!(self.out, "{ty_name}")?;
6850 }
6851 None => {
6852 write!(self.out, "void")?;
6853 }
6854 }
6855 writeln!(self.out, " {fun_name}(")?;
6856
6857 for (index, arg) in fun.arguments.iter().enumerate() {
6858 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6859 let param_type_name = TypeContext {
6860 handle: arg.ty,
6861 gctx: module.to_ctx(),
6862 names: &self.names,
6863 access: crate::StorageAccess::empty(),
6864 first_time: false,
6865 };
6866 let separator = separate(
6867 !pass_through_globals.is_empty()
6868 || index + 1 != fun.arguments.len()
6869 || needs_buffer_sizes,
6870 );
6871 writeln!(
6872 self.out,
6873 "{}{} {}{}",
6874 back::INDENT,
6875 param_type_name,
6876 name,
6877 separator
6878 )?;
6879 }
6880 for (index, &handle) in pass_through_globals.iter().enumerate() {
6881 let tyvar = TypedGlobalVariable {
6882 module,
6883 names: &self.names,
6884 handle,
6885 usage: fun_info[handle],
6886 reference: true,
6887 };
6888 let separator =
6889 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6890 write!(self.out, "{}", back::INDENT)?;
6891 tyvar.try_fmt(&mut self.out)?;
6892 writeln!(self.out, "{separator}")?;
6893 }
6894
6895 if needs_buffer_sizes {
6896 writeln!(
6897 self.out,
6898 "{}constant _mslBufferSizes& _buffer_sizes",
6899 back::INDENT
6900 )?;
6901 }
6902
6903 writeln!(self.out, ") {{")?;
6904
6905 let guarded_indices =
6906 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6907
6908 let context = StatementContext {
6909 expression: ExpressionContext {
6910 function: fun,
6911 origin: FunctionOrigin::Handle(fun_handle),
6912 info: fun_info,
6913 lang_version: options.lang_version,
6914 policies: options.bounds_check_policies,
6915 guarded_indices,
6916 module,
6917 mod_info,
6918 pipeline_options,
6919 force_loop_bounding: options.force_loop_bounding,
6920 },
6921 result_struct: None,
6922 };
6923
6924 self.put_locals(&context.expression)?;
6925 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6926 self.put_block(back::Level(1), &fun.body, &context)?;
6927 writeln!(self.out, "}}")?;
6928 self.named_expressions.clear();
6929 }
6930
6931 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6932 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6933
6934 let mut info = TranslationInfo {
6935 entry_point_names: Vec::with_capacity(ep_range.len()),
6936 };
6937
6938 for ep_index in ep_range {
6939 let ep = &module.entry_points[ep_index];
6940 let fun = &ep.function;
6941 let fun_info = mod_info.get_entry_point(ep_index);
6942 let mut ep_error = None;
6943
6944 let mut v_existing_id = None;
6948 let mut i_existing_id = None;
6949
6950 log::trace!(
6951 "entry point {:?}, index {:?}",
6952 fun.name.as_deref().unwrap_or("(anonymous)"),
6953 ep_index
6954 );
6955
6956 let ctx = back::FunctionCtx {
6957 ty: back::FunctionType::EntryPoint(ep_index as u16),
6958 info: fun_info,
6959 expressions: &fun.expressions,
6960 named_expressions: &fun.named_expressions,
6961 };
6962
6963 self.write_wrapped_functions(module, &ctx)?;
6964
6965 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6966 crate::ShaderStage::Vertex => (
6967 Some("vertex"),
6968 LocationMode::VertexInput,
6969 LocationMode::VertexOutput,
6970 true,
6971 ),
6972 crate::ShaderStage::Fragment => (
6973 Some("fragment"),
6974 LocationMode::FragmentInput,
6975 LocationMode::FragmentOutput,
6976 false,
6977 ),
6978 crate::ShaderStage::Compute => (
6979 Some("kernel"),
6980 LocationMode::Uniform,
6981 LocationMode::Uniform,
6982 false,
6983 ),
6984 crate::ShaderStage::Task => {
6985 (None, LocationMode::Uniform, LocationMode::Uniform, false)
6986 }
6987 crate::ShaderStage::Mesh => {
6988 (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6989 }
6990 crate::ShaderStage::RayGeneration
6991 | crate::ShaderStage::AnyHit
6992 | crate::ShaderStage::ClosestHit
6993 | crate::ShaderStage::Miss => unimplemented!(),
6994 };
6995
6996 let do_vertex_pulling = can_vertex_pull
6998 && pipeline_options.vertex_pulling_transform
6999 && !pipeline_options.vertex_buffer_mappings.is_empty();
7000
7001 let needs_buffer_sizes = do_vertex_pulling
7003 || module
7004 .global_variables
7005 .iter()
7006 .filter(|&(handle, _)| !fun_info[handle].is_empty())
7007 .any(|(_, var)| needs_array_length(var.ty, &module.types));
7008
7009 if !options.fake_missing_bindings {
7012 for (var_handle, var) in module.global_variables.iter() {
7013 if fun_info[var_handle].is_empty() {
7014 continue;
7015 }
7016 match var.space {
7017 crate::AddressSpace::Uniform
7018 | crate::AddressSpace::Storage { .. }
7019 | crate::AddressSpace::Handle => {
7020 let br = match var.binding {
7021 Some(ref br) => br,
7022 None => {
7023 let var_name = var.name.clone().unwrap_or_default();
7024 ep_error =
7025 Some(super::EntryPointError::MissingBinding(var_name));
7026 break;
7027 }
7028 };
7029 let target = options.get_resource_binding_target(ep, br);
7030 let good = match target {
7031 Some(target) => {
7032 match module.types[var.ty].inner {
7036 crate::TypeInner::Image {
7037 class: crate::ImageClass::External,
7038 ..
7039 } => target.external_texture.is_some(),
7040 crate::TypeInner::Image { .. } => target.texture.is_some(),
7041 crate::TypeInner::Sampler { .. } => {
7042 target.sampler.is_some()
7043 }
7044 _ => target.buffer.is_some(),
7045 }
7046 }
7047 None => false,
7048 };
7049 if !good {
7050 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
7051 break;
7052 }
7053 }
7054 crate::AddressSpace::Immediate => {
7055 if let Err(e) = options.resolve_immediates(ep) {
7056 ep_error = Some(e);
7057 break;
7058 }
7059 }
7060 crate::AddressSpace::Function
7061 | crate::AddressSpace::Private
7062 | crate::AddressSpace::WorkGroup
7063 | crate::AddressSpace::TaskPayload => {}
7064 crate::AddressSpace::RayPayload
7065 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
7066 }
7067 }
7068 if needs_buffer_sizes {
7069 if let Err(err) = options.resolve_sizes_buffer(ep) {
7070 ep_error = Some(err);
7071 }
7072 }
7073 }
7074
7075 if let Some(err) = ep_error {
7076 info.entry_point_names.push(Err(err));
7077 continue;
7078 }
7079 let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7080 info.entry_point_names.push(Ok(fun_name.clone()));
7081
7082 writeln!(self.out)?;
7083
7084 let mut flattened_member_names = FastHashMap::default();
7090 let mut varyings_namer = proc::Namer::default();
7092
7093 let mut empty_names = FastHashMap::default(); varyings_namer.reset(
7095 module,
7096 &super::keywords::RESERVED_SET,
7097 proc::KeywordSet::empty(),
7098 proc::CaseInsensitiveKeywordSet::empty(),
7099 &[CLAMPED_LOD_LOAD_PREFIX],
7100 &mut empty_names,
7101 );
7102
7103 let mut flattened_arguments = Vec::new();
7108 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7109 match module.types[arg.ty].inner {
7110 crate::TypeInner::Struct { ref members, .. } => {
7111 for (member_index, member) in members.iter().enumerate() {
7112 let member_index = member_index as u32;
7113 flattened_arguments.push((
7114 NameKey::StructMember(arg.ty, member_index),
7115 member.ty,
7116 member.binding.as_ref(),
7117 ));
7118 let name_key = NameKey::StructMember(arg.ty, member_index);
7119 let name = match member.binding {
7120 Some(crate::Binding::Location { .. }) => {
7121 if do_vertex_pulling {
7122 self.namer.call(&self.names[&name_key])
7123 } else {
7124 varyings_namer.call(&self.names[&name_key])
7125 }
7126 }
7127 _ => self.namer.call(&self.names[&name_key]),
7128 };
7129 flattened_member_names.insert(name_key, name);
7130 }
7131 }
7132 _ => flattened_arguments.push((
7133 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7134 arg.ty,
7135 arg.binding.as_ref(),
7136 )),
7137 }
7138 }
7139
7140 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7145 let varyings_member_name = self.namer.call("varyings");
7146 let mut has_varyings = false;
7147
7148 if !flattened_arguments.is_empty() {
7149 if !do_vertex_pulling {
7150 writeln!(self.out, "struct {stage_in_name} {{")?;
7151 }
7152 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7153 let Some(binding) = binding else {
7154 continue;
7155 };
7156 let name = match *name_key {
7157 NameKey::StructMember(..) => &flattened_member_names[name_key],
7158 _ => &self.names[name_key],
7159 };
7160 let ty_name = TypeContext {
7161 handle: ty,
7162 gctx: module.to_ctx(),
7163 names: &self.names,
7164 access: crate::StorageAccess::empty(),
7165 first_time: false,
7166 };
7167 let resolved = options.resolve_local_binding(binding, in_mode)?;
7168 let location = match *binding {
7169 crate::Binding::Location { location, .. } => Some(location),
7170 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7171 crate::Binding::BuiltIn(_) => continue,
7172 };
7173 if do_vertex_pulling {
7174 let Some(location) = location else {
7175 continue;
7176 };
7177 am_resolved.insert(
7179 location,
7180 AttributeMappingResolved {
7181 ty_name: ty_name.to_string(),
7182 dimension: ty_name.vector_size(),
7183 scalar: ty_name.scalar().unwrap(),
7184 name: name.to_string(),
7185 },
7186 );
7187 } else {
7188 has_varyings = true;
7189 if let super::ResolvedBinding::User {
7190 prefix,
7191 index,
7192 interpolation: Some(super::ResolvedInterpolation::PerVertex),
7193 } = resolved
7194 {
7195 if options.lang_version < (4, 0) {
7196 return Err(Error::PerVertexNotSupported);
7197 }
7198 write!(
7199 self.out,
7200 "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7201 back::INDENT,
7202 ty_name.unwrap_array()
7203 )?;
7204 } else {
7205 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7206 resolved.try_fmt(&mut self.out)?;
7207 }
7208 writeln!(self.out, ";")?;
7209 }
7210 }
7211 if !do_vertex_pulling {
7212 writeln!(self.out, "}};")?;
7213 }
7214 }
7215
7216 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7219 let result_member_name = self.namer.call("member");
7220 let result_type_name = match fun.result {
7221 Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7222 let mut result_members = Vec::new();
7223 if let crate::TypeInner::Struct { ref members, .. } =
7224 module.types[result.ty].inner
7225 {
7226 for (member_index, member) in members.iter().enumerate() {
7227 result_members.push((
7228 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7229 member.ty,
7230 member.binding.as_ref(),
7231 ));
7232 }
7233 } else {
7234 result_members.push((
7235 &result_member_name,
7236 result.ty,
7237 result.binding.as_ref(),
7238 ));
7239 }
7240
7241 writeln!(self.out, "struct {stage_out_name} {{")?;
7242 let mut has_point_size = false;
7243 for (name, ty, binding) in result_members {
7244 let ty_name = TypeContext {
7245 handle: ty,
7246 gctx: module.to_ctx(),
7247 names: &self.names,
7248 access: crate::StorageAccess::empty(),
7249 first_time: true,
7250 };
7251 let binding = binding.ok_or_else(|| {
7252 Error::GenericValidation("Expected binding, got None".into())
7253 })?;
7254
7255 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7256 has_point_size = true;
7257 if !pipeline_options.allow_and_force_point_size {
7258 continue;
7259 }
7260 }
7261
7262 let array_len = match module.types[ty].inner {
7263 crate::TypeInner::Array {
7264 size: crate::ArraySize::Constant(size),
7265 ..
7266 } => Some(size),
7267 _ => None,
7268 };
7269 let resolved = options.resolve_local_binding(binding, out_mode)?;
7270 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7271 resolved.try_fmt(&mut self.out)?;
7272 if let Some(array_len) = array_len {
7273 write!(self.out, " [{array_len}]")?;
7274 }
7275 writeln!(self.out, ";")?;
7276 }
7277
7278 if pipeline_options.allow_and_force_point_size
7279 && ep.stage == crate::ShaderStage::Vertex
7280 && !has_point_size
7281 {
7282 writeln!(
7284 self.out,
7285 "{}float _point_size [[point_size]];",
7286 back::INDENT
7287 )?;
7288 }
7289 writeln!(self.out, "}};")?;
7290 &stage_out_name
7291 }
7292 Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7293 assert_eq!(
7294 module.types[result.ty].inner,
7295 crate::TypeInner::Vector {
7296 size: crate::VectorSize::Tri,
7297 scalar: crate::Scalar::U32
7298 }
7299 );
7300
7301 "metal::uint3"
7302 }
7303 _ => "void",
7304 };
7305
7306 let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7307 Some(self.write_mesh_output_types(
7308 mesh_info,
7309 &fun_name,
7310 module,
7311 pipeline_options.allow_and_force_point_size,
7312 options,
7313 )?)
7314 } else {
7315 None
7316 };
7317
7318 if do_vertex_pulling {
7321 for vbm in &vbm_resolved {
7322 let buffer_stride = vbm.stride;
7323 let buffer_ty = &vbm.ty_name;
7324
7325 writeln!(
7329 self.out,
7330 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7331 )?;
7332 }
7333 }
7334
7335 let is_wrapped = matches!(
7336 ep.stage,
7337 crate::ShaderStage::Task | crate::ShaderStage::Mesh
7338 );
7339 let fun_name = fun_name.clone();
7340 let nested_fun_name = if is_wrapped {
7341 self.namer.call(&format!("_{fun_name}"))
7342 } else {
7343 fun_name.clone()
7344 };
7345
7346 if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7348 let total_threads =
7349 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7350 write!(
7351 self.out,
7352 "[[max_total_threads_per_threadgroup({total_threads})]] "
7353 )?;
7354 }
7355
7356 if let Some(em_str) = em_str {
7358 write!(self.out, "{em_str} ")?;
7359 }
7360 writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7361
7362 let mut args = Vec::new();
7363
7364 if has_varyings {
7367 args.push(EntryPointArgument {
7368 ty_name: stage_in_name,
7369 name: varyings_member_name.clone(),
7370 binding: " [[stage_in]]".to_string(),
7371 init: None,
7372 });
7373 }
7374
7375 let mut local_invocation_index = None;
7376
7377 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7380 let binding = match binding {
7381 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7382 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7383 _ => continue,
7384 };
7385 let name = match *name_key {
7386 NameKey::StructMember(..) => &flattened_member_names[name_key],
7387 _ => &self.names[name_key],
7388 };
7389
7390 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7391 local_invocation_index = Some(name_key);
7392 }
7393
7394 let ty_name = TypeContext {
7395 handle: ty,
7396 gctx: module.to_ctx(),
7397 names: &self.names,
7398 access: crate::StorageAccess::empty(),
7399 first_time: false,
7400 };
7401
7402 match *binding {
7403 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7404 v_existing_id = Some(name.clone());
7405 }
7406 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7407 i_existing_id = Some(name.clone());
7408 }
7409 _ => {}
7410 };
7411
7412 let resolved = options.resolve_local_binding(binding, in_mode)?;
7413 let mut binding = String::new();
7414 resolved.try_fmt(&mut binding)?;
7415
7416 args.push(EntryPointArgument {
7417 ty_name: format!("{ty_name}"),
7418 name: name.clone(),
7419 binding,
7420 init: None,
7421 });
7422 }
7423
7424 let need_workgroup_variables_initialization =
7425 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7426
7427 if local_invocation_index.is_none()
7428 && (need_workgroup_variables_initialization
7429 || ep.stage == crate::ShaderStage::Task
7430 || ep.stage == crate::ShaderStage::Mesh)
7431 {
7432 args.push(EntryPointArgument {
7433 ty_name: "uint".to_string(),
7434 name: "__local_invocation_index".to_string(),
7435 binding: " [[thread_index_in_threadgroup]]".to_string(),
7436 init: None,
7437 });
7438 }
7439
7440 for (handle, var) in module.global_variables.iter() {
7445 let usage = fun_info[handle];
7446 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7447 continue;
7448 }
7449
7450 if options.lang_version < (1, 2) {
7451 match var.space {
7452 crate::AddressSpace::Storage { access }
7462 if access.contains(crate::StorageAccess::STORE)
7463 && ep.stage == crate::ShaderStage::Fragment =>
7464 {
7465 return Err(Error::UnsupportedWritableStorageBuffer)
7466 }
7467 crate::AddressSpace::Handle => {
7468 match module.types[var.ty].inner {
7469 crate::TypeInner::Image {
7470 class: crate::ImageClass::Storage { access, .. },
7471 ..
7472 } => {
7473 if access.contains(crate::StorageAccess::STORE)
7483 && (ep.stage == crate::ShaderStage::Vertex
7484 || ep.stage == crate::ShaderStage::Fragment)
7485 {
7486 return Err(Error::UnsupportedWritableStorageTexture(
7487 ep.stage,
7488 ));
7489 }
7490
7491 if access.contains(
7492 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7493 ) {
7494 return Err(Error::UnsupportedRWStorageTexture);
7495 }
7496 }
7497 _ => {}
7498 }
7499 }
7500 _ => {}
7501 }
7502 }
7503
7504 match var.space {
7506 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7507 crate::TypeInner::BindingArray { base, .. } => {
7508 match module.types[base].inner {
7509 crate::TypeInner::Sampler { .. } => {
7510 if options.lang_version < (2, 0) {
7511 return Err(Error::UnsupportedArrayOf(
7512 "samplers".to_string(),
7513 ));
7514 }
7515 }
7516 crate::TypeInner::Image { class, .. } => match class {
7517 crate::ImageClass::Sampled { .. }
7518 | crate::ImageClass::Depth { .. }
7519 | crate::ImageClass::Storage {
7520 access: crate::StorageAccess::LOAD,
7521 ..
7522 } => {
7523 if options.lang_version < (2, 0) {
7528 return Err(Error::UnsupportedArrayOf(
7529 "textures".to_string(),
7530 ));
7531 }
7532 }
7533 crate::ImageClass::Storage {
7534 access: crate::StorageAccess::STORE,
7535 ..
7536 } => {
7537 if options.lang_version < (2, 0) {
7542 return Err(Error::UnsupportedArrayOf(
7543 "write-only textures".to_string(),
7544 ));
7545 }
7546 }
7547 crate::ImageClass::Storage { .. } => {
7548 if options.lang_version < (3, 0) {
7549 return Err(Error::UnsupportedArrayOf(
7550 "read-write textures".to_string(),
7551 ));
7552 }
7553 }
7554 crate::ImageClass::External => {
7555 return Err(Error::UnsupportedArrayOf(
7556 "external textures".to_string(),
7557 ));
7558 }
7559 },
7560 _ => {
7561 return Err(Error::UnsupportedArrayOfType(base));
7562 }
7563 }
7564 }
7565 _ => {}
7566 },
7567 _ => {}
7568 }
7569
7570 let resolved = match var.space {
7572 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7573 crate::AddressSpace::WorkGroup => None,
7574 crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7575 _ => options
7576 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7577 .ok(),
7578 };
7579 if let Some(ref resolved) = resolved {
7580 if resolved.as_inline_sampler(options).is_some() {
7582 continue;
7583 }
7584 }
7585
7586 match module.types[var.ty].inner {
7587 crate::TypeInner::Image {
7588 class: crate::ImageClass::External,
7589 ..
7590 } => {
7591 let target = match resolved {
7595 Some(back::msl::ResolvedBinding::Resource(target)) => {
7596 target.external_texture
7597 }
7598 _ => None,
7599 };
7600
7601 for i in 0..3 {
7602 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7603 handle,
7604 ExternalTextureNameKey::Plane(i),
7605 )];
7606 let ty_name = format!(
7607 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7608 );
7609 let name = plane_name.clone();
7610 let binding = if let Some(ref target) = target {
7611 format!(" [[texture({})]]", target.planes[i])
7612 } else {
7613 String::new()
7614 };
7615 args.push(EntryPointArgument {
7616 ty_name,
7617 name,
7618 binding,
7619 init: None,
7620 });
7621 }
7622 let params_ty_name = &self.names
7623 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7624 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7625 handle,
7626 ExternalTextureNameKey::Params,
7627 )];
7628 let binding = if let Some(ref target) = target {
7629 format!(" [[buffer({})]]", target.params)
7630 } else {
7631 String::new()
7632 };
7633
7634 args.push(EntryPointArgument {
7635 ty_name: format!("constant {params_ty_name}&"),
7636 name: params_name.clone(),
7637 binding,
7638 init: None,
7639 });
7640 }
7641 _ => {
7642 if var.space == crate::AddressSpace::WorkGroup
7643 && ep.stage == crate::ShaderStage::Mesh
7644 {
7645 continue;
7646 }
7647 let tyvar = TypedGlobalVariable {
7648 module,
7649 names: &self.names,
7650 handle,
7651 usage,
7652 reference: true,
7653 };
7654 let parts = tyvar.to_parts()?;
7655 let mut binding = String::new();
7656 if let Some(resolved) = resolved {
7657 resolved.try_fmt(&mut binding)?;
7658 }
7659 args.push(EntryPointArgument {
7660 ty_name: parts.ty_name,
7661 name: parts.var_name,
7662 binding,
7663 init: var.init,
7664 });
7665 }
7666 }
7667 }
7668
7669 if do_vertex_pulling {
7670 if needs_vertex_id && v_existing_id.is_none() {
7671 args.push(EntryPointArgument {
7673 ty_name: "uint".to_string(),
7674 name: v_id.clone(),
7675 binding: " [[vertex_id]]".to_string(),
7676 init: None,
7677 });
7678 }
7679
7680 if needs_instance_id && i_existing_id.is_none() {
7681 args.push(EntryPointArgument {
7682 ty_name: "uint".to_string(),
7683 name: i_id.clone(),
7684 binding: " [[instance_id]]".to_string(),
7685 init: None,
7686 });
7687 }
7688
7689 for vbm in &vbm_resolved {
7692 let id = &vbm.id;
7693 let ty_name = &vbm.ty_name;
7694 let param_name = &vbm.param_name;
7695 args.push(EntryPointArgument {
7696 ty_name: format!("const device {ty_name}*"),
7697 name: param_name.clone(),
7698 binding: format!(" [[buffer({id})]]"),
7699 init: None,
7700 });
7701 }
7702 }
7703
7704 if needs_buffer_sizes {
7707 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7709 let mut binding = String::new();
7710 resolved.try_fmt(&mut binding)?;
7711 args.push(EntryPointArgument {
7712 ty_name: "constant _mslBufferSizes&".to_string(),
7713 name: "_buffer_sizes".to_string(),
7714 binding,
7715 init: None,
7716 });
7717 }
7718
7719 let mut is_first_arg = true;
7720 for arg in &args {
7721 if is_first_arg {
7722 write!(self.out, " ")?;
7723 } else {
7724 write!(self.out, ", ")?;
7725 }
7726 is_first_arg = false;
7727 write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7728 if !is_wrapped {
7729 write!(self.out, "{}", arg.binding)?;
7730 if let Some(init) = arg.init {
7731 write!(self.out, " = ")?;
7732 self.put_const_expression(
7733 init,
7734 module,
7735 mod_info,
7736 &module.global_expressions,
7737 )?;
7738 }
7739 }
7740 writeln!(self.out)?;
7741 }
7742 if ep.stage == crate::ShaderStage::Mesh {
7743 for (handle, var) in module.global_variables.iter() {
7744 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7745 continue;
7746 }
7747 if is_first_arg {
7748 write!(self.out, " ")?;
7749 } else {
7750 write!(self.out, ", ")?;
7751 }
7752 let ty_context = TypeContext {
7753 handle: module.global_variables[handle].ty,
7754 gctx: module.to_ctx(),
7755 names: &self.names,
7756 access: crate::StorageAccess::empty(),
7757 first_time: false,
7758 };
7759 writeln!(
7760 self.out,
7761 "threadgroup {ty_context}& {}",
7762 self.names[&NameKey::GlobalVariable(handle)]
7763 )?;
7764 }
7765 }
7766
7767 writeln!(self.out, ") {{")?;
7769
7770 if do_vertex_pulling {
7772 for vbm in &vbm_resolved {
7775 for attribute in vbm.attributes {
7776 let location = attribute.shader_location;
7777 let am_option = am_resolved.get(&location);
7778 if am_option.is_none() {
7779 continue;
7782 }
7783 let am = am_option.unwrap();
7784 let attribute_ty_name = &am.ty_name;
7785 let attribute_name = &am.name;
7786
7787 writeln!(
7788 self.out,
7789 "{}{attribute_ty_name} {attribute_name} = {{}};",
7790 back::Level(1)
7791 )?;
7792 }
7793
7794 write!(self.out, "{}if (", back::Level(1))?;
7797
7798 let idx = &vbm.id;
7799 let stride = &vbm.stride;
7800 let index_name = match vbm.step_mode {
7801 back::msl::VertexBufferStepMode::Constant => "0",
7802 back::msl::VertexBufferStepMode::ByVertex => {
7803 if let Some(ref name) = v_existing_id {
7804 name
7805 } else {
7806 &v_id
7807 }
7808 }
7809 back::msl::VertexBufferStepMode::ByInstance => {
7810 if let Some(ref name) = i_existing_id {
7811 name
7812 } else {
7813 &i_id
7814 }
7815 }
7816 };
7817 write!(
7818 self.out,
7819 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7820 )?;
7821
7822 writeln!(self.out, ") {{")?;
7823
7824 let ty_name = &vbm.ty_name;
7826 let elem_name = &vbm.elem_name;
7827 let param_name = &vbm.param_name;
7828
7829 writeln!(
7830 self.out,
7831 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7832 back::Level(2),
7833 )?;
7834
7835 for attribute in vbm.attributes {
7838 let location = attribute.shader_location;
7839 let Some(am) = am_resolved.get(&location) else {
7840 continue;
7844 };
7845 let attribute_name = &am.name;
7846 let attribute_ty_name = &am.ty_name;
7847
7848 let offset = attribute.offset;
7849 let func = unpacking_functions
7850 .get(&attribute.format)
7851 .expect("Should have generated this unpacking function earlier.");
7852 let func_name = &func.name;
7853
7854 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7860
7861 let needs_conversion = am.scalar != func.scalar;
7864
7865 if needs_padding_or_truncation != Ordering::Equal {
7866 writeln!(
7869 self.out,
7870 "{}// {attribute_ty_name} <- {:?}",
7871 back::Level(2),
7872 attribute.format
7873 )?;
7874 }
7875
7876 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7877
7878 if needs_padding_or_truncation == Ordering::Greater {
7879 write!(self.out, "{attribute_ty_name}(")?;
7881 }
7882
7883 if needs_conversion {
7885 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7886 write!(self.out, "(")?;
7887 }
7888 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7889 for i in (offset + 1)..(offset + func.byte_count) {
7890 write!(self.out, ", {elem_name}.data[{i}]")?;
7891 }
7892 write!(self.out, ")")?;
7893 if needs_conversion {
7894 write!(self.out, ")")?;
7895 }
7896
7897 match needs_padding_or_truncation {
7898 Ordering::Greater => {
7899 let ty_is_int = scalar_is_int(am.scalar);
7901 let zero_value = if ty_is_int { "0" } else { "0.0" };
7902 let one_value = if ty_is_int { "1" } else { "1.0" };
7903 for i in func.dimension.map_or(1, u8::from)
7904 ..am.dimension.map_or(1, u8::from)
7905 {
7906 write!(
7907 self.out,
7908 ", {}",
7909 if i == 3 { one_value } else { zero_value }
7910 )?;
7911 }
7912 }
7913 Ordering::Less => {
7914 write!(
7916 self.out,
7917 ".{}",
7918 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7919 )?;
7920 }
7921 Ordering::Equal => {}
7922 }
7923
7924 if needs_padding_or_truncation == Ordering::Greater {
7925 write!(self.out, ")")?;
7926 }
7927
7928 writeln!(self.out, ";")?;
7929 }
7930
7931 writeln!(self.out, "{}}}", back::Level(1))?;
7933 }
7934 }
7935
7936 for (handle, var) in module.global_variables.iter() {
7939 let usage = fun_info[handle];
7940 if usage.is_empty() {
7941 continue;
7942 }
7943 if var.space == crate::AddressSpace::Private {
7944 let tyvar = TypedGlobalVariable {
7945 module,
7946 names: &self.names,
7947 handle,
7948 usage,
7949
7950 reference: false,
7951 };
7952 write!(self.out, "{}", back::INDENT)?;
7953 tyvar.try_fmt(&mut self.out)?;
7954 match var.init {
7955 Some(value) => {
7956 write!(self.out, " = ")?;
7957 self.put_const_expression(
7958 value,
7959 module,
7960 mod_info,
7961 &module.global_expressions,
7962 )?;
7963 writeln!(self.out, ";")?;
7964 }
7965 None => {
7966 writeln!(self.out, " = {{}};")?;
7967 }
7968 };
7969 } else if let Some(ref binding) = var.binding {
7970 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7971 if let Some(sampler) = resolved.as_inline_sampler(options) {
7972 let name = &self.names[&NameKey::GlobalVariable(handle)];
7974 writeln!(
7975 self.out,
7976 "{}constexpr {}::sampler {}(",
7977 back::INDENT,
7978 NAMESPACE,
7979 name
7980 )?;
7981 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7982 writeln!(self.out, "{});", back::INDENT)?;
7983 } else if let crate::TypeInner::Image {
7984 class: crate::ImageClass::External,
7985 ..
7986 } = module.types[var.ty].inner
7987 {
7988 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7991 let l1 = back::Level(1);
7992 let l2 = l1.next();
7993 writeln!(
7994 self.out,
7995 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7996 )?;
7997 for i in 0..3 {
7998 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7999 handle,
8000 ExternalTextureNameKey::Plane(i),
8001 )];
8002 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
8003 }
8004 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
8005 handle,
8006 ExternalTextureNameKey::Params,
8007 )];
8008 writeln!(self.out, "{l2}.params = {params_name},")?;
8009 writeln!(self.out, "{l1}}};")?;
8010 }
8011 }
8012 }
8013
8014 if need_workgroup_variables_initialization {
8015 self.write_workgroup_variables_initialization(
8016 module,
8017 mod_info,
8018 fun_info,
8019 local_invocation_index,
8020 ep.stage,
8021 )?;
8022 }
8023
8024 for (arg_index, arg) in fun.arguments.iter().enumerate() {
8035 let arg_name =
8036 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
8037 match module.types[arg.ty].inner {
8038 crate::TypeInner::Struct { ref members, .. } => {
8039 let struct_name = &self.names[&NameKey::Type(arg.ty)];
8040 write!(
8041 self.out,
8042 "{}const {} {} = {{ ",
8043 back::INDENT,
8044 struct_name,
8045 arg_name
8046 )?;
8047 for (member_index, member) in members.iter().enumerate() {
8048 let key = NameKey::StructMember(arg.ty, member_index as u32);
8049 let name = &flattened_member_names[&key];
8050 if member_index != 0 {
8051 write!(self.out, ", ")?;
8052 }
8053 if self
8055 .struct_member_pads
8056 .contains(&(arg.ty, member_index as u32))
8057 {
8058 write!(self.out, "{{}}, ")?;
8059 }
8060 match member.binding {
8061 Some(crate::Binding::Location {
8062 interpolation: Some(crate::Interpolation::PerVertex),
8063 ..
8064 }) => {
8065 writeln!(
8066 self.out,
8067 "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
8068 back::INDENT,
8069 varyings_member_name,
8070 arg_name,
8071 )?;
8072 continue;
8073 }
8074 Some(crate::Binding::Location { .. }) => {
8075 if has_varyings {
8076 write!(self.out, "{varyings_member_name}.")?;
8077 }
8078 }
8079 _ => (),
8080 }
8081 write!(self.out, "{name}")?;
8082 }
8083 writeln!(self.out, " }};")?;
8084 }
8085 _ => match arg.binding {
8086 Some(crate::Binding::Location {
8087 interpolation: Some(crate::Interpolation::PerVertex),
8088 ..
8089 }) => {
8090 let ty_name = TypeContext {
8091 handle: arg.ty,
8092 gctx: module.to_ctx(),
8093 names: &self.names,
8094 access: crate::StorageAccess::empty(),
8095 first_time: false,
8096 };
8097 writeln!(
8098 self.out,
8099 "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8100 back::INDENT,
8101 varyings_member_name,
8102 arg_name,
8103 )?;
8104 }
8105 Some(crate::Binding::Location { .. })
8106 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8107 if has_varyings {
8108 writeln!(
8109 self.out,
8110 "{}const auto {} = {}.{};",
8111 back::INDENT,
8112 arg_name,
8113 varyings_member_name,
8114 arg_name
8115 )?;
8116 }
8117 }
8118 _ => {}
8119 },
8120 }
8121 }
8122
8123 let guarded_indices =
8124 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8125
8126 let context = StatementContext {
8127 expression: ExpressionContext {
8128 function: fun,
8129 origin: FunctionOrigin::EntryPoint(ep_index as _),
8130 info: fun_info,
8131 lang_version: options.lang_version,
8132 policies: options.bounds_check_policies,
8133 guarded_indices,
8134 module,
8135 mod_info,
8136 pipeline_options,
8137 force_loop_bounding: options.force_loop_bounding,
8138 },
8139 result_struct: if ep.stage == crate::ShaderStage::Task {
8140 None
8141 } else {
8142 Some(&stage_out_name)
8143 },
8144 };
8145
8146 self.put_locals(&context.expression)?;
8149 self.update_expressions_to_bake(fun, fun_info, &context.expression);
8150 self.put_block(back::Level(1), &fun.body, &context)?;
8151 writeln!(self.out, "}}")?;
8152 if ep_index + 1 != module.entry_points.len() {
8153 writeln!(self.out)?;
8154 }
8155 self.named_expressions.clear();
8156
8157 if is_wrapped {
8158 self.write_wrapper_function(NestedFunctionInfo {
8159 options,
8160 ep,
8161 module,
8162 mod_info,
8163 fun_info,
8164 args,
8165 local_invocation_index,
8166 nested_name: &nested_fun_name,
8167 outer_name: &fun_name,
8168 out_mesh_info,
8169 })?;
8170 }
8171 }
8172
8173 Ok(info)
8174 }
8175
8176 pub(super) fn write_barrier(
8177 &mut self,
8178 flags: crate::Barrier,
8179 level: back::Level,
8180 ) -> BackendResult {
8181 if flags.is_empty() {
8184 writeln!(
8185 self.out,
8186 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8187 )?;
8188 }
8189 if flags.contains(crate::Barrier::STORAGE) {
8190 writeln!(
8191 self.out,
8192 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8193 )?;
8194 }
8195 if flags.contains(crate::Barrier::WORK_GROUP) {
8196 writeln!(
8197 self.out,
8198 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8199 )?;
8200 if self.needs_object_memory_barriers {
8201 writeln!(
8202 self.out,
8203 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8204 )?;
8205 }
8206 }
8207 if flags.contains(crate::Barrier::SUB_GROUP) {
8208 writeln!(
8209 self.out,
8210 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8211 )?;
8212 }
8213 if flags.contains(crate::Barrier::TEXTURE) {
8214 writeln!(
8215 self.out,
8216 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8217 )?;
8218 }
8219 Ok(())
8220 }
8221}
8222
8223mod workgroup_mem_init {
8226 use crate::EntryPoint;
8227
8228 use super::*;
8229
8230 enum Access {
8231 GlobalVariable(Handle<crate::GlobalVariable>),
8232 StructMember(Handle<crate::Type>, u32),
8233 Array(usize),
8234 }
8235
8236 impl Access {
8237 fn write<W: Write>(
8238 &self,
8239 writer: &mut W,
8240 names: &FastHashMap<NameKey, String>,
8241 ) -> Result<(), core::fmt::Error> {
8242 match *self {
8243 Access::GlobalVariable(handle) => {
8244 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8245 }
8246 Access::StructMember(handle, index) => {
8247 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8248 }
8249 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8250 }
8251 }
8252 }
8253
8254 struct AccessStack {
8255 stack: Vec<Access>,
8256 array_depth: usize,
8257 }
8258
8259 impl AccessStack {
8260 const fn new() -> Self {
8261 Self {
8262 stack: Vec::new(),
8263 array_depth: 0,
8264 }
8265 }
8266
8267 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8268 let array_depth = self.array_depth;
8269 self.stack.push(Access::Array(array_depth));
8270 self.array_depth += 1;
8271 let res = cb(self, array_depth);
8272 self.stack.pop();
8273 self.array_depth -= 1;
8274 res
8275 }
8276
8277 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8278 self.stack.push(new);
8279 let res = cb(self);
8280 self.stack.pop();
8281 res
8282 }
8283
8284 fn write<W: Write>(
8285 &self,
8286 writer: &mut W,
8287 names: &FastHashMap<NameKey, String>,
8288 ) -> Result<(), core::fmt::Error> {
8289 for next in self.stack.iter() {
8290 next.write(writer, names)?;
8291 }
8292 Ok(())
8293 }
8294 }
8295
8296 impl<W: Write> Writer<W> {
8297 pub(super) fn need_workgroup_variables_initialization(
8298 &mut self,
8299 options: &Options,
8300 ep: &EntryPoint,
8301 module: &crate::Module,
8302 fun_info: &valid::FunctionInfo,
8303 ) -> bool {
8304 let is_task = ep.stage == crate::ShaderStage::Task;
8305 options.zero_initialize_workgroup_memory
8306 && ep.stage.compute_like()
8307 && module.global_variables.iter().any(|(handle, var)| {
8308 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8309 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8310 !fun_info[handle].is_empty() && is_right_address_space
8311 })
8312 }
8313
8314 pub fn write_workgroup_variables_initialization(
8315 &mut self,
8316 module: &crate::Module,
8317 module_info: &valid::ModuleInfo,
8318 fun_info: &valid::FunctionInfo,
8319 local_invocation_index: Option<&NameKey>,
8320 stage: crate::ShaderStage,
8321 ) -> BackendResult {
8322 let level = back::Level(1);
8323
8324 writeln!(
8325 self.out,
8326 "{}if ({} == 0u) {{",
8327 level,
8328 local_invocation_index
8329 .map(|name_key| self.names[name_key].as_str())
8330 .unwrap_or("__local_invocation_index"),
8331 )?;
8332
8333 let mut access_stack = AccessStack::new();
8334
8335 let is_task = stage == crate::ShaderStage::Task;
8336 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8337 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8338 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8339 !fun_info[handle].is_empty() && is_right_address_space
8340 });
8341
8342 for (handle, var) in vars {
8343 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8344 self.write_workgroup_variable_initialization(
8345 module,
8346 module_info,
8347 var.ty,
8348 access_stack,
8349 level.next(),
8350 )
8351 })?;
8352 }
8353
8354 writeln!(self.out, "{level}}}")?;
8355 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8356 }
8357
8358 fn write_workgroup_variable_initialization(
8359 &mut self,
8360 module: &crate::Module,
8361 module_info: &valid::ModuleInfo,
8362 ty: Handle<crate::Type>,
8363 access_stack: &mut AccessStack,
8364 level: back::Level,
8365 ) -> BackendResult {
8366 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8367 write!(self.out, "{level}")?;
8368 access_stack.write(&mut self.out, &self.names)?;
8369 writeln!(self.out, " = {{}};")?;
8370 } else {
8371 match module.types[ty].inner {
8372 crate::TypeInner::Atomic { .. } => {
8373 write!(
8374 self.out,
8375 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8376 )?;
8377 access_stack.write(&mut self.out, &self.names)?;
8378 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8379 }
8380 crate::TypeInner::Array { base, size, .. } => {
8381 let count = match size.resolve(module.to_ctx())? {
8382 proc::IndexableLength::Known(count) => count,
8383 proc::IndexableLength::Dynamic => unreachable!(),
8384 };
8385
8386 access_stack.enter_array(|access_stack, array_depth| {
8387 writeln!(
8388 self.out,
8389 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8390 )?;
8391 self.write_workgroup_variable_initialization(
8392 module,
8393 module_info,
8394 base,
8395 access_stack,
8396 level.next(),
8397 )?;
8398 writeln!(self.out, "{level}}}")?;
8399 BackendResult::Ok(())
8400 })?;
8401 }
8402 crate::TypeInner::Struct { ref members, .. } => {
8403 for (index, member) in members.iter().enumerate() {
8404 access_stack.enter(
8405 Access::StructMember(ty, index as u32),
8406 |access_stack| {
8407 self.write_workgroup_variable_initialization(
8408 module,
8409 module_info,
8410 member.ty,
8411 access_stack,
8412 level,
8413 )
8414 },
8415 )?;
8416 }
8417 }
8418 _ => unreachable!(),
8419 }
8420 }
8421
8422 Ok(())
8423 }
8424 }
8425}
8426
8427impl crate::AtomicFunction {
8428 const fn to_msl(self) -> &'static str {
8429 match self {
8430 Self::Add => "fetch_add",
8431 Self::Subtract => "fetch_sub",
8432 Self::And => "fetch_and",
8433 Self::InclusiveOr => "fetch_or",
8434 Self::ExclusiveOr => "fetch_xor",
8435 Self::Min => "fetch_min",
8436 Self::Max => "fetch_max",
8437 Self::Exchange { compare: None } => "exchange",
8438 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8439 }
8440 }
8441
8442 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8443 Ok(match self {
8444 Self::Min => "min",
8445 Self::Max => "max",
8446 _ => Err(Error::FeatureNotImplemented(
8447 "64-bit atomic operation other than min/max".to_string(),
8448 ))?,
8449 })
8450 }
8451}