1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17 ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18 TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21 arena::{Handle, HandleSet},
22 back::{
23 self, get_entry_points,
24 msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25 Baked,
26 },
27 common,
28 proc::{
29 self, concrete_int_scalars,
30 index::{self, BoundsCheck},
31 ExternalTextureNameKey, NameKey, TypeResolution,
32 },
33 valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59 "nagaTextureSampleBaseClampToEdge";
60pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76fn put_numeric_type(
85 out: &mut impl Write,
86 scalar: crate::Scalar,
87 sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89 match (scalar, sizes) {
90 (scalar, &[]) => {
91 write!(out, "{}", scalar.to_msl_name())
92 }
93 (scalar, &[rows]) => {
94 write!(
95 out,
96 "{}::{}{}",
97 NAMESPACE,
98 scalar.to_msl_name(),
99 common::vector_size_str(rows)
100 )
101 }
102 (scalar, &[rows, columns]) => {
103 write!(
104 out,
105 "{}::{}{}x{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(columns),
109 common::vector_size_str(rows)
110 )
111 }
112 (_, _) => Ok(()), }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117 use crate::ScalarKind::*;
118 match scalar.kind {
119 Sint | Uint | AbstractInt | Bool => true,
120 Float | AbstractFloat => false,
121 }
122}
123
124const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140 }
141}
142
143struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161 self.0.write_prefixed(f, "size")
162 }
163}
164
165#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171 target_type: &'a str,
172 orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177 Self { target_type, orig }
178 }
179}
180
181impl Display for Reinterpreted<'_> {
182 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183 f.write_str(REINTERPRET_PREFIX)?;
184 f.write_str(self.target_type)?;
185 self.orig.write_prefixed(f, "_e")
186 }
187}
188
189pub(super) struct TypeContext<'a> {
190 pub handle: Handle<crate::Type>,
191 pub gctx: proc::GlobalCtx<'a>,
192 pub names: &'a FastHashMap<NameKey, String>,
193 pub access: crate::StorageAccess,
194 pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198 fn scalar(&self) -> Option<crate::Scalar> {
199 let ty = &self.gctx.types[self.handle];
200 ty.inner.scalar()
201 }
202
203 fn vector_size(&self) -> Option<crate::VectorSize> {
204 let ty = &self.gctx.types[self.handle];
205 match ty.inner {
206 crate::TypeInner::Vector { size, .. } => Some(size),
207 _ => None,
208 }
209 }
210
211 fn unwrap_array(self) -> Self {
212 match self.gctx.types[self.handle].inner {
213 crate::TypeInner::Array { base, .. } => Self {
214 handle: base,
215 ..self
216 },
217 _ => self,
218 }
219 }
220}
221
222impl Display for TypeContext<'_> {
223 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224 let ty = &self.gctx.types[self.handle];
225 if ty.needs_alias() && !self.first_time {
226 let name = &self.names[&NameKey::Type(self.handle)];
227 return write!(out, "{name}");
228 }
229
230 match ty.inner {
231 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232 crate::TypeInner::Atomic(scalar) => {
233 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234 }
235 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236 crate::TypeInner::Matrix {
237 columns,
238 rows,
239 scalar,
240 } => put_numeric_type(out, scalar, &[rows, columns]),
241 crate::TypeInner::CooperativeMatrix {
243 columns,
244 rows,
245 scalar,
246 role: _,
247 } => {
248 write!(
249 out,
250 "{NAMESPACE}::simdgroup_{}{}x{}",
251 scalar.to_msl_name(),
252 columns as u32,
253 rows as u32,
254 )
255 }
256 crate::TypeInner::Pointer { base, space } => {
257 let sub = Self {
258 handle: base,
259 first_time: false,
260 ..*self
261 };
262 let space_name = match space.to_msl_name() {
263 Some(name) => name,
264 None => return Ok(()),
265 };
266 write!(out, "{space_name} {sub}&")
267 }
268 crate::TypeInner::ValuePointer {
269 size,
270 scalar,
271 space,
272 } => {
273 match space.to_msl_name() {
274 Some(name) => write!(out, "{name} ")?,
275 None => return Ok(()),
276 };
277 match size {
278 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279 None => put_numeric_type(out, scalar, &[])?,
280 };
281
282 write!(out, "&")
283 }
284 crate::TypeInner::Array { base, .. } => {
285 let sub = Self {
286 handle: base,
287 first_time: false,
288 ..*self
289 };
290 write!(out, "{sub}")
293 }
294 crate::TypeInner::Struct { .. } => unreachable!(),
295 crate::TypeInner::Image {
296 dim,
297 arrayed,
298 class,
299 } => {
300 let dim_str = match dim {
301 crate::ImageDimension::D1 => "1d",
302 crate::ImageDimension::D2 => "2d",
303 crate::ImageDimension::D3 => "3d",
304 crate::ImageDimension::Cube => "cube",
305 };
306 let (texture_str, msaa_str, scalar, access) = match class {
307 crate::ImageClass::Sampled { kind, multi } => {
308 let (msaa_str, access) = if multi {
309 ("_ms", "read")
310 } else {
311 ("", "sample")
312 };
313 let scalar = crate::Scalar { kind, width: 4 };
314 ("texture", msaa_str, scalar, access)
315 }
316 crate::ImageClass::Depth { multi } => {
317 let (msaa_str, access) = if multi {
318 ("_ms", "read")
319 } else {
320 ("", "sample")
321 };
322 let scalar = crate::Scalar {
323 kind: crate::ScalarKind::Float,
324 width: 4,
325 };
326 ("depth", msaa_str, scalar, access)
327 }
328 crate::ImageClass::Storage { format, .. } => {
329 let access = if self
330 .access
331 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332 {
333 "read_write"
334 } else if self.access.contains(crate::StorageAccess::STORE) {
335 "write"
336 } else if self.access.contains(crate::StorageAccess::LOAD) {
337 "read"
338 } else {
339 log::warn!(
340 "Storage access for {:?} (name '{}'): {:?}",
341 self.handle,
342 ty.name.as_deref().unwrap_or_default(),
343 self.access
344 );
345 unreachable!("module is not valid");
346 };
347 ("texture", "", format.into(), access)
348 }
349 crate::ImageClass::External => {
350 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351 }
352 };
353 let base_name = scalar.to_msl_name();
354 let array_str = if arrayed { "_array" } else { "" };
355 write!(
356 out,
357 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358 )
359 }
360 crate::TypeInner::Sampler { comparison: _ } => {
361 write!(out, "{NAMESPACE}::sampler")
362 }
363 crate::TypeInner::AccelerationStructure { vertex_return } => {
364 if vertex_return {
365 unimplemented!("metal does not support vertex ray hit return")
366 }
367 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368 }
369 crate::TypeInner::RayQuery { vertex_return } => {
370 if vertex_return {
371 unimplemented!("metal does not support vertex ray hit return")
372 }
373 write!(out, "{}", super::ray::metal_intersector_ty())
374 }
375 crate::TypeInner::BindingArray { base, .. } => {
376 let base_tyname = Self {
377 handle: base,
378 first_time: false,
379 ..*self
380 };
381
382 write!(
383 out,
384 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385 )
386 }
387 }
388 }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392 pub module: &'a crate::Module,
393 pub names: &'a FastHashMap<NameKey, String>,
394 pub handle: Handle<crate::GlobalVariable>,
395 pub usage: valid::GlobalUse,
396 pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400 ty_name: String,
401 var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405 fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406 let var = &self.module.global_variables[self.handle];
407 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409 let storage_access = match var.space {
410 crate::AddressSpace::Storage { access } => access,
411 _ => match self.module.types[var.ty].inner {
412 crate::TypeInner::Image {
413 class: crate::ImageClass::Storage { access, .. },
414 ..
415 } => access,
416 crate::TypeInner::BindingArray { base, .. } => {
417 match self.module.types[base].inner {
418 crate::TypeInner::Image {
419 class: crate::ImageClass::Storage { access, .. },
420 ..
421 } => access,
422 _ => crate::StorageAccess::default(),
423 }
424 }
425 _ => crate::StorageAccess::default(),
426 },
427 };
428 let ty_name = TypeContext {
429 handle: var.ty,
430 gctx: self.module.to_ctx(),
431 names: self.names,
432 access: storage_access,
433 first_time: false,
434 };
435
436 let access = if var.space.needs_access_qualifier()
437 && !self.usage.intersects(valid::GlobalUse::WRITE)
438 {
439 "const"
440 } else {
441 ""
442 };
443 let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444 (Some(space), crate::AddressSpace::WorkGroup) => {
445 ("", space, access, if self.reference { "&" } else { "" })
446 }
447 (Some(space), _) if self.reference => {
448 let coherent = if var
449 .memory_decorations
450 .contains(crate::MemoryDecorations::COHERENT)
451 {
452 "coherent "
453 } else {
454 ""
455 };
456 (coherent, space, access, "&")
457 }
458 _ => ("", "", "", ""),
459 };
460
461 let ty = format!(
462 "{coherent}{space}{}{ty_name}{}{access}{reference}",
463 if space.is_empty() { "" } else { " " },
464 if access.is_empty() { "" } else { " " },
465 );
466
467 Ok(TypedGlobalVariableParts {
468 ty_name: ty,
469 var_name: name.clone(),
470 })
471 }
472 pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473 let parts = self.to_parts()?;
474
475 Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476 }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481 UnaryOp {
482 op: crate::UnaryOperator,
483 ty: (Option<crate::VectorSize>, crate::Scalar),
484 },
485 BinaryOp {
486 op: crate::BinaryOperator,
487 left_ty: (Option<crate::VectorSize>, crate::Scalar),
488 right_ty: (Option<crate::VectorSize>, crate::Scalar),
489 },
490 Math {
491 fun: crate::MathFunction,
492 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493 },
494 Cast {
495 src_scalar: crate::Scalar,
496 vector_size: Option<crate::VectorSize>,
497 dst_scalar: crate::Scalar,
498 },
499 ImageLoad {
500 class: crate::ImageClass,
501 },
502 ImageSample {
503 class: crate::ImageClass,
504 clamp_to_edge: bool,
505 },
506 ImageQuerySize {
507 class: crate::ImageClass,
508 },
509 CooperativeLoad {
510 space_name: &'static str,
511 columns: crate::CooperativeSize,
512 rows: crate::CooperativeSize,
513 scalar: crate::Scalar,
514 },
515 CooperativeMultiplyAdd {
516 space_name: &'static str,
517 columns: crate::CooperativeSize,
518 rows: crate::CooperativeSize,
519 intermediate: crate::CooperativeSize,
520 scalar: crate::Scalar,
521 },
522 RayQueryGetIntersection {
523 committed: bool,
524 },
525}
526
527pub struct Writer<W> {
528 pub(super) out: W,
529 pub(super) names: FastHashMap<NameKey, String>,
530 pub(super) named_expressions: crate::NamedExpressions,
531 pub(super) need_bake_expressions: back::NeedBakeExpressions,
533 pub(super) namer: proc::Namer,
534 pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535 #[cfg(test)]
536 pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
537 #[cfg(test)]
538 pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
539 pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
542 pub(super) needs_object_memory_barriers: bool,
543}
544
545impl crate::Scalar {
546 pub(super) fn to_msl_name(self) -> &'static str {
547 use crate::ScalarKind as Sk;
548 match self {
549 Self {
550 kind: Sk::Float,
551 width: 4,
552 } => "float",
553 Self {
554 kind: Sk::Float,
555 width: 2,
556 } => "half",
557 Self {
558 kind: Sk::Sint,
559 width: 2,
560 } => "short",
561 Self {
562 kind: Sk::Uint,
563 width: 2,
564 } => "ushort",
565 Self {
566 kind: Sk::Sint,
567 width: 4,
568 } => "int",
569 Self {
570 kind: Sk::Uint,
571 width: 4,
572 } => "uint",
573 Self {
574 kind: Sk::Sint,
575 width: 8,
576 } => "long",
577 Self {
578 kind: Sk::Uint,
579 width: 8,
580 } => "ulong",
581 Self {
582 kind: Sk::Bool,
583 width: _,
584 } => "bool",
585 Self {
586 kind: Sk::AbstractInt | Sk::AbstractFloat,
587 width: _,
588 } => unreachable!("Found Abstract scalar kind"),
589 _ => unreachable!("Unsupported scalar kind: {:?}", self),
590 }
591 }
592}
593
594const fn separate(need_separator: bool) -> &'static str {
595 if need_separator {
596 ","
597 } else {
598 ""
599 }
600}
601
602fn should_pack_struct_member(
603 members: &[crate::StructMember],
604 span: u32,
605 index: usize,
606 module: &crate::Module,
607) -> Option<crate::Scalar> {
608 let member = &members[index];
609
610 let ty_inner = &module.types[member.ty].inner;
611 let last_offset = member.offset + ty_inner.size(module.to_ctx());
612 let next_offset = match members.get(index + 1) {
613 Some(next) => next.offset,
614 None => span,
615 };
616 let is_tight = next_offset == last_offset;
617
618 match *ty_inner {
619 crate::TypeInner::Vector {
620 size: crate::VectorSize::Tri,
621 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
622 } if is_tight => Some(scalar),
623 _ => None,
624 }
625}
626
627fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
628 match arena[ty].inner {
629 crate::TypeInner::Struct { ref members, .. } => {
630 if let Some(member) = members.last() {
631 if let crate::TypeInner::Array {
632 size: crate::ArraySize::Dynamic,
633 ..
634 } = arena[member.ty].inner
635 {
636 return true;
637 }
638 }
639 false
640 }
641 crate::TypeInner::Array {
642 size: crate::ArraySize::Dynamic,
643 ..
644 } => true,
645 _ => false,
646 }
647}
648
649impl crate::AddressSpace {
650 const fn needs_pass_through(&self) -> bool {
654 match *self {
655 Self::Uniform
656 | Self::Storage { .. }
657 | Self::Private
658 | Self::WorkGroup
659 | Self::Immediate
660 | Self::Handle
661 | Self::TaskPayload => true,
662 Self::Function => false,
663 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
664 }
665 }
666
667 const fn needs_access_qualifier(&self) -> bool {
669 match *self {
670 Self::Storage { .. } => true,
675 Self::TaskPayload => true,
676 Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
677 Self::Private | Self::WorkGroup => false,
679 Self::Uniform | Self::Immediate => false,
681 Self::Handle | Self::Function => false,
683 }
684 }
685
686 const fn to_msl_name(self) -> Option<&'static str> {
687 match self {
688 Self::Handle => None,
689 Self::Uniform | Self::Immediate => Some("constant"),
690 Self::Storage { .. } => Some("device"),
691 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
695 Self::WorkGroup => Some("threadgroup"),
696 Self::TaskPayload => Some("object_data"),
697 Self::IncomingRayPayload => Some("ray_data"),
698 }
699 }
700}
701
702impl crate::Type {
703 const fn needs_alias(&self) -> bool {
705 use crate::TypeInner as Ti;
706
707 match self.inner {
708 Ti::Scalar(_)
710 | Ti::Vector { .. }
711 | Ti::Matrix { .. }
712 | Ti::CooperativeMatrix { .. }
713 | Ti::Atomic(_)
714 | Ti::Pointer { .. }
715 | Ti::ValuePointer { .. } => self.name.is_some(),
716 Ti::Struct { .. } | Ti::Array { .. } => true,
718 Ti::Image { .. }
720 | Ti::Sampler { .. }
721 | Ti::AccelerationStructure { .. }
722 | Ti::RayQuery { .. }
723 | Ti::BindingArray { .. } => false,
724 }
725 }
726}
727
728#[derive(Clone, Copy)]
729enum FunctionOrigin {
730 Handle(Handle<crate::Function>),
731 EntryPoint(proc::EntryPointIndex),
732}
733
734trait NameKeyExt {
735 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
736 match origin {
737 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
738 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
739 }
740 }
741
742 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
747 match origin {
748 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
749 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
750 }
751 }
752}
753
754impl NameKeyExt for NameKey {}
755
756#[derive(Clone, Copy)]
766enum LevelOfDetail {
767 Direct(Handle<crate::Expression>),
768 Restricted(Handle<crate::Expression>),
769}
770
771struct TexelAddress {
781 coordinate: Handle<crate::Expression>,
782 array_index: Option<Handle<crate::Expression>>,
783 sample: Option<Handle<crate::Expression>>,
784 level: Option<LevelOfDetail>,
785}
786
787pub(super) struct ExpressionContext<'a> {
788 pub(super) function: &'a crate::Function,
789 origin: FunctionOrigin,
790 pub(super) info: &'a valid::FunctionInfo,
791 pub(super) module: &'a crate::Module,
792 pub(super) mod_info: &'a valid::ModuleInfo,
793 pub(super) pipeline_options: &'a PipelineOptions,
794 pub(super) lang_version: (u8, u8),
795 pub(super) policies: index::BoundsCheckPolicies,
796
797 pub(super) guarded_indices: HandleSet<crate::Expression>,
801 pub(super) force_loop_bounding: bool,
803}
804
805impl<'a> ExpressionContext<'a> {
806 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
807 self.info[handle].ty.inner_with(&self.module.types)
808 }
809
810 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
817 let image_ty = self.resolve_type(image);
818 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
819 class.is_mipmapped() && dim != crate::ImageDimension::D1
820 } else {
821 false
822 }
823 }
824
825 fn choose_bounds_check_policy(
826 &self,
827 pointer: Handle<crate::Expression>,
828 ) -> index::BoundsCheckPolicy {
829 self.policies
830 .choose_policy(pointer, &self.module.types, self.info)
831 }
832
833 fn access_needs_check(
835 &self,
836 base: Handle<crate::Expression>,
837 index: index::GuardedIndex,
838 ) -> Option<index::IndexableLength> {
839 index::access_needs_check(
840 base,
841 index,
842 self.module,
843 &self.function.expressions,
844 self.info,
845 )
846 }
847
848 fn bounds_check_iter(
850 &self,
851 chain: Handle<crate::Expression>,
852 ) -> impl Iterator<Item = BoundsCheck> + '_ {
853 index::bounds_check_iter(chain, self.module, self.function, self.info)
854 }
855
856 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
858 index::oob_local_types(self.module, self.function, self.info, self.policies)
859 }
860
861 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
862 match self.function.expressions[expr_handle] {
863 crate::Expression::AccessIndex { base, index } => {
864 let ty = match *self.resolve_type(base) {
865 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
866 ref ty => ty,
867 };
868 match *ty {
869 crate::TypeInner::Struct {
870 ref members, span, ..
871 } => should_pack_struct_member(members, span, index as usize, self.module),
872 _ => None,
873 }
874 }
875 _ => None,
876 }
877 }
878}
879
880pub(super) struct StatementContext<'a> {
881 pub(super) expression: ExpressionContext<'a>,
882 pub(super) result_struct: Option<&'a str>,
883}
884
885impl<W: Write> Writer<W> {
886 pub fn new(out: W) -> Self {
888 Writer {
889 out,
890 names: FastHashMap::default(),
891 named_expressions: Default::default(),
892 need_bake_expressions: Default::default(),
893 namer: proc::Namer::default(),
894 wrapped_functions: FastHashSet::default(),
895 #[cfg(test)]
896 put_expression_stack_pointers: Default::default(),
897 #[cfg(test)]
898 put_block_stack_pointers: Default::default(),
899 struct_member_pads: FastHashSet::default(),
900 needs_object_memory_barriers: false,
901 }
902 }
903
904 pub fn finish(self) -> W {
907 self.out
908 }
909
910 fn gen_force_bounded_loop_statements(
1017 &mut self,
1018 level: back::Level,
1019 context: &StatementContext,
1020 ) -> Option<(String, String)> {
1021 if !context.expression.force_loop_bounding {
1022 return None;
1023 }
1024
1025 let loop_bound_name = self.namer.call("loop_bound");
1026 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1029 let level = level.next();
1030 let break_and_inc = format!(
1031 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1032{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1033 );
1034
1035 Some((decl, break_and_inc))
1036 }
1037
1038 fn put_call_parameters(
1039 &mut self,
1040 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1041 context: &ExpressionContext,
1042 ) -> BackendResult {
1043 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1044 writer.put_expression(expr, context, true)
1045 })
1046 }
1047
1048 fn put_call_parameters_impl<C, E>(
1049 &mut self,
1050 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1051 ctx: &C,
1052 put_expression: E,
1053 ) -> BackendResult
1054 where
1055 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1056 {
1057 write!(self.out, "(")?;
1058 for (i, handle) in parameters.enumerate() {
1059 if i != 0 {
1060 write!(self.out, ", ")?;
1061 }
1062 put_expression(self, ctx, handle)?;
1063 }
1064 write!(self.out, ")")?;
1065 Ok(())
1066 }
1067
1068 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1074 let oob_local_types = context.oob_local_types();
1075 for &ty in oob_local_types.iter() {
1076 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1077 self.names.insert(name_key, self.namer.call("oob"));
1078 }
1079
1080 for (name_key, ty, init) in context
1081 .function
1082 .local_variables
1083 .iter()
1084 .map(|(local_handle, local)| {
1085 let name_key = NameKey::local(context.origin, local_handle);
1086 (name_key, local.ty, local.init)
1087 })
1088 .chain(oob_local_types.iter().map(|&ty| {
1089 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1090 (name_key, ty, None)
1091 }))
1092 {
1093 let ty_name = TypeContext {
1094 handle: ty,
1095 gctx: context.module.to_ctx(),
1096 names: &self.names,
1097 access: crate::StorageAccess::empty(),
1098 first_time: false,
1099 };
1100 write!(
1101 self.out,
1102 "{}{} {}",
1103 back::INDENT,
1104 ty_name,
1105 self.names[&name_key]
1106 )?;
1107 match init {
1108 Some(value) => {
1109 write!(self.out, " = ")?;
1110 self.put_expression(value, context, true)?;
1111 }
1112 None => {
1113 write!(self.out, " = {{}}")?;
1114 }
1115 };
1116 writeln!(self.out, ";")?;
1117 }
1118 Ok(())
1119 }
1120
1121 fn put_level_of_detail(
1122 &mut self,
1123 level: LevelOfDetail,
1124 context: &ExpressionContext,
1125 ) -> BackendResult {
1126 match level {
1127 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1128 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1129 }
1130 Ok(())
1131 }
1132
1133 fn put_image_query(
1134 &mut self,
1135 image: Handle<crate::Expression>,
1136 query: &str,
1137 level: Option<LevelOfDetail>,
1138 context: &ExpressionContext,
1139 ) -> BackendResult {
1140 self.put_expression(image, context, false)?;
1141 write!(self.out, ".get_{query}(")?;
1142 if let Some(level) = level {
1143 self.put_level_of_detail(level, context)?;
1144 }
1145 write!(self.out, ")")?;
1146 Ok(())
1147 }
1148
1149 fn put_image_size_query(
1150 &mut self,
1151 image: Handle<crate::Expression>,
1152 level: Option<LevelOfDetail>,
1153 kind: crate::ScalarKind,
1154 context: &ExpressionContext,
1155 ) -> BackendResult {
1156 if let crate::TypeInner::Image {
1157 class: crate::ImageClass::External,
1158 ..
1159 } = *context.resolve_type(image)
1160 {
1161 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1162 self.put_expression(image, context, true)?;
1163 write!(self.out, ")")?;
1164 return Ok(());
1165 }
1166
1167 let dim = match *context.resolve_type(image) {
1170 crate::TypeInner::Image { dim, .. } => dim,
1171 ref other => unreachable!("Unexpected type {:?}", other),
1172 };
1173 let scalar = crate::Scalar { kind, width: 4 };
1174 let coordinate_type = scalar.to_msl_name();
1175 match dim {
1176 crate::ImageDimension::D1 => {
1177 if kind == crate::ScalarKind::Uint {
1181 self.put_image_query(image, "width", None, context)?;
1183 } else {
1184 write!(self.out, "int(")?;
1186 self.put_image_query(image, "width", None, context)?;
1187 write!(self.out, ")")?;
1188 }
1189 }
1190 crate::ImageDimension::D2 => {
1191 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1192 self.put_image_query(image, "width", level, context)?;
1193 write!(self.out, ", ")?;
1194 self.put_image_query(image, "height", level, context)?;
1195 write!(self.out, ")")?;
1196 }
1197 crate::ImageDimension::D3 => {
1198 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1199 self.put_image_query(image, "width", level, context)?;
1200 write!(self.out, ", ")?;
1201 self.put_image_query(image, "height", level, context)?;
1202 write!(self.out, ", ")?;
1203 self.put_image_query(image, "depth", level, context)?;
1204 write!(self.out, ")")?;
1205 }
1206 crate::ImageDimension::Cube => {
1207 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1208 self.put_image_query(image, "width", level, context)?;
1209 write!(self.out, ")")?;
1210 }
1211 }
1212 Ok(())
1213 }
1214
1215 fn put_cast_to_uint_scalar_or_vector(
1216 &mut self,
1217 expr: Handle<crate::Expression>,
1218 context: &ExpressionContext,
1219 ) -> BackendResult {
1220 match *context.resolve_type(expr) {
1222 crate::TypeInner::Scalar(_) => {
1223 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1224 }
1225 crate::TypeInner::Vector { size, .. } => {
1226 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1227 }
1228 _ => {
1229 return Err(Error::GenericValidation(
1230 "Invalid type for image coordinate".into(),
1231 ))
1232 }
1233 };
1234
1235 write!(self.out, "(")?;
1236 self.put_expression(expr, context, true)?;
1237 write!(self.out, ")")?;
1238 Ok(())
1239 }
1240
1241 fn put_image_sample_level(
1242 &mut self,
1243 image: Handle<crate::Expression>,
1244 level: crate::SampleLevel,
1245 context: &ExpressionContext,
1246 ) -> BackendResult {
1247 let has_levels = context.image_needs_lod(image);
1248 match level {
1249 crate::SampleLevel::Auto => {}
1250 crate::SampleLevel::Zero => {
1251 }
1253 _ if !has_levels => {
1254 log::warn!("1D image can't be sampled with level {level:?}");
1255 }
1256 crate::SampleLevel::Exact(h) => {
1257 write!(self.out, ", {NAMESPACE}::level(")?;
1258 self.put_expression(h, context, true)?;
1259 write!(self.out, ")")?;
1260 }
1261 crate::SampleLevel::Bias(h) => {
1262 write!(self.out, ", {NAMESPACE}::bias(")?;
1263 self.put_expression(h, context, true)?;
1264 write!(self.out, ")")?;
1265 }
1266 crate::SampleLevel::Gradient { x, y } => {
1267 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1268 self.put_expression(x, context, true)?;
1269 write!(self.out, ", ")?;
1270 self.put_expression(y, context, true)?;
1271 write!(self.out, ")")?;
1272 }
1273 }
1274 Ok(())
1275 }
1276
1277 fn put_image_coordinate_limits(
1278 &mut self,
1279 image: Handle<crate::Expression>,
1280 level: Option<LevelOfDetail>,
1281 context: &ExpressionContext,
1282 ) -> BackendResult {
1283 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1284 write!(self.out, " - 1")?;
1285 Ok(())
1286 }
1287
1288 fn put_restricted_scalar_image_index(
1306 &mut self,
1307 image: Handle<crate::Expression>,
1308 index: Handle<crate::Expression>,
1309 limit_method: &str,
1310 context: &ExpressionContext,
1311 ) -> BackendResult {
1312 write!(self.out, "{NAMESPACE}::min(uint(")?;
1313 self.put_expression(index, context, true)?;
1314 write!(self.out, "), ")?;
1315 self.put_expression(image, context, false)?;
1316 write!(self.out, ".{limit_method}() - 1)")?;
1317 Ok(())
1318 }
1319
1320 fn put_restricted_texel_address(
1321 &mut self,
1322 image: Handle<crate::Expression>,
1323 address: &TexelAddress,
1324 context: &ExpressionContext,
1325 ) -> BackendResult {
1326 write!(self.out, "{NAMESPACE}::min(")?;
1328 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1329 write!(self.out, ", ")?;
1330 self.put_image_coordinate_limits(image, address.level, context)?;
1331 write!(self.out, ")")?;
1332
1333 if let Some(array_index) = address.array_index {
1335 write!(self.out, ", ")?;
1336 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1337 }
1338
1339 if let Some(sample) = address.sample {
1341 write!(self.out, ", ")?;
1342 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1343 }
1344
1345 if let Some(level) = address.level {
1348 write!(self.out, ", ")?;
1349 self.put_level_of_detail(level, context)?;
1350 }
1351
1352 Ok(())
1353 }
1354
1355 fn put_image_access_bounds_check(
1357 &mut self,
1358 image: Handle<crate::Expression>,
1359 address: &TexelAddress,
1360 context: &ExpressionContext,
1361 ) -> BackendResult {
1362 let mut conjunction = "";
1363
1364 let level = if let Some(level) = address.level {
1367 write!(self.out, "uint(")?;
1368 self.put_level_of_detail(level, context)?;
1369 write!(self.out, ") < ")?;
1370 self.put_expression(image, context, true)?;
1371 write!(self.out, ".get_num_mip_levels()")?;
1372 conjunction = " && ";
1373 Some(level)
1374 } else {
1375 None
1376 };
1377
1378 if let Some(sample) = address.sample {
1380 write!(self.out, "uint(")?;
1381 self.put_expression(sample, context, true)?;
1382 write!(self.out, ") < ")?;
1383 self.put_expression(image, context, true)?;
1384 write!(self.out, ".get_num_samples()")?;
1385 conjunction = " && ";
1386 }
1387
1388 if let Some(array_index) = address.array_index {
1390 write!(self.out, "{conjunction}uint(")?;
1391 self.put_expression(array_index, context, true)?;
1392 write!(self.out, ") < ")?;
1393 self.put_expression(image, context, true)?;
1394 write!(self.out, ".get_array_size()")?;
1395 conjunction = " && ";
1396 }
1397
1398 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1400 crate::TypeInner::Vector { .. } => true,
1401 _ => false,
1402 };
1403 write!(self.out, "{conjunction}")?;
1404 if coord_is_vector {
1405 write!(self.out, "{NAMESPACE}::all(")?;
1406 }
1407 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408 write!(self.out, " < ")?;
1409 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1410 if coord_is_vector {
1411 write!(self.out, ")")?;
1412 }
1413
1414 Ok(())
1415 }
1416
1417 fn put_image_load(
1418 &mut self,
1419 load: Handle<crate::Expression>,
1420 image: Handle<crate::Expression>,
1421 mut address: TexelAddress,
1422 context: &ExpressionContext,
1423 ) -> BackendResult {
1424 if let crate::TypeInner::Image {
1425 class: crate::ImageClass::External,
1426 ..
1427 } = *context.resolve_type(image)
1428 {
1429 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1430 self.put_expression(image, context, true)?;
1431 write!(self.out, ", ")?;
1432 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1433 write!(self.out, ")")?;
1434 return Ok(());
1435 }
1436
1437 match context.policies.image_load {
1438 proc::BoundsCheckPolicy::Restrict => {
1439 if address.level.is_some() {
1442 address.level = if context.image_needs_lod(image) {
1443 Some(LevelOfDetail::Restricted(load))
1444 } else {
1445 None
1446 }
1447 }
1448
1449 self.put_expression(image, context, false)?;
1450 write!(self.out, ".read(")?;
1451 self.put_restricted_texel_address(image, &address, context)?;
1452 write!(self.out, ")")?;
1453 }
1454 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1455 write!(self.out, "(")?;
1456 self.put_image_access_bounds_check(image, &address, context)?;
1457 write!(self.out, " ? ")?;
1458 self.put_unchecked_image_load(image, &address, context)?;
1459 write!(self.out, ": DefaultConstructible())")?;
1460 }
1461 proc::BoundsCheckPolicy::Unchecked => {
1462 self.put_unchecked_image_load(image, &address, context)?;
1463 }
1464 }
1465
1466 Ok(())
1467 }
1468
1469 fn put_unchecked_image_load(
1470 &mut self,
1471 image: Handle<crate::Expression>,
1472 address: &TexelAddress,
1473 context: &ExpressionContext,
1474 ) -> BackendResult {
1475 self.put_expression(image, context, false)?;
1476 write!(self.out, ".read(")?;
1477 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1479 if let Some(expr) = address.array_index {
1480 write!(self.out, ", ")?;
1481 self.put_expression(expr, context, true)?;
1482 }
1483 if let Some(sample) = address.sample {
1484 write!(self.out, ", ")?;
1485 self.put_expression(sample, context, true)?;
1486 }
1487 if let Some(level) = address.level {
1488 if context.image_needs_lod(image) {
1489 write!(self.out, ", ")?;
1490 self.put_level_of_detail(level, context)?;
1491 }
1492 }
1493 write!(self.out, ")")?;
1494
1495 Ok(())
1496 }
1497
1498 fn put_image_atomic(
1499 &mut self,
1500 level: back::Level,
1501 image: Handle<crate::Expression>,
1502 address: &TexelAddress,
1503 fun: crate::AtomicFunction,
1504 value: Handle<crate::Expression>,
1505 context: &StatementContext,
1506 ) -> BackendResult {
1507 write!(self.out, "{level}")?;
1508 self.put_expression(image, &context.expression, false)?;
1509 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1510 fun.to_msl_64_bit()?
1511 } else {
1512 fun.to_msl()
1513 };
1514 write!(self.out, ".atomic_{op}(")?;
1515 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1517 write!(self.out, ", ")?;
1518 self.put_expression(value, &context.expression, true)?;
1519 writeln!(self.out, ");")?;
1520
1521 let value_ty = context.expression.resolve_type(value);
1527 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1528 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1529 (_, Some(8)) => "ulong4(0uL)",
1530 _ => "uint4(0u)",
1531 };
1532 let coord_ty = context.expression.resolve_type(address.coordinate);
1533 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1534 ""
1535 } else {
1536 ".x"
1537 };
1538 write!(self.out, "{level}if (")?;
1539 self.put_expression(address.coordinate, &context.expression, true)?;
1540 write!(self.out, "{x} == -99999) {{ ")?;
1541 self.put_expression(image, &context.expression, false)?;
1542 write!(self.out, ".write({zero_value}, ")?;
1543 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1544 if let Some(array_index) = address.array_index {
1545 write!(self.out, ", ")?;
1546 self.put_expression(array_index, &context.expression, true)?;
1547 }
1548 writeln!(self.out, "); }}")?;
1549
1550 Ok(())
1551 }
1552
1553 fn put_image_store(
1554 &mut self,
1555 level: back::Level,
1556 image: Handle<crate::Expression>,
1557 address: &TexelAddress,
1558 value: Handle<crate::Expression>,
1559 context: &StatementContext,
1560 ) -> BackendResult {
1561 write!(self.out, "{level}")?;
1562 self.put_expression(image, &context.expression, false)?;
1563 write!(self.out, ".write(")?;
1564 self.put_expression(value, &context.expression, true)?;
1565 write!(self.out, ", ")?;
1566 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1568 if let Some(expr) = address.array_index {
1569 write!(self.out, ", ")?;
1570 self.put_expression(expr, &context.expression, true)?;
1571 }
1572 writeln!(self.out, ");")?;
1573
1574 Ok(())
1575 }
1576
1577 fn put_dynamic_array_max_index(
1588 &mut self,
1589 handle: Handle<crate::GlobalVariable>,
1590 context: &ExpressionContext,
1591 ) -> BackendResult {
1592 let global = &context.module.global_variables[handle];
1593 let (offset, array_ty) = match context.module.types[global.ty].inner {
1594 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1595 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1596 None => return Err(Error::GenericValidation("Struct has no members".into())),
1597 },
1598 crate::TypeInner::Array {
1599 size: crate::ArraySize::Dynamic,
1600 ..
1601 } => (0, global.ty),
1602 ref ty => {
1603 return Err(Error::GenericValidation(format!(
1604 "Expected type with dynamic array, got {ty:?}"
1605 )))
1606 }
1607 };
1608
1609 let (size, stride) = match context.module.types[array_ty].inner {
1610 crate::TypeInner::Array { base, stride, .. } => (
1611 context.module.types[base]
1612 .inner
1613 .size(context.module.to_ctx()),
1614 stride,
1615 ),
1616 ref ty => {
1617 return Err(Error::GenericValidation(format!(
1618 "Expected array type, got {ty:?}"
1619 )))
1620 }
1621 };
1622
1623 write!(
1636 self.out,
1637 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1638 member = ArraySizeMember(handle),
1639 offset = offset,
1640 size = size,
1641 stride = stride,
1642 )?;
1643 Ok(())
1644 }
1645
1646 fn put_dot_product<T: Copy>(
1651 &mut self,
1652 arg: T,
1653 arg1: T,
1654 size: usize,
1655 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1656 ) -> BackendResult {
1657 write!(self.out, "(")?;
1660
1661 for index in 0..size {
1663 write!(self.out, " + ")?;
1666 extractor(self, arg, index)?;
1667 write!(self.out, " * ")?;
1668 extractor(self, arg1, index)?;
1669 }
1670
1671 write!(self.out, ")")?;
1672 Ok(())
1673 }
1674
1675 fn put_pack4x8(
1677 &mut self,
1678 arg: Handle<crate::Expression>,
1679 context: &ExpressionContext<'_>,
1680 was_signed: bool,
1681 clamp_bounds: Option<(&str, &str)>,
1682 ) -> Result<(), Error> {
1683 let write_arg = |this: &mut Self| -> BackendResult {
1684 if let Some((min, max)) = clamp_bounds {
1685 write!(this.out, "{NAMESPACE}::clamp(")?;
1687 this.put_expression(arg, context, true)?;
1688 write!(this.out, ", {min}, {max})")?;
1689 } else {
1690 this.put_expression(arg, context, true)?;
1691 }
1692 Ok(())
1693 };
1694
1695 if context.lang_version >= (2, 1) {
1696 let packed_type = if was_signed {
1697 "packed_char4"
1698 } else {
1699 "packed_uchar4"
1700 };
1701 write!(self.out, "as_type<uint>({packed_type}(")?;
1703 write_arg(self)?;
1704 write!(self.out, "))")?;
1705 } else {
1706 if was_signed {
1708 write!(self.out, "uint(")?;
1709 }
1710 write!(self.out, "(")?;
1711 write_arg(self)?;
1712 write!(self.out, "[0] & 0xFF) | ((")?;
1713 write_arg(self)?;
1714 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1715 write_arg(self)?;
1716 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1717 write_arg(self)?;
1718 write!(self.out, "[3] & 0xFF) << 24)")?;
1719 if was_signed {
1720 write!(self.out, ")")?;
1721 }
1722 }
1723
1724 Ok(())
1725 }
1726
1727 fn put_isign(
1730 &mut self,
1731 arg: Handle<crate::Expression>,
1732 context: &ExpressionContext,
1733 ) -> BackendResult {
1734 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1735 let scalar = context
1736 .resolve_type(arg)
1737 .scalar()
1738 .expect("put_isign should only be called for args which have an integer scalar type")
1739 .to_msl_name();
1740 match context.resolve_type(arg) {
1741 &crate::TypeInner::Vector { size, .. } => {
1742 let size = common::vector_size_str(size);
1743 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1744 }
1745 _ => {
1746 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1747 }
1748 }
1749 write!(self.out, ", (")?;
1750 self.put_expression(arg, context, true)?;
1751 write!(self.out, " > 0)), {scalar}(0), (")?;
1752 self.put_expression(arg, context, true)?;
1753 write!(self.out, " == 0))")?;
1754 Ok(())
1755 }
1756
1757 pub(super) fn put_const_expression(
1758 &mut self,
1759 expr_handle: Handle<crate::Expression>,
1760 module: &crate::Module,
1761 mod_info: &valid::ModuleInfo,
1762 arena: &crate::Arena<crate::Expression>,
1763 ) -> BackendResult {
1764 self.put_possibly_const_expression(
1765 expr_handle,
1766 arena,
1767 module,
1768 mod_info,
1769 &(module, mod_info),
1770 |&(_, mod_info), expr| &mod_info[expr],
1771 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1772 )
1773 }
1774
1775 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1776 match literal {
1777 crate::Literal::F64(_) => {
1778 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1779 }
1780 crate::Literal::F16(value) => {
1781 if value.is_infinite() {
1782 let sign = if value.is_sign_negative() { "-" } else { "" };
1783 write!(self.out, "{sign}INFINITY")?;
1784 } else if value.is_nan() {
1785 write!(self.out, "NAN")?;
1786 } else {
1787 let suffix = if value.fract() == f16::from_f32(0.0) {
1788 ".0h"
1789 } else {
1790 "h"
1791 };
1792 write!(self.out, "{value}{suffix}")?;
1793 }
1794 }
1795 crate::Literal::F32(value) => {
1796 if value.is_infinite() {
1797 let sign = if value.is_sign_negative() { "-" } else { "" };
1798 write!(self.out, "{sign}INFINITY")?;
1799 } else if value.is_nan() {
1800 write!(self.out, "NAN")?;
1801 } else {
1802 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1803 write!(self.out, "{value}{suffix}")?;
1804 }
1805 }
1806 crate::Literal::U16(value) => {
1807 write!(self.out, "static_cast<ushort>({value})")?;
1808 }
1809 crate::Literal::I16(value) => {
1810 write!(self.out, "static_cast<short>({value})")?;
1811 }
1812 crate::Literal::U32(value) => {
1813 write!(self.out, "{value}u")?;
1814 }
1815 crate::Literal::I32(value) => {
1816 if value == i32::MIN {
1821 write!(self.out, "({} - 1)", value + 1)?;
1822 } else {
1823 write!(self.out, "{value}")?;
1824 }
1825 }
1826 crate::Literal::U64(value) => {
1827 write!(self.out, "{value}uL")?;
1828 }
1829 crate::Literal::I64(value) => {
1830 if value == i64::MIN {
1837 write!(self.out, "({}L - 1L)", value + 1)?;
1838 } else {
1839 write!(self.out, "{value}L")?;
1840 }
1841 }
1842 crate::Literal::Bool(value) => {
1843 write!(self.out, "{value}")?;
1844 }
1845 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1846 return Err(Error::GenericValidation(
1847 "Unsupported abstract literal".into(),
1848 ));
1849 }
1850 }
1851 Ok(())
1852 }
1853
1854 #[allow(clippy::too_many_arguments)]
1855 fn put_possibly_const_expression<C, I, E>(
1856 &mut self,
1857 expr_handle: Handle<crate::Expression>,
1858 expressions: &crate::Arena<crate::Expression>,
1859 module: &crate::Module,
1860 mod_info: &valid::ModuleInfo,
1861 ctx: &C,
1862 get_expr_ty: I,
1863 put_expression: E,
1864 ) -> BackendResult
1865 where
1866 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1867 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1868 {
1869 match expressions[expr_handle] {
1870 crate::Expression::Literal(literal) => {
1871 self.put_literal(literal)?;
1872 }
1873 crate::Expression::Constant(handle) => {
1874 let constant = &module.constants[handle];
1875 if constant.name.is_some() {
1876 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1877 } else {
1878 self.put_const_expression(
1879 constant.init,
1880 module,
1881 mod_info,
1882 &module.global_expressions,
1883 )?;
1884 }
1885 }
1886 crate::Expression::ZeroValue(ty) => {
1887 let ty_name = TypeContext {
1888 handle: ty,
1889 gctx: module.to_ctx(),
1890 names: &self.names,
1891 access: crate::StorageAccess::empty(),
1892 first_time: false,
1893 };
1894 write!(self.out, "{ty_name} {{}}")?;
1895 }
1896 crate::Expression::Compose { ty, ref components } => {
1897 let ty_name = TypeContext {
1898 handle: ty,
1899 gctx: module.to_ctx(),
1900 names: &self.names,
1901 access: crate::StorageAccess::empty(),
1902 first_time: false,
1903 };
1904 write!(self.out, "{ty_name}")?;
1905 match module.types[ty].inner {
1906 crate::TypeInner::Scalar(_)
1907 | crate::TypeInner::Vector { .. }
1908 | crate::TypeInner::Matrix { .. } => {
1909 self.put_call_parameters_impl(
1910 components.iter().copied(),
1911 ctx,
1912 put_expression,
1913 )?;
1914 }
1915 crate::TypeInner::Array { .. } => {
1916 write!(self.out, " {{{{")?;
1919 for (index, &component) in components.iter().enumerate() {
1920 if index != 0 {
1921 write!(self.out, ", ")?;
1922 }
1923 put_expression(self, ctx, component)?;
1924 }
1925 write!(self.out, "}}}}")?;
1926 }
1927 crate::TypeInner::Struct { .. } => {
1928 write!(self.out, " {{")?;
1929 for (index, &component) in components.iter().enumerate() {
1930 if index != 0 {
1931 write!(self.out, ", ")?;
1932 }
1933 if self.struct_member_pads.contains(&(ty, index as u32)) {
1935 write!(self.out, "{{}}, ")?;
1936 }
1937 put_expression(self, ctx, component)?;
1938 }
1939 write!(self.out, "}}")?;
1940 }
1941 _ => return Err(Error::UnsupportedCompose(ty)),
1942 }
1943 }
1944 crate::Expression::Splat { size, value } => {
1945 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1946 crate::TypeInner::Scalar(scalar) => scalar,
1947 ref ty => {
1948 return Err(Error::GenericValidation(format!(
1949 "Expected splat value type must be a scalar, got {ty:?}",
1950 )))
1951 }
1952 };
1953 put_numeric_type(&mut self.out, scalar, &[size])?;
1954 write!(self.out, "(")?;
1955 put_expression(self, ctx, value)?;
1956 write!(self.out, ")")?;
1957 }
1958 _ => {
1959 return Err(Error::Override);
1960 }
1961 }
1962
1963 Ok(())
1964 }
1965
1966 pub(super) fn put_expression(
1978 &mut self,
1979 expr_handle: Handle<crate::Expression>,
1980 context: &ExpressionContext,
1981 is_scoped: bool,
1982 ) -> BackendResult {
1983 #[cfg(test)]
1985 self.put_expression_stack_pointers
1986 .insert(ptr::from_ref(&expr_handle).cast());
1987
1988 if let Some(name) = self.named_expressions.get(&expr_handle) {
1989 write!(self.out, "{name}")?;
1990 return Ok(());
1991 }
1992
1993 let expression = &context.function.expressions[expr_handle];
1994 match *expression {
1995 crate::Expression::Literal(_)
1996 | crate::Expression::Constant(_)
1997 | crate::Expression::ZeroValue(_)
1998 | crate::Expression::Compose { .. }
1999 | crate::Expression::Splat { .. } => {
2000 self.put_possibly_const_expression(
2001 expr_handle,
2002 &context.function.expressions,
2003 context.module,
2004 context.mod_info,
2005 context,
2006 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2007 |writer, context, expr| writer.put_expression(expr, context, true),
2008 )?;
2009 }
2010 crate::Expression::Override(_) => return Err(Error::Override),
2011 crate::Expression::Access { base, .. }
2012 | crate::Expression::AccessIndex { base, .. } => {
2013 let policy = context.choose_bounds_check_policy(base);
2019 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2020 && self.put_bounds_checks(
2021 expr_handle,
2022 context,
2023 back::Level(0),
2024 if is_scoped { "" } else { "(" },
2025 )?
2026 {
2027 write!(self.out, " ? ")?;
2028 self.put_access_chain(expr_handle, policy, context)?;
2029 write!(self.out, " : ")?;
2030
2031 if context.resolve_type(base).pointer_space().is_some() {
2032 let result_ty = context.info[expr_handle]
2036 .ty
2037 .inner_with(&context.module.types)
2038 .pointer_base_type();
2039 let result_ty_handle = match result_ty {
2040 Some(TypeResolution::Handle(handle)) => handle,
2041 Some(TypeResolution::Value(_)) => {
2042 unreachable!(
2050 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2051 );
2052 }
2053 None => {
2054 unreachable!(
2055 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2056 )
2057 }
2058 };
2059 let name_key =
2060 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2061 self.out.write_str(&self.names[&name_key])?;
2062 } else {
2063 write!(self.out, "DefaultConstructible()")?;
2064 }
2065
2066 if !is_scoped {
2067 write!(self.out, ")")?;
2068 }
2069 } else {
2070 self.put_access_chain(expr_handle, policy, context)?;
2071 }
2072 }
2073 crate::Expression::Swizzle {
2074 size,
2075 vector,
2076 pattern,
2077 } => {
2078 self.put_wrapped_expression_for_packed_vec3_access(
2079 vector,
2080 context,
2081 false,
2082 &Self::put_expression,
2083 )?;
2084 write!(self.out, ".")?;
2085 for &sc in pattern[..size as usize].iter() {
2086 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2087 }
2088 }
2089 crate::Expression::FunctionArgument(index) => {
2090 let name_key = match context.origin {
2091 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2092 FunctionOrigin::EntryPoint(ep_index) => {
2093 NameKey::EntryPointArgument(ep_index, index)
2094 }
2095 };
2096 let name = &self.names[&name_key];
2097 write!(self.out, "{name}")?;
2098 }
2099 crate::Expression::GlobalVariable(handle) => {
2100 let name = &self.names[&NameKey::GlobalVariable(handle)];
2101 write!(self.out, "{name}")?;
2102 }
2103 crate::Expression::LocalVariable(handle) => {
2104 let name_key = NameKey::local(context.origin, handle);
2105 let name = &self.names[&name_key];
2106 write!(self.out, "{name}")?;
2107 }
2108 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2109 crate::Expression::ImageSample {
2110 coordinate,
2111 image,
2112 sampler,
2113 clamp_to_edge: true,
2114 gather: None,
2115 array_index: None,
2116 offset: None,
2117 level: crate::SampleLevel::Zero,
2118 depth_ref: None,
2119 } => {
2120 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2121 self.put_expression(image, context, true)?;
2122 write!(self.out, ", ")?;
2123 self.put_expression(sampler, context, true)?;
2124 write!(self.out, ", ")?;
2125 self.put_expression(coordinate, context, true)?;
2126 write!(self.out, ")")?;
2127 }
2128 crate::Expression::ImageSample {
2129 image,
2130 sampler,
2131 gather,
2132 coordinate,
2133 array_index,
2134 offset,
2135 level,
2136 depth_ref,
2137 clamp_to_edge,
2138 } => {
2139 if clamp_to_edge {
2140 return Err(Error::GenericValidation(
2141 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2142 ));
2143 }
2144
2145 let main_op = match gather {
2146 Some(_) => "gather",
2147 None => "sample",
2148 };
2149 let comparison_op = match depth_ref {
2150 Some(_) => "_compare",
2151 None => "",
2152 };
2153 self.put_expression(image, context, false)?;
2154 write!(self.out, ".{main_op}{comparison_op}(")?;
2155 self.put_expression(sampler, context, true)?;
2156 write!(self.out, ", ")?;
2157 self.put_expression(coordinate, context, true)?;
2158 if let Some(expr) = array_index {
2159 write!(self.out, ", ")?;
2160 self.put_expression(expr, context, true)?;
2161 }
2162 if let Some(dref) = depth_ref {
2163 write!(self.out, ", ")?;
2164 self.put_expression(dref, context, true)?;
2165 }
2166
2167 self.put_image_sample_level(image, level, context)?;
2168
2169 if let Some(offset) = offset {
2170 write!(self.out, ", ")?;
2171 self.put_expression(offset, context, true)?;
2172 }
2173
2174 match gather {
2175 None | Some(crate::SwizzleComponent::X) => {}
2176 Some(component) => {
2177 let is_cube_map = match *context.resolve_type(image) {
2178 crate::TypeInner::Image {
2179 dim: crate::ImageDimension::Cube,
2180 ..
2181 } => true,
2182 _ => false,
2183 };
2184 if offset.is_none() && !is_cube_map {
2187 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2188 }
2189 let letter = back::COMPONENTS[component as usize];
2190 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2191 }
2192 }
2193 write!(self.out, ")")?;
2194 }
2195 crate::Expression::ImageLoad {
2196 image,
2197 coordinate,
2198 array_index,
2199 sample,
2200 level,
2201 } => {
2202 let address = TexelAddress {
2203 coordinate,
2204 array_index,
2205 sample,
2206 level: level.map(LevelOfDetail::Direct),
2207 };
2208 self.put_image_load(expr_handle, image, address, context)?;
2209 }
2210 crate::Expression::ImageQuery { image, query } => match query {
2213 crate::ImageQuery::Size { level } => {
2214 self.put_image_size_query(
2215 image,
2216 level.map(LevelOfDetail::Direct),
2217 crate::ScalarKind::Uint,
2218 context,
2219 )?;
2220 }
2221 crate::ImageQuery::NumLevels => {
2222 self.put_expression(image, context, false)?;
2223 write!(self.out, ".get_num_mip_levels()")?;
2224 }
2225 crate::ImageQuery::NumLayers => {
2226 self.put_expression(image, context, false)?;
2227 write!(self.out, ".get_array_size()")?;
2228 }
2229 crate::ImageQuery::NumSamples => {
2230 self.put_expression(image, context, false)?;
2231 write!(self.out, ".get_num_samples()")?;
2232 }
2233 },
2234 crate::Expression::Unary { op, expr } => {
2235 let op_str = match op {
2236 crate::UnaryOperator::Negate => {
2237 match context.resolve_type(expr).scalar_kind() {
2238 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2239 _ => "-",
2240 }
2241 }
2242 crate::UnaryOperator::LogicalNot => "!",
2243 crate::UnaryOperator::BitwiseNot => "~",
2244 };
2245 write!(self.out, "{op_str}(")?;
2246 self.put_expression(expr, context, false)?;
2247 write!(self.out, ")")?;
2248 }
2249 crate::Expression::Binary { op, left, right } => {
2250 let kind = context
2251 .resolve_type(left)
2252 .scalar_kind()
2253 .ok_or(Error::UnsupportedBinaryOp(op))?;
2254
2255 if op == crate::BinaryOperator::Divide
2256 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2257 {
2258 write!(self.out, "{DIV_FUNCTION}(")?;
2259 self.put_expression(left, context, true)?;
2260 write!(self.out, ", ")?;
2261 self.put_expression(right, context, true)?;
2262 write!(self.out, ")")?;
2263 } else if op == crate::BinaryOperator::Modulo
2264 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2265 {
2266 write!(self.out, "{MOD_FUNCTION}(")?;
2267 self.put_expression(left, context, true)?;
2268 write!(self.out, ", ")?;
2269 self.put_expression(right, context, true)?;
2270 write!(self.out, ")")?;
2271 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2272 write!(self.out, "{NAMESPACE}::fmod(")?;
2277 self.put_expression(left, context, true)?;
2278 write!(self.out, ", ")?;
2279 self.put_expression(right, context, true)?;
2280 write!(self.out, ")")?;
2281 } else if (op == crate::BinaryOperator::Add
2282 || op == crate::BinaryOperator::Subtract
2283 || op == crate::BinaryOperator::Multiply)
2284 && kind == crate::ScalarKind::Sint
2285 {
2286 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2287 crate::TypeInner::Scalar(scalar) => {
2288 Ok(crate::TypeInner::Scalar(crate::Scalar {
2289 kind: crate::ScalarKind::Uint,
2290 ..scalar
2291 }))
2292 }
2293 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2294 size,
2295 scalar: crate::Scalar {
2296 kind: crate::ScalarKind::Uint,
2297 ..scalar
2298 },
2299 }),
2300 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2301 };
2302
2303 self.put_bitcasted_expression(
2308 context.resolve_type(expr_handle),
2309 expr_handle,
2310 context,
2311 &|writer, context, is_scoped| {
2312 writer.put_binop(
2313 op,
2314 left,
2315 right,
2316 context,
2317 is_scoped,
2318 &|writer, expr, context, _is_scoped| {
2319 writer.put_bitcasted_expression(
2320 &to_unsigned(context.resolve_type(expr))?,
2321 expr,
2322 context,
2323 &|writer, context, is_scoped| {
2324 writer.put_expression(expr, context, is_scoped)
2325 },
2326 )
2327 },
2328 )
2329 },
2330 )?;
2331 } else {
2332 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2333 }
2334 }
2335 crate::Expression::Select {
2336 condition,
2337 accept,
2338 reject,
2339 } => match *context.resolve_type(condition) {
2340 crate::TypeInner::Scalar(crate::Scalar {
2341 kind: crate::ScalarKind::Bool,
2342 ..
2343 }) => {
2344 if !is_scoped {
2345 write!(self.out, "(")?;
2346 }
2347 self.put_expression(condition, context, false)?;
2348 write!(self.out, " ? ")?;
2349 self.put_expression(accept, context, false)?;
2350 write!(self.out, " : ")?;
2351 self.put_expression(reject, context, false)?;
2352 if !is_scoped {
2353 write!(self.out, ")")?;
2354 }
2355 }
2356 crate::TypeInner::Vector {
2357 scalar:
2358 crate::Scalar {
2359 kind: crate::ScalarKind::Bool,
2360 ..
2361 },
2362 ..
2363 } => {
2364 write!(self.out, "{NAMESPACE}::select(")?;
2365 self.put_expression(reject, context, true)?;
2366 write!(self.out, ", ")?;
2367 self.put_expression(accept, context, true)?;
2368 write!(self.out, ", ")?;
2369 self.put_expression(condition, context, true)?;
2370 write!(self.out, ")")?;
2371 }
2372 ref ty => {
2373 return Err(Error::GenericValidation(format!(
2374 "Expected select condition to be a non-bool type, got {ty:?}",
2375 )))
2376 }
2377 },
2378 crate::Expression::Derivative { axis, expr, .. } => {
2379 use crate::DerivativeAxis as Axis;
2380 let op = match axis {
2381 Axis::X => "dfdx",
2382 Axis::Y => "dfdy",
2383 Axis::Width => "fwidth",
2384 };
2385 write!(self.out, "{NAMESPACE}::{op}")?;
2386 self.put_call_parameters(iter::once(expr), context)?;
2387 }
2388 crate::Expression::Relational { fun, argument } => {
2389 let op = match fun {
2390 crate::RelationalFunction::Any => "any",
2391 crate::RelationalFunction::All => "all",
2392 crate::RelationalFunction::IsNan => "isnan",
2393 crate::RelationalFunction::IsInf => "isinf",
2394 };
2395 write!(self.out, "{NAMESPACE}::{op}")?;
2396 self.put_call_parameters(iter::once(argument), context)?;
2397 }
2398 crate::Expression::Math {
2399 fun,
2400 arg,
2401 arg1,
2402 arg2,
2403 arg3,
2404 } => {
2405 use crate::MathFunction as Mf;
2406
2407 let arg_type = context.resolve_type(arg);
2408 let scalar_argument = match arg_type {
2409 &crate::TypeInner::Scalar(_) => true,
2410 _ => false,
2411 };
2412
2413 let fun_name = match fun {
2414 Mf::Abs => "abs",
2416 Mf::Min => "min",
2417 Mf::Max => "max",
2418 Mf::Clamp => "clamp",
2419 Mf::Saturate => "saturate",
2420 Mf::Cos => "cos",
2422 Mf::Cosh => "cosh",
2423 Mf::Sin => "sin",
2424 Mf::Sinh => "sinh",
2425 Mf::Tan => "tan",
2426 Mf::Tanh => "tanh",
2427 Mf::Acos => "acos",
2428 Mf::Asin => "asin",
2429 Mf::Atan => "atan",
2430 Mf::Atan2 => "atan2",
2431 Mf::Asinh => "asinh",
2432 Mf::Acosh => "acosh",
2433 Mf::Atanh => "atanh",
2434 Mf::Radians => "",
2435 Mf::Degrees => "",
2436 Mf::Ceil => "ceil",
2438 Mf::Floor => "floor",
2439 Mf::Round => "rint",
2440 Mf::Fract => "fract",
2441 Mf::Trunc => "trunc",
2442 Mf::Modf => MODF_FUNCTION,
2443 Mf::Frexp => FREXP_FUNCTION,
2444 Mf::Ldexp => "ldexp",
2445 Mf::Exp => "exp",
2447 Mf::Exp2 => "exp2",
2448 Mf::Log => "log",
2449 Mf::Log2 => "log2",
2450 Mf::Pow => "pow",
2451 Mf::Dot => match *context.resolve_type(arg) {
2453 crate::TypeInner::Vector {
2454 scalar:
2455 crate::Scalar {
2456 kind: crate::ScalarKind::Float,
2458 ..
2459 },
2460 ..
2461 } => "dot",
2462 crate::TypeInner::Vector {
2463 size,
2464 scalar:
2465 scalar @ crate::Scalar {
2466 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2467 ..
2468 },
2469 } => {
2470 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2472 write!(self.out, "{fun_name}(")?;
2473 self.put_expression(arg, context, true)?;
2474 write!(self.out, ", ")?;
2475 self.put_expression(arg1.unwrap(), context, true)?;
2476 write!(self.out, ")")?;
2477 return Ok(());
2478 }
2479 _ => unreachable!(
2480 "Correct TypeInner for dot product should be already validated"
2481 ),
2482 },
2483 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2484 if context.lang_version >= (2, 1) {
2485 let packed_type = match fun {
2489 Mf::Dot4I8Packed => "packed_char4",
2490 Mf::Dot4U8Packed => "packed_uchar4",
2491 _ => unreachable!(),
2492 };
2493
2494 return self.put_dot_product(
2495 Reinterpreted::new(packed_type, arg),
2496 Reinterpreted::new(packed_type, arg1.unwrap()),
2497 4,
2498 |writer, arg, index| {
2499 write!(writer.out, "{arg}[{index}]")?;
2502 Ok(())
2503 },
2504 );
2505 } else {
2506 let conversion = match fun {
2510 Mf::Dot4I8Packed => "int",
2511 Mf::Dot4U8Packed => "",
2512 _ => unreachable!(),
2513 };
2514
2515 return self.put_dot_product(
2516 arg,
2517 arg1.unwrap(),
2518 4,
2519 |writer, arg, index| {
2520 write!(writer.out, "({conversion}(")?;
2521 writer.put_expression(arg, context, true)?;
2522 if index == 3 {
2523 write!(writer.out, ") >> 24)")?;
2524 } else {
2525 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2526 }
2527 Ok(())
2528 },
2529 );
2530 }
2531 }
2532 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2533 Mf::Cross => "cross",
2534 Mf::Distance => "distance",
2535 Mf::Length if scalar_argument => "abs",
2536 Mf::Length => "length",
2537 Mf::Normalize => "normalize",
2538 Mf::FaceForward => "faceforward",
2539 Mf::Reflect => "reflect",
2540 Mf::Refract => "refract",
2541 Mf::Sign => match arg_type.scalar_kind() {
2543 Some(crate::ScalarKind::Sint) => {
2544 return self.put_isign(arg, context);
2545 }
2546 _ => "sign",
2547 },
2548 Mf::Fma => "fma",
2549 Mf::Mix => "mix",
2550 Mf::Step => "step",
2551 Mf::SmoothStep => "smoothstep",
2552 Mf::Sqrt => "sqrt",
2553 Mf::InverseSqrt => "rsqrt",
2554 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2555 Mf::Transpose => "transpose",
2556 Mf::Determinant => "determinant",
2557 Mf::QuantizeToF16 => "",
2558 Mf::CountTrailingZeros => "ctz",
2560 Mf::CountLeadingZeros => "clz",
2561 Mf::CountOneBits => "popcount",
2562 Mf::ReverseBits => "reverse_bits",
2563 Mf::ExtractBits => "",
2564 Mf::InsertBits => "",
2565 Mf::FirstTrailingBit => "",
2566 Mf::FirstLeadingBit => "",
2567 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2569 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2570 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2571 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2572 Mf::Pack2x16float => "",
2573 Mf::Pack4xI8 => "",
2574 Mf::Pack4xU8 => "",
2575 Mf::Pack4xI8Clamp => "",
2576 Mf::Pack4xU8Clamp => "",
2577 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2579 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2580 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2581 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2582 Mf::Unpack2x16float => "",
2583 Mf::Unpack4xI8 => "",
2584 Mf::Unpack4xU8 => "",
2585 };
2586
2587 match fun {
2588 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2589 if context.lang_version < (1, 2) {
2598 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2599 }
2600 }
2601 _ => {}
2602 }
2603
2604 match fun {
2605 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2606 write!(self.out, "{ABS_FUNCTION}(")?;
2607 self.put_expression(arg, context, true)?;
2608 write!(self.out, ")")?;
2609 }
2610 Mf::Distance if scalar_argument => {
2611 write!(self.out, "{NAMESPACE}::abs(")?;
2612 self.put_expression(arg, context, false)?;
2613 write!(self.out, " - ")?;
2614 self.put_expression(arg1.unwrap(), context, false)?;
2615 write!(self.out, ")")?;
2616 }
2617 Mf::FirstTrailingBit => {
2618 let scalar = context.resolve_type(arg).scalar().unwrap();
2619 let constant = scalar.width * 8 + 1;
2620
2621 write!(self.out, "((({NAMESPACE}::ctz(")?;
2622 self.put_expression(arg, context, true)?;
2623 write!(self.out, ") + 1) % {constant}) - 1)")?;
2624 }
2625 Mf::FirstLeadingBit => {
2626 let inner = context.resolve_type(arg);
2627 let scalar = inner.scalar().unwrap();
2628 let constant = scalar.width * 8 - 1;
2629
2630 write!(
2631 self.out,
2632 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2633 )?;
2634
2635 if scalar.kind == crate::ScalarKind::Sint {
2636 write!(self.out, "{NAMESPACE}::select(")?;
2637 self.put_expression(arg, context, true)?;
2638 write!(self.out, ", ~")?;
2639 self.put_expression(arg, context, true)?;
2640 write!(self.out, ", ")?;
2641 self.put_expression(arg, context, true)?;
2642 write!(self.out, " < 0)")?;
2643 } else {
2644 self.put_expression(arg, context, true)?;
2645 }
2646
2647 write!(self.out, "), ")?;
2648
2649 match *inner {
2651 crate::TypeInner::Vector { size, scalar } => {
2652 let size = common::vector_size_str(size);
2653 let name = scalar.to_msl_name();
2654 write!(self.out, "{name}{size}")?;
2655 }
2656 crate::TypeInner::Scalar(scalar) => {
2657 let name = scalar.to_msl_name();
2658 write!(self.out, "{name}")?;
2659 }
2660 _ => (),
2661 }
2662
2663 write!(self.out, "(-1), ")?;
2664 self.put_expression(arg, context, true)?;
2665 write!(self.out, " == 0")?;
2666 if scalar.kind == crate::ScalarKind::Sint {
2667 write!(self.out, " || ")?;
2668 self.put_expression(arg, context, true)?;
2669 write!(self.out, " == -1")?;
2670 }
2671 write!(self.out, ")")?;
2672 }
2673 Mf::Unpack2x16float => {
2674 write!(self.out, "float2(as_type<half2>(")?;
2675 self.put_expression(arg, context, false)?;
2676 write!(self.out, "))")?;
2677 }
2678 Mf::Pack2x16float => {
2679 write!(self.out, "as_type<uint>(half2(")?;
2680 self.put_expression(arg, context, false)?;
2681 write!(self.out, "))")?;
2682 }
2683 Mf::ExtractBits => {
2684 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2701
2702 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2703 self.put_expression(arg, context, true)?;
2704 write!(self.out, ", {NAMESPACE}::min(")?;
2705 self.put_expression(arg1.unwrap(), context, true)?;
2706 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2707 self.put_expression(arg2.unwrap(), context, true)?;
2708 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2709 self.put_expression(arg1.unwrap(), context, true)?;
2710 write!(self.out, ", {scalar_bits}u)))")?;
2711 }
2712 Mf::InsertBits => {
2713 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2718
2719 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2720 self.put_expression(arg, context, true)?;
2721 write!(self.out, ", ")?;
2722 self.put_expression(arg1.unwrap(), context, true)?;
2723 write!(self.out, ", {NAMESPACE}::min(")?;
2724 self.put_expression(arg2.unwrap(), context, true)?;
2725 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2726 self.put_expression(arg3.unwrap(), context, true)?;
2727 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2728 self.put_expression(arg2.unwrap(), context, true)?;
2729 write!(self.out, ", {scalar_bits}u)))")?;
2730 }
2731 Mf::Radians => {
2732 write!(self.out, "((")?;
2733 self.put_expression(arg, context, false)?;
2734 write!(self.out, ") * 0.017453292519943295474)")?;
2735 }
2736 Mf::Degrees => {
2737 write!(self.out, "((")?;
2738 self.put_expression(arg, context, false)?;
2739 write!(self.out, ") * 57.295779513082322865)")?;
2740 }
2741 Mf::Modf | Mf::Frexp => {
2742 write!(self.out, "{fun_name}")?;
2743 self.put_call_parameters(iter::once(arg), context)?;
2744 }
2745 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2746 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2747 Mf::Pack4xI8Clamp => {
2748 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2749 }
2750 Mf::Pack4xU8Clamp => {
2751 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2752 }
2753 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2754 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2755 "u"
2756 } else {
2757 ""
2758 };
2759
2760 if context.lang_version >= (2, 1) {
2761 write!(
2763 self.out,
2764 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2765 )?;
2766 self.put_expression(arg, context, true)?;
2767 write!(self.out, "))")?;
2768 } else {
2769 write!(self.out, "({sign_prefix}int4(")?;
2771 self.put_expression(arg, context, true)?;
2772 write!(self.out, ", ")?;
2773 self.put_expression(arg, context, true)?;
2774 write!(self.out, " >> 8, ")?;
2775 self.put_expression(arg, context, true)?;
2776 write!(self.out, " >> 16, ")?;
2777 self.put_expression(arg, context, true)?;
2778 write!(self.out, " >> 24) << 24 >> 24)")?;
2779 }
2780 }
2781 Mf::QuantizeToF16 => {
2782 match *context.resolve_type(arg) {
2783 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2784 crate::TypeInner::Vector { size, .. } => write!(
2785 self.out,
2786 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2787 size = common::vector_size_str(size),
2788 )?,
2789 _ => unreachable!(
2790 "Correct TypeInner for QuantizeToF16 should be already validated"
2791 ),
2792 };
2793
2794 self.put_expression(arg, context, true)?;
2795 write!(self.out, "))")?;
2796 }
2797 _ => {
2798 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2799 self.put_call_parameters(
2800 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2801 context,
2802 )?;
2803 }
2804 }
2805 }
2806 crate::Expression::As {
2807 expr,
2808 kind,
2809 convert,
2810 } => match *context.resolve_type(expr) {
2811 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2812 if src.kind == crate::ScalarKind::Float
2813 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2814 && convert.is_some()
2815 {
2816 let fun_name = match (kind, convert) {
2820 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2821 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2822 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2823 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2824 _ => unreachable!(),
2825 };
2826 write!(self.out, "{fun_name}(")?;
2827 self.put_expression(expr, context, true)?;
2828 write!(self.out, ")")?;
2829 } else {
2830 let target_scalar = crate::Scalar {
2831 kind,
2832 width: convert.unwrap_or(src.width),
2833 };
2834 let op = match convert {
2835 Some(_) => "static_cast",
2836 None => "as_type",
2837 };
2838 write!(self.out, "{op}<")?;
2839 match *context.resolve_type(expr) {
2840 crate::TypeInner::Vector { size, .. } => {
2841 put_numeric_type(&mut self.out, target_scalar, &[size])?
2842 }
2843 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2844 };
2845 write!(self.out, ">(")?;
2846 self.put_expression(expr, context, true)?;
2847 write!(self.out, ")")?;
2848 }
2849 }
2850 crate::TypeInner::Matrix {
2851 columns,
2852 rows,
2853 scalar,
2854 } => {
2855 let target_scalar = crate::Scalar {
2856 kind,
2857 width: convert.unwrap_or(scalar.width),
2858 };
2859 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2860 write!(self.out, "(")?;
2861 self.put_expression(expr, context, true)?;
2862 write!(self.out, ")")?;
2863 }
2864 ref ty => {
2865 return Err(Error::GenericValidation(format!(
2866 "Unsupported type for As: {ty:?}"
2867 )))
2868 }
2869 },
2870 crate::Expression::CallResult(_)
2872 | crate::Expression::AtomicResult { .. }
2873 | crate::Expression::WorkGroupUniformLoadResult { .. }
2874 | crate::Expression::SubgroupBallotResult
2875 | crate::Expression::SubgroupOperationResult { .. }
2876 | crate::Expression::RayQueryProceedResult => {
2877 unreachable!()
2878 }
2879 crate::Expression::ArrayLength(expr) => {
2880 let global = match context.function.expressions[expr] {
2882 crate::Expression::AccessIndex { base, .. } => {
2883 match context.function.expressions[base] {
2884 crate::Expression::GlobalVariable(handle) => handle,
2885 ref ex => {
2886 return Err(Error::GenericValidation(format!(
2887 "Expected global variable in AccessIndex, got {ex:?}"
2888 )))
2889 }
2890 }
2891 }
2892 crate::Expression::GlobalVariable(handle) => handle,
2893 ref ex => {
2894 return Err(Error::GenericValidation(format!(
2895 "Unexpected expression in ArrayLength, got {ex:?}"
2896 )))
2897 }
2898 };
2899
2900 if !is_scoped {
2901 write!(self.out, "(")?;
2902 }
2903 write!(self.out, "1 + ")?;
2904 self.put_dynamic_array_max_index(global, context)?;
2905 if !is_scoped {
2906 write!(self.out, ")")?;
2907 }
2908 }
2909 crate::Expression::RayQueryVertexPositions { .. } => {
2910 unimplemented!()
2911 }
2912 crate::Expression::RayQueryGetIntersection { query, committed } => {
2913 if context.lang_version < (2, 4) {
2914 return Err(Error::UnsupportedRayTracing);
2915 }
2916
2917 write!(
2918 self.out,
2919 "{}_{committed}(",
2920 super::ray::INTERSECTION_FUNCTION_NAME
2921 )?;
2922 self.put_expression(query, context, true)?;
2923 write!(self.out, ")")?;
2924 }
2925 crate::Expression::CooperativeLoad { ref data, .. } => {
2926 if context.lang_version < (2, 3) {
2927 return Err(Error::UnsupportedCooperativeMatrix);
2928 }
2929 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2930 write!(self.out, "&")?;
2931 self.put_access_chain(data.pointer, context.policies.index, context)?;
2932 write!(self.out, ", ")?;
2933 self.put_expression(data.stride, context, true)?;
2934 write!(self.out, ", {})", data.row_major)?;
2935 }
2936 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2937 if context.lang_version < (2, 3) {
2938 return Err(Error::UnsupportedCooperativeMatrix);
2939 }
2940 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2941 self.put_expression(a, context, true)?;
2942 write!(self.out, ", ")?;
2943 self.put_expression(b, context, true)?;
2944 write!(self.out, ", ")?;
2945 self.put_expression(c, context, true)?;
2946 write!(self.out, ")")?;
2947 }
2948 }
2949 Ok(())
2950 }
2951
2952 fn put_binop<F>(
2955 &mut self,
2956 op: crate::BinaryOperator,
2957 left: Handle<crate::Expression>,
2958 right: Handle<crate::Expression>,
2959 context: &ExpressionContext,
2960 is_scoped: bool,
2961 put_expression: &F,
2962 ) -> BackendResult
2963 where
2964 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2965 {
2966 let op_str = back::binary_operation_str(op);
2967
2968 if !is_scoped {
2969 write!(self.out, "(")?;
2970 }
2971
2972 if op == crate::BinaryOperator::Multiply
2975 && matches!(
2976 context.resolve_type(right),
2977 &crate::TypeInner::Matrix { .. }
2978 )
2979 {
2980 self.put_wrapped_expression_for_packed_vec3_access(
2981 left,
2982 context,
2983 false,
2984 put_expression,
2985 )?;
2986 } else {
2987 put_expression(self, left, context, false)?;
2988 }
2989
2990 write!(self.out, " {op_str} ")?;
2991
2992 if op == crate::BinaryOperator::Multiply
2994 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2995 {
2996 self.put_wrapped_expression_for_packed_vec3_access(
2997 right,
2998 context,
2999 false,
3000 put_expression,
3001 )?;
3002 } else {
3003 put_expression(self, right, context, false)?;
3004 }
3005
3006 if !is_scoped {
3007 write!(self.out, ")")?;
3008 }
3009
3010 Ok(())
3011 }
3012
3013 fn put_wrapped_expression_for_packed_vec3_access<F>(
3015 &mut self,
3016 expr_handle: Handle<crate::Expression>,
3017 context: &ExpressionContext,
3018 is_scoped: bool,
3019 put_expression: &F,
3020 ) -> BackendResult
3021 where
3022 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3023 {
3024 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3025 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3026 put_expression(self, expr_handle, context, is_scoped)?;
3027 write!(self.out, ")")?;
3028 } else {
3029 put_expression(self, expr_handle, context, is_scoped)?;
3030 }
3031 Ok(())
3032 }
3033
3034 fn put_bitcasted_expression<F>(
3037 &mut self,
3038 cast_to: &crate::TypeInner,
3039 inner_expr: Handle<crate::Expression>,
3040 context: &ExpressionContext,
3041 put_expression: &F,
3042 ) -> BackendResult
3043 where
3044 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3045 {
3046 let needs_truncation = match *cast_to {
3051 crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3052 crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3053 _ => false,
3054 };
3055
3056 write!(self.out, "as_type<")?;
3057 match *cast_to {
3058 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3059 crate::TypeInner::Vector { size, scalar } => {
3060 put_numeric_type(&mut self.out, scalar, &[size])?
3061 }
3062 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3063 };
3064 write!(self.out, ">(")?;
3065
3066 if needs_truncation {
3067 write!(self.out, "static_cast<")?;
3068 let unsigned_scalar = match *cast_to {
3070 crate::TypeInner::Scalar(scalar) => crate::Scalar {
3071 kind: crate::ScalarKind::Uint,
3072 ..scalar
3073 },
3074 crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3075 kind: crate::ScalarKind::Uint,
3076 ..scalar
3077 },
3078 _ => unreachable!(),
3079 };
3080 match *cast_to {
3081 crate::TypeInner::Scalar(_) => {
3082 put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3083 }
3084 crate::TypeInner::Vector { size, .. } => {
3085 put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3086 }
3087 _ => unreachable!(),
3088 };
3089 write!(self.out, ">(")?;
3090 }
3091
3092 if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3094 put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3095 write!(self.out, "(")?;
3096 put_expression(self, context, true)?;
3097 write!(self.out, ")")?;
3098 } else {
3099 put_expression(self, context, true)?;
3100 }
3101
3102 if needs_truncation {
3103 write!(self.out, ")")?;
3104 }
3105
3106 write!(self.out, ")")?;
3107 Ok(())
3108 }
3109
3110 fn put_index(
3112 &mut self,
3113 index: index::GuardedIndex,
3114 context: &ExpressionContext,
3115 is_scoped: bool,
3116 ) -> BackendResult {
3117 match index {
3118 index::GuardedIndex::Expression(expr) => {
3119 self.put_expression(expr, context, is_scoped)?
3120 }
3121 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3122 }
3123 Ok(())
3124 }
3125
3126 fn put_bounds_checks(
3156 &mut self,
3157 chain: Handle<crate::Expression>,
3158 context: &ExpressionContext,
3159 level: back::Level,
3160 prefix: &'static str,
3161 ) -> Result<bool, Error> {
3162 let mut check_written = false;
3163
3164 for item in context.bounds_check_iter(chain) {
3166 let BoundsCheck {
3167 base,
3168 index,
3169 length,
3170 } = item;
3171
3172 if check_written {
3173 write!(self.out, " && ")?;
3174 } else {
3175 write!(self.out, "{level}{prefix}")?;
3176 check_written = true;
3177 }
3178
3179 write!(self.out, "uint(")?;
3183 self.put_index(index, context, true)?;
3184 self.out.write_str(") < ")?;
3185 match length {
3186 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3187 index::IndexableLength::Dynamic => {
3188 let global = context.function.originating_global(base).ok_or_else(|| {
3189 Error::GenericValidation("Could not find originating global".into())
3190 })?;
3191 write!(self.out, "1 + ")?;
3192 self.put_dynamic_array_max_index(global, context)?
3193 }
3194 }
3195 }
3196
3197 Ok(check_written)
3198 }
3199
3200 fn put_access_chain(
3220 &mut self,
3221 chain: Handle<crate::Expression>,
3222 policy: index::BoundsCheckPolicy,
3223 context: &ExpressionContext,
3224 ) -> BackendResult {
3225 match context.function.expressions[chain] {
3226 crate::Expression::Access { base, index } => {
3227 let mut base_ty = context.resolve_type(base);
3228
3229 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3231 base_ty = &context.module.types[base].inner;
3232 }
3233
3234 self.put_subscripted_access_chain(
3235 base,
3236 base_ty,
3237 index::GuardedIndex::Expression(index),
3238 policy,
3239 context,
3240 )?;
3241 }
3242 crate::Expression::AccessIndex { base, index } => {
3243 let base_resolution = &context.info[base].ty;
3244 let mut base_ty = base_resolution.inner_with(&context.module.types);
3245 let mut base_ty_handle = base_resolution.handle();
3246
3247 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3249 base_ty = &context.module.types[base].inner;
3250 base_ty_handle = Some(base);
3251 }
3252
3253 match *base_ty {
3257 crate::TypeInner::Struct { .. } => {
3258 let base_ty = base_ty_handle.unwrap();
3259 self.put_access_chain(base, policy, context)?;
3260 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3261 write!(self.out, ".{name}")?;
3262 }
3263 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3264 self.put_access_chain(base, policy, context)?;
3265 if context.get_packed_vec_kind(base).is_some() {
3268 write!(self.out, "[{index}]")?;
3269 } else {
3270 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3271 }
3272 }
3273 _ => {
3274 self.put_subscripted_access_chain(
3275 base,
3276 base_ty,
3277 index::GuardedIndex::Known(index),
3278 policy,
3279 context,
3280 )?;
3281 }
3282 }
3283 }
3284 _ => self.put_expression(chain, context, false)?,
3285 }
3286
3287 Ok(())
3288 }
3289
3290 fn put_subscripted_access_chain(
3307 &mut self,
3308 base: Handle<crate::Expression>,
3309 base_ty: &crate::TypeInner,
3310 index: index::GuardedIndex,
3311 policy: index::BoundsCheckPolicy,
3312 context: &ExpressionContext,
3313 ) -> BackendResult {
3314 let accessing_wrapped_array = match *base_ty {
3315 crate::TypeInner::Array {
3316 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3317 ..
3318 } => true,
3319 _ => false,
3320 };
3321 let accessing_wrapped_binding_array =
3322 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3323
3324 self.put_access_chain(base, policy, context)?;
3325 if accessing_wrapped_array {
3326 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3327 }
3328 write!(self.out, "[")?;
3329
3330 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3332 context.access_needs_check(base, index)
3333 } else {
3334 None
3335 };
3336 if let Some(limit) = restriction_needed {
3337 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3338 self.put_index(index, context, true)?;
3339 write!(self.out, "), ")?;
3340 match limit {
3341 index::IndexableLength::Known(limit) => {
3342 write!(self.out, "{}u", limit - 1)?;
3343 }
3344 index::IndexableLength::Dynamic => {
3345 let global = context.function.originating_global(base).ok_or_else(|| {
3346 Error::GenericValidation("Could not find originating global".into())
3347 })?;
3348 self.put_dynamic_array_max_index(global, context)?;
3349 }
3350 }
3351 write!(self.out, ")")?;
3352 } else {
3353 self.put_index(index, context, true)?;
3354 }
3355
3356 write!(self.out, "]")?;
3357
3358 if accessing_wrapped_binding_array {
3359 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3360 }
3361
3362 Ok(())
3363 }
3364
3365 fn put_load(
3366 &mut self,
3367 pointer: Handle<crate::Expression>,
3368 context: &ExpressionContext,
3369 is_scoped: bool,
3370 ) -> BackendResult {
3371 let policy = context.choose_bounds_check_policy(pointer);
3374 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3375 && self.put_bounds_checks(
3376 pointer,
3377 context,
3378 back::Level(0),
3379 if is_scoped { "" } else { "(" },
3380 )?
3381 {
3382 write!(self.out, " ? ")?;
3383 self.put_unchecked_load(pointer, policy, context)?;
3384 write!(self.out, " : DefaultConstructible()")?;
3385
3386 if !is_scoped {
3387 write!(self.out, ")")?;
3388 }
3389 } else {
3390 self.put_unchecked_load(pointer, policy, context)?;
3391 }
3392
3393 Ok(())
3394 }
3395
3396 fn put_unchecked_load(
3397 &mut self,
3398 pointer: Handle<crate::Expression>,
3399 policy: index::BoundsCheckPolicy,
3400 context: &ExpressionContext,
3401 ) -> BackendResult {
3402 let is_atomic_pointer = context
3403 .resolve_type(pointer)
3404 .is_atomic_pointer(&context.module.types);
3405
3406 if is_atomic_pointer {
3407 write!(
3408 self.out,
3409 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3410 )?;
3411 self.put_access_chain(pointer, policy, context)?;
3412 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3413 } else {
3414 self.put_access_chain(pointer, policy, context)?;
3418 }
3419
3420 Ok(())
3421 }
3422
3423 fn put_return_value(
3424 &mut self,
3425 level: back::Level,
3426 expr_handle: Handle<crate::Expression>,
3427 result_struct: Option<&str>,
3428 context: &ExpressionContext,
3429 ) -> BackendResult {
3430 match result_struct {
3431 Some(struct_name) => {
3432 let mut has_point_size = false;
3433 let result_ty = context.function.result.as_ref().unwrap().ty;
3434 match context.module.types[result_ty].inner {
3435 crate::TypeInner::Struct { ref members, .. } => {
3436 let tmp = "_tmp";
3437 write!(self.out, "{level}const auto {tmp} = ")?;
3438 self.put_expression(expr_handle, context, true)?;
3439 writeln!(self.out, ";")?;
3440 write!(self.out, "{level}return {struct_name} {{")?;
3441
3442 let mut is_first = true;
3443
3444 for (index, member) in members.iter().enumerate() {
3445 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3446 member.binding
3447 {
3448 has_point_size = true;
3449 if !context.pipeline_options.allow_and_force_point_size {
3450 continue;
3451 }
3452 }
3453
3454 let comma = if is_first { "" } else { "," };
3455 is_first = false;
3456 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3457 if let crate::TypeInner::Array {
3461 size: crate::ArraySize::Constant(size),
3462 ..
3463 } = context.module.types[member.ty].inner
3464 {
3465 write!(self.out, "{comma} {{")?;
3466 for j in 0..size.get() {
3467 if j != 0 {
3468 write!(self.out, ",")?;
3469 }
3470 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3471 }
3472 write!(self.out, "}}")?;
3473 } else {
3474 write!(self.out, "{comma} {tmp}.{name}")?;
3475 }
3476 }
3477 }
3478 _ => {
3479 write!(self.out, "{level}return {struct_name} {{ ")?;
3480 self.put_expression(expr_handle, context, true)?;
3481 }
3482 }
3483
3484 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3485 let stage = context.module.entry_points[ep_index as usize].stage;
3486 if context.pipeline_options.allow_and_force_point_size
3487 && stage == crate::ShaderStage::Vertex
3488 && !has_point_size
3489 {
3490 write!(self.out, ", 1.0")?;
3492 }
3493 }
3494 write!(self.out, " }}")?;
3495 }
3496 None => {
3497 write!(self.out, "{level}return ")?;
3498 self.put_expression(expr_handle, context, true)?;
3499 }
3500 }
3501 writeln!(self.out, ";")?;
3502 Ok(())
3503 }
3504
3505 fn update_expressions_to_bake(
3510 &mut self,
3511 func: &crate::Function,
3512 info: &valid::FunctionInfo,
3513 context: &ExpressionContext,
3514 ) {
3515 use crate::Expression;
3516 self.need_bake_expressions.clear();
3517
3518 for (expr_handle, expr) in func.expressions.iter() {
3519 let expr_info = &info[expr_handle];
3522 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3523 if min_ref_count <= expr_info.ref_count {
3524 self.need_bake_expressions.insert(expr_handle);
3525 } else {
3526 match expr_info.ty {
3527 TypeResolution::Handle(h)
3529 if Some(h) == context.module.special_types.ray_desc =>
3530 {
3531 self.need_bake_expressions.insert(expr_handle);
3532 }
3533 _ => {}
3534 }
3535 }
3536
3537 if let Expression::Math {
3538 fun,
3539 arg,
3540 arg1,
3541 arg2,
3542 ..
3543 } = *expr
3544 {
3545 match fun {
3546 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3556 self.need_bake_expressions.insert(arg);
3557 self.need_bake_expressions.insert(arg1.unwrap());
3558 }
3559 crate::MathFunction::FirstLeadingBit => {
3560 self.need_bake_expressions.insert(arg);
3561 }
3562 crate::MathFunction::Pack4xI8
3563 | crate::MathFunction::Pack4xU8
3564 | crate::MathFunction::Pack4xI8Clamp
3565 | crate::MathFunction::Pack4xU8Clamp
3566 | crate::MathFunction::Unpack4xI8
3567 | crate::MathFunction::Unpack4xU8 => {
3568 if context.lang_version < (2, 1) {
3571 self.need_bake_expressions.insert(arg);
3572 }
3573 }
3574 crate::MathFunction::ExtractBits => {
3575 self.need_bake_expressions.insert(arg1.unwrap());
3577 }
3578 crate::MathFunction::InsertBits => {
3579 self.need_bake_expressions.insert(arg2.unwrap());
3581 }
3582 crate::MathFunction::Sign => {
3583 let inner = context.resolve_type(expr_handle);
3588 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3589 self.need_bake_expressions.insert(arg);
3590 }
3591 }
3592 _ => {}
3593 }
3594 }
3595 }
3596 }
3597
3598 pub(super) fn start_baking_expression(
3599 &mut self,
3600 handle: Handle<crate::Expression>,
3601 context: &ExpressionContext,
3602 name: &str,
3603 ) -> BackendResult {
3604 match context.info[handle].ty {
3605 TypeResolution::Handle(ty_handle) => {
3606 let ty_name = TypeContext {
3607 handle: ty_handle,
3608 gctx: context.module.to_ctx(),
3609 names: &self.names,
3610 access: crate::StorageAccess::empty(),
3611 first_time: false,
3612 };
3613 write!(self.out, "{ty_name}")?;
3614 }
3615 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3616 put_numeric_type(&mut self.out, scalar, &[])?;
3617 }
3618 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3619 put_numeric_type(&mut self.out, scalar, &[size])?;
3620 }
3621 TypeResolution::Value(crate::TypeInner::Matrix {
3622 columns,
3623 rows,
3624 scalar,
3625 }) => {
3626 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3627 }
3628 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3629 columns,
3630 rows,
3631 scalar,
3632 role: _,
3633 }) => {
3634 write!(
3635 self.out,
3636 "{}::simdgroup_{}{}x{}",
3637 NAMESPACE,
3638 scalar.to_msl_name(),
3639 columns as u32,
3640 rows as u32,
3641 )?;
3642 }
3643 TypeResolution::Value(ref other) => {
3644 log::warn!("Type {other:?} isn't a known local");
3645 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3646 }
3647 }
3648
3649 write!(self.out, " {name} = ")?;
3651
3652 Ok(())
3653 }
3654
3655 fn put_cache_restricted_level(
3668 &mut self,
3669 load: Handle<crate::Expression>,
3670 image: Handle<crate::Expression>,
3671 mip_level: Option<Handle<crate::Expression>>,
3672 indent: back::Level,
3673 context: &StatementContext,
3674 ) -> BackendResult {
3675 let level_of_detail = match mip_level {
3678 Some(level) => level,
3679 None => return Ok(()),
3680 };
3681
3682 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3683 || !context.expression.image_needs_lod(image)
3684 {
3685 return Ok(());
3686 }
3687
3688 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3689 self.put_restricted_scalar_image_index(
3690 image,
3691 level_of_detail,
3692 "get_num_mip_levels",
3693 &context.expression,
3694 )?;
3695 writeln!(self.out, ";")?;
3696
3697 Ok(())
3698 }
3699
3700 fn put_casting_to_packed_chars(
3706 &mut self,
3707 fun: crate::MathFunction,
3708 arg0: Handle<crate::Expression>,
3709 arg1: Handle<crate::Expression>,
3710 indent: back::Level,
3711 context: &StatementContext<'_>,
3712 ) -> Result<(), Error> {
3713 let packed_type = match fun {
3714 crate::MathFunction::Dot4I8Packed => "packed_char4",
3715 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3716 _ => unreachable!(),
3717 };
3718
3719 for arg in [arg0, arg1] {
3720 write!(
3721 self.out,
3722 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3723 Reinterpreted::new(packed_type, arg)
3724 )?;
3725 self.put_expression(arg, &context.expression, true)?;
3726 writeln!(self.out, ");")?;
3727 }
3728
3729 Ok(())
3730 }
3731
3732 fn put_block(
3733 &mut self,
3734 level: back::Level,
3735 statements: &[crate::Statement],
3736 context: &StatementContext,
3737 ) -> BackendResult {
3738 #[cfg(test)]
3740 self.put_block_stack_pointers
3741 .insert(ptr::from_ref(&level).cast());
3742
3743 for statement in statements {
3744 log::trace!("statement[{}] {:?}", level.0, statement);
3745 match *statement {
3746 crate::Statement::Emit(ref range) => {
3747 for handle in range.clone() {
3748 use crate::MathFunction as Mf;
3749
3750 match context.expression.function.expressions[handle] {
3751 crate::Expression::ImageLoad {
3754 image,
3755 level: mip_level,
3756 ..
3757 } => {
3758 self.put_cache_restricted_level(
3759 handle, image, mip_level, level, context,
3760 )?;
3761 }
3762
3763 crate::Expression::Math {
3772 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3773 arg,
3774 arg1,
3775 ..
3776 } if context.expression.lang_version >= (2, 1) => {
3777 self.put_casting_to_packed_chars(
3778 fun,
3779 arg,
3780 arg1.unwrap(),
3781 level,
3782 context,
3783 )?;
3784 }
3785
3786 _ => (),
3787 }
3788
3789 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3790 let expr_name = if ptr_class.is_some() {
3791 None } else if let Some(name) =
3793 context.expression.function.named_expressions.get(&handle)
3794 {
3795 Some(self.namer.call(name))
3805 } else {
3806 let bake = if context.expression.guarded_indices.contains(handle) {
3810 true
3811 } else {
3812 self.need_bake_expressions.contains(&handle)
3813 };
3814
3815 if bake {
3816 Some(Baked(handle).to_string())
3817 } else {
3818 None
3819 }
3820 };
3821
3822 if let Some(name) = expr_name {
3823 write!(self.out, "{level}")?;
3824 self.start_baking_expression(handle, &context.expression, &name)?;
3825 self.put_expression(handle, &context.expression, true)?;
3826 self.named_expressions.insert(handle, name);
3827 writeln!(self.out, ";")?;
3828 }
3829 }
3830 }
3831 crate::Statement::Block(ref block) => {
3832 if !block.is_empty() {
3833 writeln!(self.out, "{level}{{")?;
3834 self.put_block(level.next(), block, context)?;
3835 writeln!(self.out, "{level}}}")?;
3836 }
3837 }
3838 crate::Statement::If {
3839 condition,
3840 ref accept,
3841 ref reject,
3842 } => {
3843 write!(self.out, "{level}if (")?;
3844 self.put_expression(condition, &context.expression, true)?;
3845 writeln!(self.out, ") {{")?;
3846 self.put_block(level.next(), accept, context)?;
3847 if !reject.is_empty() {
3848 writeln!(self.out, "{level}}} else {{")?;
3849 self.put_block(level.next(), reject, context)?;
3850 }
3851 writeln!(self.out, "{level}}}")?;
3852 }
3853 crate::Statement::Switch {
3854 selector,
3855 ref cases,
3856 } => {
3857 write!(self.out, "{level}switch(")?;
3858 self.put_expression(selector, &context.expression, true)?;
3859 writeln!(self.out, ") {{")?;
3860 let lcase = level.next();
3861 for case in cases.iter() {
3862 match case.value {
3863 crate::SwitchValue::I32(value) => {
3864 write!(self.out, "{lcase}case {value}:")?;
3865 }
3866 crate::SwitchValue::U32(value) => {
3867 write!(self.out, "{lcase}case {value}u:")?;
3868 }
3869 crate::SwitchValue::Default => {
3870 write!(self.out, "{lcase}default:")?;
3871 }
3872 }
3873
3874 let write_block_braces = !(case.fall_through && case.body.is_empty());
3875 if write_block_braces {
3876 writeln!(self.out, " {{")?;
3877 } else {
3878 writeln!(self.out)?;
3879 }
3880
3881 self.put_block(lcase.next(), &case.body, context)?;
3882 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3883 {
3884 writeln!(self.out, "{}break;", lcase.next())?;
3885 }
3886
3887 if write_block_braces {
3888 writeln!(self.out, "{lcase}}}")?;
3889 }
3890 }
3891 writeln!(self.out, "{level}}}")?;
3892 }
3893 crate::Statement::Loop {
3894 ref body,
3895 ref continuing,
3896 break_if,
3897 } => {
3898 let force_loop_bound_statements =
3899 self.gen_force_bounded_loop_statements(level, context);
3900 let gate_name = (!continuing.is_empty() || break_if.is_some())
3901 .then(|| self.namer.call("loop_init"));
3902
3903 if let Some((ref decl, _)) = force_loop_bound_statements {
3904 writeln!(self.out, "{decl}")?;
3905 }
3906 if let Some(ref gate_name) = gate_name {
3907 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3908 }
3909
3910 writeln!(self.out, "{level}while(true) {{",)?;
3911 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3912 writeln!(self.out, "{break_and_inc}")?;
3913 }
3914 if let Some(ref gate_name) = gate_name {
3915 let lif = level.next();
3916 let lcontinuing = lif.next();
3917 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3918 self.put_block(lcontinuing, continuing, context)?;
3919 if let Some(condition) = break_if {
3920 write!(self.out, "{lcontinuing}if (")?;
3921 self.put_expression(condition, &context.expression, true)?;
3922 writeln!(self.out, ") {{")?;
3923 writeln!(self.out, "{}break;", lcontinuing.next())?;
3924 writeln!(self.out, "{lcontinuing}}}")?;
3925 }
3926 writeln!(self.out, "{lif}}}")?;
3927 writeln!(self.out, "{lif}{gate_name} = false;")?;
3928 }
3929 self.put_block(level.next(), body, context)?;
3930
3931 writeln!(self.out, "{level}}}")?;
3932 }
3933 crate::Statement::Break => {
3934 writeln!(self.out, "{level}break;")?;
3935 }
3936 crate::Statement::Continue => {
3937 writeln!(self.out, "{level}continue;")?;
3938 }
3939 crate::Statement::Return {
3940 value: Some(expr_handle),
3941 } => {
3942 self.put_return_value(
3943 level,
3944 expr_handle,
3945 context.result_struct,
3946 &context.expression,
3947 )?;
3948 }
3949 crate::Statement::Return { value: None } => {
3950 writeln!(self.out, "{level}return;")?;
3951 }
3952 crate::Statement::Kill => {
3953 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3954 }
3955 crate::Statement::ControlBarrier(flags)
3956 | crate::Statement::MemoryBarrier(flags) => {
3957 self.write_barrier(flags, level)?;
3958 }
3959 crate::Statement::Store { pointer, value } => {
3960 self.put_store(pointer, value, level, context)?
3961 }
3962 crate::Statement::ImageStore {
3963 image,
3964 coordinate,
3965 array_index,
3966 value,
3967 } => {
3968 let address = TexelAddress {
3969 coordinate,
3970 array_index,
3971 sample: None,
3972 level: None,
3973 };
3974 self.put_image_store(level, image, &address, value, context)?
3975 }
3976 crate::Statement::Call {
3977 function,
3978 ref arguments,
3979 result,
3980 } => {
3981 write!(self.out, "{level}")?;
3982 if let Some(expr) = result {
3983 let name = Baked(expr).to_string();
3984 self.start_baking_expression(expr, &context.expression, &name)?;
3985 self.named_expressions.insert(expr, name);
3986 }
3987 let fun_name = &self.names[&NameKey::Function(function)];
3988 write!(self.out, "{fun_name}(")?;
3989 for (i, &handle) in arguments.iter().enumerate() {
3991 if i != 0 {
3992 write!(self.out, ", ")?;
3993 }
3994 self.put_expression(handle, &context.expression, true)?;
3995 }
3996 let mut separate = !arguments.is_empty();
3998 let fun_info = &context.expression.mod_info[function];
3999 let mut needs_buffer_sizes = false;
4000 for (handle, var) in context.expression.module.global_variables.iter() {
4001 if fun_info[handle].is_empty() {
4002 continue;
4003 }
4004 if var.space.needs_pass_through() {
4005 let name = &self.names[&NameKey::GlobalVariable(handle)];
4006 if separate {
4007 write!(self.out, ", ")?;
4008 } else {
4009 separate = true;
4010 }
4011 write!(self.out, "{name}")?;
4012 }
4013 needs_buffer_sizes |=
4014 needs_array_length(var.ty, &context.expression.module.types);
4015 }
4016 if needs_buffer_sizes {
4017 if separate {
4018 write!(self.out, ", ")?;
4019 }
4020 write!(self.out, "_buffer_sizes")?;
4021 }
4022
4023 writeln!(self.out, ");")?;
4025 }
4026 crate::Statement::Atomic {
4027 pointer,
4028 ref fun,
4029 value,
4030 result,
4031 } => {
4032 let context = &context.expression;
4033
4034 write!(self.out, "{level}")?;
4039 let fun_key = if let Some(result) = result {
4040 let res_name = Baked(result).to_string();
4041 self.start_baking_expression(result, context, &res_name)?;
4042 self.named_expressions.insert(result, res_name);
4043 fun.to_msl()
4044 } else if context.resolve_type(value).scalar_width() == Some(8) {
4045 fun.to_msl_64_bit()?
4046 } else {
4047 fun.to_msl()
4048 };
4049
4050 let policy = context.choose_bounds_check_policy(pointer);
4054 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4055 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4056
4057 if checked {
4059 write!(self.out, " ? ")?;
4060 }
4061
4062 match *fun {
4064 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4065 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4066 self.put_access_chain(pointer, policy, context)?;
4067 write!(self.out, ", ")?;
4068 self.put_expression(cmp, context, true)?;
4069 write!(self.out, ", ")?;
4070 self.put_expression(value, context, true)?;
4071 write!(self.out, ")")?;
4072 }
4073 _ => {
4074 write!(
4075 self.out,
4076 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4077 )?;
4078 self.put_access_chain(pointer, policy, context)?;
4079 write!(self.out, ", ")?;
4080 self.put_expression(value, context, true)?;
4081 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4082 }
4083 }
4084
4085 if checked {
4087 write!(self.out, " : DefaultConstructible()")?;
4088 }
4089
4090 writeln!(self.out, ";")?;
4092 }
4093 crate::Statement::ImageAtomic {
4094 image,
4095 coordinate,
4096 array_index,
4097 fun,
4098 value,
4099 } => {
4100 let address = TexelAddress {
4101 coordinate,
4102 array_index,
4103 sample: None,
4104 level: None,
4105 };
4106 self.put_image_atomic(level, image, &address, fun, value, context)?
4107 }
4108 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4109 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4110
4111 write!(self.out, "{level}")?;
4112 let name = self.namer.call("");
4113 self.start_baking_expression(result, &context.expression, &name)?;
4114 self.put_load(pointer, &context.expression, true)?;
4115 self.named_expressions.insert(result, name);
4116
4117 writeln!(self.out, ";")?;
4118 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4119 }
4120 crate::Statement::RayQuery { query, ref fun } => {
4121 self.write_ray_query_stmt(level, context, query, fun)?;
4122 }
4123 crate::Statement::SubgroupBallot { result, predicate } => {
4124 write!(self.out, "{level}")?;
4125 let name = self.namer.call("");
4126 self.start_baking_expression(result, &context.expression, &name)?;
4127 self.named_expressions.insert(result, name);
4128 write!(
4129 self.out,
4130 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4131 )?;
4132 if let Some(predicate) = predicate {
4133 self.put_expression(predicate, &context.expression, true)?;
4134 } else {
4135 write!(self.out, "true")?;
4136 }
4137 writeln!(self.out, "), 0, 0, 0);")?;
4138 }
4139 crate::Statement::SubgroupCollectiveOperation {
4140 op,
4141 collective_op,
4142 argument,
4143 result,
4144 } => {
4145 write!(self.out, "{level}")?;
4146 let name = self.namer.call("");
4147 self.start_baking_expression(result, &context.expression, &name)?;
4148 self.named_expressions.insert(result, name);
4149 match (collective_op, op) {
4150 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4151 write!(self.out, "{NAMESPACE}::simd_all(")?
4152 }
4153 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4154 write!(self.out, "{NAMESPACE}::simd_any(")?
4155 }
4156 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4157 write!(self.out, "{NAMESPACE}::simd_sum(")?
4158 }
4159 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4160 write!(self.out, "{NAMESPACE}::simd_product(")?
4161 }
4162 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4163 write!(self.out, "{NAMESPACE}::simd_max(")?
4164 }
4165 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4166 write!(self.out, "{NAMESPACE}::simd_min(")?
4167 }
4168 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4169 write!(self.out, "{NAMESPACE}::simd_and(")?
4170 }
4171 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4172 write!(self.out, "{NAMESPACE}::simd_or(")?
4173 }
4174 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4175 write!(self.out, "{NAMESPACE}::simd_xor(")?
4176 }
4177 (
4178 crate::CollectiveOperation::ExclusiveScan,
4179 crate::SubgroupOperation::Add,
4180 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4181 (
4182 crate::CollectiveOperation::ExclusiveScan,
4183 crate::SubgroupOperation::Mul,
4184 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4185 (
4186 crate::CollectiveOperation::InclusiveScan,
4187 crate::SubgroupOperation::Add,
4188 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4189 (
4190 crate::CollectiveOperation::InclusiveScan,
4191 crate::SubgroupOperation::Mul,
4192 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4193 _ => unimplemented!(),
4194 }
4195 self.put_expression(argument, &context.expression, true)?;
4196 writeln!(self.out, ");")?;
4197 }
4198 crate::Statement::SubgroupGather {
4199 mode,
4200 argument,
4201 result,
4202 } => {
4203 write!(self.out, "{level}")?;
4204 let name = self.namer.call("");
4205 self.start_baking_expression(result, &context.expression, &name)?;
4206 self.named_expressions.insert(result, name);
4207 match mode {
4208 crate::GatherMode::BroadcastFirst => {
4209 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4210 }
4211 crate::GatherMode::Broadcast(_) => {
4212 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4213 }
4214 crate::GatherMode::Shuffle(_) => {
4215 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4216 }
4217 crate::GatherMode::ShuffleDown(_) => {
4218 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4219 }
4220 crate::GatherMode::ShuffleUp(_) => {
4221 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4222 }
4223 crate::GatherMode::ShuffleXor(_) => {
4224 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4225 }
4226 crate::GatherMode::QuadBroadcast(_) => {
4227 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4228 }
4229 crate::GatherMode::QuadSwap(_) => {
4230 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4231 }
4232 }
4233 self.put_expression(argument, &context.expression, true)?;
4234 match mode {
4235 crate::GatherMode::BroadcastFirst => {}
4236 crate::GatherMode::Broadcast(index)
4237 | crate::GatherMode::Shuffle(index)
4238 | crate::GatherMode::ShuffleDown(index)
4239 | crate::GatherMode::ShuffleUp(index)
4240 | crate::GatherMode::ShuffleXor(index)
4241 | crate::GatherMode::QuadBroadcast(index) => {
4242 write!(self.out, ", ")?;
4243 self.put_expression(index, &context.expression, true)?;
4244 }
4245 crate::GatherMode::QuadSwap(direction) => {
4246 write!(self.out, ", ")?;
4247 match direction {
4248 crate::Direction::X => {
4249 write!(self.out, "1u")?;
4250 }
4251 crate::Direction::Y => {
4252 write!(self.out, "2u")?;
4253 }
4254 crate::Direction::Diagonal => {
4255 write!(self.out, "3u")?;
4256 }
4257 }
4258 }
4259 }
4260 writeln!(self.out, ");")?;
4261 }
4262 crate::Statement::CooperativeStore { target, ref data } => {
4263 write!(self.out, "{level}simdgroup_store(")?;
4264 self.put_expression(target, &context.expression, true)?;
4265 write!(self.out, ", &")?;
4266 self.put_access_chain(
4267 data.pointer,
4268 context.expression.policies.index,
4269 &context.expression,
4270 )?;
4271 write!(self.out, ", ")?;
4272 self.put_expression(data.stride, &context.expression, true)?;
4273 if data.row_major {
4274 let matrix_origin = "0";
4275 let transpose = true;
4276 write!(self.out, ", {matrix_origin}, {transpose}")?;
4277 }
4278 writeln!(self.out, ");")?;
4279 }
4280 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4281 }
4282 }
4283
4284 for statement in statements {
4287 if let crate::Statement::Emit(ref range) = *statement {
4288 for handle in range.clone() {
4289 self.named_expressions.shift_remove(&handle);
4290 }
4291 }
4292 }
4293 Ok(())
4294 }
4295
4296 fn put_store(
4297 &mut self,
4298 pointer: Handle<crate::Expression>,
4299 value: Handle<crate::Expression>,
4300 level: back::Level,
4301 context: &StatementContext,
4302 ) -> BackendResult {
4303 let policy = context.expression.choose_bounds_check_policy(pointer);
4304 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4305 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4306 {
4307 writeln!(self.out, ") {{")?;
4308 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4309 writeln!(self.out, "{level}}}")?;
4310 } else {
4311 self.put_unchecked_store(pointer, value, policy, level, context)?;
4312 }
4313
4314 Ok(())
4315 }
4316
4317 fn put_unchecked_store(
4318 &mut self,
4319 pointer: Handle<crate::Expression>,
4320 value: Handle<crate::Expression>,
4321 policy: index::BoundsCheckPolicy,
4322 level: back::Level,
4323 context: &StatementContext,
4324 ) -> BackendResult {
4325 let is_atomic_pointer = context
4326 .expression
4327 .resolve_type(pointer)
4328 .is_atomic_pointer(&context.expression.module.types);
4329
4330 if is_atomic_pointer {
4331 write!(
4332 self.out,
4333 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4334 )?;
4335 self.put_access_chain(pointer, policy, &context.expression)?;
4336 write!(self.out, ", ")?;
4337 self.put_expression(value, &context.expression, true)?;
4338 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4339 } else {
4340 write!(self.out, "{level}")?;
4341 self.put_access_chain(pointer, policy, &context.expression)?;
4342 write!(self.out, " = ")?;
4343 self.put_expression(value, &context.expression, true)?;
4344 writeln!(self.out, ";")?;
4345 }
4346
4347 Ok(())
4348 }
4349
4350 pub fn write(
4351 &mut self,
4352 module: &crate::Module,
4353 info: &valid::ModuleInfo,
4354 options: &Options,
4355 pipeline_options: &PipelineOptions,
4356 ) -> Result<TranslationInfo, Error> {
4357 self.names.clear();
4358 self.namer.reset(
4359 module,
4360 &super::keywords::RESERVED_SET,
4361 proc::KeywordSet::empty(),
4362 proc::CaseInsensitiveKeywordSet::empty(),
4363 &[
4364 CLAMPED_LOD_LOAD_PREFIX,
4365 super::ray::INTERSECTION_FUNCTION_NAME,
4366 ],
4367 &mut self.names,
4368 );
4369 self.wrapped_functions.clear();
4370 self.struct_member_pads.clear();
4371
4372 writeln!(
4373 self.out,
4374 "// language: metal{}.{}",
4375 options.lang_version.0, options.lang_version.1
4376 )?;
4377 writeln!(self.out, "#include <metal_stdlib>")?;
4378 writeln!(self.out, "#include <simd/simd.h>")?;
4379 writeln!(self.out)?;
4380 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4382
4383 if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4384 return Err(Error::UnsupportedMeshShader);
4385 }
4386 self.needs_object_memory_barriers = module
4387 .entry_points
4388 .iter()
4389 .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4390
4391 if module.special_types.ray_desc.is_some()
4392 || module.special_types.ray_intersection.is_some()
4393 {
4394 if options.lang_version < (2, 4) {
4395 return Err(Error::UnsupportedRayTracing);
4396 }
4397 }
4398
4399 if options
4400 .bounds_check_policies
4401 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4402 {
4403 self.put_default_constructible()?;
4404 }
4405 writeln!(self.out)?;
4406
4407 {
4408 let globals: Vec<Handle<crate::GlobalVariable>> = module
4411 .global_variables
4412 .iter()
4413 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4414 .map(|(handle, _)| handle)
4415 .collect();
4416
4417 let mut buffer_indices = vec![];
4418 for vbm in &pipeline_options.vertex_buffer_mappings {
4419 buffer_indices.push(vbm.id);
4420 }
4421
4422 if !globals.is_empty() || !buffer_indices.is_empty() {
4423 writeln!(self.out, "struct _mslBufferSizes {{")?;
4424
4425 for global in globals {
4426 writeln!(
4427 self.out,
4428 "{}uint {};",
4429 back::INDENT,
4430 ArraySizeMember(global)
4431 )?;
4432 }
4433
4434 for idx in buffer_indices {
4435 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4436 }
4437
4438 writeln!(self.out, "}};")?;
4439 writeln!(self.out)?;
4440 }
4441 };
4442
4443 self.write_type_defs(module)?;
4444 self.write_global_constants(module, info)?;
4445 self.write_functions(module, info, options, pipeline_options)
4446 }
4447
4448 fn put_default_constructible(&mut self) -> BackendResult {
4461 let tab = back::INDENT;
4462 writeln!(self.out, "struct DefaultConstructible {{")?;
4463 writeln!(self.out, "{tab}template<typename T>")?;
4464 writeln!(self.out, "{tab}operator T() && {{")?;
4465 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4466 writeln!(self.out, "{tab}}}")?;
4467 writeln!(self.out, "}};")?;
4468 Ok(())
4469 }
4470
4471 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4472 let mut generated_argument_buffer_wrapper = false;
4473 let mut generated_external_texture_wrapper = false;
4474 for (handle, ty) in module.types.iter() {
4475 match ty.inner {
4476 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4477 writeln!(self.out, "template <typename T>")?;
4478 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4479 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4480 writeln!(self.out, "}};")?;
4481 generated_argument_buffer_wrapper = true;
4482 }
4483 crate::TypeInner::Image {
4484 class: crate::ImageClass::External,
4485 ..
4486 } if !generated_external_texture_wrapper => {
4487 let params_ty_name = &self.names
4488 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4489 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4490 writeln!(
4491 self.out,
4492 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4493 back::INDENT
4494 )?;
4495 writeln!(
4496 self.out,
4497 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4498 back::INDENT
4499 )?;
4500 writeln!(
4501 self.out,
4502 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4503 back::INDENT
4504 )?;
4505 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4506 writeln!(self.out, "}};")?;
4507 generated_external_texture_wrapper = true;
4508 }
4509 _ => {}
4510 }
4511
4512 if !ty.needs_alias() {
4513 continue;
4514 }
4515 let name = &self.names[&NameKey::Type(handle)];
4516 match ty.inner {
4517 crate::TypeInner::Array {
4531 base,
4532 size,
4533 stride: _,
4534 } => {
4535 let base_name = TypeContext {
4536 handle: base,
4537 gctx: module.to_ctx(),
4538 names: &self.names,
4539 access: crate::StorageAccess::empty(),
4540 first_time: false,
4541 };
4542
4543 match size.resolve(module.to_ctx())? {
4544 proc::IndexableLength::Known(size) => {
4545 writeln!(self.out, "struct {name} {{")?;
4546 writeln!(
4547 self.out,
4548 "{}{} {}[{}];",
4549 back::INDENT,
4550 base_name,
4551 WRAPPED_ARRAY_FIELD,
4552 size
4553 )?;
4554 writeln!(self.out, "}};")?;
4555 }
4556 proc::IndexableLength::Dynamic => {
4557 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4558 }
4559 }
4560 }
4561 crate::TypeInner::Struct {
4562 ref members, span, ..
4563 } => {
4564 writeln!(self.out, "struct {name} {{")?;
4565 let mut last_offset = 0;
4566 for (index, member) in members.iter().enumerate() {
4567 if member.offset > last_offset {
4568 self.struct_member_pads.insert((handle, index as u32));
4569 let pad = member.offset - last_offset;
4570 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4571 }
4572 let ty_inner = &module.types[member.ty].inner;
4573 last_offset = member.offset + ty_inner.size(module.to_ctx());
4574
4575 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4576
4577 match should_pack_struct_member(members, span, index, module) {
4579 Some(scalar) => {
4580 writeln!(
4581 self.out,
4582 "{}{}::packed_{}3 {};",
4583 back::INDENT,
4584 NAMESPACE,
4585 scalar.to_msl_name(),
4586 member_name
4587 )?;
4588 }
4589 None => {
4590 let base_name = TypeContext {
4591 handle: member.ty,
4592 gctx: module.to_ctx(),
4593 names: &self.names,
4594 access: crate::StorageAccess::empty(),
4595 first_time: false,
4596 };
4597 writeln!(
4598 self.out,
4599 "{}{} {};",
4600 back::INDENT,
4601 base_name,
4602 member_name
4603 )?;
4604
4605 if let crate::TypeInner::Vector {
4607 size: crate::VectorSize::Tri,
4608 scalar,
4609 } = *ty_inner
4610 {
4611 last_offset += scalar.width as u32;
4612 }
4613 }
4614 }
4615 }
4616 if last_offset < span {
4617 let pad = span - last_offset;
4618 writeln!(
4619 self.out,
4620 "{}char _pad{}[{}];",
4621 back::INDENT,
4622 members.len(),
4623 pad
4624 )?;
4625 }
4626 writeln!(self.out, "}};")?;
4627 }
4628 _ => {
4629 let ty_name = TypeContext {
4630 handle,
4631 gctx: module.to_ctx(),
4632 names: &self.names,
4633 access: crate::StorageAccess::empty(),
4634 first_time: true,
4635 };
4636 writeln!(self.out, "typedef {ty_name} {name};")?;
4637 }
4638 }
4639 }
4640
4641 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4643 match type_key {
4644 &crate::PredeclaredType::ModfResult { size, scalar }
4645 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4646 let arg_type_name_owner;
4647 let arg_type_name = if let Some(size) = size {
4648 arg_type_name_owner = format!(
4649 "{NAMESPACE}::{}{}",
4650 if scalar.width == 8 { "double" } else { "float" },
4651 size as u8
4652 );
4653 &arg_type_name_owner
4654 } else if scalar.width == 8 {
4655 "double"
4656 } else {
4657 "float"
4658 };
4659
4660 let other_type_name_owner;
4661 let (defined_func_name, called_func_name, other_type_name) =
4662 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4663 (MODF_FUNCTION, "modf", arg_type_name)
4664 } else {
4665 let other_type_name = if let Some(size) = size {
4666 other_type_name_owner = format!("int{}", size as u8);
4667 &other_type_name_owner
4668 } else {
4669 "int"
4670 };
4671 (FREXP_FUNCTION, "frexp", other_type_name)
4672 };
4673
4674 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4675
4676 writeln!(self.out)?;
4677 writeln!(
4678 self.out,
4679 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4680 {other_type_name} other;
4681 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4682 return {struct_name}{{ fract, other }};
4683}}"
4684 )?;
4685 }
4686 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4687 let arg_type_name = scalar.to_msl_name();
4688 let called_func_name = "atomic_compare_exchange_weak_explicit";
4689 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4690 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4691
4692 writeln!(self.out)?;
4693
4694 for address_space_name in ["device", "threadgroup"] {
4695 writeln!(
4696 self.out,
4697 "\
4698template <typename A>
4699{struct_name} {defined_func_name}(
4700 {address_space_name} A *atomic_ptr,
4701 {arg_type_name} cmp,
4702 {arg_type_name} v
4703) {{
4704 bool swapped = {NAMESPACE}::{called_func_name}(
4705 atomic_ptr, &cmp, v,
4706 metal::memory_order_relaxed, metal::memory_order_relaxed
4707 );
4708 return {struct_name}{{cmp, swapped}};
4709}}"
4710 )?;
4711 }
4712 }
4713 }
4714 }
4715
4716 Ok(())
4717 }
4718
4719 fn write_global_constants(
4721 &mut self,
4722 module: &crate::Module,
4723 mod_info: &valid::ModuleInfo,
4724 ) -> BackendResult {
4725 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4726
4727 for (handle, constant) in constants {
4728 let ty_name = TypeContext {
4729 handle: constant.ty,
4730 gctx: module.to_ctx(),
4731 names: &self.names,
4732 access: crate::StorageAccess::empty(),
4733 first_time: false,
4734 };
4735 let name = &self.names[&NameKey::Constant(handle)];
4736 write!(self.out, "constant {ty_name} {name} = ")?;
4737 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4738 writeln!(self.out, ";")?;
4739 }
4740
4741 Ok(())
4742 }
4743
4744 fn put_inline_sampler_properties(
4745 &mut self,
4746 level: back::Level,
4747 sampler: &sm::InlineSampler,
4748 ) -> BackendResult {
4749 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4750 writeln!(
4751 self.out,
4752 "{}{}::{}_address::{},",
4753 level,
4754 NAMESPACE,
4755 letter,
4756 address.as_str(),
4757 )?;
4758 }
4759 writeln!(
4760 self.out,
4761 "{}{}::mag_filter::{},",
4762 level,
4763 NAMESPACE,
4764 sampler.mag_filter.as_str(),
4765 )?;
4766 writeln!(
4767 self.out,
4768 "{}{}::min_filter::{},",
4769 level,
4770 NAMESPACE,
4771 sampler.min_filter.as_str(),
4772 )?;
4773 if let Some(filter) = sampler.mip_filter {
4774 writeln!(
4775 self.out,
4776 "{}{}::mip_filter::{},",
4777 level,
4778 NAMESPACE,
4779 filter.as_str(),
4780 )?;
4781 }
4782 if sampler.border_color != sm::BorderColor::TransparentBlack {
4784 writeln!(
4785 self.out,
4786 "{}{}::border_color::{},",
4787 level,
4788 NAMESPACE,
4789 sampler.border_color.as_str(),
4790 )?;
4791 }
4792 if false {
4796 if let Some(ref lod) = sampler.lod_clamp {
4797 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4798 }
4799 if let Some(aniso) = sampler.max_anisotropy {
4800 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4801 }
4802 }
4803 if sampler.compare_func != sm::CompareFunc::Never {
4804 writeln!(
4805 self.out,
4806 "{}{}::compare_func::{},",
4807 level,
4808 NAMESPACE,
4809 sampler.compare_func.as_str(),
4810 )?;
4811 }
4812 writeln!(
4813 self.out,
4814 "{}{}::coord::{}",
4815 level,
4816 NAMESPACE,
4817 sampler.coord.as_str()
4818 )?;
4819 Ok(())
4820 }
4821
4822 fn write_unpacking_function(
4823 &mut self,
4824 format: back::msl::VertexFormat,
4825 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4826 use crate::{Scalar, VectorSize};
4827 use back::msl::VertexFormat::*;
4828 match format {
4829 Uint8 => {
4830 let name = self.namer.call("unpackUint8");
4831 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4832 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4833 writeln!(self.out, "}}")?;
4834 Ok((name, 1, None, Scalar::U32))
4835 }
4836 Uint8x2 => {
4837 let name = self.namer.call("unpackUint8x2");
4838 writeln!(
4839 self.out,
4840 "metal::uint2 {name}(metal::uchar b0, \
4841 metal::uchar b1) {{"
4842 )?;
4843 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4844 writeln!(self.out, "}}")?;
4845 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4846 }
4847 Uint8x4 => {
4848 let name = self.namer.call("unpackUint8x4");
4849 writeln!(
4850 self.out,
4851 "metal::uint4 {name}(metal::uchar b0, \
4852 metal::uchar b1, \
4853 metal::uchar b2, \
4854 metal::uchar b3) {{"
4855 )?;
4856 writeln!(
4857 self.out,
4858 "{}return metal::uint4(b0, b1, b2, b3);",
4859 back::INDENT
4860 )?;
4861 writeln!(self.out, "}}")?;
4862 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4863 }
4864 Sint8 => {
4865 let name = self.namer.call("unpackSint8");
4866 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4867 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4868 writeln!(self.out, "}}")?;
4869 Ok((name, 1, None, Scalar::I32))
4870 }
4871 Sint8x2 => {
4872 let name = self.namer.call("unpackSint8x2");
4873 writeln!(
4874 self.out,
4875 "metal::int2 {name}(metal::uchar b0, \
4876 metal::uchar b1) {{"
4877 )?;
4878 writeln!(
4879 self.out,
4880 "{}return metal::int2(as_type<char>(b0), \
4881 as_type<char>(b1));",
4882 back::INDENT
4883 )?;
4884 writeln!(self.out, "}}")?;
4885 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4886 }
4887 Sint8x4 => {
4888 let name = self.namer.call("unpackSint8x4");
4889 writeln!(
4890 self.out,
4891 "metal::int4 {name}(metal::uchar b0, \
4892 metal::uchar b1, \
4893 metal::uchar b2, \
4894 metal::uchar b3) {{"
4895 )?;
4896 writeln!(
4897 self.out,
4898 "{}return metal::int4(as_type<char>(b0), \
4899 as_type<char>(b1), \
4900 as_type<char>(b2), \
4901 as_type<char>(b3));",
4902 back::INDENT
4903 )?;
4904 writeln!(self.out, "}}")?;
4905 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4906 }
4907 Unorm8 => {
4908 let name = self.namer.call("unpackUnorm8");
4909 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4910 writeln!(
4911 self.out,
4912 "{}return float(float(b0) / 255.0f);",
4913 back::INDENT
4914 )?;
4915 writeln!(self.out, "}}")?;
4916 Ok((name, 1, None, Scalar::F32))
4917 }
4918 Unorm8x2 => {
4919 let name = self.namer.call("unpackUnorm8x2");
4920 writeln!(
4921 self.out,
4922 "metal::float2 {name}(metal::uchar b0, \
4923 metal::uchar b1) {{"
4924 )?;
4925 writeln!(
4926 self.out,
4927 "{}return metal::float2(float(b0) / 255.0f, \
4928 float(b1) / 255.0f);",
4929 back::INDENT
4930 )?;
4931 writeln!(self.out, "}}")?;
4932 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4933 }
4934 Unorm8x4 => {
4935 let name = self.namer.call("unpackUnorm8x4");
4936 writeln!(
4937 self.out,
4938 "metal::float4 {name}(metal::uchar b0, \
4939 metal::uchar b1, \
4940 metal::uchar b2, \
4941 metal::uchar b3) {{"
4942 )?;
4943 writeln!(
4944 self.out,
4945 "{}return metal::float4(float(b0) / 255.0f, \
4946 float(b1) / 255.0f, \
4947 float(b2) / 255.0f, \
4948 float(b3) / 255.0f);",
4949 back::INDENT
4950 )?;
4951 writeln!(self.out, "}}")?;
4952 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4953 }
4954 Snorm8 => {
4955 let name = self.namer.call("unpackSnorm8");
4956 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4957 writeln!(
4958 self.out,
4959 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4960 back::INDENT
4961 )?;
4962 writeln!(self.out, "}}")?;
4963 Ok((name, 1, None, Scalar::F32))
4964 }
4965 Snorm8x2 => {
4966 let name = self.namer.call("unpackSnorm8x2");
4967 writeln!(
4968 self.out,
4969 "metal::float2 {name}(metal::uchar b0, \
4970 metal::uchar b1) {{"
4971 )?;
4972 writeln!(
4973 self.out,
4974 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4975 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4976 back::INDENT
4977 )?;
4978 writeln!(self.out, "}}")?;
4979 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4980 }
4981 Snorm8x4 => {
4982 let name = self.namer.call("unpackSnorm8x4");
4983 writeln!(
4984 self.out,
4985 "metal::float4 {name}(metal::uchar b0, \
4986 metal::uchar b1, \
4987 metal::uchar b2, \
4988 metal::uchar b3) {{"
4989 )?;
4990 writeln!(
4991 self.out,
4992 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4993 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4994 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4995 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4996 back::INDENT
4997 )?;
4998 writeln!(self.out, "}}")?;
4999 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5000 }
5001 Uint16 => {
5002 let name = self.namer.call("unpackUint16");
5003 writeln!(
5004 self.out,
5005 "metal::uint {name}(metal::uint b0, \
5006 metal::uint b1) {{"
5007 )?;
5008 writeln!(
5009 self.out,
5010 "{}return metal::uint(b1 << 8 | b0);",
5011 back::INDENT
5012 )?;
5013 writeln!(self.out, "}}")?;
5014 Ok((name, 2, None, Scalar::U32))
5015 }
5016 Uint16x2 => {
5017 let name = self.namer.call("unpackUint16x2");
5018 writeln!(
5019 self.out,
5020 "metal::uint2 {name}(metal::uint b0, \
5021 metal::uint b1, \
5022 metal::uint b2, \
5023 metal::uint b3) {{"
5024 )?;
5025 writeln!(
5026 self.out,
5027 "{}return metal::uint2(b1 << 8 | b0, \
5028 b3 << 8 | b2);",
5029 back::INDENT
5030 )?;
5031 writeln!(self.out, "}}")?;
5032 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5033 }
5034 Uint16x4 => {
5035 let name = self.namer.call("unpackUint16x4");
5036 writeln!(
5037 self.out,
5038 "metal::uint4 {name}(metal::uint b0, \
5039 metal::uint b1, \
5040 metal::uint b2, \
5041 metal::uint b3, \
5042 metal::uint b4, \
5043 metal::uint b5, \
5044 metal::uint b6, \
5045 metal::uint b7) {{"
5046 )?;
5047 writeln!(
5048 self.out,
5049 "{}return metal::uint4(b1 << 8 | b0, \
5050 b3 << 8 | b2, \
5051 b5 << 8 | b4, \
5052 b7 << 8 | b6);",
5053 back::INDENT
5054 )?;
5055 writeln!(self.out, "}}")?;
5056 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5057 }
5058 Sint16 => {
5059 let name = self.namer.call("unpackSint16");
5060 writeln!(
5061 self.out,
5062 "int {name}(metal::ushort b0, \
5063 metal::ushort b1) {{"
5064 )?;
5065 writeln!(
5066 self.out,
5067 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5068 back::INDENT
5069 )?;
5070 writeln!(self.out, "}}")?;
5071 Ok((name, 2, None, Scalar::I32))
5072 }
5073 Sint16x2 => {
5074 let name = self.namer.call("unpackSint16x2");
5075 writeln!(
5076 self.out,
5077 "metal::int2 {name}(metal::ushort b0, \
5078 metal::ushort b1, \
5079 metal::ushort b2, \
5080 metal::ushort b3) {{"
5081 )?;
5082 writeln!(
5083 self.out,
5084 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5085 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5086 back::INDENT
5087 )?;
5088 writeln!(self.out, "}}")?;
5089 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5090 }
5091 Sint16x4 => {
5092 let name = self.namer.call("unpackSint16x4");
5093 writeln!(
5094 self.out,
5095 "metal::int4 {name}(metal::ushort b0, \
5096 metal::ushort b1, \
5097 metal::ushort b2, \
5098 metal::ushort b3, \
5099 metal::ushort b4, \
5100 metal::ushort b5, \
5101 metal::ushort b6, \
5102 metal::ushort b7) {{"
5103 )?;
5104 writeln!(
5105 self.out,
5106 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5107 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5108 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5109 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5110 back::INDENT
5111 )?;
5112 writeln!(self.out, "}}")?;
5113 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5114 }
5115 Unorm16 => {
5116 let name = self.namer.call("unpackUnorm16");
5117 writeln!(
5118 self.out,
5119 "float {name}(metal::ushort b0, \
5120 metal::ushort b1) {{"
5121 )?;
5122 writeln!(
5123 self.out,
5124 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5125 back::INDENT
5126 )?;
5127 writeln!(self.out, "}}")?;
5128 Ok((name, 2, None, Scalar::F32))
5129 }
5130 Unorm16x2 => {
5131 let name = self.namer.call("unpackUnorm16x2");
5132 writeln!(
5133 self.out,
5134 "metal::float2 {name}(metal::ushort b0, \
5135 metal::ushort b1, \
5136 metal::ushort b2, \
5137 metal::ushort b3) {{"
5138 )?;
5139 writeln!(
5140 self.out,
5141 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5142 float(b3 << 8 | b2) / 65535.0f);",
5143 back::INDENT
5144 )?;
5145 writeln!(self.out, "}}")?;
5146 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5147 }
5148 Unorm16x4 => {
5149 let name = self.namer.call("unpackUnorm16x4");
5150 writeln!(
5151 self.out,
5152 "metal::float4 {name}(metal::ushort b0, \
5153 metal::ushort b1, \
5154 metal::ushort b2, \
5155 metal::ushort b3, \
5156 metal::ushort b4, \
5157 metal::ushort b5, \
5158 metal::ushort b6, \
5159 metal::ushort b7) {{"
5160 )?;
5161 writeln!(
5162 self.out,
5163 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5164 float(b3 << 8 | b2) / 65535.0f, \
5165 float(b5 << 8 | b4) / 65535.0f, \
5166 float(b7 << 8 | b6) / 65535.0f);",
5167 back::INDENT
5168 )?;
5169 writeln!(self.out, "}}")?;
5170 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5171 }
5172 Snorm16 => {
5173 let name = self.namer.call("unpackSnorm16");
5174 writeln!(
5175 self.out,
5176 "float {name}(metal::ushort b0, \
5177 metal::ushort b1) {{"
5178 )?;
5179 writeln!(
5180 self.out,
5181 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5182 back::INDENT
5183 )?;
5184 writeln!(self.out, "}}")?;
5185 Ok((name, 2, None, Scalar::F32))
5186 }
5187 Snorm16x2 => {
5188 let name = self.namer.call("unpackSnorm16x2");
5189 writeln!(
5190 self.out,
5191 "metal::float2 {name}(uint b0, \
5192 uint b1, \
5193 uint b2, \
5194 uint b3) {{"
5195 )?;
5196 writeln!(
5197 self.out,
5198 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5199 back::INDENT
5200 )?;
5201 writeln!(self.out, "}}")?;
5202 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5203 }
5204 Snorm16x4 => {
5205 let name = self.namer.call("unpackSnorm16x4");
5206 writeln!(
5207 self.out,
5208 "metal::float4 {name}(uint b0, \
5209 uint b1, \
5210 uint b2, \
5211 uint b3, \
5212 uint b4, \
5213 uint b5, \
5214 uint b6, \
5215 uint b7) {{"
5216 )?;
5217 writeln!(
5218 self.out,
5219 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5220 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5221 back::INDENT
5222 )?;
5223 writeln!(self.out, "}}")?;
5224 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5225 }
5226 Float16 => {
5227 let name = self.namer.call("unpackFloat16");
5228 writeln!(
5229 self.out,
5230 "float {name}(metal::ushort b0, \
5231 metal::ushort b1) {{"
5232 )?;
5233 writeln!(
5234 self.out,
5235 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5236 back::INDENT
5237 )?;
5238 writeln!(self.out, "}}")?;
5239 Ok((name, 2, None, Scalar::F32))
5240 }
5241 Float16x2 => {
5242 let name = self.namer.call("unpackFloat16x2");
5243 writeln!(
5244 self.out,
5245 "metal::float2 {name}(metal::ushort b0, \
5246 metal::ushort b1, \
5247 metal::ushort b2, \
5248 metal::ushort b3) {{"
5249 )?;
5250 writeln!(
5251 self.out,
5252 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5253 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5254 back::INDENT
5255 )?;
5256 writeln!(self.out, "}}")?;
5257 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5258 }
5259 Float16x4 => {
5260 let name = self.namer.call("unpackFloat16x4");
5261 writeln!(
5262 self.out,
5263 "metal::float4 {name}(metal::ushort b0, \
5264 metal::ushort b1, \
5265 metal::ushort b2, \
5266 metal::ushort b3, \
5267 metal::ushort b4, \
5268 metal::ushort b5, \
5269 metal::ushort b6, \
5270 metal::ushort b7) {{"
5271 )?;
5272 writeln!(
5273 self.out,
5274 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5275 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5276 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5277 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5278 back::INDENT
5279 )?;
5280 writeln!(self.out, "}}")?;
5281 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5282 }
5283 Float32 => {
5284 let name = self.namer.call("unpackFloat32");
5285 writeln!(
5286 self.out,
5287 "float {name}(uint b0, \
5288 uint b1, \
5289 uint b2, \
5290 uint b3) {{"
5291 )?;
5292 writeln!(
5293 self.out,
5294 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5295 back::INDENT
5296 )?;
5297 writeln!(self.out, "}}")?;
5298 Ok((name, 4, None, Scalar::F32))
5299 }
5300 Float32x2 => {
5301 let name = self.namer.call("unpackFloat32x2");
5302 writeln!(
5303 self.out,
5304 "metal::float2 {name}(uint b0, \
5305 uint b1, \
5306 uint b2, \
5307 uint b3, \
5308 uint b4, \
5309 uint b5, \
5310 uint b6, \
5311 uint b7) {{"
5312 )?;
5313 writeln!(
5314 self.out,
5315 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5316 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5317 back::INDENT
5318 )?;
5319 writeln!(self.out, "}}")?;
5320 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5321 }
5322 Float32x3 => {
5323 let name = self.namer.call("unpackFloat32x3");
5324 writeln!(
5325 self.out,
5326 "metal::float3 {name}(uint b0, \
5327 uint b1, \
5328 uint b2, \
5329 uint b3, \
5330 uint b4, \
5331 uint b5, \
5332 uint b6, \
5333 uint b7, \
5334 uint b8, \
5335 uint b9, \
5336 uint b10, \
5337 uint b11) {{"
5338 )?;
5339 writeln!(
5340 self.out,
5341 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5342 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5343 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5344 back::INDENT
5345 )?;
5346 writeln!(self.out, "}}")?;
5347 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5348 }
5349 Float32x4 => {
5350 let name = self.namer.call("unpackFloat32x4");
5351 writeln!(
5352 self.out,
5353 "metal::float4 {name}(uint b0, \
5354 uint b1, \
5355 uint b2, \
5356 uint b3, \
5357 uint b4, \
5358 uint b5, \
5359 uint b6, \
5360 uint b7, \
5361 uint b8, \
5362 uint b9, \
5363 uint b10, \
5364 uint b11, \
5365 uint b12, \
5366 uint b13, \
5367 uint b14, \
5368 uint b15) {{"
5369 )?;
5370 writeln!(
5371 self.out,
5372 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5373 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5374 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5375 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5376 back::INDENT
5377 )?;
5378 writeln!(self.out, "}}")?;
5379 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5380 }
5381 Uint32 => {
5382 let name = self.namer.call("unpackUint32");
5383 writeln!(
5384 self.out,
5385 "uint {name}(uint b0, \
5386 uint b1, \
5387 uint b2, \
5388 uint b3) {{"
5389 )?;
5390 writeln!(
5391 self.out,
5392 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5393 back::INDENT
5394 )?;
5395 writeln!(self.out, "}}")?;
5396 Ok((name, 4, None, Scalar::U32))
5397 }
5398 Uint32x2 => {
5399 let name = self.namer.call("unpackUint32x2");
5400 writeln!(
5401 self.out,
5402 "uint2 {name}(uint b0, \
5403 uint b1, \
5404 uint b2, \
5405 uint b3, \
5406 uint b4, \
5407 uint b5, \
5408 uint b6, \
5409 uint b7) {{"
5410 )?;
5411 writeln!(
5412 self.out,
5413 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5414 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5415 back::INDENT
5416 )?;
5417 writeln!(self.out, "}}")?;
5418 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5419 }
5420 Uint32x3 => {
5421 let name = self.namer.call("unpackUint32x3");
5422 writeln!(
5423 self.out,
5424 "uint3 {name}(uint b0, \
5425 uint b1, \
5426 uint b2, \
5427 uint b3, \
5428 uint b4, \
5429 uint b5, \
5430 uint b6, \
5431 uint b7, \
5432 uint b8, \
5433 uint b9, \
5434 uint b10, \
5435 uint b11) {{"
5436 )?;
5437 writeln!(
5438 self.out,
5439 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5440 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5441 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5442 back::INDENT
5443 )?;
5444 writeln!(self.out, "}}")?;
5445 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5446 }
5447 Uint32x4 => {
5448 let name = self.namer.call("unpackUint32x4");
5449 writeln!(
5450 self.out,
5451 "{NAMESPACE}::uint4 {name}(uint b0, \
5452 uint b1, \
5453 uint b2, \
5454 uint b3, \
5455 uint b4, \
5456 uint b5, \
5457 uint b6, \
5458 uint b7, \
5459 uint b8, \
5460 uint b9, \
5461 uint b10, \
5462 uint b11, \
5463 uint b12, \
5464 uint b13, \
5465 uint b14, \
5466 uint b15) {{"
5467 )?;
5468 writeln!(
5469 self.out,
5470 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5471 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5472 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5473 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5474 back::INDENT
5475 )?;
5476 writeln!(self.out, "}}")?;
5477 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5478 }
5479 Sint32 => {
5480 let name = self.namer.call("unpackSint32");
5481 writeln!(
5482 self.out,
5483 "int {name}(uint b0, \
5484 uint b1, \
5485 uint b2, \
5486 uint b3) {{"
5487 )?;
5488 writeln!(
5489 self.out,
5490 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5491 back::INDENT
5492 )?;
5493 writeln!(self.out, "}}")?;
5494 Ok((name, 4, None, Scalar::I32))
5495 }
5496 Sint32x2 => {
5497 let name = self.namer.call("unpackSint32x2");
5498 writeln!(
5499 self.out,
5500 "metal::int2 {name}(uint b0, \
5501 uint b1, \
5502 uint b2, \
5503 uint b3, \
5504 uint b4, \
5505 uint b5, \
5506 uint b6, \
5507 uint b7) {{"
5508 )?;
5509 writeln!(
5510 self.out,
5511 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5512 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5513 back::INDENT
5514 )?;
5515 writeln!(self.out, "}}")?;
5516 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5517 }
5518 Sint32x3 => {
5519 let name = self.namer.call("unpackSint32x3");
5520 writeln!(
5521 self.out,
5522 "metal::int3 {name}(uint b0, \
5523 uint b1, \
5524 uint b2, \
5525 uint b3, \
5526 uint b4, \
5527 uint b5, \
5528 uint b6, \
5529 uint b7, \
5530 uint b8, \
5531 uint b9, \
5532 uint b10, \
5533 uint b11) {{"
5534 )?;
5535 writeln!(
5536 self.out,
5537 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5538 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5539 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5540 back::INDENT
5541 )?;
5542 writeln!(self.out, "}}")?;
5543 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5544 }
5545 Sint32x4 => {
5546 let name = self.namer.call("unpackSint32x4");
5547 writeln!(
5548 self.out,
5549 "metal::int4 {name}(uint b0, \
5550 uint b1, \
5551 uint b2, \
5552 uint b3, \
5553 uint b4, \
5554 uint b5, \
5555 uint b6, \
5556 uint b7, \
5557 uint b8, \
5558 uint b9, \
5559 uint b10, \
5560 uint b11, \
5561 uint b12, \
5562 uint b13, \
5563 uint b14, \
5564 uint b15) {{"
5565 )?;
5566 writeln!(
5567 self.out,
5568 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5569 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5570 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5571 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5572 back::INDENT
5573 )?;
5574 writeln!(self.out, "}}")?;
5575 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5576 }
5577 Unorm10_10_10_2 => {
5578 let name = self.namer.call("unpackUnorm10_10_10_2");
5579 writeln!(
5580 self.out,
5581 "metal::float4 {name}(uint b0, \
5582 uint b1, \
5583 uint b2, \
5584 uint b3) {{"
5585 )?;
5586 writeln!(
5587 self.out,
5588 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5600 back::INDENT
5601 )?;
5602 writeln!(self.out, "}}")?;
5603 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5604 }
5605 Unorm8x4Bgra => {
5606 let name = self.namer.call("unpackUnorm8x4Bgra");
5607 writeln!(
5608 self.out,
5609 "metal::float4 {name}(metal::uchar b0, \
5610 metal::uchar b1, \
5611 metal::uchar b2, \
5612 metal::uchar b3) {{"
5613 )?;
5614 writeln!(
5615 self.out,
5616 "{}return metal::float4(float(b2) / 255.0f, \
5617 float(b1) / 255.0f, \
5618 float(b0) / 255.0f, \
5619 float(b3) / 255.0f);",
5620 back::INDENT
5621 )?;
5622 writeln!(self.out, "}}")?;
5623 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5624 }
5625 }
5626 }
5627
5628 fn write_wrapped_unary_op(
5629 &mut self,
5630 module: &crate::Module,
5631 func_ctx: &back::FunctionCtx,
5632 op: crate::UnaryOperator,
5633 operand: Handle<crate::Expression>,
5634 ) -> BackendResult {
5635 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5636 match op {
5637 crate::UnaryOperator::Negate
5644 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5645 {
5646 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5647 return Ok(());
5648 };
5649 let wrapped = WrappedFunction::UnaryOp {
5650 op,
5651 ty: (vector_size, scalar),
5652 };
5653 if !self.wrapped_functions.insert(wrapped) {
5654 return Ok(());
5655 }
5656
5657 let unsigned_scalar = crate::Scalar {
5658 kind: crate::ScalarKind::Uint,
5659 ..scalar
5660 };
5661 let mut type_name = String::new();
5662 let mut unsigned_type_name = String::new();
5663 match vector_size {
5664 None => {
5665 put_numeric_type(&mut type_name, scalar, &[])?;
5666 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5667 }
5668 Some(size) => {
5669 put_numeric_type(&mut type_name, scalar, &[size])?;
5670 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5671 }
5672 };
5673
5674 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5675 let level = back::Level(1);
5676 if scalar.width < 4 {
5680 writeln!(
5681 self.out,
5682 "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5683 )?;
5684 } else {
5685 writeln!(
5686 self.out,
5687 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5688 )?;
5689 }
5690 writeln!(self.out, "}}")?;
5691 writeln!(self.out)?;
5692 }
5693 _ => {}
5694 }
5695 Ok(())
5696 }
5697
5698 fn write_wrapped_binary_op(
5699 &mut self,
5700 module: &crate::Module,
5701 func_ctx: &back::FunctionCtx,
5702 expr: Handle<crate::Expression>,
5703 op: crate::BinaryOperator,
5704 left: Handle<crate::Expression>,
5705 right: Handle<crate::Expression>,
5706 ) -> BackendResult {
5707 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5708 let left_ty = func_ctx.resolve_type(left, &module.types);
5709 let right_ty = func_ctx.resolve_type(right, &module.types);
5710 match (op, expr_ty.scalar_kind()) {
5711 (
5718 crate::BinaryOperator::Divide,
5719 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5720 ) => {
5721 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5722 return Ok(());
5723 };
5724 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5725 return Ok(());
5726 };
5727 let wrapped = WrappedFunction::BinaryOp {
5728 op,
5729 left_ty: left_wrapped_ty,
5730 right_ty: right_wrapped_ty,
5731 };
5732 if !self.wrapped_functions.insert(wrapped) {
5733 return Ok(());
5734 }
5735
5736 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5737 return Ok(());
5738 };
5739 let mut type_name = String::new();
5740 match vector_size {
5741 None => put_numeric_type(&mut type_name, scalar, &[])?,
5742 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5743 };
5744 writeln!(
5745 self.out,
5746 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5747 )?;
5748 let level = back::Level(1);
5749 let (lp, rp) = if scalar.width < 4 {
5753 (format!("{type_name}("), ")".to_string())
5754 } else {
5755 (String::new(), String::new())
5756 };
5757 match scalar.kind {
5758 crate::ScalarKind::Sint => {
5759 let min_val = match scalar.width {
5760 2 => crate::Literal::I16(i16::MIN),
5761 4 => crate::Literal::I32(i32::MIN),
5762 8 => crate::Literal::I64(i64::MIN),
5763 _ => {
5764 return Err(Error::GenericValidation(format!(
5765 "Unexpected width for scalar {scalar:?}"
5766 )));
5767 }
5768 };
5769 write!(
5770 self.out,
5771 "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5772 )?;
5773 self.put_literal(min_val)?;
5774 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5775 }
5776 crate::ScalarKind::Uint => {
5777 let suffix = if scalar.width < 4 { "" } else { "u" };
5778 writeln!(
5779 self.out,
5780 "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5781 )?
5782 }
5783 _ => unreachable!(),
5784 }
5785 writeln!(self.out, "}}")?;
5786 writeln!(self.out)?;
5787 }
5788 (
5801 crate::BinaryOperator::Modulo,
5802 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5803 ) => {
5804 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5805 return Ok(());
5806 };
5807 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5808 else {
5809 return Ok(());
5810 };
5811 let wrapped = WrappedFunction::BinaryOp {
5812 op,
5813 left_ty: left_wrapped_ty,
5814 right_ty: (right_vector_size, right_scalar),
5815 };
5816 if !self.wrapped_functions.insert(wrapped) {
5817 return Ok(());
5818 }
5819
5820 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5821 return Ok(());
5822 };
5823 let mut type_name = String::new();
5824 match vector_size {
5825 None => put_numeric_type(&mut type_name, scalar, &[])?,
5826 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5827 };
5828 let mut rhs_type_name = String::new();
5829 match right_vector_size {
5830 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5831 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5832 };
5833
5834 writeln!(
5835 self.out,
5836 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5837 )?;
5838 let level = back::Level(1);
5839 let (lp, rp) = if scalar.width < 4 {
5840 (format!("{type_name}("), ")".to_string())
5841 } else {
5842 (String::new(), String::new())
5843 };
5844 match scalar.kind {
5845 crate::ScalarKind::Sint => {
5846 let min_val = match scalar.width {
5847 2 => crate::Literal::I16(i16::MIN),
5848 4 => crate::Literal::I32(i32::MIN),
5849 8 => crate::Literal::I64(i64::MIN),
5850 _ => {
5851 return Err(Error::GenericValidation(format!(
5852 "Unexpected width for scalar {scalar:?}"
5853 )));
5854 }
5855 };
5856 write!(
5857 self.out,
5858 "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5859 )?;
5860 self.put_literal(min_val)?;
5861 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5862 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5863 }
5864 crate::ScalarKind::Uint => {
5865 let suffix = if scalar.width < 4 { "" } else { "u" };
5866 writeln!(
5867 self.out,
5868 "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5869 )?
5870 }
5871 _ => unreachable!(),
5872 }
5873 writeln!(self.out, "}}")?;
5874 writeln!(self.out)?;
5875 }
5876 _ => {}
5877 }
5878 Ok(())
5879 }
5880
5881 fn get_dot_wrapper_function_helper_name(
5887 &self,
5888 scalar: crate::Scalar,
5889 size: crate::VectorSize,
5890 ) -> String {
5891 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5893
5894 let type_name = scalar.to_msl_name();
5895 let size_suffix = common::vector_size_str(size);
5896 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5897 }
5898
5899 #[allow(clippy::too_many_arguments)]
5900 fn write_wrapped_math_function(
5901 &mut self,
5902 module: &crate::Module,
5903 func_ctx: &back::FunctionCtx,
5904 fun: crate::MathFunction,
5905 arg: Handle<crate::Expression>,
5906 _arg1: Option<Handle<crate::Expression>>,
5907 _arg2: Option<Handle<crate::Expression>>,
5908 _arg3: Option<Handle<crate::Expression>>,
5909 ) -> BackendResult {
5910 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5911 match fun {
5912 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5920 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5921 return Ok(());
5922 };
5923 let wrapped = WrappedFunction::Math {
5924 fun,
5925 arg_ty: (vector_size, scalar),
5926 };
5927 if !self.wrapped_functions.insert(wrapped) {
5928 return Ok(());
5929 }
5930
5931 let unsigned_scalar = crate::Scalar {
5932 kind: crate::ScalarKind::Uint,
5933 ..scalar
5934 };
5935 let mut type_name = String::new();
5936 let mut unsigned_type_name = String::new();
5937 match vector_size {
5938 None => {
5939 put_numeric_type(&mut type_name, scalar, &[])?;
5940 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5941 }
5942 Some(size) => {
5943 put_numeric_type(&mut type_name, scalar, &[size])?;
5944 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5945 }
5946 };
5947
5948 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5949 let level = back::Level(1);
5950 let zero = if scalar.width < 4 {
5951 format!("{type_name}(0)")
5952 } else {
5953 "0".to_string()
5954 };
5955 let neg_expr = if scalar.width < 4 {
5956 format!(
5957 "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5958 )
5959 } else {
5960 format!("-as_type<{unsigned_type_name}>(val)")
5961 };
5962 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5963 writeln!(self.out, "}}")?;
5964 writeln!(self.out)?;
5965 }
5966
5967 crate::MathFunction::Dot => match *arg_ty {
5968 crate::TypeInner::Vector { size, scalar }
5969 if matches!(
5970 scalar.kind,
5971 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5972 ) =>
5973 {
5974 let wrapped = WrappedFunction::Math {
5976 fun,
5977 arg_ty: (Some(size), scalar),
5978 };
5979 if !self.wrapped_functions.insert(wrapped) {
5980 return Ok(());
5981 }
5982
5983 let mut vec_ty = String::new();
5984 put_numeric_type(&mut vec_ty, scalar, &[size])?;
5985 let mut ret_ty = String::new();
5986 put_numeric_type(&mut ret_ty, scalar, &[])?;
5987
5988 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5989
5990 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5992 let level = back::Level(1);
5993 write!(self.out, "{level}return ")?;
5994 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5995 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5996 Ok(())
5997 })?;
5998 writeln!(self.out, ";")?;
5999 writeln!(self.out, "}}")?;
6000 writeln!(self.out)?;
6001 }
6002 _ => {}
6003 },
6004
6005 _ => {}
6006 }
6007 Ok(())
6008 }
6009
6010 fn write_wrapped_cast(
6011 &mut self,
6012 module: &crate::Module,
6013 func_ctx: &back::FunctionCtx,
6014 expr: Handle<crate::Expression>,
6015 kind: crate::ScalarKind,
6016 convert: Option<crate::Bytes>,
6017 ) -> BackendResult {
6018 let src_ty = func_ctx.resolve_type(expr, &module.types);
6029 let Some(width) = convert else {
6030 return Ok(());
6031 };
6032 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6033 return Ok(());
6034 };
6035 let dst_scalar = crate::Scalar { kind, width };
6036 if src_scalar.kind != crate::ScalarKind::Float
6037 || (dst_scalar.kind != crate::ScalarKind::Sint
6038 && dst_scalar.kind != crate::ScalarKind::Uint)
6039 {
6040 return Ok(());
6041 }
6042 let wrapped = WrappedFunction::Cast {
6043 src_scalar,
6044 vector_size,
6045 dst_scalar,
6046 };
6047 if !self.wrapped_functions.insert(wrapped) {
6048 return Ok(());
6049 }
6050 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6051
6052 let mut src_type_name = String::new();
6053 match vector_size {
6054 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6055 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6056 };
6057 let mut dst_type_name = String::new();
6058 match vector_size {
6059 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6060 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6061 };
6062 let fun_name = match dst_scalar {
6063 crate::Scalar::I32 => F2I32_FUNCTION,
6064 crate::Scalar::U32 => F2U32_FUNCTION,
6065 crate::Scalar::I64 => F2I64_FUNCTION,
6066 crate::Scalar::U64 => F2U64_FUNCTION,
6067 _ => unreachable!(),
6068 };
6069
6070 writeln!(
6071 self.out,
6072 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6073 )?;
6074 let level = back::Level(1);
6075 write!(
6076 self.out,
6077 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6078 )?;
6079 self.put_literal(min)?;
6080 write!(self.out, ", ")?;
6081 self.put_literal(max)?;
6082 writeln!(self.out, "));")?;
6083 writeln!(self.out, "}}")?;
6084 writeln!(self.out)?;
6085 Ok(())
6086 }
6087
6088 fn write_convert_yuv_to_rgb_and_return(
6096 &mut self,
6097 level: back::Level,
6098 y: &str,
6099 uv: &str,
6100 params: &str,
6101 ) -> BackendResult {
6102 let l1 = level;
6103 let l2 = l1.next();
6104
6105 writeln!(
6107 self.out,
6108 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6109 )?;
6110
6111 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6114 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6115 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6116 writeln!(
6117 self.out,
6118 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6119 )?;
6120
6121 writeln!(
6124 self.out,
6125 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6126 )?;
6127
6128 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6131 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6132 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6133 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6134
6135 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6136 Ok(())
6137 }
6138
6139 #[allow(clippy::too_many_arguments)]
6140 fn write_wrapped_image_load(
6141 &mut self,
6142 module: &crate::Module,
6143 func_ctx: &back::FunctionCtx,
6144 image: Handle<crate::Expression>,
6145 _coordinate: Handle<crate::Expression>,
6146 _array_index: Option<Handle<crate::Expression>>,
6147 _sample: Option<Handle<crate::Expression>>,
6148 _level: Option<Handle<crate::Expression>>,
6149 ) -> BackendResult {
6150 let class = match *func_ctx.resolve_type(image, &module.types) {
6152 crate::TypeInner::Image { class, .. } => class,
6153 _ => unreachable!(),
6154 };
6155 if class != crate::ImageClass::External {
6156 return Ok(());
6157 }
6158 let wrapped = WrappedFunction::ImageLoad { class };
6159 if !self.wrapped_functions.insert(wrapped) {
6160 return Ok(());
6161 }
6162
6163 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6164 let l1 = back::Level(1);
6165 let l2 = l1.next();
6166 let l3 = l2.next();
6167 writeln!(
6168 self.out,
6169 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6170 )?;
6171 writeln!(
6175 self.out,
6176 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6177 )?;
6178 writeln!(
6179 self.out,
6180 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6181 )?;
6182
6183 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6185 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6186 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6188 writeln!(self.out, "{l1}}} else {{")?;
6189
6190 writeln!(
6192 self.out,
6193 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6194 )?;
6195 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6196
6197 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6199
6200 writeln!(self.out, "{l2}float2 uv;")?;
6201 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6202 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6204 writeln!(self.out, "{l2}}} else {{")?;
6205 writeln!(
6207 self.out,
6208 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6209 )?;
6210 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6211 writeln!(
6212 self.out,
6213 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6214 )?;
6215 writeln!(self.out, "{l2}}}")?;
6216
6217 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6218
6219 writeln!(self.out, "{l1}}}")?;
6220 writeln!(self.out, "}}")?;
6221 writeln!(self.out)?;
6222 Ok(())
6223 }
6224
6225 #[allow(clippy::too_many_arguments)]
6226 fn write_wrapped_image_sample(
6227 &mut self,
6228 module: &crate::Module,
6229 func_ctx: &back::FunctionCtx,
6230 image: Handle<crate::Expression>,
6231 _sampler: Handle<crate::Expression>,
6232 _gather: Option<crate::SwizzleComponent>,
6233 _coordinate: Handle<crate::Expression>,
6234 _array_index: Option<Handle<crate::Expression>>,
6235 _offset: Option<Handle<crate::Expression>>,
6236 _level: crate::SampleLevel,
6237 _depth_ref: Option<Handle<crate::Expression>>,
6238 clamp_to_edge: bool,
6239 ) -> BackendResult {
6240 if !clamp_to_edge {
6243 return Ok(());
6244 }
6245 let class = match *func_ctx.resolve_type(image, &module.types) {
6246 crate::TypeInner::Image { class, .. } => class,
6247 _ => unreachable!(),
6248 };
6249 let wrapped = WrappedFunction::ImageSample {
6250 class,
6251 clamp_to_edge: true,
6252 };
6253 if !self.wrapped_functions.insert(wrapped) {
6254 return Ok(());
6255 }
6256 match class {
6257 crate::ImageClass::External => {
6258 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6259 let l1 = back::Level(1);
6260 let l2 = l1.next();
6261 let l3 = l2.next();
6262 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6263 writeln!(
6264 self.out,
6265 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6266 )?;
6267
6268 writeln!(
6276 self.out,
6277 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6278 )?;
6279 writeln!(
6280 self.out,
6281 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6282 )?;
6283 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6284 writeln!(
6285 self.out,
6286 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6287 )?;
6288 writeln!(
6289 self.out,
6290 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6291 )?;
6292 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6293 writeln!(
6295 self.out,
6296 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6297 )?;
6298 writeln!(self.out, "{l1}}} else {{")?;
6299 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6300 writeln!(
6301 self.out,
6302 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6303 )?;
6304 writeln!(
6305 self.out,
6306 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6307 )?;
6308
6309 writeln!(
6311 self.out,
6312 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6313 )?;
6314 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6315 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6316 writeln!(
6318 self.out,
6319 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6320 )?;
6321 writeln!(self.out, "{l2}}} else {{")?;
6322 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6324 writeln!(
6325 self.out,
6326 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6327 )?;
6328 writeln!(
6329 self.out,
6330 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6331 )?;
6332 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6333 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6334 writeln!(self.out, "{l2}}}")?;
6335
6336 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6337
6338 writeln!(self.out, "{l1}}}")?;
6339 writeln!(self.out, "}}")?;
6340 writeln!(self.out)?;
6341 }
6342 _ => {
6343 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6344 let l1 = back::Level(1);
6345 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6346 writeln!(
6347 self.out,
6348 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6349 )?;
6350 writeln!(self.out, "}}")?;
6351 writeln!(self.out)?;
6352 }
6353 }
6354 Ok(())
6355 }
6356
6357 fn write_wrapped_image_query(
6358 &mut self,
6359 module: &crate::Module,
6360 func_ctx: &back::FunctionCtx,
6361 image: Handle<crate::Expression>,
6362 query: crate::ImageQuery,
6363 ) -> BackendResult {
6364 if !matches!(query, crate::ImageQuery::Size { .. }) {
6366 return Ok(());
6367 }
6368 let class = match *func_ctx.resolve_type(image, &module.types) {
6369 crate::TypeInner::Image { class, .. } => class,
6370 _ => unreachable!(),
6371 };
6372 if class != crate::ImageClass::External {
6373 return Ok(());
6374 }
6375 let wrapped = WrappedFunction::ImageQuerySize { class };
6376 if !self.wrapped_functions.insert(wrapped) {
6377 return Ok(());
6378 }
6379 writeln!(
6380 self.out,
6381 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6382 )?;
6383 let l1 = back::Level(1);
6384 let l2 = l1.next();
6385 writeln!(
6386 self.out,
6387 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6388 )?;
6389 writeln!(self.out, "{l2}return tex.params.size;")?;
6390 writeln!(self.out, "{l1}}} else {{")?;
6391 writeln!(
6393 self.out,
6394 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6395 )?;
6396 writeln!(self.out, "{l1}}}")?;
6397 writeln!(self.out, "}}")?;
6398 writeln!(self.out)?;
6399 Ok(())
6400 }
6401
6402 fn write_wrapped_cooperative_load(
6403 &mut self,
6404 module: &crate::Module,
6405 func_ctx: &back::FunctionCtx,
6406 columns: crate::CooperativeSize,
6407 rows: crate::CooperativeSize,
6408 pointer: Handle<crate::Expression>,
6409 ) -> BackendResult {
6410 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6411 let space = ptr_ty.pointer_space().unwrap();
6412 let space_name = space.to_msl_name().unwrap_or_default();
6413 let scalar = ptr_ty
6414 .pointer_base_type()
6415 .unwrap()
6416 .inner_with(&module.types)
6417 .scalar()
6418 .unwrap();
6419 let wrapped = WrappedFunction::CooperativeLoad {
6420 space_name,
6421 columns,
6422 rows,
6423 scalar,
6424 };
6425 if !self.wrapped_functions.insert(wrapped) {
6426 return Ok(());
6427 }
6428 let scalar_name = scalar.to_msl_name();
6429 writeln!(
6430 self.out,
6431 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6432 columns as u32, rows as u32,
6433 )?;
6434 let l1 = back::Level(1);
6435 writeln!(
6436 self.out,
6437 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6438 columns as u32, rows as u32
6439 )?;
6440 let matrix_origin = "0";
6441 writeln!(
6442 self.out,
6443 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6444 )?;
6445 writeln!(self.out, "{l1}return m;")?;
6446 writeln!(self.out, "}}")?;
6447 writeln!(self.out)?;
6448 Ok(())
6449 }
6450
6451 fn write_wrapped_cooperative_multiply_add(
6452 &mut self,
6453 module: &crate::Module,
6454 func_ctx: &back::FunctionCtx,
6455 space: crate::AddressSpace,
6456 a: Handle<crate::Expression>,
6457 b: Handle<crate::Expression>,
6458 ) -> BackendResult {
6459 let space_name = space.to_msl_name().unwrap_or_default();
6460 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6461 crate::TypeInner::CooperativeMatrix {
6462 columns,
6463 rows,
6464 scalar,
6465 ..
6466 } => (columns, rows, scalar),
6467 _ => unreachable!(),
6468 };
6469 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6470 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6471 _ => unreachable!(),
6472 };
6473 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6474 space_name,
6475 columns: b_c,
6476 rows: a_r,
6477 intermediate: a_c,
6478 scalar,
6479 };
6480 if !self.wrapped_functions.insert(wrapped) {
6481 return Ok(());
6482 }
6483 let scalar_name = scalar.to_msl_name();
6484 writeln!(
6485 self.out,
6486 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6487 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6488 )?;
6489 let l1 = back::Level(1);
6490 writeln!(
6491 self.out,
6492 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6493 b_c as u32, a_r as u32
6494 )?;
6495 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6496 writeln!(self.out, "{l1}return d;")?;
6497 writeln!(self.out, "}}")?;
6498 writeln!(self.out)?;
6499 Ok(())
6500 }
6501
6502 pub(super) fn write_wrapped_functions(
6503 &mut self,
6504 module: &crate::Module,
6505 func_ctx: &back::FunctionCtx,
6506 ) -> BackendResult {
6507 for (expr_handle, expr) in func_ctx.expressions.iter() {
6508 match *expr {
6509 crate::Expression::Unary { op, expr: operand } => {
6510 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6511 }
6512 crate::Expression::Binary { op, left, right } => {
6513 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6514 }
6515 crate::Expression::Math {
6516 fun,
6517 arg,
6518 arg1,
6519 arg2,
6520 arg3,
6521 } => {
6522 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6523 }
6524 crate::Expression::As {
6525 expr,
6526 kind,
6527 convert,
6528 } => {
6529 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6530 }
6531 crate::Expression::ImageLoad {
6532 image,
6533 coordinate,
6534 array_index,
6535 sample,
6536 level,
6537 } => {
6538 self.write_wrapped_image_load(
6539 module,
6540 func_ctx,
6541 image,
6542 coordinate,
6543 array_index,
6544 sample,
6545 level,
6546 )?;
6547 }
6548 crate::Expression::ImageSample {
6549 image,
6550 sampler,
6551 gather,
6552 coordinate,
6553 array_index,
6554 offset,
6555 level,
6556 depth_ref,
6557 clamp_to_edge,
6558 } => {
6559 self.write_wrapped_image_sample(
6560 module,
6561 func_ctx,
6562 image,
6563 sampler,
6564 gather,
6565 coordinate,
6566 array_index,
6567 offset,
6568 level,
6569 depth_ref,
6570 clamp_to_edge,
6571 )?;
6572 }
6573 crate::Expression::ImageQuery { image, query } => {
6574 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6575 }
6576 crate::Expression::CooperativeLoad {
6577 columns,
6578 rows,
6579 role: _,
6580 ref data,
6581 } => {
6582 self.write_wrapped_cooperative_load(
6583 module,
6584 func_ctx,
6585 columns,
6586 rows,
6587 data.pointer,
6588 )?;
6589 }
6590 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6591 let space = crate::AddressSpace::Private;
6592 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6593 }
6594 crate::Expression::RayQueryGetIntersection { committed, .. } => {
6595 self.write_rq_get_intersection_function(module, committed)?;
6596 }
6597 _ => {}
6598 }
6599 }
6600
6601 Ok(())
6602 }
6603
6604 fn write_functions(
6606 &mut self,
6607 module: &crate::Module,
6608 mod_info: &valid::ModuleInfo,
6609 options: &Options,
6610 pipeline_options: &PipelineOptions,
6611 ) -> Result<TranslationInfo, Error> {
6612 use back::msl::VertexFormat;
6613
6614 struct AttributeMappingResolved {
6617 ty_name: String,
6618 dimension: Option<crate::VectorSize>,
6619 scalar: crate::Scalar,
6620 name: String,
6621 }
6622 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6623
6624 struct VertexBufferMappingResolved<'a> {
6625 id: u32,
6626 stride: u32,
6627 step_mode: back::msl::VertexBufferStepMode,
6628 ty_name: String,
6629 param_name: String,
6630 elem_name: String,
6631 attributes: &'a Vec<back::msl::AttributeMapping>,
6632 }
6633 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6634
6635 struct UnpackingFunction {
6637 name: String,
6638 byte_count: u32,
6639 dimension: Option<crate::VectorSize>,
6640 scalar: crate::Scalar,
6641 }
6642 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6643
6644 let mut needs_vertex_id = false;
6650 let v_id = self.namer.call("v_id");
6651
6652 let mut needs_instance_id = false;
6653 let i_id = self.namer.call("i_id");
6654 if pipeline_options.vertex_pulling_transform {
6655 for vbm in &pipeline_options.vertex_buffer_mappings {
6656 let buffer_id = vbm.id;
6657 let buffer_stride = vbm.stride;
6658
6659 assert!(
6660 buffer_stride > 0,
6661 "Vertex pulling requires a non-zero buffer stride."
6662 );
6663
6664 match vbm.step_mode {
6665 back::msl::VertexBufferStepMode::Constant => {}
6666 back::msl::VertexBufferStepMode::ByVertex => {
6667 needs_vertex_id = true;
6668 }
6669 back::msl::VertexBufferStepMode::ByInstance => {
6670 needs_instance_id = true;
6671 }
6672 }
6673
6674 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6675 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6676 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6677
6678 vbm_resolved.push(VertexBufferMappingResolved {
6679 id: buffer_id,
6680 stride: buffer_stride,
6681 step_mode: vbm.step_mode,
6682 ty_name: buffer_ty,
6683 param_name: buffer_param,
6684 elem_name: buffer_elem,
6685 attributes: &vbm.attributes,
6686 });
6687
6688 for attribute in &vbm.attributes {
6690 if unpacking_functions.contains_key(&attribute.format) {
6691 continue;
6692 }
6693 let (name, byte_count, dimension, scalar) =
6694 match self.write_unpacking_function(attribute.format) {
6695 Ok((name, byte_count, dimension, scalar)) => {
6696 (name, byte_count, dimension, scalar)
6697 }
6698 _ => {
6699 continue;
6700 }
6701 };
6702 unpacking_functions.insert(
6703 attribute.format,
6704 UnpackingFunction {
6705 name,
6706 byte_count,
6707 dimension,
6708 scalar,
6709 },
6710 );
6711 }
6712 }
6713 }
6714
6715 let mut pass_through_globals = Vec::new();
6716 for (fun_handle, fun) in module.functions.iter() {
6717 log::trace!(
6718 "function {:?}, handle {:?}",
6719 fun.name.as_deref().unwrap_or("(anonymous)"),
6720 fun_handle
6721 );
6722
6723 let ctx = back::FunctionCtx {
6724 ty: back::FunctionType::Function(fun_handle),
6725 info: &mod_info[fun_handle],
6726 expressions: &fun.expressions,
6727 named_expressions: &fun.named_expressions,
6728 };
6729
6730 writeln!(self.out)?;
6731 self.write_wrapped_functions(module, &ctx)?;
6732
6733 let fun_info = &mod_info[fun_handle];
6734 pass_through_globals.clear();
6735 let mut needs_buffer_sizes = false;
6736 for (handle, var) in module.global_variables.iter() {
6737 if !fun_info[handle].is_empty() {
6738 if var.space.needs_pass_through() {
6739 pass_through_globals.push(handle);
6740 }
6741 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6742 }
6743 }
6744
6745 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6746 match fun.result {
6747 Some(ref result) => {
6748 let ty_name = TypeContext {
6749 handle: result.ty,
6750 gctx: module.to_ctx(),
6751 names: &self.names,
6752 access: crate::StorageAccess::empty(),
6753 first_time: false,
6754 };
6755 write!(self.out, "{ty_name}")?;
6756 }
6757 None => {
6758 write!(self.out, "void")?;
6759 }
6760 }
6761 writeln!(self.out, " {fun_name}(")?;
6762
6763 for (index, arg) in fun.arguments.iter().enumerate() {
6764 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6765 let param_type_name = TypeContext {
6766 handle: arg.ty,
6767 gctx: module.to_ctx(),
6768 names: &self.names,
6769 access: crate::StorageAccess::empty(),
6770 first_time: false,
6771 };
6772 let separator = separate(
6773 !pass_through_globals.is_empty()
6774 || index + 1 != fun.arguments.len()
6775 || needs_buffer_sizes,
6776 );
6777 writeln!(
6778 self.out,
6779 "{}{} {}{}",
6780 back::INDENT,
6781 param_type_name,
6782 name,
6783 separator
6784 )?;
6785 }
6786 for (index, &handle) in pass_through_globals.iter().enumerate() {
6787 let tyvar = TypedGlobalVariable {
6788 module,
6789 names: &self.names,
6790 handle,
6791 usage: fun_info[handle],
6792 reference: true,
6793 };
6794 let separator =
6795 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6796 write!(self.out, "{}", back::INDENT)?;
6797 tyvar.try_fmt(&mut self.out)?;
6798 writeln!(self.out, "{separator}")?;
6799 }
6800
6801 if needs_buffer_sizes {
6802 writeln!(
6803 self.out,
6804 "{}constant _mslBufferSizes& _buffer_sizes",
6805 back::INDENT
6806 )?;
6807 }
6808
6809 writeln!(self.out, ") {{")?;
6810
6811 let guarded_indices =
6812 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6813
6814 let context = StatementContext {
6815 expression: ExpressionContext {
6816 function: fun,
6817 origin: FunctionOrigin::Handle(fun_handle),
6818 info: fun_info,
6819 lang_version: options.lang_version,
6820 policies: options.bounds_check_policies,
6821 guarded_indices,
6822 module,
6823 mod_info,
6824 pipeline_options,
6825 force_loop_bounding: options.force_loop_bounding,
6826 },
6827 result_struct: None,
6828 };
6829
6830 self.put_locals(&context.expression)?;
6831 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6832 self.put_block(back::Level(1), &fun.body, &context)?;
6833 writeln!(self.out, "}}")?;
6834 self.named_expressions.clear();
6835 }
6836
6837 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6838 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6839
6840 let mut info = TranslationInfo {
6841 entry_point_names: Vec::with_capacity(ep_range.len()),
6842 };
6843
6844 for ep_index in ep_range {
6845 let ep = &module.entry_points[ep_index];
6846 let fun = &ep.function;
6847 let fun_info = mod_info.get_entry_point(ep_index);
6848 let mut ep_error = None;
6849
6850 let mut v_existing_id = None;
6854 let mut i_existing_id = None;
6855
6856 log::trace!(
6857 "entry point {:?}, index {:?}",
6858 fun.name.as_deref().unwrap_or("(anonymous)"),
6859 ep_index
6860 );
6861
6862 let ctx = back::FunctionCtx {
6863 ty: back::FunctionType::EntryPoint(ep_index as u16),
6864 info: fun_info,
6865 expressions: &fun.expressions,
6866 named_expressions: &fun.named_expressions,
6867 };
6868
6869 self.write_wrapped_functions(module, &ctx)?;
6870
6871 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6872 crate::ShaderStage::Vertex => (
6873 Some("vertex"),
6874 LocationMode::VertexInput,
6875 LocationMode::VertexOutput,
6876 true,
6877 ),
6878 crate::ShaderStage::Fragment => (
6879 Some("fragment"),
6880 LocationMode::FragmentInput,
6881 LocationMode::FragmentOutput,
6882 false,
6883 ),
6884 crate::ShaderStage::Compute => (
6885 Some("kernel"),
6886 LocationMode::Uniform,
6887 LocationMode::Uniform,
6888 false,
6889 ),
6890 crate::ShaderStage::Task => {
6891 (None, LocationMode::Uniform, LocationMode::Uniform, false)
6892 }
6893 crate::ShaderStage::Mesh => {
6894 (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6895 }
6896 crate::ShaderStage::RayGeneration
6897 | crate::ShaderStage::AnyHit
6898 | crate::ShaderStage::ClosestHit
6899 | crate::ShaderStage::Miss => unimplemented!(),
6900 };
6901
6902 let do_vertex_pulling = can_vertex_pull
6904 && pipeline_options.vertex_pulling_transform
6905 && !pipeline_options.vertex_buffer_mappings.is_empty();
6906
6907 let needs_buffer_sizes = do_vertex_pulling
6909 || module
6910 .global_variables
6911 .iter()
6912 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6913 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6914
6915 if !options.fake_missing_bindings {
6918 for (var_handle, var) in module.global_variables.iter() {
6919 if fun_info[var_handle].is_empty() {
6920 continue;
6921 }
6922 match var.space {
6923 crate::AddressSpace::Uniform
6924 | crate::AddressSpace::Storage { .. }
6925 | crate::AddressSpace::Handle => {
6926 let br = match var.binding {
6927 Some(ref br) => br,
6928 None => {
6929 let var_name = var.name.clone().unwrap_or_default();
6930 ep_error =
6931 Some(super::EntryPointError::MissingBinding(var_name));
6932 break;
6933 }
6934 };
6935 let target = options.get_resource_binding_target(ep, br);
6936 let good = match target {
6937 Some(target) => {
6938 match module.types[var.ty].inner {
6942 crate::TypeInner::Image {
6943 class: crate::ImageClass::External,
6944 ..
6945 } => target.external_texture.is_some(),
6946 crate::TypeInner::Image { .. } => target.texture.is_some(),
6947 crate::TypeInner::Sampler { .. } => {
6948 target.sampler.is_some()
6949 }
6950 _ => target.buffer.is_some(),
6951 }
6952 }
6953 None => false,
6954 };
6955 if !good {
6956 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6957 break;
6958 }
6959 }
6960 crate::AddressSpace::Immediate => {
6961 if let Err(e) = options.resolve_immediates(ep) {
6962 ep_error = Some(e);
6963 break;
6964 }
6965 }
6966 crate::AddressSpace::Function
6967 | crate::AddressSpace::Private
6968 | crate::AddressSpace::WorkGroup
6969 | crate::AddressSpace::TaskPayload => {}
6970 crate::AddressSpace::RayPayload
6971 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6972 }
6973 }
6974 if needs_buffer_sizes {
6975 if let Err(err) = options.resolve_sizes_buffer(ep) {
6976 ep_error = Some(err);
6977 }
6978 }
6979 }
6980
6981 if let Some(err) = ep_error {
6982 info.entry_point_names.push(Err(err));
6983 continue;
6984 }
6985 let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
6986 info.entry_point_names.push(Ok(fun_name.clone()));
6987
6988 writeln!(self.out)?;
6989
6990 let mut flattened_member_names = FastHashMap::default();
6996 let mut varyings_namer = proc::Namer::default();
6998
6999 let mut empty_names = FastHashMap::default(); varyings_namer.reset(
7001 module,
7002 &super::keywords::RESERVED_SET,
7003 proc::KeywordSet::empty(),
7004 proc::CaseInsensitiveKeywordSet::empty(),
7005 &[CLAMPED_LOD_LOAD_PREFIX],
7006 &mut empty_names,
7007 );
7008
7009 let mut flattened_arguments = Vec::new();
7014 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7015 match module.types[arg.ty].inner {
7016 crate::TypeInner::Struct { ref members, .. } => {
7017 for (member_index, member) in members.iter().enumerate() {
7018 let member_index = member_index as u32;
7019 flattened_arguments.push((
7020 NameKey::StructMember(arg.ty, member_index),
7021 member.ty,
7022 member.binding.as_ref(),
7023 ));
7024 let name_key = NameKey::StructMember(arg.ty, member_index);
7025 let name = match member.binding {
7026 Some(crate::Binding::Location { .. }) => {
7027 if do_vertex_pulling {
7028 self.namer.call(&self.names[&name_key])
7029 } else {
7030 varyings_namer.call(&self.names[&name_key])
7031 }
7032 }
7033 _ => self.namer.call(&self.names[&name_key]),
7034 };
7035 flattened_member_names.insert(name_key, name);
7036 }
7037 }
7038 _ => flattened_arguments.push((
7039 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7040 arg.ty,
7041 arg.binding.as_ref(),
7042 )),
7043 }
7044 }
7045
7046 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7051 let varyings_member_name = self.namer.call("varyings");
7052 let mut has_varyings = false;
7053
7054 if !flattened_arguments.is_empty() {
7055 if !do_vertex_pulling {
7056 writeln!(self.out, "struct {stage_in_name} {{")?;
7057 }
7058 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7059 let Some(binding) = binding else {
7060 continue;
7061 };
7062 let name = match *name_key {
7063 NameKey::StructMember(..) => &flattened_member_names[name_key],
7064 _ => &self.names[name_key],
7065 };
7066 let ty_name = TypeContext {
7067 handle: ty,
7068 gctx: module.to_ctx(),
7069 names: &self.names,
7070 access: crate::StorageAccess::empty(),
7071 first_time: false,
7072 };
7073 let resolved = options.resolve_local_binding(binding, in_mode)?;
7074 let location = match *binding {
7075 crate::Binding::Location { location, .. } => Some(location),
7076 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7077 crate::Binding::BuiltIn(_) => continue,
7078 };
7079 if do_vertex_pulling {
7080 let Some(location) = location else {
7081 continue;
7082 };
7083 am_resolved.insert(
7085 location,
7086 AttributeMappingResolved {
7087 ty_name: ty_name.to_string(),
7088 dimension: ty_name.vector_size(),
7089 scalar: ty_name.scalar().unwrap(),
7090 name: name.to_string(),
7091 },
7092 );
7093 } else {
7094 has_varyings = true;
7095 if let super::ResolvedBinding::User {
7096 prefix,
7097 index,
7098 interpolation: Some(super::ResolvedInterpolation::PerVertex),
7099 } = resolved
7100 {
7101 if options.lang_version < (4, 0) {
7102 return Err(Error::PerVertexNotSupported);
7103 }
7104 write!(
7105 self.out,
7106 "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7107 back::INDENT,
7108 ty_name.unwrap_array()
7109 )?;
7110 } else {
7111 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7112 resolved.try_fmt(&mut self.out)?;
7113 }
7114 writeln!(self.out, ";")?;
7115 }
7116 }
7117 if !do_vertex_pulling {
7118 writeln!(self.out, "}};")?;
7119 }
7120 }
7121
7122 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7125 let result_member_name = self.namer.call("member");
7126 let result_type_name = match fun.result {
7127 Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7128 let mut result_members = Vec::new();
7129 if let crate::TypeInner::Struct { ref members, .. } =
7130 module.types[result.ty].inner
7131 {
7132 for (member_index, member) in members.iter().enumerate() {
7133 result_members.push((
7134 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7135 member.ty,
7136 member.binding.as_ref(),
7137 ));
7138 }
7139 } else {
7140 result_members.push((
7141 &result_member_name,
7142 result.ty,
7143 result.binding.as_ref(),
7144 ));
7145 }
7146
7147 writeln!(self.out, "struct {stage_out_name} {{")?;
7148 let mut has_point_size = false;
7149 for (name, ty, binding) in result_members {
7150 let ty_name = TypeContext {
7151 handle: ty,
7152 gctx: module.to_ctx(),
7153 names: &self.names,
7154 access: crate::StorageAccess::empty(),
7155 first_time: true,
7156 };
7157 let binding = binding.ok_or_else(|| {
7158 Error::GenericValidation("Expected binding, got None".into())
7159 })?;
7160
7161 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7162 has_point_size = true;
7163 if !pipeline_options.allow_and_force_point_size {
7164 continue;
7165 }
7166 }
7167
7168 let array_len = match module.types[ty].inner {
7169 crate::TypeInner::Array {
7170 size: crate::ArraySize::Constant(size),
7171 ..
7172 } => Some(size),
7173 _ => None,
7174 };
7175 let resolved = options.resolve_local_binding(binding, out_mode)?;
7176 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7177 resolved.try_fmt(&mut self.out)?;
7178 if let Some(array_len) = array_len {
7179 write!(self.out, " [{array_len}]")?;
7180 }
7181 writeln!(self.out, ";")?;
7182 }
7183
7184 if pipeline_options.allow_and_force_point_size
7185 && ep.stage == crate::ShaderStage::Vertex
7186 && !has_point_size
7187 {
7188 writeln!(
7190 self.out,
7191 "{}float _point_size [[point_size]];",
7192 back::INDENT
7193 )?;
7194 }
7195 writeln!(self.out, "}};")?;
7196 &stage_out_name
7197 }
7198 Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7199 assert_eq!(
7200 module.types[result.ty].inner,
7201 crate::TypeInner::Vector {
7202 size: crate::VectorSize::Tri,
7203 scalar: crate::Scalar::U32
7204 }
7205 );
7206
7207 "metal::uint3"
7208 }
7209 _ => "void",
7210 };
7211
7212 let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7213 Some(self.write_mesh_output_types(
7214 mesh_info,
7215 &fun_name,
7216 module,
7217 pipeline_options.allow_and_force_point_size,
7218 options,
7219 )?)
7220 } else {
7221 None
7222 };
7223
7224 if do_vertex_pulling {
7227 for vbm in &vbm_resolved {
7228 let buffer_stride = vbm.stride;
7229 let buffer_ty = &vbm.ty_name;
7230
7231 writeln!(
7235 self.out,
7236 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7237 )?;
7238 }
7239 }
7240
7241 let is_wrapped = matches!(
7242 ep.stage,
7243 crate::ShaderStage::Task | crate::ShaderStage::Mesh
7244 );
7245 let fun_name = fun_name.clone();
7246 let nested_fun_name = if is_wrapped {
7247 self.namer.call(&format!("_{fun_name}"))
7248 } else {
7249 fun_name.clone()
7250 };
7251
7252 if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7254 let total_threads =
7255 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7256 write!(
7257 self.out,
7258 "[[max_total_threads_per_threadgroup({total_threads})]] "
7259 )?;
7260 }
7261
7262 if let Some(em_str) = em_str {
7264 write!(self.out, "{em_str} ")?;
7265 }
7266 writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7267
7268 let mut args = Vec::new();
7269
7270 if has_varyings {
7273 args.push(EntryPointArgument {
7274 ty_name: stage_in_name,
7275 name: varyings_member_name.clone(),
7276 binding: " [[stage_in]]".to_string(),
7277 init: None,
7278 });
7279 }
7280
7281 let mut local_invocation_index = None;
7282
7283 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7286 let binding = match binding {
7287 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7288 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7289 _ => continue,
7290 };
7291 let name = match *name_key {
7292 NameKey::StructMember(..) => &flattened_member_names[name_key],
7293 _ => &self.names[name_key],
7294 };
7295
7296 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7297 local_invocation_index = Some(name_key);
7298 }
7299
7300 let ty_name = TypeContext {
7301 handle: ty,
7302 gctx: module.to_ctx(),
7303 names: &self.names,
7304 access: crate::StorageAccess::empty(),
7305 first_time: false,
7306 };
7307
7308 match *binding {
7309 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7310 v_existing_id = Some(name.clone());
7311 }
7312 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7313 i_existing_id = Some(name.clone());
7314 }
7315 _ => {}
7316 };
7317
7318 let resolved = options.resolve_local_binding(binding, in_mode)?;
7319 let mut binding = String::new();
7320 resolved.try_fmt(&mut binding)?;
7321
7322 args.push(EntryPointArgument {
7323 ty_name: format!("{ty_name}"),
7324 name: name.clone(),
7325 binding,
7326 init: None,
7327 });
7328 }
7329
7330 let need_workgroup_variables_initialization =
7331 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7332
7333 if local_invocation_index.is_none()
7334 && (need_workgroup_variables_initialization
7335 || ep.stage == crate::ShaderStage::Task
7336 || ep.stage == crate::ShaderStage::Mesh)
7337 {
7338 args.push(EntryPointArgument {
7339 ty_name: "uint".to_string(),
7340 name: "__local_invocation_index".to_string(),
7341 binding: " [[thread_index_in_threadgroup]]".to_string(),
7342 init: None,
7343 });
7344 }
7345
7346 for (handle, var) in module.global_variables.iter() {
7351 let usage = fun_info[handle];
7352 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7353 continue;
7354 }
7355
7356 if options.lang_version < (1, 2) {
7357 match var.space {
7358 crate::AddressSpace::Storage { access }
7368 if access.contains(crate::StorageAccess::STORE)
7369 && ep.stage == crate::ShaderStage::Fragment =>
7370 {
7371 return Err(Error::UnsupportedWritableStorageBuffer)
7372 }
7373 crate::AddressSpace::Handle => {
7374 match module.types[var.ty].inner {
7375 crate::TypeInner::Image {
7376 class: crate::ImageClass::Storage { access, .. },
7377 ..
7378 } => {
7379 if access.contains(crate::StorageAccess::STORE)
7389 && (ep.stage == crate::ShaderStage::Vertex
7390 || ep.stage == crate::ShaderStage::Fragment)
7391 {
7392 return Err(Error::UnsupportedWritableStorageTexture(
7393 ep.stage,
7394 ));
7395 }
7396
7397 if access.contains(
7398 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7399 ) {
7400 return Err(Error::UnsupportedRWStorageTexture);
7401 }
7402 }
7403 _ => {}
7404 }
7405 }
7406 _ => {}
7407 }
7408 }
7409
7410 match var.space {
7412 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7413 crate::TypeInner::BindingArray { base, .. } => {
7414 match module.types[base].inner {
7415 crate::TypeInner::Sampler { .. } => {
7416 if options.lang_version < (2, 0) {
7417 return Err(Error::UnsupportedArrayOf(
7418 "samplers".to_string(),
7419 ));
7420 }
7421 }
7422 crate::TypeInner::Image { class, .. } => match class {
7423 crate::ImageClass::Sampled { .. }
7424 | crate::ImageClass::Depth { .. }
7425 | crate::ImageClass::Storage {
7426 access: crate::StorageAccess::LOAD,
7427 ..
7428 } => {
7429 if options.lang_version < (2, 0) {
7434 return Err(Error::UnsupportedArrayOf(
7435 "textures".to_string(),
7436 ));
7437 }
7438 }
7439 crate::ImageClass::Storage {
7440 access: crate::StorageAccess::STORE,
7441 ..
7442 } => {
7443 if options.lang_version < (2, 0) {
7448 return Err(Error::UnsupportedArrayOf(
7449 "write-only textures".to_string(),
7450 ));
7451 }
7452 }
7453 crate::ImageClass::Storage { .. } => {
7454 if options.lang_version < (3, 0) {
7455 return Err(Error::UnsupportedArrayOf(
7456 "read-write textures".to_string(),
7457 ));
7458 }
7459 }
7460 crate::ImageClass::External => {
7461 return Err(Error::UnsupportedArrayOf(
7462 "external textures".to_string(),
7463 ));
7464 }
7465 },
7466 _ => {
7467 return Err(Error::UnsupportedArrayOfType(base));
7468 }
7469 }
7470 }
7471 _ => {}
7472 },
7473 _ => {}
7474 }
7475
7476 let resolved = match var.space {
7478 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7479 crate::AddressSpace::WorkGroup => None,
7480 crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7481 _ => options
7482 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7483 .ok(),
7484 };
7485 if let Some(ref resolved) = resolved {
7486 if resolved.as_inline_sampler(options).is_some() {
7488 continue;
7489 }
7490 }
7491
7492 match module.types[var.ty].inner {
7493 crate::TypeInner::Image {
7494 class: crate::ImageClass::External,
7495 ..
7496 } => {
7497 let target = match resolved {
7501 Some(back::msl::ResolvedBinding::Resource(target)) => {
7502 target.external_texture
7503 }
7504 _ => None,
7505 };
7506
7507 for i in 0..3 {
7508 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7509 handle,
7510 ExternalTextureNameKey::Plane(i),
7511 )];
7512 let ty_name = format!(
7513 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7514 );
7515 let name = plane_name.clone();
7516 let binding = if let Some(ref target) = target {
7517 format!(" [[texture({})]]", target.planes[i])
7518 } else {
7519 String::new()
7520 };
7521 args.push(EntryPointArgument {
7522 ty_name,
7523 name,
7524 binding,
7525 init: None,
7526 });
7527 }
7528 let params_ty_name = &self.names
7529 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7530 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7531 handle,
7532 ExternalTextureNameKey::Params,
7533 )];
7534 let binding = if let Some(ref target) = target {
7535 format!(" [[buffer({})]]", target.params)
7536 } else {
7537 String::new()
7538 };
7539
7540 args.push(EntryPointArgument {
7541 ty_name: format!("constant {params_ty_name}&"),
7542 name: params_name.clone(),
7543 binding,
7544 init: None,
7545 });
7546 }
7547 _ => {
7548 if var.space == crate::AddressSpace::WorkGroup
7549 && ep.stage == crate::ShaderStage::Mesh
7550 {
7551 continue;
7552 }
7553 let tyvar = TypedGlobalVariable {
7554 module,
7555 names: &self.names,
7556 handle,
7557 usage,
7558 reference: true,
7559 };
7560 let parts = tyvar.to_parts()?;
7561 let mut binding = String::new();
7562 if let Some(resolved) = resolved {
7563 resolved.try_fmt(&mut binding)?;
7564 }
7565 args.push(EntryPointArgument {
7566 ty_name: parts.ty_name,
7567 name: parts.var_name,
7568 binding,
7569 init: var.init,
7570 });
7571 }
7572 }
7573 }
7574
7575 if do_vertex_pulling {
7576 if needs_vertex_id && v_existing_id.is_none() {
7577 args.push(EntryPointArgument {
7579 ty_name: "uint".to_string(),
7580 name: v_id.clone(),
7581 binding: " [[vertex_id]]".to_string(),
7582 init: None,
7583 });
7584 }
7585
7586 if needs_instance_id && i_existing_id.is_none() {
7587 args.push(EntryPointArgument {
7588 ty_name: "uint".to_string(),
7589 name: i_id.clone(),
7590 binding: " [[instance_id]]".to_string(),
7591 init: None,
7592 });
7593 }
7594
7595 for vbm in &vbm_resolved {
7598 let id = &vbm.id;
7599 let ty_name = &vbm.ty_name;
7600 let param_name = &vbm.param_name;
7601 args.push(EntryPointArgument {
7602 ty_name: format!("const device {ty_name}*"),
7603 name: param_name.clone(),
7604 binding: format!(" [[buffer({id})]]"),
7605 init: None,
7606 });
7607 }
7608 }
7609
7610 if needs_buffer_sizes {
7613 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7615 let mut binding = String::new();
7616 resolved.try_fmt(&mut binding)?;
7617 args.push(EntryPointArgument {
7618 ty_name: "constant _mslBufferSizes&".to_string(),
7619 name: "_buffer_sizes".to_string(),
7620 binding,
7621 init: None,
7622 });
7623 }
7624
7625 let mut is_first_arg = true;
7626 for arg in &args {
7627 if is_first_arg {
7628 write!(self.out, " ")?;
7629 } else {
7630 write!(self.out, ", ")?;
7631 }
7632 is_first_arg = false;
7633 write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7634 if !is_wrapped {
7635 write!(self.out, "{}", arg.binding)?;
7636 if let Some(init) = arg.init {
7637 write!(self.out, " = ")?;
7638 self.put_const_expression(
7639 init,
7640 module,
7641 mod_info,
7642 &module.global_expressions,
7643 )?;
7644 }
7645 }
7646 writeln!(self.out)?;
7647 }
7648 if ep.stage == crate::ShaderStage::Mesh {
7649 for (handle, var) in module.global_variables.iter() {
7650 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7651 continue;
7652 }
7653 if is_first_arg {
7654 write!(self.out, " ")?;
7655 } else {
7656 write!(self.out, ", ")?;
7657 }
7658 let ty_context = TypeContext {
7659 handle: module.global_variables[handle].ty,
7660 gctx: module.to_ctx(),
7661 names: &self.names,
7662 access: crate::StorageAccess::empty(),
7663 first_time: false,
7664 };
7665 writeln!(
7666 self.out,
7667 "threadgroup {ty_context}& {}",
7668 self.names[&NameKey::GlobalVariable(handle)]
7669 )?;
7670 }
7671 }
7672
7673 writeln!(self.out, ") {{")?;
7675
7676 if do_vertex_pulling {
7678 for vbm in &vbm_resolved {
7681 for attribute in vbm.attributes {
7682 let location = attribute.shader_location;
7683 let am_option = am_resolved.get(&location);
7684 if am_option.is_none() {
7685 continue;
7688 }
7689 let am = am_option.unwrap();
7690 let attribute_ty_name = &am.ty_name;
7691 let attribute_name = &am.name;
7692
7693 writeln!(
7694 self.out,
7695 "{}{attribute_ty_name} {attribute_name} = {{}};",
7696 back::Level(1)
7697 )?;
7698 }
7699
7700 write!(self.out, "{}if (", back::Level(1))?;
7703
7704 let idx = &vbm.id;
7705 let stride = &vbm.stride;
7706 let index_name = match vbm.step_mode {
7707 back::msl::VertexBufferStepMode::Constant => "0",
7708 back::msl::VertexBufferStepMode::ByVertex => {
7709 if let Some(ref name) = v_existing_id {
7710 name
7711 } else {
7712 &v_id
7713 }
7714 }
7715 back::msl::VertexBufferStepMode::ByInstance => {
7716 if let Some(ref name) = i_existing_id {
7717 name
7718 } else {
7719 &i_id
7720 }
7721 }
7722 };
7723 write!(
7724 self.out,
7725 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7726 )?;
7727
7728 writeln!(self.out, ") {{")?;
7729
7730 let ty_name = &vbm.ty_name;
7732 let elem_name = &vbm.elem_name;
7733 let param_name = &vbm.param_name;
7734
7735 writeln!(
7736 self.out,
7737 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7738 back::Level(2),
7739 )?;
7740
7741 for attribute in vbm.attributes {
7744 let location = attribute.shader_location;
7745 let Some(am) = am_resolved.get(&location) else {
7746 continue;
7750 };
7751 let attribute_name = &am.name;
7752 let attribute_ty_name = &am.ty_name;
7753
7754 let offset = attribute.offset;
7755 let func = unpacking_functions
7756 .get(&attribute.format)
7757 .expect("Should have generated this unpacking function earlier.");
7758 let func_name = &func.name;
7759
7760 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7766
7767 let needs_conversion = am.scalar != func.scalar;
7770
7771 if needs_padding_or_truncation != Ordering::Equal {
7772 writeln!(
7775 self.out,
7776 "{}// {attribute_ty_name} <- {:?}",
7777 back::Level(2),
7778 attribute.format
7779 )?;
7780 }
7781
7782 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7783
7784 if needs_padding_or_truncation == Ordering::Greater {
7785 write!(self.out, "{attribute_ty_name}(")?;
7787 }
7788
7789 if needs_conversion {
7791 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7792 write!(self.out, "(")?;
7793 }
7794 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7795 for i in (offset + 1)..(offset + func.byte_count) {
7796 write!(self.out, ", {elem_name}.data[{i}]")?;
7797 }
7798 write!(self.out, ")")?;
7799 if needs_conversion {
7800 write!(self.out, ")")?;
7801 }
7802
7803 match needs_padding_or_truncation {
7804 Ordering::Greater => {
7805 let ty_is_int = scalar_is_int(am.scalar);
7807 let zero_value = if ty_is_int { "0" } else { "0.0" };
7808 let one_value = if ty_is_int { "1" } else { "1.0" };
7809 for i in func.dimension.map_or(1, u8::from)
7810 ..am.dimension.map_or(1, u8::from)
7811 {
7812 write!(
7813 self.out,
7814 ", {}",
7815 if i == 3 { one_value } else { zero_value }
7816 )?;
7817 }
7818 }
7819 Ordering::Less => {
7820 write!(
7822 self.out,
7823 ".{}",
7824 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7825 )?;
7826 }
7827 Ordering::Equal => {}
7828 }
7829
7830 if needs_padding_or_truncation == Ordering::Greater {
7831 write!(self.out, ")")?;
7832 }
7833
7834 writeln!(self.out, ";")?;
7835 }
7836
7837 writeln!(self.out, "{}}}", back::Level(1))?;
7839 }
7840 }
7841
7842 for (handle, var) in module.global_variables.iter() {
7845 let usage = fun_info[handle];
7846 if usage.is_empty() {
7847 continue;
7848 }
7849 if var.space == crate::AddressSpace::Private {
7850 let tyvar = TypedGlobalVariable {
7851 module,
7852 names: &self.names,
7853 handle,
7854 usage,
7855
7856 reference: false,
7857 };
7858 write!(self.out, "{}", back::INDENT)?;
7859 tyvar.try_fmt(&mut self.out)?;
7860 match var.init {
7861 Some(value) => {
7862 write!(self.out, " = ")?;
7863 self.put_const_expression(
7864 value,
7865 module,
7866 mod_info,
7867 &module.global_expressions,
7868 )?;
7869 writeln!(self.out, ";")?;
7870 }
7871 None => {
7872 writeln!(self.out, " = {{}};")?;
7873 }
7874 };
7875 } else if let Some(ref binding) = var.binding {
7876 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7877 if let Some(sampler) = resolved.as_inline_sampler(options) {
7878 let name = &self.names[&NameKey::GlobalVariable(handle)];
7880 writeln!(
7881 self.out,
7882 "{}constexpr {}::sampler {}(",
7883 back::INDENT,
7884 NAMESPACE,
7885 name
7886 )?;
7887 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7888 writeln!(self.out, "{});", back::INDENT)?;
7889 } else if let crate::TypeInner::Image {
7890 class: crate::ImageClass::External,
7891 ..
7892 } = module.types[var.ty].inner
7893 {
7894 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7897 let l1 = back::Level(1);
7898 let l2 = l1.next();
7899 writeln!(
7900 self.out,
7901 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7902 )?;
7903 for i in 0..3 {
7904 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7905 handle,
7906 ExternalTextureNameKey::Plane(i),
7907 )];
7908 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7909 }
7910 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7911 handle,
7912 ExternalTextureNameKey::Params,
7913 )];
7914 writeln!(self.out, "{l2}.params = {params_name},")?;
7915 writeln!(self.out, "{l1}}};")?;
7916 }
7917 }
7918 }
7919
7920 if need_workgroup_variables_initialization {
7921 self.write_workgroup_variables_initialization(
7922 module,
7923 mod_info,
7924 fun_info,
7925 local_invocation_index,
7926 ep.stage,
7927 )?;
7928 }
7929
7930 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7941 let arg_name =
7942 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7943 match module.types[arg.ty].inner {
7944 crate::TypeInner::Struct { ref members, .. } => {
7945 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7946 write!(
7947 self.out,
7948 "{}const {} {} = {{ ",
7949 back::INDENT,
7950 struct_name,
7951 arg_name
7952 )?;
7953 for (member_index, member) in members.iter().enumerate() {
7954 let key = NameKey::StructMember(arg.ty, member_index as u32);
7955 let name = &flattened_member_names[&key];
7956 if member_index != 0 {
7957 write!(self.out, ", ")?;
7958 }
7959 if self
7961 .struct_member_pads
7962 .contains(&(arg.ty, member_index as u32))
7963 {
7964 write!(self.out, "{{}}, ")?;
7965 }
7966 match member.binding {
7967 Some(crate::Binding::Location {
7968 interpolation: Some(crate::Interpolation::PerVertex),
7969 ..
7970 }) => {
7971 writeln!(
7972 self.out,
7973 "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7974 back::INDENT,
7975 varyings_member_name,
7976 arg_name,
7977 )?;
7978 continue;
7979 }
7980 Some(crate::Binding::Location { .. }) => {
7981 if has_varyings {
7982 write!(self.out, "{varyings_member_name}.")?;
7983 }
7984 }
7985 _ => (),
7986 }
7987 write!(self.out, "{name}")?;
7988 }
7989 writeln!(self.out, " }};")?;
7990 }
7991 _ => match arg.binding {
7992 Some(crate::Binding::Location {
7993 interpolation: Some(crate::Interpolation::PerVertex),
7994 ..
7995 }) => {
7996 let ty_name = TypeContext {
7997 handle: arg.ty,
7998 gctx: module.to_ctx(),
7999 names: &self.names,
8000 access: crate::StorageAccess::empty(),
8001 first_time: false,
8002 };
8003 writeln!(
8004 self.out,
8005 "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8006 back::INDENT,
8007 varyings_member_name,
8008 arg_name,
8009 )?;
8010 }
8011 Some(crate::Binding::Location { .. })
8012 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8013 if has_varyings {
8014 writeln!(
8015 self.out,
8016 "{}const auto {} = {}.{};",
8017 back::INDENT,
8018 arg_name,
8019 varyings_member_name,
8020 arg_name
8021 )?;
8022 }
8023 }
8024 _ => {}
8025 },
8026 }
8027 }
8028
8029 let guarded_indices =
8030 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8031
8032 let context = StatementContext {
8033 expression: ExpressionContext {
8034 function: fun,
8035 origin: FunctionOrigin::EntryPoint(ep_index as _),
8036 info: fun_info,
8037 lang_version: options.lang_version,
8038 policies: options.bounds_check_policies,
8039 guarded_indices,
8040 module,
8041 mod_info,
8042 pipeline_options,
8043 force_loop_bounding: options.force_loop_bounding,
8044 },
8045 result_struct: if ep.stage == crate::ShaderStage::Task {
8046 None
8047 } else {
8048 Some(&stage_out_name)
8049 },
8050 };
8051
8052 self.put_locals(&context.expression)?;
8055 self.update_expressions_to_bake(fun, fun_info, &context.expression);
8056 self.put_block(back::Level(1), &fun.body, &context)?;
8057 writeln!(self.out, "}}")?;
8058 if ep_index + 1 != module.entry_points.len() {
8059 writeln!(self.out)?;
8060 }
8061 self.named_expressions.clear();
8062
8063 if is_wrapped {
8064 self.write_wrapper_function(NestedFunctionInfo {
8065 options,
8066 ep,
8067 module,
8068 mod_info,
8069 fun_info,
8070 args,
8071 local_invocation_index,
8072 nested_name: &nested_fun_name,
8073 outer_name: &fun_name,
8074 out_mesh_info,
8075 })?;
8076 }
8077 }
8078
8079 Ok(info)
8080 }
8081
8082 pub(super) fn write_barrier(
8083 &mut self,
8084 flags: crate::Barrier,
8085 level: back::Level,
8086 ) -> BackendResult {
8087 if flags.is_empty() {
8090 writeln!(
8091 self.out,
8092 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8093 )?;
8094 }
8095 if flags.contains(crate::Barrier::STORAGE) {
8096 writeln!(
8097 self.out,
8098 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8099 )?;
8100 }
8101 if flags.contains(crate::Barrier::WORK_GROUP) {
8102 writeln!(
8103 self.out,
8104 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8105 )?;
8106 if self.needs_object_memory_barriers {
8107 writeln!(
8108 self.out,
8109 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8110 )?;
8111 }
8112 }
8113 if flags.contains(crate::Barrier::SUB_GROUP) {
8114 writeln!(
8115 self.out,
8116 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8117 )?;
8118 }
8119 if flags.contains(crate::Barrier::TEXTURE) {
8120 writeln!(
8121 self.out,
8122 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8123 )?;
8124 }
8125 Ok(())
8126 }
8127}
8128
8129mod workgroup_mem_init {
8132 use crate::EntryPoint;
8133
8134 use super::*;
8135
8136 enum Access {
8137 GlobalVariable(Handle<crate::GlobalVariable>),
8138 StructMember(Handle<crate::Type>, u32),
8139 Array(usize),
8140 }
8141
8142 impl Access {
8143 fn write<W: Write>(
8144 &self,
8145 writer: &mut W,
8146 names: &FastHashMap<NameKey, String>,
8147 ) -> Result<(), core::fmt::Error> {
8148 match *self {
8149 Access::GlobalVariable(handle) => {
8150 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8151 }
8152 Access::StructMember(handle, index) => {
8153 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8154 }
8155 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8156 }
8157 }
8158 }
8159
8160 struct AccessStack {
8161 stack: Vec<Access>,
8162 array_depth: usize,
8163 }
8164
8165 impl AccessStack {
8166 const fn new() -> Self {
8167 Self {
8168 stack: Vec::new(),
8169 array_depth: 0,
8170 }
8171 }
8172
8173 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8174 let array_depth = self.array_depth;
8175 self.stack.push(Access::Array(array_depth));
8176 self.array_depth += 1;
8177 let res = cb(self, array_depth);
8178 self.stack.pop();
8179 self.array_depth -= 1;
8180 res
8181 }
8182
8183 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8184 self.stack.push(new);
8185 let res = cb(self);
8186 self.stack.pop();
8187 res
8188 }
8189
8190 fn write<W: Write>(
8191 &self,
8192 writer: &mut W,
8193 names: &FastHashMap<NameKey, String>,
8194 ) -> Result<(), core::fmt::Error> {
8195 for next in self.stack.iter() {
8196 next.write(writer, names)?;
8197 }
8198 Ok(())
8199 }
8200 }
8201
8202 impl<W: Write> Writer<W> {
8203 pub(super) fn need_workgroup_variables_initialization(
8204 &mut self,
8205 options: &Options,
8206 ep: &EntryPoint,
8207 module: &crate::Module,
8208 fun_info: &valid::FunctionInfo,
8209 ) -> bool {
8210 let is_task = ep.stage == crate::ShaderStage::Task;
8211 options.zero_initialize_workgroup_memory
8212 && ep.stage.compute_like()
8213 && module.global_variables.iter().any(|(handle, var)| {
8214 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8215 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8216 !fun_info[handle].is_empty() && is_right_address_space
8217 })
8218 }
8219
8220 pub fn write_workgroup_variables_initialization(
8221 &mut self,
8222 module: &crate::Module,
8223 module_info: &valid::ModuleInfo,
8224 fun_info: &valid::FunctionInfo,
8225 local_invocation_index: Option<&NameKey>,
8226 stage: crate::ShaderStage,
8227 ) -> BackendResult {
8228 let level = back::Level(1);
8229
8230 writeln!(
8231 self.out,
8232 "{}if ({} == 0u) {{",
8233 level,
8234 local_invocation_index
8235 .map(|name_key| self.names[name_key].as_str())
8236 .unwrap_or("__local_invocation_index"),
8237 )?;
8238
8239 let mut access_stack = AccessStack::new();
8240
8241 let is_task = stage == crate::ShaderStage::Task;
8242 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8243 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8244 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8245 !fun_info[handle].is_empty() && is_right_address_space
8246 });
8247
8248 for (handle, var) in vars {
8249 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8250 self.write_workgroup_variable_initialization(
8251 module,
8252 module_info,
8253 var.ty,
8254 access_stack,
8255 level.next(),
8256 )
8257 })?;
8258 }
8259
8260 writeln!(self.out, "{level}}}")?;
8261 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8262 }
8263
8264 fn write_workgroup_variable_initialization(
8265 &mut self,
8266 module: &crate::Module,
8267 module_info: &valid::ModuleInfo,
8268 ty: Handle<crate::Type>,
8269 access_stack: &mut AccessStack,
8270 level: back::Level,
8271 ) -> BackendResult {
8272 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8273 write!(self.out, "{level}")?;
8274 access_stack.write(&mut self.out, &self.names)?;
8275 writeln!(self.out, " = {{}};")?;
8276 } else {
8277 match module.types[ty].inner {
8278 crate::TypeInner::Atomic { .. } => {
8279 write!(
8280 self.out,
8281 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8282 )?;
8283 access_stack.write(&mut self.out, &self.names)?;
8284 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8285 }
8286 crate::TypeInner::Array { base, size, .. } => {
8287 let count = match size.resolve(module.to_ctx())? {
8288 proc::IndexableLength::Known(count) => count,
8289 proc::IndexableLength::Dynamic => unreachable!(),
8290 };
8291
8292 access_stack.enter_array(|access_stack, array_depth| {
8293 writeln!(
8294 self.out,
8295 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8296 )?;
8297 self.write_workgroup_variable_initialization(
8298 module,
8299 module_info,
8300 base,
8301 access_stack,
8302 level.next(),
8303 )?;
8304 writeln!(self.out, "{level}}}")?;
8305 BackendResult::Ok(())
8306 })?;
8307 }
8308 crate::TypeInner::Struct { ref members, .. } => {
8309 for (index, member) in members.iter().enumerate() {
8310 access_stack.enter(
8311 Access::StructMember(ty, index as u32),
8312 |access_stack| {
8313 self.write_workgroup_variable_initialization(
8314 module,
8315 module_info,
8316 member.ty,
8317 access_stack,
8318 level,
8319 )
8320 },
8321 )?;
8322 }
8323 }
8324 _ => unreachable!(),
8325 }
8326 }
8327
8328 Ok(())
8329 }
8330 }
8331}
8332
8333impl crate::AtomicFunction {
8334 const fn to_msl(self) -> &'static str {
8335 match self {
8336 Self::Add => "fetch_add",
8337 Self::Subtract => "fetch_sub",
8338 Self::And => "fetch_and",
8339 Self::InclusiveOr => "fetch_or",
8340 Self::ExclusiveOr => "fetch_xor",
8341 Self::Min => "fetch_min",
8342 Self::Max => "fetch_max",
8343 Self::Exchange { compare: None } => "exchange",
8344 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8345 }
8346 }
8347
8348 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8349 Ok(match self {
8350 Self::Min => "min",
8351 Self::Max => "max",
8352 _ => Err(Error::FeatureNotImplemented(
8353 "64-bit atomic operation other than min/max".to_string(),
8354 ))?,
8355 })
8356 }
8357}