1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{
17 ray::RT_NAMESPACE, sampler as sm, Error, LocationMode, Options, PipelineOptions,
18 TranslationInfo, NAMESPACE, WRAPPED_ARRAY_FIELD,
19};
20use crate::{
21 arena::{Handle, HandleSet},
22 back::{
23 self, get_entry_points,
24 msl::{mesh_shader::NestedFunctionInfo, BackendResult, EntryPointArgument},
25 Baked,
26 },
27 common,
28 proc::{
29 self, concrete_int_scalars,
30 index::{self, BoundsCheck},
31 ExternalTextureNameKey, NameKey, TypeResolution,
32 },
33 valid, FastHashMap, FastHashSet,
34};
35
36#[cfg(test)]
37use core::ptr;
38
39const ATOMIC_REFERENCE: &str = "&";
43
44pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
45pub(crate) const MODF_FUNCTION: &str = "naga_modf";
46pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
47pub(crate) const ABS_FUNCTION: &str = "naga_abs";
48pub(crate) const DIV_FUNCTION: &str = "naga_div";
49pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
50pub(crate) const MOD_FUNCTION: &str = "naga_mod";
51pub(crate) const NEG_FUNCTION: &str = "naga_neg";
52pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
53pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
54pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
55pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
56pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
57pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
58pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
59 "nagaTextureSampleBaseClampToEdge";
60pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
68pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
73pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad";
74pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";
75
76fn put_numeric_type(
85 out: &mut impl Write,
86 scalar: crate::Scalar,
87 sizes: &[crate::VectorSize],
88) -> Result<(), FmtError> {
89 match (scalar, sizes) {
90 (scalar, &[]) => {
91 write!(out, "{}", scalar.to_msl_name())
92 }
93 (scalar, &[rows]) => {
94 write!(
95 out,
96 "{}::{}{}",
97 NAMESPACE,
98 scalar.to_msl_name(),
99 common::vector_size_str(rows)
100 )
101 }
102 (scalar, &[rows, columns]) => {
103 write!(
104 out,
105 "{}::{}{}x{}",
106 NAMESPACE,
107 scalar.to_msl_name(),
108 common::vector_size_str(columns),
109 common::vector_size_str(rows)
110 )
111 }
112 (_, _) => Ok(()), }
114}
115
116const fn scalar_is_int(scalar: crate::Scalar) -> bool {
117 use crate::ScalarKind::*;
118 match scalar.kind {
119 Sint | Uint | AbstractInt | Bool => true,
120 Float | AbstractFloat => false,
121 }
122}
123
124const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
126
127const REINTERPRET_PREFIX: &str = "reinterpreted_";
129
130struct ClampedLod(Handle<crate::Expression>);
136
137impl Display for ClampedLod {
138 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
139 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
140 }
141}
142
143struct ArraySizeMember(Handle<crate::GlobalVariable>);
158
159impl Display for ArraySizeMember {
160 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
161 self.0.write_prefixed(f, "size")
162 }
163}
164
165#[derive(Clone, Copy)]
170struct Reinterpreted<'a> {
171 target_type: &'a str,
172 orig: Handle<crate::Expression>,
173}
174
175impl<'a> Reinterpreted<'a> {
176 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
177 Self { target_type, orig }
178 }
179}
180
181impl Display for Reinterpreted<'_> {
182 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
183 f.write_str(REINTERPRET_PREFIX)?;
184 f.write_str(self.target_type)?;
185 self.orig.write_prefixed(f, "_e")
186 }
187}
188
189pub(super) struct TypeContext<'a> {
190 pub handle: Handle<crate::Type>,
191 pub gctx: proc::GlobalCtx<'a>,
192 pub names: &'a FastHashMap<NameKey, String>,
193 pub access: crate::StorageAccess,
194 pub first_time: bool,
195}
196
197impl TypeContext<'_> {
198 fn scalar(&self) -> Option<crate::Scalar> {
199 let ty = &self.gctx.types[self.handle];
200 ty.inner.scalar()
201 }
202
203 fn vector_size(&self) -> Option<crate::VectorSize> {
204 let ty = &self.gctx.types[self.handle];
205 match ty.inner {
206 crate::TypeInner::Vector { size, .. } => Some(size),
207 _ => None,
208 }
209 }
210
211 fn unwrap_array(self) -> Self {
212 match self.gctx.types[self.handle].inner {
213 crate::TypeInner::Array { base, .. } => Self {
214 handle: base,
215 ..self
216 },
217 _ => self,
218 }
219 }
220}
221
222impl Display for TypeContext<'_> {
223 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
224 let ty = &self.gctx.types[self.handle];
225 if ty.needs_alias() && !self.first_time {
226 let name = &self.names[&NameKey::Type(self.handle)];
227 return write!(out, "{name}");
228 }
229
230 match ty.inner {
231 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
232 crate::TypeInner::Atomic(scalar) => {
233 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
234 }
235 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
236 crate::TypeInner::Matrix {
237 columns,
238 rows,
239 scalar,
240 } => put_numeric_type(out, scalar, &[rows, columns]),
241 crate::TypeInner::CooperativeMatrix {
243 columns,
244 rows,
245 scalar,
246 role: _,
247 } => {
248 write!(
249 out,
250 "{NAMESPACE}::simdgroup_{}{}x{}",
251 scalar.to_msl_name(),
252 columns as u32,
253 rows as u32,
254 )
255 }
256 crate::TypeInner::Pointer { base, space } => {
257 let sub = Self {
258 handle: base,
259 first_time: false,
260 ..*self
261 };
262 let space_name = match space.to_msl_name() {
263 Some(name) => name,
264 None => return Ok(()),
265 };
266 write!(out, "{space_name} {sub}&")
267 }
268 crate::TypeInner::ValuePointer {
269 size,
270 scalar,
271 space,
272 } => {
273 match space.to_msl_name() {
274 Some(name) => write!(out, "{name} ")?,
275 None => return Ok(()),
276 };
277 match size {
278 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
279 None => put_numeric_type(out, scalar, &[])?,
280 };
281
282 write!(out, "&")
283 }
284 crate::TypeInner::Array { base, .. } => {
285 let sub = Self {
286 handle: base,
287 first_time: false,
288 ..*self
289 };
290 write!(out, "{sub}")
293 }
294 crate::TypeInner::Struct { .. } => unreachable!(),
295 crate::TypeInner::Image {
296 dim,
297 arrayed,
298 class,
299 } => {
300 let dim_str = match dim {
301 crate::ImageDimension::D1 => "1d",
302 crate::ImageDimension::D2 => "2d",
303 crate::ImageDimension::D3 => "3d",
304 crate::ImageDimension::Cube => "cube",
305 };
306 let (texture_str, msaa_str, scalar, access) = match class {
307 crate::ImageClass::Sampled { kind, multi } => {
308 let (msaa_str, access) = if multi {
309 ("_ms", "read")
310 } else {
311 ("", "sample")
312 };
313 let scalar = crate::Scalar { kind, width: 4 };
314 ("texture", msaa_str, scalar, access)
315 }
316 crate::ImageClass::Depth { multi } => {
317 let (msaa_str, access) = if multi {
318 ("_ms", "read")
319 } else {
320 ("", "sample")
321 };
322 let scalar = crate::Scalar {
323 kind: crate::ScalarKind::Float,
324 width: 4,
325 };
326 ("depth", msaa_str, scalar, access)
327 }
328 crate::ImageClass::Storage { format, .. } => {
329 let access = if self
330 .access
331 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
332 {
333 "read_write"
334 } else if self.access.contains(crate::StorageAccess::STORE) {
335 "write"
336 } else if self.access.contains(crate::StorageAccess::LOAD) {
337 "read"
338 } else {
339 log::warn!(
340 "Storage access for {:?} (name '{}'): {:?}",
341 self.handle,
342 ty.name.as_deref().unwrap_or_default(),
343 self.access
344 );
345 unreachable!("module is not valid");
346 };
347 ("texture", "", format.into(), access)
348 }
349 crate::ImageClass::External => {
350 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
351 }
352 };
353 let base_name = scalar.to_msl_name();
354 let array_str = if arrayed { "_array" } else { "" };
355 write!(
356 out,
357 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
358 )
359 }
360 crate::TypeInner::Sampler { comparison: _ } => {
361 write!(out, "{NAMESPACE}::sampler")
362 }
363 crate::TypeInner::AccelerationStructure { vertex_return } => {
364 if vertex_return {
365 unimplemented!("metal does not support vertex ray hit return")
366 }
367 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
368 }
369 crate::TypeInner::RayQuery { vertex_return } => {
370 if vertex_return {
371 unimplemented!("metal does not support vertex ray hit return")
372 }
373 write!(out, "{}", super::ray::metal_intersector_ty())
374 }
375 crate::TypeInner::BindingArray { base, .. } => {
376 let base_tyname = Self {
377 handle: base,
378 first_time: false,
379 ..*self
380 };
381
382 write!(
383 out,
384 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
385 )
386 }
387 }
388 }
389}
390
391pub(super) struct TypedGlobalVariable<'a> {
392 pub module: &'a crate::Module,
393 pub names: &'a FastHashMap<NameKey, String>,
394 pub handle: Handle<crate::GlobalVariable>,
395 pub usage: valid::GlobalUse,
396 pub reference: bool,
397}
398
399struct TypedGlobalVariableParts {
400 ty_name: String,
401 var_name: String,
402}
403
404impl TypedGlobalVariable<'_> {
405 fn to_parts(&self) -> Result<TypedGlobalVariableParts, Error> {
406 let var = &self.module.global_variables[self.handle];
407 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
408
409 let storage_access = match var.space {
410 crate::AddressSpace::Storage { access } => access,
411 _ => match self.module.types[var.ty].inner {
412 crate::TypeInner::Image {
413 class: crate::ImageClass::Storage { access, .. },
414 ..
415 } => access,
416 crate::TypeInner::BindingArray { base, .. } => {
417 match self.module.types[base].inner {
418 crate::TypeInner::Image {
419 class: crate::ImageClass::Storage { access, .. },
420 ..
421 } => access,
422 _ => crate::StorageAccess::default(),
423 }
424 }
425 _ => crate::StorageAccess::default(),
426 },
427 };
428 let ty_name = TypeContext {
429 handle: var.ty,
430 gctx: self.module.to_ctx(),
431 names: self.names,
432 access: storage_access,
433 first_time: false,
434 };
435
436 let access = if var.space.needs_access_qualifier()
437 && !self.usage.intersects(valid::GlobalUse::WRITE)
438 {
439 "const"
440 } else {
441 ""
442 };
443 let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
444 (Some(space), crate::AddressSpace::WorkGroup) => {
445 ("", space, access, if self.reference { "&" } else { "" })
446 }
447 (Some(space), _) if self.reference => {
448 let coherent = if var
449 .memory_decorations
450 .contains(crate::MemoryDecorations::COHERENT)
451 {
452 "coherent "
453 } else {
454 ""
455 };
456 (coherent, space, access, "&")
457 }
458 _ => ("", "", "", ""),
459 };
460
461 let ty = format!(
462 "{coherent}{space}{}{ty_name}{}{access}{reference}",
463 if space.is_empty() { "" } else { " " },
464 if access.is_empty() { "" } else { " " },
465 );
466
467 Ok(TypedGlobalVariableParts {
468 ty_name: ty,
469 var_name: name.clone(),
470 })
471 }
472 pub(super) fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
473 let parts = self.to_parts()?;
474
475 Ok(write!(out, "{} {}", parts.ty_name, parts.var_name)?)
476 }
477}
478
479#[derive(Eq, PartialEq, Hash)]
480pub(super) enum WrappedFunction {
481 UnaryOp {
482 op: crate::UnaryOperator,
483 ty: (Option<crate::VectorSize>, crate::Scalar),
484 },
485 BinaryOp {
486 op: crate::BinaryOperator,
487 left_ty: (Option<crate::VectorSize>, crate::Scalar),
488 right_ty: (Option<crate::VectorSize>, crate::Scalar),
489 },
490 Math {
491 fun: crate::MathFunction,
492 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
493 },
494 Cast {
495 src_scalar: crate::Scalar,
496 vector_size: Option<crate::VectorSize>,
497 dst_scalar: crate::Scalar,
498 },
499 ImageLoad {
500 class: crate::ImageClass,
501 },
502 ImageSample {
503 class: crate::ImageClass,
504 clamp_to_edge: bool,
505 },
506 ImageQuerySize {
507 class: crate::ImageClass,
508 },
509 CooperativeLoad {
510 space_name: &'static str,
511 columns: crate::CooperativeSize,
512 rows: crate::CooperativeSize,
513 scalar: crate::Scalar,
514 },
515 CooperativeMultiplyAdd {
516 space_name: &'static str,
517 columns: crate::CooperativeSize,
518 rows: crate::CooperativeSize,
519 intermediate: crate::CooperativeSize,
520 scalar: crate::Scalar,
521 },
522 RayQueryGetIntersection {
523 committed: bool,
524 },
525}
526
527pub struct Writer<W> {
528 pub(super) out: W,
529 pub(super) names: FastHashMap<NameKey, String>,
530 pub(super) named_expressions: crate::NamedExpressions,
531 pub(super) need_bake_expressions: back::NeedBakeExpressions,
533 pub(super) namer: proc::Namer,
534 pub(super) wrapped_functions: FastHashSet<WrappedFunction>,
535 #[cfg(test)]
536 pub(super) put_expression_stack_pointers: FastHashSet<*const ()>,
537 #[cfg(test)]
538 pub(super) put_block_stack_pointers: FastHashSet<*const ()>,
539 pub(super) struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
542 pub(super) needs_object_memory_barriers: bool,
543}
544
545impl crate::Scalar {
546 pub(super) fn to_msl_name(self) -> &'static str {
547 use crate::ScalarKind as Sk;
548 match self {
549 Self {
550 kind: Sk::Float,
551 width: 4,
552 } => "float",
553 Self {
554 kind: Sk::Float,
555 width: 2,
556 } => "half",
557 Self {
558 kind: Sk::Sint,
559 width: 2,
560 } => "short",
561 Self {
562 kind: Sk::Uint,
563 width: 2,
564 } => "ushort",
565 Self {
566 kind: Sk::Sint,
567 width: 4,
568 } => "int",
569 Self {
570 kind: Sk::Uint,
571 width: 4,
572 } => "uint",
573 Self {
574 kind: Sk::Sint,
575 width: 8,
576 } => "long",
577 Self {
578 kind: Sk::Uint,
579 width: 8,
580 } => "ulong",
581 Self {
582 kind: Sk::Bool,
583 width: _,
584 } => "bool",
585 Self {
586 kind: Sk::AbstractInt | Sk::AbstractFloat,
587 width: _,
588 } => unreachable!("Found Abstract scalar kind"),
589 _ => unreachable!("Unsupported scalar kind: {:?}", self),
590 }
591 }
592}
593
594const fn separate(need_separator: bool) -> &'static str {
595 if need_separator {
596 ","
597 } else {
598 ""
599 }
600}
601
602fn should_pack_struct_member(
603 members: &[crate::StructMember],
604 span: u32,
605 index: usize,
606 module: &crate::Module,
607) -> Option<crate::Scalar> {
608 let member = &members[index];
609
610 let ty_inner = &module.types[member.ty].inner;
611 let last_offset = member.offset + ty_inner.size(module.to_ctx());
612 let next_offset = match members.get(index + 1) {
613 Some(next) => next.offset,
614 None => span,
615 };
616 let is_tight = next_offset == last_offset;
617
618 match *ty_inner {
619 crate::TypeInner::Vector {
620 size: crate::VectorSize::Tri,
621 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
622 } if is_tight => Some(scalar),
623 _ => None,
624 }
625}
626
627fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
628 match arena[ty].inner {
629 crate::TypeInner::Struct { ref members, .. } => {
630 if let Some(member) = members.last() {
631 if let crate::TypeInner::Array {
632 size: crate::ArraySize::Dynamic,
633 ..
634 } = arena[member.ty].inner
635 {
636 return true;
637 }
638 }
639 false
640 }
641 crate::TypeInner::Array {
642 size: crate::ArraySize::Dynamic,
643 ..
644 } => true,
645 _ => false,
646 }
647}
648
649impl crate::AddressSpace {
650 const fn needs_pass_through(&self) -> bool {
654 match *self {
655 Self::Uniform
656 | Self::Storage { .. }
657 | Self::Private
658 | Self::WorkGroup
659 | Self::Immediate
660 | Self::Handle
661 | Self::TaskPayload => true,
662 Self::Function => false,
663 Self::RayPayload | Self::IncomingRayPayload => unreachable!(),
664 }
665 }
666
667 const fn needs_access_qualifier(&self) -> bool {
669 match *self {
670 Self::Storage { .. } => true,
675 Self::TaskPayload => true,
676 Self::RayPayload | Self::IncomingRayPayload => unimplemented!(),
677 Self::Private | Self::WorkGroup => false,
679 Self::Uniform | Self::Immediate => false,
681 Self::Handle | Self::Function => false,
683 }
684 }
685
686 const fn to_msl_name(self) -> Option<&'static str> {
687 match self {
688 Self::Handle => None,
689 Self::Uniform | Self::Immediate => Some("constant"),
690 Self::Storage { .. } => Some("device"),
691 Self::Private | Self::Function | Self::RayPayload => Some("thread"),
695 Self::WorkGroup => Some("threadgroup"),
696 Self::TaskPayload => Some("object_data"),
697 Self::IncomingRayPayload => Some("ray_data"),
698 }
699 }
700}
701
702impl crate::Type {
703 const fn needs_alias(&self) -> bool {
705 use crate::TypeInner as Ti;
706
707 match self.inner {
708 Ti::Scalar(_)
710 | Ti::Vector { .. }
711 | Ti::Matrix { .. }
712 | Ti::CooperativeMatrix { .. }
713 | Ti::Atomic(_)
714 | Ti::Pointer { .. }
715 | Ti::ValuePointer { .. } => self.name.is_some(),
716 Ti::Struct { .. } | Ti::Array { .. } => true,
718 Ti::Image { .. }
720 | Ti::Sampler { .. }
721 | Ti::AccelerationStructure { .. }
722 | Ti::RayQuery { .. }
723 | Ti::BindingArray { .. } => false,
724 }
725 }
726}
727
728#[derive(Clone, Copy)]
729enum FunctionOrigin {
730 Handle(Handle<crate::Function>),
731 EntryPoint(proc::EntryPointIndex),
732}
733
734trait NameKeyExt {
735 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
736 match origin {
737 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
738 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
739 }
740 }
741
742 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
747 match origin {
748 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
749 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
750 }
751 }
752}
753
754impl NameKeyExt for NameKey {}
755
756#[derive(Clone, Copy)]
766enum LevelOfDetail {
767 Direct(Handle<crate::Expression>),
768 Restricted(Handle<crate::Expression>),
769}
770
771struct TexelAddress {
781 coordinate: Handle<crate::Expression>,
782 array_index: Option<Handle<crate::Expression>>,
783 sample: Option<Handle<crate::Expression>>,
784 level: Option<LevelOfDetail>,
785}
786
787pub(super) struct ExpressionContext<'a> {
788 pub(super) function: &'a crate::Function,
789 origin: FunctionOrigin,
790 pub(super) info: &'a valid::FunctionInfo,
791 pub(super) module: &'a crate::Module,
792 pub(super) mod_info: &'a valid::ModuleInfo,
793 pub(super) pipeline_options: &'a PipelineOptions,
794 pub(super) lang_version: (u8, u8),
795 pub(super) policies: index::BoundsCheckPolicies,
796
797 pub(super) guarded_indices: HandleSet<crate::Expression>,
801 pub(super) force_loop_bounding: bool,
803}
804
805impl<'a> ExpressionContext<'a> {
806 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
807 self.info[handle].ty.inner_with(&self.module.types)
808 }
809
810 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
817 let image_ty = self.resolve_type(image);
818 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
819 class.is_mipmapped() && dim != crate::ImageDimension::D1
820 } else {
821 false
822 }
823 }
824
825 fn choose_bounds_check_policy(
826 &self,
827 pointer: Handle<crate::Expression>,
828 ) -> index::BoundsCheckPolicy {
829 self.policies
830 .choose_policy(pointer, &self.module.types, self.info)
831 }
832
833 fn access_needs_check(
835 &self,
836 base: Handle<crate::Expression>,
837 index: index::GuardedIndex,
838 ) -> Option<index::IndexableLength> {
839 index::access_needs_check(
840 base,
841 index,
842 self.module,
843 &self.function.expressions,
844 self.info,
845 )
846 }
847
848 fn bounds_check_iter(
850 &self,
851 chain: Handle<crate::Expression>,
852 ) -> impl Iterator<Item = BoundsCheck> + '_ {
853 index::bounds_check_iter(chain, self.module, self.function, self.info)
854 }
855
856 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
858 index::oob_local_types(self.module, self.function, self.info, self.policies)
859 }
860
861 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
862 match self.function.expressions[expr_handle] {
863 crate::Expression::AccessIndex { base, index } => {
864 let ty = match *self.resolve_type(base) {
865 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
866 ref ty => ty,
867 };
868 match *ty {
869 crate::TypeInner::Struct {
870 ref members, span, ..
871 } => should_pack_struct_member(members, span, index as usize, self.module),
872 _ => None,
873 }
874 }
875 _ => None,
876 }
877 }
878}
879
880pub(super) struct StatementContext<'a> {
881 pub(super) expression: ExpressionContext<'a>,
882 pub(super) result_struct: Option<&'a str>,
883}
884
885impl<W: Write> Writer<W> {
886 pub fn new(out: W) -> Self {
888 Writer {
889 out,
890 names: FastHashMap::default(),
891 named_expressions: Default::default(),
892 need_bake_expressions: Default::default(),
893 namer: proc::Namer::default(),
894 wrapped_functions: FastHashSet::default(),
895 #[cfg(test)]
896 put_expression_stack_pointers: Default::default(),
897 #[cfg(test)]
898 put_block_stack_pointers: Default::default(),
899 struct_member_pads: FastHashSet::default(),
900 needs_object_memory_barriers: false,
901 }
902 }
903
904 pub fn finish(self) -> W {
907 self.out
908 }
909
910 fn gen_force_bounded_loop_statements(
1017 &mut self,
1018 level: back::Level,
1019 context: &StatementContext,
1020 ) -> Option<(String, String)> {
1021 if !context.expression.force_loop_bounding {
1022 return None;
1023 }
1024
1025 let loop_bound_name = self.namer.call("loop_bound");
1026 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
1029 let level = level.next();
1030 let break_and_inc = format!(
1031 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
1032{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
1033 );
1034
1035 Some((decl, break_and_inc))
1036 }
1037
1038 fn put_call_parameters(
1039 &mut self,
1040 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1041 context: &ExpressionContext,
1042 ) -> BackendResult {
1043 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
1044 writer.put_expression(expr, context, true)
1045 })
1046 }
1047
1048 fn put_call_parameters_impl<C, E>(
1049 &mut self,
1050 parameters: impl Iterator<Item = Handle<crate::Expression>>,
1051 ctx: &C,
1052 put_expression: E,
1053 ) -> BackendResult
1054 where
1055 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1056 {
1057 write!(self.out, "(")?;
1058 for (i, handle) in parameters.enumerate() {
1059 if i != 0 {
1060 write!(self.out, ", ")?;
1061 }
1062 put_expression(self, ctx, handle)?;
1063 }
1064 write!(self.out, ")")?;
1065 Ok(())
1066 }
1067
1068 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1074 let oob_local_types = context.oob_local_types();
1075 for &ty in oob_local_types.iter() {
1076 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1077 self.names.insert(name_key, self.namer.call("oob"));
1078 }
1079
1080 for (name_key, ty, init) in context
1081 .function
1082 .local_variables
1083 .iter()
1084 .map(|(local_handle, local)| {
1085 let name_key = NameKey::local(context.origin, local_handle);
1086 (name_key, local.ty, local.init)
1087 })
1088 .chain(oob_local_types.iter().map(|&ty| {
1089 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1090 (name_key, ty, None)
1091 }))
1092 {
1093 let ty_name = TypeContext {
1094 handle: ty,
1095 gctx: context.module.to_ctx(),
1096 names: &self.names,
1097 access: crate::StorageAccess::empty(),
1098 first_time: false,
1099 };
1100 write!(
1101 self.out,
1102 "{}{} {}",
1103 back::INDENT,
1104 ty_name,
1105 self.names[&name_key]
1106 )?;
1107 match init {
1108 Some(value) => {
1109 write!(self.out, " = ")?;
1110 self.put_expression(value, context, true)?;
1111 }
1112 None => {
1113 write!(self.out, " = {{}}")?;
1114 }
1115 };
1116 writeln!(self.out, ";")?;
1117 }
1118 Ok(())
1119 }
1120
1121 fn put_level_of_detail(
1122 &mut self,
1123 level: LevelOfDetail,
1124 context: &ExpressionContext,
1125 ) -> BackendResult {
1126 match level {
1127 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1128 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1129 }
1130 Ok(())
1131 }
1132
1133 fn put_image_query(
1134 &mut self,
1135 image: Handle<crate::Expression>,
1136 query: &str,
1137 level: Option<LevelOfDetail>,
1138 context: &ExpressionContext,
1139 ) -> BackendResult {
1140 self.put_expression(image, context, false)?;
1141 write!(self.out, ".get_{query}(")?;
1142 if let Some(level) = level {
1143 self.put_level_of_detail(level, context)?;
1144 }
1145 write!(self.out, ")")?;
1146 Ok(())
1147 }
1148
1149 fn put_image_size_query(
1150 &mut self,
1151 image: Handle<crate::Expression>,
1152 level: Option<LevelOfDetail>,
1153 kind: crate::ScalarKind,
1154 context: &ExpressionContext,
1155 ) -> BackendResult {
1156 if let crate::TypeInner::Image {
1157 class: crate::ImageClass::External,
1158 ..
1159 } = *context.resolve_type(image)
1160 {
1161 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1162 self.put_expression(image, context, true)?;
1163 write!(self.out, ")")?;
1164 return Ok(());
1165 }
1166
1167 let dim = match *context.resolve_type(image) {
1170 crate::TypeInner::Image { dim, .. } => dim,
1171 ref other => unreachable!("Unexpected type {:?}", other),
1172 };
1173 let scalar = crate::Scalar { kind, width: 4 };
1174 let coordinate_type = scalar.to_msl_name();
1175 match dim {
1176 crate::ImageDimension::D1 => {
1177 if kind == crate::ScalarKind::Uint {
1181 self.put_image_query(image, "width", None, context)?;
1183 } else {
1184 write!(self.out, "int(")?;
1186 self.put_image_query(image, "width", None, context)?;
1187 write!(self.out, ")")?;
1188 }
1189 }
1190 crate::ImageDimension::D2 => {
1191 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1192 self.put_image_query(image, "width", level, context)?;
1193 write!(self.out, ", ")?;
1194 self.put_image_query(image, "height", level, context)?;
1195 write!(self.out, ")")?;
1196 }
1197 crate::ImageDimension::D3 => {
1198 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1199 self.put_image_query(image, "width", level, context)?;
1200 write!(self.out, ", ")?;
1201 self.put_image_query(image, "height", level, context)?;
1202 write!(self.out, ", ")?;
1203 self.put_image_query(image, "depth", level, context)?;
1204 write!(self.out, ")")?;
1205 }
1206 crate::ImageDimension::Cube => {
1207 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1208 self.put_image_query(image, "width", level, context)?;
1209 write!(self.out, ")")?;
1210 }
1211 }
1212 Ok(())
1213 }
1214
1215 fn put_cast_to_uint_scalar_or_vector(
1216 &mut self,
1217 expr: Handle<crate::Expression>,
1218 context: &ExpressionContext,
1219 ) -> BackendResult {
1220 match *context.resolve_type(expr) {
1222 crate::TypeInner::Scalar(_) => {
1223 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1224 }
1225 crate::TypeInner::Vector { size, .. } => {
1226 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1227 }
1228 _ => {
1229 return Err(Error::GenericValidation(
1230 "Invalid type for image coordinate".into(),
1231 ))
1232 }
1233 };
1234
1235 write!(self.out, "(")?;
1236 self.put_expression(expr, context, true)?;
1237 write!(self.out, ")")?;
1238 Ok(())
1239 }
1240
1241 fn put_image_sample_level(
1242 &mut self,
1243 image: Handle<crate::Expression>,
1244 level: crate::SampleLevel,
1245 context: &ExpressionContext,
1246 ) -> BackendResult {
1247 let has_levels = context.image_needs_lod(image);
1248 match level {
1249 crate::SampleLevel::Auto => {}
1250 crate::SampleLevel::Zero => {
1251 }
1253 _ if !has_levels => {
1254 log::warn!("1D image can't be sampled with level {level:?}");
1255 }
1256 crate::SampleLevel::Exact(h) => {
1257 write!(self.out, ", {NAMESPACE}::level(")?;
1258 self.put_expression(h, context, true)?;
1259 write!(self.out, ")")?;
1260 }
1261 crate::SampleLevel::Bias(h) => {
1262 write!(self.out, ", {NAMESPACE}::bias(")?;
1263 self.put_expression(h, context, true)?;
1264 write!(self.out, ")")?;
1265 }
1266 crate::SampleLevel::Gradient { x, y } => {
1267 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1268 self.put_expression(x, context, true)?;
1269 write!(self.out, ", ")?;
1270 self.put_expression(y, context, true)?;
1271 write!(self.out, ")")?;
1272 }
1273 }
1274 Ok(())
1275 }
1276
1277 fn put_image_coordinate_limits(
1278 &mut self,
1279 image: Handle<crate::Expression>,
1280 level: Option<LevelOfDetail>,
1281 context: &ExpressionContext,
1282 ) -> BackendResult {
1283 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1284 write!(self.out, " - 1")?;
1285 Ok(())
1286 }
1287
1288 fn put_restricted_scalar_image_index(
1306 &mut self,
1307 image: Handle<crate::Expression>,
1308 index: Handle<crate::Expression>,
1309 limit_method: &str,
1310 context: &ExpressionContext,
1311 ) -> BackendResult {
1312 write!(self.out, "{NAMESPACE}::min(uint(")?;
1313 self.put_expression(index, context, true)?;
1314 write!(self.out, "), ")?;
1315 self.put_expression(image, context, false)?;
1316 write!(self.out, ".{limit_method}() - 1)")?;
1317 Ok(())
1318 }
1319
1320 fn put_restricted_texel_address(
1321 &mut self,
1322 image: Handle<crate::Expression>,
1323 address: &TexelAddress,
1324 context: &ExpressionContext,
1325 ) -> BackendResult {
1326 write!(self.out, "{NAMESPACE}::min(")?;
1328 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1329 write!(self.out, ", ")?;
1330 self.put_image_coordinate_limits(image, address.level, context)?;
1331 write!(self.out, ")")?;
1332
1333 if let Some(array_index) = address.array_index {
1335 write!(self.out, ", ")?;
1336 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1337 }
1338
1339 if let Some(sample) = address.sample {
1341 write!(self.out, ", ")?;
1342 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1343 }
1344
1345 if let Some(level) = address.level {
1348 write!(self.out, ", ")?;
1349 self.put_level_of_detail(level, context)?;
1350 }
1351
1352 Ok(())
1353 }
1354
1355 fn put_image_access_bounds_check(
1357 &mut self,
1358 image: Handle<crate::Expression>,
1359 address: &TexelAddress,
1360 context: &ExpressionContext,
1361 ) -> BackendResult {
1362 let mut conjunction = "";
1363
1364 let level = if let Some(level) = address.level {
1367 write!(self.out, "uint(")?;
1368 self.put_level_of_detail(level, context)?;
1369 write!(self.out, ") < ")?;
1370 self.put_expression(image, context, true)?;
1371 write!(self.out, ".get_num_mip_levels()")?;
1372 conjunction = " && ";
1373 Some(level)
1374 } else {
1375 None
1376 };
1377
1378 if let Some(sample) = address.sample {
1380 write!(self.out, "uint(")?;
1381 self.put_expression(sample, context, true)?;
1382 write!(self.out, ") < ")?;
1383 self.put_expression(image, context, true)?;
1384 write!(self.out, ".get_num_samples()")?;
1385 conjunction = " && ";
1386 }
1387
1388 if let Some(array_index) = address.array_index {
1390 write!(self.out, "{conjunction}uint(")?;
1391 self.put_expression(array_index, context, true)?;
1392 write!(self.out, ") < ")?;
1393 self.put_expression(image, context, true)?;
1394 write!(self.out, ".get_array_size()")?;
1395 conjunction = " && ";
1396 }
1397
1398 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1400 crate::TypeInner::Vector { .. } => true,
1401 _ => false,
1402 };
1403 write!(self.out, "{conjunction}")?;
1404 if coord_is_vector {
1405 write!(self.out, "{NAMESPACE}::all(")?;
1406 }
1407 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1408 write!(self.out, " < ")?;
1409 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1410 if coord_is_vector {
1411 write!(self.out, ")")?;
1412 }
1413
1414 Ok(())
1415 }
1416
1417 fn put_image_load(
1418 &mut self,
1419 load: Handle<crate::Expression>,
1420 image: Handle<crate::Expression>,
1421 mut address: TexelAddress,
1422 context: &ExpressionContext,
1423 ) -> BackendResult {
1424 if let crate::TypeInner::Image {
1425 class: crate::ImageClass::External,
1426 ..
1427 } = *context.resolve_type(image)
1428 {
1429 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1430 self.put_expression(image, context, true)?;
1431 write!(self.out, ", ")?;
1432 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1433 write!(self.out, ")")?;
1434 return Ok(());
1435 }
1436
1437 match context.policies.image_load {
1438 proc::BoundsCheckPolicy::Restrict => {
1439 if address.level.is_some() {
1442 address.level = if context.image_needs_lod(image) {
1443 Some(LevelOfDetail::Restricted(load))
1444 } else {
1445 None
1446 }
1447 }
1448
1449 self.put_expression(image, context, false)?;
1450 write!(self.out, ".read(")?;
1451 self.put_restricted_texel_address(image, &address, context)?;
1452 write!(self.out, ")")?;
1453 }
1454 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1455 write!(self.out, "(")?;
1456 self.put_image_access_bounds_check(image, &address, context)?;
1457 write!(self.out, " ? ")?;
1458 self.put_unchecked_image_load(image, &address, context)?;
1459 write!(self.out, ": DefaultConstructible())")?;
1460 }
1461 proc::BoundsCheckPolicy::Unchecked => {
1462 self.put_unchecked_image_load(image, &address, context)?;
1463 }
1464 }
1465
1466 Ok(())
1467 }
1468
1469 fn put_unchecked_image_load(
1470 &mut self,
1471 image: Handle<crate::Expression>,
1472 address: &TexelAddress,
1473 context: &ExpressionContext,
1474 ) -> BackendResult {
1475 self.put_expression(image, context, false)?;
1476 write!(self.out, ".read(")?;
1477 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1479 if let Some(expr) = address.array_index {
1480 write!(self.out, ", ")?;
1481 self.put_expression(expr, context, true)?;
1482 }
1483 if let Some(sample) = address.sample {
1484 write!(self.out, ", ")?;
1485 self.put_expression(sample, context, true)?;
1486 }
1487 if let Some(level) = address.level {
1488 if context.image_needs_lod(image) {
1489 write!(self.out, ", ")?;
1490 self.put_level_of_detail(level, context)?;
1491 }
1492 }
1493 write!(self.out, ")")?;
1494
1495 Ok(())
1496 }
1497
1498 fn put_image_atomic(
1499 &mut self,
1500 level: back::Level,
1501 image: Handle<crate::Expression>,
1502 address: &TexelAddress,
1503 fun: crate::AtomicFunction,
1504 value: Handle<crate::Expression>,
1505 context: &StatementContext,
1506 ) -> BackendResult {
1507 write!(self.out, "{level}")?;
1508 self.put_expression(image, &context.expression, false)?;
1509 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1510 fun.to_msl_64_bit()?
1511 } else {
1512 fun.to_msl()
1513 };
1514 write!(self.out, ".atomic_{op}(")?;
1515 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1517 write!(self.out, ", ")?;
1518 self.put_expression(value, &context.expression, true)?;
1519 writeln!(self.out, ");")?;
1520
1521 let value_ty = context.expression.resolve_type(value);
1527 let zero_value = match (value_ty.scalar_kind(), value_ty.scalar_width()) {
1528 (Some(crate::ScalarKind::Sint), _) => "int4(0)",
1529 (_, Some(8)) => "ulong4(0uL)",
1530 _ => "uint4(0u)",
1531 };
1532 let coord_ty = context.expression.resolve_type(address.coordinate);
1533 let x = if matches!(coord_ty, crate::TypeInner::Scalar(_)) {
1534 ""
1535 } else {
1536 ".x"
1537 };
1538 write!(self.out, "{level}if (")?;
1539 self.put_expression(address.coordinate, &context.expression, true)?;
1540 write!(self.out, "{x} == -99999) {{ ")?;
1541 self.put_expression(image, &context.expression, false)?;
1542 write!(self.out, ".write({zero_value}, ")?;
1543 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1544 if let Some(array_index) = address.array_index {
1545 write!(self.out, ", ")?;
1546 self.put_expression(array_index, &context.expression, true)?;
1547 }
1548 writeln!(self.out, "); }}")?;
1549
1550 Ok(())
1551 }
1552
1553 fn put_image_store(
1554 &mut self,
1555 level: back::Level,
1556 image: Handle<crate::Expression>,
1557 address: &TexelAddress,
1558 value: Handle<crate::Expression>,
1559 context: &StatementContext,
1560 ) -> BackendResult {
1561 write!(self.out, "{level}")?;
1562 self.put_expression(image, &context.expression, false)?;
1563 write!(self.out, ".write(")?;
1564 self.put_expression(value, &context.expression, true)?;
1565 write!(self.out, ", ")?;
1566 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1568 if let Some(expr) = address.array_index {
1569 write!(self.out, ", ")?;
1570 self.put_expression(expr, &context.expression, true)?;
1571 }
1572 writeln!(self.out, ");")?;
1573
1574 Ok(())
1575 }
1576
1577 fn put_dynamic_array_max_index(
1588 &mut self,
1589 handle: Handle<crate::GlobalVariable>,
1590 context: &ExpressionContext,
1591 ) -> BackendResult {
1592 let global = &context.module.global_variables[handle];
1593 let (offset, array_ty) = match context.module.types[global.ty].inner {
1594 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1595 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1596 None => return Err(Error::GenericValidation("Struct has no members".into())),
1597 },
1598 crate::TypeInner::Array {
1599 size: crate::ArraySize::Dynamic,
1600 ..
1601 } => (0, global.ty),
1602 ref ty => {
1603 return Err(Error::GenericValidation(format!(
1604 "Expected type with dynamic array, got {ty:?}"
1605 )))
1606 }
1607 };
1608
1609 let (size, stride) = match context.module.types[array_ty].inner {
1610 crate::TypeInner::Array { base, stride, .. } => (
1611 context.module.types[base]
1612 .inner
1613 .size(context.module.to_ctx()),
1614 stride,
1615 ),
1616 ref ty => {
1617 return Err(Error::GenericValidation(format!(
1618 "Expected array type, got {ty:?}"
1619 )))
1620 }
1621 };
1622
1623 write!(
1636 self.out,
1637 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1638 member = ArraySizeMember(handle),
1639 offset = offset,
1640 size = size,
1641 stride = stride,
1642 )?;
1643 Ok(())
1644 }
1645
1646 fn put_dot_product<T: Copy>(
1651 &mut self,
1652 arg: T,
1653 arg1: T,
1654 size: usize,
1655 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1656 ) -> BackendResult {
1657 write!(self.out, "(")?;
1660
1661 for index in 0..size {
1663 write!(self.out, " + ")?;
1666 extractor(self, arg, index)?;
1667 write!(self.out, " * ")?;
1668 extractor(self, arg1, index)?;
1669 }
1670
1671 write!(self.out, ")")?;
1672 Ok(())
1673 }
1674
1675 fn put_pack4x8(
1677 &mut self,
1678 arg: Handle<crate::Expression>,
1679 context: &ExpressionContext<'_>,
1680 was_signed: bool,
1681 clamp_bounds: Option<(&str, &str)>,
1682 ) -> Result<(), Error> {
1683 let write_arg = |this: &mut Self| -> BackendResult {
1684 if let Some((min, max)) = clamp_bounds {
1685 write!(this.out, "{NAMESPACE}::clamp(")?;
1687 this.put_expression(arg, context, true)?;
1688 write!(this.out, ", {min}, {max})")?;
1689 } else {
1690 this.put_expression(arg, context, true)?;
1691 }
1692 Ok(())
1693 };
1694
1695 if context.lang_version >= (2, 1) {
1696 let packed_type = if was_signed {
1697 "packed_char4"
1698 } else {
1699 "packed_uchar4"
1700 };
1701 write!(self.out, "as_type<uint>({packed_type}(")?;
1703 write_arg(self)?;
1704 write!(self.out, "))")?;
1705 } else {
1706 if was_signed {
1708 write!(self.out, "uint(")?;
1709 }
1710 write!(self.out, "(")?;
1711 write_arg(self)?;
1712 write!(self.out, "[0] & 0xFF) | ((")?;
1713 write_arg(self)?;
1714 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1715 write_arg(self)?;
1716 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1717 write_arg(self)?;
1718 write!(self.out, "[3] & 0xFF) << 24)")?;
1719 if was_signed {
1720 write!(self.out, ")")?;
1721 }
1722 }
1723
1724 Ok(())
1725 }
1726
1727 fn put_isign(
1730 &mut self,
1731 arg: Handle<crate::Expression>,
1732 context: &ExpressionContext,
1733 ) -> BackendResult {
1734 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1735 let scalar = context
1736 .resolve_type(arg)
1737 .scalar()
1738 .expect("put_isign should only be called for args which have an integer scalar type")
1739 .to_msl_name();
1740 match context.resolve_type(arg) {
1741 &crate::TypeInner::Vector { size, .. } => {
1742 let size = common::vector_size_str(size);
1743 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1744 }
1745 _ => {
1746 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1747 }
1748 }
1749 write!(self.out, ", (")?;
1750 self.put_expression(arg, context, true)?;
1751 write!(self.out, " > 0)), {scalar}(0), (")?;
1752 self.put_expression(arg, context, true)?;
1753 write!(self.out, " == 0))")?;
1754 Ok(())
1755 }
1756
1757 pub(super) fn put_const_expression(
1758 &mut self,
1759 expr_handle: Handle<crate::Expression>,
1760 module: &crate::Module,
1761 mod_info: &valid::ModuleInfo,
1762 arena: &crate::Arena<crate::Expression>,
1763 ) -> BackendResult {
1764 self.put_possibly_const_expression(
1765 expr_handle,
1766 arena,
1767 module,
1768 mod_info,
1769 &(module, mod_info),
1770 |&(_, mod_info), expr| &mod_info[expr],
1771 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1772 )
1773 }
1774
1775 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1776 match literal {
1777 crate::Literal::F64(_) => {
1778 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1779 }
1780 crate::Literal::F16(value) => {
1781 if value.is_infinite() {
1782 let sign = if value.is_sign_negative() { "-" } else { "" };
1783 write!(self.out, "{sign}INFINITY")?;
1784 } else if value.is_nan() {
1785 write!(self.out, "NAN")?;
1786 } else {
1787 let suffix = if value.fract() == f16::from_f32(0.0) {
1788 ".0h"
1789 } else {
1790 "h"
1791 };
1792 write!(self.out, "{value}{suffix}")?;
1793 }
1794 }
1795 crate::Literal::F32(value) => {
1796 if value.is_infinite() {
1797 let sign = if value.is_sign_negative() { "-" } else { "" };
1798 write!(self.out, "{sign}INFINITY")?;
1799 } else if value.is_nan() {
1800 write!(self.out, "NAN")?;
1801 } else {
1802 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1803 write!(self.out, "{value}{suffix}")?;
1804 }
1805 }
1806 crate::Literal::U16(value) => {
1807 write!(self.out, "static_cast<ushort>({value})")?;
1808 }
1809 crate::Literal::I16(value) => {
1810 write!(self.out, "static_cast<short>({value})")?;
1811 }
1812 crate::Literal::U32(value) => {
1813 write!(self.out, "{value}u")?;
1814 }
1815 crate::Literal::I32(value) => {
1816 if value == i32::MIN {
1821 write!(self.out, "({} - 1)", value + 1)?;
1822 } else {
1823 write!(self.out, "{value}")?;
1824 }
1825 }
1826 crate::Literal::U64(value) => {
1827 write!(self.out, "{value}uL")?;
1828 }
1829 crate::Literal::I64(value) => {
1830 if value == i64::MIN {
1837 write!(self.out, "({}L - 1L)", value + 1)?;
1838 } else {
1839 write!(self.out, "{value}L")?;
1840 }
1841 }
1842 crate::Literal::Bool(value) => {
1843 write!(self.out, "{value}")?;
1844 }
1845 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1846 return Err(Error::GenericValidation(
1847 "Unsupported abstract literal".into(),
1848 ));
1849 }
1850 }
1851 Ok(())
1852 }
1853
1854 #[allow(clippy::too_many_arguments)]
1855 fn put_possibly_const_expression<C, I, E>(
1856 &mut self,
1857 expr_handle: Handle<crate::Expression>,
1858 expressions: &crate::Arena<crate::Expression>,
1859 module: &crate::Module,
1860 mod_info: &valid::ModuleInfo,
1861 ctx: &C,
1862 get_expr_ty: I,
1863 put_expression: E,
1864 ) -> BackendResult
1865 where
1866 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1867 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1868 {
1869 match expressions[expr_handle] {
1870 crate::Expression::Literal(literal) => {
1871 self.put_literal(literal)?;
1872 }
1873 crate::Expression::Constant(handle) => {
1874 let constant = &module.constants[handle];
1875 if constant.name.is_some() {
1876 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1877 } else {
1878 self.put_const_expression(
1879 constant.init,
1880 module,
1881 mod_info,
1882 &module.global_expressions,
1883 )?;
1884 }
1885 }
1886 crate::Expression::ZeroValue(ty) => {
1887 let ty_name = TypeContext {
1888 handle: ty,
1889 gctx: module.to_ctx(),
1890 names: &self.names,
1891 access: crate::StorageAccess::empty(),
1892 first_time: false,
1893 };
1894 write!(self.out, "{ty_name} {{}}")?;
1895 }
1896 crate::Expression::Compose { ty, ref components } => {
1897 let ty_name = TypeContext {
1898 handle: ty,
1899 gctx: module.to_ctx(),
1900 names: &self.names,
1901 access: crate::StorageAccess::empty(),
1902 first_time: false,
1903 };
1904 write!(self.out, "{ty_name}")?;
1905 match module.types[ty].inner {
1906 crate::TypeInner::Scalar(_)
1907 | crate::TypeInner::Vector { .. }
1908 | crate::TypeInner::Matrix { .. } => {
1909 self.put_call_parameters_impl(
1910 components.iter().copied(),
1911 ctx,
1912 put_expression,
1913 )?;
1914 }
1915 crate::TypeInner::Array { .. } => {
1916 write!(self.out, " {{{{")?;
1919 for (index, &component) in components.iter().enumerate() {
1920 if index != 0 {
1921 write!(self.out, ", ")?;
1922 }
1923 put_expression(self, ctx, component)?;
1924 }
1925 write!(self.out, "}}}}")?;
1926 }
1927 crate::TypeInner::Struct { .. } => {
1928 write!(self.out, " {{")?;
1929 for (index, &component) in components.iter().enumerate() {
1930 if index != 0 {
1931 write!(self.out, ", ")?;
1932 }
1933 if self.struct_member_pads.contains(&(ty, index as u32)) {
1935 write!(self.out, "{{}}, ")?;
1936 }
1937 put_expression(self, ctx, component)?;
1938 }
1939 write!(self.out, "}}")?;
1940 }
1941 _ => return Err(Error::UnsupportedCompose(ty)),
1942 }
1943 }
1944 crate::Expression::Splat { size, value } => {
1945 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1946 crate::TypeInner::Scalar(scalar) => scalar,
1947 ref ty => {
1948 return Err(Error::GenericValidation(format!(
1949 "Expected splat value type must be a scalar, got {ty:?}",
1950 )))
1951 }
1952 };
1953 put_numeric_type(&mut self.out, scalar, &[size])?;
1954 write!(self.out, "(")?;
1955 put_expression(self, ctx, value)?;
1956 write!(self.out, ")")?;
1957 }
1958 _ => {
1959 return Err(Error::Override);
1960 }
1961 }
1962
1963 Ok(())
1964 }
1965
1966 pub(super) fn put_expression(
1978 &mut self,
1979 expr_handle: Handle<crate::Expression>,
1980 context: &ExpressionContext,
1981 is_scoped: bool,
1982 ) -> BackendResult {
1983 #[cfg(test)]
1985 self.put_expression_stack_pointers
1986 .insert(ptr::from_ref(&expr_handle).cast());
1987
1988 if let Some(name) = self.named_expressions.get(&expr_handle) {
1989 write!(self.out, "{name}")?;
1990 return Ok(());
1991 }
1992
1993 let expression = &context.function.expressions[expr_handle];
1994 match *expression {
1995 crate::Expression::Literal(_)
1996 | crate::Expression::Constant(_)
1997 | crate::Expression::ZeroValue(_)
1998 | crate::Expression::Compose { .. }
1999 | crate::Expression::Splat { .. } => {
2000 self.put_possibly_const_expression(
2001 expr_handle,
2002 &context.function.expressions,
2003 context.module,
2004 context.mod_info,
2005 context,
2006 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
2007 |writer, context, expr| writer.put_expression(expr, context, true),
2008 )?;
2009 }
2010 crate::Expression::Override(_) => return Err(Error::Override),
2011 crate::Expression::Access { base, .. }
2012 | crate::Expression::AccessIndex { base, .. } => {
2013 let policy = context.choose_bounds_check_policy(base);
2019 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2020 && self.put_bounds_checks(
2021 expr_handle,
2022 context,
2023 back::Level(0),
2024 if is_scoped { "" } else { "(" },
2025 )?
2026 {
2027 write!(self.out, " ? ")?;
2028 self.put_access_chain(expr_handle, policy, context)?;
2029 write!(self.out, " : ")?;
2030
2031 if context.resolve_type(base).pointer_space().is_some() {
2032 let result_ty = context.info[expr_handle]
2036 .ty
2037 .inner_with(&context.module.types)
2038 .pointer_base_type();
2039 let result_ty_handle = match result_ty {
2040 Some(TypeResolution::Handle(handle)) => handle,
2041 Some(TypeResolution::Value(_)) => {
2042 unreachable!(
2050 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
2051 );
2052 }
2053 None => {
2054 unreachable!(
2055 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
2056 )
2057 }
2058 };
2059 let name_key =
2060 NameKey::oob_local_for_type(context.origin, result_ty_handle);
2061 self.out.write_str(&self.names[&name_key])?;
2062 } else {
2063 write!(self.out, "DefaultConstructible()")?;
2064 }
2065
2066 if !is_scoped {
2067 write!(self.out, ")")?;
2068 }
2069 } else {
2070 self.put_access_chain(expr_handle, policy, context)?;
2071 }
2072 }
2073 crate::Expression::Swizzle {
2074 size,
2075 vector,
2076 pattern,
2077 } => {
2078 self.put_wrapped_expression_for_packed_vec3_access(
2079 vector,
2080 context,
2081 false,
2082 &Self::put_expression,
2083 )?;
2084 write!(self.out, ".")?;
2085 for &sc in pattern[..size as usize].iter() {
2086 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
2087 }
2088 }
2089 crate::Expression::FunctionArgument(index) => {
2090 let name_key = match context.origin {
2091 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
2092 FunctionOrigin::EntryPoint(ep_index) => {
2093 NameKey::EntryPointArgument(ep_index, index)
2094 }
2095 };
2096 let name = &self.names[&name_key];
2097 write!(self.out, "{name}")?;
2098 }
2099 crate::Expression::GlobalVariable(handle) => {
2100 let name = &self.names[&NameKey::GlobalVariable(handle)];
2101 write!(self.out, "{name}")?;
2102 }
2103 crate::Expression::LocalVariable(handle) => {
2104 let name_key = NameKey::local(context.origin, handle);
2105 let name = &self.names[&name_key];
2106 write!(self.out, "{name}")?;
2107 }
2108 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
2109 crate::Expression::ImageSample {
2110 coordinate,
2111 image,
2112 sampler,
2113 clamp_to_edge: true,
2114 gather: None,
2115 array_index: None,
2116 offset: None,
2117 level: crate::SampleLevel::Zero,
2118 depth_ref: None,
2119 } => {
2120 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2121 self.put_expression(image, context, true)?;
2122 write!(self.out, ", ")?;
2123 self.put_expression(sampler, context, true)?;
2124 write!(self.out, ", ")?;
2125 self.put_expression(coordinate, context, true)?;
2126 write!(self.out, ")")?;
2127 }
2128 crate::Expression::ImageSample {
2129 image,
2130 sampler,
2131 gather,
2132 coordinate,
2133 array_index,
2134 offset,
2135 level,
2136 depth_ref,
2137 clamp_to_edge,
2138 } => {
2139 if clamp_to_edge {
2140 return Err(Error::GenericValidation(
2141 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2142 ));
2143 }
2144
2145 let main_op = match gather {
2146 Some(_) => "gather",
2147 None => "sample",
2148 };
2149 let comparison_op = match depth_ref {
2150 Some(_) => "_compare",
2151 None => "",
2152 };
2153 self.put_expression(image, context, false)?;
2154 write!(self.out, ".{main_op}{comparison_op}(")?;
2155 self.put_expression(sampler, context, true)?;
2156 write!(self.out, ", ")?;
2157 self.put_expression(coordinate, context, true)?;
2158 if let Some(expr) = array_index {
2159 write!(self.out, ", ")?;
2160 self.put_expression(expr, context, true)?;
2161 }
2162 if let Some(dref) = depth_ref {
2163 write!(self.out, ", ")?;
2164 self.put_expression(dref, context, true)?;
2165 }
2166
2167 self.put_image_sample_level(image, level, context)?;
2168
2169 if let Some(offset) = offset {
2170 write!(self.out, ", ")?;
2171 self.put_expression(offset, context, true)?;
2172 }
2173
2174 match gather {
2175 None | Some(crate::SwizzleComponent::X) => {}
2176 Some(component) => {
2177 let is_cube_map = match *context.resolve_type(image) {
2178 crate::TypeInner::Image {
2179 dim: crate::ImageDimension::Cube,
2180 ..
2181 } => true,
2182 _ => false,
2183 };
2184 if offset.is_none() && !is_cube_map {
2187 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2188 }
2189 let letter = back::COMPONENTS[component as usize];
2190 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2191 }
2192 }
2193 write!(self.out, ")")?;
2194 }
2195 crate::Expression::ImageLoad {
2196 image,
2197 coordinate,
2198 array_index,
2199 sample,
2200 level,
2201 } => {
2202 let address = TexelAddress {
2203 coordinate,
2204 array_index,
2205 sample,
2206 level: level.map(LevelOfDetail::Direct),
2207 };
2208 self.put_image_load(expr_handle, image, address, context)?;
2209 }
2210 crate::Expression::ImageQuery { image, query } => match query {
2213 crate::ImageQuery::Size { level } => {
2214 self.put_image_size_query(
2215 image,
2216 level.map(LevelOfDetail::Direct),
2217 crate::ScalarKind::Uint,
2218 context,
2219 )?;
2220 }
2221 crate::ImageQuery::NumLevels => {
2222 self.put_expression(image, context, false)?;
2223 write!(self.out, ".get_num_mip_levels()")?;
2224 }
2225 crate::ImageQuery::NumLayers => {
2226 self.put_expression(image, context, false)?;
2227 write!(self.out, ".get_array_size()")?;
2228 }
2229 crate::ImageQuery::NumSamples => {
2230 self.put_expression(image, context, false)?;
2231 write!(self.out, ".get_num_samples()")?;
2232 }
2233 },
2234 crate::Expression::Unary { op, expr } => {
2235 let op_str = match op {
2236 crate::UnaryOperator::Negate => {
2237 match context.resolve_type(expr).scalar_kind() {
2238 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2239 _ => "-",
2240 }
2241 }
2242 crate::UnaryOperator::LogicalNot => "!",
2243 crate::UnaryOperator::BitwiseNot => "~",
2244 };
2245 write!(self.out, "{op_str}(")?;
2246 self.put_expression(expr, context, false)?;
2247 write!(self.out, ")")?;
2248 }
2249 crate::Expression::Binary { op, left, right } => {
2250 let kind = context
2251 .resolve_type(left)
2252 .scalar_kind()
2253 .ok_or(Error::UnsupportedBinaryOp(op))?;
2254
2255 if op == crate::BinaryOperator::Divide
2256 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2257 {
2258 write!(self.out, "{DIV_FUNCTION}(")?;
2259 self.put_expression(left, context, true)?;
2260 write!(self.out, ", ")?;
2261 self.put_expression(right, context, true)?;
2262 write!(self.out, ")")?;
2263 } else if op == crate::BinaryOperator::Modulo
2264 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2265 {
2266 write!(self.out, "{MOD_FUNCTION}(")?;
2267 self.put_expression(left, context, true)?;
2268 write!(self.out, ", ")?;
2269 self.put_expression(right, context, true)?;
2270 write!(self.out, ")")?;
2271 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2272 write!(self.out, "{NAMESPACE}::fmod(")?;
2277 self.put_expression(left, context, true)?;
2278 write!(self.out, ", ")?;
2279 self.put_expression(right, context, true)?;
2280 write!(self.out, ")")?;
2281 } else if (op == crate::BinaryOperator::Add
2282 || op == crate::BinaryOperator::Subtract
2283 || op == crate::BinaryOperator::Multiply)
2284 && kind == crate::ScalarKind::Sint
2285 {
2286 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2287 crate::TypeInner::Scalar(scalar) => {
2288 Ok(crate::TypeInner::Scalar(crate::Scalar {
2289 kind: crate::ScalarKind::Uint,
2290 ..scalar
2291 }))
2292 }
2293 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2294 size,
2295 scalar: crate::Scalar {
2296 kind: crate::ScalarKind::Uint,
2297 ..scalar
2298 },
2299 }),
2300 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2301 };
2302
2303 self.put_bitcasted_expression(
2308 context.resolve_type(expr_handle),
2309 expr_handle,
2310 context,
2311 &|writer, context, is_scoped| {
2312 writer.put_binop(
2313 op,
2314 left,
2315 right,
2316 context,
2317 is_scoped,
2318 &|writer, expr, context, _is_scoped| {
2319 writer.put_bitcasted_expression(
2320 &to_unsigned(context.resolve_type(expr))?,
2321 expr,
2322 context,
2323 &|writer, context, is_scoped| {
2324 writer.put_expression(expr, context, is_scoped)
2325 },
2326 )
2327 },
2328 )
2329 },
2330 )?;
2331 } else {
2332 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2333 }
2334 }
2335 crate::Expression::Select {
2336 condition,
2337 accept,
2338 reject,
2339 } => match *context.resolve_type(condition) {
2340 crate::TypeInner::Scalar(crate::Scalar {
2341 kind: crate::ScalarKind::Bool,
2342 ..
2343 }) => {
2344 if !is_scoped {
2345 write!(self.out, "(")?;
2346 }
2347 self.put_expression(condition, context, false)?;
2348 write!(self.out, " ? ")?;
2349 self.put_expression(accept, context, false)?;
2350 write!(self.out, " : ")?;
2351 self.put_expression(reject, context, false)?;
2352 if !is_scoped {
2353 write!(self.out, ")")?;
2354 }
2355 }
2356 crate::TypeInner::Vector {
2357 scalar:
2358 crate::Scalar {
2359 kind: crate::ScalarKind::Bool,
2360 ..
2361 },
2362 ..
2363 } => {
2364 write!(self.out, "{NAMESPACE}::select(")?;
2365 self.put_expression(reject, context, true)?;
2366 write!(self.out, ", ")?;
2367 self.put_expression(accept, context, true)?;
2368 write!(self.out, ", ")?;
2369 self.put_expression(condition, context, true)?;
2370 write!(self.out, ")")?;
2371 }
2372 ref ty => {
2373 return Err(Error::GenericValidation(format!(
2374 "Expected select condition to be a non-bool type, got {ty:?}",
2375 )))
2376 }
2377 },
2378 crate::Expression::Derivative { axis, expr, .. } => {
2379 use crate::DerivativeAxis as Axis;
2380 let op = match axis {
2381 Axis::X => "dfdx",
2382 Axis::Y => "dfdy",
2383 Axis::Width => "fwidth",
2384 };
2385 write!(self.out, "{NAMESPACE}::{op}")?;
2386 self.put_call_parameters(iter::once(expr), context)?;
2387 }
2388 crate::Expression::Relational { fun, argument } => {
2389 let op = match fun {
2390 crate::RelationalFunction::Any => "any",
2391 crate::RelationalFunction::All => "all",
2392 crate::RelationalFunction::IsNan => "isnan",
2393 crate::RelationalFunction::IsInf => "isinf",
2394 };
2395 write!(self.out, "{NAMESPACE}::{op}")?;
2396 self.put_call_parameters(iter::once(argument), context)?;
2397 }
2398 crate::Expression::Math {
2399 fun,
2400 arg,
2401 arg1,
2402 arg2,
2403 arg3,
2404 } => {
2405 use crate::MathFunction as Mf;
2406
2407 let arg_type = context.resolve_type(arg);
2408 let scalar_argument = match arg_type {
2409 &crate::TypeInner::Scalar(_) => true,
2410 _ => false,
2411 };
2412
2413 let fun_name = match fun {
2414 Mf::Abs => "abs",
2416 Mf::Min => "min",
2417 Mf::Max => "max",
2418 Mf::Clamp => "clamp",
2419 Mf::Saturate => "saturate",
2420 Mf::Cos => "cos",
2422 Mf::Cosh => "cosh",
2423 Mf::Sin => "sin",
2424 Mf::Sinh => "sinh",
2425 Mf::Tan => "tan",
2426 Mf::Tanh => "tanh",
2427 Mf::Acos => "acos",
2428 Mf::Asin => "asin",
2429 Mf::Atan => "atan",
2430 Mf::Atan2 => "atan2",
2431 Mf::Asinh => "asinh",
2432 Mf::Acosh => "acosh",
2433 Mf::Atanh => "atanh",
2434 Mf::Radians => "",
2435 Mf::Degrees => "",
2436 Mf::Ceil => "ceil",
2438 Mf::Floor => "floor",
2439 Mf::Round => "rint",
2440 Mf::Fract => "fract",
2441 Mf::Trunc => "trunc",
2442 Mf::Modf => MODF_FUNCTION,
2443 Mf::Frexp => FREXP_FUNCTION,
2444 Mf::Ldexp => "ldexp",
2445 Mf::Exp => "exp",
2447 Mf::Exp2 => "exp2",
2448 Mf::Log => "log",
2449 Mf::Log2 => "log2",
2450 Mf::Pow => "pow",
2451 Mf::Dot => match *context.resolve_type(arg) {
2453 crate::TypeInner::Vector {
2454 scalar:
2455 crate::Scalar {
2456 kind: crate::ScalarKind::Float,
2458 ..
2459 },
2460 ..
2461 } => "dot",
2462 crate::TypeInner::Vector {
2463 size,
2464 scalar:
2465 scalar @ crate::Scalar {
2466 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2467 ..
2468 },
2469 } => {
2470 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2472 write!(self.out, "{fun_name}(")?;
2473 self.put_expression(arg, context, true)?;
2474 write!(self.out, ", ")?;
2475 self.put_expression(arg1.unwrap(), context, true)?;
2476 write!(self.out, ")")?;
2477 return Ok(());
2478 }
2479 _ => unreachable!(
2480 "Correct TypeInner for dot product should be already validated"
2481 ),
2482 },
2483 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2484 if context.lang_version >= (2, 1) {
2485 let packed_type = match fun {
2489 Mf::Dot4I8Packed => "packed_char4",
2490 Mf::Dot4U8Packed => "packed_uchar4",
2491 _ => unreachable!(),
2492 };
2493
2494 return self.put_dot_product(
2495 Reinterpreted::new(packed_type, arg),
2496 Reinterpreted::new(packed_type, arg1.unwrap()),
2497 4,
2498 |writer, arg, index| {
2499 write!(writer.out, "{arg}[{index}]")?;
2502 Ok(())
2503 },
2504 );
2505 } else {
2506 let conversion = match fun {
2510 Mf::Dot4I8Packed => "int",
2511 Mf::Dot4U8Packed => "",
2512 _ => unreachable!(),
2513 };
2514
2515 return self.put_dot_product(
2516 arg,
2517 arg1.unwrap(),
2518 4,
2519 |writer, arg, index| {
2520 write!(writer.out, "({conversion}(")?;
2521 writer.put_expression(arg, context, true)?;
2522 if index == 3 {
2523 write!(writer.out, ") >> 24)")?;
2524 } else {
2525 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2526 }
2527 Ok(())
2528 },
2529 );
2530 }
2531 }
2532 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2533 Mf::Cross => "cross",
2534 Mf::Distance => "distance",
2535 Mf::Length if scalar_argument => "abs",
2536 Mf::Length => "length",
2537 Mf::Normalize => "normalize",
2538 Mf::FaceForward => "faceforward",
2539 Mf::Reflect => "reflect",
2540 Mf::Refract => "refract",
2541 Mf::Sign => match arg_type.scalar_kind() {
2543 Some(crate::ScalarKind::Sint) => {
2544 return self.put_isign(arg, context);
2545 }
2546 _ => "sign",
2547 },
2548 Mf::Fma => "fma",
2549 Mf::Mix => "mix",
2550 Mf::Step => "step",
2551 Mf::SmoothStep => "smoothstep",
2552 Mf::Sqrt => "sqrt",
2553 Mf::InverseSqrt => "rsqrt",
2554 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2555 Mf::Transpose => "transpose",
2556 Mf::Determinant => "determinant",
2557 Mf::QuantizeToF16 => "",
2558 Mf::CountTrailingZeros => "ctz",
2560 Mf::CountLeadingZeros => "clz",
2561 Mf::CountOneBits => "popcount",
2562 Mf::ReverseBits => "reverse_bits",
2563 Mf::ExtractBits => "",
2564 Mf::InsertBits => "",
2565 Mf::FirstTrailingBit => "",
2566 Mf::FirstLeadingBit => "",
2567 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2569 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2570 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2571 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2572 Mf::Pack2x16float => "",
2573 Mf::Pack4xI8 => "",
2574 Mf::Pack4xU8 => "",
2575 Mf::Pack4xI8Clamp => "",
2576 Mf::Pack4xU8Clamp => "",
2577 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2579 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2580 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2581 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2582 Mf::Unpack2x16float => "",
2583 Mf::Unpack4xI8 => "",
2584 Mf::Unpack4xU8 => "",
2585 };
2586
2587 match fun {
2588 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2589 if context.lang_version < (1, 2) {
2598 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2599 }
2600 }
2601 _ => {}
2602 }
2603
2604 match fun {
2605 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2606 write!(self.out, "{ABS_FUNCTION}(")?;
2607 self.put_expression(arg, context, true)?;
2608 write!(self.out, ")")?;
2609 }
2610 Mf::Distance if scalar_argument => {
2611 write!(self.out, "{NAMESPACE}::abs(")?;
2612 self.put_expression(arg, context, false)?;
2613 write!(self.out, " - ")?;
2614 self.put_expression(arg1.unwrap(), context, false)?;
2615 write!(self.out, ")")?;
2616 }
2617 Mf::FirstTrailingBit => {
2618 let scalar = context.resolve_type(arg).scalar().unwrap();
2619 let constant = scalar.width * 8 + 1;
2620
2621 write!(self.out, "((({NAMESPACE}::ctz(")?;
2622 self.put_expression(arg, context, true)?;
2623 write!(self.out, ") + 1) % {constant}) - 1)")?;
2624 }
2625 Mf::FirstLeadingBit => {
2626 let inner = context.resolve_type(arg);
2627 let scalar = inner.scalar().unwrap();
2628 let constant = scalar.width * 8 - 1;
2629
2630 write!(
2631 self.out,
2632 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2633 )?;
2634
2635 if scalar.kind == crate::ScalarKind::Sint {
2636 write!(self.out, "{NAMESPACE}::select(")?;
2637 self.put_expression(arg, context, true)?;
2638 write!(self.out, ", ~")?;
2639 self.put_expression(arg, context, true)?;
2640 write!(self.out, ", ")?;
2641 self.put_expression(arg, context, true)?;
2642 write!(self.out, " < 0)")?;
2643 } else {
2644 self.put_expression(arg, context, true)?;
2645 }
2646
2647 write!(self.out, "), ")?;
2648
2649 match *inner {
2651 crate::TypeInner::Vector { size, scalar } => {
2652 let size = common::vector_size_str(size);
2653 let name = scalar.to_msl_name();
2654 write!(self.out, "{name}{size}")?;
2655 }
2656 crate::TypeInner::Scalar(scalar) => {
2657 let name = scalar.to_msl_name();
2658 write!(self.out, "{name}")?;
2659 }
2660 _ => (),
2661 }
2662
2663 write!(self.out, "(-1), ")?;
2664 self.put_expression(arg, context, true)?;
2665 write!(self.out, " == 0")?;
2666 if scalar.kind == crate::ScalarKind::Sint {
2667 write!(self.out, " || ")?;
2668 self.put_expression(arg, context, true)?;
2669 write!(self.out, " == -1")?;
2670 }
2671 write!(self.out, ")")?;
2672 }
2673 Mf::Unpack2x16float => {
2674 write!(self.out, "float2(as_type<half2>(")?;
2675 self.put_expression(arg, context, false)?;
2676 write!(self.out, "))")?;
2677 }
2678 Mf::Pack2x16float => {
2679 write!(self.out, "as_type<uint>(half2(")?;
2680 self.put_expression(arg, context, false)?;
2681 write!(self.out, "))")?;
2682 }
2683 Mf::ExtractBits => {
2684 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2701
2702 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2703 self.put_expression(arg, context, true)?;
2704 write!(self.out, ", {NAMESPACE}::min(")?;
2705 self.put_expression(arg1.unwrap(), context, true)?;
2706 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2707 self.put_expression(arg2.unwrap(), context, true)?;
2708 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2709 self.put_expression(arg1.unwrap(), context, true)?;
2710 write!(self.out, ", {scalar_bits}u)))")?;
2711 }
2712 Mf::InsertBits => {
2713 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2718
2719 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2720 self.put_expression(arg, context, true)?;
2721 write!(self.out, ", ")?;
2722 self.put_expression(arg1.unwrap(), context, true)?;
2723 write!(self.out, ", {NAMESPACE}::min(")?;
2724 self.put_expression(arg2.unwrap(), context, true)?;
2725 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2726 self.put_expression(arg3.unwrap(), context, true)?;
2727 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2728 self.put_expression(arg2.unwrap(), context, true)?;
2729 write!(self.out, ", {scalar_bits}u)))")?;
2730 }
2731 Mf::Radians => {
2732 write!(self.out, "((")?;
2733 self.put_expression(arg, context, false)?;
2734 write!(self.out, ") * 0.017453292519943295474)")?;
2735 }
2736 Mf::Degrees => {
2737 write!(self.out, "((")?;
2738 self.put_expression(arg, context, false)?;
2739 write!(self.out, ") * 57.295779513082322865)")?;
2740 }
2741 Mf::Modf | Mf::Frexp => {
2742 write!(self.out, "{fun_name}")?;
2743 self.put_call_parameters(iter::once(arg), context)?;
2744 }
2745 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2746 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2747 Mf::Pack4xI8Clamp => {
2748 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2749 }
2750 Mf::Pack4xU8Clamp => {
2751 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2752 }
2753 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2754 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2755 "u"
2756 } else {
2757 ""
2758 };
2759
2760 if context.lang_version >= (2, 1) {
2761 write!(
2763 self.out,
2764 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2765 )?;
2766 self.put_expression(arg, context, true)?;
2767 write!(self.out, "))")?;
2768 } else {
2769 write!(self.out, "({sign_prefix}int4(")?;
2771 self.put_expression(arg, context, true)?;
2772 write!(self.out, ", ")?;
2773 self.put_expression(arg, context, true)?;
2774 write!(self.out, " >> 8, ")?;
2775 self.put_expression(arg, context, true)?;
2776 write!(self.out, " >> 16, ")?;
2777 self.put_expression(arg, context, true)?;
2778 write!(self.out, " >> 24) << 24 >> 24)")?;
2779 }
2780 }
2781 Mf::QuantizeToF16 => {
2782 match *context.resolve_type(arg) {
2783 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2784 crate::TypeInner::Vector { size, .. } => write!(
2785 self.out,
2786 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2787 size = common::vector_size_str(size),
2788 )?,
2789 _ => unreachable!(
2790 "Correct TypeInner for QuantizeToF16 should be already validated"
2791 ),
2792 };
2793
2794 self.put_expression(arg, context, true)?;
2795 write!(self.out, "))")?;
2796 }
2797 _ => {
2798 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2799 self.put_call_parameters(
2800 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2801 context,
2802 )?;
2803 }
2804 }
2805 }
2806 crate::Expression::As {
2807 expr,
2808 kind,
2809 convert,
2810 } => match *context.resolve_type(expr) {
2811 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2812 if src.kind == crate::ScalarKind::Float
2813 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2814 && convert.is_some()
2815 {
2816 let fun_name = match (kind, convert) {
2820 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2821 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2822 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2823 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2824 _ => unreachable!(),
2825 };
2826 write!(self.out, "{fun_name}(")?;
2827 self.put_expression(expr, context, true)?;
2828 write!(self.out, ")")?;
2829 } else {
2830 let target_scalar = crate::Scalar {
2831 kind,
2832 width: convert.unwrap_or(src.width),
2833 };
2834 let op = match convert {
2835 Some(_) => "static_cast",
2836 None => "as_type",
2837 };
2838 write!(self.out, "{op}<")?;
2839 match *context.resolve_type(expr) {
2840 crate::TypeInner::Vector { size, .. } => {
2841 put_numeric_type(&mut self.out, target_scalar, &[size])?
2842 }
2843 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2844 };
2845 write!(self.out, ">(")?;
2846 self.put_expression(expr, context, true)?;
2847 write!(self.out, ")")?;
2848 }
2849 }
2850 crate::TypeInner::Matrix {
2851 columns,
2852 rows,
2853 scalar,
2854 } => {
2855 let target_scalar = crate::Scalar {
2856 kind,
2857 width: convert.unwrap_or(scalar.width),
2858 };
2859 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2860 write!(self.out, "(")?;
2861 self.put_expression(expr, context, true)?;
2862 write!(self.out, ")")?;
2863 }
2864 ref ty => {
2865 return Err(Error::GenericValidation(format!(
2866 "Unsupported type for As: {ty:?}"
2867 )))
2868 }
2869 },
2870 crate::Expression::CallResult(_)
2872 | crate::Expression::AtomicResult { .. }
2873 | crate::Expression::WorkGroupUniformLoadResult { .. }
2874 | crate::Expression::SubgroupBallotResult
2875 | crate::Expression::SubgroupOperationResult { .. }
2876 | crate::Expression::RayQueryProceedResult => {
2877 unreachable!()
2878 }
2879 crate::Expression::ArrayLength(expr) => {
2880 let global = match context.function.expressions[expr] {
2882 crate::Expression::AccessIndex { base, .. } => {
2883 match context.function.expressions[base] {
2884 crate::Expression::GlobalVariable(handle) => handle,
2885 ref ex => {
2886 return Err(Error::GenericValidation(format!(
2887 "Expected global variable in AccessIndex, got {ex:?}"
2888 )))
2889 }
2890 }
2891 }
2892 crate::Expression::GlobalVariable(handle) => handle,
2893 ref ex => {
2894 return Err(Error::GenericValidation(format!(
2895 "Unexpected expression in ArrayLength, got {ex:?}"
2896 )))
2897 }
2898 };
2899
2900 if !is_scoped {
2901 write!(self.out, "(")?;
2902 }
2903 write!(self.out, "1 + ")?;
2904 self.put_dynamic_array_max_index(global, context)?;
2905 if !is_scoped {
2906 write!(self.out, ")")?;
2907 }
2908 }
2909 crate::Expression::RayQueryVertexPositions { .. } => {
2910 unimplemented!()
2911 }
2912 crate::Expression::RayQueryGetIntersection { query, committed } => {
2913 if context.lang_version < (2, 4) {
2914 return Err(Error::UnsupportedRayTracing);
2915 }
2916
2917 write!(
2918 self.out,
2919 "{}_{committed}(",
2920 super::ray::INTERSECTION_FUNCTION_NAME
2921 )?;
2922 self.put_expression(query, context, true)?;
2923 write!(self.out, ")")?;
2924 }
2925 crate::Expression::CooperativeLoad { ref data, .. } => {
2926 if context.lang_version < (2, 3) {
2927 return Err(Error::UnsupportedCooperativeMatrix);
2928 }
2929 write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
2930 write!(self.out, "&")?;
2931 self.put_access_chain(data.pointer, context.policies.index, context)?;
2932 write!(self.out, ", ")?;
2933 self.put_expression(data.stride, context, true)?;
2934 write!(self.out, ", {})", !data.row_major)?;
2941 }
2942 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2943 if context.lang_version < (2, 3) {
2944 return Err(Error::UnsupportedCooperativeMatrix);
2945 }
2946 write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
2947 self.put_expression(a, context, true)?;
2948 write!(self.out, ", ")?;
2949 self.put_expression(b, context, true)?;
2950 write!(self.out, ", ")?;
2951 self.put_expression(c, context, true)?;
2952 write!(self.out, ")")?;
2953 }
2954 }
2955 Ok(())
2956 }
2957
2958 fn put_binop<F>(
2961 &mut self,
2962 op: crate::BinaryOperator,
2963 left: Handle<crate::Expression>,
2964 right: Handle<crate::Expression>,
2965 context: &ExpressionContext,
2966 is_scoped: bool,
2967 put_expression: &F,
2968 ) -> BackendResult
2969 where
2970 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2971 {
2972 let op_str = back::binary_operation_str(op);
2973
2974 if !is_scoped {
2975 write!(self.out, "(")?;
2976 }
2977
2978 if op == crate::BinaryOperator::Multiply
2981 && matches!(
2982 context.resolve_type(right),
2983 &crate::TypeInner::Matrix { .. }
2984 )
2985 {
2986 self.put_wrapped_expression_for_packed_vec3_access(
2987 left,
2988 context,
2989 false,
2990 put_expression,
2991 )?;
2992 } else {
2993 put_expression(self, left, context, false)?;
2994 }
2995
2996 write!(self.out, " {op_str} ")?;
2997
2998 if op == crate::BinaryOperator::Multiply
3000 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
3001 {
3002 self.put_wrapped_expression_for_packed_vec3_access(
3003 right,
3004 context,
3005 false,
3006 put_expression,
3007 )?;
3008 } else {
3009 put_expression(self, right, context, false)?;
3010 }
3011
3012 if !is_scoped {
3013 write!(self.out, ")")?;
3014 }
3015
3016 Ok(())
3017 }
3018
3019 fn put_wrapped_expression_for_packed_vec3_access<F>(
3021 &mut self,
3022 expr_handle: Handle<crate::Expression>,
3023 context: &ExpressionContext,
3024 is_scoped: bool,
3025 put_expression: &F,
3026 ) -> BackendResult
3027 where
3028 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
3029 {
3030 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
3031 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
3032 put_expression(self, expr_handle, context, is_scoped)?;
3033 write!(self.out, ")")?;
3034 } else {
3035 put_expression(self, expr_handle, context, is_scoped)?;
3036 }
3037 Ok(())
3038 }
3039
3040 fn put_bitcasted_expression<F>(
3043 &mut self,
3044 cast_to: &crate::TypeInner,
3045 inner_expr: Handle<crate::Expression>,
3046 context: &ExpressionContext,
3047 put_expression: &F,
3048 ) -> BackendResult
3049 where
3050 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
3051 {
3052 let needs_truncation = match *cast_to {
3057 crate::TypeInner::Scalar(scalar) => scalar.width < 4,
3058 crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
3059 _ => false,
3060 };
3061
3062 write!(self.out, "as_type<")?;
3063 match *cast_to {
3064 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
3065 crate::TypeInner::Vector { size, scalar } => {
3066 put_numeric_type(&mut self.out, scalar, &[size])?
3067 }
3068 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
3069 };
3070 write!(self.out, ">(")?;
3071
3072 if needs_truncation {
3073 write!(self.out, "static_cast<")?;
3074 let unsigned_scalar = match *cast_to {
3076 crate::TypeInner::Scalar(scalar) => crate::Scalar {
3077 kind: crate::ScalarKind::Uint,
3078 ..scalar
3079 },
3080 crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
3081 kind: crate::ScalarKind::Uint,
3082 ..scalar
3083 },
3084 _ => unreachable!(),
3085 };
3086 match *cast_to {
3087 crate::TypeInner::Scalar(_) => {
3088 put_numeric_type(&mut self.out, unsigned_scalar, &[])?
3089 }
3090 crate::TypeInner::Vector { size, .. } => {
3091 put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
3092 }
3093 _ => unreachable!(),
3094 };
3095 write!(self.out, ">(")?;
3096 }
3097
3098 if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
3100 put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
3101 write!(self.out, "(")?;
3102 put_expression(self, context, true)?;
3103 write!(self.out, ")")?;
3104 } else {
3105 put_expression(self, context, true)?;
3106 }
3107
3108 if needs_truncation {
3109 write!(self.out, ")")?;
3110 }
3111
3112 write!(self.out, ")")?;
3113 Ok(())
3114 }
3115
3116 fn put_index(
3118 &mut self,
3119 index: index::GuardedIndex,
3120 context: &ExpressionContext,
3121 is_scoped: bool,
3122 ) -> BackendResult {
3123 match index {
3124 index::GuardedIndex::Expression(expr) => {
3125 self.put_expression(expr, context, is_scoped)?
3126 }
3127 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
3128 }
3129 Ok(())
3130 }
3131
3132 fn put_bounds_checks(
3162 &mut self,
3163 chain: Handle<crate::Expression>,
3164 context: &ExpressionContext,
3165 level: back::Level,
3166 prefix: &'static str,
3167 ) -> Result<bool, Error> {
3168 let mut check_written = false;
3169
3170 for item in context.bounds_check_iter(chain) {
3172 let BoundsCheck {
3173 base,
3174 index,
3175 length,
3176 } = item;
3177
3178 if check_written {
3179 write!(self.out, " && ")?;
3180 } else {
3181 write!(self.out, "{level}{prefix}")?;
3182 check_written = true;
3183 }
3184
3185 write!(self.out, "uint(")?;
3189 self.put_index(index, context, true)?;
3190 self.out.write_str(") < ")?;
3191 match length {
3192 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3193 index::IndexableLength::Dynamic => {
3194 let global = context.function.originating_global(base).ok_or_else(|| {
3195 Error::GenericValidation("Could not find originating global".into())
3196 })?;
3197 write!(self.out, "1 + ")?;
3198 self.put_dynamic_array_max_index(global, context)?
3199 }
3200 }
3201 }
3202
3203 Ok(check_written)
3204 }
3205
3206 fn put_access_chain(
3226 &mut self,
3227 chain: Handle<crate::Expression>,
3228 policy: index::BoundsCheckPolicy,
3229 context: &ExpressionContext,
3230 ) -> BackendResult {
3231 match context.function.expressions[chain] {
3232 crate::Expression::Access { base, index } => {
3233 let mut base_ty = context.resolve_type(base);
3234
3235 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3237 base_ty = &context.module.types[base].inner;
3238 }
3239
3240 self.put_subscripted_access_chain(
3241 base,
3242 base_ty,
3243 index::GuardedIndex::Expression(index),
3244 policy,
3245 context,
3246 )?;
3247 }
3248 crate::Expression::AccessIndex { base, index } => {
3249 let base_resolution = &context.info[base].ty;
3250 let mut base_ty = base_resolution.inner_with(&context.module.types);
3251 let mut base_ty_handle = base_resolution.handle();
3252
3253 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3255 base_ty = &context.module.types[base].inner;
3256 base_ty_handle = Some(base);
3257 }
3258
3259 match *base_ty {
3263 crate::TypeInner::Struct { .. } => {
3264 let base_ty = base_ty_handle.unwrap();
3265 self.put_access_chain(base, policy, context)?;
3266 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3267 write!(self.out, ".{name}")?;
3268 }
3269 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3270 self.put_access_chain(base, policy, context)?;
3271 if context.get_packed_vec_kind(base).is_some() {
3274 write!(self.out, "[{index}]")?;
3275 } else {
3276 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3277 }
3278 }
3279 _ => {
3280 self.put_subscripted_access_chain(
3281 base,
3282 base_ty,
3283 index::GuardedIndex::Known(index),
3284 policy,
3285 context,
3286 )?;
3287 }
3288 }
3289 }
3290 _ => self.put_expression(chain, context, false)?,
3291 }
3292
3293 Ok(())
3294 }
3295
3296 fn put_subscripted_access_chain(
3313 &mut self,
3314 base: Handle<crate::Expression>,
3315 base_ty: &crate::TypeInner,
3316 index: index::GuardedIndex,
3317 policy: index::BoundsCheckPolicy,
3318 context: &ExpressionContext,
3319 ) -> BackendResult {
3320 let accessing_wrapped_array = match *base_ty {
3321 crate::TypeInner::Array {
3322 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3323 ..
3324 } => true,
3325 _ => false,
3326 };
3327 let accessing_wrapped_binding_array =
3328 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3329
3330 self.put_access_chain(base, policy, context)?;
3331 if accessing_wrapped_array {
3332 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3333 }
3334 write!(self.out, "[")?;
3335
3336 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3338 context.access_needs_check(base, index)
3339 } else {
3340 None
3341 };
3342 if let Some(limit) = restriction_needed {
3343 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3344 self.put_index(index, context, true)?;
3345 write!(self.out, "), ")?;
3346 match limit {
3347 index::IndexableLength::Known(limit) => {
3348 write!(self.out, "{}u", limit - 1)?;
3349 }
3350 index::IndexableLength::Dynamic => {
3351 let global = context.function.originating_global(base).ok_or_else(|| {
3352 Error::GenericValidation("Could not find originating global".into())
3353 })?;
3354 self.put_dynamic_array_max_index(global, context)?;
3355 }
3356 }
3357 write!(self.out, ")")?;
3358 } else {
3359 self.put_index(index, context, true)?;
3360 }
3361
3362 write!(self.out, "]")?;
3363
3364 if accessing_wrapped_binding_array {
3365 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3366 }
3367
3368 Ok(())
3369 }
3370
3371 fn put_load(
3372 &mut self,
3373 pointer: Handle<crate::Expression>,
3374 context: &ExpressionContext,
3375 is_scoped: bool,
3376 ) -> BackendResult {
3377 let policy = context.choose_bounds_check_policy(pointer);
3380 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3381 && self.put_bounds_checks(
3382 pointer,
3383 context,
3384 back::Level(0),
3385 if is_scoped { "" } else { "(" },
3386 )?
3387 {
3388 write!(self.out, " ? ")?;
3389 self.put_unchecked_load(pointer, policy, context)?;
3390 write!(self.out, " : DefaultConstructible()")?;
3391
3392 if !is_scoped {
3393 write!(self.out, ")")?;
3394 }
3395 } else {
3396 self.put_unchecked_load(pointer, policy, context)?;
3397 }
3398
3399 Ok(())
3400 }
3401
3402 fn put_unchecked_load(
3403 &mut self,
3404 pointer: Handle<crate::Expression>,
3405 policy: index::BoundsCheckPolicy,
3406 context: &ExpressionContext,
3407 ) -> BackendResult {
3408 let is_atomic_pointer = context
3409 .resolve_type(pointer)
3410 .is_atomic_pointer(&context.module.types);
3411
3412 if is_atomic_pointer {
3413 write!(
3414 self.out,
3415 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3416 )?;
3417 self.put_access_chain(pointer, policy, context)?;
3418 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3419 } else {
3420 self.put_access_chain(pointer, policy, context)?;
3424 }
3425
3426 Ok(())
3427 }
3428
3429 fn put_return_value(
3430 &mut self,
3431 level: back::Level,
3432 expr_handle: Handle<crate::Expression>,
3433 result_struct: Option<&str>,
3434 context: &ExpressionContext,
3435 ) -> BackendResult {
3436 match result_struct {
3437 Some(struct_name) => {
3438 let mut has_point_size = false;
3439 let result_ty = context.function.result.as_ref().unwrap().ty;
3440 match context.module.types[result_ty].inner {
3441 crate::TypeInner::Struct { ref members, .. } => {
3442 let tmp = "_tmp";
3443 write!(self.out, "{level}const auto {tmp} = ")?;
3444 self.put_expression(expr_handle, context, true)?;
3445 writeln!(self.out, ";")?;
3446 write!(self.out, "{level}return {struct_name} {{")?;
3447
3448 let mut is_first = true;
3449
3450 for (index, member) in members.iter().enumerate() {
3451 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3452 member.binding
3453 {
3454 has_point_size = true;
3455 if !context.pipeline_options.allow_and_force_point_size {
3456 continue;
3457 }
3458 }
3459
3460 let comma = if is_first { "" } else { "," };
3461 is_first = false;
3462 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3463 if let crate::TypeInner::Array {
3467 size: crate::ArraySize::Constant(size),
3468 ..
3469 } = context.module.types[member.ty].inner
3470 {
3471 write!(self.out, "{comma} {{")?;
3472 for j in 0..size.get() {
3473 if j != 0 {
3474 write!(self.out, ",")?;
3475 }
3476 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3477 }
3478 write!(self.out, "}}")?;
3479 } else {
3480 write!(self.out, "{comma} {tmp}.{name}")?;
3481 }
3482 }
3483 }
3484 _ => {
3485 write!(self.out, "{level}return {struct_name} {{ ")?;
3486 self.put_expression(expr_handle, context, true)?;
3487 }
3488 }
3489
3490 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3491 let stage = context.module.entry_points[ep_index as usize].stage;
3492 if context.pipeline_options.allow_and_force_point_size
3493 && stage == crate::ShaderStage::Vertex
3494 && !has_point_size
3495 {
3496 write!(self.out, ", 1.0")?;
3498 }
3499 }
3500 write!(self.out, " }}")?;
3501 }
3502 None => {
3503 write!(self.out, "{level}return ")?;
3504 self.put_expression(expr_handle, context, true)?;
3505 }
3506 }
3507 writeln!(self.out, ";")?;
3508 Ok(())
3509 }
3510
3511 fn update_expressions_to_bake(
3516 &mut self,
3517 func: &crate::Function,
3518 info: &valid::FunctionInfo,
3519 context: &ExpressionContext,
3520 ) {
3521 use crate::Expression;
3522 self.need_bake_expressions.clear();
3523
3524 for (expr_handle, expr) in func.expressions.iter() {
3525 let expr_info = &info[expr_handle];
3528 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3529 if min_ref_count <= expr_info.ref_count {
3530 self.need_bake_expressions.insert(expr_handle);
3531 } else {
3532 match expr_info.ty {
3533 TypeResolution::Handle(h)
3535 if Some(h) == context.module.special_types.ray_desc =>
3536 {
3537 self.need_bake_expressions.insert(expr_handle);
3538 }
3539 _ => {}
3540 }
3541 }
3542
3543 if let Expression::Math {
3544 fun,
3545 arg,
3546 arg1,
3547 arg2,
3548 ..
3549 } = *expr
3550 {
3551 match fun {
3552 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3562 self.need_bake_expressions.insert(arg);
3563 self.need_bake_expressions.insert(arg1.unwrap());
3564 }
3565 crate::MathFunction::FirstLeadingBit => {
3566 self.need_bake_expressions.insert(arg);
3567 }
3568 crate::MathFunction::Pack4xI8
3569 | crate::MathFunction::Pack4xU8
3570 | crate::MathFunction::Pack4xI8Clamp
3571 | crate::MathFunction::Pack4xU8Clamp
3572 | crate::MathFunction::Unpack4xI8
3573 | crate::MathFunction::Unpack4xU8 => {
3574 if context.lang_version < (2, 1) {
3577 self.need_bake_expressions.insert(arg);
3578 }
3579 }
3580 crate::MathFunction::ExtractBits => {
3581 self.need_bake_expressions.insert(arg1.unwrap());
3583 }
3584 crate::MathFunction::InsertBits => {
3585 self.need_bake_expressions.insert(arg2.unwrap());
3587 }
3588 crate::MathFunction::Sign => {
3589 let inner = context.resolve_type(expr_handle);
3594 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3595 self.need_bake_expressions.insert(arg);
3596 }
3597 }
3598 _ => {}
3599 }
3600 }
3601 }
3602 }
3603
3604 pub(super) fn start_baking_expression(
3605 &mut self,
3606 handle: Handle<crate::Expression>,
3607 context: &ExpressionContext,
3608 name: &str,
3609 ) -> BackendResult {
3610 match context.info[handle].ty {
3611 TypeResolution::Handle(ty_handle) => {
3612 let ty_name = TypeContext {
3613 handle: ty_handle,
3614 gctx: context.module.to_ctx(),
3615 names: &self.names,
3616 access: crate::StorageAccess::empty(),
3617 first_time: false,
3618 };
3619 write!(self.out, "{ty_name}")?;
3620 }
3621 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3622 put_numeric_type(&mut self.out, scalar, &[])?;
3623 }
3624 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3625 put_numeric_type(&mut self.out, scalar, &[size])?;
3626 }
3627 TypeResolution::Value(crate::TypeInner::Matrix {
3628 columns,
3629 rows,
3630 scalar,
3631 }) => {
3632 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3633 }
3634 TypeResolution::Value(crate::TypeInner::CooperativeMatrix {
3635 columns,
3636 rows,
3637 scalar,
3638 role: _,
3639 }) => {
3640 write!(
3641 self.out,
3642 "{}::simdgroup_{}{}x{}",
3643 NAMESPACE,
3644 scalar.to_msl_name(),
3645 columns as u32,
3646 rows as u32,
3647 )?;
3648 }
3649 TypeResolution::Value(ref other) => {
3650 log::warn!("Type {other:?} isn't a known local");
3651 return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3652 }
3653 }
3654
3655 write!(self.out, " {name} = ")?;
3657
3658 Ok(())
3659 }
3660
3661 fn put_cache_restricted_level(
3674 &mut self,
3675 load: Handle<crate::Expression>,
3676 image: Handle<crate::Expression>,
3677 mip_level: Option<Handle<crate::Expression>>,
3678 indent: back::Level,
3679 context: &StatementContext,
3680 ) -> BackendResult {
3681 let level_of_detail = match mip_level {
3684 Some(level) => level,
3685 None => return Ok(()),
3686 };
3687
3688 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3689 || !context.expression.image_needs_lod(image)
3690 {
3691 return Ok(());
3692 }
3693
3694 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3695 self.put_restricted_scalar_image_index(
3696 image,
3697 level_of_detail,
3698 "get_num_mip_levels",
3699 &context.expression,
3700 )?;
3701 writeln!(self.out, ";")?;
3702
3703 Ok(())
3704 }
3705
3706 fn put_casting_to_packed_chars(
3712 &mut self,
3713 fun: crate::MathFunction,
3714 arg0: Handle<crate::Expression>,
3715 arg1: Handle<crate::Expression>,
3716 indent: back::Level,
3717 context: &StatementContext<'_>,
3718 ) -> Result<(), Error> {
3719 let packed_type = match fun {
3720 crate::MathFunction::Dot4I8Packed => "packed_char4",
3721 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3722 _ => unreachable!(),
3723 };
3724
3725 for arg in [arg0, arg1] {
3726 write!(
3727 self.out,
3728 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3729 Reinterpreted::new(packed_type, arg)
3730 )?;
3731 self.put_expression(arg, &context.expression, true)?;
3732 writeln!(self.out, ");")?;
3733 }
3734
3735 Ok(())
3736 }
3737
3738 fn put_block(
3739 &mut self,
3740 level: back::Level,
3741 statements: &[crate::Statement],
3742 context: &StatementContext,
3743 ) -> BackendResult {
3744 #[cfg(test)]
3746 self.put_block_stack_pointers
3747 .insert(ptr::from_ref(&level).cast());
3748
3749 for statement in statements {
3750 log::trace!("statement[{}] {:?}", level.0, statement);
3751 match *statement {
3752 crate::Statement::Emit(ref range) => {
3753 for handle in range.clone() {
3754 use crate::MathFunction as Mf;
3755
3756 match context.expression.function.expressions[handle] {
3757 crate::Expression::ImageLoad {
3760 image,
3761 level: mip_level,
3762 ..
3763 } => {
3764 self.put_cache_restricted_level(
3765 handle, image, mip_level, level, context,
3766 )?;
3767 }
3768
3769 crate::Expression::Math {
3778 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3779 arg,
3780 arg1,
3781 ..
3782 } if context.expression.lang_version >= (2, 1) => {
3783 self.put_casting_to_packed_chars(
3784 fun,
3785 arg,
3786 arg1.unwrap(),
3787 level,
3788 context,
3789 )?;
3790 }
3791
3792 _ => (),
3793 }
3794
3795 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3796 let expr_name = if ptr_class.is_some() {
3797 None } else if let Some(name) =
3799 context.expression.function.named_expressions.get(&handle)
3800 {
3801 Some(self.namer.call(name))
3811 } else {
3812 let bake = if context.expression.guarded_indices.contains(handle) {
3816 true
3817 } else {
3818 self.need_bake_expressions.contains(&handle)
3819 };
3820
3821 if bake {
3822 Some(Baked(handle).to_string())
3823 } else {
3824 None
3825 }
3826 };
3827
3828 if let Some(name) = expr_name {
3829 write!(self.out, "{level}")?;
3830 self.start_baking_expression(handle, &context.expression, &name)?;
3831 self.put_expression(handle, &context.expression, true)?;
3832 self.named_expressions.insert(handle, name);
3833 writeln!(self.out, ";")?;
3834 }
3835 }
3836 }
3837 crate::Statement::Block(ref block) => {
3838 if !block.is_empty() {
3839 writeln!(self.out, "{level}{{")?;
3840 self.put_block(level.next(), block, context)?;
3841 writeln!(self.out, "{level}}}")?;
3842 }
3843 }
3844 crate::Statement::If {
3845 condition,
3846 ref accept,
3847 ref reject,
3848 } => {
3849 write!(self.out, "{level}if (")?;
3850 self.put_expression(condition, &context.expression, true)?;
3851 writeln!(self.out, ") {{")?;
3852 self.put_block(level.next(), accept, context)?;
3853 if !reject.is_empty() {
3854 writeln!(self.out, "{level}}} else {{")?;
3855 self.put_block(level.next(), reject, context)?;
3856 }
3857 writeln!(self.out, "{level}}}")?;
3858 }
3859 crate::Statement::Switch {
3860 selector,
3861 ref cases,
3862 } => {
3863 write!(self.out, "{level}switch(")?;
3864 self.put_expression(selector, &context.expression, true)?;
3865 writeln!(self.out, ") {{")?;
3866 let lcase = level.next();
3867 for case in cases.iter() {
3868 match case.value {
3869 crate::SwitchValue::I32(value) => {
3870 write!(self.out, "{lcase}case {value}:")?;
3871 }
3872 crate::SwitchValue::U32(value) => {
3873 write!(self.out, "{lcase}case {value}u:")?;
3874 }
3875 crate::SwitchValue::Default => {
3876 write!(self.out, "{lcase}default:")?;
3877 }
3878 }
3879
3880 let write_block_braces = !(case.fall_through && case.body.is_empty());
3881 if write_block_braces {
3882 writeln!(self.out, " {{")?;
3883 } else {
3884 writeln!(self.out)?;
3885 }
3886
3887 self.put_block(lcase.next(), &case.body, context)?;
3888 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3889 {
3890 writeln!(self.out, "{}break;", lcase.next())?;
3891 }
3892
3893 if write_block_braces {
3894 writeln!(self.out, "{lcase}}}")?;
3895 }
3896 }
3897 writeln!(self.out, "{level}}}")?;
3898 }
3899 crate::Statement::Loop {
3900 ref body,
3901 ref continuing,
3902 break_if,
3903 } => {
3904 let force_loop_bound_statements =
3905 self.gen_force_bounded_loop_statements(level, context);
3906 let gate_name = (!continuing.is_empty() || break_if.is_some())
3907 .then(|| self.namer.call("loop_init"));
3908
3909 if let Some((ref decl, _)) = force_loop_bound_statements {
3910 writeln!(self.out, "{decl}")?;
3911 }
3912 if let Some(ref gate_name) = gate_name {
3913 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3914 }
3915
3916 writeln!(self.out, "{level}while(true) {{",)?;
3917 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3918 writeln!(self.out, "{break_and_inc}")?;
3919 }
3920 if let Some(ref gate_name) = gate_name {
3921 let lif = level.next();
3922 let lcontinuing = lif.next();
3923 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3924 self.put_block(lcontinuing, continuing, context)?;
3925 if let Some(condition) = break_if {
3926 write!(self.out, "{lcontinuing}if (")?;
3927 self.put_expression(condition, &context.expression, true)?;
3928 writeln!(self.out, ") {{")?;
3929 writeln!(self.out, "{}break;", lcontinuing.next())?;
3930 writeln!(self.out, "{lcontinuing}}}")?;
3931 }
3932 writeln!(self.out, "{lif}}}")?;
3933 writeln!(self.out, "{lif}{gate_name} = false;")?;
3934 }
3935 self.put_block(level.next(), body, context)?;
3936
3937 writeln!(self.out, "{level}}}")?;
3938 }
3939 crate::Statement::Break => {
3940 writeln!(self.out, "{level}break;")?;
3941 }
3942 crate::Statement::Continue => {
3943 writeln!(self.out, "{level}continue;")?;
3944 }
3945 crate::Statement::Return {
3946 value: Some(expr_handle),
3947 } => {
3948 self.put_return_value(
3949 level,
3950 expr_handle,
3951 context.result_struct,
3952 &context.expression,
3953 )?;
3954 }
3955 crate::Statement::Return { value: None } => {
3956 writeln!(self.out, "{level}return;")?;
3957 }
3958 crate::Statement::Kill => {
3959 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3960 }
3961 crate::Statement::ControlBarrier(flags)
3962 | crate::Statement::MemoryBarrier(flags) => {
3963 self.write_barrier(flags, level)?;
3964 }
3965 crate::Statement::Store { pointer, value } => {
3966 self.put_store(pointer, value, level, context)?
3967 }
3968 crate::Statement::ImageStore {
3969 image,
3970 coordinate,
3971 array_index,
3972 value,
3973 } => {
3974 let address = TexelAddress {
3975 coordinate,
3976 array_index,
3977 sample: None,
3978 level: None,
3979 };
3980 self.put_image_store(level, image, &address, value, context)?
3981 }
3982 crate::Statement::Call {
3983 function,
3984 ref arguments,
3985 result,
3986 } => {
3987 write!(self.out, "{level}")?;
3988 if let Some(expr) = result {
3989 let name = Baked(expr).to_string();
3990 self.start_baking_expression(expr, &context.expression, &name)?;
3991 self.named_expressions.insert(expr, name);
3992 }
3993 let fun_name = &self.names[&NameKey::Function(function)];
3994 write!(self.out, "{fun_name}(")?;
3995 for (i, &handle) in arguments.iter().enumerate() {
3997 if i != 0 {
3998 write!(self.out, ", ")?;
3999 }
4000 self.put_expression(handle, &context.expression, true)?;
4001 }
4002 let mut separate = !arguments.is_empty();
4004 let fun_info = &context.expression.mod_info[function];
4005 let mut needs_buffer_sizes = false;
4006 for (handle, var) in context.expression.module.global_variables.iter() {
4007 if fun_info[handle].is_empty() {
4008 continue;
4009 }
4010 if var.space.needs_pass_through() {
4011 let name = &self.names[&NameKey::GlobalVariable(handle)];
4012 if separate {
4013 write!(self.out, ", ")?;
4014 } else {
4015 separate = true;
4016 }
4017 write!(self.out, "{name}")?;
4018 }
4019 needs_buffer_sizes |=
4020 needs_array_length(var.ty, &context.expression.module.types);
4021 }
4022 if needs_buffer_sizes {
4023 if separate {
4024 write!(self.out, ", ")?;
4025 }
4026 write!(self.out, "_buffer_sizes")?;
4027 }
4028
4029 writeln!(self.out, ");")?;
4031 }
4032 crate::Statement::Atomic {
4033 pointer,
4034 ref fun,
4035 value,
4036 result,
4037 } => {
4038 let context = &context.expression;
4039
4040 write!(self.out, "{level}")?;
4045 let fun_key = if let Some(result) = result {
4046 let res_name = Baked(result).to_string();
4047 self.start_baking_expression(result, context, &res_name)?;
4048 self.named_expressions.insert(result, res_name);
4049 fun.to_msl()
4050 } else if context.resolve_type(value).scalar_width() == Some(8) {
4051 fun.to_msl_64_bit()?
4052 } else {
4053 fun.to_msl()
4054 };
4055
4056 let policy = context.choose_bounds_check_policy(pointer);
4060 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4061 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
4062
4063 if checked {
4065 write!(self.out, " ? ")?;
4066 }
4067
4068 match *fun {
4070 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4071 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
4072 self.put_access_chain(pointer, policy, context)?;
4073 write!(self.out, ", ")?;
4074 self.put_expression(cmp, context, true)?;
4075 write!(self.out, ", ")?;
4076 self.put_expression(value, context, true)?;
4077 write!(self.out, ")")?;
4078 }
4079 _ => {
4080 write!(
4081 self.out,
4082 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
4083 )?;
4084 self.put_access_chain(pointer, policy, context)?;
4085 write!(self.out, ", ")?;
4086 self.put_expression(value, context, true)?;
4087 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
4088 }
4089 }
4090
4091 if checked {
4093 write!(self.out, " : DefaultConstructible()")?;
4094 }
4095
4096 writeln!(self.out, ";")?;
4098 }
4099 crate::Statement::ImageAtomic {
4100 image,
4101 coordinate,
4102 array_index,
4103 fun,
4104 value,
4105 } => {
4106 let address = TexelAddress {
4107 coordinate,
4108 array_index,
4109 sample: None,
4110 level: None,
4111 };
4112 self.put_image_atomic(level, image, &address, fun, value, context)?
4113 }
4114 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
4115 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4116
4117 write!(self.out, "{level}")?;
4118 let name = self.namer.call("");
4119 self.start_baking_expression(result, &context.expression, &name)?;
4120 self.put_load(pointer, &context.expression, true)?;
4121 self.named_expressions.insert(result, name);
4122
4123 writeln!(self.out, ";")?;
4124 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
4125 }
4126 crate::Statement::RayQuery { query, ref fun } => {
4127 self.write_ray_query_stmt(level, context, query, fun)?;
4128 }
4129 crate::Statement::SubgroupBallot { result, predicate } => {
4130 write!(self.out, "{level}")?;
4131 let name = self.namer.call("");
4132 self.start_baking_expression(result, &context.expression, &name)?;
4133 self.named_expressions.insert(result, name);
4134 write!(
4135 self.out,
4136 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4137 )?;
4138 if let Some(predicate) = predicate {
4139 self.put_expression(predicate, &context.expression, true)?;
4140 } else {
4141 write!(self.out, "true")?;
4142 }
4143 writeln!(self.out, "), 0, 0, 0);")?;
4144 }
4145 crate::Statement::SubgroupCollectiveOperation {
4146 op,
4147 collective_op,
4148 argument,
4149 result,
4150 } => {
4151 write!(self.out, "{level}")?;
4152 let name = self.namer.call("");
4153 self.start_baking_expression(result, &context.expression, &name)?;
4154 self.named_expressions.insert(result, name);
4155 match (collective_op, op) {
4156 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4157 write!(self.out, "{NAMESPACE}::simd_all(")?
4158 }
4159 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4160 write!(self.out, "{NAMESPACE}::simd_any(")?
4161 }
4162 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4163 write!(self.out, "{NAMESPACE}::simd_sum(")?
4164 }
4165 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4166 write!(self.out, "{NAMESPACE}::simd_product(")?
4167 }
4168 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4169 write!(self.out, "{NAMESPACE}::simd_max(")?
4170 }
4171 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4172 write!(self.out, "{NAMESPACE}::simd_min(")?
4173 }
4174 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4175 write!(self.out, "{NAMESPACE}::simd_and(")?
4176 }
4177 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4178 write!(self.out, "{NAMESPACE}::simd_or(")?
4179 }
4180 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4181 write!(self.out, "{NAMESPACE}::simd_xor(")?
4182 }
4183 (
4184 crate::CollectiveOperation::ExclusiveScan,
4185 crate::SubgroupOperation::Add,
4186 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4187 (
4188 crate::CollectiveOperation::ExclusiveScan,
4189 crate::SubgroupOperation::Mul,
4190 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4191 (
4192 crate::CollectiveOperation::InclusiveScan,
4193 crate::SubgroupOperation::Add,
4194 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4195 (
4196 crate::CollectiveOperation::InclusiveScan,
4197 crate::SubgroupOperation::Mul,
4198 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4199 _ => unimplemented!(),
4200 }
4201 self.put_expression(argument, &context.expression, true)?;
4202 writeln!(self.out, ");")?;
4203 }
4204 crate::Statement::SubgroupGather {
4205 mode,
4206 argument,
4207 result,
4208 } => {
4209 write!(self.out, "{level}")?;
4210 let name = self.namer.call("");
4211 self.start_baking_expression(result, &context.expression, &name)?;
4212 self.named_expressions.insert(result, name);
4213 match mode {
4214 crate::GatherMode::BroadcastFirst => {
4215 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4216 }
4217 crate::GatherMode::Broadcast(_) => {
4218 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4219 }
4220 crate::GatherMode::Shuffle(_) => {
4221 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4222 }
4223 crate::GatherMode::ShuffleDown(_) => {
4224 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4225 }
4226 crate::GatherMode::ShuffleUp(_) => {
4227 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4228 }
4229 crate::GatherMode::ShuffleXor(_) => {
4230 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4231 }
4232 crate::GatherMode::QuadBroadcast(_) => {
4233 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4234 }
4235 crate::GatherMode::QuadSwap(_) => {
4236 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4237 }
4238 }
4239 self.put_expression(argument, &context.expression, true)?;
4240 match mode {
4241 crate::GatherMode::BroadcastFirst => {}
4242 crate::GatherMode::Broadcast(index)
4243 | crate::GatherMode::Shuffle(index)
4244 | crate::GatherMode::ShuffleDown(index)
4245 | crate::GatherMode::ShuffleUp(index)
4246 | crate::GatherMode::ShuffleXor(index)
4247 | crate::GatherMode::QuadBroadcast(index) => {
4248 write!(self.out, ", ")?;
4249 self.put_expression(index, &context.expression, true)?;
4250 }
4251 crate::GatherMode::QuadSwap(direction) => {
4252 write!(self.out, ", ")?;
4253 match direction {
4254 crate::Direction::X => {
4255 write!(self.out, "1u")?;
4256 }
4257 crate::Direction::Y => {
4258 write!(self.out, "2u")?;
4259 }
4260 crate::Direction::Diagonal => {
4261 write!(self.out, "3u")?;
4262 }
4263 }
4264 }
4265 }
4266 writeln!(self.out, ");")?;
4267 }
4268 crate::Statement::CooperativeStore { target, ref data } => {
4269 write!(self.out, "{level}simdgroup_store(")?;
4270 self.put_expression(target, &context.expression, true)?;
4271 write!(self.out, ", &")?;
4272 self.put_access_chain(
4273 data.pointer,
4274 context.expression.policies.index,
4275 &context.expression,
4276 )?;
4277 write!(self.out, ", ")?;
4278 self.put_expression(data.stride, &context.expression, true)?;
4279 if !data.row_major {
4284 let matrix_origin = "0";
4285 let transpose = true;
4286 write!(self.out, ", {matrix_origin}, {transpose}")?;
4287 }
4288 writeln!(self.out, ");")?;
4289 }
4290 crate::Statement::RayPipelineFunction(_) => unreachable!(),
4291 }
4292 }
4293
4294 for statement in statements {
4297 if let crate::Statement::Emit(ref range) = *statement {
4298 for handle in range.clone() {
4299 self.named_expressions.shift_remove(&handle);
4300 }
4301 }
4302 }
4303 Ok(())
4304 }
4305
4306 fn put_store(
4307 &mut self,
4308 pointer: Handle<crate::Expression>,
4309 value: Handle<crate::Expression>,
4310 level: back::Level,
4311 context: &StatementContext,
4312 ) -> BackendResult {
4313 let policy = context.expression.choose_bounds_check_policy(pointer);
4314 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4315 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4316 {
4317 writeln!(self.out, ") {{")?;
4318 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4319 writeln!(self.out, "{level}}}")?;
4320 } else {
4321 self.put_unchecked_store(pointer, value, policy, level, context)?;
4322 }
4323
4324 Ok(())
4325 }
4326
4327 fn put_unchecked_store(
4328 &mut self,
4329 pointer: Handle<crate::Expression>,
4330 value: Handle<crate::Expression>,
4331 policy: index::BoundsCheckPolicy,
4332 level: back::Level,
4333 context: &StatementContext,
4334 ) -> BackendResult {
4335 let is_atomic_pointer = context
4336 .expression
4337 .resolve_type(pointer)
4338 .is_atomic_pointer(&context.expression.module.types);
4339
4340 if is_atomic_pointer {
4341 write!(
4342 self.out,
4343 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4344 )?;
4345 self.put_access_chain(pointer, policy, &context.expression)?;
4346 write!(self.out, ", ")?;
4347 self.put_expression(value, &context.expression, true)?;
4348 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4349 } else {
4350 write!(self.out, "{level}")?;
4351 self.put_access_chain(pointer, policy, &context.expression)?;
4352 write!(self.out, " = ")?;
4353 self.put_expression(value, &context.expression, true)?;
4354 writeln!(self.out, ";")?;
4355 }
4356
4357 Ok(())
4358 }
4359
4360 pub fn write(
4361 &mut self,
4362 module: &crate::Module,
4363 info: &valid::ModuleInfo,
4364 options: &Options,
4365 pipeline_options: &PipelineOptions,
4366 ) -> Result<TranslationInfo, Error> {
4367 self.names.clear();
4368 self.namer.reset(
4369 module,
4370 &super::keywords::RESERVED_SET,
4371 proc::KeywordSet::empty(),
4372 proc::CaseInsensitiveKeywordSet::empty(),
4373 &[
4374 CLAMPED_LOD_LOAD_PREFIX,
4375 super::ray::INTERSECTION_FUNCTION_NAME,
4376 ],
4377 &mut self.names,
4378 );
4379 self.wrapped_functions.clear();
4380 self.struct_member_pads.clear();
4381
4382 writeln!(
4383 self.out,
4384 "// language: metal{}.{}",
4385 options.lang_version.0, options.lang_version.1
4386 )?;
4387 writeln!(self.out, "#include <metal_stdlib>")?;
4388 writeln!(self.out, "#include <simd/simd.h>")?;
4389 writeln!(self.out)?;
4390 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4392
4393 if module.uses_mesh_shaders() && options.lang_version < (3, 0) {
4394 return Err(Error::UnsupportedMeshShader);
4395 }
4396 self.needs_object_memory_barriers = module
4397 .entry_points
4398 .iter()
4399 .any(|e| e.stage == crate::ShaderStage::Task && e.task_payload.is_some());
4400
4401 if module.special_types.ray_desc.is_some()
4402 || module.special_types.ray_intersection.is_some()
4403 {
4404 if options.lang_version < (2, 4) {
4405 return Err(Error::UnsupportedRayTracing);
4406 }
4407 }
4408
4409 if options
4410 .bounds_check_policies
4411 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4412 {
4413 self.put_default_constructible()?;
4414 }
4415 writeln!(self.out)?;
4416
4417 {
4418 let globals: Vec<Handle<crate::GlobalVariable>> = module
4421 .global_variables
4422 .iter()
4423 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4424 .map(|(handle, _)| handle)
4425 .collect();
4426
4427 let mut buffer_indices = vec![];
4428 for vbm in &pipeline_options.vertex_buffer_mappings {
4429 buffer_indices.push(vbm.id);
4430 }
4431
4432 if !globals.is_empty() || !buffer_indices.is_empty() {
4433 writeln!(self.out, "struct _mslBufferSizes {{")?;
4434
4435 for global in globals {
4436 writeln!(
4437 self.out,
4438 "{}uint {};",
4439 back::INDENT,
4440 ArraySizeMember(global)
4441 )?;
4442 }
4443
4444 for idx in buffer_indices {
4445 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4446 }
4447
4448 writeln!(self.out, "}};")?;
4449 writeln!(self.out)?;
4450 }
4451 };
4452
4453 self.write_type_defs(module)?;
4454 self.write_global_constants(module, info)?;
4455 self.write_functions(module, info, options, pipeline_options)
4456 }
4457
4458 fn put_default_constructible(&mut self) -> BackendResult {
4471 let tab = back::INDENT;
4472 writeln!(self.out, "struct DefaultConstructible {{")?;
4473 writeln!(self.out, "{tab}template<typename T>")?;
4474 writeln!(self.out, "{tab}operator T() && {{")?;
4475 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4476 writeln!(self.out, "{tab}}}")?;
4477 writeln!(self.out, "}};")?;
4478 Ok(())
4479 }
4480
4481 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4482 let mut generated_argument_buffer_wrapper = false;
4483 let mut generated_external_texture_wrapper = false;
4484 for (handle, ty) in module.types.iter() {
4485 match ty.inner {
4486 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4487 writeln!(self.out, "template <typename T>")?;
4488 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4489 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4490 writeln!(self.out, "}};")?;
4491 generated_argument_buffer_wrapper = true;
4492 }
4493 crate::TypeInner::Image {
4494 class: crate::ImageClass::External,
4495 ..
4496 } if !generated_external_texture_wrapper => {
4497 let params_ty_name = &self.names
4498 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4499 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4500 writeln!(
4501 self.out,
4502 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4503 back::INDENT
4504 )?;
4505 writeln!(
4506 self.out,
4507 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4508 back::INDENT
4509 )?;
4510 writeln!(
4511 self.out,
4512 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4513 back::INDENT
4514 )?;
4515 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4516 writeln!(self.out, "}};")?;
4517 generated_external_texture_wrapper = true;
4518 }
4519 _ => {}
4520 }
4521
4522 if !ty.needs_alias() {
4523 continue;
4524 }
4525 let name = &self.names[&NameKey::Type(handle)];
4526 match ty.inner {
4527 crate::TypeInner::Array {
4541 base,
4542 size,
4543 stride: _,
4544 } => {
4545 let base_name = TypeContext {
4546 handle: base,
4547 gctx: module.to_ctx(),
4548 names: &self.names,
4549 access: crate::StorageAccess::empty(),
4550 first_time: false,
4551 };
4552
4553 match size.resolve(module.to_ctx())? {
4554 proc::IndexableLength::Known(size) => {
4555 writeln!(self.out, "struct {name} {{")?;
4556 writeln!(
4557 self.out,
4558 "{}{} {}[{}];",
4559 back::INDENT,
4560 base_name,
4561 WRAPPED_ARRAY_FIELD,
4562 size
4563 )?;
4564 writeln!(self.out, "}};")?;
4565 }
4566 proc::IndexableLength::Dynamic => {
4567 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4568 }
4569 }
4570 }
4571 crate::TypeInner::Struct {
4572 ref members, span, ..
4573 } => {
4574 writeln!(self.out, "struct {name} {{")?;
4575 let mut last_offset = 0;
4576 for (index, member) in members.iter().enumerate() {
4577 if member.offset > last_offset {
4578 self.struct_member_pads.insert((handle, index as u32));
4579 let pad = member.offset - last_offset;
4580 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4581 }
4582 let ty_inner = &module.types[member.ty].inner;
4583 last_offset = member.offset + ty_inner.size(module.to_ctx());
4584
4585 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4586
4587 match should_pack_struct_member(members, span, index, module) {
4589 Some(scalar) => {
4590 writeln!(
4591 self.out,
4592 "{}{}::packed_{}3 {};",
4593 back::INDENT,
4594 NAMESPACE,
4595 scalar.to_msl_name(),
4596 member_name
4597 )?;
4598 }
4599 None => {
4600 let base_name = TypeContext {
4601 handle: member.ty,
4602 gctx: module.to_ctx(),
4603 names: &self.names,
4604 access: crate::StorageAccess::empty(),
4605 first_time: false,
4606 };
4607 writeln!(
4608 self.out,
4609 "{}{} {};",
4610 back::INDENT,
4611 base_name,
4612 member_name
4613 )?;
4614
4615 if let crate::TypeInner::Vector {
4617 size: crate::VectorSize::Tri,
4618 scalar,
4619 } = *ty_inner
4620 {
4621 last_offset += scalar.width as u32;
4622 }
4623 }
4624 }
4625 }
4626 if last_offset < span {
4627 let pad = span - last_offset;
4628 writeln!(
4629 self.out,
4630 "{}char _pad{}[{}];",
4631 back::INDENT,
4632 members.len(),
4633 pad
4634 )?;
4635 }
4636 writeln!(self.out, "}};")?;
4637 }
4638 _ => {
4639 let ty_name = TypeContext {
4640 handle,
4641 gctx: module.to_ctx(),
4642 names: &self.names,
4643 access: crate::StorageAccess::empty(),
4644 first_time: true,
4645 };
4646 writeln!(self.out, "typedef {ty_name} {name};")?;
4647 }
4648 }
4649 }
4650
4651 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4653 match type_key {
4654 &crate::PredeclaredType::ModfResult { size, scalar }
4655 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4656 let arg_type_name_owner;
4657 let arg_type_name = if let Some(size) = size {
4658 arg_type_name_owner = format!(
4659 "{NAMESPACE}::{}{}",
4660 if scalar.width == 8 { "double" } else { "float" },
4661 size as u8
4662 );
4663 &arg_type_name_owner
4664 } else if scalar.width == 8 {
4665 "double"
4666 } else {
4667 "float"
4668 };
4669
4670 let other_type_name_owner;
4671 let (defined_func_name, called_func_name, other_type_name) =
4672 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4673 (MODF_FUNCTION, "modf", arg_type_name)
4674 } else {
4675 let other_type_name = if let Some(size) = size {
4676 other_type_name_owner = format!("int{}", size as u8);
4677 &other_type_name_owner
4678 } else {
4679 "int"
4680 };
4681 (FREXP_FUNCTION, "frexp", other_type_name)
4682 };
4683
4684 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4685
4686 writeln!(self.out)?;
4687 writeln!(
4688 self.out,
4689 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4690 {other_type_name} other;
4691 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4692 return {struct_name}{{ fract, other }};
4693}}"
4694 )?;
4695 }
4696 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4697 let arg_type_name = scalar.to_msl_name();
4698 let called_func_name = "atomic_compare_exchange_weak_explicit";
4699 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4700 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4701
4702 writeln!(self.out)?;
4703
4704 for address_space_name in ["device", "threadgroup"] {
4705 writeln!(
4706 self.out,
4707 "\
4708template <typename A>
4709{struct_name} {defined_func_name}(
4710 {address_space_name} A *atomic_ptr,
4711 {arg_type_name} cmp,
4712 {arg_type_name} v
4713) {{
4714 bool swapped = {NAMESPACE}::{called_func_name}(
4715 atomic_ptr, &cmp, v,
4716 metal::memory_order_relaxed, metal::memory_order_relaxed
4717 );
4718 return {struct_name}{{cmp, swapped}};
4719}}"
4720 )?;
4721 }
4722 }
4723 }
4724 }
4725
4726 Ok(())
4727 }
4728
4729 fn write_global_constants(
4731 &mut self,
4732 module: &crate::Module,
4733 mod_info: &valid::ModuleInfo,
4734 ) -> BackendResult {
4735 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4736
4737 for (handle, constant) in constants {
4738 let ty_name = TypeContext {
4739 handle: constant.ty,
4740 gctx: module.to_ctx(),
4741 names: &self.names,
4742 access: crate::StorageAccess::empty(),
4743 first_time: false,
4744 };
4745 let name = &self.names[&NameKey::Constant(handle)];
4746 write!(self.out, "constant {ty_name} {name} = ")?;
4747 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4748 writeln!(self.out, ";")?;
4749 }
4750
4751 Ok(())
4752 }
4753
4754 fn put_inline_sampler_properties(
4755 &mut self,
4756 level: back::Level,
4757 sampler: &sm::InlineSampler,
4758 ) -> BackendResult {
4759 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4760 writeln!(
4761 self.out,
4762 "{}{}::{}_address::{},",
4763 level,
4764 NAMESPACE,
4765 letter,
4766 address.as_str(),
4767 )?;
4768 }
4769 writeln!(
4770 self.out,
4771 "{}{}::mag_filter::{},",
4772 level,
4773 NAMESPACE,
4774 sampler.mag_filter.as_str(),
4775 )?;
4776 writeln!(
4777 self.out,
4778 "{}{}::min_filter::{},",
4779 level,
4780 NAMESPACE,
4781 sampler.min_filter.as_str(),
4782 )?;
4783 if let Some(filter) = sampler.mip_filter {
4784 writeln!(
4785 self.out,
4786 "{}{}::mip_filter::{},",
4787 level,
4788 NAMESPACE,
4789 filter.as_str(),
4790 )?;
4791 }
4792 if sampler.border_color != sm::BorderColor::TransparentBlack {
4794 writeln!(
4795 self.out,
4796 "{}{}::border_color::{},",
4797 level,
4798 NAMESPACE,
4799 sampler.border_color.as_str(),
4800 )?;
4801 }
4802 if false {
4806 if let Some(ref lod) = sampler.lod_clamp {
4807 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4808 }
4809 if let Some(aniso) = sampler.max_anisotropy {
4810 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4811 }
4812 }
4813 if sampler.compare_func != sm::CompareFunc::Never {
4814 writeln!(
4815 self.out,
4816 "{}{}::compare_func::{},",
4817 level,
4818 NAMESPACE,
4819 sampler.compare_func.as_str(),
4820 )?;
4821 }
4822 writeln!(
4823 self.out,
4824 "{}{}::coord::{}",
4825 level,
4826 NAMESPACE,
4827 sampler.coord.as_str()
4828 )?;
4829 Ok(())
4830 }
4831
4832 fn write_unpacking_function(
4833 &mut self,
4834 format: back::msl::VertexFormat,
4835 ) -> Result<(String, u32, Option<crate::VectorSize>, crate::Scalar), Error> {
4836 use crate::{Scalar, VectorSize};
4837 use back::msl::VertexFormat::*;
4838 match format {
4839 Uint8 => {
4840 let name = self.namer.call("unpackUint8");
4841 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4842 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4843 writeln!(self.out, "}}")?;
4844 Ok((name, 1, None, Scalar::U32))
4845 }
4846 Uint8x2 => {
4847 let name = self.namer.call("unpackUint8x2");
4848 writeln!(
4849 self.out,
4850 "metal::uint2 {name}(metal::uchar b0, \
4851 metal::uchar b1) {{"
4852 )?;
4853 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4854 writeln!(self.out, "}}")?;
4855 Ok((name, 2, Some(VectorSize::Bi), Scalar::U32))
4856 }
4857 Uint8x4 => {
4858 let name = self.namer.call("unpackUint8x4");
4859 writeln!(
4860 self.out,
4861 "metal::uint4 {name}(metal::uchar b0, \
4862 metal::uchar b1, \
4863 metal::uchar b2, \
4864 metal::uchar b3) {{"
4865 )?;
4866 writeln!(
4867 self.out,
4868 "{}return metal::uint4(b0, b1, b2, b3);",
4869 back::INDENT
4870 )?;
4871 writeln!(self.out, "}}")?;
4872 Ok((name, 4, Some(VectorSize::Quad), Scalar::U32))
4873 }
4874 Sint8 => {
4875 let name = self.namer.call("unpackSint8");
4876 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4877 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4878 writeln!(self.out, "}}")?;
4879 Ok((name, 1, None, Scalar::I32))
4880 }
4881 Sint8x2 => {
4882 let name = self.namer.call("unpackSint8x2");
4883 writeln!(
4884 self.out,
4885 "metal::int2 {name}(metal::uchar b0, \
4886 metal::uchar b1) {{"
4887 )?;
4888 writeln!(
4889 self.out,
4890 "{}return metal::int2(as_type<char>(b0), \
4891 as_type<char>(b1));",
4892 back::INDENT
4893 )?;
4894 writeln!(self.out, "}}")?;
4895 Ok((name, 2, Some(VectorSize::Bi), Scalar::I32))
4896 }
4897 Sint8x4 => {
4898 let name = self.namer.call("unpackSint8x4");
4899 writeln!(
4900 self.out,
4901 "metal::int4 {name}(metal::uchar b0, \
4902 metal::uchar b1, \
4903 metal::uchar b2, \
4904 metal::uchar b3) {{"
4905 )?;
4906 writeln!(
4907 self.out,
4908 "{}return metal::int4(as_type<char>(b0), \
4909 as_type<char>(b1), \
4910 as_type<char>(b2), \
4911 as_type<char>(b3));",
4912 back::INDENT
4913 )?;
4914 writeln!(self.out, "}}")?;
4915 Ok((name, 4, Some(VectorSize::Quad), Scalar::I32))
4916 }
4917 Unorm8 => {
4918 let name = self.namer.call("unpackUnorm8");
4919 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4920 writeln!(
4921 self.out,
4922 "{}return float(float(b0) / 255.0f);",
4923 back::INDENT
4924 )?;
4925 writeln!(self.out, "}}")?;
4926 Ok((name, 1, None, Scalar::F32))
4927 }
4928 Unorm8x2 => {
4929 let name = self.namer.call("unpackUnorm8x2");
4930 writeln!(
4931 self.out,
4932 "metal::float2 {name}(metal::uchar b0, \
4933 metal::uchar b1) {{"
4934 )?;
4935 writeln!(
4936 self.out,
4937 "{}return metal::float2(float(b0) / 255.0f, \
4938 float(b1) / 255.0f);",
4939 back::INDENT
4940 )?;
4941 writeln!(self.out, "}}")?;
4942 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4943 }
4944 Unorm8x4 => {
4945 let name = self.namer.call("unpackUnorm8x4");
4946 writeln!(
4947 self.out,
4948 "metal::float4 {name}(metal::uchar b0, \
4949 metal::uchar b1, \
4950 metal::uchar b2, \
4951 metal::uchar b3) {{"
4952 )?;
4953 writeln!(
4954 self.out,
4955 "{}return metal::float4(float(b0) / 255.0f, \
4956 float(b1) / 255.0f, \
4957 float(b2) / 255.0f, \
4958 float(b3) / 255.0f);",
4959 back::INDENT
4960 )?;
4961 writeln!(self.out, "}}")?;
4962 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
4963 }
4964 Snorm8 => {
4965 let name = self.namer.call("unpackSnorm8");
4966 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4967 writeln!(
4968 self.out,
4969 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4970 back::INDENT
4971 )?;
4972 writeln!(self.out, "}}")?;
4973 Ok((name, 1, None, Scalar::F32))
4974 }
4975 Snorm8x2 => {
4976 let name = self.namer.call("unpackSnorm8x2");
4977 writeln!(
4978 self.out,
4979 "metal::float2 {name}(metal::uchar b0, \
4980 metal::uchar b1) {{"
4981 )?;
4982 writeln!(
4983 self.out,
4984 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4985 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4986 back::INDENT
4987 )?;
4988 writeln!(self.out, "}}")?;
4989 Ok((name, 2, Some(VectorSize::Bi), Scalar::F32))
4990 }
4991 Snorm8x4 => {
4992 let name = self.namer.call("unpackSnorm8x4");
4993 writeln!(
4994 self.out,
4995 "metal::float4 {name}(metal::uchar b0, \
4996 metal::uchar b1, \
4997 metal::uchar b2, \
4998 metal::uchar b3) {{"
4999 )?;
5000 writeln!(
5001 self.out,
5002 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
5003 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
5004 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
5005 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
5006 back::INDENT
5007 )?;
5008 writeln!(self.out, "}}")?;
5009 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5010 }
5011 Uint16 => {
5012 let name = self.namer.call("unpackUint16");
5013 writeln!(
5014 self.out,
5015 "metal::uint {name}(metal::uint b0, \
5016 metal::uint b1) {{"
5017 )?;
5018 writeln!(
5019 self.out,
5020 "{}return metal::uint(b1 << 8 | b0);",
5021 back::INDENT
5022 )?;
5023 writeln!(self.out, "}}")?;
5024 Ok((name, 2, None, Scalar::U32))
5025 }
5026 Uint16x2 => {
5027 let name = self.namer.call("unpackUint16x2");
5028 writeln!(
5029 self.out,
5030 "metal::uint2 {name}(metal::uint b0, \
5031 metal::uint b1, \
5032 metal::uint b2, \
5033 metal::uint b3) {{"
5034 )?;
5035 writeln!(
5036 self.out,
5037 "{}return metal::uint2(b1 << 8 | b0, \
5038 b3 << 8 | b2);",
5039 back::INDENT
5040 )?;
5041 writeln!(self.out, "}}")?;
5042 Ok((name, 4, Some(VectorSize::Bi), Scalar::U32))
5043 }
5044 Uint16x4 => {
5045 let name = self.namer.call("unpackUint16x4");
5046 writeln!(
5047 self.out,
5048 "metal::uint4 {name}(metal::uint b0, \
5049 metal::uint b1, \
5050 metal::uint b2, \
5051 metal::uint b3, \
5052 metal::uint b4, \
5053 metal::uint b5, \
5054 metal::uint b6, \
5055 metal::uint b7) {{"
5056 )?;
5057 writeln!(
5058 self.out,
5059 "{}return metal::uint4(b1 << 8 | b0, \
5060 b3 << 8 | b2, \
5061 b5 << 8 | b4, \
5062 b7 << 8 | b6);",
5063 back::INDENT
5064 )?;
5065 writeln!(self.out, "}}")?;
5066 Ok((name, 8, Some(VectorSize::Quad), Scalar::U32))
5067 }
5068 Sint16 => {
5069 let name = self.namer.call("unpackSint16");
5070 writeln!(
5071 self.out,
5072 "int {name}(metal::ushort b0, \
5073 metal::ushort b1) {{"
5074 )?;
5075 writeln!(
5076 self.out,
5077 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5078 back::INDENT
5079 )?;
5080 writeln!(self.out, "}}")?;
5081 Ok((name, 2, None, Scalar::I32))
5082 }
5083 Sint16x2 => {
5084 let name = self.namer.call("unpackSint16x2");
5085 writeln!(
5086 self.out,
5087 "metal::int2 {name}(metal::ushort b0, \
5088 metal::ushort b1, \
5089 metal::ushort b2, \
5090 metal::ushort b3) {{"
5091 )?;
5092 writeln!(
5093 self.out,
5094 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5095 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5096 back::INDENT
5097 )?;
5098 writeln!(self.out, "}}")?;
5099 Ok((name, 4, Some(VectorSize::Bi), Scalar::I32))
5100 }
5101 Sint16x4 => {
5102 let name = self.namer.call("unpackSint16x4");
5103 writeln!(
5104 self.out,
5105 "metal::int4 {name}(metal::ushort b0, \
5106 metal::ushort b1, \
5107 metal::ushort b2, \
5108 metal::ushort b3, \
5109 metal::ushort b4, \
5110 metal::ushort b5, \
5111 metal::ushort b6, \
5112 metal::ushort b7) {{"
5113 )?;
5114 writeln!(
5115 self.out,
5116 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5117 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5118 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5119 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5120 back::INDENT
5121 )?;
5122 writeln!(self.out, "}}")?;
5123 Ok((name, 8, Some(VectorSize::Quad), Scalar::I32))
5124 }
5125 Unorm16 => {
5126 let name = self.namer.call("unpackUnorm16");
5127 writeln!(
5128 self.out,
5129 "float {name}(metal::ushort b0, \
5130 metal::ushort b1) {{"
5131 )?;
5132 writeln!(
5133 self.out,
5134 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5135 back::INDENT
5136 )?;
5137 writeln!(self.out, "}}")?;
5138 Ok((name, 2, None, Scalar::F32))
5139 }
5140 Unorm16x2 => {
5141 let name = self.namer.call("unpackUnorm16x2");
5142 writeln!(
5143 self.out,
5144 "metal::float2 {name}(metal::ushort b0, \
5145 metal::ushort b1, \
5146 metal::ushort b2, \
5147 metal::ushort b3) {{"
5148 )?;
5149 writeln!(
5150 self.out,
5151 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5152 float(b3 << 8 | b2) / 65535.0f);",
5153 back::INDENT
5154 )?;
5155 writeln!(self.out, "}}")?;
5156 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5157 }
5158 Unorm16x4 => {
5159 let name = self.namer.call("unpackUnorm16x4");
5160 writeln!(
5161 self.out,
5162 "metal::float4 {name}(metal::ushort b0, \
5163 metal::ushort b1, \
5164 metal::ushort b2, \
5165 metal::ushort b3, \
5166 metal::ushort b4, \
5167 metal::ushort b5, \
5168 metal::ushort b6, \
5169 metal::ushort b7) {{"
5170 )?;
5171 writeln!(
5172 self.out,
5173 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5174 float(b3 << 8 | b2) / 65535.0f, \
5175 float(b5 << 8 | b4) / 65535.0f, \
5176 float(b7 << 8 | b6) / 65535.0f);",
5177 back::INDENT
5178 )?;
5179 writeln!(self.out, "}}")?;
5180 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5181 }
5182 Snorm16 => {
5183 let name = self.namer.call("unpackSnorm16");
5184 writeln!(
5185 self.out,
5186 "float {name}(metal::ushort b0, \
5187 metal::ushort b1) {{"
5188 )?;
5189 writeln!(
5190 self.out,
5191 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5192 back::INDENT
5193 )?;
5194 writeln!(self.out, "}}")?;
5195 Ok((name, 2, None, Scalar::F32))
5196 }
5197 Snorm16x2 => {
5198 let name = self.namer.call("unpackSnorm16x2");
5199 writeln!(
5200 self.out,
5201 "metal::float2 {name}(uint b0, \
5202 uint b1, \
5203 uint b2, \
5204 uint b3) {{"
5205 )?;
5206 writeln!(
5207 self.out,
5208 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5209 back::INDENT
5210 )?;
5211 writeln!(self.out, "}}")?;
5212 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5213 }
5214 Snorm16x4 => {
5215 let name = self.namer.call("unpackSnorm16x4");
5216 writeln!(
5217 self.out,
5218 "metal::float4 {name}(uint b0, \
5219 uint b1, \
5220 uint b2, \
5221 uint b3, \
5222 uint b4, \
5223 uint b5, \
5224 uint b6, \
5225 uint b7) {{"
5226 )?;
5227 writeln!(
5228 self.out,
5229 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5230 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5231 back::INDENT
5232 )?;
5233 writeln!(self.out, "}}")?;
5234 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5235 }
5236 Float16 => {
5237 let name = self.namer.call("unpackFloat16");
5238 writeln!(
5239 self.out,
5240 "float {name}(metal::ushort b0, \
5241 metal::ushort b1) {{"
5242 )?;
5243 writeln!(
5244 self.out,
5245 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5246 back::INDENT
5247 )?;
5248 writeln!(self.out, "}}")?;
5249 Ok((name, 2, None, Scalar::F32))
5250 }
5251 Float16x2 => {
5252 let name = self.namer.call("unpackFloat16x2");
5253 writeln!(
5254 self.out,
5255 "metal::float2 {name}(metal::ushort b0, \
5256 metal::ushort b1, \
5257 metal::ushort b2, \
5258 metal::ushort b3) {{"
5259 )?;
5260 writeln!(
5261 self.out,
5262 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5263 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5264 back::INDENT
5265 )?;
5266 writeln!(self.out, "}}")?;
5267 Ok((name, 4, Some(VectorSize::Bi), Scalar::F32))
5268 }
5269 Float16x4 => {
5270 let name = self.namer.call("unpackFloat16x4");
5271 writeln!(
5272 self.out,
5273 "metal::float4 {name}(metal::ushort b0, \
5274 metal::ushort b1, \
5275 metal::ushort b2, \
5276 metal::ushort b3, \
5277 metal::ushort b4, \
5278 metal::ushort b5, \
5279 metal::ushort b6, \
5280 metal::ushort b7) {{"
5281 )?;
5282 writeln!(
5283 self.out,
5284 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5285 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5286 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5287 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5288 back::INDENT
5289 )?;
5290 writeln!(self.out, "}}")?;
5291 Ok((name, 8, Some(VectorSize::Quad), Scalar::F32))
5292 }
5293 Float32 => {
5294 let name = self.namer.call("unpackFloat32");
5295 writeln!(
5296 self.out,
5297 "float {name}(uint b0, \
5298 uint b1, \
5299 uint b2, \
5300 uint b3) {{"
5301 )?;
5302 writeln!(
5303 self.out,
5304 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5305 back::INDENT
5306 )?;
5307 writeln!(self.out, "}}")?;
5308 Ok((name, 4, None, Scalar::F32))
5309 }
5310 Float32x2 => {
5311 let name = self.namer.call("unpackFloat32x2");
5312 writeln!(
5313 self.out,
5314 "metal::float2 {name}(uint b0, \
5315 uint b1, \
5316 uint b2, \
5317 uint b3, \
5318 uint b4, \
5319 uint b5, \
5320 uint b6, \
5321 uint b7) {{"
5322 )?;
5323 writeln!(
5324 self.out,
5325 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5326 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5327 back::INDENT
5328 )?;
5329 writeln!(self.out, "}}")?;
5330 Ok((name, 8, Some(VectorSize::Bi), Scalar::F32))
5331 }
5332 Float32x3 => {
5333 let name = self.namer.call("unpackFloat32x3");
5334 writeln!(
5335 self.out,
5336 "metal::float3 {name}(uint b0, \
5337 uint b1, \
5338 uint b2, \
5339 uint b3, \
5340 uint b4, \
5341 uint b5, \
5342 uint b6, \
5343 uint b7, \
5344 uint b8, \
5345 uint b9, \
5346 uint b10, \
5347 uint b11) {{"
5348 )?;
5349 writeln!(
5350 self.out,
5351 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5352 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5353 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5354 back::INDENT
5355 )?;
5356 writeln!(self.out, "}}")?;
5357 Ok((name, 12, Some(VectorSize::Tri), Scalar::F32))
5358 }
5359 Float32x4 => {
5360 let name = self.namer.call("unpackFloat32x4");
5361 writeln!(
5362 self.out,
5363 "metal::float4 {name}(uint b0, \
5364 uint b1, \
5365 uint b2, \
5366 uint b3, \
5367 uint b4, \
5368 uint b5, \
5369 uint b6, \
5370 uint b7, \
5371 uint b8, \
5372 uint b9, \
5373 uint b10, \
5374 uint b11, \
5375 uint b12, \
5376 uint b13, \
5377 uint b14, \
5378 uint b15) {{"
5379 )?;
5380 writeln!(
5381 self.out,
5382 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5383 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5384 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5385 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5386 back::INDENT
5387 )?;
5388 writeln!(self.out, "}}")?;
5389 Ok((name, 16, Some(VectorSize::Quad), Scalar::F32))
5390 }
5391 Uint32 => {
5392 let name = self.namer.call("unpackUint32");
5393 writeln!(
5394 self.out,
5395 "uint {name}(uint b0, \
5396 uint b1, \
5397 uint b2, \
5398 uint b3) {{"
5399 )?;
5400 writeln!(
5401 self.out,
5402 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5403 back::INDENT
5404 )?;
5405 writeln!(self.out, "}}")?;
5406 Ok((name, 4, None, Scalar::U32))
5407 }
5408 Uint32x2 => {
5409 let name = self.namer.call("unpackUint32x2");
5410 writeln!(
5411 self.out,
5412 "uint2 {name}(uint b0, \
5413 uint b1, \
5414 uint b2, \
5415 uint b3, \
5416 uint b4, \
5417 uint b5, \
5418 uint b6, \
5419 uint b7) {{"
5420 )?;
5421 writeln!(
5422 self.out,
5423 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5424 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5425 back::INDENT
5426 )?;
5427 writeln!(self.out, "}}")?;
5428 Ok((name, 8, Some(VectorSize::Bi), Scalar::U32))
5429 }
5430 Uint32x3 => {
5431 let name = self.namer.call("unpackUint32x3");
5432 writeln!(
5433 self.out,
5434 "uint3 {name}(uint b0, \
5435 uint b1, \
5436 uint b2, \
5437 uint b3, \
5438 uint b4, \
5439 uint b5, \
5440 uint b6, \
5441 uint b7, \
5442 uint b8, \
5443 uint b9, \
5444 uint b10, \
5445 uint b11) {{"
5446 )?;
5447 writeln!(
5448 self.out,
5449 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5450 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5451 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5452 back::INDENT
5453 )?;
5454 writeln!(self.out, "}}")?;
5455 Ok((name, 12, Some(VectorSize::Tri), Scalar::U32))
5456 }
5457 Uint32x4 => {
5458 let name = self.namer.call("unpackUint32x4");
5459 writeln!(
5460 self.out,
5461 "{NAMESPACE}::uint4 {name}(uint b0, \
5462 uint b1, \
5463 uint b2, \
5464 uint b3, \
5465 uint b4, \
5466 uint b5, \
5467 uint b6, \
5468 uint b7, \
5469 uint b8, \
5470 uint b9, \
5471 uint b10, \
5472 uint b11, \
5473 uint b12, \
5474 uint b13, \
5475 uint b14, \
5476 uint b15) {{"
5477 )?;
5478 writeln!(
5479 self.out,
5480 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5481 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5482 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5483 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5484 back::INDENT
5485 )?;
5486 writeln!(self.out, "}}")?;
5487 Ok((name, 16, Some(VectorSize::Quad), Scalar::U32))
5488 }
5489 Sint32 => {
5490 let name = self.namer.call("unpackSint32");
5491 writeln!(
5492 self.out,
5493 "int {name}(uint b0, \
5494 uint b1, \
5495 uint b2, \
5496 uint b3) {{"
5497 )?;
5498 writeln!(
5499 self.out,
5500 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5501 back::INDENT
5502 )?;
5503 writeln!(self.out, "}}")?;
5504 Ok((name, 4, None, Scalar::I32))
5505 }
5506 Sint32x2 => {
5507 let name = self.namer.call("unpackSint32x2");
5508 writeln!(
5509 self.out,
5510 "metal::int2 {name}(uint b0, \
5511 uint b1, \
5512 uint b2, \
5513 uint b3, \
5514 uint b4, \
5515 uint b5, \
5516 uint b6, \
5517 uint b7) {{"
5518 )?;
5519 writeln!(
5520 self.out,
5521 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5522 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5523 back::INDENT
5524 )?;
5525 writeln!(self.out, "}}")?;
5526 Ok((name, 8, Some(VectorSize::Bi), Scalar::I32))
5527 }
5528 Sint32x3 => {
5529 let name = self.namer.call("unpackSint32x3");
5530 writeln!(
5531 self.out,
5532 "metal::int3 {name}(uint b0, \
5533 uint b1, \
5534 uint b2, \
5535 uint b3, \
5536 uint b4, \
5537 uint b5, \
5538 uint b6, \
5539 uint b7, \
5540 uint b8, \
5541 uint b9, \
5542 uint b10, \
5543 uint b11) {{"
5544 )?;
5545 writeln!(
5546 self.out,
5547 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5548 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5549 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5550 back::INDENT
5551 )?;
5552 writeln!(self.out, "}}")?;
5553 Ok((name, 12, Some(VectorSize::Tri), Scalar::I32))
5554 }
5555 Sint32x4 => {
5556 let name = self.namer.call("unpackSint32x4");
5557 writeln!(
5558 self.out,
5559 "metal::int4 {name}(uint b0, \
5560 uint b1, \
5561 uint b2, \
5562 uint b3, \
5563 uint b4, \
5564 uint b5, \
5565 uint b6, \
5566 uint b7, \
5567 uint b8, \
5568 uint b9, \
5569 uint b10, \
5570 uint b11, \
5571 uint b12, \
5572 uint b13, \
5573 uint b14, \
5574 uint b15) {{"
5575 )?;
5576 writeln!(
5577 self.out,
5578 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5579 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5580 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5581 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5582 back::INDENT
5583 )?;
5584 writeln!(self.out, "}}")?;
5585 Ok((name, 16, Some(VectorSize::Quad), Scalar::I32))
5586 }
5587 Unorm10_10_10_2 => {
5588 let name = self.namer.call("unpackUnorm10_10_10_2");
5589 writeln!(
5590 self.out,
5591 "metal::float4 {name}(uint b0, \
5592 uint b1, \
5593 uint b2, \
5594 uint b3) {{"
5595 )?;
5596 writeln!(
5597 self.out,
5598 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5610 back::INDENT
5611 )?;
5612 writeln!(self.out, "}}")?;
5613 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5614 }
5615 Unorm8x4Bgra => {
5616 let name = self.namer.call("unpackUnorm8x4Bgra");
5617 writeln!(
5618 self.out,
5619 "metal::float4 {name}(metal::uchar b0, \
5620 metal::uchar b1, \
5621 metal::uchar b2, \
5622 metal::uchar b3) {{"
5623 )?;
5624 writeln!(
5625 self.out,
5626 "{}return metal::float4(float(b2) / 255.0f, \
5627 float(b1) / 255.0f, \
5628 float(b0) / 255.0f, \
5629 float(b3) / 255.0f);",
5630 back::INDENT
5631 )?;
5632 writeln!(self.out, "}}")?;
5633 Ok((name, 4, Some(VectorSize::Quad), Scalar::F32))
5634 }
5635 }
5636 }
5637
5638 fn write_wrapped_unary_op(
5639 &mut self,
5640 module: &crate::Module,
5641 func_ctx: &back::FunctionCtx,
5642 op: crate::UnaryOperator,
5643 operand: Handle<crate::Expression>,
5644 ) -> BackendResult {
5645 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5646 match op {
5647 crate::UnaryOperator::Negate
5654 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5655 {
5656 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5657 return Ok(());
5658 };
5659 let wrapped = WrappedFunction::UnaryOp {
5660 op,
5661 ty: (vector_size, scalar),
5662 };
5663 if !self.wrapped_functions.insert(wrapped) {
5664 return Ok(());
5665 }
5666
5667 let unsigned_scalar = crate::Scalar {
5668 kind: crate::ScalarKind::Uint,
5669 ..scalar
5670 };
5671 let mut type_name = String::new();
5672 let mut unsigned_type_name = String::new();
5673 match vector_size {
5674 None => {
5675 put_numeric_type(&mut type_name, scalar, &[])?;
5676 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5677 }
5678 Some(size) => {
5679 put_numeric_type(&mut type_name, scalar, &[size])?;
5680 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5681 }
5682 };
5683
5684 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5685 let level = back::Level(1);
5686 if scalar.width < 4 {
5690 writeln!(
5691 self.out,
5692 "{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
5693 )?;
5694 } else {
5695 writeln!(
5696 self.out,
5697 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5698 )?;
5699 }
5700 writeln!(self.out, "}}")?;
5701 writeln!(self.out)?;
5702 }
5703 _ => {}
5704 }
5705 Ok(())
5706 }
5707
5708 fn write_wrapped_binary_op(
5709 &mut self,
5710 module: &crate::Module,
5711 func_ctx: &back::FunctionCtx,
5712 expr: Handle<crate::Expression>,
5713 op: crate::BinaryOperator,
5714 left: Handle<crate::Expression>,
5715 right: Handle<crate::Expression>,
5716 ) -> BackendResult {
5717 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5718 let left_ty = func_ctx.resolve_type(left, &module.types);
5719 let right_ty = func_ctx.resolve_type(right, &module.types);
5720 match (op, expr_ty.scalar_kind()) {
5721 (
5728 crate::BinaryOperator::Divide,
5729 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5730 ) => {
5731 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5732 return Ok(());
5733 };
5734 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5735 return Ok(());
5736 };
5737 let wrapped = WrappedFunction::BinaryOp {
5738 op,
5739 left_ty: left_wrapped_ty,
5740 right_ty: right_wrapped_ty,
5741 };
5742 if !self.wrapped_functions.insert(wrapped) {
5743 return Ok(());
5744 }
5745
5746 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5747 return Ok(());
5748 };
5749 let mut type_name = String::new();
5750 match vector_size {
5751 None => put_numeric_type(&mut type_name, scalar, &[])?,
5752 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5753 };
5754 writeln!(
5755 self.out,
5756 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5757 )?;
5758 let level = back::Level(1);
5759 let (lp, rp) = if scalar.width < 4 {
5763 (format!("{type_name}("), ")".to_string())
5764 } else {
5765 (String::new(), String::new())
5766 };
5767 match scalar.kind {
5768 crate::ScalarKind::Sint => {
5769 let min_val = match scalar.width {
5770 2 => crate::Literal::I16(i16::MIN),
5771 4 => crate::Literal::I32(i32::MIN),
5772 8 => crate::Literal::I64(i64::MIN),
5773 _ => {
5774 return Err(Error::GenericValidation(format!(
5775 "Unexpected width for scalar {scalar:?}"
5776 )));
5777 }
5778 };
5779 write!(
5780 self.out,
5781 "{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
5782 )?;
5783 self.put_literal(min_val)?;
5784 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
5785 }
5786 crate::ScalarKind::Uint => {
5787 let suffix = if scalar.width < 4 { "" } else { "u" };
5788 writeln!(
5789 self.out,
5790 "{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5791 )?
5792 }
5793 _ => unreachable!(),
5794 }
5795 writeln!(self.out, "}}")?;
5796 writeln!(self.out)?;
5797 }
5798 (
5811 crate::BinaryOperator::Modulo,
5812 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5813 ) => {
5814 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5815 return Ok(());
5816 };
5817 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5818 else {
5819 return Ok(());
5820 };
5821 let wrapped = WrappedFunction::BinaryOp {
5822 op,
5823 left_ty: left_wrapped_ty,
5824 right_ty: (right_vector_size, right_scalar),
5825 };
5826 if !self.wrapped_functions.insert(wrapped) {
5827 return Ok(());
5828 }
5829
5830 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5831 return Ok(());
5832 };
5833 let mut type_name = String::new();
5834 match vector_size {
5835 None => put_numeric_type(&mut type_name, scalar, &[])?,
5836 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5837 };
5838 let mut rhs_type_name = String::new();
5839 match right_vector_size {
5840 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5841 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5842 };
5843
5844 writeln!(
5845 self.out,
5846 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5847 )?;
5848 let level = back::Level(1);
5849 let (lp, rp) = if scalar.width < 4 {
5850 (format!("{type_name}("), ")".to_string())
5851 } else {
5852 (String::new(), String::new())
5853 };
5854 match scalar.kind {
5855 crate::ScalarKind::Sint => {
5856 let min_val = match scalar.width {
5857 2 => crate::Literal::I16(i16::MIN),
5858 4 => crate::Literal::I32(i32::MIN),
5859 8 => crate::Literal::I64(i64::MIN),
5860 _ => {
5861 return Err(Error::GenericValidation(format!(
5862 "Unexpected width for scalar {scalar:?}"
5863 )));
5864 }
5865 };
5866 write!(
5867 self.out,
5868 "{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
5869 )?;
5870 self.put_literal(min_val)?;
5871 writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
5872 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5873 }
5874 crate::ScalarKind::Uint => {
5875 let suffix = if scalar.width < 4 { "" } else { "u" };
5876 writeln!(
5877 self.out,
5878 "{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
5879 )?
5880 }
5881 _ => unreachable!(),
5882 }
5883 writeln!(self.out, "}}")?;
5884 writeln!(self.out)?;
5885 }
5886 _ => {}
5887 }
5888 Ok(())
5889 }
5890
5891 fn get_dot_wrapper_function_helper_name(
5897 &self,
5898 scalar: crate::Scalar,
5899 size: crate::VectorSize,
5900 ) -> String {
5901 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5903
5904 let type_name = scalar.to_msl_name();
5905 let size_suffix = common::vector_size_str(size);
5906 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5907 }
5908
5909 #[allow(clippy::too_many_arguments)]
5910 fn write_wrapped_math_function(
5911 &mut self,
5912 module: &crate::Module,
5913 func_ctx: &back::FunctionCtx,
5914 fun: crate::MathFunction,
5915 arg: Handle<crate::Expression>,
5916 _arg1: Option<Handle<crate::Expression>>,
5917 _arg2: Option<Handle<crate::Expression>>,
5918 _arg3: Option<Handle<crate::Expression>>,
5919 ) -> BackendResult {
5920 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5921 match fun {
5922 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5930 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5931 return Ok(());
5932 };
5933 let wrapped = WrappedFunction::Math {
5934 fun,
5935 arg_ty: (vector_size, scalar),
5936 };
5937 if !self.wrapped_functions.insert(wrapped) {
5938 return Ok(());
5939 }
5940
5941 let unsigned_scalar = crate::Scalar {
5942 kind: crate::ScalarKind::Uint,
5943 ..scalar
5944 };
5945 let mut type_name = String::new();
5946 let mut unsigned_type_name = String::new();
5947 match vector_size {
5948 None => {
5949 put_numeric_type(&mut type_name, scalar, &[])?;
5950 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5951 }
5952 Some(size) => {
5953 put_numeric_type(&mut type_name, scalar, &[size])?;
5954 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5955 }
5956 };
5957
5958 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5959 let level = back::Level(1);
5960 let zero = if scalar.width < 4 {
5961 format!("{type_name}(0)")
5962 } else {
5963 "0".to_string()
5964 };
5965 let neg_expr = if scalar.width < 4 {
5966 format!(
5967 "static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
5968 )
5969 } else {
5970 format!("-as_type<{unsigned_type_name}>(val)")
5971 };
5972 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
5973 writeln!(self.out, "}}")?;
5974 writeln!(self.out)?;
5975 }
5976
5977 crate::MathFunction::Dot => match *arg_ty {
5978 crate::TypeInner::Vector { size, scalar }
5979 if matches!(
5980 scalar.kind,
5981 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5982 ) =>
5983 {
5984 let wrapped = WrappedFunction::Math {
5986 fun,
5987 arg_ty: (Some(size), scalar),
5988 };
5989 if !self.wrapped_functions.insert(wrapped) {
5990 return Ok(());
5991 }
5992
5993 let mut vec_ty = String::new();
5994 put_numeric_type(&mut vec_ty, scalar, &[size])?;
5995 let mut ret_ty = String::new();
5996 put_numeric_type(&mut ret_ty, scalar, &[])?;
5997
5998 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5999
6000 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
6002 let level = back::Level(1);
6003 write!(self.out, "{level}return ")?;
6004 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
6005 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
6006 Ok(())
6007 })?;
6008 writeln!(self.out, ";")?;
6009 writeln!(self.out, "}}")?;
6010 writeln!(self.out)?;
6011 }
6012 _ => {}
6013 },
6014
6015 _ => {}
6016 }
6017 Ok(())
6018 }
6019
6020 fn write_wrapped_cast(
6021 &mut self,
6022 module: &crate::Module,
6023 func_ctx: &back::FunctionCtx,
6024 expr: Handle<crate::Expression>,
6025 kind: crate::ScalarKind,
6026 convert: Option<crate::Bytes>,
6027 ) -> BackendResult {
6028 let src_ty = func_ctx.resolve_type(expr, &module.types);
6039 let Some(width) = convert else {
6040 return Ok(());
6041 };
6042 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
6043 return Ok(());
6044 };
6045 let dst_scalar = crate::Scalar { kind, width };
6046 if src_scalar.kind != crate::ScalarKind::Float
6047 || (dst_scalar.kind != crate::ScalarKind::Sint
6048 && dst_scalar.kind != crate::ScalarKind::Uint)
6049 {
6050 return Ok(());
6051 }
6052 let wrapped = WrappedFunction::Cast {
6053 src_scalar,
6054 vector_size,
6055 dst_scalar,
6056 };
6057 if !self.wrapped_functions.insert(wrapped) {
6058 return Ok(());
6059 }
6060 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
6061
6062 let mut src_type_name = String::new();
6063 match vector_size {
6064 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
6065 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
6066 };
6067 let mut dst_type_name = String::new();
6068 match vector_size {
6069 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
6070 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
6071 };
6072 let fun_name = match dst_scalar {
6073 crate::Scalar::I32 => F2I32_FUNCTION,
6074 crate::Scalar::U32 => F2U32_FUNCTION,
6075 crate::Scalar::I64 => F2I64_FUNCTION,
6076 crate::Scalar::U64 => F2U64_FUNCTION,
6077 _ => unreachable!(),
6078 };
6079
6080 writeln!(
6081 self.out,
6082 "{dst_type_name} {fun_name}({src_type_name} value) {{"
6083 )?;
6084 let level = back::Level(1);
6085 write!(
6086 self.out,
6087 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
6088 )?;
6089 self.put_literal(min)?;
6090 write!(self.out, ", ")?;
6091 self.put_literal(max)?;
6092 writeln!(self.out, "));")?;
6093 writeln!(self.out, "}}")?;
6094 writeln!(self.out)?;
6095 Ok(())
6096 }
6097
6098 fn write_convert_yuv_to_rgb_and_return(
6106 &mut self,
6107 level: back::Level,
6108 y: &str,
6109 uv: &str,
6110 params: &str,
6111 ) -> BackendResult {
6112 let l1 = level;
6113 let l2 = l1.next();
6114
6115 writeln!(
6117 self.out,
6118 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6119 )?;
6120
6121 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6124 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6125 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6126 writeln!(
6127 self.out,
6128 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6129 )?;
6130
6131 writeln!(
6134 self.out,
6135 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6136 )?;
6137
6138 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6141 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6142 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6143 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6144
6145 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6146 Ok(())
6147 }
6148
6149 #[allow(clippy::too_many_arguments)]
6150 fn write_wrapped_image_load(
6151 &mut self,
6152 module: &crate::Module,
6153 func_ctx: &back::FunctionCtx,
6154 image: Handle<crate::Expression>,
6155 _coordinate: Handle<crate::Expression>,
6156 _array_index: Option<Handle<crate::Expression>>,
6157 _sample: Option<Handle<crate::Expression>>,
6158 _level: Option<Handle<crate::Expression>>,
6159 ) -> BackendResult {
6160 let class = match *func_ctx.resolve_type(image, &module.types) {
6162 crate::TypeInner::Image { class, .. } => class,
6163 _ => unreachable!(),
6164 };
6165 if class != crate::ImageClass::External {
6166 return Ok(());
6167 }
6168 let wrapped = WrappedFunction::ImageLoad { class };
6169 if !self.wrapped_functions.insert(wrapped) {
6170 return Ok(());
6171 }
6172
6173 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6174 let l1 = back::Level(1);
6175 let l2 = l1.next();
6176 let l3 = l2.next();
6177 writeln!(
6178 self.out,
6179 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6180 )?;
6181 writeln!(
6185 self.out,
6186 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6187 )?;
6188 writeln!(
6189 self.out,
6190 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6191 )?;
6192
6193 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6195 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6196 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6198 writeln!(self.out, "{l1}}} else {{")?;
6199
6200 writeln!(
6202 self.out,
6203 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6204 )?;
6205 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6206
6207 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6209
6210 writeln!(self.out, "{l2}float2 uv;")?;
6211 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6212 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6214 writeln!(self.out, "{l2}}} else {{")?;
6215 writeln!(
6217 self.out,
6218 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6219 )?;
6220 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6221 writeln!(
6222 self.out,
6223 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6224 )?;
6225 writeln!(self.out, "{l2}}}")?;
6226
6227 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6228
6229 writeln!(self.out, "{l1}}}")?;
6230 writeln!(self.out, "}}")?;
6231 writeln!(self.out)?;
6232 Ok(())
6233 }
6234
6235 #[allow(clippy::too_many_arguments)]
6236 fn write_wrapped_image_sample(
6237 &mut self,
6238 module: &crate::Module,
6239 func_ctx: &back::FunctionCtx,
6240 image: Handle<crate::Expression>,
6241 _sampler: Handle<crate::Expression>,
6242 _gather: Option<crate::SwizzleComponent>,
6243 _coordinate: Handle<crate::Expression>,
6244 _array_index: Option<Handle<crate::Expression>>,
6245 _offset: Option<Handle<crate::Expression>>,
6246 _level: crate::SampleLevel,
6247 _depth_ref: Option<Handle<crate::Expression>>,
6248 clamp_to_edge: bool,
6249 ) -> BackendResult {
6250 if !clamp_to_edge {
6253 return Ok(());
6254 }
6255 let class = match *func_ctx.resolve_type(image, &module.types) {
6256 crate::TypeInner::Image { class, .. } => class,
6257 _ => unreachable!(),
6258 };
6259 let wrapped = WrappedFunction::ImageSample {
6260 class,
6261 clamp_to_edge: true,
6262 };
6263 if !self.wrapped_functions.insert(wrapped) {
6264 return Ok(());
6265 }
6266 match class {
6267 crate::ImageClass::External => {
6268 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6269 let l1 = back::Level(1);
6270 let l2 = l1.next();
6271 let l3 = l2.next();
6272 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6273 writeln!(
6274 self.out,
6275 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6276 )?;
6277
6278 writeln!(
6286 self.out,
6287 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6288 )?;
6289 writeln!(
6290 self.out,
6291 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6292 )?;
6293 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6294 writeln!(
6295 self.out,
6296 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6297 )?;
6298 writeln!(
6299 self.out,
6300 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6301 )?;
6302 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6303 writeln!(
6305 self.out,
6306 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6307 )?;
6308 writeln!(self.out, "{l1}}} else {{")?;
6309 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6310 writeln!(
6311 self.out,
6312 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6313 )?;
6314 writeln!(
6315 self.out,
6316 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6317 )?;
6318
6319 writeln!(
6321 self.out,
6322 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6323 )?;
6324 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6325 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6326 writeln!(
6328 self.out,
6329 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6330 )?;
6331 writeln!(self.out, "{l2}}} else {{")?;
6332 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6334 writeln!(
6335 self.out,
6336 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6337 )?;
6338 writeln!(
6339 self.out,
6340 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6341 )?;
6342 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6343 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6344 writeln!(self.out, "{l2}}}")?;
6345
6346 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6347
6348 writeln!(self.out, "{l1}}}")?;
6349 writeln!(self.out, "}}")?;
6350 writeln!(self.out)?;
6351 }
6352 _ => {
6353 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6354 let l1 = back::Level(1);
6355 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6356 writeln!(
6357 self.out,
6358 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6359 )?;
6360 writeln!(self.out, "}}")?;
6361 writeln!(self.out)?;
6362 }
6363 }
6364 Ok(())
6365 }
6366
6367 fn write_wrapped_image_query(
6368 &mut self,
6369 module: &crate::Module,
6370 func_ctx: &back::FunctionCtx,
6371 image: Handle<crate::Expression>,
6372 query: crate::ImageQuery,
6373 ) -> BackendResult {
6374 if !matches!(query, crate::ImageQuery::Size { .. }) {
6376 return Ok(());
6377 }
6378 let class = match *func_ctx.resolve_type(image, &module.types) {
6379 crate::TypeInner::Image { class, .. } => class,
6380 _ => unreachable!(),
6381 };
6382 if class != crate::ImageClass::External {
6383 return Ok(());
6384 }
6385 let wrapped = WrappedFunction::ImageQuerySize { class };
6386 if !self.wrapped_functions.insert(wrapped) {
6387 return Ok(());
6388 }
6389 writeln!(
6390 self.out,
6391 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6392 )?;
6393 let l1 = back::Level(1);
6394 let l2 = l1.next();
6395 writeln!(
6396 self.out,
6397 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6398 )?;
6399 writeln!(self.out, "{l2}return tex.params.size;")?;
6400 writeln!(self.out, "{l1}}} else {{")?;
6401 writeln!(
6403 self.out,
6404 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6405 )?;
6406 writeln!(self.out, "{l1}}}")?;
6407 writeln!(self.out, "}}")?;
6408 writeln!(self.out)?;
6409 Ok(())
6410 }
6411
6412 fn write_wrapped_cooperative_load(
6413 &mut self,
6414 module: &crate::Module,
6415 func_ctx: &back::FunctionCtx,
6416 columns: crate::CooperativeSize,
6417 rows: crate::CooperativeSize,
6418 pointer: Handle<crate::Expression>,
6419 ) -> BackendResult {
6420 let ptr_ty = func_ctx.resolve_type(pointer, &module.types);
6421 let space = ptr_ty.pointer_space().unwrap();
6422 let space_name = space.to_msl_name().unwrap_or_default();
6423 let scalar = ptr_ty
6424 .pointer_base_type()
6425 .unwrap()
6426 .inner_with(&module.types)
6427 .scalar()
6428 .unwrap();
6429 let wrapped = WrappedFunction::CooperativeLoad {
6430 space_name,
6431 columns,
6432 rows,
6433 scalar,
6434 };
6435 if !self.wrapped_functions.insert(wrapped) {
6436 return Ok(());
6437 }
6438 let scalar_name = scalar.to_msl_name();
6439 writeln!(
6440 self.out,
6441 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{",
6442 columns as u32, rows as u32,
6443 )?;
6444 let l1 = back::Level(1);
6445 writeln!(
6446 self.out,
6447 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;",
6448 columns as u32, rows as u32
6449 )?;
6450 let matrix_origin = "0";
6451 writeln!(
6452 self.out,
6453 "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6454 )?;
6455 writeln!(self.out, "{l1}return m;")?;
6456 writeln!(self.out, "}}")?;
6457 writeln!(self.out)?;
6458 Ok(())
6459 }
6460
6461 fn write_wrapped_cooperative_multiply_add(
6462 &mut self,
6463 module: &crate::Module,
6464 func_ctx: &back::FunctionCtx,
6465 space: crate::AddressSpace,
6466 a: Handle<crate::Expression>,
6467 b: Handle<crate::Expression>,
6468 ) -> BackendResult {
6469 let space_name = space.to_msl_name().unwrap_or_default();
6470 let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6471 crate::TypeInner::CooperativeMatrix {
6472 columns,
6473 rows,
6474 scalar,
6475 ..
6476 } => (columns, rows, scalar),
6477 _ => unreachable!(),
6478 };
6479 let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6480 crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6481 _ => unreachable!(),
6482 };
6483 let wrapped = WrappedFunction::CooperativeMultiplyAdd {
6484 space_name,
6485 columns: b_c,
6486 rows: a_r,
6487 intermediate: a_c,
6488 scalar,
6489 };
6490 if !self.wrapped_functions.insert(wrapped) {
6491 return Ok(());
6492 }
6493 let scalar_name = scalar.to_msl_name();
6494 writeln!(
6495 self.out,
6496 "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
6497 b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
6498 )?;
6499 let l1 = back::Level(1);
6500 writeln!(
6501 self.out,
6502 "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
6503 b_c as u32, a_r as u32
6504 )?;
6505 writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
6506 writeln!(self.out, "{l1}return d;")?;
6507 writeln!(self.out, "}}")?;
6508 writeln!(self.out)?;
6509 Ok(())
6510 }
6511
6512 pub(super) fn write_wrapped_functions(
6513 &mut self,
6514 module: &crate::Module,
6515 func_ctx: &back::FunctionCtx,
6516 ) -> BackendResult {
6517 for (expr_handle, expr) in func_ctx.expressions.iter() {
6518 match *expr {
6519 crate::Expression::Unary { op, expr: operand } => {
6520 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6521 }
6522 crate::Expression::Binary { op, left, right } => {
6523 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6524 }
6525 crate::Expression::Math {
6526 fun,
6527 arg,
6528 arg1,
6529 arg2,
6530 arg3,
6531 } => {
6532 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6533 }
6534 crate::Expression::As {
6535 expr,
6536 kind,
6537 convert,
6538 } => {
6539 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6540 }
6541 crate::Expression::ImageLoad {
6542 image,
6543 coordinate,
6544 array_index,
6545 sample,
6546 level,
6547 } => {
6548 self.write_wrapped_image_load(
6549 module,
6550 func_ctx,
6551 image,
6552 coordinate,
6553 array_index,
6554 sample,
6555 level,
6556 )?;
6557 }
6558 crate::Expression::ImageSample {
6559 image,
6560 sampler,
6561 gather,
6562 coordinate,
6563 array_index,
6564 offset,
6565 level,
6566 depth_ref,
6567 clamp_to_edge,
6568 } => {
6569 self.write_wrapped_image_sample(
6570 module,
6571 func_ctx,
6572 image,
6573 sampler,
6574 gather,
6575 coordinate,
6576 array_index,
6577 offset,
6578 level,
6579 depth_ref,
6580 clamp_to_edge,
6581 )?;
6582 }
6583 crate::Expression::ImageQuery { image, query } => {
6584 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6585 }
6586 crate::Expression::CooperativeLoad {
6587 columns,
6588 rows,
6589 role: _,
6590 ref data,
6591 } => {
6592 self.write_wrapped_cooperative_load(
6593 module,
6594 func_ctx,
6595 columns,
6596 rows,
6597 data.pointer,
6598 )?;
6599 }
6600 crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
6601 let space = crate::AddressSpace::Private;
6602 self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
6603 }
6604 crate::Expression::RayQueryGetIntersection { committed, .. } => {
6605 self.write_rq_get_intersection_function(module, committed)?;
6606 }
6607 _ => {}
6608 }
6609 }
6610
6611 Ok(())
6612 }
6613
6614 fn write_functions(
6616 &mut self,
6617 module: &crate::Module,
6618 mod_info: &valid::ModuleInfo,
6619 options: &Options,
6620 pipeline_options: &PipelineOptions,
6621 ) -> Result<TranslationInfo, Error> {
6622 use back::msl::VertexFormat;
6623
6624 struct AttributeMappingResolved {
6627 ty_name: String,
6628 dimension: Option<crate::VectorSize>,
6629 scalar: crate::Scalar,
6630 name: String,
6631 }
6632 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6633
6634 struct VertexBufferMappingResolved<'a> {
6635 id: u32,
6636 stride: u32,
6637 step_mode: back::msl::VertexBufferStepMode,
6638 ty_name: String,
6639 param_name: String,
6640 elem_name: String,
6641 attributes: &'a Vec<back::msl::AttributeMapping>,
6642 }
6643 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6644
6645 struct UnpackingFunction {
6647 name: String,
6648 byte_count: u32,
6649 dimension: Option<crate::VectorSize>,
6650 scalar: crate::Scalar,
6651 }
6652 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6653
6654 let mut needs_vertex_id = false;
6660 let v_id = self.namer.call("v_id");
6661
6662 let mut needs_instance_id = false;
6663 let i_id = self.namer.call("i_id");
6664 if pipeline_options.vertex_pulling_transform {
6665 for vbm in &pipeline_options.vertex_buffer_mappings {
6666 let buffer_id = vbm.id;
6667 let buffer_stride = vbm.stride;
6668
6669 assert!(
6670 buffer_stride > 0,
6671 "Vertex pulling requires a non-zero buffer stride."
6672 );
6673
6674 match vbm.step_mode {
6675 back::msl::VertexBufferStepMode::Constant => {}
6676 back::msl::VertexBufferStepMode::ByVertex => {
6677 needs_vertex_id = true;
6678 }
6679 back::msl::VertexBufferStepMode::ByInstance => {
6680 needs_instance_id = true;
6681 }
6682 }
6683
6684 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6685 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6686 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6687
6688 vbm_resolved.push(VertexBufferMappingResolved {
6689 id: buffer_id,
6690 stride: buffer_stride,
6691 step_mode: vbm.step_mode,
6692 ty_name: buffer_ty,
6693 param_name: buffer_param,
6694 elem_name: buffer_elem,
6695 attributes: &vbm.attributes,
6696 });
6697
6698 for attribute in &vbm.attributes {
6700 if unpacking_functions.contains_key(&attribute.format) {
6701 continue;
6702 }
6703 let (name, byte_count, dimension, scalar) =
6704 match self.write_unpacking_function(attribute.format) {
6705 Ok((name, byte_count, dimension, scalar)) => {
6706 (name, byte_count, dimension, scalar)
6707 }
6708 _ => {
6709 continue;
6710 }
6711 };
6712 unpacking_functions.insert(
6713 attribute.format,
6714 UnpackingFunction {
6715 name,
6716 byte_count,
6717 dimension,
6718 scalar,
6719 },
6720 );
6721 }
6722 }
6723 }
6724
6725 let mut pass_through_globals = Vec::new();
6726 for (fun_handle, fun) in module.functions.iter() {
6727 log::trace!(
6728 "function {:?}, handle {:?}",
6729 fun.name.as_deref().unwrap_or("(anonymous)"),
6730 fun_handle
6731 );
6732
6733 let ctx = back::FunctionCtx {
6734 ty: back::FunctionType::Function(fun_handle),
6735 info: &mod_info[fun_handle],
6736 expressions: &fun.expressions,
6737 named_expressions: &fun.named_expressions,
6738 };
6739
6740 writeln!(self.out)?;
6741 self.write_wrapped_functions(module, &ctx)?;
6742
6743 let fun_info = &mod_info[fun_handle];
6744 pass_through_globals.clear();
6745 let mut needs_buffer_sizes = false;
6746 for (handle, var) in module.global_variables.iter() {
6747 if !fun_info[handle].is_empty() {
6748 if var.space.needs_pass_through() {
6749 pass_through_globals.push(handle);
6750 }
6751 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6752 }
6753 }
6754
6755 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6756 match fun.result {
6757 Some(ref result) => {
6758 let ty_name = TypeContext {
6759 handle: result.ty,
6760 gctx: module.to_ctx(),
6761 names: &self.names,
6762 access: crate::StorageAccess::empty(),
6763 first_time: false,
6764 };
6765 write!(self.out, "{ty_name}")?;
6766 }
6767 None => {
6768 write!(self.out, "void")?;
6769 }
6770 }
6771 writeln!(self.out, " {fun_name}(")?;
6772
6773 for (index, arg) in fun.arguments.iter().enumerate() {
6774 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6775 let param_type_name = TypeContext {
6776 handle: arg.ty,
6777 gctx: module.to_ctx(),
6778 names: &self.names,
6779 access: crate::StorageAccess::empty(),
6780 first_time: false,
6781 };
6782 let separator = separate(
6783 !pass_through_globals.is_empty()
6784 || index + 1 != fun.arguments.len()
6785 || needs_buffer_sizes,
6786 );
6787 writeln!(
6788 self.out,
6789 "{}{} {}{}",
6790 back::INDENT,
6791 param_type_name,
6792 name,
6793 separator
6794 )?;
6795 }
6796 for (index, &handle) in pass_through_globals.iter().enumerate() {
6797 let tyvar = TypedGlobalVariable {
6798 module,
6799 names: &self.names,
6800 handle,
6801 usage: fun_info[handle],
6802 reference: true,
6803 };
6804 let separator =
6805 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6806 write!(self.out, "{}", back::INDENT)?;
6807 tyvar.try_fmt(&mut self.out)?;
6808 writeln!(self.out, "{separator}")?;
6809 }
6810
6811 if needs_buffer_sizes {
6812 writeln!(
6813 self.out,
6814 "{}constant _mslBufferSizes& _buffer_sizes",
6815 back::INDENT
6816 )?;
6817 }
6818
6819 writeln!(self.out, ") {{")?;
6820
6821 let guarded_indices =
6822 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6823
6824 let context = StatementContext {
6825 expression: ExpressionContext {
6826 function: fun,
6827 origin: FunctionOrigin::Handle(fun_handle),
6828 info: fun_info,
6829 lang_version: options.lang_version,
6830 policies: options.bounds_check_policies,
6831 guarded_indices,
6832 module,
6833 mod_info,
6834 pipeline_options,
6835 force_loop_bounding: options.force_loop_bounding,
6836 },
6837 result_struct: None,
6838 };
6839
6840 self.put_locals(&context.expression)?;
6841 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6842 self.put_block(back::Level(1), &fun.body, &context)?;
6843 writeln!(self.out, "}}")?;
6844 self.named_expressions.clear();
6845 }
6846
6847 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6848 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6849
6850 let mut info = TranslationInfo {
6851 entry_point_names: Vec::with_capacity(ep_range.len()),
6852 };
6853
6854 for ep_index in ep_range {
6855 let ep = &module.entry_points[ep_index];
6856 let fun = &ep.function;
6857 let fun_info = mod_info.get_entry_point(ep_index);
6858 let mut ep_error = None;
6859
6860 let mut v_existing_id = None;
6864 let mut i_existing_id = None;
6865
6866 log::trace!(
6867 "entry point {:?}, index {:?}",
6868 fun.name.as_deref().unwrap_or("(anonymous)"),
6869 ep_index
6870 );
6871
6872 let ctx = back::FunctionCtx {
6873 ty: back::FunctionType::EntryPoint(ep_index as u16),
6874 info: fun_info,
6875 expressions: &fun.expressions,
6876 named_expressions: &fun.named_expressions,
6877 };
6878
6879 self.write_wrapped_functions(module, &ctx)?;
6880
6881 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6882 crate::ShaderStage::Vertex => (
6883 Some("vertex"),
6884 LocationMode::VertexInput,
6885 LocationMode::VertexOutput,
6886 true,
6887 ),
6888 crate::ShaderStage::Fragment => (
6889 Some("fragment"),
6890 LocationMode::FragmentInput,
6891 LocationMode::FragmentOutput,
6892 false,
6893 ),
6894 crate::ShaderStage::Compute => (
6895 Some("kernel"),
6896 LocationMode::Uniform,
6897 LocationMode::Uniform,
6898 false,
6899 ),
6900 crate::ShaderStage::Task => {
6901 (None, LocationMode::Uniform, LocationMode::Uniform, false)
6902 }
6903 crate::ShaderStage::Mesh => {
6904 (None, LocationMode::Uniform, LocationMode::MeshOutput, false)
6905 }
6906 crate::ShaderStage::RayGeneration
6907 | crate::ShaderStage::AnyHit
6908 | crate::ShaderStage::ClosestHit
6909 | crate::ShaderStage::Miss => unimplemented!(),
6910 };
6911
6912 let do_vertex_pulling = can_vertex_pull
6914 && pipeline_options.vertex_pulling_transform
6915 && !pipeline_options.vertex_buffer_mappings.is_empty();
6916
6917 let needs_buffer_sizes = do_vertex_pulling
6919 || module
6920 .global_variables
6921 .iter()
6922 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6923 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6924
6925 if !options.fake_missing_bindings {
6928 for (var_handle, var) in module.global_variables.iter() {
6929 if fun_info[var_handle].is_empty() {
6930 continue;
6931 }
6932 match var.space {
6933 crate::AddressSpace::Uniform
6934 | crate::AddressSpace::Storage { .. }
6935 | crate::AddressSpace::Handle => {
6936 let br = match var.binding {
6937 Some(ref br) => br,
6938 None => {
6939 let var_name = var.name.clone().unwrap_or_default();
6940 ep_error =
6941 Some(super::EntryPointError::MissingBinding(var_name));
6942 break;
6943 }
6944 };
6945 let target = options.get_resource_binding_target(ep, br);
6946 let good = match target {
6947 Some(target) => {
6948 match module.types[var.ty].inner {
6952 crate::TypeInner::Image {
6953 class: crate::ImageClass::External,
6954 ..
6955 } => target.external_texture.is_some(),
6956 crate::TypeInner::Image { .. } => target.texture.is_some(),
6957 crate::TypeInner::Sampler { .. } => {
6958 target.sampler.is_some()
6959 }
6960 _ => target.buffer.is_some(),
6961 }
6962 }
6963 None => false,
6964 };
6965 if !good {
6966 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6967 break;
6968 }
6969 }
6970 crate::AddressSpace::Immediate => {
6971 if let Err(e) = options.resolve_immediates(ep) {
6972 ep_error = Some(e);
6973 break;
6974 }
6975 }
6976 crate::AddressSpace::Function
6977 | crate::AddressSpace::Private
6978 | crate::AddressSpace::WorkGroup
6979 | crate::AddressSpace::TaskPayload => {}
6980 crate::AddressSpace::RayPayload
6981 | crate::AddressSpace::IncomingRayPayload => unimplemented!(),
6982 }
6983 }
6984 if needs_buffer_sizes {
6985 if let Err(err) = options.resolve_sizes_buffer(ep) {
6986 ep_error = Some(err);
6987 }
6988 }
6989 }
6990
6991 if let Some(err) = ep_error {
6992 info.entry_point_names.push(Err(err));
6993 continue;
6994 }
6995 let fun_name = self.names[&NameKey::EntryPoint(ep_index as _)].clone();
6996 info.entry_point_names.push(Ok(fun_name.clone()));
6997
6998 writeln!(self.out)?;
6999
7000 let mut flattened_member_names = FastHashMap::default();
7006 let mut varyings_namer = proc::Namer::default();
7008
7009 let mut empty_names = FastHashMap::default(); varyings_namer.reset(
7011 module,
7012 &super::keywords::RESERVED_SET,
7013 proc::KeywordSet::empty(),
7014 proc::CaseInsensitiveKeywordSet::empty(),
7015 &[CLAMPED_LOD_LOAD_PREFIX],
7016 &mut empty_names,
7017 );
7018
7019 let mut flattened_arguments = Vec::new();
7024 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7025 match module.types[arg.ty].inner {
7026 crate::TypeInner::Struct { ref members, .. } => {
7027 for (member_index, member) in members.iter().enumerate() {
7028 let member_index = member_index as u32;
7029 flattened_arguments.push((
7030 NameKey::StructMember(arg.ty, member_index),
7031 member.ty,
7032 member.binding.as_ref(),
7033 ));
7034 let name_key = NameKey::StructMember(arg.ty, member_index);
7035 let name = match member.binding {
7036 Some(crate::Binding::Location { .. }) => {
7037 if do_vertex_pulling {
7038 self.namer.call(&self.names[&name_key])
7039 } else {
7040 varyings_namer.call(&self.names[&name_key])
7041 }
7042 }
7043 _ => self.namer.call(&self.names[&name_key]),
7044 };
7045 flattened_member_names.insert(name_key, name);
7046 }
7047 }
7048 _ => flattened_arguments.push((
7049 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
7050 arg.ty,
7051 arg.binding.as_ref(),
7052 )),
7053 }
7054 }
7055
7056 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
7061 let varyings_member_name = self.namer.call("varyings");
7062 let mut has_varyings = false;
7063
7064 if !flattened_arguments.is_empty() {
7065 if !do_vertex_pulling {
7066 writeln!(self.out, "struct {stage_in_name} {{")?;
7067 }
7068 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7069 let Some(binding) = binding else {
7070 continue;
7071 };
7072 let name = match *name_key {
7073 NameKey::StructMember(..) => &flattened_member_names[name_key],
7074 _ => &self.names[name_key],
7075 };
7076 let ty_name = TypeContext {
7077 handle: ty,
7078 gctx: module.to_ctx(),
7079 names: &self.names,
7080 access: crate::StorageAccess::empty(),
7081 first_time: false,
7082 };
7083 let resolved = options.resolve_local_binding(binding, in_mode)?;
7084 let location = match *binding {
7085 crate::Binding::Location { location, .. } => Some(location),
7086 crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. }) => None,
7087 crate::Binding::BuiltIn(_) => continue,
7088 };
7089 if do_vertex_pulling {
7090 let Some(location) = location else {
7091 continue;
7092 };
7093 am_resolved.insert(
7095 location,
7096 AttributeMappingResolved {
7097 ty_name: ty_name.to_string(),
7098 dimension: ty_name.vector_size(),
7099 scalar: ty_name.scalar().unwrap(),
7100 name: name.to_string(),
7101 },
7102 );
7103 } else {
7104 has_varyings = true;
7105 if let super::ResolvedBinding::User {
7106 prefix,
7107 index,
7108 interpolation: Some(super::ResolvedInterpolation::PerVertex),
7109 } = resolved
7110 {
7111 if options.lang_version < (4, 0) {
7112 return Err(Error::PerVertexNotSupported);
7113 }
7114 write!(
7115 self.out,
7116 "{}{NAMESPACE}::vertex_value<{}> {name} [[user({prefix}{index})]]",
7117 back::INDENT,
7118 ty_name.unwrap_array()
7119 )?;
7120 } else {
7121 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7122 resolved.try_fmt(&mut self.out)?;
7123 }
7124 writeln!(self.out, ";")?;
7125 }
7126 }
7127 if !do_vertex_pulling {
7128 writeln!(self.out, "}};")?;
7129 }
7130 }
7131
7132 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
7135 let result_member_name = self.namer.call("member");
7136 let result_type_name = match fun.result {
7137 Some(ref result) if ep.stage != crate::ShaderStage::Task => {
7138 let mut result_members = Vec::new();
7139 if let crate::TypeInner::Struct { ref members, .. } =
7140 module.types[result.ty].inner
7141 {
7142 for (member_index, member) in members.iter().enumerate() {
7143 result_members.push((
7144 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
7145 member.ty,
7146 member.binding.as_ref(),
7147 ));
7148 }
7149 } else {
7150 result_members.push((
7151 &result_member_name,
7152 result.ty,
7153 result.binding.as_ref(),
7154 ));
7155 }
7156
7157 writeln!(self.out, "struct {stage_out_name} {{")?;
7158 let mut has_point_size = false;
7159 for (name, ty, binding) in result_members {
7160 let ty_name = TypeContext {
7161 handle: ty,
7162 gctx: module.to_ctx(),
7163 names: &self.names,
7164 access: crate::StorageAccess::empty(),
7165 first_time: true,
7166 };
7167 let binding = binding.ok_or_else(|| {
7168 Error::GenericValidation("Expected binding, got None".into())
7169 })?;
7170
7171 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
7172 has_point_size = true;
7173 if !pipeline_options.allow_and_force_point_size {
7174 continue;
7175 }
7176 }
7177
7178 let array_len = match module.types[ty].inner {
7179 crate::TypeInner::Array {
7180 size: crate::ArraySize::Constant(size),
7181 ..
7182 } => Some(size),
7183 _ => None,
7184 };
7185 let resolved = options.resolve_local_binding(binding, out_mode)?;
7186 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
7187 resolved.try_fmt(&mut self.out)?;
7188 if let Some(array_len) = array_len {
7189 write!(self.out, " [{array_len}]")?;
7190 }
7191 writeln!(self.out, ";")?;
7192 }
7193
7194 if pipeline_options.allow_and_force_point_size
7195 && ep.stage == crate::ShaderStage::Vertex
7196 && !has_point_size
7197 {
7198 writeln!(
7200 self.out,
7201 "{}float _point_size [[point_size]];",
7202 back::INDENT
7203 )?;
7204 }
7205 writeln!(self.out, "}};")?;
7206 &stage_out_name
7207 }
7208 Some(ref result) if ep.stage == crate::ShaderStage::Task => {
7209 assert_eq!(
7210 module.types[result.ty].inner,
7211 crate::TypeInner::Vector {
7212 size: crate::VectorSize::Tri,
7213 scalar: crate::Scalar::U32
7214 }
7215 );
7216
7217 "metal::uint3"
7218 }
7219 _ => "void",
7220 };
7221
7222 let out_mesh_info = if let Some(ref mesh_info) = ep.mesh_info {
7223 Some(self.write_mesh_output_types(
7224 mesh_info,
7225 &fun_name,
7226 module,
7227 pipeline_options.allow_and_force_point_size,
7228 options,
7229 )?)
7230 } else {
7231 None
7232 };
7233
7234 if do_vertex_pulling {
7237 for vbm in &vbm_resolved {
7238 let buffer_stride = vbm.stride;
7239 let buffer_ty = &vbm.ty_name;
7240
7241 writeln!(
7245 self.out,
7246 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
7247 )?;
7248 }
7249 }
7250
7251 let is_wrapped = matches!(
7252 ep.stage,
7253 crate::ShaderStage::Task | crate::ShaderStage::Mesh
7254 );
7255 let fun_name = fun_name.clone();
7256 let nested_fun_name = if is_wrapped {
7257 self.namer.call(&format!("_{fun_name}"))
7258 } else {
7259 fun_name.clone()
7260 };
7261
7262 if ep.stage == crate::ShaderStage::Compute && options.lang_version >= (2, 1) {
7264 let total_threads =
7265 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2];
7266 write!(
7267 self.out,
7268 "[[max_total_threads_per_threadgroup({total_threads})]] "
7269 )?;
7270 }
7271
7272 if let Some(em_str) = em_str {
7274 write!(self.out, "{em_str} ")?;
7275 }
7276 writeln!(self.out, "{result_type_name} {nested_fun_name}(")?;
7277
7278 let mut args = Vec::new();
7279
7280 if has_varyings {
7283 args.push(EntryPointArgument {
7284 ty_name: stage_in_name,
7285 name: varyings_member_name.clone(),
7286 binding: " [[stage_in]]".to_string(),
7287 init: None,
7288 });
7289 }
7290
7291 let mut local_invocation_index = None;
7292
7293 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
7296 let binding = match binding {
7297 Some(&crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => continue,
7298 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
7299 _ => continue,
7300 };
7301 let name = match *name_key {
7302 NameKey::StructMember(..) => &flattened_member_names[name_key],
7303 _ => &self.names[name_key],
7304 };
7305
7306 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) {
7307 local_invocation_index = Some(name_key);
7308 }
7309
7310 let ty_name = TypeContext {
7311 handle: ty,
7312 gctx: module.to_ctx(),
7313 names: &self.names,
7314 access: crate::StorageAccess::empty(),
7315 first_time: false,
7316 };
7317
7318 match *binding {
7319 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7320 v_existing_id = Some(name.clone());
7321 }
7322 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7323 i_existing_id = Some(name.clone());
7324 }
7325 _ => {}
7326 };
7327
7328 let resolved = options.resolve_local_binding(binding, in_mode)?;
7329 let mut binding = String::new();
7330 resolved.try_fmt(&mut binding)?;
7331
7332 args.push(EntryPointArgument {
7333 ty_name: format!("{ty_name}"),
7334 name: name.clone(),
7335 binding,
7336 init: None,
7337 });
7338 }
7339
7340 let need_workgroup_variables_initialization =
7341 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7342
7343 if local_invocation_index.is_none()
7344 && (need_workgroup_variables_initialization
7345 || ep.stage == crate::ShaderStage::Task
7346 || ep.stage == crate::ShaderStage::Mesh)
7347 {
7348 args.push(EntryPointArgument {
7349 ty_name: "uint".to_string(),
7350 name: "__local_invocation_index".to_string(),
7351 binding: " [[thread_index_in_threadgroup]]".to_string(),
7352 init: None,
7353 });
7354 }
7355
7356 for (handle, var) in module.global_variables.iter() {
7361 let usage = fun_info[handle];
7362 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7363 continue;
7364 }
7365
7366 if options.lang_version < (1, 2) {
7367 match var.space {
7368 crate::AddressSpace::Storage { access }
7378 if access.contains(crate::StorageAccess::STORE)
7379 && ep.stage == crate::ShaderStage::Fragment =>
7380 {
7381 return Err(Error::UnsupportedWritableStorageBuffer)
7382 }
7383 crate::AddressSpace::Handle => {
7384 match module.types[var.ty].inner {
7385 crate::TypeInner::Image {
7386 class: crate::ImageClass::Storage { access, .. },
7387 ..
7388 } => {
7389 if access.contains(crate::StorageAccess::STORE)
7399 && (ep.stage == crate::ShaderStage::Vertex
7400 || ep.stage == crate::ShaderStage::Fragment)
7401 {
7402 return Err(Error::UnsupportedWritableStorageTexture(
7403 ep.stage,
7404 ));
7405 }
7406
7407 if access.contains(
7408 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7409 ) {
7410 return Err(Error::UnsupportedRWStorageTexture);
7411 }
7412 }
7413 _ => {}
7414 }
7415 }
7416 _ => {}
7417 }
7418 }
7419
7420 match var.space {
7422 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7423 crate::TypeInner::BindingArray { base, .. } => {
7424 match module.types[base].inner {
7425 crate::TypeInner::Sampler { .. } => {
7426 if options.lang_version < (2, 0) {
7427 return Err(Error::UnsupportedArrayOf(
7428 "samplers".to_string(),
7429 ));
7430 }
7431 }
7432 crate::TypeInner::Image { class, .. } => match class {
7433 crate::ImageClass::Sampled { .. }
7434 | crate::ImageClass::Depth { .. }
7435 | crate::ImageClass::Storage {
7436 access: crate::StorageAccess::LOAD,
7437 ..
7438 } => {
7439 if options.lang_version < (2, 0) {
7444 return Err(Error::UnsupportedArrayOf(
7445 "textures".to_string(),
7446 ));
7447 }
7448 }
7449 crate::ImageClass::Storage {
7450 access: crate::StorageAccess::STORE,
7451 ..
7452 } => {
7453 if options.lang_version < (2, 0) {
7458 return Err(Error::UnsupportedArrayOf(
7459 "write-only textures".to_string(),
7460 ));
7461 }
7462 }
7463 crate::ImageClass::Storage { .. } => {
7464 if options.lang_version < (3, 0) {
7465 return Err(Error::UnsupportedArrayOf(
7466 "read-write textures".to_string(),
7467 ));
7468 }
7469 }
7470 crate::ImageClass::External => {
7471 return Err(Error::UnsupportedArrayOf(
7472 "external textures".to_string(),
7473 ));
7474 }
7475 },
7476 _ => {
7477 return Err(Error::UnsupportedArrayOfType(base));
7478 }
7479 }
7480 }
7481 _ => {}
7482 },
7483 _ => {}
7484 }
7485
7486 let resolved = match var.space {
7488 crate::AddressSpace::Immediate => options.resolve_immediates(ep).ok(),
7489 crate::AddressSpace::WorkGroup => None,
7490 crate::AddressSpace::TaskPayload => Some(back::msl::ResolvedBinding::Payload),
7491 _ => options
7492 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7493 .ok(),
7494 };
7495 if let Some(ref resolved) = resolved {
7496 if resolved.as_inline_sampler(options).is_some() {
7498 continue;
7499 }
7500 }
7501
7502 match module.types[var.ty].inner {
7503 crate::TypeInner::Image {
7504 class: crate::ImageClass::External,
7505 ..
7506 } => {
7507 let target = match resolved {
7511 Some(back::msl::ResolvedBinding::Resource(target)) => {
7512 target.external_texture
7513 }
7514 _ => None,
7515 };
7516
7517 for i in 0..3 {
7518 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7519 handle,
7520 ExternalTextureNameKey::Plane(i),
7521 )];
7522 let ty_name = format!(
7523 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample>"
7524 );
7525 let name = plane_name.clone();
7526 let binding = if let Some(ref target) = target {
7527 format!(" [[texture({})]]", target.planes[i])
7528 } else {
7529 String::new()
7530 };
7531 args.push(EntryPointArgument {
7532 ty_name,
7533 name,
7534 binding,
7535 init: None,
7536 });
7537 }
7538 let params_ty_name = &self.names
7539 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7540 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7541 handle,
7542 ExternalTextureNameKey::Params,
7543 )];
7544 let binding = if let Some(ref target) = target {
7545 format!(" [[buffer({})]]", target.params)
7546 } else {
7547 String::new()
7548 };
7549
7550 args.push(EntryPointArgument {
7551 ty_name: format!("constant {params_ty_name}&"),
7552 name: params_name.clone(),
7553 binding,
7554 init: None,
7555 });
7556 }
7557 _ => {
7558 if var.space == crate::AddressSpace::WorkGroup
7559 && ep.stage == crate::ShaderStage::Mesh
7560 {
7561 continue;
7562 }
7563 let tyvar = TypedGlobalVariable {
7564 module,
7565 names: &self.names,
7566 handle,
7567 usage,
7568 reference: true,
7569 };
7570 let parts = tyvar.to_parts()?;
7571 let mut binding = String::new();
7572 if let Some(resolved) = resolved {
7573 resolved.try_fmt(&mut binding)?;
7574 }
7575 args.push(EntryPointArgument {
7576 ty_name: parts.ty_name,
7577 name: parts.var_name,
7578 binding,
7579 init: var.init,
7580 });
7581 }
7582 }
7583 }
7584
7585 if do_vertex_pulling {
7586 if needs_vertex_id && v_existing_id.is_none() {
7587 args.push(EntryPointArgument {
7589 ty_name: "uint".to_string(),
7590 name: v_id.clone(),
7591 binding: " [[vertex_id]]".to_string(),
7592 init: None,
7593 });
7594 }
7595
7596 if needs_instance_id && i_existing_id.is_none() {
7597 args.push(EntryPointArgument {
7598 ty_name: "uint".to_string(),
7599 name: i_id.clone(),
7600 binding: " [[instance_id]]".to_string(),
7601 init: None,
7602 });
7603 }
7604
7605 for vbm in &vbm_resolved {
7608 let id = &vbm.id;
7609 let ty_name = &vbm.ty_name;
7610 let param_name = &vbm.param_name;
7611 args.push(EntryPointArgument {
7612 ty_name: format!("const device {ty_name}*"),
7613 name: param_name.clone(),
7614 binding: format!(" [[buffer({id})]]"),
7615 init: None,
7616 });
7617 }
7618 }
7619
7620 if needs_buffer_sizes {
7623 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7625 let mut binding = String::new();
7626 resolved.try_fmt(&mut binding)?;
7627 args.push(EntryPointArgument {
7628 ty_name: "constant _mslBufferSizes&".to_string(),
7629 name: "_buffer_sizes".to_string(),
7630 binding,
7631 init: None,
7632 });
7633 }
7634
7635 let mut is_first_arg = true;
7636 for arg in &args {
7637 if is_first_arg {
7638 write!(self.out, " ")?;
7639 } else {
7640 write!(self.out, ", ")?;
7641 }
7642 is_first_arg = false;
7643 write!(self.out, "{} {}", arg.ty_name, arg.name)?;
7644 if !is_wrapped {
7645 write!(self.out, "{}", arg.binding)?;
7646 if let Some(init) = arg.init {
7647 write!(self.out, " = ")?;
7648 self.put_const_expression(
7649 init,
7650 module,
7651 mod_info,
7652 &module.global_expressions,
7653 )?;
7654 }
7655 }
7656 writeln!(self.out)?;
7657 }
7658 if ep.stage == crate::ShaderStage::Mesh {
7659 for (handle, var) in module.global_variables.iter() {
7660 if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
7661 continue;
7662 }
7663 if is_first_arg {
7664 write!(self.out, " ")?;
7665 } else {
7666 write!(self.out, ", ")?;
7667 }
7668 let ty_context = TypeContext {
7669 handle: module.global_variables[handle].ty,
7670 gctx: module.to_ctx(),
7671 names: &self.names,
7672 access: crate::StorageAccess::empty(),
7673 first_time: false,
7674 };
7675 writeln!(
7676 self.out,
7677 "threadgroup {ty_context}& {}",
7678 self.names[&NameKey::GlobalVariable(handle)]
7679 )?;
7680 }
7681 }
7682
7683 writeln!(self.out, ") {{")?;
7685
7686 if do_vertex_pulling {
7688 for vbm in &vbm_resolved {
7691 for attribute in vbm.attributes {
7692 let location = attribute.shader_location;
7693 let am_option = am_resolved.get(&location);
7694 if am_option.is_none() {
7695 continue;
7698 }
7699 let am = am_option.unwrap();
7700 let attribute_ty_name = &am.ty_name;
7701 let attribute_name = &am.name;
7702
7703 writeln!(
7704 self.out,
7705 "{}{attribute_ty_name} {attribute_name} = {{}};",
7706 back::Level(1)
7707 )?;
7708 }
7709
7710 write!(self.out, "{}if (", back::Level(1))?;
7713
7714 let idx = &vbm.id;
7715 let stride = &vbm.stride;
7716 let index_name = match vbm.step_mode {
7717 back::msl::VertexBufferStepMode::Constant => "0",
7718 back::msl::VertexBufferStepMode::ByVertex => {
7719 if let Some(ref name) = v_existing_id {
7720 name
7721 } else {
7722 &v_id
7723 }
7724 }
7725 back::msl::VertexBufferStepMode::ByInstance => {
7726 if let Some(ref name) = i_existing_id {
7727 name
7728 } else {
7729 &i_id
7730 }
7731 }
7732 };
7733 write!(
7734 self.out,
7735 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7736 )?;
7737
7738 writeln!(self.out, ") {{")?;
7739
7740 let ty_name = &vbm.ty_name;
7742 let elem_name = &vbm.elem_name;
7743 let param_name = &vbm.param_name;
7744
7745 writeln!(
7746 self.out,
7747 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7748 back::Level(2),
7749 )?;
7750
7751 for attribute in vbm.attributes {
7754 let location = attribute.shader_location;
7755 let Some(am) = am_resolved.get(&location) else {
7756 continue;
7760 };
7761 let attribute_name = &am.name;
7762 let attribute_ty_name = &am.ty_name;
7763
7764 let offset = attribute.offset;
7765 let func = unpacking_functions
7766 .get(&attribute.format)
7767 .expect("Should have generated this unpacking function earlier.");
7768 let func_name = &func.name;
7769
7770 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7776
7777 let needs_conversion = am.scalar != func.scalar;
7780
7781 if needs_padding_or_truncation != Ordering::Equal {
7782 writeln!(
7785 self.out,
7786 "{}// {attribute_ty_name} <- {:?}",
7787 back::Level(2),
7788 attribute.format
7789 )?;
7790 }
7791
7792 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7793
7794 if needs_padding_or_truncation == Ordering::Greater {
7795 write!(self.out, "{attribute_ty_name}(")?;
7797 }
7798
7799 if needs_conversion {
7801 put_numeric_type(&mut self.out, am.scalar, func.dimension.as_slice())?;
7802 write!(self.out, "(")?;
7803 }
7804 write!(self.out, "{func_name}({elem_name}.data[{offset}]")?;
7805 for i in (offset + 1)..(offset + func.byte_count) {
7806 write!(self.out, ", {elem_name}.data[{i}]")?;
7807 }
7808 write!(self.out, ")")?;
7809 if needs_conversion {
7810 write!(self.out, ")")?;
7811 }
7812
7813 match needs_padding_or_truncation {
7814 Ordering::Greater => {
7815 let ty_is_int = scalar_is_int(am.scalar);
7817 let zero_value = if ty_is_int { "0" } else { "0.0" };
7818 let one_value = if ty_is_int { "1" } else { "1.0" };
7819 for i in func.dimension.map_or(1, u8::from)
7820 ..am.dimension.map_or(1, u8::from)
7821 {
7822 write!(
7823 self.out,
7824 ", {}",
7825 if i == 3 { one_value } else { zero_value }
7826 )?;
7827 }
7828 }
7829 Ordering::Less => {
7830 write!(
7832 self.out,
7833 ".{}",
7834 &"xyzw"[0..usize::from(am.dimension.map_or(1, u8::from))]
7835 )?;
7836 }
7837 Ordering::Equal => {}
7838 }
7839
7840 if needs_padding_or_truncation == Ordering::Greater {
7841 write!(self.out, ")")?;
7842 }
7843
7844 writeln!(self.out, ";")?;
7845 }
7846
7847 writeln!(self.out, "{}}}", back::Level(1))?;
7849 }
7850 }
7851
7852 for (handle, var) in module.global_variables.iter() {
7855 let usage = fun_info[handle];
7856 if usage.is_empty() {
7857 continue;
7858 }
7859 if var.space == crate::AddressSpace::Private {
7860 let tyvar = TypedGlobalVariable {
7861 module,
7862 names: &self.names,
7863 handle,
7864 usage,
7865
7866 reference: false,
7867 };
7868 write!(self.out, "{}", back::INDENT)?;
7869 tyvar.try_fmt(&mut self.out)?;
7870 match var.init {
7871 Some(value) => {
7872 write!(self.out, " = ")?;
7873 self.put_const_expression(
7874 value,
7875 module,
7876 mod_info,
7877 &module.global_expressions,
7878 )?;
7879 writeln!(self.out, ";")?;
7880 }
7881 None => {
7882 writeln!(self.out, " = {{}};")?;
7883 }
7884 };
7885 } else if let Some(ref binding) = var.binding {
7886 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7887 if let Some(sampler) = resolved.as_inline_sampler(options) {
7888 let name = &self.names[&NameKey::GlobalVariable(handle)];
7890 writeln!(
7891 self.out,
7892 "{}constexpr {}::sampler {}(",
7893 back::INDENT,
7894 NAMESPACE,
7895 name
7896 )?;
7897 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7898 writeln!(self.out, "{});", back::INDENT)?;
7899 } else if let crate::TypeInner::Image {
7900 class: crate::ImageClass::External,
7901 ..
7902 } = module.types[var.ty].inner
7903 {
7904 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7907 let l1 = back::Level(1);
7908 let l2 = l1.next();
7909 writeln!(
7910 self.out,
7911 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7912 )?;
7913 for i in 0..3 {
7914 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7915 handle,
7916 ExternalTextureNameKey::Plane(i),
7917 )];
7918 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7919 }
7920 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7921 handle,
7922 ExternalTextureNameKey::Params,
7923 )];
7924 writeln!(self.out, "{l2}.params = {params_name},")?;
7925 writeln!(self.out, "{l1}}};")?;
7926 }
7927 }
7928 }
7929
7930 if need_workgroup_variables_initialization {
7931 self.write_workgroup_variables_initialization(
7932 module,
7933 mod_info,
7934 fun_info,
7935 local_invocation_index,
7936 ep.stage,
7937 )?;
7938 }
7939
7940 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7951 let arg_name =
7952 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7953 match module.types[arg.ty].inner {
7954 crate::TypeInner::Struct { ref members, .. } => {
7955 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7956 write!(
7957 self.out,
7958 "{}const {} {} = {{ ",
7959 back::INDENT,
7960 struct_name,
7961 arg_name
7962 )?;
7963 for (member_index, member) in members.iter().enumerate() {
7964 let key = NameKey::StructMember(arg.ty, member_index as u32);
7965 let name = &flattened_member_names[&key];
7966 if member_index != 0 {
7967 write!(self.out, ", ")?;
7968 }
7969 if self
7971 .struct_member_pads
7972 .contains(&(arg.ty, member_index as u32))
7973 {
7974 write!(self.out, "{{}}, ")?;
7975 }
7976 match member.binding {
7977 Some(crate::Binding::Location {
7978 interpolation: Some(crate::Interpolation::PerVertex),
7979 ..
7980 }) => {
7981 writeln!(
7982 self.out,
7983 "{0}{{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }}",
7984 back::INDENT,
7985 varyings_member_name,
7986 arg_name,
7987 )?;
7988 continue;
7989 }
7990 Some(crate::Binding::Location { .. }) => {
7991 if has_varyings {
7992 write!(self.out, "{varyings_member_name}.")?;
7993 }
7994 }
7995 _ => (),
7996 }
7997 write!(self.out, "{name}")?;
7998 }
7999 writeln!(self.out, " }};")?;
8000 }
8001 _ => match arg.binding {
8002 Some(crate::Binding::Location {
8003 interpolation: Some(crate::Interpolation::PerVertex),
8004 ..
8005 }) => {
8006 let ty_name = TypeContext {
8007 handle: arg.ty,
8008 gctx: module.to_ctx(),
8009 names: &self.names,
8010 access: crate::StorageAccess::empty(),
8011 first_time: false,
8012 };
8013 writeln!(
8014 self.out,
8015 "{0}const {ty_name} {arg_name} = {{ {1}.{2}.get({NAMESPACE}::vertex_index::first), {1}.{2}.get({NAMESPACE}::vertex_index::second), {1}.{2}.get({NAMESPACE}::vertex_index::third) }};",
8016 back::INDENT,
8017 varyings_member_name,
8018 arg_name,
8019 )?;
8020 }
8021 Some(crate::Binding::Location { .. })
8022 | Some(crate::Binding::BuiltIn(crate::BuiltIn::Barycentric { .. })) => {
8023 if has_varyings {
8024 writeln!(
8025 self.out,
8026 "{}const auto {} = {}.{};",
8027 back::INDENT,
8028 arg_name,
8029 varyings_member_name,
8030 arg_name
8031 )?;
8032 }
8033 }
8034 _ => {}
8035 },
8036 }
8037 }
8038
8039 let guarded_indices =
8040 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
8041
8042 let context = StatementContext {
8043 expression: ExpressionContext {
8044 function: fun,
8045 origin: FunctionOrigin::EntryPoint(ep_index as _),
8046 info: fun_info,
8047 lang_version: options.lang_version,
8048 policies: options.bounds_check_policies,
8049 guarded_indices,
8050 module,
8051 mod_info,
8052 pipeline_options,
8053 force_loop_bounding: options.force_loop_bounding,
8054 },
8055 result_struct: if ep.stage == crate::ShaderStage::Task {
8056 None
8057 } else {
8058 Some(&stage_out_name)
8059 },
8060 };
8061
8062 self.put_locals(&context.expression)?;
8065 self.update_expressions_to_bake(fun, fun_info, &context.expression);
8066 self.put_block(back::Level(1), &fun.body, &context)?;
8067 writeln!(self.out, "}}")?;
8068 if ep_index + 1 != module.entry_points.len() {
8069 writeln!(self.out)?;
8070 }
8071 self.named_expressions.clear();
8072
8073 if is_wrapped {
8074 self.write_wrapper_function(NestedFunctionInfo {
8075 options,
8076 ep,
8077 module,
8078 mod_info,
8079 fun_info,
8080 args,
8081 local_invocation_index,
8082 nested_name: &nested_fun_name,
8083 outer_name: &fun_name,
8084 out_mesh_info,
8085 })?;
8086 }
8087 }
8088
8089 Ok(info)
8090 }
8091
8092 pub(super) fn write_barrier(
8093 &mut self,
8094 flags: crate::Barrier,
8095 level: back::Level,
8096 ) -> BackendResult {
8097 if flags.is_empty() {
8100 writeln!(
8101 self.out,
8102 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
8103 )?;
8104 }
8105 if flags.contains(crate::Barrier::STORAGE) {
8106 writeln!(
8107 self.out,
8108 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
8109 )?;
8110 }
8111 if flags.contains(crate::Barrier::WORK_GROUP) {
8112 writeln!(
8113 self.out,
8114 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8115 )?;
8116 if self.needs_object_memory_barriers {
8117 writeln!(
8118 self.out,
8119 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_object_data);",
8120 )?;
8121 }
8122 }
8123 if flags.contains(crate::Barrier::SUB_GROUP) {
8124 writeln!(
8125 self.out,
8126 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
8127 )?;
8128 }
8129 if flags.contains(crate::Barrier::TEXTURE) {
8130 writeln!(
8131 self.out,
8132 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
8133 )?;
8134 }
8135 Ok(())
8136 }
8137}
8138
8139mod workgroup_mem_init {
8142 use crate::EntryPoint;
8143
8144 use super::*;
8145
8146 enum Access {
8147 GlobalVariable(Handle<crate::GlobalVariable>),
8148 StructMember(Handle<crate::Type>, u32),
8149 Array(usize),
8150 }
8151
8152 impl Access {
8153 fn write<W: Write>(
8154 &self,
8155 writer: &mut W,
8156 names: &FastHashMap<NameKey, String>,
8157 ) -> Result<(), core::fmt::Error> {
8158 match *self {
8159 Access::GlobalVariable(handle) => {
8160 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
8161 }
8162 Access::StructMember(handle, index) => {
8163 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
8164 }
8165 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
8166 }
8167 }
8168 }
8169
8170 struct AccessStack {
8171 stack: Vec<Access>,
8172 array_depth: usize,
8173 }
8174
8175 impl AccessStack {
8176 const fn new() -> Self {
8177 Self {
8178 stack: Vec::new(),
8179 array_depth: 0,
8180 }
8181 }
8182
8183 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
8184 let array_depth = self.array_depth;
8185 self.stack.push(Access::Array(array_depth));
8186 self.array_depth += 1;
8187 let res = cb(self, array_depth);
8188 self.stack.pop();
8189 self.array_depth -= 1;
8190 res
8191 }
8192
8193 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
8194 self.stack.push(new);
8195 let res = cb(self);
8196 self.stack.pop();
8197 res
8198 }
8199
8200 fn write<W: Write>(
8201 &self,
8202 writer: &mut W,
8203 names: &FastHashMap<NameKey, String>,
8204 ) -> Result<(), core::fmt::Error> {
8205 for next in self.stack.iter() {
8206 next.write(writer, names)?;
8207 }
8208 Ok(())
8209 }
8210 }
8211
8212 impl<W: Write> Writer<W> {
8213 pub(super) fn need_workgroup_variables_initialization(
8214 &mut self,
8215 options: &Options,
8216 ep: &EntryPoint,
8217 module: &crate::Module,
8218 fun_info: &valid::FunctionInfo,
8219 ) -> bool {
8220 let is_task = ep.stage == crate::ShaderStage::Task;
8221 options.zero_initialize_workgroup_memory
8222 && ep.stage.compute_like()
8223 && module.global_variables.iter().any(|(handle, var)| {
8224 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8225 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8226 !fun_info[handle].is_empty() && is_right_address_space
8227 })
8228 }
8229
8230 pub fn write_workgroup_variables_initialization(
8231 &mut self,
8232 module: &crate::Module,
8233 module_info: &valid::ModuleInfo,
8234 fun_info: &valid::FunctionInfo,
8235 local_invocation_index: Option<&NameKey>,
8236 stage: crate::ShaderStage,
8237 ) -> BackendResult {
8238 let level = back::Level(1);
8239
8240 writeln!(
8241 self.out,
8242 "{}if ({} == 0u) {{",
8243 level,
8244 local_invocation_index
8245 .map(|name_key| self.names[name_key].as_str())
8246 .unwrap_or("__local_invocation_index"),
8247 )?;
8248
8249 let mut access_stack = AccessStack::new();
8250
8251 let is_task = stage == crate::ShaderStage::Task;
8252 let vars = module.global_variables.iter().filter(|&(handle, var)| {
8253 let is_right_address_space = var.space == crate::AddressSpace::WorkGroup
8254 || (var.space == crate::AddressSpace::TaskPayload && is_task);
8255 !fun_info[handle].is_empty() && is_right_address_space
8256 });
8257
8258 for (handle, var) in vars {
8259 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
8260 self.write_workgroup_variable_initialization(
8261 module,
8262 module_info,
8263 var.ty,
8264 access_stack,
8265 level.next(),
8266 )
8267 })?;
8268 }
8269
8270 writeln!(self.out, "{level}}}")?;
8271 self.write_barrier(crate::Barrier::WORK_GROUP, level)
8272 }
8273
8274 fn write_workgroup_variable_initialization(
8275 &mut self,
8276 module: &crate::Module,
8277 module_info: &valid::ModuleInfo,
8278 ty: Handle<crate::Type>,
8279 access_stack: &mut AccessStack,
8280 level: back::Level,
8281 ) -> BackendResult {
8282 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
8283 write!(self.out, "{level}")?;
8284 access_stack.write(&mut self.out, &self.names)?;
8285 writeln!(self.out, " = {{}};")?;
8286 } else {
8287 match module.types[ty].inner {
8288 crate::TypeInner::Atomic { .. } => {
8289 write!(
8290 self.out,
8291 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
8292 )?;
8293 access_stack.write(&mut self.out, &self.names)?;
8294 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
8295 }
8296 crate::TypeInner::Array { base, size, .. } => {
8297 let count = match size.resolve(module.to_ctx())? {
8298 proc::IndexableLength::Known(count) => count,
8299 proc::IndexableLength::Dynamic => unreachable!(),
8300 };
8301
8302 access_stack.enter_array(|access_stack, array_depth| {
8303 writeln!(
8304 self.out,
8305 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
8306 )?;
8307 self.write_workgroup_variable_initialization(
8308 module,
8309 module_info,
8310 base,
8311 access_stack,
8312 level.next(),
8313 )?;
8314 writeln!(self.out, "{level}}}")?;
8315 BackendResult::Ok(())
8316 })?;
8317 }
8318 crate::TypeInner::Struct { ref members, .. } => {
8319 for (index, member) in members.iter().enumerate() {
8320 access_stack.enter(
8321 Access::StructMember(ty, index as u32),
8322 |access_stack| {
8323 self.write_workgroup_variable_initialization(
8324 module,
8325 module_info,
8326 member.ty,
8327 access_stack,
8328 level,
8329 )
8330 },
8331 )?;
8332 }
8333 }
8334 _ => unreachable!(),
8335 }
8336 }
8337
8338 Ok(())
8339 }
8340 }
8341}
8342
8343impl crate::AtomicFunction {
8344 const fn to_msl(self) -> &'static str {
8345 match self {
8346 Self::Add => "fetch_add",
8347 Self::Subtract => "fetch_sub",
8348 Self::And => "fetch_and",
8349 Self::InclusiveOr => "fetch_or",
8350 Self::ExclusiveOr => "fetch_xor",
8351 Self::Min => "fetch_min",
8352 Self::Max => "fetch_max",
8353 Self::Exchange { compare: None } => "exchange",
8354 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
8355 }
8356 }
8357
8358 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
8359 Ok(match self {
8360 Self::Min => "min",
8361 Self::Max => "max",
8362 _ => Err(Error::FeatureNotImplemented(
8363 "64-bit atomic operation other than min/max".to_string(),
8364 ))?,
8365 })
8366 }
8367}