1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67 "nagaTextureSampleBaseClampToEdge";
68pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81
82fn put_numeric_type(
91 out: &mut impl Write,
92 scalar: crate::Scalar,
93 sizes: &[crate::VectorSize],
94) -> Result<(), FmtError> {
95 match (scalar, sizes) {
96 (scalar, &[]) => {
97 write!(out, "{}", scalar.to_msl_name())
98 }
99 (scalar, &[rows]) => {
100 write!(
101 out,
102 "{}::{}{}",
103 NAMESPACE,
104 scalar.to_msl_name(),
105 common::vector_size_str(rows)
106 )
107 }
108 (scalar, &[rows, columns]) => {
109 write!(
110 out,
111 "{}::{}{}x{}",
112 NAMESPACE,
113 scalar.to_msl_name(),
114 common::vector_size_str(columns),
115 common::vector_size_str(rows)
116 )
117 }
118 (_, _) => Ok(()), }
120}
121
122const fn scalar_is_int(scalar: crate::Scalar) -> bool {
123 use crate::ScalarKind::*;
124 match scalar.kind {
125 Sint | Uint | AbstractInt | Bool => true,
126 Float | AbstractFloat => false,
127 }
128}
129
130const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
132
133const REINTERPRET_PREFIX: &str = "reinterpreted_";
135
136struct ClampedLod(Handle<crate::Expression>);
142
143impl Display for ClampedLod {
144 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
145 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
146 }
147}
148
149struct ArraySizeMember(Handle<crate::GlobalVariable>);
164
165impl Display for ArraySizeMember {
166 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167 self.0.write_prefixed(f, "size")
168 }
169}
170
171#[derive(Clone, Copy)]
176struct Reinterpreted<'a> {
177 target_type: &'a str,
178 orig: Handle<crate::Expression>,
179}
180
181impl<'a> Reinterpreted<'a> {
182 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
183 Self { target_type, orig }
184 }
185}
186
187impl Display for Reinterpreted<'_> {
188 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
189 f.write_str(REINTERPRET_PREFIX)?;
190 f.write_str(self.target_type)?;
191 self.orig.write_prefixed(f, "_e")
192 }
193}
194
195struct TypeContext<'a> {
196 handle: Handle<crate::Type>,
197 gctx: proc::GlobalCtx<'a>,
198 names: &'a FastHashMap<NameKey, String>,
199 access: crate::StorageAccess,
200 first_time: bool,
201}
202
203impl TypeContext<'_> {
204 fn scalar(&self) -> Option<crate::Scalar> {
205 let ty = &self.gctx.types[self.handle];
206 ty.inner.scalar()
207 }
208
209 fn vertex_input_dimension(&self) -> u32 {
210 let ty = &self.gctx.types[self.handle];
211 match ty.inner {
212 crate::TypeInner::Scalar(_) => 1,
213 crate::TypeInner::Vector { size, .. } => size as u32,
214 _ => unreachable!(),
215 }
216 }
217}
218
219impl Display for TypeContext<'_> {
220 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221 let ty = &self.gctx.types[self.handle];
222 if ty.needs_alias() && !self.first_time {
223 let name = &self.names[&NameKey::Type(self.handle)];
224 return write!(out, "{name}");
225 }
226
227 match ty.inner {
228 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229 crate::TypeInner::Atomic(scalar) => {
230 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231 }
232 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233 crate::TypeInner::Matrix {
234 columns,
235 rows,
236 scalar,
237 } => put_numeric_type(out, scalar, &[rows, columns]),
238 crate::TypeInner::Pointer { base, space } => {
239 let sub = Self {
240 handle: base,
241 first_time: false,
242 ..*self
243 };
244 let space_name = match space.to_msl_name() {
245 Some(name) => name,
246 None => return Ok(()),
247 };
248 write!(out, "{space_name} {sub}&")
249 }
250 crate::TypeInner::ValuePointer {
251 size,
252 scalar,
253 space,
254 } => {
255 match space.to_msl_name() {
256 Some(name) => write!(out, "{name} ")?,
257 None => return Ok(()),
258 };
259 match size {
260 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
261 None => put_numeric_type(out, scalar, &[])?,
262 };
263
264 write!(out, "&")
265 }
266 crate::TypeInner::Array { base, .. } => {
267 let sub = Self {
268 handle: base,
269 first_time: false,
270 ..*self
271 };
272 write!(out, "{sub}")
275 }
276 crate::TypeInner::Struct { .. } => unreachable!(),
277 crate::TypeInner::Image {
278 dim,
279 arrayed,
280 class,
281 } => {
282 let dim_str = match dim {
283 crate::ImageDimension::D1 => "1d",
284 crate::ImageDimension::D2 => "2d",
285 crate::ImageDimension::D3 => "3d",
286 crate::ImageDimension::Cube => "cube",
287 };
288 let (texture_str, msaa_str, scalar, access) = match class {
289 crate::ImageClass::Sampled { kind, multi } => {
290 let (msaa_str, access) = if multi {
291 ("_ms", "read")
292 } else {
293 ("", "sample")
294 };
295 let scalar = crate::Scalar { kind, width: 4 };
296 ("texture", msaa_str, scalar, access)
297 }
298 crate::ImageClass::Depth { multi } => {
299 let (msaa_str, access) = if multi {
300 ("_ms", "read")
301 } else {
302 ("", "sample")
303 };
304 let scalar = crate::Scalar {
305 kind: crate::ScalarKind::Float,
306 width: 4,
307 };
308 ("depth", msaa_str, scalar, access)
309 }
310 crate::ImageClass::Storage { format, .. } => {
311 let access = if self
312 .access
313 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
314 {
315 "read_write"
316 } else if self.access.contains(crate::StorageAccess::STORE) {
317 "write"
318 } else if self.access.contains(crate::StorageAccess::LOAD) {
319 "read"
320 } else {
321 log::warn!(
322 "Storage access for {:?} (name '{}'): {:?}",
323 self.handle,
324 ty.name.as_deref().unwrap_or_default(),
325 self.access
326 );
327 unreachable!("module is not valid");
328 };
329 ("texture", "", format.into(), access)
330 }
331 crate::ImageClass::External => {
332 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
333 }
334 };
335 let base_name = scalar.to_msl_name();
336 let array_str = if arrayed { "_array" } else { "" };
337 write!(
338 out,
339 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
340 )
341 }
342 crate::TypeInner::Sampler { comparison: _ } => {
343 write!(out, "{NAMESPACE}::sampler")
344 }
345 crate::TypeInner::AccelerationStructure { vertex_return } => {
346 if vertex_return {
347 unimplemented!("metal does not support vertex ray hit return")
348 }
349 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
350 }
351 crate::TypeInner::RayQuery { vertex_return } => {
352 if vertex_return {
353 unimplemented!("metal does not support vertex ray hit return")
354 }
355 write!(out, "{RAY_QUERY_TYPE}")
356 }
357 crate::TypeInner::BindingArray { base, .. } => {
358 let base_tyname = Self {
359 handle: base,
360 first_time: false,
361 ..*self
362 };
363
364 write!(
365 out,
366 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
367 )
368 }
369 }
370 }
371}
372
373struct TypedGlobalVariable<'a> {
374 module: &'a crate::Module,
375 names: &'a FastHashMap<NameKey, String>,
376 handle: Handle<crate::GlobalVariable>,
377 usage: valid::GlobalUse,
378 reference: bool,
379}
380
381impl TypedGlobalVariable<'_> {
382 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
383 let var = &self.module.global_variables[self.handle];
384 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
385
386 let storage_access = match var.space {
387 crate::AddressSpace::Storage { access } => access,
388 _ => match self.module.types[var.ty].inner {
389 crate::TypeInner::Image {
390 class: crate::ImageClass::Storage { access, .. },
391 ..
392 } => access,
393 crate::TypeInner::BindingArray { base, .. } => {
394 match self.module.types[base].inner {
395 crate::TypeInner::Image {
396 class: crate::ImageClass::Storage { access, .. },
397 ..
398 } => access,
399 _ => crate::StorageAccess::default(),
400 }
401 }
402 _ => crate::StorageAccess::default(),
403 },
404 };
405 let ty_name = TypeContext {
406 handle: var.ty,
407 gctx: self.module.to_ctx(),
408 names: self.names,
409 access: storage_access,
410 first_time: false,
411 };
412
413 let (space, access, reference) = match var.space.to_msl_name() {
414 Some(space) if self.reference => {
415 let access = if var.space.needs_access_qualifier()
416 && !self.usage.intersects(valid::GlobalUse::WRITE)
417 {
418 "const"
419 } else {
420 ""
421 };
422 (space, access, "&")
423 }
424 _ => ("", "", ""),
425 };
426
427 Ok(write!(
428 out,
429 "{}{}{}{}{}{} {}",
430 space,
431 if space.is_empty() { "" } else { " " },
432 ty_name,
433 if access.is_empty() { "" } else { " " },
434 access,
435 reference,
436 name,
437 )?)
438 }
439}
440
441#[derive(Eq, PartialEq, Hash)]
442enum WrappedFunction {
443 UnaryOp {
444 op: crate::UnaryOperator,
445 ty: (Option<crate::VectorSize>, crate::Scalar),
446 },
447 BinaryOp {
448 op: crate::BinaryOperator,
449 left_ty: (Option<crate::VectorSize>, crate::Scalar),
450 right_ty: (Option<crate::VectorSize>, crate::Scalar),
451 },
452 Math {
453 fun: crate::MathFunction,
454 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
455 },
456 Cast {
457 src_scalar: crate::Scalar,
458 vector_size: Option<crate::VectorSize>,
459 dst_scalar: crate::Scalar,
460 },
461 ImageLoad {
462 class: crate::ImageClass,
463 },
464 ImageSample {
465 class: crate::ImageClass,
466 clamp_to_edge: bool,
467 },
468 ImageQuerySize {
469 class: crate::ImageClass,
470 },
471}
472
473pub struct Writer<W> {
474 out: W,
475 names: FastHashMap<NameKey, String>,
476 named_expressions: crate::NamedExpressions,
477 need_bake_expressions: back::NeedBakeExpressions,
479 namer: proc::Namer,
480 wrapped_functions: FastHashSet<WrappedFunction>,
481 #[cfg(test)]
482 put_expression_stack_pointers: FastHashSet<*const ()>,
483 #[cfg(test)]
484 put_block_stack_pointers: FastHashSet<*const ()>,
485 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
488}
489
490impl crate::Scalar {
491 fn to_msl_name(self) -> &'static str {
492 use crate::ScalarKind as Sk;
493 match self {
494 Self {
495 kind: Sk::Float,
496 width: 4,
497 } => "float",
498 Self {
499 kind: Sk::Float,
500 width: 2,
501 } => "half",
502 Self {
503 kind: Sk::Sint,
504 width: 4,
505 } => "int",
506 Self {
507 kind: Sk::Uint,
508 width: 4,
509 } => "uint",
510 Self {
511 kind: Sk::Sint,
512 width: 8,
513 } => "long",
514 Self {
515 kind: Sk::Uint,
516 width: 8,
517 } => "ulong",
518 Self {
519 kind: Sk::Bool,
520 width: _,
521 } => "bool",
522 Self {
523 kind: Sk::AbstractInt | Sk::AbstractFloat,
524 width: _,
525 } => unreachable!("Found Abstract scalar kind"),
526 _ => unreachable!("Unsupported scalar kind: {:?}", self),
527 }
528 }
529}
530
531const fn separate(need_separator: bool) -> &'static str {
532 if need_separator {
533 ","
534 } else {
535 ""
536 }
537}
538
539fn should_pack_struct_member(
540 members: &[crate::StructMember],
541 span: u32,
542 index: usize,
543 module: &crate::Module,
544) -> Option<crate::Scalar> {
545 let member = &members[index];
546
547 let ty_inner = &module.types[member.ty].inner;
548 let last_offset = member.offset + ty_inner.size(module.to_ctx());
549 let next_offset = match members.get(index + 1) {
550 Some(next) => next.offset,
551 None => span,
552 };
553 let is_tight = next_offset == last_offset;
554
555 match *ty_inner {
556 crate::TypeInner::Vector {
557 size: crate::VectorSize::Tri,
558 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
559 } if is_tight => Some(scalar),
560 _ => None,
561 }
562}
563
564fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
565 match arena[ty].inner {
566 crate::TypeInner::Struct { ref members, .. } => {
567 if let Some(member) = members.last() {
568 if let crate::TypeInner::Array {
569 size: crate::ArraySize::Dynamic,
570 ..
571 } = arena[member.ty].inner
572 {
573 return true;
574 }
575 }
576 false
577 }
578 crate::TypeInner::Array {
579 size: crate::ArraySize::Dynamic,
580 ..
581 } => true,
582 _ => false,
583 }
584}
585
586impl crate::AddressSpace {
587 const fn needs_pass_through(&self) -> bool {
591 match *self {
592 Self::Uniform
593 | Self::Storage { .. }
594 | Self::Private
595 | Self::WorkGroup
596 | Self::PushConstant
597 | Self::Handle => true,
598 Self::Function => false,
599 }
600 }
601
602 const fn needs_access_qualifier(&self) -> bool {
604 match *self {
605 Self::Storage { .. } => true,
610 Self::Private | Self::WorkGroup => false,
612 Self::Uniform | Self::PushConstant => false,
614 Self::Handle | Self::Function => false,
616 }
617 }
618
619 const fn to_msl_name(self) -> Option<&'static str> {
620 match self {
621 Self::Handle => None,
622 Self::Uniform | Self::PushConstant => Some("constant"),
623 Self::Storage { .. } => Some("device"),
624 Self::Private | Self::Function => Some("thread"),
625 Self::WorkGroup => Some("threadgroup"),
626 }
627 }
628}
629
630impl crate::Type {
631 const fn needs_alias(&self) -> bool {
633 use crate::TypeInner as Ti;
634
635 match self.inner {
636 Ti::Scalar(_)
638 | Ti::Vector { .. }
639 | Ti::Matrix { .. }
640 | Ti::Atomic(_)
641 | Ti::Pointer { .. }
642 | Ti::ValuePointer { .. } => self.name.is_some(),
643 Ti::Struct { .. } | Ti::Array { .. } => true,
645 Ti::Image { .. }
647 | Ti::Sampler { .. }
648 | Ti::AccelerationStructure { .. }
649 | Ti::RayQuery { .. }
650 | Ti::BindingArray { .. } => false,
651 }
652 }
653}
654
655#[derive(Clone, Copy)]
656enum FunctionOrigin {
657 Handle(Handle<crate::Function>),
658 EntryPoint(proc::EntryPointIndex),
659}
660
661trait NameKeyExt {
662 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
663 match origin {
664 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
665 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
666 }
667 }
668
669 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
674 match origin {
675 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
676 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
677 }
678 }
679}
680
681impl NameKeyExt for NameKey {}
682
683#[derive(Clone, Copy)]
693enum LevelOfDetail {
694 Direct(Handle<crate::Expression>),
695 Restricted(Handle<crate::Expression>),
696}
697
698struct TexelAddress {
708 coordinate: Handle<crate::Expression>,
709 array_index: Option<Handle<crate::Expression>>,
710 sample: Option<Handle<crate::Expression>>,
711 level: Option<LevelOfDetail>,
712}
713
714struct ExpressionContext<'a> {
715 function: &'a crate::Function,
716 origin: FunctionOrigin,
717 info: &'a valid::FunctionInfo,
718 module: &'a crate::Module,
719 mod_info: &'a valid::ModuleInfo,
720 pipeline_options: &'a PipelineOptions,
721 lang_version: (u8, u8),
722 policies: index::BoundsCheckPolicies,
723
724 guarded_indices: HandleSet<crate::Expression>,
728 force_loop_bounding: bool,
730}
731
732impl<'a> ExpressionContext<'a> {
733 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
734 self.info[handle].ty.inner_with(&self.module.types)
735 }
736
737 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
744 let image_ty = self.resolve_type(image);
745 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
746 class.is_mipmapped() && dim != crate::ImageDimension::D1
747 } else {
748 false
749 }
750 }
751
752 fn choose_bounds_check_policy(
753 &self,
754 pointer: Handle<crate::Expression>,
755 ) -> index::BoundsCheckPolicy {
756 self.policies
757 .choose_policy(pointer, &self.module.types, self.info)
758 }
759
760 fn access_needs_check(
762 &self,
763 base: Handle<crate::Expression>,
764 index: index::GuardedIndex,
765 ) -> Option<index::IndexableLength> {
766 index::access_needs_check(
767 base,
768 index,
769 self.module,
770 &self.function.expressions,
771 self.info,
772 )
773 }
774
775 fn bounds_check_iter(
777 &self,
778 chain: Handle<crate::Expression>,
779 ) -> impl Iterator<Item = BoundsCheck> + '_ {
780 index::bounds_check_iter(chain, self.module, self.function, self.info)
781 }
782
783 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
785 index::oob_local_types(self.module, self.function, self.info, self.policies)
786 }
787
788 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
789 match self.function.expressions[expr_handle] {
790 crate::Expression::AccessIndex { base, index } => {
791 let ty = match *self.resolve_type(base) {
792 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
793 ref ty => ty,
794 };
795 match *ty {
796 crate::TypeInner::Struct {
797 ref members, span, ..
798 } => should_pack_struct_member(members, span, index as usize, self.module),
799 _ => None,
800 }
801 }
802 _ => None,
803 }
804 }
805}
806
807struct StatementContext<'a> {
808 expression: ExpressionContext<'a>,
809 result_struct: Option<&'a str>,
810}
811
812impl<W: Write> Writer<W> {
813 pub fn new(out: W) -> Self {
815 Writer {
816 out,
817 names: FastHashMap::default(),
818 named_expressions: Default::default(),
819 need_bake_expressions: Default::default(),
820 namer: proc::Namer::default(),
821 wrapped_functions: FastHashSet::default(),
822 #[cfg(test)]
823 put_expression_stack_pointers: Default::default(),
824 #[cfg(test)]
825 put_block_stack_pointers: Default::default(),
826 struct_member_pads: FastHashSet::default(),
827 }
828 }
829
830 #[allow(clippy::missing_const_for_fn)]
833 pub fn finish(self) -> W {
834 self.out
835 }
836
837 fn gen_force_bounded_loop_statements(
944 &mut self,
945 level: back::Level,
946 context: &StatementContext,
947 ) -> Option<(String, String)> {
948 if !context.expression.force_loop_bounding {
949 return None;
950 }
951
952 let loop_bound_name = self.namer.call("loop_bound");
953 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
956 let level = level.next();
957 let break_and_inc = format!(
958 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
959{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
960 );
961
962 Some((decl, break_and_inc))
963 }
964
965 fn put_call_parameters(
966 &mut self,
967 parameters: impl Iterator<Item = Handle<crate::Expression>>,
968 context: &ExpressionContext,
969 ) -> BackendResult {
970 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
971 writer.put_expression(expr, context, true)
972 })
973 }
974
975 fn put_call_parameters_impl<C, E>(
976 &mut self,
977 parameters: impl Iterator<Item = Handle<crate::Expression>>,
978 ctx: &C,
979 put_expression: E,
980 ) -> BackendResult
981 where
982 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
983 {
984 write!(self.out, "(")?;
985 for (i, handle) in parameters.enumerate() {
986 if i != 0 {
987 write!(self.out, ", ")?;
988 }
989 put_expression(self, ctx, handle)?;
990 }
991 write!(self.out, ")")?;
992 Ok(())
993 }
994
995 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1001 let oob_local_types = context.oob_local_types();
1002 for &ty in oob_local_types.iter() {
1003 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1004 self.names.insert(name_key, self.namer.call("oob"));
1005 }
1006
1007 for (name_key, ty, init) in context
1008 .function
1009 .local_variables
1010 .iter()
1011 .map(|(local_handle, local)| {
1012 let name_key = NameKey::local(context.origin, local_handle);
1013 (name_key, local.ty, local.init)
1014 })
1015 .chain(oob_local_types.iter().map(|&ty| {
1016 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1017 (name_key, ty, None)
1018 }))
1019 {
1020 let ty_name = TypeContext {
1021 handle: ty,
1022 gctx: context.module.to_ctx(),
1023 names: &self.names,
1024 access: crate::StorageAccess::empty(),
1025 first_time: false,
1026 };
1027 write!(
1028 self.out,
1029 "{}{} {}",
1030 back::INDENT,
1031 ty_name,
1032 self.names[&name_key]
1033 )?;
1034 match init {
1035 Some(value) => {
1036 write!(self.out, " = ")?;
1037 self.put_expression(value, context, true)?;
1038 }
1039 None => {
1040 write!(self.out, " = {{}}")?;
1041 }
1042 };
1043 writeln!(self.out, ";")?;
1044 }
1045 Ok(())
1046 }
1047
1048 fn put_level_of_detail(
1049 &mut self,
1050 level: LevelOfDetail,
1051 context: &ExpressionContext,
1052 ) -> BackendResult {
1053 match level {
1054 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1055 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1056 }
1057 Ok(())
1058 }
1059
1060 fn put_image_query(
1061 &mut self,
1062 image: Handle<crate::Expression>,
1063 query: &str,
1064 level: Option<LevelOfDetail>,
1065 context: &ExpressionContext,
1066 ) -> BackendResult {
1067 self.put_expression(image, context, false)?;
1068 write!(self.out, ".get_{query}(")?;
1069 if let Some(level) = level {
1070 self.put_level_of_detail(level, context)?;
1071 }
1072 write!(self.out, ")")?;
1073 Ok(())
1074 }
1075
1076 fn put_image_size_query(
1077 &mut self,
1078 image: Handle<crate::Expression>,
1079 level: Option<LevelOfDetail>,
1080 kind: crate::ScalarKind,
1081 context: &ExpressionContext,
1082 ) -> BackendResult {
1083 if let crate::TypeInner::Image {
1084 class: crate::ImageClass::External,
1085 ..
1086 } = *context.resolve_type(image)
1087 {
1088 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1089 self.put_expression(image, context, true)?;
1090 write!(self.out, ")")?;
1091 return Ok(());
1092 }
1093
1094 let dim = match *context.resolve_type(image) {
1097 crate::TypeInner::Image { dim, .. } => dim,
1098 ref other => unreachable!("Unexpected type {:?}", other),
1099 };
1100 let scalar = crate::Scalar { kind, width: 4 };
1101 let coordinate_type = scalar.to_msl_name();
1102 match dim {
1103 crate::ImageDimension::D1 => {
1104 if kind == crate::ScalarKind::Uint {
1108 self.put_image_query(image, "width", None, context)?;
1110 } else {
1111 write!(self.out, "int(")?;
1113 self.put_image_query(image, "width", None, context)?;
1114 write!(self.out, ")")?;
1115 }
1116 }
1117 crate::ImageDimension::D2 => {
1118 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1119 self.put_image_query(image, "width", level, context)?;
1120 write!(self.out, ", ")?;
1121 self.put_image_query(image, "height", level, context)?;
1122 write!(self.out, ")")?;
1123 }
1124 crate::ImageDimension::D3 => {
1125 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1126 self.put_image_query(image, "width", level, context)?;
1127 write!(self.out, ", ")?;
1128 self.put_image_query(image, "height", level, context)?;
1129 write!(self.out, ", ")?;
1130 self.put_image_query(image, "depth", level, context)?;
1131 write!(self.out, ")")?;
1132 }
1133 crate::ImageDimension::Cube => {
1134 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1135 self.put_image_query(image, "width", level, context)?;
1136 write!(self.out, ")")?;
1137 }
1138 }
1139 Ok(())
1140 }
1141
1142 fn put_cast_to_uint_scalar_or_vector(
1143 &mut self,
1144 expr: Handle<crate::Expression>,
1145 context: &ExpressionContext,
1146 ) -> BackendResult {
1147 match *context.resolve_type(expr) {
1149 crate::TypeInner::Scalar(_) => {
1150 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1151 }
1152 crate::TypeInner::Vector { size, .. } => {
1153 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1154 }
1155 _ => {
1156 return Err(Error::GenericValidation(
1157 "Invalid type for image coordinate".into(),
1158 ))
1159 }
1160 };
1161
1162 write!(self.out, "(")?;
1163 self.put_expression(expr, context, true)?;
1164 write!(self.out, ")")?;
1165 Ok(())
1166 }
1167
1168 fn put_image_sample_level(
1169 &mut self,
1170 image: Handle<crate::Expression>,
1171 level: crate::SampleLevel,
1172 context: &ExpressionContext,
1173 ) -> BackendResult {
1174 let has_levels = context.image_needs_lod(image);
1175 match level {
1176 crate::SampleLevel::Auto => {}
1177 crate::SampleLevel::Zero => {
1178 }
1180 _ if !has_levels => {
1181 log::warn!("1D image can't be sampled with level {level:?}");
1182 }
1183 crate::SampleLevel::Exact(h) => {
1184 write!(self.out, ", {NAMESPACE}::level(")?;
1185 self.put_expression(h, context, true)?;
1186 write!(self.out, ")")?;
1187 }
1188 crate::SampleLevel::Bias(h) => {
1189 write!(self.out, ", {NAMESPACE}::bias(")?;
1190 self.put_expression(h, context, true)?;
1191 write!(self.out, ")")?;
1192 }
1193 crate::SampleLevel::Gradient { x, y } => {
1194 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1195 self.put_expression(x, context, true)?;
1196 write!(self.out, ", ")?;
1197 self.put_expression(y, context, true)?;
1198 write!(self.out, ")")?;
1199 }
1200 }
1201 Ok(())
1202 }
1203
1204 fn put_image_coordinate_limits(
1205 &mut self,
1206 image: Handle<crate::Expression>,
1207 level: Option<LevelOfDetail>,
1208 context: &ExpressionContext,
1209 ) -> BackendResult {
1210 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1211 write!(self.out, " - 1")?;
1212 Ok(())
1213 }
1214
1215 fn put_restricted_scalar_image_index(
1233 &mut self,
1234 image: Handle<crate::Expression>,
1235 index: Handle<crate::Expression>,
1236 limit_method: &str,
1237 context: &ExpressionContext,
1238 ) -> BackendResult {
1239 write!(self.out, "{NAMESPACE}::min(uint(")?;
1240 self.put_expression(index, context, true)?;
1241 write!(self.out, "), ")?;
1242 self.put_expression(image, context, false)?;
1243 write!(self.out, ".{limit_method}() - 1)")?;
1244 Ok(())
1245 }
1246
1247 fn put_restricted_texel_address(
1248 &mut self,
1249 image: Handle<crate::Expression>,
1250 address: &TexelAddress,
1251 context: &ExpressionContext,
1252 ) -> BackendResult {
1253 write!(self.out, "{NAMESPACE}::min(")?;
1255 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1256 write!(self.out, ", ")?;
1257 self.put_image_coordinate_limits(image, address.level, context)?;
1258 write!(self.out, ")")?;
1259
1260 if let Some(array_index) = address.array_index {
1262 write!(self.out, ", ")?;
1263 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1264 }
1265
1266 if let Some(sample) = address.sample {
1268 write!(self.out, ", ")?;
1269 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1270 }
1271
1272 if let Some(level) = address.level {
1275 write!(self.out, ", ")?;
1276 self.put_level_of_detail(level, context)?;
1277 }
1278
1279 Ok(())
1280 }
1281
1282 fn put_image_access_bounds_check(
1284 &mut self,
1285 image: Handle<crate::Expression>,
1286 address: &TexelAddress,
1287 context: &ExpressionContext,
1288 ) -> BackendResult {
1289 let mut conjunction = "";
1290
1291 let level = if let Some(level) = address.level {
1294 write!(self.out, "uint(")?;
1295 self.put_level_of_detail(level, context)?;
1296 write!(self.out, ") < ")?;
1297 self.put_expression(image, context, true)?;
1298 write!(self.out, ".get_num_mip_levels()")?;
1299 conjunction = " && ";
1300 Some(level)
1301 } else {
1302 None
1303 };
1304
1305 if let Some(sample) = address.sample {
1307 write!(self.out, "uint(")?;
1308 self.put_expression(sample, context, true)?;
1309 write!(self.out, ") < ")?;
1310 self.put_expression(image, context, true)?;
1311 write!(self.out, ".get_num_samples()")?;
1312 conjunction = " && ";
1313 }
1314
1315 if let Some(array_index) = address.array_index {
1317 write!(self.out, "{conjunction}uint(")?;
1318 self.put_expression(array_index, context, true)?;
1319 write!(self.out, ") < ")?;
1320 self.put_expression(image, context, true)?;
1321 write!(self.out, ".get_array_size()")?;
1322 conjunction = " && ";
1323 }
1324
1325 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1327 crate::TypeInner::Vector { .. } => true,
1328 _ => false,
1329 };
1330 write!(self.out, "{conjunction}")?;
1331 if coord_is_vector {
1332 write!(self.out, "{NAMESPACE}::all(")?;
1333 }
1334 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1335 write!(self.out, " < ")?;
1336 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1337 if coord_is_vector {
1338 write!(self.out, ")")?;
1339 }
1340
1341 Ok(())
1342 }
1343
1344 fn put_image_load(
1345 &mut self,
1346 load: Handle<crate::Expression>,
1347 image: Handle<crate::Expression>,
1348 mut address: TexelAddress,
1349 context: &ExpressionContext,
1350 ) -> BackendResult {
1351 if let crate::TypeInner::Image {
1352 class: crate::ImageClass::External,
1353 ..
1354 } = *context.resolve_type(image)
1355 {
1356 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1357 self.put_expression(image, context, true)?;
1358 write!(self.out, ", ")?;
1359 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1360 write!(self.out, ")")?;
1361 return Ok(());
1362 }
1363
1364 match context.policies.image_load {
1365 proc::BoundsCheckPolicy::Restrict => {
1366 if address.level.is_some() {
1369 address.level = if context.image_needs_lod(image) {
1370 Some(LevelOfDetail::Restricted(load))
1371 } else {
1372 None
1373 }
1374 }
1375
1376 self.put_expression(image, context, false)?;
1377 write!(self.out, ".read(")?;
1378 self.put_restricted_texel_address(image, &address, context)?;
1379 write!(self.out, ")")?;
1380 }
1381 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1382 write!(self.out, "(")?;
1383 self.put_image_access_bounds_check(image, &address, context)?;
1384 write!(self.out, " ? ")?;
1385 self.put_unchecked_image_load(image, &address, context)?;
1386 write!(self.out, ": DefaultConstructible())")?;
1387 }
1388 proc::BoundsCheckPolicy::Unchecked => {
1389 self.put_unchecked_image_load(image, &address, context)?;
1390 }
1391 }
1392
1393 Ok(())
1394 }
1395
1396 fn put_unchecked_image_load(
1397 &mut self,
1398 image: Handle<crate::Expression>,
1399 address: &TexelAddress,
1400 context: &ExpressionContext,
1401 ) -> BackendResult {
1402 self.put_expression(image, context, false)?;
1403 write!(self.out, ".read(")?;
1404 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1406 if let Some(expr) = address.array_index {
1407 write!(self.out, ", ")?;
1408 self.put_expression(expr, context, true)?;
1409 }
1410 if let Some(sample) = address.sample {
1411 write!(self.out, ", ")?;
1412 self.put_expression(sample, context, true)?;
1413 }
1414 if let Some(level) = address.level {
1415 if context.image_needs_lod(image) {
1416 write!(self.out, ", ")?;
1417 self.put_level_of_detail(level, context)?;
1418 }
1419 }
1420 write!(self.out, ")")?;
1421
1422 Ok(())
1423 }
1424
1425 fn put_image_atomic(
1426 &mut self,
1427 level: back::Level,
1428 image: Handle<crate::Expression>,
1429 address: &TexelAddress,
1430 fun: crate::AtomicFunction,
1431 value: Handle<crate::Expression>,
1432 context: &StatementContext,
1433 ) -> BackendResult {
1434 write!(self.out, "{level}")?;
1435 self.put_expression(image, &context.expression, false)?;
1436 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1437 fun.to_msl_64_bit()?
1438 } else {
1439 fun.to_msl()
1440 };
1441 write!(self.out, ".atomic_{op}(")?;
1442 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1444 write!(self.out, ", ")?;
1445 self.put_expression(value, &context.expression, true)?;
1446 writeln!(self.out, ");")?;
1447
1448 Ok(())
1449 }
1450
1451 fn put_image_store(
1452 &mut self,
1453 level: back::Level,
1454 image: Handle<crate::Expression>,
1455 address: &TexelAddress,
1456 value: Handle<crate::Expression>,
1457 context: &StatementContext,
1458 ) -> BackendResult {
1459 write!(self.out, "{level}")?;
1460 self.put_expression(image, &context.expression, false)?;
1461 write!(self.out, ".write(")?;
1462 self.put_expression(value, &context.expression, true)?;
1463 write!(self.out, ", ")?;
1464 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1466 if let Some(expr) = address.array_index {
1467 write!(self.out, ", ")?;
1468 self.put_expression(expr, &context.expression, true)?;
1469 }
1470 writeln!(self.out, ");")?;
1471
1472 Ok(())
1473 }
1474
1475 fn put_dynamic_array_max_index(
1486 &mut self,
1487 handle: Handle<crate::GlobalVariable>,
1488 context: &ExpressionContext,
1489 ) -> BackendResult {
1490 let global = &context.module.global_variables[handle];
1491 let (offset, array_ty) = match context.module.types[global.ty].inner {
1492 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1493 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1494 None => return Err(Error::GenericValidation("Struct has no members".into())),
1495 },
1496 crate::TypeInner::Array {
1497 size: crate::ArraySize::Dynamic,
1498 ..
1499 } => (0, global.ty),
1500 ref ty => {
1501 return Err(Error::GenericValidation(format!(
1502 "Expected type with dynamic array, got {ty:?}"
1503 )))
1504 }
1505 };
1506
1507 let (size, stride) = match context.module.types[array_ty].inner {
1508 crate::TypeInner::Array { base, stride, .. } => (
1509 context.module.types[base]
1510 .inner
1511 .size(context.module.to_ctx()),
1512 stride,
1513 ),
1514 ref ty => {
1515 return Err(Error::GenericValidation(format!(
1516 "Expected array type, got {ty:?}"
1517 )))
1518 }
1519 };
1520
1521 write!(
1534 self.out,
1535 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1536 member = ArraySizeMember(handle),
1537 offset = offset,
1538 size = size,
1539 stride = stride,
1540 )?;
1541 Ok(())
1542 }
1543
1544 fn put_dot_product<T: Copy>(
1549 &mut self,
1550 arg: T,
1551 arg1: T,
1552 size: usize,
1553 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1554 ) -> BackendResult {
1555 write!(self.out, "(")?;
1558
1559 for index in 0..size {
1561 write!(self.out, " + ")?;
1564 extractor(self, arg, index)?;
1565 write!(self.out, " * ")?;
1566 extractor(self, arg1, index)?;
1567 }
1568
1569 write!(self.out, ")")?;
1570 Ok(())
1571 }
1572
1573 fn put_pack4x8(
1575 &mut self,
1576 arg: Handle<crate::Expression>,
1577 context: &ExpressionContext<'_>,
1578 was_signed: bool,
1579 clamp_bounds: Option<(&str, &str)>,
1580 ) -> Result<(), Error> {
1581 let write_arg = |this: &mut Self| -> BackendResult {
1582 if let Some((min, max)) = clamp_bounds {
1583 write!(this.out, "{NAMESPACE}::clamp(")?;
1585 this.put_expression(arg, context, true)?;
1586 write!(this.out, ", {min}, {max})")?;
1587 } else {
1588 this.put_expression(arg, context, true)?;
1589 }
1590 Ok(())
1591 };
1592
1593 if context.lang_version >= (2, 1) {
1594 let packed_type = if was_signed {
1595 "packed_char4"
1596 } else {
1597 "packed_uchar4"
1598 };
1599 write!(self.out, "as_type<uint>({packed_type}(")?;
1601 write_arg(self)?;
1602 write!(self.out, "))")?;
1603 } else {
1604 if was_signed {
1606 write!(self.out, "uint(")?;
1607 }
1608 write!(self.out, "(")?;
1609 write_arg(self)?;
1610 write!(self.out, "[0] & 0xFF) | ((")?;
1611 write_arg(self)?;
1612 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1613 write_arg(self)?;
1614 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1615 write_arg(self)?;
1616 write!(self.out, "[3] & 0xFF) << 24)")?;
1617 if was_signed {
1618 write!(self.out, ")")?;
1619 }
1620 }
1621
1622 Ok(())
1623 }
1624
1625 fn put_isign(
1628 &mut self,
1629 arg: Handle<crate::Expression>,
1630 context: &ExpressionContext,
1631 ) -> BackendResult {
1632 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1633 let scalar = context
1634 .resolve_type(arg)
1635 .scalar()
1636 .expect("put_isign should only be called for args which have an integer scalar type")
1637 .to_msl_name();
1638 match context.resolve_type(arg) {
1639 &crate::TypeInner::Vector { size, .. } => {
1640 let size = common::vector_size_str(size);
1641 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1642 }
1643 _ => {
1644 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1645 }
1646 }
1647 write!(self.out, ", (")?;
1648 self.put_expression(arg, context, true)?;
1649 write!(self.out, " > 0)), {scalar}(0), (")?;
1650 self.put_expression(arg, context, true)?;
1651 write!(self.out, " == 0))")?;
1652 Ok(())
1653 }
1654
1655 fn put_const_expression(
1656 &mut self,
1657 expr_handle: Handle<crate::Expression>,
1658 module: &crate::Module,
1659 mod_info: &valid::ModuleInfo,
1660 arena: &crate::Arena<crate::Expression>,
1661 ) -> BackendResult {
1662 self.put_possibly_const_expression(
1663 expr_handle,
1664 arena,
1665 module,
1666 mod_info,
1667 &(module, mod_info),
1668 |&(_, mod_info), expr| &mod_info[expr],
1669 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1670 )
1671 }
1672
1673 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1674 match literal {
1675 crate::Literal::F64(_) => {
1676 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1677 }
1678 crate::Literal::F16(value) => {
1679 if value.is_infinite() {
1680 let sign = if value.is_sign_negative() { "-" } else { "" };
1681 write!(self.out, "{sign}INFINITY")?;
1682 } else if value.is_nan() {
1683 write!(self.out, "NAN")?;
1684 } else {
1685 let suffix = if value.fract() == f16::from_f32(0.0) {
1686 ".0h"
1687 } else {
1688 "h"
1689 };
1690 write!(self.out, "{value}{suffix}")?;
1691 }
1692 }
1693 crate::Literal::F32(value) => {
1694 if value.is_infinite() {
1695 let sign = if value.is_sign_negative() { "-" } else { "" };
1696 write!(self.out, "{sign}INFINITY")?;
1697 } else if value.is_nan() {
1698 write!(self.out, "NAN")?;
1699 } else {
1700 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1701 write!(self.out, "{value}{suffix}")?;
1702 }
1703 }
1704 crate::Literal::U32(value) => {
1705 write!(self.out, "{value}u")?;
1706 }
1707 crate::Literal::I32(value) => {
1708 if value == i32::MIN {
1713 write!(self.out, "({} - 1)", value + 1)?;
1714 } else {
1715 write!(self.out, "{value}")?;
1716 }
1717 }
1718 crate::Literal::U64(value) => {
1719 write!(self.out, "{value}uL")?;
1720 }
1721 crate::Literal::I64(value) => {
1722 if value == i64::MIN {
1729 write!(self.out, "({}L - 1L)", value + 1)?;
1730 } else {
1731 write!(self.out, "{value}L")?;
1732 }
1733 }
1734 crate::Literal::Bool(value) => {
1735 write!(self.out, "{value}")?;
1736 }
1737 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1738 return Err(Error::GenericValidation(
1739 "Unsupported abstract literal".into(),
1740 ));
1741 }
1742 }
1743 Ok(())
1744 }
1745
1746 #[allow(clippy::too_many_arguments)]
1747 fn put_possibly_const_expression<C, I, E>(
1748 &mut self,
1749 expr_handle: Handle<crate::Expression>,
1750 expressions: &crate::Arena<crate::Expression>,
1751 module: &crate::Module,
1752 mod_info: &valid::ModuleInfo,
1753 ctx: &C,
1754 get_expr_ty: I,
1755 put_expression: E,
1756 ) -> BackendResult
1757 where
1758 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1759 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1760 {
1761 match expressions[expr_handle] {
1762 crate::Expression::Literal(literal) => {
1763 self.put_literal(literal)?;
1764 }
1765 crate::Expression::Constant(handle) => {
1766 let constant = &module.constants[handle];
1767 if constant.name.is_some() {
1768 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1769 } else {
1770 self.put_const_expression(
1771 constant.init,
1772 module,
1773 mod_info,
1774 &module.global_expressions,
1775 )?;
1776 }
1777 }
1778 crate::Expression::ZeroValue(ty) => {
1779 let ty_name = TypeContext {
1780 handle: ty,
1781 gctx: module.to_ctx(),
1782 names: &self.names,
1783 access: crate::StorageAccess::empty(),
1784 first_time: false,
1785 };
1786 write!(self.out, "{ty_name} {{}}")?;
1787 }
1788 crate::Expression::Compose { ty, ref components } => {
1789 let ty_name = TypeContext {
1790 handle: ty,
1791 gctx: module.to_ctx(),
1792 names: &self.names,
1793 access: crate::StorageAccess::empty(),
1794 first_time: false,
1795 };
1796 write!(self.out, "{ty_name}")?;
1797 match module.types[ty].inner {
1798 crate::TypeInner::Scalar(_)
1799 | crate::TypeInner::Vector { .. }
1800 | crate::TypeInner::Matrix { .. } => {
1801 self.put_call_parameters_impl(
1802 components.iter().copied(),
1803 ctx,
1804 put_expression,
1805 )?;
1806 }
1807 crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1808 write!(self.out, " {{")?;
1809 for (index, &component) in components.iter().enumerate() {
1810 if index != 0 {
1811 write!(self.out, ", ")?;
1812 }
1813 if self.struct_member_pads.contains(&(ty, index as u32)) {
1815 write!(self.out, "{{}}, ")?;
1816 }
1817 put_expression(self, ctx, component)?;
1818 }
1819 write!(self.out, "}}")?;
1820 }
1821 _ => return Err(Error::UnsupportedCompose(ty)),
1822 }
1823 }
1824 crate::Expression::Splat { size, value } => {
1825 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1826 crate::TypeInner::Scalar(scalar) => scalar,
1827 ref ty => {
1828 return Err(Error::GenericValidation(format!(
1829 "Expected splat value type must be a scalar, got {ty:?}",
1830 )))
1831 }
1832 };
1833 put_numeric_type(&mut self.out, scalar, &[size])?;
1834 write!(self.out, "(")?;
1835 put_expression(self, ctx, value)?;
1836 write!(self.out, ")")?;
1837 }
1838 _ => {
1839 return Err(Error::Override);
1840 }
1841 }
1842
1843 Ok(())
1844 }
1845
1846 fn put_expression(
1858 &mut self,
1859 expr_handle: Handle<crate::Expression>,
1860 context: &ExpressionContext,
1861 is_scoped: bool,
1862 ) -> BackendResult {
1863 #[cfg(test)]
1865 self.put_expression_stack_pointers
1866 .insert(ptr::from_ref(&expr_handle).cast());
1867
1868 if let Some(name) = self.named_expressions.get(&expr_handle) {
1869 write!(self.out, "{name}")?;
1870 return Ok(());
1871 }
1872
1873 let expression = &context.function.expressions[expr_handle];
1874 match *expression {
1875 crate::Expression::Literal(_)
1876 | crate::Expression::Constant(_)
1877 | crate::Expression::ZeroValue(_)
1878 | crate::Expression::Compose { .. }
1879 | crate::Expression::Splat { .. } => {
1880 self.put_possibly_const_expression(
1881 expr_handle,
1882 &context.function.expressions,
1883 context.module,
1884 context.mod_info,
1885 context,
1886 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1887 |writer, context, expr| writer.put_expression(expr, context, true),
1888 )?;
1889 }
1890 crate::Expression::Override(_) => return Err(Error::Override),
1891 crate::Expression::Access { base, .. }
1892 | crate::Expression::AccessIndex { base, .. } => {
1893 let policy = context.choose_bounds_check_policy(base);
1899 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1900 && self.put_bounds_checks(
1901 expr_handle,
1902 context,
1903 back::Level(0),
1904 if is_scoped { "" } else { "(" },
1905 )?
1906 {
1907 write!(self.out, " ? ")?;
1908 self.put_access_chain(expr_handle, policy, context)?;
1909 write!(self.out, " : ")?;
1910
1911 if context.resolve_type(base).pointer_space().is_some() {
1912 let result_ty = context.info[expr_handle]
1916 .ty
1917 .inner_with(&context.module.types)
1918 .pointer_base_type();
1919 let result_ty_handle = match result_ty {
1920 Some(TypeResolution::Handle(handle)) => handle,
1921 Some(TypeResolution::Value(_)) => {
1922 unreachable!(
1930 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1931 );
1932 }
1933 None => {
1934 unreachable!(
1935 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1936 )
1937 }
1938 };
1939 let name_key =
1940 NameKey::oob_local_for_type(context.origin, result_ty_handle);
1941 self.out.write_str(&self.names[&name_key])?;
1942 } else {
1943 write!(self.out, "DefaultConstructible()")?;
1944 }
1945
1946 if !is_scoped {
1947 write!(self.out, ")")?;
1948 }
1949 } else {
1950 self.put_access_chain(expr_handle, policy, context)?;
1951 }
1952 }
1953 crate::Expression::Swizzle {
1954 size,
1955 vector,
1956 pattern,
1957 } => {
1958 self.put_wrapped_expression_for_packed_vec3_access(
1959 vector,
1960 context,
1961 false,
1962 &Self::put_expression,
1963 )?;
1964 write!(self.out, ".")?;
1965 for &sc in pattern[..size as usize].iter() {
1966 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1967 }
1968 }
1969 crate::Expression::FunctionArgument(index) => {
1970 let name_key = match context.origin {
1971 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1972 FunctionOrigin::EntryPoint(ep_index) => {
1973 NameKey::EntryPointArgument(ep_index, index)
1974 }
1975 };
1976 let name = &self.names[&name_key];
1977 write!(self.out, "{name}")?;
1978 }
1979 crate::Expression::GlobalVariable(handle) => {
1980 let name = &self.names[&NameKey::GlobalVariable(handle)];
1981 write!(self.out, "{name}")?;
1982 }
1983 crate::Expression::LocalVariable(handle) => {
1984 let name_key = NameKey::local(context.origin, handle);
1985 let name = &self.names[&name_key];
1986 write!(self.out, "{name}")?;
1987 }
1988 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1989 crate::Expression::ImageSample {
1990 coordinate,
1991 image,
1992 sampler,
1993 clamp_to_edge: true,
1994 gather: None,
1995 array_index: None,
1996 offset: None,
1997 level: crate::SampleLevel::Zero,
1998 depth_ref: None,
1999 } => {
2000 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2001 self.put_expression(image, context, true)?;
2002 write!(self.out, ", ")?;
2003 self.put_expression(sampler, context, true)?;
2004 write!(self.out, ", ")?;
2005 self.put_expression(coordinate, context, true)?;
2006 write!(self.out, ")")?;
2007 }
2008 crate::Expression::ImageSample {
2009 image,
2010 sampler,
2011 gather,
2012 coordinate,
2013 array_index,
2014 offset,
2015 level,
2016 depth_ref,
2017 clamp_to_edge,
2018 } => {
2019 if clamp_to_edge {
2020 return Err(Error::GenericValidation(
2021 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2022 ));
2023 }
2024
2025 let main_op = match gather {
2026 Some(_) => "gather",
2027 None => "sample",
2028 };
2029 let comparison_op = match depth_ref {
2030 Some(_) => "_compare",
2031 None => "",
2032 };
2033 self.put_expression(image, context, false)?;
2034 write!(self.out, ".{main_op}{comparison_op}(")?;
2035 self.put_expression(sampler, context, true)?;
2036 write!(self.out, ", ")?;
2037 self.put_expression(coordinate, context, true)?;
2038 if let Some(expr) = array_index {
2039 write!(self.out, ", ")?;
2040 self.put_expression(expr, context, true)?;
2041 }
2042 if let Some(dref) = depth_ref {
2043 write!(self.out, ", ")?;
2044 self.put_expression(dref, context, true)?;
2045 }
2046
2047 self.put_image_sample_level(image, level, context)?;
2048
2049 if let Some(offset) = offset {
2050 write!(self.out, ", ")?;
2051 self.put_expression(offset, context, true)?;
2052 }
2053
2054 match gather {
2055 None | Some(crate::SwizzleComponent::X) => {}
2056 Some(component) => {
2057 let is_cube_map = match *context.resolve_type(image) {
2058 crate::TypeInner::Image {
2059 dim: crate::ImageDimension::Cube,
2060 ..
2061 } => true,
2062 _ => false,
2063 };
2064 if offset.is_none() && !is_cube_map {
2067 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2068 }
2069 let letter = back::COMPONENTS[component as usize];
2070 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2071 }
2072 }
2073 write!(self.out, ")")?;
2074 }
2075 crate::Expression::ImageLoad {
2076 image,
2077 coordinate,
2078 array_index,
2079 sample,
2080 level,
2081 } => {
2082 let address = TexelAddress {
2083 coordinate,
2084 array_index,
2085 sample,
2086 level: level.map(LevelOfDetail::Direct),
2087 };
2088 self.put_image_load(expr_handle, image, address, context)?;
2089 }
2090 crate::Expression::ImageQuery { image, query } => match query {
2093 crate::ImageQuery::Size { level } => {
2094 self.put_image_size_query(
2095 image,
2096 level.map(LevelOfDetail::Direct),
2097 crate::ScalarKind::Uint,
2098 context,
2099 )?;
2100 }
2101 crate::ImageQuery::NumLevels => {
2102 self.put_expression(image, context, false)?;
2103 write!(self.out, ".get_num_mip_levels()")?;
2104 }
2105 crate::ImageQuery::NumLayers => {
2106 self.put_expression(image, context, false)?;
2107 write!(self.out, ".get_array_size()")?;
2108 }
2109 crate::ImageQuery::NumSamples => {
2110 self.put_expression(image, context, false)?;
2111 write!(self.out, ".get_num_samples()")?;
2112 }
2113 },
2114 crate::Expression::Unary { op, expr } => {
2115 let op_str = match op {
2116 crate::UnaryOperator::Negate => {
2117 match context.resolve_type(expr).scalar_kind() {
2118 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2119 _ => "-",
2120 }
2121 }
2122 crate::UnaryOperator::LogicalNot => "!",
2123 crate::UnaryOperator::BitwiseNot => "~",
2124 };
2125 write!(self.out, "{op_str}(")?;
2126 self.put_expression(expr, context, false)?;
2127 write!(self.out, ")")?;
2128 }
2129 crate::Expression::Binary { op, left, right } => {
2130 let kind = context
2131 .resolve_type(left)
2132 .scalar_kind()
2133 .ok_or(Error::UnsupportedBinaryOp(op))?;
2134
2135 if op == crate::BinaryOperator::Divide
2136 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2137 {
2138 write!(self.out, "{DIV_FUNCTION}(")?;
2139 self.put_expression(left, context, true)?;
2140 write!(self.out, ", ")?;
2141 self.put_expression(right, context, true)?;
2142 write!(self.out, ")")?;
2143 } else if op == crate::BinaryOperator::Modulo
2144 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2145 {
2146 write!(self.out, "{MOD_FUNCTION}(")?;
2147 self.put_expression(left, context, true)?;
2148 write!(self.out, ", ")?;
2149 self.put_expression(right, context, true)?;
2150 write!(self.out, ")")?;
2151 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2152 write!(self.out, "{NAMESPACE}::fmod(")?;
2157 self.put_expression(left, context, true)?;
2158 write!(self.out, ", ")?;
2159 self.put_expression(right, context, true)?;
2160 write!(self.out, ")")?;
2161 } else if (op == crate::BinaryOperator::Add
2162 || op == crate::BinaryOperator::Subtract
2163 || op == crate::BinaryOperator::Multiply)
2164 && kind == crate::ScalarKind::Sint
2165 {
2166 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2167 crate::TypeInner::Scalar(scalar) => {
2168 Ok(crate::TypeInner::Scalar(crate::Scalar {
2169 kind: crate::ScalarKind::Uint,
2170 ..scalar
2171 }))
2172 }
2173 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2174 size,
2175 scalar: crate::Scalar {
2176 kind: crate::ScalarKind::Uint,
2177 ..scalar
2178 },
2179 }),
2180 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2181 };
2182
2183 self.put_bitcasted_expression(
2188 context.resolve_type(expr_handle),
2189 context,
2190 &|writer, context, is_scoped| {
2191 writer.put_binop(
2192 op,
2193 left,
2194 right,
2195 context,
2196 is_scoped,
2197 &|writer, expr, context, _is_scoped| {
2198 writer.put_bitcasted_expression(
2199 &to_unsigned(context.resolve_type(expr))?,
2200 context,
2201 &|writer, context, is_scoped| {
2202 writer.put_expression(expr, context, is_scoped)
2203 },
2204 )
2205 },
2206 )
2207 },
2208 )?;
2209 } else {
2210 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2211 }
2212 }
2213 crate::Expression::Select {
2214 condition,
2215 accept,
2216 reject,
2217 } => match *context.resolve_type(condition) {
2218 crate::TypeInner::Scalar(crate::Scalar {
2219 kind: crate::ScalarKind::Bool,
2220 ..
2221 }) => {
2222 if !is_scoped {
2223 write!(self.out, "(")?;
2224 }
2225 self.put_expression(condition, context, false)?;
2226 write!(self.out, " ? ")?;
2227 self.put_expression(accept, context, false)?;
2228 write!(self.out, " : ")?;
2229 self.put_expression(reject, context, false)?;
2230 if !is_scoped {
2231 write!(self.out, ")")?;
2232 }
2233 }
2234 crate::TypeInner::Vector {
2235 scalar:
2236 crate::Scalar {
2237 kind: crate::ScalarKind::Bool,
2238 ..
2239 },
2240 ..
2241 } => {
2242 write!(self.out, "{NAMESPACE}::select(")?;
2243 self.put_expression(reject, context, true)?;
2244 write!(self.out, ", ")?;
2245 self.put_expression(accept, context, true)?;
2246 write!(self.out, ", ")?;
2247 self.put_expression(condition, context, true)?;
2248 write!(self.out, ")")?;
2249 }
2250 ref ty => {
2251 return Err(Error::GenericValidation(format!(
2252 "Expected select condition to be a non-bool type, got {ty:?}",
2253 )))
2254 }
2255 },
2256 crate::Expression::Derivative { axis, expr, .. } => {
2257 use crate::DerivativeAxis as Axis;
2258 let op = match axis {
2259 Axis::X => "dfdx",
2260 Axis::Y => "dfdy",
2261 Axis::Width => "fwidth",
2262 };
2263 write!(self.out, "{NAMESPACE}::{op}")?;
2264 self.put_call_parameters(iter::once(expr), context)?;
2265 }
2266 crate::Expression::Relational { fun, argument } => {
2267 let op = match fun {
2268 crate::RelationalFunction::Any => "any",
2269 crate::RelationalFunction::All => "all",
2270 crate::RelationalFunction::IsNan => "isnan",
2271 crate::RelationalFunction::IsInf => "isinf",
2272 };
2273 write!(self.out, "{NAMESPACE}::{op}")?;
2274 self.put_call_parameters(iter::once(argument), context)?;
2275 }
2276 crate::Expression::Math {
2277 fun,
2278 arg,
2279 arg1,
2280 arg2,
2281 arg3,
2282 } => {
2283 use crate::MathFunction as Mf;
2284
2285 let arg_type = context.resolve_type(arg);
2286 let scalar_argument = match arg_type {
2287 &crate::TypeInner::Scalar(_) => true,
2288 _ => false,
2289 };
2290
2291 let fun_name = match fun {
2292 Mf::Abs => "abs",
2294 Mf::Min => "min",
2295 Mf::Max => "max",
2296 Mf::Clamp => "clamp",
2297 Mf::Saturate => "saturate",
2298 Mf::Cos => "cos",
2300 Mf::Cosh => "cosh",
2301 Mf::Sin => "sin",
2302 Mf::Sinh => "sinh",
2303 Mf::Tan => "tan",
2304 Mf::Tanh => "tanh",
2305 Mf::Acos => "acos",
2306 Mf::Asin => "asin",
2307 Mf::Atan => "atan",
2308 Mf::Atan2 => "atan2",
2309 Mf::Asinh => "asinh",
2310 Mf::Acosh => "acosh",
2311 Mf::Atanh => "atanh",
2312 Mf::Radians => "",
2313 Mf::Degrees => "",
2314 Mf::Ceil => "ceil",
2316 Mf::Floor => "floor",
2317 Mf::Round => "rint",
2318 Mf::Fract => "fract",
2319 Mf::Trunc => "trunc",
2320 Mf::Modf => MODF_FUNCTION,
2321 Mf::Frexp => FREXP_FUNCTION,
2322 Mf::Ldexp => "ldexp",
2323 Mf::Exp => "exp",
2325 Mf::Exp2 => "exp2",
2326 Mf::Log => "log",
2327 Mf::Log2 => "log2",
2328 Mf::Pow => "pow",
2329 Mf::Dot => match *context.resolve_type(arg) {
2331 crate::TypeInner::Vector {
2332 scalar:
2333 crate::Scalar {
2334 kind: crate::ScalarKind::Float,
2335 ..
2336 },
2337 ..
2338 } => "dot",
2339 crate::TypeInner::Vector { size, .. } => {
2340 return self.put_dot_product(
2341 arg,
2342 arg1.unwrap(),
2343 size as usize,
2344 |writer, arg, index| {
2345 writer.put_expression(arg, context, true)?;
2349 write!(writer.out, ".{}", back::COMPONENTS[index])?;
2351 Ok(())
2352 },
2353 );
2354 }
2355 _ => unreachable!(
2356 "Correct TypeInner for dot product should be already validated"
2357 ),
2358 },
2359 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2360 if context.lang_version >= (2, 1) {
2361 let packed_type = match fun {
2365 Mf::Dot4I8Packed => "packed_char4",
2366 Mf::Dot4U8Packed => "packed_uchar4",
2367 _ => unreachable!(),
2368 };
2369
2370 return self.put_dot_product(
2371 Reinterpreted::new(packed_type, arg),
2372 Reinterpreted::new(packed_type, arg1.unwrap()),
2373 4,
2374 |writer, arg, index| {
2375 write!(writer.out, "{arg}[{index}]")?;
2378 Ok(())
2379 },
2380 );
2381 } else {
2382 let conversion = match fun {
2386 Mf::Dot4I8Packed => "int",
2387 Mf::Dot4U8Packed => "",
2388 _ => unreachable!(),
2389 };
2390
2391 return self.put_dot_product(
2392 arg,
2393 arg1.unwrap(),
2394 4,
2395 |writer, arg, index| {
2396 write!(writer.out, "({conversion}(")?;
2397 writer.put_expression(arg, context, true)?;
2398 if index == 3 {
2399 write!(writer.out, ") >> 24)")?;
2400 } else {
2401 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2402 }
2403 Ok(())
2404 },
2405 );
2406 }
2407 }
2408 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2409 Mf::Cross => "cross",
2410 Mf::Distance => "distance",
2411 Mf::Length if scalar_argument => "abs",
2412 Mf::Length => "length",
2413 Mf::Normalize => "normalize",
2414 Mf::FaceForward => "faceforward",
2415 Mf::Reflect => "reflect",
2416 Mf::Refract => "refract",
2417 Mf::Sign => match arg_type.scalar_kind() {
2419 Some(crate::ScalarKind::Sint) => {
2420 return self.put_isign(arg, context);
2421 }
2422 _ => "sign",
2423 },
2424 Mf::Fma => "fma",
2425 Mf::Mix => "mix",
2426 Mf::Step => "step",
2427 Mf::SmoothStep => "smoothstep",
2428 Mf::Sqrt => "sqrt",
2429 Mf::InverseSqrt => "rsqrt",
2430 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2431 Mf::Transpose => "transpose",
2432 Mf::Determinant => "determinant",
2433 Mf::QuantizeToF16 => "",
2434 Mf::CountTrailingZeros => "ctz",
2436 Mf::CountLeadingZeros => "clz",
2437 Mf::CountOneBits => "popcount",
2438 Mf::ReverseBits => "reverse_bits",
2439 Mf::ExtractBits => "",
2440 Mf::InsertBits => "",
2441 Mf::FirstTrailingBit => "",
2442 Mf::FirstLeadingBit => "",
2443 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2445 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2446 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2447 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2448 Mf::Pack2x16float => "",
2449 Mf::Pack4xI8 => "",
2450 Mf::Pack4xU8 => "",
2451 Mf::Pack4xI8Clamp => "",
2452 Mf::Pack4xU8Clamp => "",
2453 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2455 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2456 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2457 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2458 Mf::Unpack2x16float => "",
2459 Mf::Unpack4xI8 => "",
2460 Mf::Unpack4xU8 => "",
2461 };
2462
2463 match fun {
2464 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2465 if context.lang_version < (1, 2) {
2474 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2475 }
2476 }
2477 _ => {}
2478 }
2479
2480 match fun {
2481 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2482 write!(self.out, "{ABS_FUNCTION}(")?;
2483 self.put_expression(arg, context, true)?;
2484 write!(self.out, ")")?;
2485 }
2486 Mf::Distance if scalar_argument => {
2487 write!(self.out, "{NAMESPACE}::abs(")?;
2488 self.put_expression(arg, context, false)?;
2489 write!(self.out, " - ")?;
2490 self.put_expression(arg1.unwrap(), context, false)?;
2491 write!(self.out, ")")?;
2492 }
2493 Mf::FirstTrailingBit => {
2494 let scalar = context.resolve_type(arg).scalar().unwrap();
2495 let constant = scalar.width * 8 + 1;
2496
2497 write!(self.out, "((({NAMESPACE}::ctz(")?;
2498 self.put_expression(arg, context, true)?;
2499 write!(self.out, ") + 1) % {constant}) - 1)")?;
2500 }
2501 Mf::FirstLeadingBit => {
2502 let inner = context.resolve_type(arg);
2503 let scalar = inner.scalar().unwrap();
2504 let constant = scalar.width * 8 - 1;
2505
2506 write!(
2507 self.out,
2508 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2509 )?;
2510
2511 if scalar.kind == crate::ScalarKind::Sint {
2512 write!(self.out, "{NAMESPACE}::select(")?;
2513 self.put_expression(arg, context, true)?;
2514 write!(self.out, ", ~")?;
2515 self.put_expression(arg, context, true)?;
2516 write!(self.out, ", ")?;
2517 self.put_expression(arg, context, true)?;
2518 write!(self.out, " < 0)")?;
2519 } else {
2520 self.put_expression(arg, context, true)?;
2521 }
2522
2523 write!(self.out, "), ")?;
2524
2525 match *inner {
2527 crate::TypeInner::Vector { size, scalar } => {
2528 let size = common::vector_size_str(size);
2529 let name = scalar.to_msl_name();
2530 write!(self.out, "{name}{size}")?;
2531 }
2532 crate::TypeInner::Scalar(scalar) => {
2533 let name = scalar.to_msl_name();
2534 write!(self.out, "{name}")?;
2535 }
2536 _ => (),
2537 }
2538
2539 write!(self.out, "(-1), ")?;
2540 self.put_expression(arg, context, true)?;
2541 write!(self.out, " == 0 || ")?;
2542 self.put_expression(arg, context, true)?;
2543 write!(self.out, " == -1)")?;
2544 }
2545 Mf::Unpack2x16float => {
2546 write!(self.out, "float2(as_type<half2>(")?;
2547 self.put_expression(arg, context, false)?;
2548 write!(self.out, "))")?;
2549 }
2550 Mf::Pack2x16float => {
2551 write!(self.out, "as_type<uint>(half2(")?;
2552 self.put_expression(arg, context, false)?;
2553 write!(self.out, "))")?;
2554 }
2555 Mf::ExtractBits => {
2556 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2573
2574 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2575 self.put_expression(arg, context, true)?;
2576 write!(self.out, ", {NAMESPACE}::min(")?;
2577 self.put_expression(arg1.unwrap(), context, true)?;
2578 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2579 self.put_expression(arg2.unwrap(), context, true)?;
2580 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2581 self.put_expression(arg1.unwrap(), context, true)?;
2582 write!(self.out, ", {scalar_bits}u)))")?;
2583 }
2584 Mf::InsertBits => {
2585 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2590
2591 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2592 self.put_expression(arg, context, true)?;
2593 write!(self.out, ", ")?;
2594 self.put_expression(arg1.unwrap(), context, true)?;
2595 write!(self.out, ", {NAMESPACE}::min(")?;
2596 self.put_expression(arg2.unwrap(), context, true)?;
2597 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2598 self.put_expression(arg3.unwrap(), context, true)?;
2599 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2600 self.put_expression(arg2.unwrap(), context, true)?;
2601 write!(self.out, ", {scalar_bits}u)))")?;
2602 }
2603 Mf::Radians => {
2604 write!(self.out, "((")?;
2605 self.put_expression(arg, context, false)?;
2606 write!(self.out, ") * 0.017453292519943295474)")?;
2607 }
2608 Mf::Degrees => {
2609 write!(self.out, "((")?;
2610 self.put_expression(arg, context, false)?;
2611 write!(self.out, ") * 57.295779513082322865)")?;
2612 }
2613 Mf::Modf | Mf::Frexp => {
2614 write!(self.out, "{fun_name}")?;
2615 self.put_call_parameters(iter::once(arg), context)?;
2616 }
2617 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2618 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2619 Mf::Pack4xI8Clamp => {
2620 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2621 }
2622 Mf::Pack4xU8Clamp => {
2623 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2624 }
2625 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2626 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2627 "u"
2628 } else {
2629 ""
2630 };
2631
2632 if context.lang_version >= (2, 1) {
2633 write!(
2635 self.out,
2636 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2637 )?;
2638 self.put_expression(arg, context, true)?;
2639 write!(self.out, "))")?;
2640 } else {
2641 write!(self.out, "({sign_prefix}int4(")?;
2643 self.put_expression(arg, context, true)?;
2644 write!(self.out, ", ")?;
2645 self.put_expression(arg, context, true)?;
2646 write!(self.out, " >> 8, ")?;
2647 self.put_expression(arg, context, true)?;
2648 write!(self.out, " >> 16, ")?;
2649 self.put_expression(arg, context, true)?;
2650 write!(self.out, " >> 24) << 24 >> 24)")?;
2651 }
2652 }
2653 Mf::QuantizeToF16 => {
2654 match *context.resolve_type(arg) {
2655 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2656 crate::TypeInner::Vector { size, .. } => write!(
2657 self.out,
2658 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2659 size = common::vector_size_str(size),
2660 )?,
2661 _ => unreachable!(
2662 "Correct TypeInner for QuantizeToF16 should be already validated"
2663 ),
2664 };
2665
2666 self.put_expression(arg, context, true)?;
2667 write!(self.out, "))")?;
2668 }
2669 _ => {
2670 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2671 self.put_call_parameters(
2672 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2673 context,
2674 )?;
2675 }
2676 }
2677 }
2678 crate::Expression::As {
2679 expr,
2680 kind,
2681 convert,
2682 } => match *context.resolve_type(expr) {
2683 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2684 if src.kind == crate::ScalarKind::Float
2685 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2686 && convert.is_some()
2687 {
2688 let fun_name = match (kind, convert) {
2692 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2693 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2694 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2695 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2696 _ => unreachable!(),
2697 };
2698 write!(self.out, "{fun_name}(")?;
2699 self.put_expression(expr, context, true)?;
2700 write!(self.out, ")")?;
2701 } else {
2702 let target_scalar = crate::Scalar {
2703 kind,
2704 width: convert.unwrap_or(src.width),
2705 };
2706 let op = match convert {
2707 Some(_) => "static_cast",
2708 None => "as_type",
2709 };
2710 write!(self.out, "{op}<")?;
2711 match *context.resolve_type(expr) {
2712 crate::TypeInner::Vector { size, .. } => {
2713 put_numeric_type(&mut self.out, target_scalar, &[size])?
2714 }
2715 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2716 };
2717 write!(self.out, ">(")?;
2718 self.put_expression(expr, context, true)?;
2719 write!(self.out, ")")?;
2720 }
2721 }
2722 crate::TypeInner::Matrix {
2723 columns,
2724 rows,
2725 scalar,
2726 } => {
2727 let target_scalar = crate::Scalar {
2728 kind,
2729 width: convert.unwrap_or(scalar.width),
2730 };
2731 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2732 write!(self.out, "(")?;
2733 self.put_expression(expr, context, true)?;
2734 write!(self.out, ")")?;
2735 }
2736 ref ty => {
2737 return Err(Error::GenericValidation(format!(
2738 "Unsupported type for As: {ty:?}"
2739 )))
2740 }
2741 },
2742 crate::Expression::CallResult(_)
2744 | crate::Expression::AtomicResult { .. }
2745 | crate::Expression::WorkGroupUniformLoadResult { .. }
2746 | crate::Expression::SubgroupBallotResult
2747 | crate::Expression::SubgroupOperationResult { .. }
2748 | crate::Expression::RayQueryProceedResult => {
2749 unreachable!()
2750 }
2751 crate::Expression::ArrayLength(expr) => {
2752 let global = match context.function.expressions[expr] {
2754 crate::Expression::AccessIndex { base, .. } => {
2755 match context.function.expressions[base] {
2756 crate::Expression::GlobalVariable(handle) => handle,
2757 ref ex => {
2758 return Err(Error::GenericValidation(format!(
2759 "Expected global variable in AccessIndex, got {ex:?}"
2760 )))
2761 }
2762 }
2763 }
2764 crate::Expression::GlobalVariable(handle) => handle,
2765 ref ex => {
2766 return Err(Error::GenericValidation(format!(
2767 "Unexpected expression in ArrayLength, got {ex:?}"
2768 )))
2769 }
2770 };
2771
2772 if !is_scoped {
2773 write!(self.out, "(")?;
2774 }
2775 write!(self.out, "1 + ")?;
2776 self.put_dynamic_array_max_index(global, context)?;
2777 if !is_scoped {
2778 write!(self.out, ")")?;
2779 }
2780 }
2781 crate::Expression::RayQueryVertexPositions { .. } => {
2782 unimplemented!()
2783 }
2784 crate::Expression::RayQueryGetIntersection {
2785 query,
2786 committed: _,
2787 } => {
2788 if context.lang_version < (2, 4) {
2789 return Err(Error::UnsupportedRayTracing);
2790 }
2791
2792 let ty = context.module.special_types.ray_intersection.unwrap();
2793 let type_name = &self.names[&NameKey::Type(ty)];
2794 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2795 self.put_expression(query, context, true)?;
2796 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2797 let fields = [
2798 "distance",
2799 "user_instance_id", "instance_id",
2801 "", "geometry_id",
2803 "primitive_id",
2804 "triangle_barycentric_coord",
2805 "triangle_front_facing",
2806 "", "object_to_world_transform", "world_to_object_transform", ];
2810 for field in fields {
2811 write!(self.out, ", ")?;
2812 if field.is_empty() {
2813 write!(self.out, "{{}}")?;
2814 } else {
2815 self.put_expression(query, context, true)?;
2816 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2817 }
2818 }
2819 write!(self.out, "}}")?;
2820 }
2821 }
2822 Ok(())
2823 }
2824
2825 fn put_binop<F>(
2828 &mut self,
2829 op: crate::BinaryOperator,
2830 left: Handle<crate::Expression>,
2831 right: Handle<crate::Expression>,
2832 context: &ExpressionContext,
2833 is_scoped: bool,
2834 put_expression: &F,
2835 ) -> BackendResult
2836 where
2837 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2838 {
2839 let op_str = back::binary_operation_str(op);
2840
2841 if !is_scoped {
2842 write!(self.out, "(")?;
2843 }
2844
2845 if op == crate::BinaryOperator::Multiply
2848 && matches!(
2849 context.resolve_type(right),
2850 &crate::TypeInner::Matrix { .. }
2851 )
2852 {
2853 self.put_wrapped_expression_for_packed_vec3_access(
2854 left,
2855 context,
2856 false,
2857 put_expression,
2858 )?;
2859 } else {
2860 put_expression(self, left, context, false)?;
2861 }
2862
2863 write!(self.out, " {op_str} ")?;
2864
2865 if op == crate::BinaryOperator::Multiply
2867 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2868 {
2869 self.put_wrapped_expression_for_packed_vec3_access(
2870 right,
2871 context,
2872 false,
2873 put_expression,
2874 )?;
2875 } else {
2876 put_expression(self, right, context, false)?;
2877 }
2878
2879 if !is_scoped {
2880 write!(self.out, ")")?;
2881 }
2882
2883 Ok(())
2884 }
2885
2886 fn put_wrapped_expression_for_packed_vec3_access<F>(
2888 &mut self,
2889 expr_handle: Handle<crate::Expression>,
2890 context: &ExpressionContext,
2891 is_scoped: bool,
2892 put_expression: &F,
2893 ) -> BackendResult
2894 where
2895 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2896 {
2897 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2898 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2899 put_expression(self, expr_handle, context, is_scoped)?;
2900 write!(self.out, ")")?;
2901 } else {
2902 put_expression(self, expr_handle, context, is_scoped)?;
2903 }
2904 Ok(())
2905 }
2906
2907 fn put_bitcasted_expression<F>(
2910 &mut self,
2911 cast_to: &crate::TypeInner,
2912 context: &ExpressionContext,
2913 put_expression: &F,
2914 ) -> BackendResult
2915 where
2916 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2917 {
2918 write!(self.out, "as_type<")?;
2919 match *cast_to {
2920 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2921 crate::TypeInner::Vector { size, scalar } => {
2922 put_numeric_type(&mut self.out, scalar, &[size])?
2923 }
2924 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2925 };
2926 write!(self.out, ">(")?;
2927 put_expression(self, context, true)?;
2928 write!(self.out, ")")?;
2929
2930 Ok(())
2931 }
2932
2933 fn put_index(
2935 &mut self,
2936 index: index::GuardedIndex,
2937 context: &ExpressionContext,
2938 is_scoped: bool,
2939 ) -> BackendResult {
2940 match index {
2941 index::GuardedIndex::Expression(expr) => {
2942 self.put_expression(expr, context, is_scoped)?
2943 }
2944 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2945 }
2946 Ok(())
2947 }
2948
2949 #[allow(unused_variables)]
2979 fn put_bounds_checks(
2980 &mut self,
2981 chain: Handle<crate::Expression>,
2982 context: &ExpressionContext,
2983 level: back::Level,
2984 prefix: &'static str,
2985 ) -> Result<bool, Error> {
2986 let mut check_written = false;
2987
2988 for item in context.bounds_check_iter(chain) {
2990 let BoundsCheck {
2991 base,
2992 index,
2993 length,
2994 } = item;
2995
2996 if check_written {
2997 write!(self.out, " && ")?;
2998 } else {
2999 write!(self.out, "{level}{prefix}")?;
3000 check_written = true;
3001 }
3002
3003 write!(self.out, "uint(")?;
3007 self.put_index(index, context, true)?;
3008 self.out.write_str(") < ")?;
3009 match length {
3010 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3011 index::IndexableLength::Dynamic => {
3012 let global = context.function.originating_global(base).ok_or_else(|| {
3013 Error::GenericValidation("Could not find originating global".into())
3014 })?;
3015 write!(self.out, "1 + ")?;
3016 self.put_dynamic_array_max_index(global, context)?
3017 }
3018 }
3019 }
3020
3021 Ok(check_written)
3022 }
3023
3024 fn put_access_chain(
3044 &mut self,
3045 chain: Handle<crate::Expression>,
3046 policy: index::BoundsCheckPolicy,
3047 context: &ExpressionContext,
3048 ) -> BackendResult {
3049 match context.function.expressions[chain] {
3050 crate::Expression::Access { base, index } => {
3051 let mut base_ty = context.resolve_type(base);
3052
3053 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3055 base_ty = &context.module.types[base].inner;
3056 }
3057
3058 self.put_subscripted_access_chain(
3059 base,
3060 base_ty,
3061 index::GuardedIndex::Expression(index),
3062 policy,
3063 context,
3064 )?;
3065 }
3066 crate::Expression::AccessIndex { base, index } => {
3067 let base_resolution = &context.info[base].ty;
3068 let mut base_ty = base_resolution.inner_with(&context.module.types);
3069 let mut base_ty_handle = base_resolution.handle();
3070
3071 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3073 base_ty = &context.module.types[base].inner;
3074 base_ty_handle = Some(base);
3075 }
3076
3077 match *base_ty {
3081 crate::TypeInner::Struct { .. } => {
3082 let base_ty = base_ty_handle.unwrap();
3083 self.put_access_chain(base, policy, context)?;
3084 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3085 write!(self.out, ".{name}")?;
3086 }
3087 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3088 self.put_access_chain(base, policy, context)?;
3089 if context.get_packed_vec_kind(base).is_some() {
3092 write!(self.out, "[{index}]")?;
3093 } else {
3094 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3095 }
3096 }
3097 _ => {
3098 self.put_subscripted_access_chain(
3099 base,
3100 base_ty,
3101 index::GuardedIndex::Known(index),
3102 policy,
3103 context,
3104 )?;
3105 }
3106 }
3107 }
3108 _ => self.put_expression(chain, context, false)?,
3109 }
3110
3111 Ok(())
3112 }
3113
3114 fn put_subscripted_access_chain(
3131 &mut self,
3132 base: Handle<crate::Expression>,
3133 base_ty: &crate::TypeInner,
3134 index: index::GuardedIndex,
3135 policy: index::BoundsCheckPolicy,
3136 context: &ExpressionContext,
3137 ) -> BackendResult {
3138 let accessing_wrapped_array = match *base_ty {
3139 crate::TypeInner::Array {
3140 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3141 ..
3142 } => true,
3143 _ => false,
3144 };
3145 let accessing_wrapped_binding_array =
3146 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3147
3148 self.put_access_chain(base, policy, context)?;
3149 if accessing_wrapped_array {
3150 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3151 }
3152 write!(self.out, "[")?;
3153
3154 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3156 context.access_needs_check(base, index)
3157 } else {
3158 None
3159 };
3160 if let Some(limit) = restriction_needed {
3161 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3162 self.put_index(index, context, true)?;
3163 write!(self.out, "), ")?;
3164 match limit {
3165 index::IndexableLength::Known(limit) => {
3166 write!(self.out, "{}u", limit - 1)?;
3167 }
3168 index::IndexableLength::Dynamic => {
3169 let global = context.function.originating_global(base).ok_or_else(|| {
3170 Error::GenericValidation("Could not find originating global".into())
3171 })?;
3172 self.put_dynamic_array_max_index(global, context)?;
3173 }
3174 }
3175 write!(self.out, ")")?;
3176 } else {
3177 self.put_index(index, context, true)?;
3178 }
3179
3180 write!(self.out, "]")?;
3181
3182 if accessing_wrapped_binding_array {
3183 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3184 }
3185
3186 Ok(())
3187 }
3188
3189 fn put_load(
3190 &mut self,
3191 pointer: Handle<crate::Expression>,
3192 context: &ExpressionContext,
3193 is_scoped: bool,
3194 ) -> BackendResult {
3195 let policy = context.choose_bounds_check_policy(pointer);
3198 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3199 && self.put_bounds_checks(
3200 pointer,
3201 context,
3202 back::Level(0),
3203 if is_scoped { "" } else { "(" },
3204 )?
3205 {
3206 write!(self.out, " ? ")?;
3207 self.put_unchecked_load(pointer, policy, context)?;
3208 write!(self.out, " : DefaultConstructible()")?;
3209
3210 if !is_scoped {
3211 write!(self.out, ")")?;
3212 }
3213 } else {
3214 self.put_unchecked_load(pointer, policy, context)?;
3215 }
3216
3217 Ok(())
3218 }
3219
3220 fn put_unchecked_load(
3221 &mut self,
3222 pointer: Handle<crate::Expression>,
3223 policy: index::BoundsCheckPolicy,
3224 context: &ExpressionContext,
3225 ) -> BackendResult {
3226 let is_atomic_pointer = context
3227 .resolve_type(pointer)
3228 .is_atomic_pointer(&context.module.types);
3229
3230 if is_atomic_pointer {
3231 write!(
3232 self.out,
3233 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3234 )?;
3235 self.put_access_chain(pointer, policy, context)?;
3236 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3237 } else {
3238 self.put_access_chain(pointer, policy, context)?;
3242 }
3243
3244 Ok(())
3245 }
3246
3247 fn put_return_value(
3248 &mut self,
3249 level: back::Level,
3250 expr_handle: Handle<crate::Expression>,
3251 result_struct: Option<&str>,
3252 context: &ExpressionContext,
3253 ) -> BackendResult {
3254 match result_struct {
3255 Some(struct_name) => {
3256 let mut has_point_size = false;
3257 let result_ty = context.function.result.as_ref().unwrap().ty;
3258 match context.module.types[result_ty].inner {
3259 crate::TypeInner::Struct { ref members, .. } => {
3260 let tmp = "_tmp";
3261 write!(self.out, "{level}const auto {tmp} = ")?;
3262 self.put_expression(expr_handle, context, true)?;
3263 writeln!(self.out, ";")?;
3264 write!(self.out, "{level}return {struct_name} {{")?;
3265
3266 let mut is_first = true;
3267
3268 for (index, member) in members.iter().enumerate() {
3269 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3270 member.binding
3271 {
3272 has_point_size = true;
3273 if !context.pipeline_options.allow_and_force_point_size {
3274 continue;
3275 }
3276 }
3277
3278 let comma = if is_first { "" } else { "," };
3279 is_first = false;
3280 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3281 if let crate::TypeInner::Array {
3285 size: crate::ArraySize::Constant(size),
3286 ..
3287 } = context.module.types[member.ty].inner
3288 {
3289 write!(self.out, "{comma} {{")?;
3290 for j in 0..size.get() {
3291 if j != 0 {
3292 write!(self.out, ",")?;
3293 }
3294 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3295 }
3296 write!(self.out, "}}")?;
3297 } else {
3298 write!(self.out, "{comma} {tmp}.{name}")?;
3299 }
3300 }
3301 }
3302 _ => {
3303 write!(self.out, "{level}return {struct_name} {{ ")?;
3304 self.put_expression(expr_handle, context, true)?;
3305 }
3306 }
3307
3308 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3309 let stage = context.module.entry_points[ep_index as usize].stage;
3310 if context.pipeline_options.allow_and_force_point_size
3311 && stage == crate::ShaderStage::Vertex
3312 && !has_point_size
3313 {
3314 write!(self.out, ", 1.0")?;
3316 }
3317 }
3318 write!(self.out, " }}")?;
3319 }
3320 None => {
3321 write!(self.out, "{level}return ")?;
3322 self.put_expression(expr_handle, context, true)?;
3323 }
3324 }
3325 writeln!(self.out, ";")?;
3326 Ok(())
3327 }
3328
3329 fn update_expressions_to_bake(
3334 &mut self,
3335 func: &crate::Function,
3336 info: &valid::FunctionInfo,
3337 context: &ExpressionContext,
3338 ) {
3339 use crate::Expression;
3340 self.need_bake_expressions.clear();
3341
3342 for (expr_handle, expr) in func.expressions.iter() {
3343 let expr_info = &info[expr_handle];
3346 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3347 if min_ref_count <= expr_info.ref_count {
3348 self.need_bake_expressions.insert(expr_handle);
3349 } else {
3350 match expr_info.ty {
3351 TypeResolution::Handle(h)
3353 if Some(h) == context.module.special_types.ray_desc =>
3354 {
3355 self.need_bake_expressions.insert(expr_handle);
3356 }
3357 _ => {}
3358 }
3359 }
3360
3361 if let Expression::Math {
3362 fun,
3363 arg,
3364 arg1,
3365 arg2,
3366 ..
3367 } = *expr
3368 {
3369 match fun {
3370 crate::MathFunction::Dot => {
3371 let inner = context.resolve_type(expr_handle);
3380 if let crate::TypeInner::Scalar(scalar) = *inner {
3381 match scalar.kind {
3382 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3383 self.need_bake_expressions.insert(arg);
3384 self.need_bake_expressions.insert(arg1.unwrap());
3385 }
3386 _ => {}
3387 }
3388 }
3389 }
3390 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3391 self.need_bake_expressions.insert(arg);
3392 self.need_bake_expressions.insert(arg1.unwrap());
3393 }
3394 crate::MathFunction::FirstLeadingBit => {
3395 self.need_bake_expressions.insert(arg);
3396 }
3397 crate::MathFunction::Pack4xI8
3398 | crate::MathFunction::Pack4xU8
3399 | crate::MathFunction::Pack4xI8Clamp
3400 | crate::MathFunction::Pack4xU8Clamp
3401 | crate::MathFunction::Unpack4xI8
3402 | crate::MathFunction::Unpack4xU8 => {
3403 if context.lang_version < (2, 1) {
3406 self.need_bake_expressions.insert(arg);
3407 }
3408 }
3409 crate::MathFunction::ExtractBits => {
3410 self.need_bake_expressions.insert(arg1.unwrap());
3412 }
3413 crate::MathFunction::InsertBits => {
3414 self.need_bake_expressions.insert(arg2.unwrap());
3416 }
3417 crate::MathFunction::Sign => {
3418 let inner = context.resolve_type(expr_handle);
3423 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3424 self.need_bake_expressions.insert(arg);
3425 }
3426 }
3427 _ => {}
3428 }
3429 }
3430 }
3431 }
3432
3433 fn start_baking_expression(
3434 &mut self,
3435 handle: Handle<crate::Expression>,
3436 context: &ExpressionContext,
3437 name: &str,
3438 ) -> BackendResult {
3439 match context.info[handle].ty {
3440 TypeResolution::Handle(ty_handle) => {
3441 let ty_name = TypeContext {
3442 handle: ty_handle,
3443 gctx: context.module.to_ctx(),
3444 names: &self.names,
3445 access: crate::StorageAccess::empty(),
3446 first_time: false,
3447 };
3448 write!(self.out, "{ty_name}")?;
3449 }
3450 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3451 put_numeric_type(&mut self.out, scalar, &[])?;
3452 }
3453 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3454 put_numeric_type(&mut self.out, scalar, &[size])?;
3455 }
3456 TypeResolution::Value(crate::TypeInner::Matrix {
3457 columns,
3458 rows,
3459 scalar,
3460 }) => {
3461 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3462 }
3463 TypeResolution::Value(ref other) => {
3464 log::warn!("Type {other:?} isn't a known local"); return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3466 }
3467 }
3468
3469 write!(self.out, " {name} = ")?;
3471
3472 Ok(())
3473 }
3474
3475 fn put_cache_restricted_level(
3488 &mut self,
3489 load: Handle<crate::Expression>,
3490 image: Handle<crate::Expression>,
3491 mip_level: Option<Handle<crate::Expression>>,
3492 indent: back::Level,
3493 context: &StatementContext,
3494 ) -> BackendResult {
3495 let level_of_detail = match mip_level {
3498 Some(level) => level,
3499 None => return Ok(()),
3500 };
3501
3502 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3503 || !context.expression.image_needs_lod(image)
3504 {
3505 return Ok(());
3506 }
3507
3508 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3509 self.put_restricted_scalar_image_index(
3510 image,
3511 level_of_detail,
3512 "get_num_mip_levels",
3513 &context.expression,
3514 )?;
3515 writeln!(self.out, ";")?;
3516
3517 Ok(())
3518 }
3519
3520 fn put_casting_to_packed_chars(
3526 &mut self,
3527 fun: crate::MathFunction,
3528 arg0: Handle<crate::Expression>,
3529 arg1: Handle<crate::Expression>,
3530 indent: back::Level,
3531 context: &StatementContext<'_>,
3532 ) -> Result<(), Error> {
3533 let packed_type = match fun {
3534 crate::MathFunction::Dot4I8Packed => "packed_char4",
3535 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3536 _ => unreachable!(),
3537 };
3538
3539 for arg in [arg0, arg1] {
3540 write!(
3541 self.out,
3542 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3543 Reinterpreted::new(packed_type, arg)
3544 )?;
3545 self.put_expression(arg, &context.expression, true)?;
3546 writeln!(self.out, ");")?;
3547 }
3548
3549 Ok(())
3550 }
3551
3552 fn put_block(
3553 &mut self,
3554 level: back::Level,
3555 statements: &[crate::Statement],
3556 context: &StatementContext,
3557 ) -> BackendResult {
3558 #[cfg(test)]
3560 self.put_block_stack_pointers
3561 .insert(ptr::from_ref(&level).cast());
3562
3563 for statement in statements {
3564 log::trace!("statement[{}] {:?}", level.0, statement);
3565 match *statement {
3566 crate::Statement::Emit(ref range) => {
3567 for handle in range.clone() {
3568 use crate::MathFunction as Mf;
3569
3570 match context.expression.function.expressions[handle] {
3571 crate::Expression::ImageLoad {
3574 image,
3575 level: mip_level,
3576 ..
3577 } => {
3578 self.put_cache_restricted_level(
3579 handle, image, mip_level, level, context,
3580 )?;
3581 }
3582
3583 crate::Expression::Math {
3592 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3593 arg,
3594 arg1,
3595 ..
3596 } if context.expression.lang_version >= (2, 1) => {
3597 self.put_casting_to_packed_chars(
3598 fun,
3599 arg,
3600 arg1.unwrap(),
3601 level,
3602 context,
3603 )?;
3604 }
3605
3606 _ => (),
3607 }
3608
3609 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3610 let expr_name = if ptr_class.is_some() {
3611 None } else if let Some(name) =
3613 context.expression.function.named_expressions.get(&handle)
3614 {
3615 Some(self.namer.call(name))
3625 } else {
3626 let bake = if context.expression.guarded_indices.contains(handle) {
3630 true
3631 } else {
3632 self.need_bake_expressions.contains(&handle)
3633 };
3634
3635 if bake {
3636 Some(Baked(handle).to_string())
3637 } else {
3638 None
3639 }
3640 };
3641
3642 if let Some(name) = expr_name {
3643 write!(self.out, "{level}")?;
3644 self.start_baking_expression(handle, &context.expression, &name)?;
3645 self.put_expression(handle, &context.expression, true)?;
3646 self.named_expressions.insert(handle, name);
3647 writeln!(self.out, ";")?;
3648 }
3649 }
3650 }
3651 crate::Statement::Block(ref block) => {
3652 if !block.is_empty() {
3653 writeln!(self.out, "{level}{{")?;
3654 self.put_block(level.next(), block, context)?;
3655 writeln!(self.out, "{level}}}")?;
3656 }
3657 }
3658 crate::Statement::If {
3659 condition,
3660 ref accept,
3661 ref reject,
3662 } => {
3663 write!(self.out, "{level}if (")?;
3664 self.put_expression(condition, &context.expression, true)?;
3665 writeln!(self.out, ") {{")?;
3666 self.put_block(level.next(), accept, context)?;
3667 if !reject.is_empty() {
3668 writeln!(self.out, "{level}}} else {{")?;
3669 self.put_block(level.next(), reject, context)?;
3670 }
3671 writeln!(self.out, "{level}}}")?;
3672 }
3673 crate::Statement::Switch {
3674 selector,
3675 ref cases,
3676 } => {
3677 write!(self.out, "{level}switch(")?;
3678 self.put_expression(selector, &context.expression, true)?;
3679 writeln!(self.out, ") {{")?;
3680 let lcase = level.next();
3681 for case in cases.iter() {
3682 match case.value {
3683 crate::SwitchValue::I32(value) => {
3684 write!(self.out, "{lcase}case {value}:")?;
3685 }
3686 crate::SwitchValue::U32(value) => {
3687 write!(self.out, "{lcase}case {value}u:")?;
3688 }
3689 crate::SwitchValue::Default => {
3690 write!(self.out, "{lcase}default:")?;
3691 }
3692 }
3693
3694 let write_block_braces = !(case.fall_through && case.body.is_empty());
3695 if write_block_braces {
3696 writeln!(self.out, " {{")?;
3697 } else {
3698 writeln!(self.out)?;
3699 }
3700
3701 self.put_block(lcase.next(), &case.body, context)?;
3702 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3703 {
3704 writeln!(self.out, "{}break;", lcase.next())?;
3705 }
3706
3707 if write_block_braces {
3708 writeln!(self.out, "{lcase}}}")?;
3709 }
3710 }
3711 writeln!(self.out, "{level}}}")?;
3712 }
3713 crate::Statement::Loop {
3714 ref body,
3715 ref continuing,
3716 break_if,
3717 } => {
3718 let force_loop_bound_statements =
3719 self.gen_force_bounded_loop_statements(level, context);
3720 let gate_name = (!continuing.is_empty() || break_if.is_some())
3721 .then(|| self.namer.call("loop_init"));
3722
3723 if let Some((ref decl, _)) = force_loop_bound_statements {
3724 writeln!(self.out, "{decl}")?;
3725 }
3726 if let Some(ref gate_name) = gate_name {
3727 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3728 }
3729
3730 writeln!(self.out, "{level}while(true) {{",)?;
3731 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3732 writeln!(self.out, "{break_and_inc}")?;
3733 }
3734 if let Some(ref gate_name) = gate_name {
3735 let lif = level.next();
3736 let lcontinuing = lif.next();
3737 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3738 self.put_block(lcontinuing, continuing, context)?;
3739 if let Some(condition) = break_if {
3740 write!(self.out, "{lcontinuing}if (")?;
3741 self.put_expression(condition, &context.expression, true)?;
3742 writeln!(self.out, ") {{")?;
3743 writeln!(self.out, "{}break;", lcontinuing.next())?;
3744 writeln!(self.out, "{lcontinuing}}}")?;
3745 }
3746 writeln!(self.out, "{lif}}}")?;
3747 writeln!(self.out, "{lif}{gate_name} = false;")?;
3748 }
3749 self.put_block(level.next(), body, context)?;
3750
3751 writeln!(self.out, "{level}}}")?;
3752 }
3753 crate::Statement::Break => {
3754 writeln!(self.out, "{level}break;")?;
3755 }
3756 crate::Statement::Continue => {
3757 writeln!(self.out, "{level}continue;")?;
3758 }
3759 crate::Statement::Return {
3760 value: Some(expr_handle),
3761 } => {
3762 self.put_return_value(
3763 level,
3764 expr_handle,
3765 context.result_struct,
3766 &context.expression,
3767 )?;
3768 }
3769 crate::Statement::Return { value: None } => {
3770 writeln!(self.out, "{level}return;")?;
3771 }
3772 crate::Statement::Kill => {
3773 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3774 }
3775 crate::Statement::ControlBarrier(flags)
3776 | crate::Statement::MemoryBarrier(flags) => {
3777 self.write_barrier(flags, level)?;
3778 }
3779 crate::Statement::Store { pointer, value } => {
3780 self.put_store(pointer, value, level, context)?
3781 }
3782 crate::Statement::ImageStore {
3783 image,
3784 coordinate,
3785 array_index,
3786 value,
3787 } => {
3788 let address = TexelAddress {
3789 coordinate,
3790 array_index,
3791 sample: None,
3792 level: None,
3793 };
3794 self.put_image_store(level, image, &address, value, context)?
3795 }
3796 crate::Statement::Call {
3797 function,
3798 ref arguments,
3799 result,
3800 } => {
3801 write!(self.out, "{level}")?;
3802 if let Some(expr) = result {
3803 let name = Baked(expr).to_string();
3804 self.start_baking_expression(expr, &context.expression, &name)?;
3805 self.named_expressions.insert(expr, name);
3806 }
3807 let fun_name = &self.names[&NameKey::Function(function)];
3808 write!(self.out, "{fun_name}(")?;
3809 for (i, &handle) in arguments.iter().enumerate() {
3811 if i != 0 {
3812 write!(self.out, ", ")?;
3813 }
3814 self.put_expression(handle, &context.expression, true)?;
3815 }
3816 let mut separate = !arguments.is_empty();
3818 let fun_info = &context.expression.mod_info[function];
3819 let mut needs_buffer_sizes = false;
3820 for (handle, var) in context.expression.module.global_variables.iter() {
3821 if fun_info[handle].is_empty() {
3822 continue;
3823 }
3824 if var.space.needs_pass_through() {
3825 let name = &self.names[&NameKey::GlobalVariable(handle)];
3826 if separate {
3827 write!(self.out, ", ")?;
3828 } else {
3829 separate = true;
3830 }
3831 write!(self.out, "{name}")?;
3832 }
3833 needs_buffer_sizes |=
3834 needs_array_length(var.ty, &context.expression.module.types);
3835 }
3836 if needs_buffer_sizes {
3837 if separate {
3838 write!(self.out, ", ")?;
3839 }
3840 write!(self.out, "_buffer_sizes")?;
3841 }
3842
3843 writeln!(self.out, ");")?;
3845 }
3846 crate::Statement::Atomic {
3847 pointer,
3848 ref fun,
3849 value,
3850 result,
3851 } => {
3852 let context = &context.expression;
3853
3854 write!(self.out, "{level}")?;
3859 let fun_key = if let Some(result) = result {
3860 let res_name = Baked(result).to_string();
3861 self.start_baking_expression(result, context, &res_name)?;
3862 self.named_expressions.insert(result, res_name);
3863 fun.to_msl()
3864 } else if context.resolve_type(value).scalar_width() == Some(8) {
3865 fun.to_msl_64_bit()?
3866 } else {
3867 fun.to_msl()
3868 };
3869
3870 let policy = context.choose_bounds_check_policy(pointer);
3874 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3875 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3876
3877 if checked {
3879 write!(self.out, " ? ")?;
3880 }
3881
3882 match *fun {
3884 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3885 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3886 self.put_access_chain(pointer, policy, context)?;
3887 write!(self.out, ", ")?;
3888 self.put_expression(cmp, context, true)?;
3889 write!(self.out, ", ")?;
3890 self.put_expression(value, context, true)?;
3891 write!(self.out, ")")?;
3892 }
3893 _ => {
3894 write!(
3895 self.out,
3896 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3897 )?;
3898 self.put_access_chain(pointer, policy, context)?;
3899 write!(self.out, ", ")?;
3900 self.put_expression(value, context, true)?;
3901 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3902 }
3903 }
3904
3905 if checked {
3907 write!(self.out, " : DefaultConstructible()")?;
3908 }
3909
3910 writeln!(self.out, ";")?;
3912 }
3913 crate::Statement::ImageAtomic {
3914 image,
3915 coordinate,
3916 array_index,
3917 fun,
3918 value,
3919 } => {
3920 let address = TexelAddress {
3921 coordinate,
3922 array_index,
3923 sample: None,
3924 level: None,
3925 };
3926 self.put_image_atomic(level, image, &address, fun, value, context)?
3927 }
3928 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3929 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3930
3931 write!(self.out, "{level}")?;
3932 let name = self.namer.call("");
3933 self.start_baking_expression(result, &context.expression, &name)?;
3934 self.put_load(pointer, &context.expression, true)?;
3935 self.named_expressions.insert(result, name);
3936
3937 writeln!(self.out, ";")?;
3938 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3939 }
3940 crate::Statement::RayQuery { query, ref fun } => {
3941 if context.expression.lang_version < (2, 4) {
3942 return Err(Error::UnsupportedRayTracing);
3943 }
3944
3945 match *fun {
3946 crate::RayQueryFunction::Initialize {
3947 acceleration_structure,
3948 descriptor,
3949 } => {
3950 write!(self.out, "{level}")?;
3952 self.put_expression(query, &context.expression, true)?;
3953 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3954 {
3955 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3956 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3957 write!(self.out, "{level}")?;
3958 self.put_expression(query, &context.expression, true)?;
3959 write!(
3960 self.out,
3961 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3962 )?;
3963 self.put_expression(descriptor, &context.expression, true)?;
3964 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3965 self.put_expression(descriptor, &context.expression, true)?;
3966 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3967 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3968 }
3969 {
3970 let f_opaque = back::RayFlag::OPAQUE.bits();
3971 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3972 write!(self.out, "{level}")?;
3973 self.put_expression(query, &context.expression, true)?;
3974 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3975 self.put_expression(descriptor, &context.expression, true)?;
3976 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3977 self.put_expression(descriptor, &context.expression, true)?;
3978 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3979 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3980 }
3981 {
3982 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3983 write!(self.out, "{level}")?;
3984 self.put_expression(query, &context.expression, true)?;
3985 write!(
3986 self.out,
3987 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3988 )?;
3989 self.put_expression(descriptor, &context.expression, true)?;
3990 writeln!(self.out, ".flags & {flag}) != 0);")?;
3991 }
3992
3993 write!(self.out, "{level}")?;
3994 self.put_expression(query, &context.expression, true)?;
3995 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3996 self.put_expression(query, &context.expression, true)?;
3997 write!(
3998 self.out,
3999 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4000 )?;
4001 self.put_expression(descriptor, &context.expression, true)?;
4002 write!(self.out, ".origin, ")?;
4003 self.put_expression(descriptor, &context.expression, true)?;
4004 write!(self.out, ".dir, ")?;
4005 self.put_expression(descriptor, &context.expression, true)?;
4006 write!(self.out, ".tmin, ")?;
4007 self.put_expression(descriptor, &context.expression, true)?;
4008 write!(self.out, ".tmax), ")?;
4009 self.put_expression(acceleration_structure, &context.expression, true)?;
4010 write!(self.out, ", ")?;
4011 self.put_expression(descriptor, &context.expression, true)?;
4012 write!(self.out, ".cull_mask);")?;
4013
4014 write!(self.out, "{level}")?;
4015 self.put_expression(query, &context.expression, true)?;
4016 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4017 }
4018 crate::RayQueryFunction::Proceed { result } => {
4019 write!(self.out, "{level}")?;
4020 let name = Baked(result).to_string();
4021 self.start_baking_expression(result, &context.expression, &name)?;
4022 self.named_expressions.insert(result, name);
4023 self.put_expression(query, &context.expression, true)?;
4024 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4025 if RAY_QUERY_MODERN_SUPPORT {
4026 write!(self.out, "{level}")?;
4027 self.put_expression(query, &context.expression, true)?;
4028 writeln!(self.out, ".?.next();")?;
4029 }
4030 }
4031 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4032 if RAY_QUERY_MODERN_SUPPORT {
4033 write!(self.out, "{level}")?;
4034 self.put_expression(query, &context.expression, true)?;
4035 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4036 self.put_expression(hit_t, &context.expression, true)?;
4037 writeln!(self.out, ");")?;
4038 } else {
4039 log::warn!("Ray Query GenerateIntersection is not yet supported");
4040 }
4041 }
4042 crate::RayQueryFunction::ConfirmIntersection => {
4043 if RAY_QUERY_MODERN_SUPPORT {
4044 write!(self.out, "{level}")?;
4045 self.put_expression(query, &context.expression, true)?;
4046 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4047 } else {
4048 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4049 }
4050 }
4051 crate::RayQueryFunction::Terminate => {
4052 if RAY_QUERY_MODERN_SUPPORT {
4053 write!(self.out, "{level}")?;
4054 self.put_expression(query, &context.expression, true)?;
4055 writeln!(self.out, ".?.abort();")?;
4056 }
4057 write!(self.out, "{level}")?;
4058 self.put_expression(query, &context.expression, true)?;
4059 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4060 }
4061 }
4062 }
4063 crate::Statement::SubgroupBallot { result, predicate } => {
4064 write!(self.out, "{level}")?;
4065 let name = self.namer.call("");
4066 self.start_baking_expression(result, &context.expression, &name)?;
4067 self.named_expressions.insert(result, name);
4068 write!(
4069 self.out,
4070 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4071 )?;
4072 if let Some(predicate) = predicate {
4073 self.put_expression(predicate, &context.expression, true)?;
4074 } else {
4075 write!(self.out, "true")?;
4076 }
4077 writeln!(self.out, "), 0, 0, 0);")?;
4078 }
4079 crate::Statement::SubgroupCollectiveOperation {
4080 op,
4081 collective_op,
4082 argument,
4083 result,
4084 } => {
4085 write!(self.out, "{level}")?;
4086 let name = self.namer.call("");
4087 self.start_baking_expression(result, &context.expression, &name)?;
4088 self.named_expressions.insert(result, name);
4089 match (collective_op, op) {
4090 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4091 write!(self.out, "{NAMESPACE}::simd_all(")?
4092 }
4093 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4094 write!(self.out, "{NAMESPACE}::simd_any(")?
4095 }
4096 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4097 write!(self.out, "{NAMESPACE}::simd_sum(")?
4098 }
4099 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4100 write!(self.out, "{NAMESPACE}::simd_product(")?
4101 }
4102 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4103 write!(self.out, "{NAMESPACE}::simd_max(")?
4104 }
4105 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4106 write!(self.out, "{NAMESPACE}::simd_min(")?
4107 }
4108 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4109 write!(self.out, "{NAMESPACE}::simd_and(")?
4110 }
4111 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4112 write!(self.out, "{NAMESPACE}::simd_or(")?
4113 }
4114 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4115 write!(self.out, "{NAMESPACE}::simd_xor(")?
4116 }
4117 (
4118 crate::CollectiveOperation::ExclusiveScan,
4119 crate::SubgroupOperation::Add,
4120 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4121 (
4122 crate::CollectiveOperation::ExclusiveScan,
4123 crate::SubgroupOperation::Mul,
4124 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4125 (
4126 crate::CollectiveOperation::InclusiveScan,
4127 crate::SubgroupOperation::Add,
4128 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4129 (
4130 crate::CollectiveOperation::InclusiveScan,
4131 crate::SubgroupOperation::Mul,
4132 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4133 _ => unimplemented!(),
4134 }
4135 self.put_expression(argument, &context.expression, true)?;
4136 writeln!(self.out, ");")?;
4137 }
4138 crate::Statement::SubgroupGather {
4139 mode,
4140 argument,
4141 result,
4142 } => {
4143 write!(self.out, "{level}")?;
4144 let name = self.namer.call("");
4145 self.start_baking_expression(result, &context.expression, &name)?;
4146 self.named_expressions.insert(result, name);
4147 match mode {
4148 crate::GatherMode::BroadcastFirst => {
4149 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4150 }
4151 crate::GatherMode::Broadcast(_) => {
4152 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4153 }
4154 crate::GatherMode::Shuffle(_) => {
4155 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4156 }
4157 crate::GatherMode::ShuffleDown(_) => {
4158 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4159 }
4160 crate::GatherMode::ShuffleUp(_) => {
4161 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4162 }
4163 crate::GatherMode::ShuffleXor(_) => {
4164 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4165 }
4166 crate::GatherMode::QuadBroadcast(_) => {
4167 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4168 }
4169 crate::GatherMode::QuadSwap(_) => {
4170 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4171 }
4172 }
4173 self.put_expression(argument, &context.expression, true)?;
4174 match mode {
4175 crate::GatherMode::BroadcastFirst => {}
4176 crate::GatherMode::Broadcast(index)
4177 | crate::GatherMode::Shuffle(index)
4178 | crate::GatherMode::ShuffleDown(index)
4179 | crate::GatherMode::ShuffleUp(index)
4180 | crate::GatherMode::ShuffleXor(index)
4181 | crate::GatherMode::QuadBroadcast(index) => {
4182 write!(self.out, ", ")?;
4183 self.put_expression(index, &context.expression, true)?;
4184 }
4185 crate::GatherMode::QuadSwap(direction) => {
4186 write!(self.out, ", ")?;
4187 match direction {
4188 crate::Direction::X => {
4189 write!(self.out, "1u")?;
4190 }
4191 crate::Direction::Y => {
4192 write!(self.out, "2u")?;
4193 }
4194 crate::Direction::Diagonal => {
4195 write!(self.out, "3u")?;
4196 }
4197 }
4198 }
4199 }
4200 writeln!(self.out, ");")?;
4201 }
4202 }
4203 }
4204
4205 for statement in statements {
4208 if let crate::Statement::Emit(ref range) = *statement {
4209 for handle in range.clone() {
4210 self.named_expressions.shift_remove(&handle);
4211 }
4212 }
4213 }
4214 Ok(())
4215 }
4216
4217 fn put_store(
4218 &mut self,
4219 pointer: Handle<crate::Expression>,
4220 value: Handle<crate::Expression>,
4221 level: back::Level,
4222 context: &StatementContext,
4223 ) -> BackendResult {
4224 let policy = context.expression.choose_bounds_check_policy(pointer);
4225 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4226 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4227 {
4228 writeln!(self.out, ") {{")?;
4229 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4230 writeln!(self.out, "{level}}}")?;
4231 } else {
4232 self.put_unchecked_store(pointer, value, policy, level, context)?;
4233 }
4234
4235 Ok(())
4236 }
4237
4238 fn put_unchecked_store(
4239 &mut self,
4240 pointer: Handle<crate::Expression>,
4241 value: Handle<crate::Expression>,
4242 policy: index::BoundsCheckPolicy,
4243 level: back::Level,
4244 context: &StatementContext,
4245 ) -> BackendResult {
4246 let is_atomic_pointer = context
4247 .expression
4248 .resolve_type(pointer)
4249 .is_atomic_pointer(&context.expression.module.types);
4250
4251 if is_atomic_pointer {
4252 write!(
4253 self.out,
4254 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4255 )?;
4256 self.put_access_chain(pointer, policy, &context.expression)?;
4257 write!(self.out, ", ")?;
4258 self.put_expression(value, &context.expression, true)?;
4259 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4260 } else {
4261 write!(self.out, "{level}")?;
4262 self.put_access_chain(pointer, policy, &context.expression)?;
4263 write!(self.out, " = ")?;
4264 self.put_expression(value, &context.expression, true)?;
4265 writeln!(self.out, ";")?;
4266 }
4267
4268 Ok(())
4269 }
4270
4271 pub fn write(
4272 &mut self,
4273 module: &crate::Module,
4274 info: &valid::ModuleInfo,
4275 options: &Options,
4276 pipeline_options: &PipelineOptions,
4277 ) -> Result<TranslationInfo, Error> {
4278 self.names.clear();
4279 self.namer.reset(
4280 module,
4281 &super::keywords::RESERVED_SET,
4282 proc::CaseInsensitiveKeywordSet::empty(),
4283 &[CLAMPED_LOD_LOAD_PREFIX],
4284 &mut self.names,
4285 );
4286 self.wrapped_functions.clear();
4287 self.struct_member_pads.clear();
4288
4289 writeln!(
4290 self.out,
4291 "// language: metal{}.{}",
4292 options.lang_version.0, options.lang_version.1
4293 )?;
4294 writeln!(self.out, "#include <metal_stdlib>")?;
4295 writeln!(self.out, "#include <simd/simd.h>")?;
4296 writeln!(self.out)?;
4297 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4299
4300 let mut uses_ray_query = false;
4301 for (_, ty) in module.types.iter() {
4302 match ty.inner {
4303 crate::TypeInner::AccelerationStructure { .. } => {
4304 if options.lang_version < (2, 4) {
4305 return Err(Error::UnsupportedRayTracing);
4306 }
4307 }
4308 crate::TypeInner::RayQuery { .. } => {
4309 if options.lang_version < (2, 4) {
4310 return Err(Error::UnsupportedRayTracing);
4311 }
4312 uses_ray_query = true;
4313 }
4314 _ => (),
4315 }
4316 }
4317
4318 if module.special_types.ray_desc.is_some()
4319 || module.special_types.ray_intersection.is_some()
4320 {
4321 if options.lang_version < (2, 4) {
4322 return Err(Error::UnsupportedRayTracing);
4323 }
4324 }
4325
4326 if uses_ray_query {
4327 self.put_ray_query_type()?;
4328 }
4329
4330 if options
4331 .bounds_check_policies
4332 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4333 {
4334 self.put_default_constructible()?;
4335 }
4336 writeln!(self.out)?;
4337
4338 {
4339 let globals: Vec<Handle<crate::GlobalVariable>> = module
4342 .global_variables
4343 .iter()
4344 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4345 .map(|(handle, _)| handle)
4346 .collect();
4347
4348 let mut buffer_indices = vec![];
4349 for vbm in &pipeline_options.vertex_buffer_mappings {
4350 buffer_indices.push(vbm.id);
4351 }
4352
4353 if !globals.is_empty() || !buffer_indices.is_empty() {
4354 writeln!(self.out, "struct _mslBufferSizes {{")?;
4355
4356 for global in globals {
4357 writeln!(
4358 self.out,
4359 "{}uint {};",
4360 back::INDENT,
4361 ArraySizeMember(global)
4362 )?;
4363 }
4364
4365 for idx in buffer_indices {
4366 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4367 }
4368
4369 writeln!(self.out, "}};")?;
4370 writeln!(self.out)?;
4371 }
4372 };
4373
4374 self.write_type_defs(module)?;
4375 self.write_global_constants(module, info)?;
4376 self.write_functions(module, info, options, pipeline_options)
4377 }
4378
4379 fn put_default_constructible(&mut self) -> BackendResult {
4392 let tab = back::INDENT;
4393 writeln!(self.out, "struct DefaultConstructible {{")?;
4394 writeln!(self.out, "{tab}template<typename T>")?;
4395 writeln!(self.out, "{tab}operator T() && {{")?;
4396 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4397 writeln!(self.out, "{tab}}}")?;
4398 writeln!(self.out, "}};")?;
4399 Ok(())
4400 }
4401
4402 fn put_ray_query_type(&mut self) -> BackendResult {
4403 let tab = back::INDENT;
4404 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4405 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4406 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4407 writeln!(
4408 self.out,
4409 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4410 )?;
4411 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4412 writeln!(self.out, "}};")?;
4413 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4414 let v_triangle = back::RayIntersectionType::Triangle as u32;
4415 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4416 writeln!(
4417 self.out,
4418 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4419 )?;
4420 writeln!(
4421 self.out,
4422 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4423 )?;
4424 writeln!(self.out, "}}")?;
4425 Ok(())
4426 }
4427
4428 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4429 let mut generated_argument_buffer_wrapper = false;
4430 let mut generated_external_texture_wrapper = false;
4431 for (handle, ty) in module.types.iter() {
4432 match ty.inner {
4433 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4434 writeln!(self.out, "template <typename T>")?;
4435 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4436 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4437 writeln!(self.out, "}};")?;
4438 generated_argument_buffer_wrapper = true;
4439 }
4440 crate::TypeInner::Image {
4441 class: crate::ImageClass::External,
4442 ..
4443 } if !generated_external_texture_wrapper => {
4444 let params_ty_name = &self.names
4445 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4446 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4447 writeln!(
4448 self.out,
4449 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4450 back::INDENT
4451 )?;
4452 writeln!(
4453 self.out,
4454 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4455 back::INDENT
4456 )?;
4457 writeln!(
4458 self.out,
4459 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4460 back::INDENT
4461 )?;
4462 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4463 writeln!(self.out, "}};")?;
4464 generated_external_texture_wrapper = true;
4465 }
4466 _ => {}
4467 }
4468
4469 if !ty.needs_alias() {
4470 continue;
4471 }
4472 let name = &self.names[&NameKey::Type(handle)];
4473 match ty.inner {
4474 crate::TypeInner::Array {
4488 base,
4489 size,
4490 stride: _,
4491 } => {
4492 let base_name = TypeContext {
4493 handle: base,
4494 gctx: module.to_ctx(),
4495 names: &self.names,
4496 access: crate::StorageAccess::empty(),
4497 first_time: false,
4498 };
4499
4500 match size.resolve(module.to_ctx())? {
4501 proc::IndexableLength::Known(size) => {
4502 writeln!(self.out, "struct {name} {{")?;
4503 writeln!(
4504 self.out,
4505 "{}{} {}[{}];",
4506 back::INDENT,
4507 base_name,
4508 WRAPPED_ARRAY_FIELD,
4509 size
4510 )?;
4511 writeln!(self.out, "}};")?;
4512 }
4513 proc::IndexableLength::Dynamic => {
4514 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4515 }
4516 }
4517 }
4518 crate::TypeInner::Struct {
4519 ref members, span, ..
4520 } => {
4521 writeln!(self.out, "struct {name} {{")?;
4522 let mut last_offset = 0;
4523 for (index, member) in members.iter().enumerate() {
4524 if member.offset > last_offset {
4525 self.struct_member_pads.insert((handle, index as u32));
4526 let pad = member.offset - last_offset;
4527 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4528 }
4529 let ty_inner = &module.types[member.ty].inner;
4530 last_offset = member.offset + ty_inner.size(module.to_ctx());
4531
4532 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4533
4534 match should_pack_struct_member(members, span, index, module) {
4536 Some(scalar) => {
4537 writeln!(
4538 self.out,
4539 "{}{}::packed_{}3 {};",
4540 back::INDENT,
4541 NAMESPACE,
4542 scalar.to_msl_name(),
4543 member_name
4544 )?;
4545 }
4546 None => {
4547 let base_name = TypeContext {
4548 handle: member.ty,
4549 gctx: module.to_ctx(),
4550 names: &self.names,
4551 access: crate::StorageAccess::empty(),
4552 first_time: false,
4553 };
4554 writeln!(
4555 self.out,
4556 "{}{} {};",
4557 back::INDENT,
4558 base_name,
4559 member_name
4560 )?;
4561
4562 if let crate::TypeInner::Vector {
4564 size: crate::VectorSize::Tri,
4565 scalar,
4566 } = *ty_inner
4567 {
4568 last_offset += scalar.width as u32;
4569 }
4570 }
4571 }
4572 }
4573 if last_offset < span {
4574 let pad = span - last_offset;
4575 writeln!(
4576 self.out,
4577 "{}char _pad{}[{}];",
4578 back::INDENT,
4579 members.len(),
4580 pad
4581 )?;
4582 }
4583 writeln!(self.out, "}};")?;
4584 }
4585 _ => {
4586 let ty_name = TypeContext {
4587 handle,
4588 gctx: module.to_ctx(),
4589 names: &self.names,
4590 access: crate::StorageAccess::empty(),
4591 first_time: true,
4592 };
4593 writeln!(self.out, "typedef {ty_name} {name};")?;
4594 }
4595 }
4596 }
4597
4598 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4600 match type_key {
4601 &crate::PredeclaredType::ModfResult { size, scalar }
4602 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4603 let arg_type_name_owner;
4604 let arg_type_name = if let Some(size) = size {
4605 arg_type_name_owner = format!(
4606 "{NAMESPACE}::{}{}",
4607 if scalar.width == 8 { "double" } else { "float" },
4608 size as u8
4609 );
4610 &arg_type_name_owner
4611 } else if scalar.width == 8 {
4612 "double"
4613 } else {
4614 "float"
4615 };
4616
4617 let other_type_name_owner;
4618 let (defined_func_name, called_func_name, other_type_name) =
4619 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4620 (MODF_FUNCTION, "modf", arg_type_name)
4621 } else {
4622 let other_type_name = if let Some(size) = size {
4623 other_type_name_owner = format!("int{}", size as u8);
4624 &other_type_name_owner
4625 } else {
4626 "int"
4627 };
4628 (FREXP_FUNCTION, "frexp", other_type_name)
4629 };
4630
4631 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4632
4633 writeln!(self.out)?;
4634 writeln!(
4635 self.out,
4636 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4637 {other_type_name} other;
4638 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4639 return {struct_name}{{ fract, other }};
4640}}"
4641 )?;
4642 }
4643 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4644 let arg_type_name = scalar.to_msl_name();
4645 let called_func_name = "atomic_compare_exchange_weak_explicit";
4646 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4647 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4648
4649 writeln!(self.out)?;
4650
4651 for address_space_name in ["device", "threadgroup"] {
4652 writeln!(
4653 self.out,
4654 "\
4655template <typename A>
4656{struct_name} {defined_func_name}(
4657 {address_space_name} A *atomic_ptr,
4658 {arg_type_name} cmp,
4659 {arg_type_name} v
4660) {{
4661 bool swapped = {NAMESPACE}::{called_func_name}(
4662 atomic_ptr, &cmp, v,
4663 metal::memory_order_relaxed, metal::memory_order_relaxed
4664 );
4665 return {struct_name}{{cmp, swapped}};
4666}}"
4667 )?;
4668 }
4669 }
4670 }
4671 }
4672
4673 Ok(())
4674 }
4675
4676 fn write_global_constants(
4678 &mut self,
4679 module: &crate::Module,
4680 mod_info: &valid::ModuleInfo,
4681 ) -> BackendResult {
4682 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4683
4684 for (handle, constant) in constants {
4685 let ty_name = TypeContext {
4686 handle: constant.ty,
4687 gctx: module.to_ctx(),
4688 names: &self.names,
4689 access: crate::StorageAccess::empty(),
4690 first_time: false,
4691 };
4692 let name = &self.names[&NameKey::Constant(handle)];
4693 write!(self.out, "constant {ty_name} {name} = ")?;
4694 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4695 writeln!(self.out, ";")?;
4696 }
4697
4698 Ok(())
4699 }
4700
4701 fn put_inline_sampler_properties(
4702 &mut self,
4703 level: back::Level,
4704 sampler: &sm::InlineSampler,
4705 ) -> BackendResult {
4706 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4707 writeln!(
4708 self.out,
4709 "{}{}::{}_address::{},",
4710 level,
4711 NAMESPACE,
4712 letter,
4713 address.as_str(),
4714 )?;
4715 }
4716 writeln!(
4717 self.out,
4718 "{}{}::mag_filter::{},",
4719 level,
4720 NAMESPACE,
4721 sampler.mag_filter.as_str(),
4722 )?;
4723 writeln!(
4724 self.out,
4725 "{}{}::min_filter::{},",
4726 level,
4727 NAMESPACE,
4728 sampler.min_filter.as_str(),
4729 )?;
4730 if let Some(filter) = sampler.mip_filter {
4731 writeln!(
4732 self.out,
4733 "{}{}::mip_filter::{},",
4734 level,
4735 NAMESPACE,
4736 filter.as_str(),
4737 )?;
4738 }
4739 if sampler.border_color != sm::BorderColor::TransparentBlack {
4741 writeln!(
4742 self.out,
4743 "{}{}::border_color::{},",
4744 level,
4745 NAMESPACE,
4746 sampler.border_color.as_str(),
4747 )?;
4748 }
4749 if false {
4753 if let Some(ref lod) = sampler.lod_clamp {
4754 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4755 }
4756 if let Some(aniso) = sampler.max_anisotropy {
4757 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4758 }
4759 }
4760 if sampler.compare_func != sm::CompareFunc::Never {
4761 writeln!(
4762 self.out,
4763 "{}{}::compare_func::{},",
4764 level,
4765 NAMESPACE,
4766 sampler.compare_func.as_str(),
4767 )?;
4768 }
4769 writeln!(
4770 self.out,
4771 "{}{}::coord::{}",
4772 level,
4773 NAMESPACE,
4774 sampler.coord.as_str()
4775 )?;
4776 Ok(())
4777 }
4778
4779 fn write_unpacking_function(
4780 &mut self,
4781 format: back::msl::VertexFormat,
4782 ) -> Result<(String, u32, u32), Error> {
4783 use back::msl::VertexFormat::*;
4784 match format {
4785 Uint8 => {
4786 let name = self.namer.call("unpackUint8");
4787 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4788 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4789 writeln!(self.out, "}}")?;
4790 Ok((name, 1, 1))
4791 }
4792 Uint8x2 => {
4793 let name = self.namer.call("unpackUint8x2");
4794 writeln!(
4795 self.out,
4796 "metal::uint2 {name}(metal::uchar b0, \
4797 metal::uchar b1) {{"
4798 )?;
4799 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4800 writeln!(self.out, "}}")?;
4801 Ok((name, 2, 2))
4802 }
4803 Uint8x4 => {
4804 let name = self.namer.call("unpackUint8x4");
4805 writeln!(
4806 self.out,
4807 "metal::uint4 {name}(metal::uchar b0, \
4808 metal::uchar b1, \
4809 metal::uchar b2, \
4810 metal::uchar b3) {{"
4811 )?;
4812 writeln!(
4813 self.out,
4814 "{}return metal::uint4(b0, b1, b2, b3);",
4815 back::INDENT
4816 )?;
4817 writeln!(self.out, "}}")?;
4818 Ok((name, 4, 4))
4819 }
4820 Sint8 => {
4821 let name = self.namer.call("unpackSint8");
4822 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4823 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4824 writeln!(self.out, "}}")?;
4825 Ok((name, 1, 1))
4826 }
4827 Sint8x2 => {
4828 let name = self.namer.call("unpackSint8x2");
4829 writeln!(
4830 self.out,
4831 "metal::int2 {name}(metal::uchar b0, \
4832 metal::uchar b1) {{"
4833 )?;
4834 writeln!(
4835 self.out,
4836 "{}return metal::int2(as_type<char>(b0), \
4837 as_type<char>(b1));",
4838 back::INDENT
4839 )?;
4840 writeln!(self.out, "}}")?;
4841 Ok((name, 2, 2))
4842 }
4843 Sint8x4 => {
4844 let name = self.namer.call("unpackSint8x4");
4845 writeln!(
4846 self.out,
4847 "metal::int4 {name}(metal::uchar b0, \
4848 metal::uchar b1, \
4849 metal::uchar b2, \
4850 metal::uchar b3) {{"
4851 )?;
4852 writeln!(
4853 self.out,
4854 "{}return metal::int4(as_type<char>(b0), \
4855 as_type<char>(b1), \
4856 as_type<char>(b2), \
4857 as_type<char>(b3));",
4858 back::INDENT
4859 )?;
4860 writeln!(self.out, "}}")?;
4861 Ok((name, 4, 4))
4862 }
4863 Unorm8 => {
4864 let name = self.namer.call("unpackUnorm8");
4865 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4866 writeln!(
4867 self.out,
4868 "{}return float(float(b0) / 255.0f);",
4869 back::INDENT
4870 )?;
4871 writeln!(self.out, "}}")?;
4872 Ok((name, 1, 1))
4873 }
4874 Unorm8x2 => {
4875 let name = self.namer.call("unpackUnorm8x2");
4876 writeln!(
4877 self.out,
4878 "metal::float2 {name}(metal::uchar b0, \
4879 metal::uchar b1) {{"
4880 )?;
4881 writeln!(
4882 self.out,
4883 "{}return metal::float2(float(b0) / 255.0f, \
4884 float(b1) / 255.0f);",
4885 back::INDENT
4886 )?;
4887 writeln!(self.out, "}}")?;
4888 Ok((name, 2, 2))
4889 }
4890 Unorm8x4 => {
4891 let name = self.namer.call("unpackUnorm8x4");
4892 writeln!(
4893 self.out,
4894 "metal::float4 {name}(metal::uchar b0, \
4895 metal::uchar b1, \
4896 metal::uchar b2, \
4897 metal::uchar b3) {{"
4898 )?;
4899 writeln!(
4900 self.out,
4901 "{}return metal::float4(float(b0) / 255.0f, \
4902 float(b1) / 255.0f, \
4903 float(b2) / 255.0f, \
4904 float(b3) / 255.0f);",
4905 back::INDENT
4906 )?;
4907 writeln!(self.out, "}}")?;
4908 Ok((name, 4, 4))
4909 }
4910 Snorm8 => {
4911 let name = self.namer.call("unpackSnorm8");
4912 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4913 writeln!(
4914 self.out,
4915 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4916 back::INDENT
4917 )?;
4918 writeln!(self.out, "}}")?;
4919 Ok((name, 1, 1))
4920 }
4921 Snorm8x2 => {
4922 let name = self.namer.call("unpackSnorm8x2");
4923 writeln!(
4924 self.out,
4925 "metal::float2 {name}(metal::uchar b0, \
4926 metal::uchar b1) {{"
4927 )?;
4928 writeln!(
4929 self.out,
4930 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4931 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4932 back::INDENT
4933 )?;
4934 writeln!(self.out, "}}")?;
4935 Ok((name, 2, 2))
4936 }
4937 Snorm8x4 => {
4938 let name = self.namer.call("unpackSnorm8x4");
4939 writeln!(
4940 self.out,
4941 "metal::float4 {name}(metal::uchar b0, \
4942 metal::uchar b1, \
4943 metal::uchar b2, \
4944 metal::uchar b3) {{"
4945 )?;
4946 writeln!(
4947 self.out,
4948 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4949 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4950 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4951 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4952 back::INDENT
4953 )?;
4954 writeln!(self.out, "}}")?;
4955 Ok((name, 4, 4))
4956 }
4957 Uint16 => {
4958 let name = self.namer.call("unpackUint16");
4959 writeln!(
4960 self.out,
4961 "metal::uint {name}(metal::uint b0, \
4962 metal::uint b1) {{"
4963 )?;
4964 writeln!(
4965 self.out,
4966 "{}return metal::uint(b1 << 8 | b0);",
4967 back::INDENT
4968 )?;
4969 writeln!(self.out, "}}")?;
4970 Ok((name, 2, 1))
4971 }
4972 Uint16x2 => {
4973 let name = self.namer.call("unpackUint16x2");
4974 writeln!(
4975 self.out,
4976 "metal::uint2 {name}(metal::uint b0, \
4977 metal::uint b1, \
4978 metal::uint b2, \
4979 metal::uint b3) {{"
4980 )?;
4981 writeln!(
4982 self.out,
4983 "{}return metal::uint2(b1 << 8 | b0, \
4984 b3 << 8 | b2);",
4985 back::INDENT
4986 )?;
4987 writeln!(self.out, "}}")?;
4988 Ok((name, 4, 2))
4989 }
4990 Uint16x4 => {
4991 let name = self.namer.call("unpackUint16x4");
4992 writeln!(
4993 self.out,
4994 "metal::uint4 {name}(metal::uint b0, \
4995 metal::uint b1, \
4996 metal::uint b2, \
4997 metal::uint b3, \
4998 metal::uint b4, \
4999 metal::uint b5, \
5000 metal::uint b6, \
5001 metal::uint b7) {{"
5002 )?;
5003 writeln!(
5004 self.out,
5005 "{}return metal::uint4(b1 << 8 | b0, \
5006 b3 << 8 | b2, \
5007 b5 << 8 | b4, \
5008 b7 << 8 | b6);",
5009 back::INDENT
5010 )?;
5011 writeln!(self.out, "}}")?;
5012 Ok((name, 8, 4))
5013 }
5014 Sint16 => {
5015 let name = self.namer.call("unpackSint16");
5016 writeln!(
5017 self.out,
5018 "int {name}(metal::ushort b0, \
5019 metal::ushort b1) {{"
5020 )?;
5021 writeln!(
5022 self.out,
5023 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5024 back::INDENT
5025 )?;
5026 writeln!(self.out, "}}")?;
5027 Ok((name, 2, 1))
5028 }
5029 Sint16x2 => {
5030 let name = self.namer.call("unpackSint16x2");
5031 writeln!(
5032 self.out,
5033 "metal::int2 {name}(metal::ushort b0, \
5034 metal::ushort b1, \
5035 metal::ushort b2, \
5036 metal::ushort b3) {{"
5037 )?;
5038 writeln!(
5039 self.out,
5040 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5041 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5042 back::INDENT
5043 )?;
5044 writeln!(self.out, "}}")?;
5045 Ok((name, 4, 2))
5046 }
5047 Sint16x4 => {
5048 let name = self.namer.call("unpackSint16x4");
5049 writeln!(
5050 self.out,
5051 "metal::int4 {name}(metal::ushort b0, \
5052 metal::ushort b1, \
5053 metal::ushort b2, \
5054 metal::ushort b3, \
5055 metal::ushort b4, \
5056 metal::ushort b5, \
5057 metal::ushort b6, \
5058 metal::ushort b7) {{"
5059 )?;
5060 writeln!(
5061 self.out,
5062 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5063 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5064 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5065 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5066 back::INDENT
5067 )?;
5068 writeln!(self.out, "}}")?;
5069 Ok((name, 8, 4))
5070 }
5071 Unorm16 => {
5072 let name = self.namer.call("unpackUnorm16");
5073 writeln!(
5074 self.out,
5075 "float {name}(metal::ushort b0, \
5076 metal::ushort b1) {{"
5077 )?;
5078 writeln!(
5079 self.out,
5080 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5081 back::INDENT
5082 )?;
5083 writeln!(self.out, "}}")?;
5084 Ok((name, 2, 1))
5085 }
5086 Unorm16x2 => {
5087 let name = self.namer.call("unpackUnorm16x2");
5088 writeln!(
5089 self.out,
5090 "metal::float2 {name}(metal::ushort b0, \
5091 metal::ushort b1, \
5092 metal::ushort b2, \
5093 metal::ushort b3) {{"
5094 )?;
5095 writeln!(
5096 self.out,
5097 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5098 float(b3 << 8 | b2) / 65535.0f);",
5099 back::INDENT
5100 )?;
5101 writeln!(self.out, "}}")?;
5102 Ok((name, 4, 2))
5103 }
5104 Unorm16x4 => {
5105 let name = self.namer.call("unpackUnorm16x4");
5106 writeln!(
5107 self.out,
5108 "metal::float4 {name}(metal::ushort b0, \
5109 metal::ushort b1, \
5110 metal::ushort b2, \
5111 metal::ushort b3, \
5112 metal::ushort b4, \
5113 metal::ushort b5, \
5114 metal::ushort b6, \
5115 metal::ushort b7) {{"
5116 )?;
5117 writeln!(
5118 self.out,
5119 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5120 float(b3 << 8 | b2) / 65535.0f, \
5121 float(b5 << 8 | b4) / 65535.0f, \
5122 float(b7 << 8 | b6) / 65535.0f);",
5123 back::INDENT
5124 )?;
5125 writeln!(self.out, "}}")?;
5126 Ok((name, 8, 4))
5127 }
5128 Snorm16 => {
5129 let name = self.namer.call("unpackSnorm16");
5130 writeln!(
5131 self.out,
5132 "float {name}(metal::ushort b0, \
5133 metal::ushort b1) {{"
5134 )?;
5135 writeln!(
5136 self.out,
5137 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5138 back::INDENT
5139 )?;
5140 writeln!(self.out, "}}")?;
5141 Ok((name, 2, 1))
5142 }
5143 Snorm16x2 => {
5144 let name = self.namer.call("unpackSnorm16x2");
5145 writeln!(
5146 self.out,
5147 "metal::float2 {name}(metal::ushort b0, \
5148 metal::ushort b1, \
5149 metal::ushort b2, \
5150 metal::ushort b3) {{"
5151 )?;
5152 writeln!(
5153 self.out,
5154 "{}return metal::unpack_snorm2x16_to_float(b1 << 24 | b0 << 16 | b3 << 8 | b2);",
5155 back::INDENT
5156 )?;
5157 writeln!(self.out, "}}")?;
5158 Ok((name, 4, 2))
5159 }
5160 Snorm16x4 => {
5161 let name = self.namer.call("unpackSnorm16x4");
5162 writeln!(
5163 self.out,
5164 "metal::float4 {name}(metal::ushort b0, \
5165 metal::ushort b1, \
5166 metal::ushort b2, \
5167 metal::ushort b3, \
5168 metal::ushort b4, \
5169 metal::ushort b5, \
5170 metal::ushort b6, \
5171 metal::ushort b7) {{"
5172 )?;
5173 writeln!(
5174 self.out,
5175 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b1 << 24 | b0 << 16 | b3 << 8 | b2), \
5176 metal::unpack_snorm2x16_to_float(b5 << 24 | b4 << 16 | b7 << 8 | b6));",
5177 back::INDENT
5178 )?;
5179 writeln!(self.out, "}}")?;
5180 Ok((name, 8, 4))
5181 }
5182 Float16 => {
5183 let name = self.namer.call("unpackFloat16");
5184 writeln!(
5185 self.out,
5186 "float {name}(metal::ushort b0, \
5187 metal::ushort b1) {{"
5188 )?;
5189 writeln!(
5190 self.out,
5191 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5192 back::INDENT
5193 )?;
5194 writeln!(self.out, "}}")?;
5195 Ok((name, 2, 1))
5196 }
5197 Float16x2 => {
5198 let name = self.namer.call("unpackFloat16x2");
5199 writeln!(
5200 self.out,
5201 "metal::float2 {name}(metal::ushort b0, \
5202 metal::ushort b1, \
5203 metal::ushort b2, \
5204 metal::ushort b3) {{"
5205 )?;
5206 writeln!(
5207 self.out,
5208 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5209 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5210 back::INDENT
5211 )?;
5212 writeln!(self.out, "}}")?;
5213 Ok((name, 4, 2))
5214 }
5215 Float16x4 => {
5216 let name = self.namer.call("unpackFloat16x4");
5217 writeln!(
5218 self.out,
5219 "metal::float4 {name}(metal::ushort b0, \
5220 metal::ushort b1, \
5221 metal::ushort b2, \
5222 metal::ushort b3, \
5223 metal::ushort b4, \
5224 metal::ushort b5, \
5225 metal::ushort b6, \
5226 metal::ushort b7) {{"
5227 )?;
5228 writeln!(
5229 self.out,
5230 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5231 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5232 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5233 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5234 back::INDENT
5235 )?;
5236 writeln!(self.out, "}}")?;
5237 Ok((name, 8, 4))
5238 }
5239 Float32 => {
5240 let name = self.namer.call("unpackFloat32");
5241 writeln!(
5242 self.out,
5243 "float {name}(uint b0, \
5244 uint b1, \
5245 uint b2, \
5246 uint b3) {{"
5247 )?;
5248 writeln!(
5249 self.out,
5250 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5251 back::INDENT
5252 )?;
5253 writeln!(self.out, "}}")?;
5254 Ok((name, 4, 1))
5255 }
5256 Float32x2 => {
5257 let name = self.namer.call("unpackFloat32x2");
5258 writeln!(
5259 self.out,
5260 "metal::float2 {name}(uint b0, \
5261 uint b1, \
5262 uint b2, \
5263 uint b3, \
5264 uint b4, \
5265 uint b5, \
5266 uint b6, \
5267 uint b7) {{"
5268 )?;
5269 writeln!(
5270 self.out,
5271 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5272 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5273 back::INDENT
5274 )?;
5275 writeln!(self.out, "}}")?;
5276 Ok((name, 8, 2))
5277 }
5278 Float32x3 => {
5279 let name = self.namer.call("unpackFloat32x3");
5280 writeln!(
5281 self.out,
5282 "metal::float3 {name}(uint b0, \
5283 uint b1, \
5284 uint b2, \
5285 uint b3, \
5286 uint b4, \
5287 uint b5, \
5288 uint b6, \
5289 uint b7, \
5290 uint b8, \
5291 uint b9, \
5292 uint b10, \
5293 uint b11) {{"
5294 )?;
5295 writeln!(
5296 self.out,
5297 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5298 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5299 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5300 back::INDENT
5301 )?;
5302 writeln!(self.out, "}}")?;
5303 Ok((name, 12, 3))
5304 }
5305 Float32x4 => {
5306 let name = self.namer.call("unpackFloat32x4");
5307 writeln!(
5308 self.out,
5309 "metal::float4 {name}(uint b0, \
5310 uint b1, \
5311 uint b2, \
5312 uint b3, \
5313 uint b4, \
5314 uint b5, \
5315 uint b6, \
5316 uint b7, \
5317 uint b8, \
5318 uint b9, \
5319 uint b10, \
5320 uint b11, \
5321 uint b12, \
5322 uint b13, \
5323 uint b14, \
5324 uint b15) {{"
5325 )?;
5326 writeln!(
5327 self.out,
5328 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5329 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5330 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5331 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5332 back::INDENT
5333 )?;
5334 writeln!(self.out, "}}")?;
5335 Ok((name, 16, 4))
5336 }
5337 Uint32 => {
5338 let name = self.namer.call("unpackUint32");
5339 writeln!(
5340 self.out,
5341 "uint {name}(uint b0, \
5342 uint b1, \
5343 uint b2, \
5344 uint b3) {{"
5345 )?;
5346 writeln!(
5347 self.out,
5348 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5349 back::INDENT
5350 )?;
5351 writeln!(self.out, "}}")?;
5352 Ok((name, 4, 1))
5353 }
5354 Uint32x2 => {
5355 let name = self.namer.call("unpackUint32x2");
5356 writeln!(
5357 self.out,
5358 "uint2 {name}(uint b0, \
5359 uint b1, \
5360 uint b2, \
5361 uint b3, \
5362 uint b4, \
5363 uint b5, \
5364 uint b6, \
5365 uint b7) {{"
5366 )?;
5367 writeln!(
5368 self.out,
5369 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5370 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5371 back::INDENT
5372 )?;
5373 writeln!(self.out, "}}")?;
5374 Ok((name, 8, 2))
5375 }
5376 Uint32x3 => {
5377 let name = self.namer.call("unpackUint32x3");
5378 writeln!(
5379 self.out,
5380 "uint3 {name}(uint b0, \
5381 uint b1, \
5382 uint b2, \
5383 uint b3, \
5384 uint b4, \
5385 uint b5, \
5386 uint b6, \
5387 uint b7, \
5388 uint b8, \
5389 uint b9, \
5390 uint b10, \
5391 uint b11) {{"
5392 )?;
5393 writeln!(
5394 self.out,
5395 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5396 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5397 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5398 back::INDENT
5399 )?;
5400 writeln!(self.out, "}}")?;
5401 Ok((name, 12, 3))
5402 }
5403 Uint32x4 => {
5404 let name = self.namer.call("unpackUint32x4");
5405 writeln!(
5406 self.out,
5407 "{NAMESPACE}::uint4 {name}(uint b0, \
5408 uint b1, \
5409 uint b2, \
5410 uint b3, \
5411 uint b4, \
5412 uint b5, \
5413 uint b6, \
5414 uint b7, \
5415 uint b8, \
5416 uint b9, \
5417 uint b10, \
5418 uint b11, \
5419 uint b12, \
5420 uint b13, \
5421 uint b14, \
5422 uint b15) {{"
5423 )?;
5424 writeln!(
5425 self.out,
5426 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5427 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5428 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5429 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5430 back::INDENT
5431 )?;
5432 writeln!(self.out, "}}")?;
5433 Ok((name, 16, 4))
5434 }
5435 Sint32 => {
5436 let name = self.namer.call("unpackSint32");
5437 writeln!(
5438 self.out,
5439 "int {name}(uint b0, \
5440 uint b1, \
5441 uint b2, \
5442 uint b3) {{"
5443 )?;
5444 writeln!(
5445 self.out,
5446 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5447 back::INDENT
5448 )?;
5449 writeln!(self.out, "}}")?;
5450 Ok((name, 4, 1))
5451 }
5452 Sint32x2 => {
5453 let name = self.namer.call("unpackSint32x2");
5454 writeln!(
5455 self.out,
5456 "metal::int2 {name}(uint b0, \
5457 uint b1, \
5458 uint b2, \
5459 uint b3, \
5460 uint b4, \
5461 uint b5, \
5462 uint b6, \
5463 uint b7) {{"
5464 )?;
5465 writeln!(
5466 self.out,
5467 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5468 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5469 back::INDENT
5470 )?;
5471 writeln!(self.out, "}}")?;
5472 Ok((name, 8, 2))
5473 }
5474 Sint32x3 => {
5475 let name = self.namer.call("unpackSint32x3");
5476 writeln!(
5477 self.out,
5478 "metal::int3 {name}(uint b0, \
5479 uint b1, \
5480 uint b2, \
5481 uint b3, \
5482 uint b4, \
5483 uint b5, \
5484 uint b6, \
5485 uint b7, \
5486 uint b8, \
5487 uint b9, \
5488 uint b10, \
5489 uint b11) {{"
5490 )?;
5491 writeln!(
5492 self.out,
5493 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5494 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5495 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5496 back::INDENT
5497 )?;
5498 writeln!(self.out, "}}")?;
5499 Ok((name, 12, 3))
5500 }
5501 Sint32x4 => {
5502 let name = self.namer.call("unpackSint32x4");
5503 writeln!(
5504 self.out,
5505 "metal::int4 {name}(uint b0, \
5506 uint b1, \
5507 uint b2, \
5508 uint b3, \
5509 uint b4, \
5510 uint b5, \
5511 uint b6, \
5512 uint b7, \
5513 uint b8, \
5514 uint b9, \
5515 uint b10, \
5516 uint b11, \
5517 uint b12, \
5518 uint b13, \
5519 uint b14, \
5520 uint b15) {{"
5521 )?;
5522 writeln!(
5523 self.out,
5524 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5525 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5526 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5527 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5528 back::INDENT
5529 )?;
5530 writeln!(self.out, "}}")?;
5531 Ok((name, 16, 4))
5532 }
5533 Unorm10_10_10_2 => {
5534 let name = self.namer.call("unpackUnorm10_10_10_2");
5535 writeln!(
5536 self.out,
5537 "metal::float4 {name}(uint b0, \
5538 uint b1, \
5539 uint b2, \
5540 uint b3) {{"
5541 )?;
5542 writeln!(
5543 self.out,
5544 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5556 back::INDENT
5557 )?;
5558 writeln!(self.out, "}}")?;
5559 Ok((name, 4, 4))
5560 }
5561 Unorm8x4Bgra => {
5562 let name = self.namer.call("unpackUnorm8x4Bgra");
5563 writeln!(
5564 self.out,
5565 "metal::float4 {name}(metal::uchar b0, \
5566 metal::uchar b1, \
5567 metal::uchar b2, \
5568 metal::uchar b3) {{"
5569 )?;
5570 writeln!(
5571 self.out,
5572 "{}return metal::float4(float(b2) / 255.0f, \
5573 float(b1) / 255.0f, \
5574 float(b0) / 255.0f, \
5575 float(b3) / 255.0f);",
5576 back::INDENT
5577 )?;
5578 writeln!(self.out, "}}")?;
5579 Ok((name, 4, 4))
5580 }
5581 }
5582 }
5583
5584 fn write_wrapped_unary_op(
5585 &mut self,
5586 module: &crate::Module,
5587 func_ctx: &back::FunctionCtx,
5588 op: crate::UnaryOperator,
5589 operand: Handle<crate::Expression>,
5590 ) -> BackendResult {
5591 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5592 match op {
5593 crate::UnaryOperator::Negate
5600 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5601 {
5602 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5603 return Ok(());
5604 };
5605 let wrapped = WrappedFunction::UnaryOp {
5606 op,
5607 ty: (vector_size, scalar),
5608 };
5609 if !self.wrapped_functions.insert(wrapped) {
5610 return Ok(());
5611 }
5612
5613 let unsigned_scalar = crate::Scalar {
5614 kind: crate::ScalarKind::Uint,
5615 ..scalar
5616 };
5617 let mut type_name = String::new();
5618 let mut unsigned_type_name = String::new();
5619 match vector_size {
5620 None => {
5621 put_numeric_type(&mut type_name, scalar, &[])?;
5622 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5623 }
5624 Some(size) => {
5625 put_numeric_type(&mut type_name, scalar, &[size])?;
5626 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5627 }
5628 };
5629
5630 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5631 let level = back::Level(1);
5632 writeln!(
5633 self.out,
5634 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5635 )?;
5636 writeln!(self.out, "}}")?;
5637 writeln!(self.out)?;
5638 }
5639 _ => {}
5640 }
5641 Ok(())
5642 }
5643
5644 fn write_wrapped_binary_op(
5645 &mut self,
5646 module: &crate::Module,
5647 func_ctx: &back::FunctionCtx,
5648 expr: Handle<crate::Expression>,
5649 op: crate::BinaryOperator,
5650 left: Handle<crate::Expression>,
5651 right: Handle<crate::Expression>,
5652 ) -> BackendResult {
5653 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5654 let left_ty = func_ctx.resolve_type(left, &module.types);
5655 let right_ty = func_ctx.resolve_type(right, &module.types);
5656 match (op, expr_ty.scalar_kind()) {
5657 (
5664 crate::BinaryOperator::Divide,
5665 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5666 ) => {
5667 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5668 return Ok(());
5669 };
5670 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5671 return Ok(());
5672 };
5673 let wrapped = WrappedFunction::BinaryOp {
5674 op,
5675 left_ty: left_wrapped_ty,
5676 right_ty: right_wrapped_ty,
5677 };
5678 if !self.wrapped_functions.insert(wrapped) {
5679 return Ok(());
5680 }
5681
5682 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5683 return Ok(());
5684 };
5685 let mut type_name = String::new();
5686 match vector_size {
5687 None => put_numeric_type(&mut type_name, scalar, &[])?,
5688 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5689 };
5690 writeln!(
5691 self.out,
5692 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5693 )?;
5694 let level = back::Level(1);
5695 match scalar.kind {
5696 crate::ScalarKind::Sint => {
5697 let min_val = match scalar.width {
5698 4 => crate::Literal::I32(i32::MIN),
5699 8 => crate::Literal::I64(i64::MIN),
5700 _ => {
5701 return Err(Error::GenericValidation(format!(
5702 "Unexpected width for scalar {scalar:?}"
5703 )));
5704 }
5705 };
5706 write!(
5707 self.out,
5708 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5709 )?;
5710 self.put_literal(min_val)?;
5711 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5712 }
5713 crate::ScalarKind::Uint => writeln!(
5714 self.out,
5715 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5716 )?,
5717 _ => unreachable!(),
5718 }
5719 writeln!(self.out, "}}")?;
5720 writeln!(self.out)?;
5721 }
5722 (
5735 crate::BinaryOperator::Modulo,
5736 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5737 ) => {
5738 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5739 return Ok(());
5740 };
5741 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5742 else {
5743 return Ok(());
5744 };
5745 let wrapped = WrappedFunction::BinaryOp {
5746 op,
5747 left_ty: left_wrapped_ty,
5748 right_ty: (right_vector_size, right_scalar),
5749 };
5750 if !self.wrapped_functions.insert(wrapped) {
5751 return Ok(());
5752 }
5753
5754 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5755 return Ok(());
5756 };
5757 let mut type_name = String::new();
5758 match vector_size {
5759 None => put_numeric_type(&mut type_name, scalar, &[])?,
5760 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5761 };
5762 let mut rhs_type_name = String::new();
5763 match right_vector_size {
5764 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5765 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5766 };
5767
5768 writeln!(
5769 self.out,
5770 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5771 )?;
5772 let level = back::Level(1);
5773 match scalar.kind {
5774 crate::ScalarKind::Sint => {
5775 let min_val = match scalar.width {
5776 4 => crate::Literal::I32(i32::MIN),
5777 8 => crate::Literal::I64(i64::MIN),
5778 _ => {
5779 return Err(Error::GenericValidation(format!(
5780 "Unexpected width for scalar {scalar:?}"
5781 )));
5782 }
5783 };
5784 write!(
5785 self.out,
5786 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5787 )?;
5788 self.put_literal(min_val)?;
5789 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5790 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5791 }
5792 crate::ScalarKind::Uint => writeln!(
5793 self.out,
5794 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5795 )?,
5796 _ => unreachable!(),
5797 }
5798 writeln!(self.out, "}}")?;
5799 writeln!(self.out)?;
5800 }
5801 _ => {}
5802 }
5803 Ok(())
5804 }
5805
5806 #[allow(clippy::too_many_arguments)]
5807 fn write_wrapped_math_function(
5808 &mut self,
5809 module: &crate::Module,
5810 func_ctx: &back::FunctionCtx,
5811 fun: crate::MathFunction,
5812 arg: Handle<crate::Expression>,
5813 _arg1: Option<Handle<crate::Expression>>,
5814 _arg2: Option<Handle<crate::Expression>>,
5815 _arg3: Option<Handle<crate::Expression>>,
5816 ) -> BackendResult {
5817 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5818 match fun {
5819 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5827 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5828 return Ok(());
5829 };
5830 let wrapped = WrappedFunction::Math {
5831 fun,
5832 arg_ty: (vector_size, scalar),
5833 };
5834 if !self.wrapped_functions.insert(wrapped) {
5835 return Ok(());
5836 }
5837
5838 let unsigned_scalar = crate::Scalar {
5839 kind: crate::ScalarKind::Uint,
5840 ..scalar
5841 };
5842 let mut type_name = String::new();
5843 let mut unsigned_type_name = String::new();
5844 match vector_size {
5845 None => {
5846 put_numeric_type(&mut type_name, scalar, &[])?;
5847 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5848 }
5849 Some(size) => {
5850 put_numeric_type(&mut type_name, scalar, &[size])?;
5851 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5852 }
5853 };
5854
5855 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5856 let level = back::Level(1);
5857 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5858 writeln!(self.out, "}}")?;
5859 writeln!(self.out)?;
5860 }
5861 _ => {}
5862 }
5863 Ok(())
5864 }
5865
5866 fn write_wrapped_cast(
5867 &mut self,
5868 module: &crate::Module,
5869 func_ctx: &back::FunctionCtx,
5870 expr: Handle<crate::Expression>,
5871 kind: crate::ScalarKind,
5872 convert: Option<crate::Bytes>,
5873 ) -> BackendResult {
5874 let src_ty = func_ctx.resolve_type(expr, &module.types);
5885 let Some(width) = convert else {
5886 return Ok(());
5887 };
5888 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5889 return Ok(());
5890 };
5891 let dst_scalar = crate::Scalar { kind, width };
5892 if src_scalar.kind != crate::ScalarKind::Float
5893 || (dst_scalar.kind != crate::ScalarKind::Sint
5894 && dst_scalar.kind != crate::ScalarKind::Uint)
5895 {
5896 return Ok(());
5897 }
5898 let wrapped = WrappedFunction::Cast {
5899 src_scalar,
5900 vector_size,
5901 dst_scalar,
5902 };
5903 if !self.wrapped_functions.insert(wrapped) {
5904 return Ok(());
5905 }
5906 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5907
5908 let mut src_type_name = String::new();
5909 match vector_size {
5910 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5911 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5912 };
5913 let mut dst_type_name = String::new();
5914 match vector_size {
5915 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5916 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5917 };
5918 let fun_name = match dst_scalar {
5919 crate::Scalar::I32 => F2I32_FUNCTION,
5920 crate::Scalar::U32 => F2U32_FUNCTION,
5921 crate::Scalar::I64 => F2I64_FUNCTION,
5922 crate::Scalar::U64 => F2U64_FUNCTION,
5923 _ => unreachable!(),
5924 };
5925
5926 writeln!(
5927 self.out,
5928 "{dst_type_name} {fun_name}({src_type_name} value) {{"
5929 )?;
5930 let level = back::Level(1);
5931 write!(
5932 self.out,
5933 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5934 )?;
5935 self.put_literal(min)?;
5936 write!(self.out, ", ")?;
5937 self.put_literal(max)?;
5938 writeln!(self.out, "));")?;
5939 writeln!(self.out, "}}")?;
5940 writeln!(self.out)?;
5941 Ok(())
5942 }
5943
5944 fn write_convert_yuv_to_rgb_and_return(
5952 &mut self,
5953 level: back::Level,
5954 y: &str,
5955 uv: &str,
5956 params: &str,
5957 ) -> BackendResult {
5958 let l1 = level;
5959 let l2 = l1.next();
5960
5961 writeln!(
5963 self.out,
5964 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
5965 )?;
5966
5967 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
5970 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
5971 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
5972 writeln!(
5973 self.out,
5974 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
5975 )?;
5976
5977 writeln!(
5980 self.out,
5981 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
5982 )?;
5983
5984 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
5987 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
5988 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
5989 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
5990
5991 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
5992 Ok(())
5993 }
5994
5995 #[allow(clippy::too_many_arguments)]
5996 fn write_wrapped_image_load(
5997 &mut self,
5998 module: &crate::Module,
5999 func_ctx: &back::FunctionCtx,
6000 image: Handle<crate::Expression>,
6001 _coordinate: Handle<crate::Expression>,
6002 _array_index: Option<Handle<crate::Expression>>,
6003 _sample: Option<Handle<crate::Expression>>,
6004 _level: Option<Handle<crate::Expression>>,
6005 ) -> BackendResult {
6006 let class = match *func_ctx.resolve_type(image, &module.types) {
6008 crate::TypeInner::Image { class, .. } => class,
6009 _ => unreachable!(),
6010 };
6011 if class != crate::ImageClass::External {
6012 return Ok(());
6013 }
6014 let wrapped = WrappedFunction::ImageLoad { class };
6015 if !self.wrapped_functions.insert(wrapped) {
6016 return Ok(());
6017 }
6018
6019 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6020 let l1 = back::Level(1);
6021 let l2 = l1.next();
6022 let l3 = l2.next();
6023 writeln!(
6024 self.out,
6025 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6026 )?;
6027 writeln!(
6031 self.out,
6032 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6033 )?;
6034 writeln!(
6035 self.out,
6036 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6037 )?;
6038
6039 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6041 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6042 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6044 writeln!(self.out, "{l1}}} else {{")?;
6045
6046 writeln!(
6048 self.out,
6049 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6050 )?;
6051 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6052
6053 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6055
6056 writeln!(self.out, "{l2}float2 uv;")?;
6057 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6058 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6060 writeln!(self.out, "{l2}}} else {{")?;
6061 writeln!(
6063 self.out,
6064 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6065 )?;
6066 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6067 writeln!(
6068 self.out,
6069 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6070 )?;
6071 writeln!(self.out, "{l2}}}")?;
6072
6073 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6074
6075 writeln!(self.out, "{l1}}}")?;
6076 writeln!(self.out, "}}")?;
6077 writeln!(self.out)?;
6078 Ok(())
6079 }
6080
6081 #[allow(clippy::too_many_arguments)]
6082 fn write_wrapped_image_sample(
6083 &mut self,
6084 module: &crate::Module,
6085 func_ctx: &back::FunctionCtx,
6086 image: Handle<crate::Expression>,
6087 _sampler: Handle<crate::Expression>,
6088 _gather: Option<crate::SwizzleComponent>,
6089 _coordinate: Handle<crate::Expression>,
6090 _array_index: Option<Handle<crate::Expression>>,
6091 _offset: Option<Handle<crate::Expression>>,
6092 _level: crate::SampleLevel,
6093 _depth_ref: Option<Handle<crate::Expression>>,
6094 clamp_to_edge: bool,
6095 ) -> BackendResult {
6096 if !clamp_to_edge {
6099 return Ok(());
6100 }
6101 let class = match *func_ctx.resolve_type(image, &module.types) {
6102 crate::TypeInner::Image { class, .. } => class,
6103 _ => unreachable!(),
6104 };
6105 let wrapped = WrappedFunction::ImageSample {
6106 class,
6107 clamp_to_edge: true,
6108 };
6109 if !self.wrapped_functions.insert(wrapped) {
6110 return Ok(());
6111 }
6112 match class {
6113 crate::ImageClass::External => {
6114 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6115 let l1 = back::Level(1);
6116 let l2 = l1.next();
6117 let l3 = l2.next();
6118 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6119 writeln!(
6120 self.out,
6121 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6122 )?;
6123
6124 writeln!(
6132 self.out,
6133 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6134 )?;
6135 writeln!(
6136 self.out,
6137 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6138 )?;
6139 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6140 writeln!(
6141 self.out,
6142 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6143 )?;
6144 writeln!(
6145 self.out,
6146 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6147 )?;
6148 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6149 writeln!(
6151 self.out,
6152 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6153 )?;
6154 writeln!(self.out, "{l1}}} else {{")?;
6155 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6156 writeln!(
6157 self.out,
6158 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6159 )?;
6160 writeln!(
6161 self.out,
6162 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6163 )?;
6164
6165 writeln!(
6167 self.out,
6168 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6169 )?;
6170 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6171 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6172 writeln!(
6174 self.out,
6175 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6176 )?;
6177 writeln!(self.out, "{l2}}} else {{")?;
6178 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6180 writeln!(
6181 self.out,
6182 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6183 )?;
6184 writeln!(
6185 self.out,
6186 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6187 )?;
6188 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6189 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6190 writeln!(self.out, "{l2}}}")?;
6191
6192 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6193
6194 writeln!(self.out, "{l1}}}")?;
6195 writeln!(self.out, "}}")?;
6196 writeln!(self.out)?;
6197 }
6198 _ => {
6199 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6200 let l1 = back::Level(1);
6201 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6202 writeln!(
6203 self.out,
6204 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6205 )?;
6206 writeln!(self.out, "}}")?;
6207 writeln!(self.out)?;
6208 }
6209 }
6210 Ok(())
6211 }
6212
6213 fn write_wrapped_image_query(
6214 &mut self,
6215 module: &crate::Module,
6216 func_ctx: &back::FunctionCtx,
6217 image: Handle<crate::Expression>,
6218 query: crate::ImageQuery,
6219 ) -> BackendResult {
6220 if !matches!(query, crate::ImageQuery::Size { .. }) {
6222 return Ok(());
6223 }
6224 let class = match *func_ctx.resolve_type(image, &module.types) {
6225 crate::TypeInner::Image { class, .. } => class,
6226 _ => unreachable!(),
6227 };
6228 if class != crate::ImageClass::External {
6229 return Ok(());
6230 }
6231 let wrapped = WrappedFunction::ImageQuerySize { class };
6232 if !self.wrapped_functions.insert(wrapped) {
6233 return Ok(());
6234 }
6235 writeln!(
6236 self.out,
6237 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6238 )?;
6239 let l1 = back::Level(1);
6240 let l2 = l1.next();
6241 writeln!(
6242 self.out,
6243 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6244 )?;
6245 writeln!(self.out, "{l2}return tex.params.size;")?;
6246 writeln!(self.out, "{l1}}} else {{")?;
6247 writeln!(
6249 self.out,
6250 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6251 )?;
6252 writeln!(self.out, "{l1}}}")?;
6253 writeln!(self.out, "}}")?;
6254 writeln!(self.out)?;
6255 Ok(())
6256 }
6257
6258 pub(super) fn write_wrapped_functions(
6259 &mut self,
6260 module: &crate::Module,
6261 func_ctx: &back::FunctionCtx,
6262 ) -> BackendResult {
6263 for (expr_handle, expr) in func_ctx.expressions.iter() {
6264 match *expr {
6265 crate::Expression::Unary { op, expr: operand } => {
6266 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6267 }
6268 crate::Expression::Binary { op, left, right } => {
6269 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6270 }
6271 crate::Expression::Math {
6272 fun,
6273 arg,
6274 arg1,
6275 arg2,
6276 arg3,
6277 } => {
6278 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6279 }
6280 crate::Expression::As {
6281 expr,
6282 kind,
6283 convert,
6284 } => {
6285 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6286 }
6287 crate::Expression::ImageLoad {
6288 image,
6289 coordinate,
6290 array_index,
6291 sample,
6292 level,
6293 } => {
6294 self.write_wrapped_image_load(
6295 module,
6296 func_ctx,
6297 image,
6298 coordinate,
6299 array_index,
6300 sample,
6301 level,
6302 )?;
6303 }
6304 crate::Expression::ImageSample {
6305 image,
6306 sampler,
6307 gather,
6308 coordinate,
6309 array_index,
6310 offset,
6311 level,
6312 depth_ref,
6313 clamp_to_edge,
6314 } => {
6315 self.write_wrapped_image_sample(
6316 module,
6317 func_ctx,
6318 image,
6319 sampler,
6320 gather,
6321 coordinate,
6322 array_index,
6323 offset,
6324 level,
6325 depth_ref,
6326 clamp_to_edge,
6327 )?;
6328 }
6329 crate::Expression::ImageQuery { image, query } => {
6330 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6331 }
6332 _ => {}
6333 }
6334 }
6335
6336 Ok(())
6337 }
6338
6339 fn write_functions(
6341 &mut self,
6342 module: &crate::Module,
6343 mod_info: &valid::ModuleInfo,
6344 options: &Options,
6345 pipeline_options: &PipelineOptions,
6346 ) -> Result<TranslationInfo, Error> {
6347 use back::msl::VertexFormat;
6348
6349 struct AttributeMappingResolved {
6352 ty_name: String,
6353 dimension: u32,
6354 ty_is_int: bool,
6355 name: String,
6356 }
6357 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6358
6359 struct VertexBufferMappingResolved<'a> {
6360 id: u32,
6361 stride: u32,
6362 indexed_by_vertex: bool,
6363 ty_name: String,
6364 param_name: String,
6365 elem_name: String,
6366 attributes: &'a Vec<back::msl::AttributeMapping>,
6367 }
6368 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6369
6370 struct UnpackingFunction {
6372 name: String,
6373 byte_count: u32,
6374 dimension: u32,
6375 }
6376 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6377
6378 let mut needs_vertex_id = false;
6384 let v_id = self.namer.call("v_id");
6385
6386 let mut needs_instance_id = false;
6387 let i_id = self.namer.call("i_id");
6388 if pipeline_options.vertex_pulling_transform {
6389 for vbm in &pipeline_options.vertex_buffer_mappings {
6390 let buffer_id = vbm.id;
6391 let buffer_stride = vbm.stride;
6392
6393 assert!(
6394 buffer_stride > 0,
6395 "Vertex pulling requires a non-zero buffer stride."
6396 );
6397
6398 if vbm.indexed_by_vertex {
6399 needs_vertex_id = true;
6400 } else {
6401 needs_instance_id = true;
6402 }
6403
6404 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6405 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6406 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6407
6408 vbm_resolved.push(VertexBufferMappingResolved {
6409 id: buffer_id,
6410 stride: buffer_stride,
6411 indexed_by_vertex: vbm.indexed_by_vertex,
6412 ty_name: buffer_ty,
6413 param_name: buffer_param,
6414 elem_name: buffer_elem,
6415 attributes: &vbm.attributes,
6416 });
6417
6418 for attribute in &vbm.attributes {
6420 if unpacking_functions.contains_key(&attribute.format) {
6421 continue;
6422 }
6423 let (name, byte_count, dimension) =
6424 match self.write_unpacking_function(attribute.format) {
6425 Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6426 _ => {
6427 continue;
6428 }
6429 };
6430 unpacking_functions.insert(
6431 attribute.format,
6432 UnpackingFunction {
6433 name,
6434 byte_count,
6435 dimension,
6436 },
6437 );
6438 }
6439 }
6440 }
6441
6442 let mut pass_through_globals = Vec::new();
6443 for (fun_handle, fun) in module.functions.iter() {
6444 log::trace!(
6445 "function {:?}, handle {:?}",
6446 fun.name.as_deref().unwrap_or("(anonymous)"),
6447 fun_handle
6448 );
6449
6450 let ctx = back::FunctionCtx {
6451 ty: back::FunctionType::Function(fun_handle),
6452 info: &mod_info[fun_handle],
6453 expressions: &fun.expressions,
6454 named_expressions: &fun.named_expressions,
6455 };
6456
6457 writeln!(self.out)?;
6458 self.write_wrapped_functions(module, &ctx)?;
6459
6460 let fun_info = &mod_info[fun_handle];
6461 pass_through_globals.clear();
6462 let mut needs_buffer_sizes = false;
6463 for (handle, var) in module.global_variables.iter() {
6464 if !fun_info[handle].is_empty() {
6465 if var.space.needs_pass_through() {
6466 pass_through_globals.push(handle);
6467 }
6468 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6469 }
6470 }
6471
6472 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6473 match fun.result {
6474 Some(ref result) => {
6475 let ty_name = TypeContext {
6476 handle: result.ty,
6477 gctx: module.to_ctx(),
6478 names: &self.names,
6479 access: crate::StorageAccess::empty(),
6480 first_time: false,
6481 };
6482 write!(self.out, "{ty_name}")?;
6483 }
6484 None => {
6485 write!(self.out, "void")?;
6486 }
6487 }
6488 writeln!(self.out, " {fun_name}(")?;
6489
6490 for (index, arg) in fun.arguments.iter().enumerate() {
6491 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6492 let param_type_name = TypeContext {
6493 handle: arg.ty,
6494 gctx: module.to_ctx(),
6495 names: &self.names,
6496 access: crate::StorageAccess::empty(),
6497 first_time: false,
6498 };
6499 let separator = separate(
6500 !pass_through_globals.is_empty()
6501 || index + 1 != fun.arguments.len()
6502 || needs_buffer_sizes,
6503 );
6504 writeln!(
6505 self.out,
6506 "{}{} {}{}",
6507 back::INDENT,
6508 param_type_name,
6509 name,
6510 separator
6511 )?;
6512 }
6513 for (index, &handle) in pass_through_globals.iter().enumerate() {
6514 let tyvar = TypedGlobalVariable {
6515 module,
6516 names: &self.names,
6517 handle,
6518 usage: fun_info[handle],
6519
6520 reference: true,
6521 };
6522 let separator =
6523 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6524 write!(self.out, "{}", back::INDENT)?;
6525 tyvar.try_fmt(&mut self.out)?;
6526 writeln!(self.out, "{separator}")?;
6527 }
6528
6529 if needs_buffer_sizes {
6530 writeln!(
6531 self.out,
6532 "{}constant _mslBufferSizes& _buffer_sizes",
6533 back::INDENT
6534 )?;
6535 }
6536
6537 writeln!(self.out, ") {{")?;
6538
6539 let guarded_indices =
6540 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6541
6542 let context = StatementContext {
6543 expression: ExpressionContext {
6544 function: fun,
6545 origin: FunctionOrigin::Handle(fun_handle),
6546 info: fun_info,
6547 lang_version: options.lang_version,
6548 policies: options.bounds_check_policies,
6549 guarded_indices,
6550 module,
6551 mod_info,
6552 pipeline_options,
6553 force_loop_bounding: options.force_loop_bounding,
6554 },
6555 result_struct: None,
6556 };
6557
6558 self.put_locals(&context.expression)?;
6559 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6560 self.put_block(back::Level(1), &fun.body, &context)?;
6561 writeln!(self.out, "}}")?;
6562 self.named_expressions.clear();
6563 }
6564
6565 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6566 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6567
6568 let mut info = TranslationInfo {
6569 entry_point_names: Vec::with_capacity(ep_range.len()),
6570 };
6571
6572 for ep_index in ep_range {
6573 let ep = &module.entry_points[ep_index];
6574 let fun = &ep.function;
6575 let fun_info = mod_info.get_entry_point(ep_index);
6576 let mut ep_error = None;
6577
6578 let mut v_existing_id = None;
6582 let mut i_existing_id = None;
6583
6584 log::trace!(
6585 "entry point {:?}, index {:?}",
6586 fun.name.as_deref().unwrap_or("(anonymous)"),
6587 ep_index
6588 );
6589
6590 let ctx = back::FunctionCtx {
6591 ty: back::FunctionType::EntryPoint(ep_index as u16),
6592 info: fun_info,
6593 expressions: &fun.expressions,
6594 named_expressions: &fun.named_expressions,
6595 };
6596
6597 self.write_wrapped_functions(module, &ctx)?;
6598
6599 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6600 crate::ShaderStage::Vertex => (
6601 "vertex",
6602 LocationMode::VertexInput,
6603 LocationMode::VertexOutput,
6604 true,
6605 ),
6606 crate::ShaderStage::Fragment => (
6607 "fragment",
6608 LocationMode::FragmentInput,
6609 LocationMode::FragmentOutput,
6610 false,
6611 ),
6612 crate::ShaderStage::Compute => (
6613 "kernel",
6614 LocationMode::Uniform,
6615 LocationMode::Uniform,
6616 false,
6617 ),
6618 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(),
6619 };
6620
6621 let do_vertex_pulling = can_vertex_pull
6623 && pipeline_options.vertex_pulling_transform
6624 && !pipeline_options.vertex_buffer_mappings.is_empty();
6625
6626 let needs_buffer_sizes = do_vertex_pulling
6628 || module
6629 .global_variables
6630 .iter()
6631 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6632 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6633
6634 if !options.fake_missing_bindings {
6637 for (var_handle, var) in module.global_variables.iter() {
6638 if fun_info[var_handle].is_empty() {
6639 continue;
6640 }
6641 match var.space {
6642 crate::AddressSpace::Uniform
6643 | crate::AddressSpace::Storage { .. }
6644 | crate::AddressSpace::Handle => {
6645 let br = match var.binding {
6646 Some(ref br) => br,
6647 None => {
6648 let var_name = var.name.clone().unwrap_or_default();
6649 ep_error =
6650 Some(super::EntryPointError::MissingBinding(var_name));
6651 break;
6652 }
6653 };
6654 let target = options.get_resource_binding_target(ep, br);
6655 let good = match target {
6656 Some(target) => {
6657 match module.types[var.ty].inner {
6661 crate::TypeInner::Image {
6662 class: crate::ImageClass::External,
6663 ..
6664 } => target.external_texture.is_some(),
6665 crate::TypeInner::Image { .. } => target.texture.is_some(),
6666 crate::TypeInner::Sampler { .. } => {
6667 target.sampler.is_some()
6668 }
6669 _ => target.buffer.is_some(),
6670 }
6671 }
6672 None => false,
6673 };
6674 if !good {
6675 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6676 break;
6677 }
6678 }
6679 crate::AddressSpace::PushConstant => {
6680 if let Err(e) = options.resolve_push_constants(ep) {
6681 ep_error = Some(e);
6682 break;
6683 }
6684 }
6685 crate::AddressSpace::Function
6686 | crate::AddressSpace::Private
6687 | crate::AddressSpace::WorkGroup => {}
6688 }
6689 }
6690 if needs_buffer_sizes {
6691 if let Err(err) = options.resolve_sizes_buffer(ep) {
6692 ep_error = Some(err);
6693 }
6694 }
6695 }
6696
6697 if let Some(err) = ep_error {
6698 info.entry_point_names.push(Err(err));
6699 continue;
6700 }
6701 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6702 info.entry_point_names.push(Ok(fun_name.clone()));
6703
6704 writeln!(self.out)?;
6705
6706 let mut flattened_member_names = FastHashMap::default();
6712 let mut varyings_namer = proc::Namer::default();
6714
6715 let mut flattened_arguments = Vec::new();
6720 for (arg_index, arg) in fun.arguments.iter().enumerate() {
6721 match module.types[arg.ty].inner {
6722 crate::TypeInner::Struct { ref members, .. } => {
6723 for (member_index, member) in members.iter().enumerate() {
6724 let member_index = member_index as u32;
6725 flattened_arguments.push((
6726 NameKey::StructMember(arg.ty, member_index),
6727 member.ty,
6728 member.binding.as_ref(),
6729 ));
6730 let name_key = NameKey::StructMember(arg.ty, member_index);
6731 let name = match member.binding {
6732 Some(crate::Binding::Location { .. }) => {
6733 if do_vertex_pulling {
6734 self.namer.call(&self.names[&name_key])
6735 } else {
6736 varyings_namer.call(&self.names[&name_key])
6737 }
6738 }
6739 _ => self.namer.call(&self.names[&name_key]),
6740 };
6741 flattened_member_names.insert(name_key, name);
6742 }
6743 }
6744 _ => flattened_arguments.push((
6745 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6746 arg.ty,
6747 arg.binding.as_ref(),
6748 )),
6749 }
6750 }
6751
6752 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6757 let varyings_member_name = self.namer.call("varyings");
6758 let mut has_varyings = false;
6759 if !flattened_arguments.is_empty() {
6760 if !do_vertex_pulling {
6761 writeln!(self.out, "struct {stage_in_name} {{")?;
6762 }
6763 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6764 let (binding, location) = match binding {
6765 Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6766 (binding, location)
6767 }
6768 _ => continue,
6769 };
6770 let name = match *name_key {
6771 NameKey::StructMember(..) => &flattened_member_names[name_key],
6772 _ => &self.names[name_key],
6773 };
6774 let ty_name = TypeContext {
6775 handle: ty,
6776 gctx: module.to_ctx(),
6777 names: &self.names,
6778 access: crate::StorageAccess::empty(),
6779 first_time: false,
6780 };
6781 let resolved = options.resolve_local_binding(binding, in_mode)?;
6782 if do_vertex_pulling {
6783 am_resolved.insert(
6785 location,
6786 AttributeMappingResolved {
6787 ty_name: ty_name.to_string(),
6788 dimension: ty_name.vertex_input_dimension(),
6789 ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6790 name: name.to_string(),
6791 },
6792 );
6793 } else {
6794 has_varyings = true;
6795 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6796 resolved.try_fmt(&mut self.out)?;
6797 writeln!(self.out, ";")?;
6798 }
6799 }
6800 if !do_vertex_pulling {
6801 writeln!(self.out, "}};")?;
6802 }
6803 }
6804
6805 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6808 let result_member_name = self.namer.call("member");
6809 let result_type_name = match fun.result {
6810 Some(ref result) => {
6811 let mut result_members = Vec::new();
6812 if let crate::TypeInner::Struct { ref members, .. } =
6813 module.types[result.ty].inner
6814 {
6815 for (member_index, member) in members.iter().enumerate() {
6816 result_members.push((
6817 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6818 member.ty,
6819 member.binding.as_ref(),
6820 ));
6821 }
6822 } else {
6823 result_members.push((
6824 &result_member_name,
6825 result.ty,
6826 result.binding.as_ref(),
6827 ));
6828 }
6829
6830 writeln!(self.out, "struct {stage_out_name} {{")?;
6831 let mut has_point_size = false;
6832 for (name, ty, binding) in result_members {
6833 let ty_name = TypeContext {
6834 handle: ty,
6835 gctx: module.to_ctx(),
6836 names: &self.names,
6837 access: crate::StorageAccess::empty(),
6838 first_time: true,
6839 };
6840 let binding = binding.ok_or_else(|| {
6841 Error::GenericValidation("Expected binding, got None".into())
6842 })?;
6843
6844 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6845 has_point_size = true;
6846 if !pipeline_options.allow_and_force_point_size {
6847 continue;
6848 }
6849 }
6850
6851 let array_len = match module.types[ty].inner {
6852 crate::TypeInner::Array {
6853 size: crate::ArraySize::Constant(size),
6854 ..
6855 } => Some(size),
6856 _ => None,
6857 };
6858 let resolved = options.resolve_local_binding(binding, out_mode)?;
6859 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6860 if let Some(array_len) = array_len {
6861 write!(self.out, " [{array_len}]")?;
6862 }
6863 resolved.try_fmt(&mut self.out)?;
6864 writeln!(self.out, ";")?;
6865 }
6866
6867 if pipeline_options.allow_and_force_point_size
6868 && ep.stage == crate::ShaderStage::Vertex
6869 && !has_point_size
6870 {
6871 writeln!(
6873 self.out,
6874 "{}float _point_size [[point_size]];",
6875 back::INDENT
6876 )?;
6877 }
6878 writeln!(self.out, "}};")?;
6879 &stage_out_name
6880 }
6881 None => "void",
6882 };
6883
6884 if do_vertex_pulling {
6887 for vbm in &vbm_resolved {
6888 let buffer_stride = vbm.stride;
6889 let buffer_ty = &vbm.ty_name;
6890
6891 writeln!(
6895 self.out,
6896 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6897 )?;
6898 }
6899 }
6900
6901 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6903 let mut is_first_argument = true;
6904
6905 if has_varyings {
6908 writeln!(
6909 self.out,
6910 " {stage_in_name} {varyings_member_name} [[stage_in]]"
6911 )?;
6912 is_first_argument = false;
6913 }
6914
6915 let mut local_invocation_id = None;
6916
6917 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6920 let binding = match binding {
6921 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6922 _ => continue,
6923 };
6924 let name = match *name_key {
6925 NameKey::StructMember(..) => &flattened_member_names[name_key],
6926 _ => &self.names[name_key],
6927 };
6928
6929 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6930 local_invocation_id = Some(name_key);
6931 }
6932
6933 let ty_name = TypeContext {
6934 handle: ty,
6935 gctx: module.to_ctx(),
6936 names: &self.names,
6937 access: crate::StorageAccess::empty(),
6938 first_time: false,
6939 };
6940
6941 match *binding {
6942 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
6943 v_existing_id = Some(name.clone());
6944 }
6945 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
6946 i_existing_id = Some(name.clone());
6947 }
6948 _ => {}
6949 };
6950
6951 let resolved = options.resolve_local_binding(binding, in_mode)?;
6952 let separator = if is_first_argument {
6953 is_first_argument = false;
6954 ' '
6955 } else {
6956 ','
6957 };
6958 write!(self.out, "{separator} {ty_name} {name}")?;
6959 resolved.try_fmt(&mut self.out)?;
6960 writeln!(self.out)?;
6961 }
6962
6963 let need_workgroup_variables_initialization =
6964 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
6965
6966 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
6967 let separator = if is_first_argument {
6968 is_first_argument = false;
6969 ' '
6970 } else {
6971 ','
6972 };
6973 writeln!(
6974 self.out,
6975 "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
6976 )?;
6977 }
6978
6979 for (handle, var) in module.global_variables.iter() {
6984 let usage = fun_info[handle];
6985 if usage.is_empty() || var.space == crate::AddressSpace::Private {
6986 continue;
6987 }
6988
6989 if options.lang_version < (1, 2) {
6990 match var.space {
6991 crate::AddressSpace::Storage { access }
7001 if access.contains(crate::StorageAccess::STORE)
7002 && ep.stage == crate::ShaderStage::Fragment =>
7003 {
7004 return Err(Error::UnsupportedWriteableStorageBuffer)
7005 }
7006 crate::AddressSpace::Handle => {
7007 match module.types[var.ty].inner {
7008 crate::TypeInner::Image {
7009 class: crate::ImageClass::Storage { access, .. },
7010 ..
7011 } => {
7012 if access.contains(crate::StorageAccess::STORE)
7022 && (ep.stage == crate::ShaderStage::Vertex
7023 || ep.stage == crate::ShaderStage::Fragment)
7024 {
7025 return Err(Error::UnsupportedWriteableStorageTexture(
7026 ep.stage,
7027 ));
7028 }
7029
7030 if access.contains(
7031 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7032 ) {
7033 return Err(Error::UnsupportedRWStorageTexture);
7034 }
7035 }
7036 _ => {}
7037 }
7038 }
7039 _ => {}
7040 }
7041 }
7042
7043 match var.space {
7045 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7046 crate::TypeInner::BindingArray { base, .. } => {
7047 match module.types[base].inner {
7048 crate::TypeInner::Sampler { .. } => {
7049 if options.lang_version < (2, 0) {
7050 return Err(Error::UnsupportedArrayOf(
7051 "samplers".to_string(),
7052 ));
7053 }
7054 }
7055 crate::TypeInner::Image { class, .. } => match class {
7056 crate::ImageClass::Sampled { .. }
7057 | crate::ImageClass::Depth { .. }
7058 | crate::ImageClass::Storage {
7059 access: crate::StorageAccess::LOAD,
7060 ..
7061 } => {
7062 if options.lang_version < (2, 0) {
7067 return Err(Error::UnsupportedArrayOf(
7068 "textures".to_string(),
7069 ));
7070 }
7071 }
7072 crate::ImageClass::Storage {
7073 access: crate::StorageAccess::STORE,
7074 ..
7075 } => {
7076 if options.lang_version < (2, 0) {
7081 return Err(Error::UnsupportedArrayOf(
7082 "write-only textures".to_string(),
7083 ));
7084 }
7085 }
7086 crate::ImageClass::Storage { .. } => {
7087 return Err(Error::UnsupportedArrayOf(
7088 "read-write textures".to_string(),
7089 ));
7090 }
7091 crate::ImageClass::External => {
7092 return Err(Error::UnsupportedArrayOf(
7093 "external textures".to_string(),
7094 ));
7095 }
7096 },
7097 _ => {
7098 return Err(Error::UnsupportedArrayOfType(base));
7099 }
7100 }
7101 }
7102 _ => {}
7103 },
7104 _ => {}
7105 }
7106
7107 let resolved = match var.space {
7109 crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7110 crate::AddressSpace::WorkGroup => None,
7111 _ => options
7112 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7113 .ok(),
7114 };
7115 if let Some(ref resolved) = resolved {
7116 if resolved.as_inline_sampler(options).is_some() {
7118 continue;
7119 }
7120 }
7121
7122 let mut separator = || {
7123 if is_first_argument {
7124 is_first_argument = false;
7125 ' '
7126 } else {
7127 ','
7128 }
7129 };
7130
7131 match module.types[var.ty].inner {
7132 crate::TypeInner::Image {
7133 class: crate::ImageClass::External,
7134 ..
7135 } => {
7136 let target = match resolved {
7140 Some(back::msl::ResolvedBinding::Resource(target)) => {
7141 target.external_texture
7142 }
7143 _ => None,
7144 };
7145
7146 for i in 0..3 {
7147 write!(self.out, "{} ", separator())?;
7148
7149 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7150 handle,
7151 ExternalTextureNameKey::Plane(i),
7152 )];
7153 write!(
7154 self.out,
7155 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7156 )?;
7157 if let Some(ref target) = target {
7158 write!(self.out, " [[texture({})]]", target.planes[i])?;
7159 }
7160 writeln!(self.out)?;
7161 }
7162 let params_ty_name = &self.names
7163 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7164 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7165 handle,
7166 ExternalTextureNameKey::Params,
7167 )];
7168 write!(self.out, "{} ", separator())?;
7169 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7170 if let Some(ref target) = target {
7171 write!(self.out, " [[buffer({})]]", target.params)?;
7172 }
7173 }
7174 _ => {
7175 let tyvar = TypedGlobalVariable {
7176 module,
7177 names: &self.names,
7178 handle,
7179 usage,
7180 reference: true,
7181 };
7182 write!(self.out, "{} ", separator())?;
7183 tyvar.try_fmt(&mut self.out)?;
7184 if let Some(resolved) = resolved {
7185 resolved.try_fmt(&mut self.out)?;
7186 }
7187 if let Some(value) = var.init {
7188 write!(self.out, " = ")?;
7189 self.put_const_expression(
7190 value,
7191 module,
7192 mod_info,
7193 &module.global_expressions,
7194 )?;
7195 }
7196 }
7197 }
7198 writeln!(self.out)?;
7199 }
7200
7201 if do_vertex_pulling {
7202 assert!(needs_vertex_id || needs_instance_id);
7203
7204 let mut separator = if is_first_argument {
7205 is_first_argument = false;
7206 ' '
7207 } else {
7208 ','
7209 };
7210
7211 if needs_vertex_id && v_existing_id.is_none() {
7212 writeln!(self.out, "{separator} uint {v_id} [[vertex_id]]")?;
7214 separator = ',';
7215 }
7216
7217 if needs_instance_id && i_existing_id.is_none() {
7218 writeln!(self.out, "{separator} uint {i_id} [[instance_id]]")?;
7219 }
7220
7221 for vbm in &vbm_resolved {
7224 let id = &vbm.id;
7225 let ty_name = &vbm.ty_name;
7226 let param_name = &vbm.param_name;
7227 writeln!(
7228 self.out,
7229 ", const device {ty_name}* {param_name} [[buffer({id})]]"
7230 )?;
7231 }
7232 }
7233
7234 if needs_buffer_sizes {
7237 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7239 let separator = if is_first_argument { ' ' } else { ',' };
7240 write!(
7241 self.out,
7242 "{separator} constant _mslBufferSizes& _buffer_sizes",
7243 )?;
7244 resolved.try_fmt(&mut self.out)?;
7245 writeln!(self.out)?;
7246 }
7247
7248 writeln!(self.out, ") {{")?;
7250
7251 if do_vertex_pulling {
7253 for vbm in &vbm_resolved {
7256 for attribute in vbm.attributes {
7257 let location = attribute.shader_location;
7258 let am_option = am_resolved.get(&location);
7259 if am_option.is_none() {
7260 continue;
7263 }
7264 let am = am_option.unwrap();
7265 let attribute_ty_name = &am.ty_name;
7266 let attribute_name = &am.name;
7267
7268 writeln!(
7269 self.out,
7270 "{}{attribute_ty_name} {attribute_name} = {{}};",
7271 back::Level(1)
7272 )?;
7273 }
7274
7275 write!(self.out, "{}if (", back::Level(1))?;
7278
7279 let idx = &vbm.id;
7280 let stride = &vbm.stride;
7281 let index_name = if vbm.indexed_by_vertex {
7282 if let Some(ref name) = v_existing_id {
7283 name
7284 } else {
7285 &v_id
7286 }
7287 } else if let Some(ref name) = i_existing_id {
7288 name
7289 } else {
7290 &i_id
7291 };
7292 write!(
7293 self.out,
7294 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7295 )?;
7296
7297 writeln!(self.out, ") {{")?;
7298
7299 let ty_name = &vbm.ty_name;
7301 let elem_name = &vbm.elem_name;
7302 let param_name = &vbm.param_name;
7303
7304 writeln!(
7305 self.out,
7306 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7307 back::Level(2),
7308 )?;
7309
7310 for attribute in vbm.attributes {
7313 let location = attribute.shader_location;
7314 let Some(am) = am_resolved.get(&location) else {
7315 continue;
7319 };
7320 let attribute_name = &am.name;
7321 let attribute_ty_name = &am.ty_name;
7322
7323 let offset = attribute.offset;
7324 let func = unpacking_functions
7325 .get(&attribute.format)
7326 .expect("Should have generated this unpacking function earlier.");
7327 let func_name = &func.name;
7328
7329 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7336
7337 if needs_padding_or_truncation != Ordering::Equal {
7338 writeln!(
7341 self.out,
7342 "{}// {attribute_ty_name} <- {:?}",
7343 back::Level(2),
7344 attribute.format
7345 )?;
7346 }
7347
7348 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7349
7350 if needs_padding_or_truncation == Ordering::Greater {
7351 write!(self.out, "{attribute_ty_name}(")?;
7353 }
7354
7355 write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7357 for i in (offset + 1)..(offset + func.byte_count) {
7358 write!(self.out, ", {elem_name}.data[{i}]")?;
7359 }
7360 write!(self.out, ")")?;
7361
7362 match needs_padding_or_truncation {
7363 Ordering::Greater => {
7364 let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7366 let one_value = if am.ty_is_int { "1" } else { "1.0" };
7367 for i in func.dimension..am.dimension {
7368 write!(
7369 self.out,
7370 ", {}",
7371 if i == 3 { one_value } else { zero_value }
7372 )?;
7373 }
7374 write!(self.out, ")")?;
7375 }
7376 Ordering::Less => {
7377 write!(
7379 self.out,
7380 ".{}",
7381 &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7382 )?;
7383 }
7384 Ordering::Equal => {}
7385 }
7386
7387 writeln!(self.out, ";")?;
7388 }
7389
7390 writeln!(self.out, "{}}}", back::Level(1))?;
7392 }
7393 }
7394
7395 if need_workgroup_variables_initialization {
7396 self.write_workgroup_variables_initialization(
7397 module,
7398 mod_info,
7399 fun_info,
7400 local_invocation_id,
7401 )?;
7402 }
7403
7404 for (handle, var) in module.global_variables.iter() {
7407 let usage = fun_info[handle];
7408 if usage.is_empty() {
7409 continue;
7410 }
7411 if var.space == crate::AddressSpace::Private {
7412 let tyvar = TypedGlobalVariable {
7413 module,
7414 names: &self.names,
7415 handle,
7416 usage,
7417
7418 reference: false,
7419 };
7420 write!(self.out, "{}", back::INDENT)?;
7421 tyvar.try_fmt(&mut self.out)?;
7422 match var.init {
7423 Some(value) => {
7424 write!(self.out, " = ")?;
7425 self.put_const_expression(
7426 value,
7427 module,
7428 mod_info,
7429 &module.global_expressions,
7430 )?;
7431 writeln!(self.out, ";")?;
7432 }
7433 None => {
7434 writeln!(self.out, " = {{}};")?;
7435 }
7436 };
7437 } else if let Some(ref binding) = var.binding {
7438 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7439 if let Some(sampler) = resolved.as_inline_sampler(options) {
7440 let name = &self.names[&NameKey::GlobalVariable(handle)];
7442 writeln!(
7443 self.out,
7444 "{}constexpr {}::sampler {}(",
7445 back::INDENT,
7446 NAMESPACE,
7447 name
7448 )?;
7449 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7450 writeln!(self.out, "{});", back::INDENT)?;
7451 } else if let crate::TypeInner::Image {
7452 class: crate::ImageClass::External,
7453 ..
7454 } = module.types[var.ty].inner
7455 {
7456 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7459 let l1 = back::Level(1);
7460 let l2 = l1.next();
7461 writeln!(
7462 self.out,
7463 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7464 )?;
7465 for i in 0..3 {
7466 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7467 handle,
7468 ExternalTextureNameKey::Plane(i),
7469 )];
7470 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7471 }
7472 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7473 handle,
7474 ExternalTextureNameKey::Params,
7475 )];
7476 writeln!(self.out, "{l2}.params = {params_name},")?;
7477 writeln!(self.out, "{l1}}};")?;
7478 }
7479 }
7480 }
7481
7482 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7493 let arg_name =
7494 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7495 match module.types[arg.ty].inner {
7496 crate::TypeInner::Struct { ref members, .. } => {
7497 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7498 write!(
7499 self.out,
7500 "{}const {} {} = {{ ",
7501 back::INDENT,
7502 struct_name,
7503 arg_name
7504 )?;
7505 for (member_index, member) in members.iter().enumerate() {
7506 let key = NameKey::StructMember(arg.ty, member_index as u32);
7507 let name = &flattened_member_names[&key];
7508 if member_index != 0 {
7509 write!(self.out, ", ")?;
7510 }
7511 if self
7513 .struct_member_pads
7514 .contains(&(arg.ty, member_index as u32))
7515 {
7516 write!(self.out, "{{}}, ")?;
7517 }
7518 if let Some(crate::Binding::Location { .. }) = member.binding {
7519 if has_varyings {
7520 write!(self.out, "{varyings_member_name}.")?;
7521 }
7522 }
7523 write!(self.out, "{name}")?;
7524 }
7525 writeln!(self.out, " }};")?;
7526 }
7527 _ => {
7528 if let Some(crate::Binding::Location { .. }) = arg.binding {
7529 if has_varyings {
7530 writeln!(
7531 self.out,
7532 "{}const auto {} = {}.{};",
7533 back::INDENT,
7534 arg_name,
7535 varyings_member_name,
7536 arg_name
7537 )?;
7538 }
7539 }
7540 }
7541 }
7542 }
7543
7544 let guarded_indices =
7545 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7546
7547 let context = StatementContext {
7548 expression: ExpressionContext {
7549 function: fun,
7550 origin: FunctionOrigin::EntryPoint(ep_index as _),
7551 info: fun_info,
7552 lang_version: options.lang_version,
7553 policies: options.bounds_check_policies,
7554 guarded_indices,
7555 module,
7556 mod_info,
7557 pipeline_options,
7558 force_loop_bounding: options.force_loop_bounding,
7559 },
7560 result_struct: Some(&stage_out_name),
7561 };
7562
7563 self.put_locals(&context.expression)?;
7566 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7567 self.put_block(back::Level(1), &fun.body, &context)?;
7568 writeln!(self.out, "}}")?;
7569 if ep_index + 1 != module.entry_points.len() {
7570 writeln!(self.out)?;
7571 }
7572 self.named_expressions.clear();
7573 }
7574
7575 Ok(info)
7576 }
7577
7578 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7579 if flags.is_empty() {
7582 writeln!(
7583 self.out,
7584 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7585 )?;
7586 }
7587 if flags.contains(crate::Barrier::STORAGE) {
7588 writeln!(
7589 self.out,
7590 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7591 )?;
7592 }
7593 if flags.contains(crate::Barrier::WORK_GROUP) {
7594 writeln!(
7595 self.out,
7596 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7597 )?;
7598 }
7599 if flags.contains(crate::Barrier::SUB_GROUP) {
7600 writeln!(
7601 self.out,
7602 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7603 )?;
7604 }
7605 if flags.contains(crate::Barrier::TEXTURE) {
7606 writeln!(
7607 self.out,
7608 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7609 )?;
7610 }
7611 Ok(())
7612 }
7613}
7614
7615mod workgroup_mem_init {
7618 use crate::EntryPoint;
7619
7620 use super::*;
7621
7622 enum Access {
7623 GlobalVariable(Handle<crate::GlobalVariable>),
7624 StructMember(Handle<crate::Type>, u32),
7625 Array(usize),
7626 }
7627
7628 impl Access {
7629 fn write<W: Write>(
7630 &self,
7631 writer: &mut W,
7632 names: &FastHashMap<NameKey, String>,
7633 ) -> Result<(), core::fmt::Error> {
7634 match *self {
7635 Access::GlobalVariable(handle) => {
7636 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7637 }
7638 Access::StructMember(handle, index) => {
7639 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7640 }
7641 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7642 }
7643 }
7644 }
7645
7646 struct AccessStack {
7647 stack: Vec<Access>,
7648 array_depth: usize,
7649 }
7650
7651 impl AccessStack {
7652 const fn new() -> Self {
7653 Self {
7654 stack: Vec::new(),
7655 array_depth: 0,
7656 }
7657 }
7658
7659 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7660 let array_depth = self.array_depth;
7661 self.stack.push(Access::Array(array_depth));
7662 self.array_depth += 1;
7663 let res = cb(self, array_depth);
7664 self.stack.pop();
7665 self.array_depth -= 1;
7666 res
7667 }
7668
7669 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7670 self.stack.push(new);
7671 let res = cb(self);
7672 self.stack.pop();
7673 res
7674 }
7675
7676 fn write<W: Write>(
7677 &self,
7678 writer: &mut W,
7679 names: &FastHashMap<NameKey, String>,
7680 ) -> Result<(), core::fmt::Error> {
7681 for next in self.stack.iter() {
7682 next.write(writer, names)?;
7683 }
7684 Ok(())
7685 }
7686 }
7687
7688 impl<W: Write> Writer<W> {
7689 pub(super) fn need_workgroup_variables_initialization(
7690 &mut self,
7691 options: &Options,
7692 ep: &EntryPoint,
7693 module: &crate::Module,
7694 fun_info: &valid::FunctionInfo,
7695 ) -> bool {
7696 options.zero_initialize_workgroup_memory
7697 && ep.stage == crate::ShaderStage::Compute
7698 && module.global_variables.iter().any(|(handle, var)| {
7699 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7700 })
7701 }
7702
7703 pub(super) fn write_workgroup_variables_initialization(
7704 &mut self,
7705 module: &crate::Module,
7706 module_info: &valid::ModuleInfo,
7707 fun_info: &valid::FunctionInfo,
7708 local_invocation_id: Option<&NameKey>,
7709 ) -> BackendResult {
7710 let level = back::Level(1);
7711
7712 writeln!(
7713 self.out,
7714 "{}if ({}::all({} == {}::uint3(0u))) {{",
7715 level,
7716 NAMESPACE,
7717 local_invocation_id
7718 .map(|name_key| self.names[name_key].as_str())
7719 .unwrap_or("__local_invocation_id"),
7720 NAMESPACE,
7721 )?;
7722
7723 let mut access_stack = AccessStack::new();
7724
7725 let vars = module.global_variables.iter().filter(|&(handle, var)| {
7726 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7727 });
7728
7729 for (handle, var) in vars {
7730 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7731 self.write_workgroup_variable_initialization(
7732 module,
7733 module_info,
7734 var.ty,
7735 access_stack,
7736 level.next(),
7737 )
7738 })?;
7739 }
7740
7741 writeln!(self.out, "{level}}}")?;
7742 self.write_barrier(crate::Barrier::WORK_GROUP, level)
7743 }
7744
7745 fn write_workgroup_variable_initialization(
7746 &mut self,
7747 module: &crate::Module,
7748 module_info: &valid::ModuleInfo,
7749 ty: Handle<crate::Type>,
7750 access_stack: &mut AccessStack,
7751 level: back::Level,
7752 ) -> BackendResult {
7753 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7754 write!(self.out, "{level}")?;
7755 access_stack.write(&mut self.out, &self.names)?;
7756 writeln!(self.out, " = {{}};")?;
7757 } else {
7758 match module.types[ty].inner {
7759 crate::TypeInner::Atomic { .. } => {
7760 write!(
7761 self.out,
7762 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7763 )?;
7764 access_stack.write(&mut self.out, &self.names)?;
7765 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7766 }
7767 crate::TypeInner::Array { base, size, .. } => {
7768 let count = match size.resolve(module.to_ctx())? {
7769 proc::IndexableLength::Known(count) => count,
7770 proc::IndexableLength::Dynamic => unreachable!(),
7771 };
7772
7773 access_stack.enter_array(|access_stack, array_depth| {
7774 writeln!(
7775 self.out,
7776 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7777 )?;
7778 self.write_workgroup_variable_initialization(
7779 module,
7780 module_info,
7781 base,
7782 access_stack,
7783 level.next(),
7784 )?;
7785 writeln!(self.out, "{level}}}")?;
7786 BackendResult::Ok(())
7787 })?;
7788 }
7789 crate::TypeInner::Struct { ref members, .. } => {
7790 for (index, member) in members.iter().enumerate() {
7791 access_stack.enter(
7792 Access::StructMember(ty, index as u32),
7793 |access_stack| {
7794 self.write_workgroup_variable_initialization(
7795 module,
7796 module_info,
7797 member.ty,
7798 access_stack,
7799 level,
7800 )
7801 },
7802 )?;
7803 }
7804 }
7805 _ => unreachable!(),
7806 }
7807 }
7808
7809 Ok(())
7810 }
7811 }
7812}
7813
7814impl crate::AtomicFunction {
7815 const fn to_msl(self) -> &'static str {
7816 match self {
7817 Self::Add => "fetch_add",
7818 Self::Subtract => "fetch_sub",
7819 Self::And => "fetch_and",
7820 Self::InclusiveOr => "fetch_or",
7821 Self::ExclusiveOr => "fetch_xor",
7822 Self::Min => "fetch_min",
7823 Self::Max => "fetch_max",
7824 Self::Exchange { compare: None } => "exchange",
7825 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7826 }
7827 }
7828
7829 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7830 Ok(match self {
7831 Self::Min => "min",
7832 Self::Max => "max",
7833 _ => Err(Error::FeatureNotImplemented(
7834 "64-bit atomic operation other than min/max".to_string(),
7835 ))?,
7836 })
7837 }
7838}