1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17 ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18 TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21 arena::{Handle, HandleSet},
22 back::{
23 self, get_entry_points,
24 msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25 Baked,
26 },
27 common,
28 proc::{
29 self, concrete_int_scalars,
30 index::{self, BoundsCheck},
31 ExternalTextureNameKey, NameKey, TypeResolution,
32 },
33 valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59 "nagaTextureSampleBaseClampToEdge";
60pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76fn put_numeric_type(
85 out: &mut impl Write,
86 scalar: crate::Scalar,
87 sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89 match (scalar, sizes) {
90 (scalar, &[]) => {
91 write!(out, "{}", scalar.to_msl_name())
92 }
93 (scalar, &[rows]) => {
94 write!(
95 out,
96 "{}::{}{}",
97 NAMESPACE,
98 scalar.to_msl_name(),
99 common::vector_size_str(rows)
100 )
101 }
102 (scalar, &[rows, columns]) => {
103 write!(
104 out,
105 "{}::{}{}x{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(columns),
109 common::vector_size_str(rows)
110 )
111 }
112 (_, _) => Ok(()), }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117 use crate::ScalarKind::*;
118 match scalar.kind {
119 Sint | Uint | AbstractInt | Bool => true,
120 Float | AbstractFloat => false,
121 }
122}
123
124const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140 }
141}
142
143struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161 self.0.write_prefixed(f, "size")
162 }
163}
164
165#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171 target_type: &'a str,
172 orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177 Self { target_type, orig }
178 }
179}
180
181impl Display for Reinterpreted<'_> {
182 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183 f.write_str(REINTERPRET_PREFIX)?;
184 f.write_str(self.target_type)?;
185 self.orig.write_prefixed(f, "_e")
186 }
187}
188
189pub(super) struct TypeContext<'a> {
190 pub handle: Handle<crate::Type>,
191 pub gctx: proc::GlobalCtx<'a>,
192 pub names: &'a FastHashMap<NameKey, String>,
193 pub access: crate::StorageAccess,
194 pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198 fn scalar(&self) -> Option<crate::Scalar> {
199 let ty = &self.gctx.types[self.handle];
200 ty.inner.scalar()
201 }
202
203 fn vector_size(&self) -> Option<crate::VectorSize> {
204 let ty = &self.gctx.types[self.handle];
205 match ty.inner {
206 crate::TypeInner::Vector { size, .. } => Some(size),
207 _ => None,
208 }
209 }
210
211 fn unwrap_array(self) -> Self {
212 match self.gctx.types[self.handle].inner {
213 crate::TypeInner::Array { base, .. } => Self {
214 handle: base,
215 ..self
216 },
217 _ => self,
218 }
219 }
220}
221
222impl Display for TypeContext<'_> {
223 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224 let ty = &self.gctx.types[self.handle];
225 if ty.needs_alias() && !self.first_time {
226 let name = &self.names[&NameKey::Type(self.handle)];
227 return write!(out, "{name}");
228 }
229
230 match ty.inner {
231 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232 crate::TypeInner::Atomic(scalar) => {
233 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234 }
235 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236 crate::TypeInner::Matrix {
237 columns,
238 rows,
239 scalar,
240 } => put_numeric_type(out, scalar, &[rows, columns]),
241 crate::TypeInner::CooperativeMatrix {
243 columns,
244 rows,
245 scalar,
246 role: _,
247 } => {
248 write!(
249 out,
250 "{NAMESPACE}::simdgroup_{}{}x{}",
251 scalar.to_msl_name(),
252 columns as u32,
253 rows as u32,
254 )
255 }
256 crate::TypeInner::Pointer { base, space } => {
257 let sub = Self {
258 handle: base,
259 first_time: false,
260 ..*self
261 };
262 let space_name = match space.to_msl_name() {
263 Some(name) => name,
264 None => return Ok(()),
265 };
266 write!(out, "{space_name} {sub}&")
267 }
268 crate::TypeInner::ValuePointer {
269 size,
270 scalar,
271 space,
272 } => {
273 match space.to_msl_name() {
274 Some(name) => write!(out, "{name} ")?,
275 None => return Ok(()),
276 };
277 match size {
278 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279 None => put_numeric_type(out, scalar, &[])?,
280 };
281
282 write!(out, "&")
283 }
284 crate::TypeInner::Array { base, .. } => {
285 let sub = Self {
286 handle: base,
287 first_time: false,
288 ..*self
289 };
290 write!(out, "{sub}")
293 }
294 crate::TypeInner::Struct { .. } => unreachable!(),
295 crate::TypeInner::Image {
296 dim,
297 arrayed,
298 class,
299 } => {
300 let dim_str = match dim {
301 crate::ImageDimension::D1 => "1d",
302 crate::ImageDimension::D2 => "2d",
303 crate::ImageDimension::D3 => "3d",
304 crate::ImageDimension::Cube => "cube",
305 };
306 let (texture_str, msaa_str, scalar, access) = match class {
307 crate::ImageClass::Sampled { kind, multi } => {
308 let (msaa_str, access) = if multi {
309 ("_ms", "read")
310 } else {
311 ("", "sample")
312 };
313 let scalar = crate::Scalar { kind, width: 4 };
314 ("texture", msaa_str, scalar, access)
315 }
316 crate::ImageClass::Depth { multi } => {
317 let (msaa_str, access) = if multi {
318 ("_ms", "read")
319 } else {
320 ("", "sample")
321 };
322 let scalar = crate::Scalar {
323 kind: crate::ScalarKind::Float,
324 width: 4,
325 };
326 ("depth", msaa_str, scalar, access)
327 }
328 crate::ImageClass::Storage { format, .. } => {
329 let access = if self
330 .access
331 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332 {
333 "read_write"
334 } else if self.access.contains(crate::StorageAccess::STORE) {
335 "write"
336 } else if self.access.contains(crate::StorageAccess::LOAD) {
337 "read"
338 } else {
339 log::warn!(
340 "Storage access for {:?} (name '{}'): {:?}",
341 self.handle,
342 ty.name.as_deref().unwrap_or_default(),
343 self.access
344 );
345 unreachable!("module is not valid");
346 };
347 ("texture", "", format.into(), access)
348 }
349 crate::ImageClass::External => {
350 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351 }
352 };
353 let base_name = scalar.to_msl_name();
354 let array_str = if arrayed { "_array" } else { "" };
355 write!(
356 out,
357 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358 )
359 }
360 crate::TypeInner::Sampler { comparison: _ } => {
361 write!(out, "{NAMESPACE}::sampler")
362 }
363 crate::TypeInner::AccelerationStructure { vertex_return } => {
364 if vertex_return {
365 unimplemented!("metal does not support vertex ray hit return")
366 }
367 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368 }
369 crate::TypeInner::RayQuery { vertex_return } => {
370 if vertex_return {
371 unimplemented!("metal does not support vertex ray hit return")
372 }
373 write!(out, "{}", super::ray::metal_intersector_ty())
374 }
375 crate::TypeInner::BindingArray { base, .. } => {
376 let base_tyname = Self {
377 handle: base,
378 first_time: false,
379 ..*self
380 };
381
382 write!(
383 out,
384 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385 )
386 }
387 }
388 }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392 pub module: &'a crate::Module,
393 pub names: &'a FastHashMap<NameKey, String>,
394 pub handle: Handle<crate::GlobalVariable>,
395 pub usage: valid::GlobalUse,
396 pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400 ty_name: String,
401 var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405 fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406 let var = &self.module.global_variables[self.handle];
407 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409 let storage_access = match var.space {
410 crate::AddressSpace::Storage { access } => access,
411 _ => match self.module.types[var.ty].inner {
412 crate::TypeInner::Image {
413 class: crate::ImageClass::Storage { access, .. },
414 ..
415 } => access,
416 crate::TypeInner::BindingArray { base, .. } => {
417 match self.module.types[base].inner {
418 crate::TypeInner::Image {
419 class: crate::ImageClass::Storage { access, .. },
420 ..
421 } => access,
422 _ => crate::StorageAccess::default(),
423 }
424 }
425 _ => crate::StorageAccess::default(),
426 },
427 };
428 let ty_name = TypeContext {
429 handle: var.ty,
430 gctx: self.module.to_ctx(),
431 names: self.names,
432 access: storage_access,
433 first_time: false,
434 };
435
436 let access = if var.space.needs_access_qualifier()
437 && !self.usage.intersects(valid::GlobalUse::WRITE)
438 {
439 "const"
440 } else {
441 ""
442 };
443 let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444 (Some(space), crate::AddressSpace::WorkGroup) => {
445 ("", space, access, if self.reference { "&" } else { "" })
446 }
447 (Some(space), _) if self.reference => {
448 let coherent = if var
449 .memory_decorations
450 .contains(crate::MemoryDecorations::COHERENT)
451 {
452 "coherent "
453 } else {
454 ""
455 };
456 (coherent, space, access, "&")
457 }
458 _ => ("", "", "", ""),
459 };
460
461 let ty = format!(
462 "{coherent}{space}{}{ty_name}{}{access}{reference}",
463 if space.is_empty() { "" } else { " " },
464 if access.is_empty() { "" } else { " " },
465 );
466
467 Ok(TypedGlobalVariableParts {
468 ty_name: ty,
469 var_name: name.clone(),
470 })
471 }
472 pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473 let parts = self.to_parts()?;
474
475 Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476 }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481 UnaryOp {
482 op: crate::UnaryOperator,
483 ty: (Option<crate::VectorSize>, crate::Scalar),
484 },
485 BinaryOp {
486 op: crate::BinaryOperator,
487 left_ty: (Option<crate::VectorSize>, crate::Scalar),
488 right_ty: (Option<crate::VectorSize>, crate::Scalar),
489 },
490 Math {
491 fun: crate::MathFunction,
492 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493 },
494 Cast {
495 src_scalar: crate::Scalar,
496 vector_size: Option<crate::VectorSize>,
497 dst_scalar: crate::Scalar,
498 },
499 ImageLoad {
500 class: crate::ImageClass,
501 },
502 ImageSample {
503 class: crate::ImageClass,
504 clamp_to_edge: bool,
505 },
506 ImageQuerySize {
507 class: crate::ImageClass,
508 },
509 CooperativeLoad {
510 space_name: &'static str,
511 columns: crate::CooperativeSize,
512 rows: crate::CooperativeSize,
513 scalar: crate::Scalar,
514 },
515 CooperativeMultiplyAdd {
516 space_name: &'static str,
517 columns: crate::CooperativeSize,
518 rows: crate::CooperativeSize,
519 intermediate: crate::CooperativeSize,
520 scalar: crate::Scalar,
521 },
522 RayQueryGetIntersection {
523 committed: bool,
524 },
525}
526
527pub struct Writer<W> {
528 pub(super) out: W,
529 pub(super) names: FastHashMap<NameKey, String>,
530 pub(super) named_expressions: crate::NamedExpressions,
531 need_bake_expressions: back::NeedBakeExpressions,
533 pub(super) namer: proc::Namer,
534 pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535 emit_int_div_checks: bool,
536 #[cfg(test)]
537 put_expression_stack_pointers: FastHashSet<*const ()>,
538 #[cfg(test)]
539 put_block_stack_pointers: FastHashSet<*const ()>,
540 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
543 needs_object_memory_barriers: bool,
544}
545
546impl crate::Scalar {
547 pub(super) fn to_msl_name(self) -> &'static str {
548 use crate::ScalarKind as Sk;
549 match self {
550 Self {
551 kind: Sk::Float,
552 width: 4,
553 } => "float",
554 Self {
555 kind: Sk::Float,
556 width: 2,
557 } => "half",
558 Self {
559 kind: Sk::Sint,
560 width: 2,
561 } => "short",
562 Self {
563 kind: Sk::Uint,
564 width: 2,
565 } => "ushort",
566 Self {
567 kind: Sk::Sint,
568 width: 4,
569 } => "int",
570 Self {
571 kind: Sk::Uint,
572 width: 4,
573 } => "uint",
574 Self {
575 kind: Sk::Sint,
576 width: 8,
577 } => "long",
578 Self {
579 kind: Sk::Uint,
580 width: 8,
581 } => "ulong",
582 Self {
583 kind: Sk::Bool,
584 width: _,
585 } => "bool",
586 Self {
587 kind: Sk::AbstractInt | Sk::AbstractFloat,
588 width: _,
589 } => unreachable!("Found Abstract scalar kind"),
590 _ => unreachable!("Unsupported scalar kind: {:?}", self),
591 }
592 }
593}
594
595const fn separate(need_separator: bool) -> &'static str {
596 if need_separator {
597 ","
598 } else {
599 ""
600 }
601}
602
603fn should_pack_struct_member(
604 members: &[crate::StructMember],
605 span: u32,
606 index: usize,
607 module: &crate::Module,
608) -> Option<crate::Scalar> {
609 let member = &members[index];
610
611 let ty_inner = &module.types[member.ty].inner;
612 let last_offset = member.offset + ty_inner.size(module.to_ctx());
613 let next_offset = match members.get(index + 1) {
614 Some(next) => next.offset,
615 None => span,
616 };
617 let is_tight = next_offset == last_offset;
618
619 match *ty_inner {
620 crate::TypeInner::Vector {
621 size: crate::VectorSize::Tri,
622 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
623 } if is_tight => Some(scalar),
624 _ => None,
625 }
626}
627
628fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
629 match arena[ty].inner {
630 crate::TypeInner::Struct { ref members, .. } => {
631 if let Some(member) = members.last() {
632 if let crate::TypeInner::Array {
633 size: crate::ArraySize::Dynamic,
634 ..
635 } = arena[member.ty].inner
636 {
637 return true;
638 }
639 }
640 false
641 }
642 crate::TypeInner::Array {
643 size: crate::ArraySize::Dynamic,
644 ..
645 } => true,
646 _ => false,
647 }
648}
649
650impl crate::AddressSpace {
651 const fn needs_pass_through(&self) -> bool {
655 match *self {
656 Self::Uniform
657 | Self::Storage { .. }
658 | Self::Private
659 | Self::WorkGroup
660 | Self::Immediate
661 | Self::Handle
662 | Self::TaskPayload => true,
663 Self::Function => false,
664 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
665 }
666 }
667
668 const fn needs_access_qualifier(&self) -> bool {
670 match *self {
671 Self::Storage { .. } => true,
676 Self::TaskPayload => true,
677 Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
678 Self::Private | Self::WorkGroup => false,
680 Self::Uniform | Self::Immediate => false,
682 Self::Handle | Self::Function => false,
684 }
685 }
686
687 const fn to_msl_name(self) -> Option<&'static str> {
688 match self {
689 Self::Handle => None,
690 Self::Uniform | Self::Immediate => Some("constant"),
691 Self::Storage { .. } => Some("device"),
692 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
696 Self::WorkGroup => Some("threadgroup"),
697 Self::TaskPayload => Some("object_data"),
698 Self::IncomingRayPayload => Some("ray_data"),
699 }
700 }
701}
702
703impl crate::Type {
704 const fn needs_alias(&self) -> bool {
706 use crate::TypeInner as Ti;
707
708 match self.inner {
709 Ti::Scalar(_)
711 | Ti::Vector { .. }
712 | Ti::Matrix { .. }
713 | Ti::CooperativeMatrix { .. }
714 | Ti::Atomic(_)
715 | Ti::Pointer { .. }
716 | Ti::ValuePointer { .. } => self.name.is_some(),
717 Ti::Struct { .. } | Ti::Array { .. } => true,
719 Ti::Image { .. }
721 | Ti::Sampler { .. }
722 | Ti::AccelerationStructure { .. }
723 | Ti::RayQuery { .. }
724 | Ti::BindingArray { .. } => false,
725 }
726 }
727}
728
729#[derive(Clone, Copy)]
730enum FunctionOrigin {
731 Handle(Handle<crate::Function>),
732 EntryPoint(proc::EntryPointIndex),
733}
734
735trait NameKeyExt {
736 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
737 match origin {
738 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
739 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
740 }
741 }
742
743 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
748 match origin {
749 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
750 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
751 }
752 }
753}
754
755impl NameKeyExt for NameKey {}
756
757#[derive(Clone, Copy)]
767enum LevelOfDetail {
768 Direct(Handle<crate::Expression>),
769 Restricted(Handle<crate::Expression>),
770}
771
772struct TexelAddress {
782 coordinate: Handle<crate::Expression>,
783 array_index: Option<Handle<crate::Expression>>,
784 sample: Option<Handle<crate::Expression>>,
785 level: Option<LevelOfDetail>,
786}
787
788pub(super) struct ExpressionContext<'a> {
789 pub(super) function: &'a crate::Function,
790 origin: FunctionOrigin,
791 pub(super) info: &'a valid::FunctionInfo,
792 pub(super) module: &'a crate::Module,
793 pub(super) mod_info: &'a valid::ModuleInfo,
794 pub(super) pipeline_options: &'a PipelineOptions,
795 pub(super) lang_version: (u8, u8),
796 pub(super) policies: index::BoundsCheckPolicies,
797
798 pub(super) guarded_indices: HandleSet<crate::Expression>,
802 pub(super) force_loop_bounding: bool,
804 emit_int_div_checks: bool,
806}
807
808impl<'a> ExpressionContext<'a> {
809 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
810 self.info[handle].ty.inner_with(&self.module.types)
811 }
812
813 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
820 let image_ty = self.resolve_type(image);
821 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
822 class.is_mipmapped() && dim != crate::ImageDimension::D1
823 } else {
824 false
825 }
826 }
827
828 fn choose_bounds_check_policy(
829 &self,
830 pointer: Handle<crate::Expression>,
831 ) -> index::BoundsCheckPolicy {
832 self.policies
833 .choose_policy(pointer, &self.module.types, self.info)
834 }
835
836 fn access_needs_check(
838 &self,
839 base: Handle<crate::Expression>,
840 index: index::GuardedIndex,
841 ) -> Option<index::IndexableLength> {
842 index::access_needs_check(
843 base,
844 index,
845 self.module,
846 &self.function.expressions,
847 self.info,
848 )
849 }
850
851 fn bounds_check_iter(
853 &self,
854 chain: Handle<crate::Expression>,
855 ) -> impl Iterator<Item = BoundsCheck> + '_ {
856 index::bounds_check_iter(chain, self.module, self.function, self.info)
857 }
858
859 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
861 index::oob_local_types(self.module, self.function, self.info, self.policies)
862 }
863
864 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
865 match self.function.expressions[expr_handle] {
866 crate::Expression::AccessIndex { base, index } => {
867 let ty = match *self.resolve_type(base) {
868 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
869 ref ty => ty,
870 };
871 match *ty {
872 crate::TypeInner::Struct {
873 ref members, span, ..
874 } => should_pack_struct_member(members, span, index as usize, self.module),
875 _ => None,
876 }
877 }
878 _ => None,
879 }
880 }
881}
882
883pub(super) struct StatementContext<'a> {
884 pub(super) expression: ExpressionContext<'a>,
885 pub(super) result_struct: Option<&'a str>,
886}
887
888impl<W: Write> Writer<W> {
889 pub fn new(out: W) -> Self {
891 Writer {
892 out,
893 names: FastHashMap::default(),
894 named_expressions: Default::default(),
895 need_bake_expressions: Default::default(),
896 namer: proc::Namer::default(),
897 wrapped_functions: FastHashSet::default(),
898 emit_int_div_checks: true,
899 #[cfg(test)]
900 put_expression_stack_pointers: Default::default(),
901 #[cfg(test)]
902 put_block_stack_pointers: Default::default(),
903 struct_member_pads: FastHashSet::default(),
904 needs_object_memory_barriers: false,
905 }
906 }
907
908 pub fn finish(self) -> W {
911 self.out
912 }
913
914 fn gen_force_bounded_loop_statements(
1021 &mut self,
1022 level: back::Level,
1023 context: &StatementContext,
1024 ) -> Option<(String, String)> {
1025 if !context.expression.force_loop_bounding {
1026 return None;
1027 }
1028
1029 let loop_bound_name = self.namer.call("loop_bound");
1030 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1033 let level = level.next();
1034 let break_and_inc = format!(
1035 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1036{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1037 );
1038
1039 Some((decl, break_and_inc))
1040 }
1041
1042 fn put_call_parameters(
1043 &mut self,
1044 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1045 context: &ExpressionContext,
1046 ) -> BackendResult {
1047 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1048 writer.put_expression(expr, context, true)
1049 })
1050 }
1051
1052 fn put_call_parameters_impl<C, E>(
1053 &mut self,
1054 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1055 ctx: &C,
1056 put_expression: E,
1057 ) -> BackendResult
1058 where
1059 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1060 {
1061 write!(self.out, "(")?;
1062 for (i, handle) in parameters.enumerate() {
1063 if i != 0 {
1064 write!(self.out, ", ")?;
1065 }
1066 put_expression(self, ctx, handle)?;
1067 }
1068 write!(self.out, ")")?;
1069 Ok(())
1070 }
1071
1072 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1078 let oob_local_types = context.oob_local_types();
1079 for &ty in oob_local_types.iter() {
1080 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1081 self.names.insert(name_key, self.namer.call("oob"));
1082 }
1083
1084 for (name_key, ty, init) in context
1085 .function
1086 .local_variables
1087 .iter()
1088 .map(|(local_handle, local)| {
1089 let name_key = NameKey::local(context.origin, local_handle);
1090 (name_key, local.ty, local.init)
1091 })
1092 .chain(oob_local_types.iter().map(|&ty| {
1093 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1094 (name_key, ty, None)
1095 }))
1096 {
1097 let ty_name = TypeContext {
1098 handle: ty,
1099 gctx: context.module.to_ctx(),
1100 names: &self.names,
1101 access: crate::StorageAccess::empty(),
1102 first_time: false,
1103 };
1104 write!(
1105 self.out,
1106 "{}{} {}",
1107 back::INDENT,
1108 ty_name,
1109 self.names[&name_key]
1110 )?;
1111 match init {
1112 Some(value) => {
1113 write!(self.out, " = ")?;
1114 self.put_expression(value, context, true)?;
1115 }
1116 None => {
1117 write!(self.out, " = {{}}")?;
1118 }
1119 };
1120 writeln!(self.out, ";")?;
1121 }
1122 Ok(())
1123 }
1124
1125 fn put_level_of_detail(
1126 &mut self,
1127 level: LevelOfDetail,
1128 context: &ExpressionContext,
1129 ) -> BackendResult {
1130 match level {
1131 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1132 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1133 }
1134 Ok(())
1135 }
1136
1137 fn put_image_query(
1138 &mut self,
1139 image: Handle<crate::Expression>,
1140 query: &str,
1141 level: Option<LevelOfDetail>,
1142 context: &ExpressionContext,
1143 ) -> BackendResult {
1144 self.put_expression(image, context, false)?;
1145 write!(self.out, ".get_{query}(")?;
1146 if let Some(level) = level {
1147 self.put_level_of_detail(level, context)?;
1148 }
1149 write!(self.out, ")")?;
1150 Ok(())
1151 }
1152
1153 fn put_image_size_query(
1154 &mut self,
1155 image: Handle<crate::Expression>,
1156 level: Option<LevelOfDetail>,
1157 kind: crate::ScalarKind,
1158 context: &ExpressionContext,
1159 ) -> BackendResult {
1160 if let crate::TypeInner::Image {
1161 class: crate::ImageClass::External,
1162 ..
1163 } = *context.resolve_type(image)
1164 {
1165 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1166 self.put_expression(image, context, true)?;
1167 write!(self.out, ")")?;
1168 return Ok(());
1169 }
1170
1171 let dim = match *context.resolve_type(image) {
1174 crate::TypeInner::Image { dim, .. } => dim,
1175 ref other => unreachable!("Unexpected type {:?}", other),
1176 };
1177 let scalar = crate::Scalar { kind, width: 4 };
1178 let coordinate_type = scalar.to_msl_name();
1179 match dim {
1180 crate::ImageDimension::D1 => {
1181 if kind == crate::ScalarKind::Uint {
1185 self.put_image_query(image, "width", None, context)?;
1187 } else {
1188 write!(self.out, "int(")?;
1190 self.put_image_query(image, "width", None, context)?;
1191 write!(self.out, ")")?;
1192 }
1193 }
1194 crate::ImageDimension::D2 => {
1195 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1196 self.put_image_query(image, "width", level, context)?;
1197 write!(self.out, ", ")?;
1198 self.put_image_query(image, "height", level, context)?;
1199 write!(self.out, ")")?;
1200 }
1201 crate::ImageDimension::D3 => {
1202 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1203 self.put_image_query(image, "width", level, context)?;
1204 write!(self.out, ", ")?;
1205 self.put_image_query(image, "height", level, context)?;
1206 write!(self.out, ", ")?;
1207 self.put_image_query(image, "depth", level, context)?;
1208 write!(self.out, ")")?;
1209 }
1210 crate::ImageDimension::Cube => {
1211 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1212 self.put_image_query(image, "width", level, context)?;
1213 write!(self.out, ")")?;
1214 }
1215 }
1216 Ok(())
1217 }
1218
1219 fn put_cast_to_uint_scalar_or_vector(
1220 &mut self,
1221 expr: Handle<crate::Expression>,
1222 context: &ExpressionContext,
1223 ) -> BackendResult {
1224 match *context.resolve_type(expr) {
1226 crate::TypeInner::Scalar(_) => {
1227 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1228 }
1229 crate::TypeInner::Vector { size, .. } => {
1230 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1231 }
1232 _ => {
1233 return Err(Error::GenericValidation(
1234 "Invalid type for image coordinate".into(),
1235 ))
1236 }
1237 };
1238
1239 write!(self.out, "(")?;
1240 self.put_expression(expr, context, true)?;
1241 write!(self.out, ")")?;
1242 Ok(())
1243 }
1244
1245 fn put_image_sample_level(
1246 &mut self,
1247 image: Handle<crate::Expression>,
1248 level: crate::SampleLevel,
1249 context: &ExpressionContext,
1250 ) -> BackendResult {
1251 let has_levels = context.image_needs_lod(image);
1252 match level {
1253 crate::SampleLevel::Auto => {}
1254 crate::SampleLevel::Zero => {
1255 }
1257 _ if !has_levels => {
1258 log::warn!("1D image can't be sampled with level {level:?}");
1259 }
1260 crate::SampleLevel::Exact(h) => {
1261 write!(self.out, ", {NAMESPACE}::level(")?;
1262 self.put_expression(h, context, true)?;
1263 write!(self.out, ")")?;
1264 }
1265 crate::SampleLevel::Bias(h) => {
1266 write!(self.out, ", {NAMESPACE}::bias(")?;
1267 self.put_expression(h, context, true)?;
1268 write!(self.out, ")")?;
1269 }
1270 crate::SampleLevel::Gradient { x, y } => {
1271 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1272 self.put_expression(x, context, true)?;
1273 write!(self.out, ", ")?;
1274 self.put_expression(y, context, true)?;
1275 write!(self.out, ")")?;
1276 }
1277 }
1278 Ok(())
1279 }
1280
1281 fn put_image_coordinate_limits(
1282 &mut self,
1283 image: Handle<crate::Expression>,
1284 level: Option<LevelOfDetail>,
1285 context: &ExpressionContext,
1286 ) -> BackendResult {
1287 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1288 write!(self.out, " - 1")?;
1289 Ok(())
1290 }
1291
1292 fn put_restricted_scalar_image_index(
1310 &mut self,
1311 image: Handle<crate::Expression>,
1312 index: Handle<crate::Expression>,
1313 limit_method: &str,
1314 context: &ExpressionContext,
1315 ) -> BackendResult {
1316 write!(self.out, "{NAMESPACE}::min(uint(")?;
1317 self.put_expression(index, context, true)?;
1318 write!(self.out, "), ")?;
1319 self.put_expression(image, context, false)?;
1320 write!(self.out, ".{limit_method}() - 1)")?;
1321 Ok(())
1322 }
1323
1324 fn put_restricted_texel_address(
1325 &mut self,
1326 image: Handle<crate::Expression>,
1327 address: &TexelAddress,
1328 context: &ExpressionContext,
1329 ) -> BackendResult {
1330 write!(self.out, "{NAMESPACE}::min(")?;
1332 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1333 write!(self.out, ", ")?;
1334 self.put_image_coordinate_limits(image, address.level, context)?;
1335 write!(self.out, ")")?;
1336
1337 if let Some(array_index) = address.array_index {
1339 write!(self.out, ", ")?;
1340 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1341 }
1342
1343 if let Some(sample) = address.sample {
1345 write!(self.out, ", ")?;
1346 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1347 }
1348
1349 if let Some(level) = address.level {
1352 write!(self.out, ", ")?;
1353 self.put_level_of_detail(level, context)?;
1354 }
1355
1356 Ok(())
1357 }
1358
1359 fn put_image_access_bounds_check(
1361 &mut self,
1362 image: Handle<crate::Expression>,
1363 address: &TexelAddress,
1364 context: &ExpressionContext,
1365 ) -> BackendResult {
1366 let mut conjunction = "";
1367
1368 let level = if let Some(level) = address.level {
1371 write!(self.out, "uint(")?;
1372 self.put_level_of_detail(level, context)?;
1373 write!(self.out, ") < ")?;
1374 self.put_expression(image, context, true)?;
1375 write!(self.out, ".get_num_mip_levels()")?;
1376 conjunction = " && ";
1377 Some(level)
1378 } else {
1379 None
1380 };
1381
1382 if let Some(sample) = address.sample {
1384 write!(self.out, "uint(")?;
1385 self.put_expression(sample, context, true)?;
1386 write!(self.out, ") < ")?;
1387 self.put_expression(image, context, true)?;
1388 write!(self.out, ".get_num_samples()")?;
1389 conjunction = " && ";
1390 }
1391
1392 if let Some(array_index) = address.array_index {
1394 write!(self.out, "{conjunction}uint(")?;
1395 self.put_expression(array_index, context, true)?;
1396 write!(self.out, ") < ")?;
1397 self.put_expression(image, context, true)?;
1398 write!(self.out, ".get_array_size()")?;
1399 conjunction = " && ";
1400 }
1401
1402 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1404 crate::TypeInner::Vector { .. } => true,
1405 _ => false,
1406 };
1407 write!(self.out, "{conjunction}")?;
1408 if coord_is_vector {
1409 write!(self.out, "{NAMESPACE}::all(")?;
1410 }
1411 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1412 write!(self.out, " < ")?;
1413 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1414 if coord_is_vector {
1415 write!(self.out, ")")?;
1416 }
1417
1418 Ok(())
1419 }
1420
1421 fn put_image_load(
1422 &mut self,
1423 load: Handle<crate::Expression>,
1424 image: Handle<crate::Expression>,
1425 mut address: TexelAddress,
1426 context: &ExpressionContext,
1427 ) -> BackendResult {
1428 if let crate::TypeInner::Image {
1429 class: crate::ImageClass::External,
1430 ..
1431 } = *context.resolve_type(image)
1432 {
1433 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1434 self.put_expression(image, context, true)?;
1435 write!(self.out, ", ")?;
1436 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1437 write!(self.out, ")")?;
1438 return Ok(());
1439 }
1440
1441 match context.policies.image_load {
1442 proc::BoundsCheckPolicy::Restrict => {
1443 if address.level.is_some() {
1446 address.level = if context.image_needs_lod(image) {
1447 Some(LevelOfDetail::Restricted(load))
1448 } else {
1449 None
1450 }
1451 }
1452
1453 self.put_expression(image, context, false)?;
1454 write!(self.out, ".read(")?;
1455 self.put_restricted_texel_address(image, &address, context)?;
1456 write!(self.out, ")")?;
1457 }
1458 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1459 write!(self.out, "(")?;
1460 self.put_image_access_bounds_check(image, &address, context)?;
1461 write!(self.out, " ? ")?;
1462 self.put_unchecked_image_load(image, &address, context)?;
1463 write!(self.out, ": DefaultConstructible())")?;
1464 }
1465 proc::BoundsCheckPolicy::Unchecked => {
1466 self.put_unchecked_image_load(image, &address, context)?;
1467 }
1468 }
1469
1470 Ok(())
1471 }
1472
1473 fn put_unchecked_image_load(
1474 &mut self,
1475 image: Handle<crate::Expression>,
1476 address: &TexelAddress,
1477 context: &ExpressionContext,
1478 ) -> BackendResult {
1479 self.put_expression(image, context, false)?;
1480 write!(self.out, ".read(")?;
1481 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1483 if let Some(expr) = address.array_index {
1484 write!(self.out, ", ")?;
1485 self.put_expression(expr, context, true)?;
1486 }
1487 if let Some(sample) = address.sample {
1488 write!(self.out, ", ")?;
1489 self.put_expression(sample, context, true)?;
1490 }
1491 if let Some(level) = address.level {
1492 if context.image_needs_lod(image) {
1493 write!(self.out, ", ")?;
1494 self.put_level_of_detail(level, context)?;
1495 }
1496 }
1497 write!(self.out, ")")?;
1498
1499 Ok(())
1500 }
1501
1502 fn put_image_atomic(
1503 &mut self,
1504 level: back::Level,
1505 image: Handle<crate::Expression>,
1506 address: &TexelAddress,
1507 fun: crate::AtomicFunction,
1508 value: Handle<crate::Expression>,
1509 context: &StatementContext,
1510 ) -> BackendResult {
1511 write!(self.out, "{level}")?;
1512 self.put_expression(image, &context.expression, false)?;
1513 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1514 fun.to_msl_64_bit()?
1515 } else {
1516 fun.to_msl()
1517 };
1518 write!(self.out, ".atomic_{op}(")?;
1519 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1521 write!(self.out, ", ")?;
1522 self.put_expression(value, &context.expression, true)?;
1523 writeln!(self.out, ");")?;
1524
1525 let value_ty = context.expression.resolve_type(value);
1531 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1532 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1533 (_, Some(8)) => "ulong4(0uL)",
1534 _ => "uint4(0u)",
1535 };
1536 let coord_ty = context.expression.resolve_type(address.coordinate);
1537 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1538 ""
1539 } else {
1540 ".x"
1541 };
1542 write!(self.out, "{level}if (")?;
1543 self.put_expression(address.coordinate, &context.expression, true)?;
1544 write!(self.out, "{x} == -99999) {{ ")?;
1545 self.put_expression(image, &context.expression, false)?;
1546 write!(self.out, ".write({zero_value}, ")?;
1547 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1548 if let Some(array_index) = address.array_index {
1549 write!(self.out, ", ")?;
1550 self.put_expression(array_index, &context.expression, true)?;
1551 }
1552 writeln!(self.out, "); }}")?;
1553
1554 Ok(())
1555 }
1556
1557 fn put_image_store(
1558 &mut self,
1559 level: back::Level,
1560 image: Handle<crate::Expression>,
1561 address: &TexelAddress,
1562 value: Handle<crate::Expression>,
1563 context: &StatementContext,
1564 ) -> BackendResult {
1565 write!(self.out, "{level}")?;
1566 self.put_expression(image, &context.expression, false)?;
1567 write!(self.out, ".write(")?;
1568 self.put_expression(value, &context.expression, true)?;
1569 write!(self.out, ", ")?;
1570 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1572 if let Some(expr) = address.array_index {
1573 write!(self.out, ", ")?;
1574 self.put_expression(expr, &context.expression, true)?;
1575 }
1576 writeln!(self.out, ");")?;
1577
1578 Ok(())
1579 }
1580
1581 fn put_dynamic_array_max_index(
1592 &mut self,
1593 handle: Handle<crate::GlobalVariable>,
1594 context: &ExpressionContext,
1595 ) -> BackendResult {
1596 let global = &context.module.global_variables[handle];
1597 let (offset, array_ty) = match context.module.types[global.ty].inner {
1598 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1599 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1600 None => return Err(Error::GenericValidation("Struct has no members".into())),
1601 },
1602 crate::TypeInner::Array {
1603 size: crate::ArraySize::Dynamic,
1604 ..
1605 } => (0, global.ty),
1606 ref ty => {
1607 return Err(Error::GenericValidation(format!(
1608 "Expected type with dynamic array, got {ty:?}"
1609 )))
1610 }
1611 };
1612
1613 let (size, stride) = match context.module.types[array_ty].inner {
1614 crate::TypeInner::Array { base, stride, .. } => (
1615 context.module.types[base]
1616 .inner
1617 .size(context.module.to_ctx()),
1618 stride,
1619 ),
1620 ref ty => {
1621 return Err(Error::GenericValidation(format!(
1622 "Expected array type, got {ty:?}"
1623 )))
1624 }
1625 };
1626
1627 write!(
1640 self.out,
1641 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1642 member = ArraySizeMember(handle),
1643 offset = offset,
1644 size = size,
1645 stride = stride,
1646 )?;
1647 Ok(())
1648 }
1649
1650 fn put_dot_product<T: Copy>(
1655 &mut self,
1656 arg: T,
1657 arg1: T,
1658 size: usize,
1659 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1660 ) -> BackendResult {
1661 write!(self.out, "(")?;
1664
1665 for index in 0..size {
1667 write!(self.out, " + ")?;
1670 extractor(self, arg, index)?;
1671 write!(self.out, " * ")?;
1672 extractor(self, arg1, index)?;
1673 }
1674
1675 write!(self.out, ")")?;
1676 Ok(())
1677 }
1678
1679 fn put_pack4x8(
1681 &mut self,
1682 arg: Handle<crate::Expression>,
1683 context: &ExpressionContext<'_>,
1684 was_signed: bool,
1685 clamp_bounds: Option<(&str, &str)>,
1686 ) -> Result<(), Error> {
1687 let write_arg = |this: &mut Self| -> BackendResult {
1688 if let Some((min, max)) = clamp_bounds {
1689 write!(this.out, "{NAMESPACE}::clamp(")?;
1691 this.put_expression(arg, context, true)?;
1692 write!(this.out, ", {min}, {max})")?;
1693 } else {
1694 this.put_expression(arg, context, true)?;
1695 }
1696 Ok(())
1697 };
1698
1699 if context.lang_version >= (2, 1) {
1700 let packed_type = if was_signed {
1701 "packed_char4"
1702 } else {
1703 "packed_uchar4"
1704 };
1705 write!(self.out, "as_type<uint>({packed_type}(")?;
1707 write_arg(self)?;
1708 write!(self.out, "))")?;
1709 } else {
1710 if was_signed {
1712 write!(self.out, "uint(")?;
1713 }
1714 write!(self.out, "(")?;
1715 write_arg(self)?;
1716 write!(self.out, "[0] & 0xFF) | ((")?;
1717 write_arg(self)?;
1718 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1719 write_arg(self)?;
1720 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1721 write_arg(self)?;
1722 write!(self.out, "[3] & 0xFF) << 24)")?;
1723 if was_signed {
1724 write!(self.out, ")")?;
1725 }
1726 }
1727
1728 Ok(())
1729 }
1730
1731 fn put_isign(
1734 &mut self,
1735 arg: Handle<crate::Expression>,
1736 context: &ExpressionContext,
1737 ) -> BackendResult {
1738 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1739 let scalar = context
1740 .resolve_type(arg)
1741 .scalar()
1742 .expect("put_isign should only be called for args which have an integer scalar type")
1743 .to_msl_name();
1744 match context.resolve_type(arg) {
1745 &crate::TypeInner::Vector { size, .. } => {
1746 let size = common::vector_size_str(size);
1747 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1748 }
1749 _ => {
1750 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1751 }
1752 }
1753 write!(self.out, ", (")?;
1754 self.put_expression(arg, context, true)?;
1755 write!(self.out, " > 0)), {scalar}(0), (")?;
1756 self.put_expression(arg, context, true)?;
1757 write!(self.out, " == 0))")?;
1758 Ok(())
1759 }
1760
1761 pub(super) fn put_const_expression(
1762 &mut self,
1763 expr_handle: Handle<crate::Expression>,
1764 module: &crate::Module,
1765 mod_info: &valid::ModuleInfo,
1766 arena: &crate::Arena<crate::Expression>,
1767 ) -> BackendResult {
1768 self.put_possibly_const_expression(
1769 expr_handle,
1770 arena,
1771 module,
1772 mod_info,
1773 &(module, mod_info),
1774 |&(_, mod_info), expr| &mod_info[expr],
1775 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1776 )
1777 }
1778
1779 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1780 match literal {
1781 crate::Literal::F64(_) => {
1782 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1783 }
1784 crate::Literal::F16(value) => {
1785 if value.is_infinite() {
1786 let sign = if value.is_sign_negative() { "-" } else { "" };
1787 write!(self.out, "{sign}INFINITY")?;
1788 } else if value.is_nan() {
1789 write!(self.out, "NAN")?;
1790 } else {
1791 let suffix = if value.fract() == f16::from_f32(0.0) {
1792 ".0h"
1793 } else {
1794 "h"
1795 };
1796 write!(self.out, "{value}{suffix}")?;
1797 }
1798 }
1799 crate::Literal::F32(value) => {
1800 if value.is_infinite() {
1801 let sign = if value.is_sign_negative() { "-" } else { "" };
1802 write!(self.out, "{sign}INFINITY")?;
1803 } else if value.is_nan() {
1804 write!(self.out, "NAN")?;
1805 } else {
1806 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1807 write!(self.out, "{value}{suffix}")?;
1808 }
1809 }
1810 crate::Literal::U16(value) => {
1811 write!(self.out, "static_cast<ushort>({value})")?;
1812 }
1813 crate::Literal::I16(value) => {
1814 write!(self.out, "static_cast<short>({value})")?;
1815 }
1816 crate::Literal::U32(value) => {
1817 write!(self.out, "{value}u")?;
1818 }
1819 crate::Literal::I32(value) => {
1820 if value == i32::MIN {
1825 write!(self.out, "({} - 1)", value + 1)?;
1826 } else {
1827 write!(self.out, "{value}")?;
1828 }
1829 }
1830 crate::Literal::U64(value) => {
1831 write!(self.out, "{value}uL")?;
1832 }
1833 crate::Literal::I64(value) => {
1834 if value == i64::MIN {
1841 write!(self.out, "({}L - 1L)", value + 1)?;
1842 } else {
1843 write!(self.out, "{value}L")?;
1844 }
1845 }
1846 crate::Literal::Bool(value) => {
1847 write!(self.out, "{value}")?;
1848 }
1849 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1850 return Err(Error::GenericValidation(
1851 "Unsupported abstract literal".into(),
1852 ));
1853 }
1854 }
1855 Ok(())
1856 }
1857
1858 #[allow(clippy::too_many_arguments)]
1859 fn put_possibly_const_expression<C, I, E>(
1860 &mut self,
1861 expr_handle: Handle<crate::Expression>,
1862 expressions: &crate::Arena<crate::Expression>,
1863 module: &crate::Module,
1864 mod_info: &valid::ModuleInfo,
1865 ctx: &C,
1866 get_expr_ty: I,
1867 put_expression: E,
1868 ) -> BackendResult
1869 where
1870 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1871 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1872 {
1873 match expressions[expr_handle] {
1874 crate::Expression::Literal(literal) => {
1875 self.put_literal(literal)?;
1876 }
1877 crate::Expression::Constant(handle) => {
1878 let constant = &module.constants[handle];
1879 if constant.name.is_some() {
1880 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1881 } else {
1882 self.put_const_expression(
1883 constant.init,
1884 module,
1885 mod_info,
1886 &module.global_expressions,
1887 )?;
1888 }
1889 }
1890 crate::Expression::ZeroValue(ty) => {
1891 let ty_name = TypeContext {
1892 handle: ty,
1893 gctx: module.to_ctx(),
1894 names: &self.names,
1895 access: crate::StorageAccess::empty(),
1896 first_time: false,
1897 };
1898 write!(self.out, "{ty_name} {{}}")?;
1899 }
1900 crate::Expression::Compose { ty, ref components } => {
1901 let ty_name = TypeContext {
1902 handle: ty,
1903 gctx: module.to_ctx(),
1904 names: &self.names,
1905 access: crate::StorageAccess::empty(),
1906 first_time: false,
1907 };
1908 write!(self.out, "{ty_name}")?;
1909 match module.types[ty].inner {
1910 crate::TypeInner::Scalar(_)
1911 | crate::TypeInner::Vector { .. }
1912 | crate::TypeInner::Matrix { .. } => {
1913 self.put_call_parameters_impl(
1914 components.iter().copied(),
1915 ctx,
1916 put_expression,
1917 )?;
1918 }
1919 crate::TypeInner::Array { .. } => {
1920 write!(self.out, " {{{{")?;
1923 for (index, &component) in components.iter().enumerate() {
1924 if index != 0 {
1925 write!(self.out, ", ")?;
1926 }
1927 put_expression(self, ctx, component)?;
1928 }
1929 write!(self.out, "}}}}")?;
1930 }
1931 crate::TypeInner::Struct { .. } => {
1932 write!(self.out, " {{")?;
1933 for (index, &component) in components.iter().enumerate() {
1934 if index != 0 {
1935 write!(self.out, ", ")?;
1936 }
1937 if self.struct_member_pads.contains(&(ty, index as u32)) {
1939 write!(self.out, "{{}}, ")?;
1940 }
1941 put_expression(self, ctx, component)?;
1942 }
1943 write!(self.out, "}}")?;
1944 }
1945 _ => return Err(Error::UnsupportedCompose(ty)),
1946 }
1947 }
1948 crate::Expression::Splat { size, value } => {
1949 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1950 crate::TypeInner::Scalar(scalar) => scalar,
1951 ref ty => {
1952 return Err(Error::GenericValidation(format!(
1953 "Expected splat value type must be a scalar, got {ty:?}",
1954 )))
1955 }
1956 };
1957 put_numeric_type(&mut self.out, scalar, &[size])?;
1958 write!(self.out, "(")?;
1959 put_expression(self, ctx, value)?;
1960 write!(self.out, ")")?;
1961 }
1962 _ => {
1963 return Err(Error::Override);
1964 }
1965 }
1966
1967 Ok(())
1968 }
1969
1970 pub(super) fn put_expression(
1982 &mut self,
1983 expr_handle: Handle<crate::Expression>,
1984 context: &ExpressionContext,
1985 is_scoped: bool,
1986 ) -> BackendResult {
1987 #[cfg(test)]
1989 self.put_expression_stack_pointers
1990 .insert(ptr::from_ref(&expr_handle).cast());
1991
1992 if let Some(name) = self.named_expressions.get(&expr_handle) {
1993 write!(self.out, "{name}")?;
1994 return Ok(());
1995 }
1996
1997 let expression = &context.function.expressions[expr_handle];
1998 match *expression {
1999 crate::Expression::Literal(_)
2000 | crate::Expression::Constant(_)
2001 | crate::Expression::ZeroValue(_)
2002 | crate::Expression::Compose { .. }
2003 | crate::Expression::Splat { .. } => {
2004 self.put_possibly_const_expression(
2005 expr_handle,
2006 &context.function.expressions,
2007 context.module,
2008 context.mod_info,
2009 context,
2010 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2011 |writer, context, expr| writer.put_expression(expr, context, true),
2012 )?;
2013 }
2014 crate::Expression::Override(_) => return Err(Error::Override),
2015 crate::Expression::Access { base, .. }
2016 | crate::Expression::AccessIndex { base, .. } => {
2017 let policy = context.choose_bounds_check_policy(base);
2023 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2024 && self.put_bounds_checks(
2025 expr_handle,
2026 context,
2027 back::Level(0),
2028 if is_scoped { "" } else { "(" },
2029 )?
2030 {
2031 write!(self.out, " ? ")?;
2032 self.put_access_chain(expr_handle, policy, context)?;
2033 write!(self.out, " : ")?;
2034
2035 if context.resolve_type(base).pointer_space().is_some() {
2036 let result_ty = context.info[expr_handle]
2040 .ty
2041 .inner_with(&context.module.types)
2042 .pointer_base_type();
2043 let result_ty_handle = match result_ty {
2044 Some(TypeResolution::Handle(handle)) => handle,
2045 Some(TypeResolution::Value(_)) => {
2046 unreachable!(
2054 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2055 );
2056 }
2057 None => {
2058 unreachable!(
2059 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2060 )
2061 }
2062 };
2063 let name_key =
2064 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2065 self.out.write_str(&self.names[&name_key])?;
2066 } else {
2067 write!(self.out, "DefaultConstructible()")?;
2068 }
2069
2070 if !is_scoped {
2071 write!(self.out, ")")?;
2072 }
2073 } else {
2074 self.put_access_chain(expr_handle, policy, context)?;
2075 }
2076 }
2077 crate::Expression::Swizzle {
2078 size,
2079 vector,
2080 pattern,
2081 } => {
2082 self.put_wrapped_expression_for_packed_vec3_access(
2083 vector,
2084 context,
2085 false,
2086 &Self::put_expression,
2087 )?;
2088 write!(self.out, ".")?;
2089 for &sc in pattern[..size as usize].iter() {
2090 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2091 }
2092 }
2093 crate::Expression::FunctionArgument(index) => {
2094 let name_key = match context.origin {
2095 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2096 FunctionOrigin::EntryPoint(ep_index) => {
2097 NameKey::EntryPointArgument(ep_index, index)
2098 }
2099 };
2100 let name = &self.names[&name_key];
2101 write!(self.out, "{name}")?;
2102 }
2103 crate::Expression::GlobalVariable(handle) => {
2104 let name = &self.names[&NameKey::GlobalVariable(handle)];
2105 write!(self.out, "{name}")?;
2106 }
2107 crate::Expression::LocalVariable(handle) => {
2108 let name_key = NameKey::local(context.origin, handle);
2109 let name = &self.names[&name_key];
2110 write!(self.out, "{name}")?;
2111 }
2112 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2113 crate::Expression::ImageSample {
2114 coordinate,
2115 image,
2116 sampler,
2117 clamp_to_edge: true,
2118 gather: None,
2119 array_index: None,
2120 offset: None,
2121 level: crate::SampleLevel::Zero,
2122 depth_ref: None,
2123 } => {
2124 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2125 self.put_expression(image, context, true)?;
2126 write!(self.out, ", ")?;
2127 self.put_expression(sampler, context, true)?;
2128 write!(self.out, ", ")?;
2129 self.put_expression(coordinate, context, true)?;
2130 write!(self.out, ")")?;
2131 }
2132 crate::Expression::ImageSample {
2133 image,
2134 sampler,
2135 gather,
2136 coordinate,
2137 array_index,
2138 offset,
2139 level,
2140 depth_ref,
2141 clamp_to_edge,
2142 } => {
2143 if clamp_to_edge {
2144 return Err(Error::GenericValidation(
2145 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2146 ));
2147 }
2148
2149 let main_op = match gather {
2150 Some(_) => "gather",
2151 None => "sample",
2152 };
2153 let comparison_op = match depth_ref {
2154 Some(_) => "_compare",
2155 None => "",
2156 };
2157 self.put_expression(image, context, false)?;
2158 write!(self.out, ".{main_op}{comparison_op}(")?;
2159 self.put_expression(sampler, context, true)?;
2160 write!(self.out, ", ")?;
2161 self.put_expression(coordinate, context, true)?;
2162 if let Some(expr) = array_index {
2163 write!(self.out, ", ")?;
2164 self.put_expression(expr, context, true)?;
2165 }
2166 if let Some(dref) = depth_ref {
2167 write!(self.out, ", ")?;
2168 self.put_expression(dref, context, true)?;
2169 }
2170
2171 self.put_image_sample_level(image, level, context)?;
2172
2173 if let Some(offset) = offset {
2174 write!(self.out, ", ")?;
2175 self.put_expression(offset, context, true)?;
2176 }
2177
2178 match gather {
2179 None | Some(crate::SwizzleComponent::X) => {}
2180 Some(component) => {
2181 let is_cube_map = match *context.resolve_type(image) {
2182 crate::TypeInner::Image {
2183 dim: crate::ImageDimension::Cube,
2184 ..
2185 } => true,
2186 _ => false,
2187 };
2188 if offset.is_none() && !is_cube_map {
2191 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2192 }
2193 let letter = back::COMPONENTS[component as usize];
2194 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2195 }
2196 }
2197 write!(self.out, ")")?;
2198 }
2199 crate::Expression::ImageLoad {
2200 image,
2201 coordinate,
2202 array_index,
2203 sample,
2204 level,
2205 } => {
2206 let address = TexelAddress {
2207 coordinate,
2208 array_index,
2209 sample,
2210 level: level.map(LevelOfDetail::Direct),
2211 };
2212 self.put_image_load(expr_handle, image, address, context)?;
2213 }
2214 crate::Expression::ImageQuery { image, query } => match query {
2217 crate::ImageQuery::Size { level } => {
2218 self.put_image_size_query(
2219 image,
2220 level.map(LevelOfDetail::Direct),
2221 crate::ScalarKind::Uint,
2222 context,
2223 )?;
2224 }
2225 crate::ImageQuery::NumLevels => {
2226 self.put_expression(image, context, false)?;
2227 write!(self.out, ".get_num_mip_levels()")?;
2228 }
2229 crate::ImageQuery::NumLayers => {
2230 self.put_expression(image, context, false)?;
2231 write!(self.out, ".get_array_size()")?;
2232 }
2233 crate::ImageQuery::NumSamples => {
2234 self.put_expression(image, context, false)?;
2235 write!(self.out, ".get_num_samples()")?;
2236 }
2237 },
2238 crate::Expression::Unary { op, expr } => {
2239 let op_str = match op {
2240 crate::UnaryOperator::Negate => {
2241 match context.resolve_type(expr).scalar_kind() {
2242 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2243 _ => "-",
2244 }
2245 }
2246 crate::UnaryOperator::LogicalNot => "!",
2247 crate::UnaryOperator::BitwiseNot => "~",
2248 };
2249 write!(self.out, "{op_str}(")?;
2250 self.put_expression(expr, context, false)?;
2251 write!(self.out, ")")?;
2252 }
2253 crate::Expression::Binary { op, left, right } => {
2254 let kind = context
2255 .resolve_type(left)
2256 .scalar_kind()
2257 .ok_or(Error::UnsupportedBinaryOp(op))?;
2258
2259 if op == crate::BinaryOperator::Divide
2260 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2261 && context.emit_int_div_checks
2262 {
2263 write!(self.out, "{DIV_FUNCTION}(")?;
2264 self.put_expression(left, context, true)?;
2265 write!(self.out, ", ")?;
2266 self.put_expression(right, context, true)?;
2267 write!(self.out, ")")?;
2268 } else if op == crate::BinaryOperator::Modulo
2269 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2270 && context.emit_int_div_checks
2271 {
2272 write!(self.out, "{MOD_FUNCTION}(")?;
2273 self.put_expression(left, context, true)?;
2274 write!(self.out, ", ")?;
2275 self.put_expression(right, context, true)?;
2276 write!(self.out, ")")?;
2277 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2278 write!(self.out, "{NAMESPACE}::fmod(")?;
2283 self.put_expression(left, context, true)?;
2284 write!(self.out, ", ")?;
2285 self.put_expression(right, context, true)?;
2286 write!(self.out, ")")?;
2287 } else if (op == crate::BinaryOperator::Add
2288 || op == crate::BinaryOperator::Subtract
2289 || op == crate::BinaryOperator::Multiply)
2290 && kind == crate::ScalarKind::Sint
2291 {
2292 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2293 crate::TypeInner::Scalar(scalar) => {
2294 Ok(crate::TypeInner::Scalar(crate::Scalar {
2295 kind: crate::ScalarKind::Uint,
2296 ..scalar
2297 }))
2298 }
2299 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2300 size,
2301 scalar: crate::Scalar {
2302 kind: crate::ScalarKind::Uint,
2303 ..scalar
2304 },
2305 }),
2306 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2307 };
2308
2309 self.put_bitcasted_expression(
2314 context.resolve_type(expr_handle),
2315 expr_handle,
2316 context,
2317 &|writer, context, is_scoped| {
2318 writer.put_binop(
2319 op,
2320 left,
2321 right,
2322 context,
2323 is_scoped,
2324 &|writer, expr, context, _is_scoped| {
2325 writer.put_bitcasted_expression(
2326 &to_unsigned(context.resolve_type(expr))?,
2327 expr,
2328 context,
2329 &|writer, context, is_scoped| {
2330 writer.put_expression(expr, context, is_scoped)
2331 },
2332 )
2333 },
2334 )
2335 },
2336 )?;
2337 } else {
2338 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2339 }
2340 }
2341 crate::Expression::Select {
2342 condition,
2343 accept,
2344 reject,
2345 } => match *context.resolve_type(condition) {
2346 crate::TypeInner::Scalar(crate::Scalar {
2347 kind: crate::ScalarKind::Bool,
2348 ..
2349 }) => {
2350 if !is_scoped {
2351 write!(self.out, "(")?;
2352 }
2353 self.put_expression(condition, context, false)?;
2354 write!(self.out, " ? ")?;
2355 self.put_expression(accept, context, false)?;
2356 write!(self.out, " : ")?;
2357 self.put_expression(reject, context, false)?;
2358 if !is_scoped {
2359 write!(self.out, ")")?;
2360 }
2361 }
2362 crate::TypeInner::Vector {
2363 scalar:
2364 crate::Scalar {
2365 kind: crate::ScalarKind::Bool,
2366 ..
2367 },
2368 ..
2369 } => {
2370 write!(self.out, "{NAMESPACE}::select(")?;
2371 self.put_expression(reject, context, true)?;
2372 write!(self.out, ", ")?;
2373 self.put_expression(accept, context, true)?;
2374 write!(self.out, ", ")?;
2375 self.put_expression(condition, context, true)?;
2376 write!(self.out, ")")?;
2377 }
2378 ref ty => {
2379 return Err(Error::GenericValidation(format!(
2380 "Expected select condition to be a non-bool type, got {ty:?}",
2381 )))
2382 }
2383 },
2384 crate::Expression::Derivative { axis, expr, .. } => {
2385 use crate::DerivativeAxis as Axis;
2386 let op = match axis {
2387 Axis::X => "dfdx",
2388 Axis::Y => "dfdy",
2389 Axis::Width => "fwidth",
2390 };
2391 write!(self.out, "{NAMESPACE}::{op}")?;
2392 self.put_call_parameters(iter::once(expr), context)?;
2393 }
2394 crate::Expression::Relational { fun, argument } => {
2395 let op = match fun {
2396 crate::RelationalFunction::Any => "any",
2397 crate::RelationalFunction::All => "all",
2398 crate::RelationalFunction::IsNan => "isnan",
2399 crate::RelationalFunction::IsInf => "isinf",
2400 };
2401 write!(self.out, "{NAMESPACE}::{op}")?;
2402 self.put_call_parameters(iter::once(argument), context)?;
2403 }
2404 crate::Expression::Math {
2405 fun,
2406 arg,
2407 arg1,
2408 arg2,
2409 arg3,
2410 } => {
2411 use crate::MathFunction as Mf;
2412
2413 let arg_type = context.resolve_type(arg);
2414 let scalar_argument = match arg_type {
2415 &crate::TypeInner::Scalar(_) => true,
2416 _ => false,
2417 };
2418
2419 let fun_name = match fun {
2420 Mf::Abs => "abs",
2422 Mf::Min => "min",
2423 Mf::Max => "max",
2424 Mf::Clamp => "clamp",
2425 Mf::Saturate => "saturate",
2426 Mf::Cos => "cos",
2428 Mf::Cosh => "cosh",
2429 Mf::Sin => "sin",
2430 Mf::Sinh => "sinh",
2431 Mf::Tan => "tan",
2432 Mf::Tanh => "tanh",
2433 Mf::Acos => "acos",
2434 Mf::Asin => "asin",
2435 Mf::Atan => "atan",
2436 Mf::Atan2 => "atan2",
2437 Mf::Asinh => "asinh",
2438 Mf::Acosh => "acosh",
2439 Mf::Atanh => "atanh",
2440 Mf::Radians => "",
2441 Mf::Degrees => "",
2442 Mf::Ceil => "ceil",
2444 Mf::Floor => "floor",
2445 Mf::Round => "rint",
2446 Mf::Fract => "fract",
2447 Mf::Trunc => "trunc",
2448 Mf::Modf => MODF_FUNCTION,
2449 Mf::Frexp => FREXP_FUNCTION,
2450 Mf::Ldexp => "ldexp",
2451 Mf::Exp => "exp",
2453 Mf::Exp2 => "exp2",
2454 Mf::Log => "log",
2455 Mf::Log2 => "log2",
2456 Mf::Pow => "pow",
2457 Mf::Dot => match *context.resolve_type(arg) {
2459 crate::TypeInner::Vector {
2460 scalar:
2461 crate::Scalar {
2462 kind: crate::ScalarKind::Float,
2464 ..
2465 },
2466 ..
2467 } => "dot",
2468 crate::TypeInner::Vector {
2469 size,
2470 scalar:
2471 scalar @ crate::Scalar {
2472 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2473 ..
2474 },
2475 } => {
2476 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2478 write!(self.out, "{fun_name}(")?;
2479 self.put_expression(arg, context, true)?;
2480 write!(self.out, ", ")?;
2481 self.put_expression(arg1.unwrap(), context, true)?;
2482 write!(self.out, ")")?;
2483 return Ok(());
2484 }
2485 _ => unreachable!(
2486 "Correct TypeInner for dot product should be already validated"
2487 ),
2488 },
2489 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2490 if context.lang_version >= (2, 1) {
2491 let packed_type = match fun {
2495 Mf::Dot4I8Packed => "packed_char4",
2496 Mf::Dot4U8Packed => "packed_uchar4",
2497 _ => unreachable!(),
2498 };
2499
2500 return self.put_dot_product(
2501 Reinterpreted::new(packed_type, arg),
2502 Reinterpreted::new(packed_type, arg1.unwrap()),
2503 4,
2504 |writer, arg, index| {
2505 write!(writer.out, "{arg}[{index}]")?;
2508 Ok(())
2509 },
2510 );
2511 } else {
2512 let conversion = match fun {
2516 Mf::Dot4I8Packed => "int",
2517 Mf::Dot4U8Packed => "",
2518 _ => unreachable!(),
2519 };
2520
2521 return self.put_dot_product(
2522 arg,
2523 arg1.unwrap(),
2524 4,
2525 |writer, arg, index| {
2526 write!(writer.out, "({conversion}(")?;
2527 writer.put_expression(arg, context, true)?;
2528 if index == 3 {
2529 write!(writer.out, ") >> 24)")?;
2530 } else {
2531 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2532 }
2533 Ok(())
2534 },
2535 );
2536 }
2537 }
2538 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2539 Mf::Cross => "cross",
2540 Mf::Distance => "distance",
2541 Mf::Length if scalar_argument => "abs",
2542 Mf::Length => "length",
2543 Mf::Normalize => "normalize",
2544 Mf::FaceForward => "faceforward",
2545 Mf::Reflect => "reflect",
2546 Mf::Refract => "refract",
2547 Mf::Sign => match arg_type.scalar_kind() {
2549 Some(crate::ScalarKind::Sint) => {
2550 return self.put_isign(arg, context);
2551 }
2552 _ => "sign",
2553 },
2554 Mf::Fma => "fma",
2555 Mf::Mix => "mix",
2556 Mf::Step => "step",
2557 Mf::SmoothStep => "smoothstep",
2558 Mf::Sqrt => "sqrt",
2559 Mf::InverseSqrt => "rsqrt",
2560 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2561 Mf::Transpose => "transpose",
2562 Mf::Determinant => "determinant",
2563 Mf::QuantizeToF16 => "",
2564 Mf::CountTrailingZeros => "ctz",
2566 Mf::CountLeadingZeros => "clz",
2567 Mf::CountOneBits => "popcount",
2568 Mf::ReverseBits => "reverse_bits",
2569 Mf::ExtractBits => "",
2570 Mf::InsertBits => "",
2571 Mf::FirstTrailingBit => "",
2572 Mf::FirstLeadingBit => "",
2573 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2575 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2576 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2577 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2578 Mf::Pack2x16float => "",
2579 Mf::Pack4xI8 => "",
2580 Mf::Pack4xU8 => "",
2581 Mf::Pack4xI8Clamp => "",
2582 Mf::Pack4xU8Clamp => "",
2583 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2585 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2586 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2587 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2588 Mf::Unpack2x16float => "",
2589 Mf::Unpack4xI8 => "",
2590 Mf::Unpack4xU8 => "",
2591 };
2592
2593 match fun {
2594 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2595 if context.lang_version < (1, 2) {
2604 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2605 }
2606 }
2607 _ => {}
2608 }
2609
2610 match fun {
2611 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2612 write!(self.out, "{ABS_FUNCTION}(")?;
2613 self.put_expression(arg, context, true)?;
2614 write!(self.out, ")")?;
2615 }
2616 Mf::Distance if scalar_argument => {
2617 write!(self.out, "{NAMESPACE}::abs(")?;
2618 self.put_expression(arg, context, false)?;
2619 write!(self.out, " - ")?;
2620 self.put_expression(arg1.unwrap(), context, false)?;
2621 write!(self.out, ")")?;
2622 }
2623 Mf::FirstTrailingBit => {
2624 let scalar = context.resolve_type(arg).scalar().unwrap();
2625 let constant = scalar.width * 8 + 1;
2626
2627 write!(self.out, "((({NAMESPACE}::ctz(")?;
2628 self.put_expression(arg, context, true)?;
2629 write!(self.out, ") + 1) % {constant}) - 1)")?;
2630 }
2631 Mf::FirstLeadingBit => {
2632 let inner = context.resolve_type(arg);
2633 let scalar = inner.scalar().unwrap();
2634 let constant = scalar.width * 8 - 1;
2635
2636 write!(
2637 self.out,
2638 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2639 )?;
2640
2641 if scalar.kind == crate::ScalarKind::Sint {
2642 write!(self.out, "{NAMESPACE}::select(")?;
2643 self.put_expression(arg, context, true)?;
2644 write!(self.out, ", ~")?;
2645 self.put_expression(arg, context, true)?;
2646 write!(self.out, ", ")?;
2647 self.put_expression(arg, context, true)?;
2648 write!(self.out, " < 0)")?;
2649 } else {
2650 self.put_expression(arg, context, true)?;
2651 }
2652
2653 write!(self.out, "), ")?;
2654
2655 match *inner {
2657 crate::TypeInner::Vector { size, scalar } => {
2658 let size = common::vector_size_str(size);
2659 let name = scalar.to_msl_name();
2660 write!(self.out, "{name}{size}")?;
2661 }
2662 crate::TypeInner::Scalar(scalar) => {
2663 let name = scalar.to_msl_name();
2664 write!(self.out, "{name}")?;
2665 }
2666 _ => (),
2667 }
2668
2669 write!(self.out, "(-1), ")?;
2670 self.put_expression(arg, context, true)?;
2671 write!(self.out, " == 0")?;
2672 if scalar.kind == crate::ScalarKind::Sint {
2673 write!(self.out, " || ")?;
2674 self.put_expression(arg, context, true)?;
2675 write!(self.out, " == -1")?;
2676 }
2677 write!(self.out, ")")?;
2678 }
2679 Mf::Unpack2x16float => {
2680 write!(self.out, "float2(as_type<half2>(")?;
2681 self.put_expression(arg, context, false)?;
2682 write!(self.out, "))")?;
2683 }
2684 Mf::Pack2x16float => {
2685 write!(self.out, "as_type<uint>(half2(")?;
2686 self.put_expression(arg, context, false)?;
2687 write!(self.out, "))")?;
2688 }
2689 Mf::ExtractBits => {
2690 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2707
2708 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2709 self.put_expression(arg, context, true)?;
2710 write!(self.out, ", {NAMESPACE}::min(")?;
2711 self.put_expression(arg1.unwrap(), context, true)?;
2712 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2713 self.put_expression(arg2.unwrap(), context, true)?;
2714 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2715 self.put_expression(arg1.unwrap(), context, true)?;
2716 write!(self.out, ", {scalar_bits}u)))")?;
2717 }
2718 Mf::InsertBits => {
2719 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2724
2725 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2726 self.put_expression(arg, context, true)?;
2727 write!(self.out, ", ")?;
2728 self.put_expression(arg1.unwrap(), context, true)?;
2729 write!(self.out, ", {NAMESPACE}::min(")?;
2730 self.put_expression(arg2.unwrap(), context, true)?;
2731 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2732 self.put_expression(arg3.unwrap(), context, true)?;
2733 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2734 self.put_expression(arg2.unwrap(), context, true)?;
2735 write!(self.out, ", {scalar_bits}u)))")?;
2736 }
2737 Mf::Radians => {
2738 write!(self.out, "((")?;
2739 self.put_expression(arg, context, false)?;
2740 write!(self.out, ") * 0.017453292519943295474)")?;
2741 }
2742 Mf::Degrees => {
2743 write!(self.out, "((")?;
2744 self.put_expression(arg, context, false)?;
2745 write!(self.out, ") * 57.295779513082322865)")?;
2746 }
2747 Mf::Modf | Mf::Frexp => {
2748 write!(self.out, "{fun_name}")?;
2749 self.put_call_parameters(iter::once(arg), context)?;
2750 }
2751 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2752 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2753 Mf::Pack4xI8Clamp => {
2754 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2755 }
2756 Mf::Pack4xU8Clamp => {
2757 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2758 }
2759 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2760 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2761 "u"
2762 } else {
2763 ""
2764 };
2765
2766 if context.lang_version >= (2, 1) {
2767 write!(
2769 self.out,
2770 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2771 )?;
2772 self.put_expression(arg, context, true)?;
2773 write!(self.out, "))")?;
2774 } else {
2775 write!(self.out, "({sign_prefix}int4(")?;
2777 self.put_expression(arg, context, true)?;
2778 write!(self.out, ", ")?;
2779 self.put_expression(arg, context, true)?;
2780 write!(self.out, " >> 8, ")?;
2781 self.put_expression(arg, context, true)?;
2782 write!(self.out, " >> 16, ")?;
2783 self.put_expression(arg, context, true)?;
2784 write!(self.out, " >> 24) << 24 >> 24)")?;
2785 }
2786 }
2787 Mf::QuantizeToF16 => {
2788 match *context.resolve_type(arg) {
2789 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2790 crate::TypeInner::Vector { size, .. } => write!(
2791 self.out,
2792 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2793 size = common::vector_size_str(size),
2794 )?,
2795 _ => unreachable!(
2796 "Correct TypeInner for QuantizeToF16 should be already validated"
2797 ),
2798 };
2799
2800 self.put_expression(arg, context, true)?;
2801 write!(self.out, "))")?;
2802 }
2803 _ => {
2804 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2805 self.put_call_parameters(
2806 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2807 context,
2808 )?;
2809 }
2810 }
2811 }
2812 crate::Expression::As {
2813 expr,
2814 kind,
2815 convert,
2816 } => match *context.resolve_type(expr) {
2817 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2818 if src.kind == crate::ScalarKind::Float
2819 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2820 && convert.is_some()
2821 {
2822 let fun_name = match (kind, convert) {
2826 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2827 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2828 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2829 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2830 _ => unreachable!(),
2831 };
2832 write!(self.out, "{fun_name}(")?;
2833 self.put_expression(expr, context, true)?;
2834 write!(self.out, ")")?;
2835 } else {
2836 let target_scalar = crate::Scalar {
2837 kind,
2838 width: convert.unwrap_or(src.width),
2839 };
2840 let op = match convert {
2841 Some(_) => "static_cast",
2842 None => "as_type",
2843 };
2844 write!(self.out, "{op}<")?;
2845 match *context.resolve_type(expr) {
2846 crate::TypeInner::Vector { size, .. } => {
2847 put_numeric_type(&mut self.out, target_scalar, &[size])?
2848 }
2849 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2850 };
2851 write!(self.out, ">(")?;
2852 self.put_expression(expr, context, true)?;
2853 write!(self.out, ")")?;
2854 }
2855 }
2856 crate::TypeInner::Matrix {
2857 columns,
2858 rows,
2859 scalar,
2860 } => {
2861 let target_scalar = crate::Scalar {
2862 kind,
2863 width: convert.unwrap_or(scalar.width),
2864 };
2865 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2866 write!(self.out, "(")?;
2867 self.put_expression(expr, context, true)?;
2868 write!(self.out, ")")?;
2869 }
2870 ref ty => {
2871 return Err(Error::GenericValidation(format!(
2872 "Unsupported type for As: {ty:?}"
2873 )))
2874 }
2875 },
2876 crate::Expression::CallResult(_)
2878 | crate::Expression::AtomicResult { .. }
2879 | crate::Expression::WorkGroupUniformLoadResult { .. }
2880 | crate::Expression::SubgroupBallotResult
2881 | crate::Expression::SubgroupOperationResult { .. }
2882 | crate::Expression::RayQueryProceedResult => {
2883 unreachable!()
2884 }
2885 crate::Expression::ArrayLength(expr) => {
2886 let global = match context.function.expressions[expr] {
2888 crate::Expression::AccessIndex { base, .. } => {
2889 match context.function.expressions[base] {
2890 crate::Expression::GlobalVariable(handle) => handle,
2891 ref ex => {
2892 return Err(Error::GenericValidation(format!(
2893 "Expected global variable in AccessIndex, got {ex:?}"
2894 )))
2895 }
2896 }
2897 }
2898 crate::Expression::GlobalVariable(handle) => handle,
2899 ref ex => {
2900 return Err(Error::GenericValidation(format!(
2901 "Unexpected expression in ArrayLength, got {ex:?}"
2902 )))
2903 }
2904 };
2905
2906 if !is_scoped {
2907 write!(self.out, "(")?;
2908 }
2909 write!(self.out, "1 + ")?;
2910 self.put_dynamic_array_max_index(global, context)?;
2911 if !is_scoped {
2912 write!(self.out, ")")?;
2913 }
2914 }
2915 crate::Expression::RayQueryVertexPositions { .. } => {
2916 unimplemented!()
2917 }
2918 crate::Expression::RayQueryGetIntersection { query, committed } => {
2919 if context.lang_version < (2, 4) {
2920 return Err(Error::UnsupportedRayTracing);
2921 }
2922
2923 write!(
2924 self.out,
2925 "{}_{committed}(",
2926 super::ray::INTERSECTION_FUNCTION_NAME
2927 )?;
2928 self.put_expression(query, context, true)?;
2929 write!(self.out, ")")?;
2930 }
2931 crate::Expression::CooperativeLoad { ref data, .. } => {
2932 if context.lang_version < (2, 3) {
2933 return Err(Error::UnsupportedCooperativeMatrix);
2934 }
2935 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2936 write!(self.out, "&")?;
2937 self.put_access_chain(data.pointer, context.policies.index, context)?;
2938 write!(self.out, ", ")?;
2939 self.put_expression(data.stride, context, true)?;
2940 write!(self.out, ", {})", !data.row_major)?;
2947 }
2948 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2949 if context.lang_version < (2, 3) {
2950 return Err(Error::UnsupportedCooperativeMatrix);
2951 }
2952 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2953 self.put_expression(a, context, true)?;
2954 write!(self.out, ", ")?;
2955 self.put_expression(b, context, true)?;
2956 write!(self.out, ", ")?;
2957 self.put_expression(c, context, true)?;
2958 write!(self.out, ")")?;
2959 }
2960 }
2961 Ok(())
2962 }
2963
2964 fn put_binop<F>(
2967 &mut self,
2968 op: crate::BinaryOperator,
2969 left: Handle<crate::Expression>,
2970 right: Handle<crate::Expression>,
2971 context: &ExpressionContext,
2972 is_scoped: bool,
2973 put_expression: &F,
2974 ) -> BackendResult
2975 where
2976 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2977 {
2978 let op_str = back::binary_operation_str(op);
2979
2980 if !is_scoped {
2981 write!(self.out, "(")?;
2982 }
2983
2984 if op == crate::BinaryOperator::Multiply
2987 && matches!(
2988 context.resolve_type(right),
2989 &crate::TypeInner::Matrix { .. }
2990 )
2991 {
2992 self.put_wrapped_expression_for_packed_vec3_access(
2993 left,
2994 context,
2995 false,
2996 put_expression,
2997 )?;
2998 } else {
2999 put_expression(self, left, context, false)?;
3000 }
3001
3002 write!(self.out, " {op_str} ")?;
3003
3004 if op == crate::BinaryOperator::Multiply
3006 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3007 {
3008 self.put_wrapped_expression_for_packed_vec3_access(
3009 right,
3010 context,
3011 false,
3012 put_expression,
3013 )?;
3014 } else {
3015 put_expression(self, right, context, false)?;
3016 }
3017
3018 if !is_scoped {
3019 write!(self.out, ")")?;
3020 }
3021
3022 Ok(())
3023 }
3024
3025 fn put_wrapped_expression_for_packed_vec3_access<F>(
3027 &mut self,
3028 expr_handle: Handle<crate::Expression>,
3029 context: &ExpressionContext,
3030 is_scoped: bool,
3031 put_expression: &F,
3032 ) -> BackendResult
3033 where
3034 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3035 {
3036 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3037 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3038 put_expression(self, expr_handle, context, is_scoped)?;
3039 write!(self.out, ")")?;
3040 } else {
3041 put_expression(self, expr_handle, context, is_scoped)?;
3042 }
3043 Ok(())
3044 }
3045
3046 fn put_bitcasted_expression<F>(
3049 &mut self,
3050 cast_to: &crate::TypeInner,
3051 inner_expr: Handle<crate::Expression>,
3052 context: &ExpressionContext,
3053 put_expression: &F,
3054 ) -> BackendResult
3055 where
3056 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3057 {
3058 let needs_truncation = match *cast_to {
3063 crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3064 crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3065 _ => false,
3066 };
3067
3068 write!(self.out, "as_type<")?;
3069 match *cast_to {
3070 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3071 crate::TypeInner::Vector { size, scalar } => {
3072 put_numeric_type(&mut self.out, scalar, &[size])?
3073 }
3074 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3075 };
3076 write!(self.out, ">(")?;
3077
3078 if needs_truncation {
3079 write!(self.out, "static_cast<")?;
3080 let unsigned_scalar = match *cast_to {
3082 crate::TypeInner::Scalar(scalar) => crate::Scalar {
3083 kind: crate::ScalarKind::Uint,
3084 ..scalar
3085 },
3086 crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3087 kind: crate::ScalarKind::Uint,
3088 ..scalar
3089 },
3090 _ => unreachable!(),
3091 };
3092 match *cast_to {
3093 crate::TypeInner::Scalar(_) => {
3094 put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3095 }
3096 crate::TypeInner::Vector { size, .. } => {
3097 put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3098 }
3099 _ => unreachable!(),
3100 };
3101 write!(self.out, ">(")?;
3102 }
3103
3104 if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3106 put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3107 write!(self.out, "(")?;
3108 put_expression(self, context, true)?;
3109 write!(self.out, ")")?;
3110 } else {
3111 put_expression(self, context, true)?;
3112 }
3113
3114 if needs_truncation {
3115 write!(self.out, ")")?;
3116 }
3117
3118 write!(self.out, ")")?;
3119 Ok(())
3120 }
3121
3122 fn put_index(
3124 &mut self,
3125 index: index::GuardedIndex,
3126 context: &ExpressionContext,
3127 is_scoped: bool,
3128 ) -> BackendResult {
3129 match index {
3130 index::GuardedIndex::Expression(expr) => {
3131 self.put_expression(expr, context, is_scoped)?
3132 }
3133 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3134 }
3135 Ok(())
3136 }
3137
3138 fn put_bounds_checks(
3168 &mut self,
3169 chain: Handle<crate::Expression>,
3170 context: &ExpressionContext,
3171 level: back::Level,
3172 prefix: &'static str,
3173 ) -> Result<bool, Error> {
3174 let mut check_written = false;
3175
3176 for item in context.bounds_check_iter(chain) {
3178 let BoundsCheck {
3179 base,
3180 index,
3181 length,
3182 } = item;
3183
3184 if check_written {
3185 write!(self.out, " && ")?;
3186 } else {
3187 write!(self.out, "{level}{prefix}")?;
3188 check_written = true;
3189 }
3190
3191 write!(self.out, "uint(")?;
3195 self.put_index(index, context, true)?;
3196 self.out.write_str(") < ")?;
3197 match length {
3198 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3199 index::IndexableLength::Dynamic => {
3200 let global = context.function.originating_global(base).ok_or_else(|| {
3201 Error::GenericValidation("Could not find originating global".into())
3202 })?;
3203 write!(self.out, "1 + ")?;
3204 self.put_dynamic_array_max_index(global, context)?
3205 }
3206 }
3207 }
3208
3209 Ok(check_written)
3210 }
3211
3212 fn put_access_chain(
3232 &mut self,
3233 chain: Handle<crate::Expression>,
3234 policy: index::BoundsCheckPolicy,
3235 context: &ExpressionContext,
3236 ) -> BackendResult {
3237 match context.function.expressions[chain] {
3238 crate::Expression::Access { base, index } => {
3239 let mut base_ty = context.resolve_type(base);
3240
3241 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3243 base_ty = &context.module.types[base].inner;
3244 }
3245
3246 self.put_subscripted_access_chain(
3247 base,
3248 base_ty,
3249 index::GuardedIndex::Expression(index),
3250 policy,
3251 context,
3252 )?;
3253 }
3254 crate::Expression::AccessIndex { base, index } => {
3255 let base_resolution = &context.info[base].ty;
3256 let mut base_ty = base_resolution.inner_with(&context.module.types);
3257 let mut base_ty_handle = base_resolution.handle();
3258
3259 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3261 base_ty = &context.module.types[base].inner;
3262 base_ty_handle = Some(base);
3263 }
3264
3265 match *base_ty {
3269 crate::TypeInner::Struct { .. } => {
3270 let base_ty = base_ty_handle.unwrap();
3271 self.put_access_chain(base, policy, context)?;
3272 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3273 write!(self.out, ".{name}")?;
3274 }
3275 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3276 self.put_access_chain(base, policy, context)?;
3277 if context.get_packed_vec_kind(base).is_some() {
3280 write!(self.out, "[{index}]")?;
3281 } else {
3282 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3283 }
3284 }
3285 _ => {
3286 self.put_subscripted_access_chain(
3287 base,
3288 base_ty,
3289 index::GuardedIndex::Known(index),
3290 policy,
3291 context,
3292 )?;
3293 }
3294 }
3295 }
3296 _ => self.put_expression(chain, context, false)?,
3297 }
3298
3299 Ok(())
3300 }
3301
3302 fn put_subscripted_access_chain(
3319 &mut self,
3320 base: Handle<crate::Expression>,
3321 base_ty: &crate::TypeInner,
3322 index: index::GuardedIndex,
3323 policy: index::BoundsCheckPolicy,
3324 context: &ExpressionContext,
3325 ) -> BackendResult {
3326 let accessing_wrapped_array = match *base_ty {
3327 crate::TypeInner::Array {
3328 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3329 ..
3330 } => true,
3331 _ => false,
3332 };
3333 let accessing_wrapped_binding_array =
3334 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3335
3336 self.put_access_chain(base, policy, context)?;
3337 if accessing_wrapped_array {
3338 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3339 }
3340 write!(self.out, "[")?;
3341
3342 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3344 context.access_needs_check(base, index)
3345 } else {
3346 None
3347 };
3348 if let Some(limit) = restriction_needed {
3349 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3350 self.put_index(index, context, true)?;
3351 write!(self.out, "), ")?;
3352 match limit {
3353 index::IndexableLength::Known(limit) => {
3354 write!(self.out, "{}u", limit - 1)?;
3355 }
3356 index::IndexableLength::Dynamic => {
3357 let global = context.function.originating_global(base).ok_or_else(|| {
3358 Error::GenericValidation("Could not find originating global".into())
3359 })?;
3360 self.put_dynamic_array_max_index(global, context)?;
3361 }
3362 }
3363 write!(self.out, ")")?;
3364 } else {
3365 self.put_index(index, context, true)?;
3366 }
3367
3368 write!(self.out, "]")?;
3369
3370 if accessing_wrapped_binding_array {
3371 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3372 }
3373
3374 Ok(())
3375 }
3376
3377 fn put_load(
3378 &mut self,
3379 pointer: Handle<crate::Expression>,
3380 context: &ExpressionContext,
3381 is_scoped: bool,
3382 ) -> BackendResult {
3383 let policy = context.choose_bounds_check_policy(pointer);
3386 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3387 && self.put_bounds_checks(
3388 pointer,
3389 context,
3390 back::Level(0),
3391 if is_scoped { "" } else { "(" },
3392 )?
3393 {
3394 write!(self.out, " ? ")?;
3395 self.put_unchecked_load(pointer, policy, context)?;
3396 write!(self.out, " : DefaultConstructible()")?;
3397
3398 if !is_scoped {
3399 write!(self.out, ")")?;
3400 }
3401 } else {
3402 self.put_unchecked_load(pointer, policy, context)?;
3403 }
3404
3405 Ok(())
3406 }
3407
3408 fn put_unchecked_load(
3409 &mut self,
3410 pointer: Handle<crate::Expression>,
3411 policy: index::BoundsCheckPolicy,
3412 context: &ExpressionContext,
3413 ) -> BackendResult {
3414 let is_atomic_pointer = context
3415 .resolve_type(pointer)
3416 .is_atomic_pointer(&context.module.types);
3417
3418 if is_atomic_pointer {
3419 write!(
3420 self.out,
3421 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3422 )?;
3423 self.put_access_chain(pointer, policy, context)?;
3424 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3425 } else {
3426 self.put_access_chain(pointer, policy, context)?;
3430 }
3431
3432 Ok(())
3433 }
3434
3435 fn put_return_value(
3436 &mut self,
3437 level: back::Level,
3438 expr_handle: Handle<crate::Expression>,
3439 result_struct: Option<&str>,
3440 context: &ExpressionContext,
3441 ) -> BackendResult {
3442 match result_struct {
3443 Some(struct_name) => {
3444 let mut has_point_size = false;
3445 let result_ty = context.function.result.as_ref().unwrap().ty;
3446 match context.module.types[result_ty].inner {
3447 crate::TypeInner::Struct { ref members, .. } => {
3448 let tmp = "_tmp";
3449 write!(self.out, "{level}const auto {tmp} = ")?;
3450 self.put_expression(expr_handle, context, true)?;
3451 writeln!(self.out, ";")?;
3452 write!(self.out, "{level}return {struct_name} {{")?;
3453
3454 let mut is_first = true;
3455
3456 for (index, member) in members.iter().enumerate() {
3457 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3458 member.binding
3459 {
3460 has_point_size = true;
3461 if !context.pipeline_options.allow_and_force_point_size {
3462 continue;
3463 }
3464 }
3465
3466 let comma = if is_first { "" } else { "," };
3467 is_first = false;
3468 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3469 if let crate::TypeInner::Array {
3473 size: crate::ArraySize::Constant(size),
3474 ..
3475 } = context.module.types[member.ty].inner
3476 {
3477 write!(self.out, "{comma} {{")?;
3478 for j in 0..size.get() {
3479 if j != 0 {
3480 write!(self.out, ",")?;
3481 }
3482 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3483 }
3484 write!(self.out, "}}")?;
3485 } else {
3486 write!(self.out, "{comma} {tmp}.{name}")?;
3487 }
3488 }
3489 }
3490 _ => {
3491 write!(self.out, "{level}return {struct_name} {{ ")?;
3492 self.put_expression(expr_handle, context, true)?;
3493 }
3494 }
3495
3496 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3497 let stage = context.module.entry_points[ep_index as usize].stage;
3498 if context.pipeline_options.allow_and_force_point_size
3499 && stage == crate::ShaderStage::Vertex
3500 && !has_point_size
3501 {
3502 write!(self.out, ", 1.0")?;
3504 }
3505 }
3506 write!(self.out, " }}")?;
3507 }
3508 None => {
3509 write!(self.out, "{level}return ")?;
3510 self.put_expression(expr_handle, context, true)?;
3511 }
3512 }
3513 writeln!(self.out, ";")?;
3514 Ok(())
3515 }
3516
3517 fn update_expressions_to_bake(
3522 &mut self,
3523 func: &crate::Function,
3524 info: &valid::FunctionInfo,
3525 context: &ExpressionContext,
3526 ) {
3527 use crate::Expression;
3528 self.need_bake_expressions.clear();
3529
3530 for (expr_handle, expr) in func.expressions.iter() {
3531 let expr_info = &info[expr_handle];
3534 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3535 if min_ref_count <= expr_info.ref_count {
3536 self.need_bake_expressions.insert(expr_handle);
3537 } else {
3538 match expr_info.ty {
3539 TypeResolution::Handle(h)
3541 if Some(h) == context.module.special_types.ray_desc =>
3542 {
3543 self.need_bake_expressions.insert(expr_handle);
3544 }
3545 _ => {}
3546 }
3547 }
3548
3549 if let Expression::Math {
3550 fun,
3551 arg,
3552 arg1,
3553 arg2,
3554 ..
3555 } = *expr
3556 {
3557 match fun {
3558 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3568 self.need_bake_expressions.insert(arg);
3569 self.need_bake_expressions.insert(arg1.unwrap());
3570 }
3571 crate::MathFunction::FirstLeadingBit => {
3572 self.need_bake_expressions.insert(arg);
3573 }
3574 crate::MathFunction::Pack4xI8
3575 | crate::MathFunction::Pack4xU8
3576 | crate::MathFunction::Pack4xI8Clamp
3577 | crate::MathFunction::Pack4xU8Clamp
3578 | crate::MathFunction::Unpack4xI8
3579 | crate::MathFunction::Unpack4xU8 => {
3580 if context.lang_version < (2, 1) {
3583 self.need_bake_expressions.insert(arg);
3584 }
3585 }
3586 crate::MathFunction::ExtractBits => {
3587 self.need_bake_expressions.insert(arg1.unwrap());
3589 }
3590 crate::MathFunction::InsertBits => {
3591 self.need_bake_expressions.insert(arg2.unwrap());
3593 }
3594 crate::MathFunction::Sign => {
3595 let inner = context.resolve_type(expr_handle);
3600 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3601 self.need_bake_expressions.insert(arg);
3602 }
3603 }
3604 _ => {}
3605 }
3606 }
3607 }
3608 }
3609
3610 pub(super) fn start_baking_expression(
3611 &mut self,
3612 handle: Handle<crate::Expression>,
3613 context: &ExpressionContext,
3614 name: &str,
3615 ) -> BackendResult {
3616 match context.info[handle].ty {
3617 TypeResolution::Handle(ty_handle) => {
3618 let ty_name = TypeContext {
3619 handle: ty_handle,
3620 gctx: context.module.to_ctx(),
3621 names: &self.names,
3622 access: crate::StorageAccess::empty(),
3623 first_time: false,
3624 };
3625 write!(self.out, "{ty_name}")?;
3626 }
3627 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3628 put_numeric_type(&mut self.out, scalar, &[])?;
3629 }
3630 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3631 put_numeric_type(&mut self.out, scalar, &[size])?;
3632 }
3633 TypeResolution::Value(crate::TypeInner::Matrix {
3634 columns,
3635 rows,
3636 scalar,
3637 }) => {
3638 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3639 }
3640 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3641 columns,
3642 rows,
3643 scalar,
3644 role: _,
3645 }) => {
3646 write!(
3647 self.out,
3648 "{}::simdgroup_{}{}x{}",
3649 NAMESPACE,
3650 scalar.to_msl_name(),
3651 columns as u32,
3652 rows as u32,
3653 )?;
3654 }
3655 TypeResolution::Value(ref other) => {
3656 log::warn!("Type {other:?} isn't a known local");
3657 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3658 }
3659 }
3660
3661 write!(self.out, " {name} = ")?;
3663
3664 Ok(())
3665 }
3666
3667 fn put_cache_restricted_level(
3680 &mut self,
3681 load: Handle<crate::Expression>,
3682 image: Handle<crate::Expression>,
3683 mip_level: Option<Handle<crate::Expression>>,
3684 indent: back::Level,
3685 context: &StatementContext,
3686 ) -> BackendResult {
3687 let level_of_detail = match mip_level {
3690 Some(level) => level,
3691 None => return Ok(()),
3692 };
3693
3694 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3695 || !context.expression.image_needs_lod(image)
3696 {
3697 return Ok(());
3698 }
3699
3700 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3701 self.put_restricted_scalar_image_index(
3702 image,
3703 level_of_detail,
3704 "get_num_mip_levels",
3705 &context.expression,
3706 )?;
3707 writeln!(self.out, ";")?;
3708
3709 Ok(())
3710 }
3711
3712 fn put_casting_to_packed_chars(
3718 &mut self,
3719 fun: crate::MathFunction,
3720 arg0: Handle<crate::Expression>,
3721 arg1: Handle<crate::Expression>,
3722 indent: back::Level,
3723 context: &StatementContext<'_>,
3724 ) -> Result<(), Error> {
3725 let packed_type = match fun {
3726 crate::MathFunction::Dot4I8Packed => "packed_char4",
3727 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3728 _ => unreachable!(),
3729 };
3730
3731 for arg in [arg0, arg1] {
3732 write!(
3733 self.out,
3734 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3735 Reinterpreted::new(packed_type, arg)
3736 )?;
3737 self.put_expression(arg, &context.expression, true)?;
3738 writeln!(self.out, ");")?;
3739 }
3740
3741 Ok(())
3742 }
3743
3744 fn put_block(
3745 &mut self,
3746 level: back::Level,
3747 statements: &[crate::Statement],
3748 context: &StatementContext,
3749 ) -> BackendResult {
3750 #[cfg(test)]
3752 self.put_block_stack_pointers
3753 .insert(ptr::from_ref(&level).cast());
3754
3755 for statement in statements {
3756 log::trace!("statement[{}] {:?}", level.0, statement);
3757 match *statement {
3758 crate::Statement::Emit(ref range) => {
3759 for handle in range.clone() {
3760 use crate::MathFunction as Mf;
3761
3762 match context.expression.function.expressions[handle] {
3763 crate::Expression::ImageLoad {
3766 image,
3767 level: mip_level,
3768 ..
3769 } => {
3770 self.put_cache_restricted_level(
3771 handle, image, mip_level, level, context,
3772 )?;
3773 }
3774
3775 crate::Expression::Math {
3784 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3785 arg,
3786 arg1,
3787 ..
3788 } if context.expression.lang_version >= (2, 1) => {
3789 self.put_casting_to_packed_chars(
3790 fun,
3791 arg,
3792 arg1.unwrap(),
3793 level,
3794 context,
3795 )?;
3796 }
3797
3798 _ => (),
3799 }
3800
3801 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3802 let expr_name = if ptr_class.is_some() {
3803 None } else if let Some(name) =
3805 context.expression.function.named_expressions.get(&handle)
3806 {
3807 Some(self.namer.call(name))
3817 } else {
3818 let bake = if context.expression.guarded_indices.contains(handle) {
3822 true
3823 } else {
3824 self.need_bake_expressions.contains(&handle)
3825 };
3826
3827 if bake {
3828 Some(Baked(handle).to_string())
3829 } else {
3830 None
3831 }
3832 };
3833
3834 if let Some(name) = expr_name {
3835 write!(self.out, "{level}")?;
3836 self.start_baking_expression(handle, &context.expression, &name)?;
3837 self.put_expression(handle, &context.expression, true)?;
3838 self.named_expressions.insert(handle, name);
3839 writeln!(self.out, ";")?;
3840 }
3841 }
3842 }
3843 crate::Statement::Block(ref block) => {
3844 if !block.is_empty() {
3845 writeln!(self.out, "{level}{{")?;
3846 self.put_block(level.next(), block, context)?;
3847 writeln!(self.out, "{level}}}")?;
3848 }
3849 }
3850 crate::Statement::If {
3851 condition,
3852 ref accept,
3853 ref reject,
3854 } => {
3855 write!(self.out, "{level}if (")?;
3856 self.put_expression(condition, &context.expression, true)?;
3857 writeln!(self.out, ") {{")?;
3858 self.put_block(level.next(), accept, context)?;
3859 if !reject.is_empty() {
3860 writeln!(self.out, "{level}}} else {{")?;
3861 self.put_block(level.next(), reject, context)?;
3862 }
3863 writeln!(self.out, "{level}}}")?;
3864 }
3865 crate::Statement::Switch {
3866 selector,
3867 ref cases,
3868 } => {
3869 write!(self.out, "{level}switch(")?;
3870 self.put_expression(selector, &context.expression, true)?;
3871 writeln!(self.out, ") {{")?;
3872 let lcase = level.next();
3873 for case in cases.iter() {
3874 match case.value {
3875 crate::SwitchValue::I32(value) => {
3876 write!(self.out, "{lcase}case {value}:")?;
3877 }
3878 crate::SwitchValue::U32(value) => {
3879 write!(self.out, "{lcase}case {value}u:")?;
3880 }
3881 crate::SwitchValue::Default => {
3882 write!(self.out, "{lcase}default:")?;
3883 }
3884 }
3885
3886 let write_block_braces = !(case.fall_through && case.body.is_empty());
3887 if write_block_braces {
3888 writeln!(self.out, " {{")?;
3889 } else {
3890 writeln!(self.out)?;
3891 }
3892
3893 self.put_block(lcase.next(), &case.body, context)?;
3894 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3895 {
3896 writeln!(self.out, "{}break;", lcase.next())?;
3897 }
3898
3899 if write_block_braces {
3900 writeln!(self.out, "{lcase}}}")?;
3901 }
3902 }
3903 writeln!(self.out, "{level}}}")?;
3904 }
3905 crate::Statement::Loop {
3906 ref body,
3907 ref continuing,
3908 break_if,
3909 } => {
3910 let force_loop_bound_statements =
3911 self.gen_force_bounded_loop_statements(level, context);
3912 let gate_name = (!continuing.is_empty() || break_if.is_some())
3913 .then(|| self.namer.call("loop_init"));
3914
3915 if let Some((ref decl, _)) = force_loop_bound_statements {
3916 writeln!(self.out, "{decl}")?;
3917 }
3918 if let Some(ref gate_name) = gate_name {
3919 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3920 }
3921
3922 writeln!(self.out, "{level}while(true) {{",)?;
3923 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3924 writeln!(self.out, "{break_and_inc}")?;
3925 }
3926 if let Some(ref gate_name) = gate_name {
3927 let lif = level.next();
3928 let lcontinuing = lif.next();
3929 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3930 self.put_block(lcontinuing, continuing, context)?;
3931 if let Some(condition) = break_if {
3932 write!(self.out, "{lcontinuing}if (")?;
3933 self.put_expression(condition, &context.expression, true)?;
3934 writeln!(self.out, ") {{")?;
3935 writeln!(self.out, "{}break;", lcontinuing.next())?;
3936 writeln!(self.out, "{lcontinuing}}}")?;
3937 }
3938 writeln!(self.out, "{lif}}}")?;
3939 writeln!(self.out, "{lif}{gate_name} = false;")?;
3940 }
3941 self.put_block(level.next(), body, context)?;
3942
3943 writeln!(self.out, "{level}}}")?;
3944 }
3945 crate::Statement::Break => {
3946 writeln!(self.out, "{level}break;")?;
3947 }
3948 crate::Statement::Continue => {
3949 writeln!(self.out, "{level}continue;")?;
3950 }
3951 crate::Statement::Return {
3952 value: Some(expr_handle),
3953 } => {
3954 self.put_return_value(
3955 level,
3956 expr_handle,
3957 context.result_struct,
3958 &context.expression,
3959 )?;
3960 }
3961 crate::Statement::Return { value: None } => {
3962 writeln!(self.out, "{level}return;")?;
3963 }
3964 crate::Statement::Kill => {
3965 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3966 }
3967 crate::Statement::ControlBarrier(flags)
3968 | crate::Statement::MemoryBarrier(flags) => {
3969 self.write_barrier(flags, level)?;
3970 }
3971 crate::Statement::Store { pointer, value } => {
3972 self.put_store(pointer, value, level, context)?
3973 }
3974 crate::Statement::ImageStore {
3975 image,
3976 coordinate,
3977 array_index,
3978 value,
3979 } => {
3980 let address = TexelAddress {
3981 coordinate,
3982 array_index,
3983 sample: None,
3984 level: None,
3985 };
3986 self.put_image_store(level, image, &address, value, context)?
3987 }
3988 crate::Statement::Call {
3989 function,
3990 ref arguments,
3991 result,
3992 } => {
3993 write!(self.out, "{level}")?;
3994 if let Some(expr) = result {
3995 let name = Baked(expr).to_string();
3996 self.start_baking_expression(expr, &context.expression, &name)?;
3997 self.named_expressions.insert(expr, name);
3998 }
3999 let fun_name = &self.names[&NameKey::Function(function)];
4000 write!(self.out, "{fun_name}(")?;
4001 for (i, &handle) in arguments.iter().enumerate() {
4003 if i != 0 {
4004 write!(self.out, ", ")?;
4005 }
4006 self.put_expression(handle, &context.expression, true)?;
4007 }
4008 let mut separate = !arguments.is_empty();
4010 let fun_info = &context.expression.mod_info[function];
4011 let mut needs_buffer_sizes = false;
4012 for (handle, var) in context.expression.module.global_variables.iter() {
4013 if fun_info[handle].is_empty() {
4014 continue;
4015 }
4016 if var.space.needs_pass_through() {
4017 let name = &self.names[&NameKey::GlobalVariable(handle)];
4018 if separate {
4019 write!(self.out, ", ")?;
4020 } else {
4021 separate = true;
4022 }
4023 write!(self.out, "{name}")?;
4024 }
4025 needs_buffer_sizes |=
4026 needs_array_length(var.ty, &context.expression.module.types);
4027 }
4028 if needs_buffer_sizes {
4029 if separate {
4030 write!(self.out, ", ")?;
4031 }
4032 write!(self.out, "_buffer_sizes")?;
4033 }
4034
4035 writeln!(self.out, ");")?;
4037 }
4038 crate::Statement::Atomic {
4039 pointer,
4040 ref fun,
4041 value,
4042 result,
4043 } => {
4044 let context = &context.expression;
4045
4046 write!(self.out, "{level}")?;
4051 let fun_key = if let Some(result) = result {
4052 let res_name = Baked(result).to_string();
4053 self.start_baking_expression(result, context, &res_name)?;
4054 self.named_expressions.insert(result, res_name);
4055 fun.to_msl()
4056 } else if context.resolve_type(value).scalar_width() == Some(8) {
4057 fun.to_msl_64_bit()?
4058 } else {
4059 fun.to_msl()
4060 };
4061
4062 let policy = context.choose_bounds_check_policy(pointer);
4066 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4067 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4068
4069 if checked {
4071 write!(self.out, " ? ")?;
4072 }
4073
4074 match *fun {
4076 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4077 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4078 self.put_access_chain(pointer, policy, context)?;
4079 write!(self.out, ", ")?;
4080 self.put_expression(cmp, context, true)?;
4081 write!(self.out, ", ")?;
4082 self.put_expression(value, context, true)?;
4083 write!(self.out, ")")?;
4084 }
4085 _ => {
4086 write!(
4087 self.out,
4088 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4089 )?;
4090 self.put_access_chain(pointer, policy, context)?;
4091 write!(self.out, ", ")?;
4092 self.put_expression(value, context, true)?;
4093 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4094 }
4095 }
4096
4097 if checked {
4099 write!(self.out, " : DefaultConstructible()")?;
4100 }
4101
4102 writeln!(self.out, ";")?;
4104 }
4105 crate::Statement::ImageAtomic {
4106 image,
4107 coordinate,
4108 array_index,
4109 fun,
4110 value,
4111 } => {
4112 let address = TexelAddress {
4113 coordinate,
4114 array_index,
4115 sample: None,
4116 level: None,
4117 };
4118 self.put_image_atomic(level, image, &address, fun, value, context)?
4119 }
4120 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4121 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4122
4123 write!(self.out, "{level}")?;
4124 let name = self.namer.call("");
4125 self.start_baking_expression(result, &context.expression, &name)?;
4126 self.put_load(pointer, &context.expression, true)?;
4127 self.named_expressions.insert(result, name);
4128
4129 writeln!(self.out, ";")?;
4130 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4131 }
4132 crate::Statement::RayQuery { query, ref fun } => {
4133 self.write_ray_query_stmt(level, context, query, fun)?;
4134 }
4135 crate::Statement::SubgroupBallot { result, predicate } => {
4136 write!(self.out, "{level}")?;
4137 let name = self.namer.call("");
4138 self.start_baking_expression(result, &context.expression, &name)?;
4139 self.named_expressions.insert(result, name);
4140 write!(
4141 self.out,
4142 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4143 )?;
4144 if let Some(predicate) = predicate {
4145 self.put_expression(predicate, &context.expression, true)?;
4146 } else {
4147 write!(self.out, "true")?;
4148 }
4149 writeln!(self.out, "), 0, 0, 0);")?;
4150 }
4151 crate::Statement::SubgroupCollectiveOperation {
4152 op,
4153 collective_op,
4154 argument,
4155 result,
4156 } => {
4157 write!(self.out, "{level}")?;
4158 let name = self.namer.call("");
4159 self.start_baking_expression(result, &context.expression, &name)?;
4160 self.named_expressions.insert(result, name);
4161 match (collective_op, op) {
4162 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4163 write!(self.out, "{NAMESPACE}::simd_all(")?
4164 }
4165 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4166 write!(self.out, "{NAMESPACE}::simd_any(")?
4167 }
4168 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4169 write!(self.out, "{NAMESPACE}::simd_sum(")?
4170 }
4171 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4172 write!(self.out, "{NAMESPACE}::simd_product(")?
4173 }
4174 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4175 write!(self.out, "{NAMESPACE}::simd_max(")?
4176 }
4177 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4178 write!(self.out, "{NAMESPACE}::simd_min(")?
4179 }
4180 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4181 write!(self.out, "{NAMESPACE}::simd_and(")?
4182 }
4183 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4184 write!(self.out, "{NAMESPACE}::simd_or(")?
4185 }
4186 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4187 write!(self.out, "{NAMESPACE}::simd_xor(")?
4188 }
4189 (
4190 crate::CollectiveOperation::ExclusiveScan,
4191 crate::SubgroupOperation::Add,
4192 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4193 (
4194 crate::CollectiveOperation::ExclusiveScan,
4195 crate::SubgroupOperation::Mul,
4196 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4197 (
4198 crate::CollectiveOperation::InclusiveScan,
4199 crate::SubgroupOperation::Add,
4200 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4201 (
4202 crate::CollectiveOperation::InclusiveScan,
4203 crate::SubgroupOperation::Mul,
4204 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4205 _ => unimplemented!(),
4206 }
4207 self.put_expression(argument, &context.expression, true)?;
4208 writeln!(self.out, ");")?;
4209 }
4210 crate::Statement::SubgroupGather {
4211 mode,
4212 argument,
4213 result,
4214 } => {
4215 write!(self.out, "{level}")?;
4216 let name = self.namer.call("");
4217 self.start_baking_expression(result, &context.expression, &name)?;
4218 self.named_expressions.insert(result, name);
4219 match mode {
4220 crate::GatherMode::BroadcastFirst => {
4221 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4222 }
4223 crate::GatherMode::Broadcast(_) => {
4224 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4225 }
4226 crate::GatherMode::Shuffle(_) => {
4227 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4228 }
4229 crate::GatherMode::ShuffleDown(_) => {
4230 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4231 }
4232 crate::GatherMode::ShuffleUp(_) => {
4233 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4234 }
4235 crate::GatherMode::ShuffleXor(_) => {
4236 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4237 }
4238 crate::GatherMode::QuadBroadcast(_) => {
4239 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4240 }
4241 crate::GatherMode::QuadSwap(_) => {
4242 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4243 }
4244 }
4245 self.put_expression(argument, &context.expression, true)?;
4246 match mode {
4247 crate::GatherMode::BroadcastFirst => {}
4248 crate::GatherMode::Broadcast(index)
4249 | crate::GatherMode::Shuffle(index)
4250 | crate::GatherMode::ShuffleDown(index)
4251 | crate::GatherMode::ShuffleUp(index)
4252 | crate::GatherMode::ShuffleXor(index)
4253 | crate::GatherMode::QuadBroadcast(index) => {
4254 write!(self.out, ", ")?;
4255 self.put_expression(index, &context.expression, true)?;
4256 }
4257 crate::GatherMode::QuadSwap(direction) => {
4258 write!(self.out, ", ")?;
4259 match direction {
4260 crate::Direction::X => {
4261 write!(self.out, "1u")?;
4262 }
4263 crate::Direction::Y => {
4264 write!(self.out, "2u")?;
4265 }
4266 crate::Direction::Diagonal => {
4267 write!(self.out, "3u")?;
4268 }
4269 }
4270 }
4271 }
4272 writeln!(self.out, ");")?;
4273 }
4274 crate::Statement::CooperativeStore { target, ref data } => {
4275 write!(self.out, "{level}simdgroup_store(")?;
4276 self.put_expression(target, &context.expression, true)?;
4277 write!(self.out, ", &")?;
4278 self.put_access_chain(
4279 data.pointer,
4280 context.expression.policies.index,
4281 &context.expression,
4282 )?;
4283 write!(self.out, ", ")?;
4284 self.put_expression(data.stride, &context.expression, true)?;
4285 if !data.row_major {
4290 let matrix_origin = "0";
4291 let transpose = true;
4292 write!(self.out, ", {matrix_origin}, {transpose}")?;
4293 }
4294 writeln!(self.out, ");")?;
4295 }
4296 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4297 }
4298 }
4299
4300 for statement in statements {
4303 if let crate::Statement::Emit(ref range) = *statement {
4304 for handle in range.clone() {
4305 self.named_expressions.shift_remove(&handle);
4306 }
4307 }
4308 }
4309 Ok(())
4310 }
4311
4312 fn put_store(
4313 &mut self,
4314 pointer: Handle<crate::Expression>,
4315 value: Handle<crate::Expression>,
4316 level: back::Level,
4317 context: &StatementContext,
4318 ) -> BackendResult {
4319 let policy = context.expression.choose_bounds_check_policy(pointer);
4320 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4321 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4322 {
4323 writeln!(self.out, ") {{")?;
4324 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4325 writeln!(self.out, "{level}}}")?;
4326 } else {
4327 self.put_unchecked_store(pointer, value, policy, level, context)?;
4328 }
4329
4330 Ok(())
4331 }
4332
4333 fn put_unchecked_store(
4334 &mut self,
4335 pointer: Handle<crate::Expression>,
4336 value: Handle<crate::Expression>,
4337 policy: index::BoundsCheckPolicy,
4338 level: back::Level,
4339 context: &StatementContext,
4340 ) -> BackendResult {
4341 let is_atomic_pointer = context
4342 .expression
4343 .resolve_type(pointer)
4344 .is_atomic_pointer(&context.expression.module.types);
4345
4346 if is_atomic_pointer {
4347 write!(
4348 self.out,
4349 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4350 )?;
4351 self.put_access_chain(pointer, policy, &context.expression)?;
4352 write!(self.out, ", ")?;
4353 self.put_expression(value, &context.expression, true)?;
4354 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4355 } else {
4356 write!(self.out, "{level}")?;
4357 self.put_access_chain(pointer, policy, &context.expression)?;
4358 write!(self.out, " = ")?;
4359 self.put_expression(value, &context.expression, true)?;
4360 writeln!(self.out, ";")?;
4361 }
4362
4363 Ok(())
4364 }
4365
4366 pub fn write(
4367 &mut self,
4368 module: &crate::Module,
4369 info: &valid::ModuleInfo,
4370 options: &Options,
4371 pipeline_options: &PipelineOptions,
4372 ) -> Result<TranslationInfo, Error> {
4373 self.emit_int_div_checks = options.emit_int_div_checks;
4374 self.names.clear();
4375 self.namer.reset(
4376 module,
4377 &super::keywords::RESERVED_SET,
4378 proc::KeywordSet::empty(),
4379 proc::CaseInsensitiveKeywordSet::empty(),
4380 &[
4381 CLAMPED_LOD_LOAD_PREFIX,
4382 super::ray::INTERSECTION_FUNCTION_NAME,
4383 ],
4384 &mut self.names,
4385 );
4386 self.wrapped_functions.clear();
4387 self.struct_member_pads.clear();
4388
4389 writeln!(
4390 self.out,
4391 "// language: metal{}.{}",
4392 options.lang_version.0, options.lang_version.1
4393 )?;
4394 writeln!(self.out, "#include <metal_stdlib>")?;
4395 writeln!(self.out, "#include <simd/simd.h>")?;
4396 writeln!(self.out)?;
4397 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4399
4400 if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4401 return Err(Error::UnsupportedMeshShader);
4402 }
4403 self.needs_object_memory_barriers = module
4404 .entry_points
4405 .iter()
4406 .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4407
4408 if module.special_types.ray_desc.is_some()
4409 || module.special_types.ray_intersection.is_some()
4410 {
4411 if options.lang_version < (2, 4) {
4412 return Err(Error::UnsupportedRayTracing);
4413 }
4414 }
4415
4416 if options
4417 .bounds_check_policies
4418 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4419 {
4420 self.put_default_constructible()?;
4421 }
4422 writeln!(self.out)?;
4423
4424 {
4425 let globals: Vec<Handle<crate::GlobalVariable>> = module
4428 .global_variables
4429 .iter()
4430 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4431 .map(|(handle, _)| handle)
4432 .collect();
4433
4434 let mut buffer_indices = vec![];
4435 for vbm in &pipeline_options.vertex_buffer_mappings {
4436 buffer_indices.push(vbm.id);
4437 }
4438
4439 if !globals.is_empty() || !buffer_indices.is_empty() {
4440 writeln!(self.out, "struct _mslBufferSizes {{")?;
4441
4442 for global in globals {
4443 writeln!(
4444 self.out,
4445 "{}uint {};",
4446 back::INDENT,
4447 ArraySizeMember(global)
4448 )?;
4449 }
4450
4451 for idx in buffer_indices {
4452 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4453 }
4454
4455 writeln!(self.out, "}};")?;
4456 writeln!(self.out)?;
4457 }
4458 };
4459
4460 self.write_type_defs(module)?;
4461 self.write_global_constants(module, info)?;
4462 self.write_functions(module, info, options, pipeline_options)
4463 }
4464
4465 fn put_default_constructible(&mut self) -> BackendResult {
4478 let tab = back::INDENT;
4479 writeln!(self.out, "struct DefaultConstructible {{")?;
4480 writeln!(self.out, "{tab}template<typename T>")?;
4481 writeln!(self.out, "{tab}operator T() && {{")?;
4482 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4483 writeln!(self.out, "{tab}}}")?;
4484 writeln!(self.out, "}};")?;
4485 Ok(())
4486 }
4487
4488 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4489 let mut generated_argument_buffer_wrapper = false;
4490 let mut generated_external_texture_wrapper = false;
4491 for (handle, ty) in module.types.iter() {
4492 match ty.inner {
4493 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4494 writeln!(self.out, "template <typename T>")?;
4495 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4496 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4497 writeln!(self.out, "}};")?;
4498 generated_argument_buffer_wrapper = true;
4499 }
4500 crate::TypeInner::Image {
4501 class: crate::ImageClass::External,
4502 ..
4503 } if !generated_external_texture_wrapper => {
4504 let params_ty_name = &self.names
4505 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4506 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4507 writeln!(
4508 self.out,
4509 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4510 back::INDENT
4511 )?;
4512 writeln!(
4513 self.out,
4514 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4515 back::INDENT
4516 )?;
4517 writeln!(
4518 self.out,
4519 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4520 back::INDENT
4521 )?;
4522 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4523 writeln!(self.out, "}};")?;
4524 generated_external_texture_wrapper = true;
4525 }
4526 _ => {}
4527 }
4528
4529 if !ty.needs_alias() {
4530 continue;
4531 }
4532 let name = &self.names[&NameKey::Type(handle)];
4533 match ty.inner {
4534 crate::TypeInner::Array {
4548 base,
4549 size,
4550 stride: _,
4551 } => {
4552 let base_name = TypeContext {
4553 handle: base,
4554 gctx: module.to_ctx(),
4555 names: &self.names,
4556 access: crate::StorageAccess::empty(),
4557 first_time: false,
4558 };
4559
4560 match size.resolve(module.to_ctx())? {
4561 proc::IndexableLength::Known(size) => {
4562 writeln!(self.out, "struct {name} {{")?;
4563 writeln!(
4564 self.out,
4565 "{}{} {}[{}];",
4566 back::INDENT,
4567 base_name,
4568 WRAPPED_ARRAY_FIELD,
4569 size
4570 )?;
4571 writeln!(self.out, "}};")?;
4572 }
4573 proc::IndexableLength::Dynamic => {
4574 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4575 }
4576 }
4577 }
4578 crate::TypeInner::Struct {
4579 ref members, span, ..
4580 } => {
4581 writeln!(self.out, "struct {name} {{")?;
4582 let mut last_offset = 0;
4583 for (index, member) in members.iter().enumerate() {
4584 if member.offset > last_offset {
4585 self.struct_member_pads.insert((handle, index as u32));
4586 let pad = member.offset - last_offset;
4587 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4588 }
4589 let ty_inner = &module.types[member.ty].inner;
4590 last_offset = member.offset + ty_inner.size(module.to_ctx());
4591
4592 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4593
4594 match should_pack_struct_member(members, span, index, module) {
4596 Some(scalar) => {
4597 writeln!(
4598 self.out,
4599 "{}{}::packed_{}3 {};",
4600 back::INDENT,
4601 NAMESPACE,
4602 scalar.to_msl_name(),
4603 member_name
4604 )?;
4605 }
4606 None => {
4607 let base_name = TypeContext {
4608 handle: member.ty,
4609 gctx: module.to_ctx(),
4610 names: &self.names,
4611 access: crate::StorageAccess::empty(),
4612 first_time: false,
4613 };
4614 writeln!(
4615 self.out,
4616 "{}{} {};",
4617 back::INDENT,
4618 base_name,
4619 member_name
4620 )?;
4621
4622 if let crate::TypeInner::Vector {
4624 size: crate::VectorSize::Tri,
4625 scalar,
4626 } = *ty_inner
4627 {
4628 last_offset += scalar.width as u32;
4629 }
4630 }
4631 }
4632 }
4633 if last_offset < span {
4634 let pad = span - last_offset;
4635 writeln!(
4636 self.out,
4637 "{}char _pad{}[{}];",
4638 back::INDENT,
4639 members.len(),
4640 pad
4641 )?;
4642 }
4643 writeln!(self.out, "}};")?;
4644 }
4645 _ => {
4646 let ty_name = TypeContext {
4647 handle,
4648 gctx: module.to_ctx(),
4649 names: &self.names,
4650 access: crate::StorageAccess::empty(),
4651 first_time: true,
4652 };
4653 writeln!(self.out, "typedef {ty_name} {name};")?;
4654 }
4655 }
4656 }
4657
4658 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4660 match type_key {
4661 &crate::PredeclaredType::ModfResult { size, scalar }
4662 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4663 let arg_type_name_owner;
4664 let arg_type_name = if let Some(size) = size {
4665 arg_type_name_owner = format!(
4666 "{NAMESPACE}::{}{}",
4667 if scalar.width == 8 { "double" } else { "float" },
4668 size as u8
4669 );
4670 &arg_type_name_owner
4671 } else if scalar.width == 8 {
4672 "double"
4673 } else {
4674 "float"
4675 };
4676
4677 let other_type_name_owner;
4678 let (defined_func_name, called_func_name, other_type_name) =
4679 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4680 (MODF_FUNCTION, "modf", arg_type_name)
4681 } else {
4682 let other_type_name = if let Some(size) = size {
4683 other_type_name_owner = format!("int{}", size as u8);
4684 &other_type_name_owner
4685 } else {
4686 "int"
4687 };
4688 (FREXP_FUNCTION, "frexp", other_type_name)
4689 };
4690
4691 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4692
4693 writeln!(self.out)?;
4694 writeln!(
4695 self.out,
4696 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4697 {other_type_name} other;
4698 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4699 return {struct_name}{{ fract, other }};
4700}}"
4701 )?;
4702 }
4703 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4704 let arg_type_name = scalar.to_msl_name();
4705 let called_func_name = "atomic_compare_exchange_weak_explicit";
4706 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4707 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4708
4709 writeln!(self.out)?;
4710
4711 for address_space_name in ["device", "threadgroup"] {
4712 writeln!(
4713 self.out,
4714 "\
4715template <typename A>
4716{struct_name} {defined_func_name}(
4717 {address_space_name} A *atomic_ptr,
4718 {arg_type_name} cmp,
4719 {arg_type_name} v
4720) {{
4721 bool swapped = {NAMESPACE}::{called_func_name}(
4722 atomic_ptr, &cmp, v,
4723 metal::memory_order_relaxed, metal::memory_order_relaxed
4724 );
4725 return {struct_name}{{cmp, swapped}};
4726}}"
4727 )?;
4728 }
4729 }
4730 }
4731 }
4732
4733 Ok(())
4734 }
4735
4736 fn write_global_constants(
4738 &mut self,
4739 module: &crate::Module,
4740 mod_info: &valid::ModuleInfo,
4741 ) -> BackendResult {
4742 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4743
4744 for (handle, constant) in constants {
4745 let ty_name = TypeContext {
4746 handle: constant.ty,
4747 gctx: module.to_ctx(),
4748 names: &self.names,
4749 access: crate::StorageAccess::empty(),
4750 first_time: false,
4751 };
4752 let name = &self.names[&NameKey::Constant(handle)];
4753 write!(self.out, "constant {ty_name} {name} = ")?;
4754 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4755 writeln!(self.out, ";")?;
4756 }
4757
4758 Ok(())
4759 }
4760
4761 fn put_inline_sampler_properties(
4762 &mut self,
4763 level: back::Level,
4764 sampler: &sm::InlineSampler,
4765 ) -> BackendResult {
4766 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4767 writeln!(
4768 self.out,
4769 "{}{}::{}_address::{},",
4770 level,
4771 NAMESPACE,
4772 letter,
4773 address.as_str(),
4774 )?;
4775 }
4776 writeln!(
4777 self.out,
4778 "{}{}::mag_filter::{},",
4779 level,
4780 NAMESPACE,
4781 sampler.mag_filter.as_str(),
4782 )?;
4783 writeln!(
4784 self.out,
4785 "{}{}::min_filter::{},",
4786 level,
4787 NAMESPACE,
4788 sampler.min_filter.as_str(),
4789 )?;
4790 if let Some(filter) = sampler.mip_filter {
4791 writeln!(
4792 self.out,
4793 "{}{}::mip_filter::{},",
4794 level,
4795 NAMESPACE,
4796 filter.as_str(),
4797 )?;
4798 }
4799 if sampler.border_color != sm::BorderColor::TransparentBlack {
4801 writeln!(
4802 self.out,
4803 "{}{}::border_color::{},",
4804 level,
4805 NAMESPACE,
4806 sampler.border_color.as_str(),
4807 )?;
4808 }
4809 if false {
4813 if let Some(ref lod) = sampler.lod_clamp {
4814 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4815 }
4816 if let Some(aniso) = sampler.max_anisotropy {
4817 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4818 }
4819 }
4820 if sampler.compare_func != sm::CompareFunc::Never {
4821 writeln!(
4822 self.out,
4823 "{}{}::compare_func::{},",
4824 level,
4825 NAMESPACE,
4826 sampler.compare_func.as_str(),
4827 )?;
4828 }
4829 writeln!(
4830 self.out,
4831 "{}{}::coord::{}",
4832 level,
4833 NAMESPACE,
4834 sampler.coord.as_str()
4835 )?;
4836 Ok(())
4837 }
4838
4839 fn write_unpacking_function(
4840 &mut self,
4841 format: back::msl::VertexFormat,
4842 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4843 use crate::{Scalar, VectorSize};
4844 use back::msl::VertexFormat::*;
4845 match format {
4846 Uint8 => {
4847 let name = self.namer.call("unpackUint8");
4848 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4849 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4850 writeln!(self.out, "}}")?;
4851 Ok((name, 1, None, Scalar::U32))
4852 }
4853 Uint8x2 => {
4854 let name = self.namer.call("unpackUint8x2");
4855 writeln!(
4856 self.out,
4857 "metal::uint2 {name}(metal::uchar b0, \
4858 metal::uchar b1) {{"
4859 )?;
4860 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4861 writeln!(self.out, "}}")?;
4862 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4863 }
4864 Uint8x4 => {
4865 let name = self.namer.call("unpackUint8x4");
4866 writeln!(
4867 self.out,
4868 "metal::uint4 {name}(metal::uchar b0, \
4869 metal::uchar b1, \
4870 metal::uchar b2, \
4871 metal::uchar b3) {{"
4872 )?;
4873 writeln!(
4874 self.out,
4875 "{}return metal::uint4(b0, b1, b2, b3);",
4876 back::INDENT
4877 )?;
4878 writeln!(self.out, "}}")?;
4879 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4880 }
4881 Sint8 => {
4882 let name = self.namer.call("unpackSint8");
4883 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4884 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4885 writeln!(self.out, "}}")?;
4886 Ok((name, 1, None, Scalar::I32))
4887 }
4888 Sint8x2 => {
4889 let name = self.namer.call("unpackSint8x2");
4890 writeln!(
4891 self.out,
4892 "metal::int2 {name}(metal::uchar b0, \
4893 metal::uchar b1) {{"
4894 )?;
4895 writeln!(
4896 self.out,
4897 "{}return metal::int2(as_type<char>(b0), \
4898 as_type<char>(b1));",
4899 back::INDENT
4900 )?;
4901 writeln!(self.out, "}}")?;
4902 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4903 }
4904 Sint8x4 => {
4905 let name = self.namer.call("unpackSint8x4");
4906 writeln!(
4907 self.out,
4908 "metal::int4 {name}(metal::uchar b0, \
4909 metal::uchar b1, \
4910 metal::uchar b2, \
4911 metal::uchar b3) {{"
4912 )?;
4913 writeln!(
4914 self.out,
4915 "{}return metal::int4(as_type<char>(b0), \
4916 as_type<char>(b1), \
4917 as_type<char>(b2), \
4918 as_type<char>(b3));",
4919 back::INDENT
4920 )?;
4921 writeln!(self.out, "}}")?;
4922 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4923 }
4924 Unorm8 => {
4925 let name = self.namer.call("unpackUnorm8");
4926 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4927 writeln!(
4928 self.out,
4929 "{}return float(float(b0) / 255.0f);",
4930 back::INDENT
4931 )?;
4932 writeln!(self.out, "}}")?;
4933 Ok((name, 1, None, Scalar::F32))
4934 }
4935 Unorm8x2 => {
4936 let name = self.namer.call("unpackUnorm8x2");
4937 writeln!(
4938 self.out,
4939 "metal::float2 {name}(metal::uchar b0, \
4940 metal::uchar b1) {{"
4941 )?;
4942 writeln!(
4943 self.out,
4944 "{}return metal::float2(float(b0) / 255.0f, \
4945 float(b1) / 255.0f);",
4946 back::INDENT
4947 )?;
4948 writeln!(self.out, "}}")?;
4949 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4950 }
4951 Unorm8x4 => {
4952 let name = self.namer.call("unpackUnorm8x4");
4953 writeln!(
4954 self.out,
4955 "metal::float4 {name}(metal::uchar b0, \
4956 metal::uchar b1, \
4957 metal::uchar b2, \
4958 metal::uchar b3) {{"
4959 )?;
4960 writeln!(
4961 self.out,
4962 "{}return metal::float4(float(b0) / 255.0f, \
4963 float(b1) / 255.0f, \
4964 float(b2) / 255.0f, \
4965 float(b3) / 255.0f);",
4966 back::INDENT
4967 )?;
4968 writeln!(self.out, "}}")?;
4969 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4970 }
4971 Snorm8 => {
4972 let name = self.namer.call("unpackSnorm8");
4973 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4974 writeln!(
4975 self.out,
4976 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4977 back::INDENT
4978 )?;
4979 writeln!(self.out, "}}")?;
4980 Ok((name, 1, None, Scalar::F32))
4981 }
4982 Snorm8x2 => {
4983 let name = self.namer.call("unpackSnorm8x2");
4984 writeln!(
4985 self.out,
4986 "metal::float2 {name}(metal::uchar b0, \
4987 metal::uchar b1) {{"
4988 )?;
4989 writeln!(
4990 self.out,
4991 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4992 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4993 back::INDENT
4994 )?;
4995 writeln!(self.out, "}}")?;
4996 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4997 }
4998 Snorm8x4 => {
4999 let name = self.namer.call("unpackSnorm8x4");
5000 writeln!(
5001 self.out,
5002 "metal::float4 {name}(metal::uchar b0, \
5003 metal::uchar b1, \
5004 metal::uchar b2, \
5005 metal::uchar b3) {{"
5006 )?;
5007 writeln!(
5008 self.out,
5009 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5010 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5011 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5012 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5013 back::INDENT
5014 )?;
5015 writeln!(self.out, "}}")?;
5016 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5017 }
5018 Uint16 => {
5019 let name = self.namer.call("unpackUint16");
5020 writeln!(
5021 self.out,
5022 "metal::uint {name}(metal::uint b0, \
5023 metal::uint b1) {{"
5024 )?;
5025 writeln!(
5026 self.out,
5027 "{}return metal::uint(b1 << 8 | b0);",
5028 back::INDENT
5029 )?;
5030 writeln!(self.out, "}}")?;
5031 Ok((name, 2, None, Scalar::U32))
5032 }
5033 Uint16x2 => {
5034 let name = self.namer.call("unpackUint16x2");
5035 writeln!(
5036 self.out,
5037 "metal::uint2 {name}(metal::uint b0, \
5038 metal::uint b1, \
5039 metal::uint b2, \
5040 metal::uint b3) {{"
5041 )?;
5042 writeln!(
5043 self.out,
5044 "{}return metal::uint2(b1 << 8 | b0, \
5045 b3 << 8 | b2);",
5046 back::INDENT
5047 )?;
5048 writeln!(self.out, "}}")?;
5049 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5050 }
5051 Uint16x4 => {
5052 let name = self.namer.call("unpackUint16x4");
5053 writeln!(
5054 self.out,
5055 "metal::uint4 {name}(metal::uint b0, \
5056 metal::uint b1, \
5057 metal::uint b2, \
5058 metal::uint b3, \
5059 metal::uint b4, \
5060 metal::uint b5, \
5061 metal::uint b6, \
5062 metal::uint b7) {{"
5063 )?;
5064 writeln!(
5065 self.out,
5066 "{}return metal::uint4(b1 << 8 | b0, \
5067 b3 << 8 | b2, \
5068 b5 << 8 | b4, \
5069 b7 << 8 | b6);",
5070 back::INDENT
5071 )?;
5072 writeln!(self.out, "}}")?;
5073 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5074 }
5075 Sint16 => {
5076 let name = self.namer.call("unpackSint16");
5077 writeln!(
5078 self.out,
5079 "int {name}(metal::ushort b0, \
5080 metal::ushort b1) {{"
5081 )?;
5082 writeln!(
5083 self.out,
5084 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5085 back::INDENT
5086 )?;
5087 writeln!(self.out, "}}")?;
5088 Ok((name, 2, None, Scalar::I32))
5089 }
5090 Sint16x2 => {
5091 let name = self.namer.call("unpackSint16x2");
5092 writeln!(
5093 self.out,
5094 "metal::int2 {name}(metal::ushort b0, \
5095 metal::ushort b1, \
5096 metal::ushort b2, \
5097 metal::ushort b3) {{"
5098 )?;
5099 writeln!(
5100 self.out,
5101 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5102 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5103 back::INDENT
5104 )?;
5105 writeln!(self.out, "}}")?;
5106 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5107 }
5108 Sint16x4 => {
5109 let name = self.namer.call("unpackSint16x4");
5110 writeln!(
5111 self.out,
5112 "metal::int4 {name}(metal::ushort b0, \
5113 metal::ushort b1, \
5114 metal::ushort b2, \
5115 metal::ushort b3, \
5116 metal::ushort b4, \
5117 metal::ushort b5, \
5118 metal::ushort b6, \
5119 metal::ushort b7) {{"
5120 )?;
5121 writeln!(
5122 self.out,
5123 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5124 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5125 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5126 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5127 back::INDENT
5128 )?;
5129 writeln!(self.out, "}}")?;
5130 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5131 }
5132 Unorm16 => {
5133 let name = self.namer.call("unpackUnorm16");
5134 writeln!(
5135 self.out,
5136 "float {name}(metal::ushort b0, \
5137 metal::ushort b1) {{"
5138 )?;
5139 writeln!(
5140 self.out,
5141 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5142 back::INDENT
5143 )?;
5144 writeln!(self.out, "}}")?;
5145 Ok((name, 2, None, Scalar::F32))
5146 }
5147 Unorm16x2 => {
5148 let name = self.namer.call("unpackUnorm16x2");
5149 writeln!(
5150 self.out,
5151 "metal::float2 {name}(metal::ushort b0, \
5152 metal::ushort b1, \
5153 metal::ushort b2, \
5154 metal::ushort b3) {{"
5155 )?;
5156 writeln!(
5157 self.out,
5158 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5159 float(b3 << 8 | b2) / 65535.0f);",
5160 back::INDENT
5161 )?;
5162 writeln!(self.out, "}}")?;
5163 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5164 }
5165 Unorm16x4 => {
5166 let name = self.namer.call("unpackUnorm16x4");
5167 writeln!(
5168 self.out,
5169 "metal::float4 {name}(metal::ushort b0, \
5170 metal::ushort b1, \
5171 metal::ushort b2, \
5172 metal::ushort b3, \
5173 metal::ushort b4, \
5174 metal::ushort b5, \
5175 metal::ushort b6, \
5176 metal::ushort b7) {{"
5177 )?;
5178 writeln!(
5179 self.out,
5180 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5181 float(b3 << 8 | b2) / 65535.0f, \
5182 float(b5 << 8 | b4) / 65535.0f, \
5183 float(b7 << 8 | b6) / 65535.0f);",
5184 back::INDENT
5185 )?;
5186 writeln!(self.out, "}}")?;
5187 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5188 }
5189 Snorm16 => {
5190 let name = self.namer.call("unpackSnorm16");
5191 writeln!(
5192 self.out,
5193 "float {name}(metal::ushort b0, \
5194 metal::ushort b1) {{"
5195 )?;
5196 writeln!(
5197 self.out,
5198 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5199 back::INDENT
5200 )?;
5201 writeln!(self.out, "}}")?;
5202 Ok((name, 2, None, Scalar::F32))
5203 }
5204 Snorm16x2 => {
5205 let name = self.namer.call("unpackSnorm16x2");
5206 writeln!(
5207 self.out,
5208 "metal::float2 {name}(uint b0, \
5209 uint b1, \
5210 uint b2, \
5211 uint b3) {{"
5212 )?;
5213 writeln!(
5214 self.out,
5215 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5216 back::INDENT
5217 )?;
5218 writeln!(self.out, "}}")?;
5219 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5220 }
5221 Snorm16x4 => {
5222 let name = self.namer.call("unpackSnorm16x4");
5223 writeln!(
5224 self.out,
5225 "metal::float4 {name}(uint b0, \
5226 uint b1, \
5227 uint b2, \
5228 uint b3, \
5229 uint b4, \
5230 uint b5, \
5231 uint b6, \
5232 uint b7) {{"
5233 )?;
5234 writeln!(
5235 self.out,
5236 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5237 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5238 back::INDENT
5239 )?;
5240 writeln!(self.out, "}}")?;
5241 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5242 }
5243 Float16 => {
5244 let name = self.namer.call("unpackFloat16");
5245 writeln!(
5246 self.out,
5247 "float {name}(metal::ushort b0, \
5248 metal::ushort b1) {{"
5249 )?;
5250 writeln!(
5251 self.out,
5252 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5253 back::INDENT
5254 )?;
5255 writeln!(self.out, "}}")?;
5256 Ok((name, 2, None, Scalar::F32))
5257 }
5258 Float16x2 => {
5259 let name = self.namer.call("unpackFloat16x2");
5260 writeln!(
5261 self.out,
5262 "metal::float2 {name}(metal::ushort b0, \
5263 metal::ushort b1, \
5264 metal::ushort b2, \
5265 metal::ushort b3) {{"
5266 )?;
5267 writeln!(
5268 self.out,
5269 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5270 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5271 back::INDENT
5272 )?;
5273 writeln!(self.out, "}}")?;
5274 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5275 }
5276 Float16x4 => {
5277 let name = self.namer.call("unpackFloat16x4");
5278 writeln!(
5279 self.out,
5280 "metal::float4 {name}(metal::ushort b0, \
5281 metal::ushort b1, \
5282 metal::ushort b2, \
5283 metal::ushort b3, \
5284 metal::ushort b4, \
5285 metal::ushort b5, \
5286 metal::ushort b6, \
5287 metal::ushort b7) {{"
5288 )?;
5289 writeln!(
5290 self.out,
5291 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5292 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5293 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5294 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5295 back::INDENT
5296 )?;
5297 writeln!(self.out, "}}")?;
5298 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5299 }
5300 Float32 => {
5301 let name = self.namer.call("unpackFloat32");
5302 writeln!(
5303 self.out,
5304 "float {name}(uint b0, \
5305 uint b1, \
5306 uint b2, \
5307 uint b3) {{"
5308 )?;
5309 writeln!(
5310 self.out,
5311 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5312 back::INDENT
5313 )?;
5314 writeln!(self.out, "}}")?;
5315 Ok((name, 4, None, Scalar::F32))
5316 }
5317 Float32x2 => {
5318 let name = self.namer.call("unpackFloat32x2");
5319 writeln!(
5320 self.out,
5321 "metal::float2 {name}(uint b0, \
5322 uint b1, \
5323 uint b2, \
5324 uint b3, \
5325 uint b4, \
5326 uint b5, \
5327 uint b6, \
5328 uint b7) {{"
5329 )?;
5330 writeln!(
5331 self.out,
5332 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5333 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5334 back::INDENT
5335 )?;
5336 writeln!(self.out, "}}")?;
5337 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5338 }
5339 Float32x3 => {
5340 let name = self.namer.call("unpackFloat32x3");
5341 writeln!(
5342 self.out,
5343 "metal::float3 {name}(uint b0, \
5344 uint b1, \
5345 uint b2, \
5346 uint b3, \
5347 uint b4, \
5348 uint b5, \
5349 uint b6, \
5350 uint b7, \
5351 uint b8, \
5352 uint b9, \
5353 uint b10, \
5354 uint b11) {{"
5355 )?;
5356 writeln!(
5357 self.out,
5358 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5359 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5360 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5361 back::INDENT
5362 )?;
5363 writeln!(self.out, "}}")?;
5364 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5365 }
5366 Float32x4 => {
5367 let name = self.namer.call("unpackFloat32x4");
5368 writeln!(
5369 self.out,
5370 "metal::float4 {name}(uint b0, \
5371 uint b1, \
5372 uint b2, \
5373 uint b3, \
5374 uint b4, \
5375 uint b5, \
5376 uint b6, \
5377 uint b7, \
5378 uint b8, \
5379 uint b9, \
5380 uint b10, \
5381 uint b11, \
5382 uint b12, \
5383 uint b13, \
5384 uint b14, \
5385 uint b15) {{"
5386 )?;
5387 writeln!(
5388 self.out,
5389 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5390 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5391 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5392 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5393 back::INDENT
5394 )?;
5395 writeln!(self.out, "}}")?;
5396 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5397 }
5398 Uint32 => {
5399 let name = self.namer.call("unpackUint32");
5400 writeln!(
5401 self.out,
5402 "uint {name}(uint b0, \
5403 uint b1, \
5404 uint b2, \
5405 uint b3) {{"
5406 )?;
5407 writeln!(
5408 self.out,
5409 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5410 back::INDENT
5411 )?;
5412 writeln!(self.out, "}}")?;
5413 Ok((name, 4, None, Scalar::U32))
5414 }
5415 Uint32x2 => {
5416 let name = self.namer.call("unpackUint32x2");
5417 writeln!(
5418 self.out,
5419 "uint2 {name}(uint b0, \
5420 uint b1, \
5421 uint b2, \
5422 uint b3, \
5423 uint b4, \
5424 uint b5, \
5425 uint b6, \
5426 uint b7) {{"
5427 )?;
5428 writeln!(
5429 self.out,
5430 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5431 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5432 back::INDENT
5433 )?;
5434 writeln!(self.out, "}}")?;
5435 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5436 }
5437 Uint32x3 => {
5438 let name = self.namer.call("unpackUint32x3");
5439 writeln!(
5440 self.out,
5441 "uint3 {name}(uint b0, \
5442 uint b1, \
5443 uint b2, \
5444 uint b3, \
5445 uint b4, \
5446 uint b5, \
5447 uint b6, \
5448 uint b7, \
5449 uint b8, \
5450 uint b9, \
5451 uint b10, \
5452 uint b11) {{"
5453 )?;
5454 writeln!(
5455 self.out,
5456 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5457 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5458 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5459 back::INDENT
5460 )?;
5461 writeln!(self.out, "}}")?;
5462 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5463 }
5464 Uint32x4 => {
5465 let name = self.namer.call("unpackUint32x4");
5466 writeln!(
5467 self.out,
5468 "{NAMESPACE}::uint4 {name}(uint b0, \
5469 uint b1, \
5470 uint b2, \
5471 uint b3, \
5472 uint b4, \
5473 uint b5, \
5474 uint b6, \
5475 uint b7, \
5476 uint b8, \
5477 uint b9, \
5478 uint b10, \
5479 uint b11, \
5480 uint b12, \
5481 uint b13, \
5482 uint b14, \
5483 uint b15) {{"
5484 )?;
5485 writeln!(
5486 self.out,
5487 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5488 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5489 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5490 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5491 back::INDENT
5492 )?;
5493 writeln!(self.out, "}}")?;
5494 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5495 }
5496 Sint32 => {
5497 let name = self.namer.call("unpackSint32");
5498 writeln!(
5499 self.out,
5500 "int {name}(uint b0, \
5501 uint b1, \
5502 uint b2, \
5503 uint b3) {{"
5504 )?;
5505 writeln!(
5506 self.out,
5507 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5508 back::INDENT
5509 )?;
5510 writeln!(self.out, "}}")?;
5511 Ok((name, 4, None, Scalar::I32))
5512 }
5513 Sint32x2 => {
5514 let name = self.namer.call("unpackSint32x2");
5515 writeln!(
5516 self.out,
5517 "metal::int2 {name}(uint b0, \
5518 uint b1, \
5519 uint b2, \
5520 uint b3, \
5521 uint b4, \
5522 uint b5, \
5523 uint b6, \
5524 uint b7) {{"
5525 )?;
5526 writeln!(
5527 self.out,
5528 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5529 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5530 back::INDENT
5531 )?;
5532 writeln!(self.out, "}}")?;
5533 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5534 }
5535 Sint32x3 => {
5536 let name = self.namer.call("unpackSint32x3");
5537 writeln!(
5538 self.out,
5539 "metal::int3 {name}(uint b0, \
5540 uint b1, \
5541 uint b2, \
5542 uint b3, \
5543 uint b4, \
5544 uint b5, \
5545 uint b6, \
5546 uint b7, \
5547 uint b8, \
5548 uint b9, \
5549 uint b10, \
5550 uint b11) {{"
5551 )?;
5552 writeln!(
5553 self.out,
5554 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5555 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5556 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5557 back::INDENT
5558 )?;
5559 writeln!(self.out, "}}")?;
5560 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5561 }
5562 Sint32x4 => {
5563 let name = self.namer.call("unpackSint32x4");
5564 writeln!(
5565 self.out,
5566 "metal::int4 {name}(uint b0, \
5567 uint b1, \
5568 uint b2, \
5569 uint b3, \
5570 uint b4, \
5571 uint b5, \
5572 uint b6, \
5573 uint b7, \
5574 uint b8, \
5575 uint b9, \
5576 uint b10, \
5577 uint b11, \
5578 uint b12, \
5579 uint b13, \
5580 uint b14, \
5581 uint b15) {{"
5582 )?;
5583 writeln!(
5584 self.out,
5585 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5586 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5587 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5588 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5589 back::INDENT
5590 )?;
5591 writeln!(self.out, "}}")?;
5592 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5593 }
5594 Unorm10_10_10_2 => {
5595 let name = self.namer.call("unpackUnorm10_10_10_2");
5596 writeln!(
5597 self.out,
5598 "metal::float4 {name}(uint b0, \
5599 uint b1, \
5600 uint b2, \
5601 uint b3) {{"
5602 )?;
5603 writeln!(
5604 self.out,
5605 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5617 back::INDENT
5618 )?;
5619 writeln!(self.out, "}}")?;
5620 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5621 }
5622 Unorm8x4Bgra => {
5623 let name = self.namer.call("unpackUnorm8x4Bgra");
5624 writeln!(
5625 self.out,
5626 "metal::float4 {name}(metal::uchar b0, \
5627 metal::uchar b1, \
5628 metal::uchar b2, \
5629 metal::uchar b3) {{"
5630 )?;
5631 writeln!(
5632 self.out,
5633 "{}return metal::float4(float(b2) / 255.0f, \
5634 float(b1) / 255.0f, \
5635 float(b0) / 255.0f, \
5636 float(b3) / 255.0f);",
5637 back::INDENT
5638 )?;
5639 writeln!(self.out, "}}")?;
5640 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5641 }
5642 }
5643 }
5644
5645 fn write_wrapped_unary_op(
5646 &mut self,
5647 module: &crate::Module,
5648 func_ctx: &back::FunctionCtx,
5649 op: crate::UnaryOperator,
5650 operand: Handle<crate::Expression>,
5651 ) -> BackendResult {
5652 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5653 match op {
5654 crate::UnaryOperator::Negate
5661 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5662 {
5663 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5664 return Ok(());
5665 };
5666 let wrapped = WrappedFunction::UnaryOp {
5667 op,
5668 ty: (vector_size, scalar),
5669 };
5670 if !self.wrapped_functions.insert(wrapped) {
5671 return Ok(());
5672 }
5673
5674 let unsigned_scalar = crate::Scalar {
5675 kind: crate::ScalarKind::Uint,
5676 ..scalar
5677 };
5678 let mut type_name = String::new();
5679 let mut unsigned_type_name = String::new();
5680 match vector_size {
5681 None => {
5682 put_numeric_type(&mut type_name, scalar, &[])?;
5683 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5684 }
5685 Some(size) => {
5686 put_numeric_type(&mut type_name, scalar, &[size])?;
5687 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5688 }
5689 };
5690
5691 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5692 let level = back::Level(1);
5693 if scalar.width < 4 {
5697 writeln!(
5698 self.out,
5699 "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5700 )?;
5701 } else {
5702 writeln!(
5703 self.out,
5704 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5705 )?;
5706 }
5707 writeln!(self.out, "}}")?;
5708 writeln!(self.out)?;
5709 }
5710 _ => {}
5711 }
5712 Ok(())
5713 }
5714
5715 fn write_wrapped_binary_op(
5716 &mut self,
5717 module: &crate::Module,
5718 func_ctx: &back::FunctionCtx,
5719 expr: Handle<crate::Expression>,
5720 op: crate::BinaryOperator,
5721 left: Handle<crate::Expression>,
5722 right: Handle<crate::Expression>,
5723 ) -> BackendResult {
5724 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5725 let left_ty = func_ctx.resolve_type(left, &module.types);
5726 let right_ty = func_ctx.resolve_type(right, &module.types);
5727 match (op, expr_ty.scalar_kind()) {
5728 (
5735 crate::BinaryOperator::Divide,
5736 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5737 ) if self.emit_int_div_checks => {
5738 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5739 return Ok(());
5740 };
5741 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5742 return Ok(());
5743 };
5744 let wrapped = WrappedFunction::BinaryOp {
5745 op,
5746 left_ty: left_wrapped_ty,
5747 right_ty: right_wrapped_ty,
5748 };
5749 if !self.wrapped_functions.insert(wrapped) {
5750 return Ok(());
5751 }
5752
5753 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5754 return Ok(());
5755 };
5756 let mut type_name = String::new();
5757 match vector_size {
5758 None => put_numeric_type(&mut type_name, scalar, &[])?,
5759 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5760 };
5761 writeln!(
5762 self.out,
5763 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5764 )?;
5765 let level = back::Level(1);
5766 let (lp, rp) = if scalar.width < 4 {
5770 (format!("{type_name}("), ")".to_string())
5771 } else {
5772 (String::new(), String::new())
5773 };
5774 match scalar.kind {
5775 crate::ScalarKind::Sint => {
5776 let min_val = match scalar.width {
5777 2 => crate::Literal::I16(i16::MIN),
5778 4 => crate::Literal::I32(i32::MIN),
5779 8 => crate::Literal::I64(i64::MIN),
5780 _ => {
5781 return Err(Error::GenericValidation(format!(
5782 "Unexpected width for scalar {scalar:?}"
5783 )));
5784 }
5785 };
5786 write!(
5787 self.out,
5788 "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5789 )?;
5790 self.put_literal(min_val)?;
5791 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5792 }
5793 crate::ScalarKind::Uint => {
5794 let suffix = if scalar.width < 4 { "" } else { "u" };
5795 writeln!(
5796 self.out,
5797 "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5798 )?
5799 }
5800 _ => unreachable!(),
5801 }
5802 writeln!(self.out, "}}")?;
5803 writeln!(self.out)?;
5804 }
5805 (
5818 crate::BinaryOperator::Modulo,
5819 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5820 ) if self.emit_int_div_checks => {
5821 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5822 return Ok(());
5823 };
5824 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5825 else {
5826 return Ok(());
5827 };
5828 let wrapped = WrappedFunction::BinaryOp {
5829 op,
5830 left_ty: left_wrapped_ty,
5831 right_ty: (right_vector_size, right_scalar),
5832 };
5833 if !self.wrapped_functions.insert(wrapped) {
5834 return Ok(());
5835 }
5836
5837 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5838 return Ok(());
5839 };
5840 let mut type_name = String::new();
5841 match vector_size {
5842 None => put_numeric_type(&mut type_name, scalar, &[])?,
5843 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5844 };
5845 let mut rhs_type_name = String::new();
5846 match right_vector_size {
5847 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5848 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5849 };
5850
5851 writeln!(
5852 self.out,
5853 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5854 )?;
5855 let level = back::Level(1);
5856 let (lp, rp) = if scalar.width < 4 {
5857 (format!("{type_name}("), ")".to_string())
5858 } else {
5859 (String::new(), String::new())
5860 };
5861 match scalar.kind {
5862 crate::ScalarKind::Sint => {
5863 let min_val = match scalar.width {
5864 2 => crate::Literal::I16(i16::MIN),
5865 4 => crate::Literal::I32(i32::MIN),
5866 8 => crate::Literal::I64(i64::MIN),
5867 _ => {
5868 return Err(Error::GenericValidation(format!(
5869 "Unexpected width for scalar {scalar:?}"
5870 )));
5871 }
5872 };
5873 write!(
5874 self.out,
5875 "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5876 )?;
5877 self.put_literal(min_val)?;
5878 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5879 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5880 }
5881 crate::ScalarKind::Uint => {
5882 let suffix = if scalar.width < 4 { "" } else { "u" };
5883 writeln!(
5884 self.out,
5885 "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5886 )?
5887 }
5888 _ => unreachable!(),
5889 }
5890 writeln!(self.out, "}}")?;
5891 writeln!(self.out)?;
5892 }
5893 _ => {}
5894 }
5895 Ok(())
5896 }
5897
5898 fn get_dot_wrapper_function_helper_name(
5904 &self,
5905 scalar: crate::Scalar,
5906 size: crate::VectorSize,
5907 ) -> String {
5908 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5910
5911 let type_name = scalar.to_msl_name();
5912 let size_suffix = common::vector_size_str(size);
5913 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5914 }
5915
5916 #[allow(clippy::too_many_arguments)]
5917 fn write_wrapped_math_function(
5918 &mut self,
5919 module: &crate::Module,
5920 func_ctx: &back::FunctionCtx,
5921 fun: crate::MathFunction,
5922 arg: Handle<crate::Expression>,
5923 _arg1: Option<Handle<crate::Expression>>,
5924 _arg2: Option<Handle<crate::Expression>>,
5925 _arg3: Option<Handle<crate::Expression>>,
5926 ) -> BackendResult {
5927 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5928 match fun {
5929 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5937 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5938 return Ok(());
5939 };
5940 let wrapped = WrappedFunction::Math {
5941 fun,
5942 arg_ty: (vector_size, scalar),
5943 };
5944 if !self.wrapped_functions.insert(wrapped) {
5945 return Ok(());
5946 }
5947
5948 let unsigned_scalar = crate::Scalar {
5949 kind: crate::ScalarKind::Uint,
5950 ..scalar
5951 };
5952 let mut type_name = String::new();
5953 let mut unsigned_type_name = String::new();
5954 match vector_size {
5955 None => {
5956 put_numeric_type(&mut type_name, scalar, &[])?;
5957 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5958 }
5959 Some(size) => {
5960 put_numeric_type(&mut type_name, scalar, &[size])?;
5961 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5962 }
5963 };
5964
5965 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5966 let level = back::Level(1);
5967 let zero = if scalar.width < 4 {
5968 format!("{type_name}(0)")
5969 } else {
5970 "0".to_string()
5971 };
5972 let neg_expr = if scalar.width < 4 {
5973 format!(
5974 "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5975 )
5976 } else {
5977 format!("-as_type<{unsigned_type_name}>(val)")
5978 };
5979 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5980 writeln!(self.out, "}}")?;
5981 writeln!(self.out)?;
5982 }
5983
5984 crate::MathFunction::Dot => match *arg_ty {
5985 crate::TypeInner::Vector { size, scalar }
5986 if matches!(
5987 scalar.kind,
5988 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5989 ) =>
5990 {
5991 let wrapped = WrappedFunction::Math {
5993 fun,
5994 arg_ty: (Some(size), scalar),
5995 };
5996 if !self.wrapped_functions.insert(wrapped) {
5997 return Ok(());
5998 }
5999
6000 let mut vec_ty = String::new();
6001 put_numeric_type(&mut vec_ty, scalar, &[size])?;
6002 let mut ret_ty = String::new();
6003 put_numeric_type(&mut ret_ty, scalar, &[])?;
6004
6005 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
6006
6007 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6009 let level = back::Level(1);
6010 write!(self.out, "{level}return ")?;
6011 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6012 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6013 Ok(())
6014 })?;
6015 writeln!(self.out, ";")?;
6016 writeln!(self.out, "}}")?;
6017 writeln!(self.out)?;
6018 }
6019 _ => {}
6020 },
6021
6022 _ => {}
6023 }
6024 Ok(())
6025 }
6026
6027 fn write_wrapped_cast(
6028 &mut self,
6029 module: &crate::Module,
6030 func_ctx: &back::FunctionCtx,
6031 expr: Handle<crate::Expression>,
6032 kind: crate::ScalarKind,
6033 convert: Option<crate::Bytes>,
6034 ) -> BackendResult {
6035 let src_ty = func_ctx.resolve_type(expr, &module.types);
6046 let Some(width) = convert else {
6047 return Ok(());
6048 };
6049 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6050 return Ok(());
6051 };
6052 let dst_scalar = crate::Scalar { kind, width };
6053 if src_scalar.kind != crate::ScalarKind::Float
6054 || (dst_scalar.kind != crate::ScalarKind::Sint
6055 && dst_scalar.kind != crate::ScalarKind::Uint)
6056 {
6057 return Ok(());
6058 }
6059 let wrapped = WrappedFunction::Cast {
6060 src_scalar,
6061 vector_size,
6062 dst_scalar,
6063 };
6064 if !self.wrapped_functions.insert(wrapped) {
6065 return Ok(());
6066 }
6067 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6068
6069 let mut src_type_name = String::new();
6070 match vector_size {
6071 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6072 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6073 };
6074 let mut dst_type_name = String::new();
6075 match vector_size {
6076 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6077 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6078 };
6079 let fun_name = match dst_scalar {
6080 crate::Scalar::I32 => F2I32_FUNCTION,
6081 crate::Scalar::U32 => F2U32_FUNCTION,
6082 crate::Scalar::I64 => F2I64_FUNCTION,
6083 crate::Scalar::U64 => F2U64_FUNCTION,
6084 _ => unreachable!(),
6085 };
6086
6087 writeln!(
6088 self.out,
6089 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6090 )?;
6091 let level = back::Level(1);
6092 write!(
6093 self.out,
6094 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6095 )?;
6096 self.put_literal(min)?;
6097 write!(self.out, ", ")?;
6098 self.put_literal(max)?;
6099 writeln!(self.out, "));")?;
6100 writeln!(self.out, "}}")?;
6101 writeln!(self.out)?;
6102 Ok(())
6103 }
6104
6105 fn write_convert_yuv_to_rgb_and_return(
6113 &mut self,
6114 level: back::Level,
6115 y: &str,
6116 uv: &str,
6117 params: &str,
6118 ) -> BackendResult {
6119 let l1 = level;
6120 let l2 = l1.next();
6121
6122 writeln!(
6124 self.out,
6125 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6126 )?;
6127
6128 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6131 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6132 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6133 writeln!(
6134 self.out,
6135 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6136 )?;
6137
6138 writeln!(
6141 self.out,
6142 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6143 )?;
6144
6145 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6148 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6149 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6150 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6151
6152 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6153 Ok(())
6154 }
6155
6156 #[allow(clippy::too_many_arguments)]
6157 fn write_wrapped_image_load(
6158 &mut self,
6159 module: &crate::Module,
6160 func_ctx: &back::FunctionCtx,
6161 image: Handle<crate::Expression>,
6162 _coordinate: Handle<crate::Expression>,
6163 _array_index: Option<Handle<crate::Expression>>,
6164 _sample: Option<Handle<crate::Expression>>,
6165 _level: Option<Handle<crate::Expression>>,
6166 ) -> BackendResult {
6167 let class = match *func_ctx.resolve_type(image, &module.types) {
6169 crate::TypeInner::Image { class, .. } => class,
6170 _ => unreachable!(),
6171 };
6172 if class != crate::ImageClass::External {
6173 return Ok(());
6174 }
6175 let wrapped = WrappedFunction::ImageLoad { class };
6176 if !self.wrapped_functions.insert(wrapped) {
6177 return Ok(());
6178 }
6179
6180 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6181 let l1 = back::Level(1);
6182 let l2 = l1.next();
6183 let l3 = l2.next();
6184 writeln!(
6185 self.out,
6186 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6187 )?;
6188 writeln!(
6192 self.out,
6193 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6194 )?;
6195 writeln!(
6196 self.out,
6197 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6198 )?;
6199
6200 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6202 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6203 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6205 writeln!(self.out, "{l1}}} else {{")?;
6206
6207 writeln!(
6209 self.out,
6210 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6211 )?;
6212 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6213
6214 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6216
6217 writeln!(self.out, "{l2}float2 uv;")?;
6218 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6219 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6221 writeln!(self.out, "{l2}}} else {{")?;
6222 writeln!(
6224 self.out,
6225 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6226 )?;
6227 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6228 writeln!(
6229 self.out,
6230 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6231 )?;
6232 writeln!(self.out, "{l2}}}")?;
6233
6234 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6235
6236 writeln!(self.out, "{l1}}}")?;
6237 writeln!(self.out, "}}")?;
6238 writeln!(self.out)?;
6239 Ok(())
6240 }
6241
6242 #[allow(clippy::too_many_arguments)]
6243 fn write_wrapped_image_sample(
6244 &mut self,
6245 module: &crate::Module,
6246 func_ctx: &back::FunctionCtx,
6247 image: Handle<crate::Expression>,
6248 _sampler: Handle<crate::Expression>,
6249 _gather: Option<crate::SwizzleComponent>,
6250 _coordinate: Handle<crate::Expression>,
6251 _array_index: Option<Handle<crate::Expression>>,
6252 _offset: Option<Handle<crate::Expression>>,
6253 _level: crate::SampleLevel,
6254 _depth_ref: Option<Handle<crate::Expression>>,
6255 clamp_to_edge: bool,
6256 ) -> BackendResult {
6257 if !clamp_to_edge {
6260 return Ok(());
6261 }
6262 let class = match *func_ctx.resolve_type(image, &module.types) {
6263 crate::TypeInner::Image { class, .. } => class,
6264 _ => unreachable!(),
6265 };
6266 let wrapped = WrappedFunction::ImageSample {
6267 class,
6268 clamp_to_edge: true,
6269 };
6270 if !self.wrapped_functions.insert(wrapped) {
6271 return Ok(());
6272 }
6273 match class {
6274 crate::ImageClass::External => {
6275 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6276 let l1 = back::Level(1);
6277 let l2 = l1.next();
6278 let l3 = l2.next();
6279 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6280 writeln!(
6281 self.out,
6282 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6283 )?;
6284
6285 writeln!(
6293 self.out,
6294 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6295 )?;
6296 writeln!(
6297 self.out,
6298 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6299 )?;
6300 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6301 writeln!(
6302 self.out,
6303 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6304 )?;
6305 writeln!(
6306 self.out,
6307 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6308 )?;
6309 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6310 writeln!(
6312 self.out,
6313 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6314 )?;
6315 writeln!(self.out, "{l1}}} else {{")?;
6316 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6317 writeln!(
6318 self.out,
6319 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6320 )?;
6321 writeln!(
6322 self.out,
6323 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6324 )?;
6325
6326 writeln!(
6328 self.out,
6329 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6330 )?;
6331 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6332 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6333 writeln!(
6335 self.out,
6336 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6337 )?;
6338 writeln!(self.out, "{l2}}} else {{")?;
6339 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6341 writeln!(
6342 self.out,
6343 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6344 )?;
6345 writeln!(
6346 self.out,
6347 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6348 )?;
6349 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6350 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6351 writeln!(self.out, "{l2}}}")?;
6352
6353 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6354
6355 writeln!(self.out, "{l1}}}")?;
6356 writeln!(self.out, "}}")?;
6357 writeln!(self.out)?;
6358 }
6359 _ => {
6360 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6361 let l1 = back::Level(1);
6362 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6363 writeln!(
6364 self.out,
6365 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6366 )?;
6367 writeln!(self.out, "}}")?;
6368 writeln!(self.out)?;
6369 }
6370 }
6371 Ok(())
6372 }
6373
6374 fn write_wrapped_image_query(
6375 &mut self,
6376 module: &crate::Module,
6377 func_ctx: &back::FunctionCtx,
6378 image: Handle<crate::Expression>,
6379 query: crate::ImageQuery,
6380 ) -> BackendResult {
6381 if !matches!(query, crate::ImageQuery::Size { .. }) {
6383 return Ok(());
6384 }
6385 let class = match *func_ctx.resolve_type(image, &module.types) {
6386 crate::TypeInner::Image { class, .. } => class,
6387 _ => unreachable!(),
6388 };
6389 if class != crate::ImageClass::External {
6390 return Ok(());
6391 }
6392 let wrapped = WrappedFunction::ImageQuerySize { class };
6393 if !self.wrapped_functions.insert(wrapped) {
6394 return Ok(());
6395 }
6396 writeln!(
6397 self.out,
6398 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6399 )?;
6400 let l1 = back::Level(1);
6401 let l2 = l1.next();
6402 writeln!(
6403 self.out,
6404 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6405 )?;
6406 writeln!(self.out, "{l2}return tex.params.size;")?;
6407 writeln!(self.out, "{l1}}} else {{")?;
6408 writeln!(
6410 self.out,
6411 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6412 )?;
6413 writeln!(self.out, "{l1}}}")?;
6414 writeln!(self.out, "}}")?;
6415 writeln!(self.out)?;
6416 Ok(())
6417 }
6418
6419 fn write_wrapped_cooperative_load(
6420 &mut self,
6421 module: &crate::Module,
6422 func_ctx: &back::FunctionCtx,
6423 columns: crate::CooperativeSize,
6424 rows: crate::CooperativeSize,
6425 pointer: Handle<crate::Expression>,
6426 ) -> BackendResult {
6427 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6428 let space = ptr_ty.pointer_space().unwrap();
6429 let space_name = space.to_msl_name().unwrap_or_default();
6430 let scalar = ptr_ty
6431 .pointer_base_type()
6432 .unwrap()
6433 .inner_with(&module.types)
6434 .scalar()
6435 .unwrap();
6436 let wrapped = WrappedFunction::CooperativeLoad {
6437 space_name,
6438 columns,
6439 rows,
6440 scalar,
6441 };
6442 if !self.wrapped_functions.insert(wrapped) {
6443 return Ok(());
6444 }
6445 let scalar_name = scalar.to_msl_name();
6446 writeln!(
6447 self.out,
6448 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6449 columns as u32, rows as u32,
6450 )?;
6451 let l1 = back::Level(1);
6452 writeln!(
6453 self.out,
6454 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6455 columns as u32, rows as u32
6456 )?;
6457 let matrix_origin = "0";
6458 writeln!(
6459 self.out,
6460 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6461 )?;
6462 writeln!(self.out, "{l1}return m;")?;
6463 writeln!(self.out, "}}")?;
6464 writeln!(self.out)?;
6465 Ok(())
6466 }
6467
6468 fn write_wrapped_cooperative_multiply_add(
6469 &mut self,
6470 module: &crate::Module,
6471 func_ctx: &back::FunctionCtx,
6472 space: crate::AddressSpace,
6473 a: Handle<crate::Expression>,
6474 b: Handle<crate::Expression>,
6475 ) -> BackendResult {
6476 let space_name = space.to_msl_name().unwrap_or_default();
6477 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6478 crate::TypeInner::CooperativeMatrix {
6479 columns,
6480 rows,
6481 scalar,
6482 ..
6483 } => (columns, rows, scalar),
6484 _ => unreachable!(),
6485 };
6486 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6487 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6488 _ => unreachable!(),
6489 };
6490 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6491 space_name,
6492 columns: b_c,
6493 rows: a_r,
6494 intermediate: a_c,
6495 scalar,
6496 };
6497 if !self.wrapped_functions.insert(wrapped) {
6498 return Ok(());
6499 }
6500 let scalar_name = scalar.to_msl_name();
6501 writeln!(
6502 self.out,
6503 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6504 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6505 )?;
6506 let l1 = back::Level(1);
6507 writeln!(
6508 self.out,
6509 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6510 b_c as u32, a_r as u32
6511 )?;
6512 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6513 writeln!(self.out, "{l1}return d;")?;
6514 writeln!(self.out, "}}")?;
6515 writeln!(self.out)?;
6516 Ok(())
6517 }
6518
6519 pub(super) fn write_wrapped_functions(
6520 &mut self,
6521 module: &crate::Module,
6522 func_ctx: &back::FunctionCtx,
6523 ) -> BackendResult {
6524 for (expr_handle, expr) in func_ctx.expressions.iter() {
6525 match *expr {
6526 crate::Expression::Unary { op, expr: operand } => {
6527 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6528 }
6529 crate::Expression::Binary { op, left, right } => {
6530 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6531 }
6532 crate::Expression::Math {
6533 fun,
6534 arg,
6535 arg1,
6536 arg2,
6537 arg3,
6538 } => {
6539 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6540 }
6541 crate::Expression::As {
6542 expr,
6543 kind,
6544 convert,
6545 } => {
6546 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6547 }
6548 crate::Expression::ImageLoad {
6549 image,
6550 coordinate,
6551 array_index,
6552 sample,
6553 level,
6554 } => {
6555 self.write_wrapped_image_load(
6556 module,
6557 func_ctx,
6558 image,
6559 coordinate,
6560 array_index,
6561 sample,
6562 level,
6563 )?;
6564 }
6565 crate::Expression::ImageSample {
6566 image,
6567 sampler,
6568 gather,
6569 coordinate,
6570 array_index,
6571 offset,
6572 level,
6573 depth_ref,
6574 clamp_to_edge,
6575 } => {
6576 self.write_wrapped_image_sample(
6577 module,
6578 func_ctx,
6579 image,
6580 sampler,
6581 gather,
6582 coordinate,
6583 array_index,
6584 offset,
6585 level,
6586 depth_ref,
6587 clamp_to_edge,
6588 )?;
6589 }
6590 crate::Expression::ImageQuery { image, query } => {
6591 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6592 }
6593 crate::Expression::CooperativeLoad {
6594 columns,
6595 rows,
6596 role: _,
6597 ref data,
6598 } => {
6599 self.write_wrapped_cooperative_load(
6600 module,
6601 func_ctx,
6602 columns,
6603 rows,
6604 data.pointer,
6605 )?;
6606 }
6607 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6608 let space = crate::AddressSpace::Private;
6609 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6610 }
6611 crate::Expression::RayQueryGetIntersection { committed, .. } => {
6612 self.write_rq_get_intersection_function(module, committed)?;
6613 }
6614 _ => {}
6615 }
6616 }
6617
6618 Ok(())
6619 }
6620
6621 fn write_functions(
6623 &mut self,
6624 module: &crate::Module,
6625 mod_info: &valid::ModuleInfo,
6626 options: &Options,
6627 pipeline_options: &PipelineOptions,
6628 ) -> Result<TranslationInfo, Error> {
6629 use back::msl::VertexFormat;
6630
6631 struct AttributeMappingResolved {
6634 ty_name: String,
6635 dimension: Option<crate::VectorSize>,
6636 scalar: crate::Scalar,
6637 name: String,
6638 }
6639 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6640
6641 struct VertexBufferMappingResolved<'a> {
6642 id: u32,
6643 stride: u32,
6644 step_mode: back::msl::VertexBufferStepMode,
6645 ty_name: String,
6646 param_name: String,
6647 elem_name: String,
6648 attributes: &'a Vec<back::msl::AttributeMapping>,
6649 }
6650 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6651
6652 struct UnpackingFunction {
6654 name: String,
6655 byte_count: u32,
6656 dimension: Option<crate::VectorSize>,
6657 scalar: crate::Scalar,
6658 }
6659 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6660
6661 let mut needs_vertex_id = false;
6667 let v_id = self.namer.call("v_id");
6668
6669 let mut needs_instance_id = false;
6670 let i_id = self.namer.call("i_id");
6671 if pipeline_options.vertex_pulling_transform {
6672 for vbm in &pipeline_options.vertex_buffer_mappings {
6673 let buffer_id = vbm.id;
6674 let buffer_stride = vbm.stride;
6675
6676 assert!(
6677 buffer_stride > 0,
6678 "Vertex pulling requires a non-zero buffer stride."
6679 );
6680
6681 match vbm.step_mode {
6682 back::msl::VertexBufferStepMode::Constant => {}
6683 back::msl::VertexBufferStepMode::ByVertex => {
6684 needs_vertex_id = true;
6685 }
6686 back::msl::VertexBufferStepMode::ByInstance => {
6687 needs_instance_id = true;
6688 }
6689 }
6690
6691 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6692 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6693 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6694
6695 vbm_resolved.push(VertexBufferMappingResolved {
6696 id: buffer_id,
6697 stride: buffer_stride,
6698 step_mode: vbm.step_mode,
6699 ty_name: buffer_ty,
6700 param_name: buffer_param,
6701 elem_name: buffer_elem,
6702 attributes: &vbm.attributes,
6703 });
6704
6705 for attribute in &vbm.attributes {
6707 if unpacking_functions.contains_key(&attribute.format) {
6708 continue;
6709 }
6710 let (name, byte_count, dimension, scalar) =
6711 match self.write_unpacking_function(attribute.format) {
6712 Ok((name, byte_count, dimension, scalar)) => {
6713 (name, byte_count, dimension, scalar)
6714 }
6715 _ => {
6716 continue;
6717 }
6718 };
6719 unpacking_functions.insert(
6720 attribute.format,
6721 UnpackingFunction {
6722 name,
6723 byte_count,
6724 dimension,
6725 scalar,
6726 },
6727 );
6728 }
6729 }
6730 }
6731
6732 let mut pass_through_globals = Vec::new();
6733 for (fun_handle, fun) in module.functions.iter() {
6734 log::trace!(
6735 "function {:?}, handle {:?}",
6736 fun.name.as_deref().unwrap_or("(anonymous)"),
6737 fun_handle
6738 );
6739
6740 let ctx = back::FunctionCtx {
6741 ty: back::FunctionType::Function(fun_handle),
6742 info: &mod_info[fun_handle],
6743 expressions: &fun.expressions,
6744 named_expressions: &fun.named_expressions,
6745 };
6746
6747 writeln!(self.out)?;
6748 self.write_wrapped_functions(module, &ctx)?;
6749
6750 let fun_info = &mod_info[fun_handle];
6751 pass_through_globals.clear();
6752 let mut needs_buffer_sizes = false;
6753 for (handle, var) in module.global_variables.iter() {
6754 if !fun_info[handle].is_empty() {
6755 if var.space.needs_pass_through() {
6756 pass_through_globals.push(handle);
6757 }
6758 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6759 }
6760 }
6761
6762 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6763 match fun.result {
6764 Some(ref result) => {
6765 let ty_name = TypeContext {
6766 handle: result.ty,
6767 gctx: module.to_ctx(),
6768 names: &self.names,
6769 access: crate::StorageAccess::empty(),
6770 first_time: false,
6771 };
6772 write!(self.out, "{ty_name}")?;
6773 }
6774 None => {
6775 write!(self.out, "void")?;
6776 }
6777 }
6778 writeln!(self.out, " {fun_name}(")?;
6779
6780 for (index, arg) in fun.arguments.iter().enumerate() {
6781 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6782 let param_type_name = TypeContext {
6783 handle: arg.ty,
6784 gctx: module.to_ctx(),
6785 names: &self.names,
6786 access: crate::StorageAccess::empty(),
6787 first_time: false,
6788 };
6789 let separator = separate(
6790 !pass_through_globals.is_empty()
6791 || index + 1 != fun.arguments.len()
6792 || needs_buffer_sizes,
6793 );
6794 writeln!(
6795 self.out,
6796 "{}{} {}{}",
6797 back::INDENT,
6798 param_type_name,
6799 name,
6800 separator
6801 )?;
6802 }
6803 for (index, &handle) in pass_through_globals.iter().enumerate() {
6804 let tyvar = TypedGlobalVariable {
6805 module,
6806 names: &self.names,
6807 handle,
6808 usage: fun_info[handle],
6809 reference: true,
6810 };
6811 let separator =
6812 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6813 write!(self.out, "{}", back::INDENT)?;
6814 tyvar.try_fmt(&mut self.out)?;
6815 writeln!(self.out, "{separator}")?;
6816 }
6817
6818 if needs_buffer_sizes {
6819 writeln!(
6820 self.out,
6821 "{}constant _mslBufferSizes& _buffer_sizes",
6822 back::INDENT
6823 )?;
6824 }
6825
6826 writeln!(self.out, ") {{")?;
6827
6828 let guarded_indices =
6829 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6830
6831 let context = StatementContext {
6832 expression: ExpressionContext {
6833 function: fun,
6834 origin: FunctionOrigin::Handle(fun_handle),
6835 info: fun_info,
6836 lang_version: options.lang_version,
6837 policies: options.bounds_check_policies,
6838 guarded_indices,
6839 module,
6840 mod_info,
6841 pipeline_options,
6842 force_loop_bounding: options.force_loop_bounding,
6843 emit_int_div_checks: options.emit_int_div_checks,
6844 },
6845 result_struct: None,
6846 };
6847
6848 self.put_locals(&context.expression)?;
6849 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6850 self.put_block(back::Level(1), &fun.body, &context)?;
6851 writeln!(self.out, "}}")?;
6852 self.named_expressions.clear();
6853 }
6854
6855 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6856 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6857
6858 let mut info = TranslationInfo {
6859 entry_point_names: Vec::with_capacity(ep_range.len()),
6860 };
6861
6862 for ep_index in ep_range {
6863 let ep = &module.entry_points[ep_index];
6864 let fun = &ep.function;
6865 let fun_info = mod_info.get_entry_point(ep_index);
6866 let mut ep_error = None;
6867
6868 let mut v_existing_id = None;
6872 let mut i_existing_id = None;
6873
6874 log::trace!(
6875 "entry point {:?}, index {:?}",
6876 fun.name.as_deref().unwrap_or("(anonymous)"),
6877 ep_index
6878 );
6879
6880 let ctx = back::FunctionCtx {
6881 ty: back::FunctionType::EntryPoint(ep_index as u16),
6882 info: fun_info,
6883 expressions: &fun.expressions,
6884 named_expressions: &fun.named_expressions,
6885 };
6886
6887 self.write_wrapped_functions(module, &ctx)?;
6888
6889 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6890 crate::ShaderStage::Vertex => (
6891 Some("vertex"),
6892 LocationMode::VertexInput,
6893 LocationMode::VertexOutput,
6894 true,
6895 ),
6896 crate::ShaderStage::Fragment => (
6897 Some("fragment"),
6898 LocationMode::FragmentInput,
6899 LocationMode::FragmentOutput,
6900 false,
6901 ),
6902 crate::ShaderStage::Compute => (
6903 Some("kernel"),
6904 LocationMode::Uniform,
6905 LocationMode::Uniform,
6906 false,
6907 ),
6908 crate::ShaderStage::Task => {
6909 (None, LocationMode::Uniform, LocationMode::Uniform, false)
6910 }
6911 crate::ShaderStage::Mesh => {
6912 (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6913 }
6914 crate::ShaderStage::RayGeneration
6915 | crate::ShaderStage::AnyHit
6916 | crate::ShaderStage::ClosestHit
6917 | crate::ShaderStage::Miss => unimplemented!(),
6918 };
6919
6920 let do_vertex_pulling = can_vertex_pull
6922 && pipeline_options.vertex_pulling_transform
6923 && !pipeline_options.vertex_buffer_mappings.is_empty();
6924
6925 let needs_buffer_sizes = do_vertex_pulling
6927 || module
6928 .global_variables
6929 .iter()
6930 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6931 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6932
6933 if !options.fake_missing_bindings {
6936 for (var_handle, var) in module.global_variables.iter() {
6937 if fun_info[var_handle].is_empty() {
6938 continue;
6939 }
6940 match var.space {
6941 crate::AddressSpace::Uniform
6942 | crate::AddressSpace::Storage { .. }
6943 | crate::AddressSpace::Handle => {
6944 let br = match var.binding {
6945 Some(ref br) => br,
6946 None => {
6947 let var_name = var.name.clone().unwrap_or_default();
6948 ep_error =
6949 Some(super::EntryPointError::MissingBinding(var_name));
6950 break;
6951 }
6952 };
6953 let target = options.get_resource_binding_target(ep, br);
6954 let good = match target {
6955 Some(target) => {
6956 match module.types[var.ty].inner {
6960 crate::TypeInner::Image {
6961 class: crate::ImageClass::External,
6962 ..
6963 } => target.external_texture.is_some(),
6964 crate::TypeInner::Image { .. } => target.texture.is_some(),
6965 crate::TypeInner::Sampler { .. } => {
6966 target.sampler.is_some()
6967 }
6968 _ => target.buffer.is_some(),
6969 }
6970 }
6971 None => false,
6972 };
6973 if !good {
6974 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6975 break;
6976 }
6977 }
6978 crate::AddressSpace::Immediate => {
6979 if let Err(e) = options.resolve_immediates(ep) {
6980 ep_error = Some(e);
6981 break;
6982 }
6983 }
6984 crate::AddressSpace::Function
6985 | crate::AddressSpace::Private
6986 | crate::AddressSpace::WorkGroup
6987 | crate::AddressSpace::TaskPayload => {}
6988 crate::AddressSpace::RayPayload
6989 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6990 }
6991 }
6992 if needs_buffer_sizes {
6993 if let Err(err) = options.resolve_sizes_buffer(ep) {
6994 ep_error = Some(err);
6995 }
6996 }
6997 }
6998
6999 if let Some(err) = ep_error {
7000 info.entry_point_names.push(Err(err));
7001 continue;
7002 }
7003 let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
7004 info.entry_point_names.push(Ok(fun_name.clone()));
7005
7006 writeln!(self.out)?;
7007
7008 let mut flattened_member_names = FastHashMap::default();
7014 let mut varyings_namer = proc::Namer::default();
7016
7017 let mut empty_names = FastHashMap::default(); varyings_namer.reset(
7019 module,
7020 &super::keywords::RESERVED_SET,
7021 proc::KeywordSet::empty(),
7022 proc::CaseInsensitiveKeywordSet::empty(),
7023 &[CLAMPED_LOD_LOAD_PREFIX],
7024 &mut empty_names,
7025 );
7026
7027 let mut flattened_arguments = Vec::new();
7032 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7033 match module.types[arg.ty].inner {
7034 crate::TypeInner::Struct { ref members, .. } => {
7035 for (member_index, member) in members.iter().enumerate() {
7036 let member_index = member_index as u32;
7037 flattened_arguments.push((
7038 NameKey::StructMember(arg.ty, member_index),
7039 member.ty,
7040 member.binding.as_ref(),
7041 ));
7042 let name_key = NameKey::StructMember(arg.ty, member_index);
7043 let name = match member.binding {
7044 Some(crate::Binding::Location { .. }) => {
7045 if do_vertex_pulling {
7046 self.namer.call(&self.names[&name_key])
7047 } else {
7048 varyings_namer.call(&self.names[&name_key])
7049 }
7050 }
7051 _ => self.namer.call(&self.names[&name_key]),
7052 };
7053 flattened_member_names.insert(name_key, name);
7054 }
7055 }
7056 _ => flattened_arguments.push((
7057 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7058 arg.ty,
7059 arg.binding.as_ref(),
7060 )),
7061 }
7062 }
7063
7064 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7069 let varyings_member_name = self.namer.call("varyings");
7070 let mut has_varyings = false;
7071
7072 if !flattened_arguments.is_empty() {
7073 if !do_vertex_pulling {
7074 writeln!(self.out, "struct {stage_in_name} {{")?;
7075 }
7076 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7077 let Some(binding) = binding else {
7078 continue;
7079 };
7080 let name = match *name_key {
7081 NameKey::StructMember(..) => &flattened_member_names[name_key],
7082 _ => &self.names[name_key],
7083 };
7084 let ty_name = TypeContext {
7085 handle: ty,
7086 gctx: module.to_ctx(),
7087 names: &self.names,
7088 access: crate::StorageAccess::empty(),
7089 first_time: false,
7090 };
7091 let resolved = options.resolve_local_binding(binding, in_mode)?;
7092 let location = match *binding {
7093 crate::Binding::Location { location, .. } => Some(location),
7094 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7095 crate::Binding::BuiltIn(_) => continue,
7096 };
7097 if do_vertex_pulling {
7098 let Some(location) = location else {
7099 continue;
7100 };
7101 am_resolved.insert(
7103 location,
7104 AttributeMappingResolved {
7105 ty_name: ty_name.to_string(),
7106 dimension: ty_name.vector_size(),
7107 scalar: ty_name.scalar().unwrap(),
7108 name: name.to_string(),
7109 },
7110 );
7111 } else {
7112 has_varyings = true;
7113 if let super::ResolvedBinding::User {
7114 prefix,
7115 index,
7116 interpolation: Some(super::ResolvedInterpolation::PerVertex),
7117 } = resolved
7118 {
7119 if options.lang_version < (4, 0) {
7120 return Err(Error::PerVertexNotSupported);
7121 }
7122 write!(
7123 self.out,
7124 "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7125 back::INDENT,
7126 ty_name.unwrap_array()
7127 )?;
7128 } else {
7129 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7130 resolved.try_fmt(&mut self.out)?;
7131 }
7132 writeln!(self.out, ";")?;
7133 }
7134 }
7135 if !do_vertex_pulling {
7136 writeln!(self.out, "}};")?;
7137 }
7138 }
7139
7140 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7143 let result_member_name = self.namer.call("member");
7144 let result_type_name = match fun.result {
7145 Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7146 let mut result_members = Vec::new();
7147 if let crate::TypeInner::Struct { ref members, .. } =
7148 module.types[result.ty].inner
7149 {
7150 for (member_index, member) in members.iter().enumerate() {
7151 result_members.push((
7152 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7153 member.ty,
7154 member.binding.as_ref(),
7155 ));
7156 }
7157 } else {
7158 result_members.push((
7159 &result_member_name,
7160 result.ty,
7161 result.binding.as_ref(),
7162 ));
7163 }
7164
7165 writeln!(self.out, "struct {stage_out_name} {{")?;
7166 let mut has_point_size = false;
7167 for (name, ty, binding) in result_members {
7168 let ty_name = TypeContext {
7169 handle: ty,
7170 gctx: module.to_ctx(),
7171 names: &self.names,
7172 access: crate::StorageAccess::empty(),
7173 first_time: true,
7174 };
7175 let binding = binding.ok_or_else(|| {
7176 Error::GenericValidation("Expected binding, got None".into())
7177 })?;
7178
7179 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7180 has_point_size = true;
7181 if !pipeline_options.allow_and_force_point_size {
7182 continue;
7183 }
7184 }
7185
7186 let array_len = match module.types[ty].inner {
7187 crate::TypeInner::Array {
7188 size: crate::ArraySize::Constant(size),
7189 ..
7190 } => Some(size),
7191 _ => None,
7192 };
7193 let resolved = options.resolve_local_binding(binding, out_mode)?;
7194 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7195 resolved.try_fmt(&mut self.out)?;
7196 if let Some(array_len) = array_len {
7197 write!(self.out, " [{array_len}]")?;
7198 }
7199 writeln!(self.out, ";")?;
7200 }
7201
7202 if pipeline_options.allow_and_force_point_size
7203 && ep.stage == crate::ShaderStage::Vertex
7204 && !has_point_size
7205 {
7206 writeln!(
7208 self.out,
7209 "{}float _point_size [[point_size]];",
7210 back::INDENT
7211 )?;
7212 }
7213 writeln!(self.out, "}};")?;
7214 &stage_out_name
7215 }
7216 Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7217 assert_eq!(
7218 module.types[result.ty].inner,
7219 crate::TypeInner::Vector {
7220 size: crate::VectorSize::Tri,
7221 scalar: crate::Scalar::U32
7222 }
7223 );
7224
7225 "metal::uint3"
7226 }
7227 _ => "void",
7228 };
7229
7230 let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7231 Some(self.write_mesh_output_types(
7232 mesh_info,
7233 &fun_name,
7234 module,
7235 pipeline_options.allow_and_force_point_size,
7236 options,
7237 )?)
7238 } else {
7239 None
7240 };
7241
7242 if do_vertex_pulling {
7245 for vbm in &vbm_resolved {
7246 let buffer_stride = vbm.stride;
7247 let buffer_ty = &vbm.ty_name;
7248
7249 writeln!(
7253 self.out,
7254 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7255 )?;
7256 }
7257 }
7258
7259 let is_wrapped = matches!(
7260 ep.stage,
7261 crate::ShaderStage::Task | crate::ShaderStage::Mesh
7262 );
7263 let fun_name = fun_name.clone();
7264 let nested_fun_name = if is_wrapped {
7265 self.namer.call(&format!("_{fun_name}"))
7266 } else {
7267 fun_name.clone()
7268 };
7269
7270 if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7272 let total_threads =
7273 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7274 write!(
7275 self.out,
7276 "[[max_total_threads_per_threadgroup({total_threads})]] "
7277 )?;
7278 }
7279
7280 if let Some(em_str) = em_str {
7282 write!(self.out, "{em_str} ")?;
7283 }
7284 writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7285
7286 let mut args = Vec::new();
7287
7288 if has_varyings {
7291 args.push(EntryPointArgument {
7292 ty_name: stage_in_name,
7293 name: varyings_member_name.clone(),
7294 binding: " [[stage_in]]".to_string(),
7295 init: None,
7296 });
7297 }
7298
7299 let mut local_invocation_index = None;
7300
7301 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7304 let binding = match binding {
7305 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7306 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7307 _ => continue,
7308 };
7309 let name = match *name_key {
7310 NameKey::StructMember(..) => &flattened_member_names[name_key],
7311 _ => &self.names[name_key],
7312 };
7313
7314 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7315 local_invocation_index = Some(name_key);
7316 }
7317
7318 let ty_name = TypeContext {
7319 handle: ty,
7320 gctx: module.to_ctx(),
7321 names: &self.names,
7322 access: crate::StorageAccess::empty(),
7323 first_time: false,
7324 };
7325
7326 match *binding {
7327 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7328 v_existing_id = Some(name.clone());
7329 }
7330 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7331 i_existing_id = Some(name.clone());
7332 }
7333 _ => {}
7334 };
7335
7336 let resolved = options.resolve_local_binding(binding, in_mode)?;
7337 let mut binding = String::new();
7338 resolved.try_fmt(&mut binding)?;
7339
7340 args.push(EntryPointArgument {
7341 ty_name: format!("{ty_name}"),
7342 name: name.clone(),
7343 binding,
7344 init: None,
7345 });
7346 }
7347
7348 let need_workgroup_variables_initialization =
7349 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7350
7351 if local_invocation_index.is_none()
7352 && (need_workgroup_variables_initialization
7353 || ep.stage == crate::ShaderStage::Task
7354 || ep.stage == crate::ShaderStage::Mesh)
7355 {
7356 args.push(EntryPointArgument {
7357 ty_name: "uint".to_string(),
7358 name: "__local_invocation_index".to_string(),
7359 binding: " [[thread_index_in_threadgroup]]".to_string(),
7360 init: None,
7361 });
7362 }
7363
7364 for (handle, var) in module.global_variables.iter() {
7369 let usage = fun_info[handle];
7370 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7371 continue;
7372 }
7373
7374 if options.lang_version < (1, 2) {
7375 match var.space {
7376 crate::AddressSpace::Storage { access }
7386 if access.contains(crate::StorageAccess::STORE)
7387 && ep.stage == crate::ShaderStage::Fragment =>
7388 {
7389 return Err(Error::UnsupportedWritableStorageBuffer)
7390 }
7391 crate::AddressSpace::Handle => {
7392 match module.types[var.ty].inner {
7393 crate::TypeInner::Image {
7394 class: crate::ImageClass::Storage { access, .. },
7395 ..
7396 } => {
7397 if access.contains(crate::StorageAccess::STORE)
7407 && (ep.stage == crate::ShaderStage::Vertex
7408 || ep.stage == crate::ShaderStage::Fragment)
7409 {
7410 return Err(Error::UnsupportedWritableStorageTexture(
7411 ep.stage,
7412 ));
7413 }
7414
7415 if access.contains(
7416 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7417 ) {
7418 return Err(Error::UnsupportedRWStorageTexture);
7419 }
7420 }
7421 _ => {}
7422 }
7423 }
7424 _ => {}
7425 }
7426 }
7427
7428 match var.space {
7430 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7431 crate::TypeInner::BindingArray { base, .. } => {
7432 match module.types[base].inner {
7433 crate::TypeInner::Sampler { .. } => {
7434 if options.lang_version < (2, 0) {
7435 return Err(Error::UnsupportedArrayOf(
7436 "samplers".to_string(),
7437 ));
7438 }
7439 }
7440 crate::TypeInner::Image { class, .. } => match class {
7441 crate::ImageClass::Sampled { .. }
7442 | crate::ImageClass::Depth { .. }
7443 | crate::ImageClass::Storage {
7444 access: crate::StorageAccess::LOAD,
7445 ..
7446 } => {
7447 if options.lang_version < (2, 0) {
7452 return Err(Error::UnsupportedArrayOf(
7453 "textures".to_string(),
7454 ));
7455 }
7456 }
7457 crate::ImageClass::Storage {
7458 access: crate::StorageAccess::STORE,
7459 ..
7460 } => {
7461 if options.lang_version < (2, 0) {
7466 return Err(Error::UnsupportedArrayOf(
7467 "write-only textures".to_string(),
7468 ));
7469 }
7470 }
7471 crate::ImageClass::Storage { .. } => {
7472 if options.lang_version < (3, 0) {
7473 return Err(Error::UnsupportedArrayOf(
7474 "read-write textures".to_string(),
7475 ));
7476 }
7477 }
7478 crate::ImageClass::External => {
7479 return Err(Error::UnsupportedArrayOf(
7480 "external textures".to_string(),
7481 ));
7482 }
7483 },
7484 _ => {
7485 return Err(Error::UnsupportedArrayOfType(base));
7486 }
7487 }
7488 }
7489 _ => {}
7490 },
7491 _ => {}
7492 }
7493
7494 let resolved = match var.space {
7496 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7497 crate::AddressSpace::WorkGroup => None,
7498 crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7499 _ => options
7500 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7501 .ok(),
7502 };
7503 if let Some(ref resolved) = resolved {
7504 if resolved.as_inline_sampler(options).is_some() {
7506 continue;
7507 }
7508 }
7509
7510 match module.types[var.ty].inner {
7511 crate::TypeInner::Image {
7512 class: crate::ImageClass::External,
7513 ..
7514 } => {
7515 let target = match resolved {
7519 Some(back::msl::ResolvedBinding::Resource(target)) => {
7520 target.external_texture
7521 }
7522 _ => None,
7523 };
7524
7525 for i in 0..3 {
7526 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7527 handle,
7528 ExternalTextureNameKey::Plane(i),
7529 )];
7530 let ty_name = format!(
7531 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7532 );
7533 let name = plane_name.clone();
7534 let binding = if let Some(ref target) = target {
7535 format!(" [[texture({})]]", target.planes[i])
7536 } else {
7537 String::new()
7538 };
7539 args.push(EntryPointArgument {
7540 ty_name,
7541 name,
7542 binding,
7543 init: None,
7544 });
7545 }
7546 let params_ty_name = &self.names
7547 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7548 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7549 handle,
7550 ExternalTextureNameKey::Params,
7551 )];
7552 let binding = if let Some(ref target) = target {
7553 format!(" [[buffer({})]]", target.params)
7554 } else {
7555 String::new()
7556 };
7557
7558 args.push(EntryPointArgument {
7559 ty_name: format!("constant {params_ty_name}&"),
7560 name: params_name.clone(),
7561 binding,
7562 init: None,
7563 });
7564 }
7565 _ => {
7566 if var.space == crate::AddressSpace::WorkGroup
7567 && ep.stage == crate::ShaderStage::Mesh
7568 {
7569 continue;
7570 }
7571 let tyvar = TypedGlobalVariable {
7572 module,
7573 names: &self.names,
7574 handle,
7575 usage,
7576 reference: true,
7577 };
7578 let parts = tyvar.to_parts()?;
7579 let mut binding = String::new();
7580 if let Some(resolved) = resolved {
7581 resolved.try_fmt(&mut binding)?;
7582 }
7583 args.push(EntryPointArgument {
7584 ty_name: parts.ty_name,
7585 name: parts.var_name,
7586 binding,
7587 init: var.init,
7588 });
7589 }
7590 }
7591 }
7592
7593 if do_vertex_pulling {
7594 if needs_vertex_id && v_existing_id.is_none() {
7595 args.push(EntryPointArgument {
7597 ty_name: "uint".to_string(),
7598 name: v_id.clone(),
7599 binding: " [[vertex_id]]".to_string(),
7600 init: None,
7601 });
7602 }
7603
7604 if needs_instance_id && i_existing_id.is_none() {
7605 args.push(EntryPointArgument {
7606 ty_name: "uint".to_string(),
7607 name: i_id.clone(),
7608 binding: " [[instance_id]]".to_string(),
7609 init: None,
7610 });
7611 }
7612
7613 for vbm in &vbm_resolved {
7616 let id = &vbm.id;
7617 let ty_name = &vbm.ty_name;
7618 let param_name = &vbm.param_name;
7619 args.push(EntryPointArgument {
7620 ty_name: format!("const device {ty_name}*"),
7621 name: param_name.clone(),
7622 binding: format!(" [[buffer({id})]]"),
7623 init: None,
7624 });
7625 }
7626 }
7627
7628 if needs_buffer_sizes {
7631 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7633 let mut binding = String::new();
7634 resolved.try_fmt(&mut binding)?;
7635 args.push(EntryPointArgument {
7636 ty_name: "constant _mslBufferSizes&".to_string(),
7637 name: "_buffer_sizes".to_string(),
7638 binding,
7639 init: None,
7640 });
7641 }
7642
7643 let mut is_first_arg = true;
7644 for arg in &args {
7645 if is_first_arg {
7646 write!(self.out, " ")?;
7647 } else {
7648 write!(self.out, ", ")?;
7649 }
7650 is_first_arg = false;
7651 write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7652 if !is_wrapped {
7653 write!(self.out, "{}", arg.binding)?;
7654 if let Some(init) = arg.init {
7655 write!(self.out, " = ")?;
7656 self.put_const_expression(
7657 init,
7658 module,
7659 mod_info,
7660 &module.global_expressions,
7661 )?;
7662 }
7663 }
7664 writeln!(self.out)?;
7665 }
7666 if ep.stage == crate::ShaderStage::Mesh {
7667 for (handle, var) in module.global_variables.iter() {
7668 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7669 continue;
7670 }
7671 if is_first_arg {
7672 write!(self.out, " ")?;
7673 } else {
7674 write!(self.out, ", ")?;
7675 }
7676 let ty_context = TypeContext {
7677 handle: module.global_variables[handle].ty,
7678 gctx: module.to_ctx(),
7679 names: &self.names,
7680 access: crate::StorageAccess::empty(),
7681 first_time: false,
7682 };
7683 writeln!(
7684 self.out,
7685 "threadgroup {ty_context}& {}",
7686 self.names[&NameKey::GlobalVariable(handle)]
7687 )?;
7688 }
7689 }
7690
7691 writeln!(self.out, ") {{")?;
7693
7694 if do_vertex_pulling {
7696 for vbm in &vbm_resolved {
7699 for attribute in vbm.attributes {
7700 let location = attribute.shader_location;
7701 let am_option = am_resolved.get(&location);
7702 if am_option.is_none() {
7703 continue;
7706 }
7707 let am = am_option.unwrap();
7708 let attribute_ty_name = &am.ty_name;
7709 let attribute_name = &am.name;
7710
7711 writeln!(
7712 self.out,
7713 "{}{attribute_ty_name} {attribute_name} = {{}};",
7714 back::Level(1)
7715 )?;
7716 }
7717
7718 write!(self.out, "{}if (", back::Level(1))?;
7721
7722 let idx = &vbm.id;
7723 let stride = &vbm.stride;
7724 let index_name = match vbm.step_mode {
7725 back::msl::VertexBufferStepMode::Constant => "0",
7726 back::msl::VertexBufferStepMode::ByVertex => {
7727 if let Some(ref name) = v_existing_id {
7728 name
7729 } else {
7730 &v_id
7731 }
7732 }
7733 back::msl::VertexBufferStepMode::ByInstance => {
7734 if let Some(ref name) = i_existing_id {
7735 name
7736 } else {
7737 &i_id
7738 }
7739 }
7740 };
7741 write!(
7742 self.out,
7743 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7744 )?;
7745
7746 writeln!(self.out, ") {{")?;
7747
7748 let ty_name = &vbm.ty_name;
7750 let elem_name = &vbm.elem_name;
7751 let param_name = &vbm.param_name;
7752
7753 writeln!(
7754 self.out,
7755 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7756 back::Level(2),
7757 )?;
7758
7759 for attribute in vbm.attributes {
7762 let location = attribute.shader_location;
7763 let Some(am) = am_resolved.get(&location) else {
7764 continue;
7768 };
7769 let attribute_name = &am.name;
7770 let attribute_ty_name = &am.ty_name;
7771
7772 let offset = attribute.offset;
7773 let func = unpacking_functions
7774 .get(&attribute.format)
7775 .expect("Should have generated this unpacking function earlier.");
7776 let func_name = &func.name;
7777
7778 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7784
7785 let needs_conversion = am.scalar != func.scalar;
7788
7789 if needs_padding_or_truncation != Ordering::Equal {
7790 writeln!(
7793 self.out,
7794 "{}// {attribute_ty_name} <- {:?}",
7795 back::Level(2),
7796 attribute.format
7797 )?;
7798 }
7799
7800 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7801
7802 if needs_padding_or_truncation == Ordering::Greater {
7803 write!(self.out, "{attribute_ty_name}(")?;
7805 }
7806
7807 if needs_conversion {
7809 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7810 write!(self.out, "(")?;
7811 }
7812 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7813 for i in (offset + 1)..(offset + func.byte_count) {
7814 write!(self.out, ", {elem_name}.data[{i}]")?;
7815 }
7816 write!(self.out, ")")?;
7817 if needs_conversion {
7818 write!(self.out, ")")?;
7819 }
7820
7821 match needs_padding_or_truncation {
7822 Ordering::Greater => {
7823 let ty_is_int = scalar_is_int(am.scalar);
7825 let zero_value = if ty_is_int { "0" } else { "0.0" };
7826 let one_value = if ty_is_int { "1" } else { "1.0" };
7827 for i in func.dimension.map_or(1, u8::from)
7828 ..am.dimension.map_or(1, u8::from)
7829 {
7830 write!(
7831 self.out,
7832 ", {}",
7833 if i == 3 { one_value } else { zero_value }
7834 )?;
7835 }
7836 }
7837 Ordering::Less => {
7838 write!(
7840 self.out,
7841 ".{}",
7842 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7843 )?;
7844 }
7845 Ordering::Equal => {}
7846 }
7847
7848 if needs_padding_or_truncation == Ordering::Greater {
7849 write!(self.out, ")")?;
7850 }
7851
7852 writeln!(self.out, ";")?;
7853 }
7854
7855 writeln!(self.out, "{}}}", back::Level(1))?;
7857 }
7858 }
7859
7860 for (handle, var) in module.global_variables.iter() {
7863 let usage = fun_info[handle];
7864 if usage.is_empty() {
7865 continue;
7866 }
7867 if var.space == crate::AddressSpace::Private {
7868 let tyvar = TypedGlobalVariable {
7869 module,
7870 names: &self.names,
7871 handle,
7872 usage,
7873
7874 reference: false,
7875 };
7876 write!(self.out, "{}", back::INDENT)?;
7877 tyvar.try_fmt(&mut self.out)?;
7878 match var.init {
7879 Some(value) => {
7880 write!(self.out, " = ")?;
7881 self.put_const_expression(
7882 value,
7883 module,
7884 mod_info,
7885 &module.global_expressions,
7886 )?;
7887 writeln!(self.out, ";")?;
7888 }
7889 None => {
7890 writeln!(self.out, " = {{}};")?;
7891 }
7892 };
7893 } else if let Some(ref binding) = var.binding {
7894 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7895 if let Some(sampler) = resolved.as_inline_sampler(options) {
7896 let name = &self.names[&NameKey::GlobalVariable(handle)];
7898 writeln!(
7899 self.out,
7900 "{}constexpr {}::sampler {}(",
7901 back::INDENT,
7902 NAMESPACE,
7903 name
7904 )?;
7905 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7906 writeln!(self.out, "{});", back::INDENT)?;
7907 } else if let crate::TypeInner::Image {
7908 class: crate::ImageClass::External,
7909 ..
7910 } = module.types[var.ty].inner
7911 {
7912 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7915 let l1 = back::Level(1);
7916 let l2 = l1.next();
7917 writeln!(
7918 self.out,
7919 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7920 )?;
7921 for i in 0..3 {
7922 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7923 handle,
7924 ExternalTextureNameKey::Plane(i),
7925 )];
7926 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7927 }
7928 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7929 handle,
7930 ExternalTextureNameKey::Params,
7931 )];
7932 writeln!(self.out, "{l2}.params = {params_name},")?;
7933 writeln!(self.out, "{l1}}};")?;
7934 }
7935 }
7936 }
7937
7938 if need_workgroup_variables_initialization {
7939 self.write_workgroup_variables_initialization(
7940 module,
7941 mod_info,
7942 fun_info,
7943 local_invocation_index,
7944 ep.stage,
7945 )?;
7946 }
7947
7948 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7959 let arg_name =
7960 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7961 match module.types[arg.ty].inner {
7962 crate::TypeInner::Struct { ref members, .. } => {
7963 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7964 write!(
7965 self.out,
7966 "{}const {} {} = {{ ",
7967 back::INDENT,
7968 struct_name,
7969 arg_name
7970 )?;
7971 for (member_index, member) in members.iter().enumerate() {
7972 let key = NameKey::StructMember(arg.ty, member_index as u32);
7973 let name = &flattened_member_names[&key];
7974 if member_index != 0 {
7975 write!(self.out, ", ")?;
7976 }
7977 if self
7979 .struct_member_pads
7980 .contains(&(arg.ty, member_index as u32))
7981 {
7982 write!(self.out, "{{}}, ")?;
7983 }
7984 match member.binding {
7985 Some(crate::Binding::Location {
7986 interpolation: Some(crate::Interpolation::PerVertex),
7987 ..
7988 }) => {
7989 writeln!(
7990 self.out,
7991 "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7992 back::INDENT,
7993 varyings_member_name,
7994 arg_name,
7995 )?;
7996 continue;
7997 }
7998 Some(crate::Binding::Location { .. }) => {
7999 if has_varyings {
8000 write!(self.out, "{varyings_member_name}.")?;
8001 }
8002 }
8003 _ => (),
8004 }
8005 write!(self.out, "{name}")?;
8006 }
8007 writeln!(self.out, " }};")?;
8008 }
8009 _ => match arg.binding {
8010 Some(crate::Binding::Location {
8011 interpolation: Some(crate::Interpolation::PerVertex),
8012 ..
8013 }) => {
8014 let ty_name = TypeContext {
8015 handle: arg.ty,
8016 gctx: module.to_ctx(),
8017 names: &self.names,
8018 access: crate::StorageAccess::empty(),
8019 first_time: false,
8020 };
8021 writeln!(
8022 self.out,
8023 "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8024 back::INDENT,
8025 varyings_member_name,
8026 arg_name,
8027 )?;
8028 }
8029 Some(crate::Binding::Location { .. })
8030 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8031 if has_varyings {
8032 writeln!(
8033 self.out,
8034 "{}const auto {} = {}.{};",
8035 back::INDENT,
8036 arg_name,
8037 varyings_member_name,
8038 arg_name
8039 )?;
8040 }
8041 }
8042 _ => {}
8043 },
8044 }
8045 }
8046
8047 let guarded_indices =
8048 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8049
8050 let context = StatementContext {
8051 expression: ExpressionContext {
8052 function: fun,
8053 origin: FunctionOrigin::EntryPoint(ep_index as _),
8054 info: fun_info,
8055 lang_version: options.lang_version,
8056 policies: options.bounds_check_policies,
8057 guarded_indices,
8058 module,
8059 mod_info,
8060 pipeline_options,
8061 force_loop_bounding: options.force_loop_bounding,
8062 emit_int_div_checks: options.emit_int_div_checks,
8063 },
8064 result_struct: if ep.stage == crate::ShaderStage::Task {
8065 None
8066 } else {
8067 Some(&stage_out_name)
8068 },
8069 };
8070
8071 self.put_locals(&context.expression)?;
8074 self.update_expressions_to_bake(fun, fun_info, &context.expression);
8075 self.put_block(back::Level(1), &fun.body, &context)?;
8076 writeln!(self.out, "}}")?;
8077 if ep_index + 1 != module.entry_points.len() {
8078 writeln!(self.out)?;
8079 }
8080 self.named_expressions.clear();
8081
8082 if is_wrapped {
8083 self.write_wrapper_function(NestedFunctionInfo {
8084 options,
8085 ep,
8086 module,
8087 mod_info,
8088 fun_info,
8089 args,
8090 local_invocation_index,
8091 nested_name: &nested_fun_name,
8092 outer_name: &fun_name,
8093 out_mesh_info,
8094 })?;
8095 }
8096 }
8097
8098 Ok(info)
8099 }
8100
8101 pub(super) fn write_barrier(
8102 &mut self,
8103 flags: crate::Barrier,
8104 level: back::Level,
8105 ) -> BackendResult {
8106 if flags.is_empty() {
8109 writeln!(
8110 self.out,
8111 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8112 )?;
8113 }
8114 if flags.contains(crate::Barrier::STORAGE) {
8115 writeln!(
8116 self.out,
8117 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8118 )?;
8119 }
8120 if flags.contains(crate::Barrier::WORK_GROUP) {
8121 writeln!(
8122 self.out,
8123 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8124 )?;
8125 if self.needs_object_memory_barriers {
8126 writeln!(
8127 self.out,
8128 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8129 )?;
8130 }
8131 }
8132 if flags.contains(crate::Barrier::SUB_GROUP) {
8133 writeln!(
8134 self.out,
8135 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8136 )?;
8137 }
8138 if flags.contains(crate::Barrier::TEXTURE) {
8139 writeln!(
8140 self.out,
8141 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8142 )?;
8143 }
8144 Ok(())
8145 }
8146}
8147
8148mod workgroup_mem_init {
8151 use crate::EntryPoint;
8152
8153 use super::*;
8154
8155 enum Access {
8156 GlobalVariable(Handle<crate::GlobalVariable>),
8157 StructMember(Handle<crate::Type>, u32),
8158 Array(usize),
8159 }
8160
8161 impl Access {
8162 fn write<W: Write>(
8163 &self,
8164 writer: &mut W,
8165 names: &FastHashMap<NameKey, String>,
8166 ) -> Result<(), core::fmt::Error> {
8167 match *self {
8168 Access::GlobalVariable(handle) => {
8169 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8170 }
8171 Access::StructMember(handle, index) => {
8172 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8173 }
8174 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8175 }
8176 }
8177 }
8178
8179 struct AccessStack {
8180 stack: Vec<Access>,
8181 array_depth: usize,
8182 }
8183
8184 impl AccessStack {
8185 const fn new() -> Self {
8186 Self {
8187 stack: Vec::new(),
8188 array_depth: 0,
8189 }
8190 }
8191
8192 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8193 let array_depth = self.array_depth;
8194 self.stack.push(Access::Array(array_depth));
8195 self.array_depth += 1;
8196 let res = cb(self, array_depth);
8197 self.stack.pop();
8198 self.array_depth -= 1;
8199 res
8200 }
8201
8202 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8203 self.stack.push(new);
8204 let res = cb(self);
8205 self.stack.pop();
8206 res
8207 }
8208
8209 fn write<W: Write>(
8210 &self,
8211 writer: &mut W,
8212 names: &FastHashMap<NameKey, String>,
8213 ) -> Result<(), core::fmt::Error> {
8214 for next in self.stack.iter() {
8215 next.write(writer, names)?;
8216 }
8217 Ok(())
8218 }
8219 }
8220
8221 impl<W: Write> Writer<W> {
8222 pub(super) fn need_workgroup_variables_initialization(
8223 &mut self,
8224 options: &Options,
8225 ep: &EntryPoint,
8226 module: &crate::Module,
8227 fun_info: &valid::FunctionInfo,
8228 ) -> bool {
8229 let is_task = ep.stage == crate::ShaderStage::Task;
8230 options.zero_initialize_workgroup_memory
8231 && ep.stage.compute_like()
8232 && module.global_variables.iter().any(|(handle, var)| {
8233 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8234 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8235 !fun_info[handle].is_empty() && is_right_address_space
8236 })
8237 }
8238
8239 pub fn write_workgroup_variables_initialization(
8240 &mut self,
8241 module: &crate::Module,
8242 module_info: &valid::ModuleInfo,
8243 fun_info: &valid::FunctionInfo,
8244 local_invocation_index: Option<&NameKey>,
8245 stage: crate::ShaderStage,
8246 ) -> BackendResult {
8247 let level = back::Level(1);
8248
8249 writeln!(
8250 self.out,
8251 "{}if ({} == 0u) {{",
8252 level,
8253 local_invocation_index
8254 .map(|name_key| self.names[name_key].as_str())
8255 .unwrap_or("__local_invocation_index"),
8256 )?;
8257
8258 let mut access_stack = AccessStack::new();
8259
8260 let is_task = stage == crate::ShaderStage::Task;
8261 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8262 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8263 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8264 !fun_info[handle].is_empty() && is_right_address_space
8265 });
8266
8267 for (handle, var) in vars {
8268 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8269 self.write_workgroup_variable_initialization(
8270 module,
8271 module_info,
8272 var.ty,
8273 access_stack,
8274 level.next(),
8275 )
8276 })?;
8277 }
8278
8279 writeln!(self.out, "{level}}}")?;
8280 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8281 }
8282
8283 fn write_workgroup_variable_initialization(
8284 &mut self,
8285 module: &crate::Module,
8286 module_info: &valid::ModuleInfo,
8287 ty: Handle<crate::Type>,
8288 access_stack: &mut AccessStack,
8289 level: back::Level,
8290 ) -> BackendResult {
8291 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8292 write!(self.out, "{level}")?;
8293 access_stack.write(&mut self.out, &self.names)?;
8294 writeln!(self.out, " = {{}};")?;
8295 } else {
8296 match module.types[ty].inner {
8297 crate::TypeInner::Atomic { .. } => {
8298 write!(
8299 self.out,
8300 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8301 )?;
8302 access_stack.write(&mut self.out, &self.names)?;
8303 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8304 }
8305 crate::TypeInner::Array { base, size, .. } => {
8306 let count = match size.resolve(module.to_ctx())? {
8307 proc::IndexableLength::Known(count) => count,
8308 proc::IndexableLength::Dynamic => unreachable!(),
8309 };
8310
8311 access_stack.enter_array(|access_stack, array_depth| {
8312 writeln!(
8313 self.out,
8314 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8315 )?;
8316 self.write_workgroup_variable_initialization(
8317 module,
8318 module_info,
8319 base,
8320 access_stack,
8321 level.next(),
8322 )?;
8323 writeln!(self.out, "{level}}}")?;
8324 BackendResult::Ok(())
8325 })?;
8326 }
8327 crate::TypeInner::Struct { ref members, .. } => {
8328 for (index, member) in members.iter().enumerate() {
8329 access_stack.enter(
8330 Access::StructMember(ty, index as u32),
8331 |access_stack| {
8332 self.write_workgroup_variable_initialization(
8333 module,
8334 module_info,
8335 member.ty,
8336 access_stack,
8337 level,
8338 )
8339 },
8340 )?;
8341 }
8342 }
8343 _ => unreachable!(),
8344 }
8345 }
8346
8347 Ok(())
8348 }
8349 }
8350}
8351
8352impl crate::AtomicFunction {
8353 const fn to_msl(self) -> &'static str {
8354 match self {
8355 Self::Add => "fetch_add",
8356 Self::Subtract => "fetch_sub",
8357 Self::And => "fetch_and",
8358 Self::InclusiveOr => "fetch_or",
8359 Self::ExclusiveOr => "fetch_xor",
8360 Self::Min => "fetch_min",
8361 Self::Max => "fetch_max",
8362 Self::Exchange { compare: None } => "exchange",
8363 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8364 }
8365 }
8366
8367 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8368 Ok(match self {
8369 Self::Min => "min",
8370 Self::Max => "max",
8371 _ => Err(Error::FeatureNotImplemented(
8372 "64-bit atomic operation other than min/max".to_string(),
8373 ))?,
8374 })
8375 }
8376}