1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self, concrete_int_scalars,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
59pub(crate) const MOD_FUNCTION: &str = "naga_mod";
60pub(crate) const NEG_FUNCTION: &str = "naga_neg";
61pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
62pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
63pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
64pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
65pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
66pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
67pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
68 "nagaTextureSampleBaseClampToEdge";
69pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
77pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
82
83fn put_numeric_type(
92 out: &mut impl Write,
93 scalar: crate::Scalar,
94 sizes: &[crate::VectorSize],
95) -> Result<(), FmtError> {
96 match (scalar, sizes) {
97 (scalar, &[]) => {
98 write!(out, "{}", scalar.to_msl_name())
99 }
100 (scalar, &[rows]) => {
101 write!(
102 out,
103 "{}::{}{}",
104 NAMESPACE,
105 scalar.to_msl_name(),
106 common::vector_size_str(rows)
107 )
108 }
109 (scalar, &[rows, columns]) => {
110 write!(
111 out,
112 "{}::{}{}x{}",
113 NAMESPACE,
114 scalar.to_msl_name(),
115 common::vector_size_str(columns),
116 common::vector_size_str(rows)
117 )
118 }
119 (_, _) => Ok(()), }
121}
122
123const fn scalar_is_int(scalar: crate::Scalar) -> bool {
124 use crate::ScalarKind::*;
125 match scalar.kind {
126 Sint | Uint | AbstractInt | Bool => true,
127 Float | AbstractFloat => false,
128 }
129}
130
131const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
133
134const REINTERPRET_PREFIX: &str = "reinterpreted_";
136
137struct ClampedLod(Handle<crate::Expression>);
143
144impl Display for ClampedLod {
145 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
146 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
147 }
148}
149
150struct ArraySizeMember(Handle<crate::GlobalVariable>);
165
166impl Display for ArraySizeMember {
167 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
168 self.0.write_prefixed(f, "size")
169 }
170}
171
172#[derive(Clone, Copy)]
177struct Reinterpreted<'a> {
178 target_type: &'a str,
179 orig: Handle<crate::Expression>,
180}
181
182impl<'a> Reinterpreted<'a> {
183 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
184 Self { target_type, orig }
185 }
186}
187
188impl Display for Reinterpreted<'_> {
189 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
190 f.write_str(REINTERPRET_PREFIX)?;
191 f.write_str(self.target_type)?;
192 self.orig.write_prefixed(f, "_e")
193 }
194}
195
196struct TypeContext<'a> {
197 handle: Handle<crate::Type>,
198 gctx: proc::GlobalCtx<'a>,
199 names: &'a FastHashMap<NameKey, String>,
200 access: crate::StorageAccess,
201 first_time: bool,
202}
203
204impl TypeContext<'_> {
205 fn scalar(&self) -> Option<crate::Scalar> {
206 let ty = &self.gctx.types[self.handle];
207 ty.inner.scalar()
208 }
209
210 fn vertex_input_dimension(&self) -> u32 {
211 let ty = &self.gctx.types[self.handle];
212 match ty.inner {
213 crate::TypeInner::Scalar(_) => 1,
214 crate::TypeInner::Vector { size, .. } => size as u32,
215 _ => unreachable!(),
216 }
217 }
218}
219
220impl Display for TypeContext<'_> {
221 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
222 let ty = &self.gctx.types[self.handle];
223 if ty.needs_alias() && !self.first_time {
224 let name = &self.names[&NameKey::Type(self.handle)];
225 return write!(out, "{name}");
226 }
227
228 match ty.inner {
229 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
230 crate::TypeInner::Atomic(scalar) => {
231 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
232 }
233 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
234 crate::TypeInner::Matrix {
235 columns,
236 rows,
237 scalar,
238 } => put_numeric_type(out, scalar, &[rows, columns]),
239 crate::TypeInner::Pointer { base, space } => {
240 let sub = Self {
241 handle: base,
242 first_time: false,
243 ..*self
244 };
245 let space_name = match space.to_msl_name() {
246 Some(name) => name,
247 None => return Ok(()),
248 };
249 write!(out, "{space_name} {sub}&")
250 }
251 crate::TypeInner::ValuePointer {
252 size,
253 scalar,
254 space,
255 } => {
256 match space.to_msl_name() {
257 Some(name) => write!(out, "{name} ")?,
258 None => return Ok(()),
259 };
260 match size {
261 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
262 None => put_numeric_type(out, scalar, &[])?,
263 };
264
265 write!(out, "&")
266 }
267 crate::TypeInner::Array { base, .. } => {
268 let sub = Self {
269 handle: base,
270 first_time: false,
271 ..*self
272 };
273 write!(out, "{sub}")
276 }
277 crate::TypeInner::Struct { .. } => unreachable!(),
278 crate::TypeInner::Image {
279 dim,
280 arrayed,
281 class,
282 } => {
283 let dim_str = match dim {
284 crate::ImageDimension::D1 => "1d",
285 crate::ImageDimension::D2 => "2d",
286 crate::ImageDimension::D3 => "3d",
287 crate::ImageDimension::Cube => "cube",
288 };
289 let (texture_str, msaa_str, scalar, access) = match class {
290 crate::ImageClass::Sampled { kind, multi } => {
291 let (msaa_str, access) = if multi {
292 ("_ms", "read")
293 } else {
294 ("", "sample")
295 };
296 let scalar = crate::Scalar { kind, width: 4 };
297 ("texture", msaa_str, scalar, access)
298 }
299 crate::ImageClass::Depth { multi } => {
300 let (msaa_str, access) = if multi {
301 ("_ms", "read")
302 } else {
303 ("", "sample")
304 };
305 let scalar = crate::Scalar {
306 kind: crate::ScalarKind::Float,
307 width: 4,
308 };
309 ("depth", msaa_str, scalar, access)
310 }
311 crate::ImageClass::Storage { format, .. } => {
312 let access = if self
313 .access
314 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
315 {
316 "read_write"
317 } else if self.access.contains(crate::StorageAccess::STORE) {
318 "write"
319 } else if self.access.contains(crate::StorageAccess::LOAD) {
320 "read"
321 } else {
322 log::warn!(
323 "Storage access for {:?} (name '{}'): {:?}",
324 self.handle,
325 ty.name.as_deref().unwrap_or_default(),
326 self.access
327 );
328 unreachable!("module is not valid");
329 };
330 ("texture", "", format.into(), access)
331 }
332 crate::ImageClass::External => {
333 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
334 }
335 };
336 let base_name = scalar.to_msl_name();
337 let array_str = if arrayed { "_array" } else { "" };
338 write!(
339 out,
340 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
341 )
342 }
343 crate::TypeInner::Sampler { comparison: _ } => {
344 write!(out, "{NAMESPACE}::sampler")
345 }
346 crate::TypeInner::AccelerationStructure { vertex_return } => {
347 if vertex_return {
348 unimplemented!("metal does not support vertex ray hit return")
349 }
350 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
351 }
352 crate::TypeInner::RayQuery { vertex_return } => {
353 if vertex_return {
354 unimplemented!("metal does not support vertex ray hit return")
355 }
356 write!(out, "{RAY_QUERY_TYPE}")
357 }
358 crate::TypeInner::BindingArray { base, .. } => {
359 let base_tyname = Self {
360 handle: base,
361 first_time: false,
362 ..*self
363 };
364
365 write!(
366 out,
367 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
368 )
369 }
370 }
371 }
372}
373
374struct TypedGlobalVariable<'a> {
375 module: &'a crate::Module,
376 names: &'a FastHashMap<NameKey, String>,
377 handle: Handle<crate::GlobalVariable>,
378 usage: valid::GlobalUse,
379 reference: bool,
380}
381
382impl TypedGlobalVariable<'_> {
383 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
384 let var = &self.module.global_variables[self.handle];
385 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
386
387 let storage_access = match var.space {
388 crate::AddressSpace::Storage { access } => access,
389 _ => match self.module.types[var.ty].inner {
390 crate::TypeInner::Image {
391 class: crate::ImageClass::Storage { access, .. },
392 ..
393 } => access,
394 crate::TypeInner::BindingArray { base, .. } => {
395 match self.module.types[base].inner {
396 crate::TypeInner::Image {
397 class: crate::ImageClass::Storage { access, .. },
398 ..
399 } => access,
400 _ => crate::StorageAccess::default(),
401 }
402 }
403 _ => crate::StorageAccess::default(),
404 },
405 };
406 let ty_name = TypeContext {
407 handle: var.ty,
408 gctx: self.module.to_ctx(),
409 names: self.names,
410 access: storage_access,
411 first_time: false,
412 };
413
414 let (space, access, reference) = match var.space.to_msl_name() {
415 Some(space) if self.reference => {
416 let access = if var.space.needs_access_qualifier()
417 && !self.usage.intersects(valid::GlobalUse::WRITE)
418 {
419 "const"
420 } else {
421 ""
422 };
423 (space, access, "&")
424 }
425 _ => ("", "", ""),
426 };
427
428 Ok(write!(
429 out,
430 "{}{}{}{}{}{} {}",
431 space,
432 if space.is_empty() { "" } else { " " },
433 ty_name,
434 if access.is_empty() { "" } else { " " },
435 access,
436 reference,
437 name,
438 )?)
439 }
440}
441
442#[derive(Eq, PartialEq, Hash)]
443enum WrappedFunction {
444 UnaryOp {
445 op: crate::UnaryOperator,
446 ty: (Option<crate::VectorSize>, crate::Scalar),
447 },
448 BinaryOp {
449 op: crate::BinaryOperator,
450 left_ty: (Option<crate::VectorSize>, crate::Scalar),
451 right_ty: (Option<crate::VectorSize>, crate::Scalar),
452 },
453 Math {
454 fun: crate::MathFunction,
455 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
456 },
457 Cast {
458 src_scalar: crate::Scalar,
459 vector_size: Option<crate::VectorSize>,
460 dst_scalar: crate::Scalar,
461 },
462 ImageLoad {
463 class: crate::ImageClass,
464 },
465 ImageSample {
466 class: crate::ImageClass,
467 clamp_to_edge: bool,
468 },
469 ImageQuerySize {
470 class: crate::ImageClass,
471 },
472}
473
474pub struct Writer<W> {
475 out: W,
476 names: FastHashMap<NameKey, String>,
477 named_expressions: crate::NamedExpressions,
478 need_bake_expressions: back::NeedBakeExpressions,
480 namer: proc::Namer,
481 wrapped_functions: FastHashSet<WrappedFunction>,
482 #[cfg(test)]
483 put_expression_stack_pointers: FastHashSet<*const ()>,
484 #[cfg(test)]
485 put_block_stack_pointers: FastHashSet<*const ()>,
486 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
489}
490
491impl crate::Scalar {
492 pub(super) fn to_msl_name(self) -> &'static str {
493 use crate::ScalarKind as Sk;
494 match self {
495 Self {
496 kind: Sk::Float,
497 width: 4,
498 } => "float",
499 Self {
500 kind: Sk::Float,
501 width: 2,
502 } => "half",
503 Self {
504 kind: Sk::Sint,
505 width: 4,
506 } => "int",
507 Self {
508 kind: Sk::Uint,
509 width: 4,
510 } => "uint",
511 Self {
512 kind: Sk::Sint,
513 width: 8,
514 } => "long",
515 Self {
516 kind: Sk::Uint,
517 width: 8,
518 } => "ulong",
519 Self {
520 kind: Sk::Bool,
521 width: _,
522 } => "bool",
523 Self {
524 kind: Sk::AbstractInt | Sk::AbstractFloat,
525 width: _,
526 } => unreachable!("Found Abstract scalar kind"),
527 _ => unreachable!("Unsupported scalar kind: {:?}", self),
528 }
529 }
530}
531
532const fn separate(need_separator: bool) -> &'static str {
533 if need_separator {
534 ","
535 } else {
536 ""
537 }
538}
539
540fn should_pack_struct_member(
541 members: &[crate::StructMember],
542 span: u32,
543 index: usize,
544 module: &crate::Module,
545) -> Option<crate::Scalar> {
546 let member = &members[index];
547
548 let ty_inner = &module.types[member.ty].inner;
549 let last_offset = member.offset + ty_inner.size(module.to_ctx());
550 let next_offset = match members.get(index + 1) {
551 Some(next) => next.offset,
552 None => span,
553 };
554 let is_tight = next_offset == last_offset;
555
556 match *ty_inner {
557 crate::TypeInner::Vector {
558 size: crate::VectorSize::Tri,
559 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
560 } if is_tight => Some(scalar),
561 _ => None,
562 }
563}
564
565fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
566 match arena[ty].inner {
567 crate::TypeInner::Struct { ref members, .. } => {
568 if let Some(member) = members.last() {
569 if let crate::TypeInner::Array {
570 size: crate::ArraySize::Dynamic,
571 ..
572 } = arena[member.ty].inner
573 {
574 return true;
575 }
576 }
577 false
578 }
579 crate::TypeInner::Array {
580 size: crate::ArraySize::Dynamic,
581 ..
582 } => true,
583 _ => false,
584 }
585}
586
587impl crate::AddressSpace {
588 const fn needs_pass_through(&self) -> bool {
592 match *self {
593 Self::Uniform
594 | Self::Storage { .. }
595 | Self::Private
596 | Self::WorkGroup
597 | Self::PushConstant
598 | Self::Handle
599 | Self::TaskPayload => true,
600 Self::Function => false,
601 }
602 }
603
604 const fn needs_access_qualifier(&self) -> bool {
606 match *self {
607 Self::Storage { .. } => true,
612 Self::TaskPayload => unimplemented!(),
613 Self::Private | Self::WorkGroup => false,
615 Self::Uniform | Self::PushConstant => false,
617 Self::Handle | Self::Function => false,
619 }
620 }
621
622 const fn to_msl_name(self) -> Option<&'static str> {
623 match self {
624 Self::Handle => None,
625 Self::Uniform | Self::PushConstant => Some("constant"),
626 Self::Storage { .. } => Some("device"),
627 Self::Private | Self::Function => Some("thread"),
628 Self::WorkGroup => Some("threadgroup"),
629 Self::TaskPayload => Some("object_data"),
630 }
631 }
632}
633
634impl crate::Type {
635 const fn needs_alias(&self) -> bool {
637 use crate::TypeInner as Ti;
638
639 match self.inner {
640 Ti::Scalar(_)
642 | Ti::Vector { .. }
643 | Ti::Matrix { .. }
644 | Ti::Atomic(_)
645 | Ti::Pointer { .. }
646 | Ti::ValuePointer { .. } => self.name.is_some(),
647 Ti::Struct { .. } | Ti::Array { .. } => true,
649 Ti::Image { .. }
651 | Ti::Sampler { .. }
652 | Ti::AccelerationStructure { .. }
653 | Ti::RayQuery { .. }
654 | Ti::BindingArray { .. } => false,
655 }
656 }
657}
658
659#[derive(Clone, Copy)]
660enum FunctionOrigin {
661 Handle(Handle<crate::Function>),
662 EntryPoint(proc::EntryPointIndex),
663}
664
665trait NameKeyExt {
666 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
667 match origin {
668 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
669 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
670 }
671 }
672
673 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
678 match origin {
679 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
680 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
681 }
682 }
683}
684
685impl NameKeyExt for NameKey {}
686
687#[derive(Clone, Copy)]
697enum LevelOfDetail {
698 Direct(Handle<crate::Expression>),
699 Restricted(Handle<crate::Expression>),
700}
701
702struct TexelAddress {
712 coordinate: Handle<crate::Expression>,
713 array_index: Option<Handle<crate::Expression>>,
714 sample: Option<Handle<crate::Expression>>,
715 level: Option<LevelOfDetail>,
716}
717
718struct ExpressionContext<'a> {
719 function: &'a crate::Function,
720 origin: FunctionOrigin,
721 info: &'a valid::FunctionInfo,
722 module: &'a crate::Module,
723 mod_info: &'a valid::ModuleInfo,
724 pipeline_options: &'a PipelineOptions,
725 lang_version: (u8, u8),
726 policies: index::BoundsCheckPolicies,
727
728 guarded_indices: HandleSet<crate::Expression>,
732 force_loop_bounding: bool,
734}
735
736impl<'a> ExpressionContext<'a> {
737 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
738 self.info[handle].ty.inner_with(&self.module.types)
739 }
740
741 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
748 let image_ty = self.resolve_type(image);
749 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
750 class.is_mipmapped() && dim != crate::ImageDimension::D1
751 } else {
752 false
753 }
754 }
755
756 fn choose_bounds_check_policy(
757 &self,
758 pointer: Handle<crate::Expression>,
759 ) -> index::BoundsCheckPolicy {
760 self.policies
761 .choose_policy(pointer, &self.module.types, self.info)
762 }
763
764 fn access_needs_check(
766 &self,
767 base: Handle<crate::Expression>,
768 index: index::GuardedIndex,
769 ) -> Option<index::IndexableLength> {
770 index::access_needs_check(
771 base,
772 index,
773 self.module,
774 &self.function.expressions,
775 self.info,
776 )
777 }
778
779 fn bounds_check_iter(
781 &self,
782 chain: Handle<crate::Expression>,
783 ) -> impl Iterator<Item = BoundsCheck> + '_ {
784 index::bounds_check_iter(chain, self.module, self.function, self.info)
785 }
786
787 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
789 index::oob_local_types(self.module, self.function, self.info, self.policies)
790 }
791
792 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
793 match self.function.expressions[expr_handle] {
794 crate::Expression::AccessIndex { base, index } => {
795 let ty = match *self.resolve_type(base) {
796 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
797 ref ty => ty,
798 };
799 match *ty {
800 crate::TypeInner::Struct {
801 ref members, span, ..
802 } => should_pack_struct_member(members, span, index as usize, self.module),
803 _ => None,
804 }
805 }
806 _ => None,
807 }
808 }
809}
810
811struct StatementContext<'a> {
812 expression: ExpressionContext<'a>,
813 result_struct: Option<&'a str>,
814}
815
816impl<W: Write> Writer<W> {
817 pub fn new(out: W) -> Self {
819 Writer {
820 out,
821 names: FastHashMap::default(),
822 named_expressions: Default::default(),
823 need_bake_expressions: Default::default(),
824 namer: proc::Namer::default(),
825 wrapped_functions: FastHashSet::default(),
826 #[cfg(test)]
827 put_expression_stack_pointers: Default::default(),
828 #[cfg(test)]
829 put_block_stack_pointers: Default::default(),
830 struct_member_pads: FastHashSet::default(),
831 }
832 }
833
834 #[allow(clippy::missing_const_for_fn)]
837 pub fn finish(self) -> W {
838 self.out
839 }
840
841 fn gen_force_bounded_loop_statements(
948 &mut self,
949 level: back::Level,
950 context: &StatementContext,
951 ) -> Option<(String, String)> {
952 if !context.expression.force_loop_bounding {
953 return None;
954 }
955
956 let loop_bound_name = self.namer.call("loop_bound");
957 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
960 let level = level.next();
961 let break_and_inc = format!(
962 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
963{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
964 );
965
966 Some((decl, break_and_inc))
967 }
968
969 fn put_call_parameters(
970 &mut self,
971 parameters: impl Iterator<Item = Handle<crate::Expression>>,
972 context: &ExpressionContext,
973 ) -> BackendResult {
974 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
975 writer.put_expression(expr, context, true)
976 })
977 }
978
979 fn put_call_parameters_impl<C, E>(
980 &mut self,
981 parameters: impl Iterator<Item = Handle<crate::Expression>>,
982 ctx: &C,
983 put_expression: E,
984 ) -> BackendResult
985 where
986 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
987 {
988 write!(self.out, "(")?;
989 for (i, handle) in parameters.enumerate() {
990 if i != 0 {
991 write!(self.out, ", ")?;
992 }
993 put_expression(self, ctx, handle)?;
994 }
995 write!(self.out, ")")?;
996 Ok(())
997 }
998
999 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1005 let oob_local_types = context.oob_local_types();
1006 for &ty in oob_local_types.iter() {
1007 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1008 self.names.insert(name_key, self.namer.call("oob"));
1009 }
1010
1011 for (name_key, ty, init) in context
1012 .function
1013 .local_variables
1014 .iter()
1015 .map(|(local_handle, local)| {
1016 let name_key = NameKey::local(context.origin, local_handle);
1017 (name_key, local.ty, local.init)
1018 })
1019 .chain(oob_local_types.iter().map(|&ty| {
1020 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1021 (name_key, ty, None)
1022 }))
1023 {
1024 let ty_name = TypeContext {
1025 handle: ty,
1026 gctx: context.module.to_ctx(),
1027 names: &self.names,
1028 access: crate::StorageAccess::empty(),
1029 first_time: false,
1030 };
1031 write!(
1032 self.out,
1033 "{}{} {}",
1034 back::INDENT,
1035 ty_name,
1036 self.names[&name_key]
1037 )?;
1038 match init {
1039 Some(value) => {
1040 write!(self.out, " = ")?;
1041 self.put_expression(value, context, true)?;
1042 }
1043 None => {
1044 write!(self.out, " = {{}}")?;
1045 }
1046 };
1047 writeln!(self.out, ";")?;
1048 }
1049 Ok(())
1050 }
1051
1052 fn put_level_of_detail(
1053 &mut self,
1054 level: LevelOfDetail,
1055 context: &ExpressionContext,
1056 ) -> BackendResult {
1057 match level {
1058 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1059 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1060 }
1061 Ok(())
1062 }
1063
1064 fn put_image_query(
1065 &mut self,
1066 image: Handle<crate::Expression>,
1067 query: &str,
1068 level: Option<LevelOfDetail>,
1069 context: &ExpressionContext,
1070 ) -> BackendResult {
1071 self.put_expression(image, context, false)?;
1072 write!(self.out, ".get_{query}(")?;
1073 if let Some(level) = level {
1074 self.put_level_of_detail(level, context)?;
1075 }
1076 write!(self.out, ")")?;
1077 Ok(())
1078 }
1079
1080 fn put_image_size_query(
1081 &mut self,
1082 image: Handle<crate::Expression>,
1083 level: Option<LevelOfDetail>,
1084 kind: crate::ScalarKind,
1085 context: &ExpressionContext,
1086 ) -> BackendResult {
1087 if let crate::TypeInner::Image {
1088 class: crate::ImageClass::External,
1089 ..
1090 } = *context.resolve_type(image)
1091 {
1092 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1093 self.put_expression(image, context, true)?;
1094 write!(self.out, ")")?;
1095 return Ok(());
1096 }
1097
1098 let dim = match *context.resolve_type(image) {
1101 crate::TypeInner::Image { dim, .. } => dim,
1102 ref other => unreachable!("Unexpected type {:?}", other),
1103 };
1104 let scalar = crate::Scalar { kind, width: 4 };
1105 let coordinate_type = scalar.to_msl_name();
1106 match dim {
1107 crate::ImageDimension::D1 => {
1108 if kind == crate::ScalarKind::Uint {
1112 self.put_image_query(image, "width", None, context)?;
1114 } else {
1115 write!(self.out, "int(")?;
1117 self.put_image_query(image, "width", None, context)?;
1118 write!(self.out, ")")?;
1119 }
1120 }
1121 crate::ImageDimension::D2 => {
1122 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1123 self.put_image_query(image, "width", level, context)?;
1124 write!(self.out, ", ")?;
1125 self.put_image_query(image, "height", level, context)?;
1126 write!(self.out, ")")?;
1127 }
1128 crate::ImageDimension::D3 => {
1129 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1130 self.put_image_query(image, "width", level, context)?;
1131 write!(self.out, ", ")?;
1132 self.put_image_query(image, "height", level, context)?;
1133 write!(self.out, ", ")?;
1134 self.put_image_query(image, "depth", level, context)?;
1135 write!(self.out, ")")?;
1136 }
1137 crate::ImageDimension::Cube => {
1138 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1139 self.put_image_query(image, "width", level, context)?;
1140 write!(self.out, ")")?;
1141 }
1142 }
1143 Ok(())
1144 }
1145
1146 fn put_cast_to_uint_scalar_or_vector(
1147 &mut self,
1148 expr: Handle<crate::Expression>,
1149 context: &ExpressionContext,
1150 ) -> BackendResult {
1151 match *context.resolve_type(expr) {
1153 crate::TypeInner::Scalar(_) => {
1154 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1155 }
1156 crate::TypeInner::Vector { size, .. } => {
1157 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1158 }
1159 _ => {
1160 return Err(Error::GenericValidation(
1161 "Invalid type for image coordinate".into(),
1162 ))
1163 }
1164 };
1165
1166 write!(self.out, "(")?;
1167 self.put_expression(expr, context, true)?;
1168 write!(self.out, ")")?;
1169 Ok(())
1170 }
1171
1172 fn put_image_sample_level(
1173 &mut self,
1174 image: Handle<crate::Expression>,
1175 level: crate::SampleLevel,
1176 context: &ExpressionContext,
1177 ) -> BackendResult {
1178 let has_levels = context.image_needs_lod(image);
1179 match level {
1180 crate::SampleLevel::Auto => {}
1181 crate::SampleLevel::Zero => {
1182 }
1184 _ if !has_levels => {
1185 log::warn!("1D image can't be sampled with level {level:?}");
1186 }
1187 crate::SampleLevel::Exact(h) => {
1188 write!(self.out, ", {NAMESPACE}::level(")?;
1189 self.put_expression(h, context, true)?;
1190 write!(self.out, ")")?;
1191 }
1192 crate::SampleLevel::Bias(h) => {
1193 write!(self.out, ", {NAMESPACE}::bias(")?;
1194 self.put_expression(h, context, true)?;
1195 write!(self.out, ")")?;
1196 }
1197 crate::SampleLevel::Gradient { x, y } => {
1198 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1199 self.put_expression(x, context, true)?;
1200 write!(self.out, ", ")?;
1201 self.put_expression(y, context, true)?;
1202 write!(self.out, ")")?;
1203 }
1204 }
1205 Ok(())
1206 }
1207
1208 fn put_image_coordinate_limits(
1209 &mut self,
1210 image: Handle<crate::Expression>,
1211 level: Option<LevelOfDetail>,
1212 context: &ExpressionContext,
1213 ) -> BackendResult {
1214 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1215 write!(self.out, " - 1")?;
1216 Ok(())
1217 }
1218
1219 fn put_restricted_scalar_image_index(
1237 &mut self,
1238 image: Handle<crate::Expression>,
1239 index: Handle<crate::Expression>,
1240 limit_method: &str,
1241 context: &ExpressionContext,
1242 ) -> BackendResult {
1243 write!(self.out, "{NAMESPACE}::min(uint(")?;
1244 self.put_expression(index, context, true)?;
1245 write!(self.out, "), ")?;
1246 self.put_expression(image, context, false)?;
1247 write!(self.out, ".{limit_method}() - 1)")?;
1248 Ok(())
1249 }
1250
1251 fn put_restricted_texel_address(
1252 &mut self,
1253 image: Handle<crate::Expression>,
1254 address: &TexelAddress,
1255 context: &ExpressionContext,
1256 ) -> BackendResult {
1257 write!(self.out, "{NAMESPACE}::min(")?;
1259 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1260 write!(self.out, ", ")?;
1261 self.put_image_coordinate_limits(image, address.level, context)?;
1262 write!(self.out, ")")?;
1263
1264 if let Some(array_index) = address.array_index {
1266 write!(self.out, ", ")?;
1267 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1268 }
1269
1270 if let Some(sample) = address.sample {
1272 write!(self.out, ", ")?;
1273 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1274 }
1275
1276 if let Some(level) = address.level {
1279 write!(self.out, ", ")?;
1280 self.put_level_of_detail(level, context)?;
1281 }
1282
1283 Ok(())
1284 }
1285
1286 fn put_image_access_bounds_check(
1288 &mut self,
1289 image: Handle<crate::Expression>,
1290 address: &TexelAddress,
1291 context: &ExpressionContext,
1292 ) -> BackendResult {
1293 let mut conjunction = "";
1294
1295 let level = if let Some(level) = address.level {
1298 write!(self.out, "uint(")?;
1299 self.put_level_of_detail(level, context)?;
1300 write!(self.out, ") < ")?;
1301 self.put_expression(image, context, true)?;
1302 write!(self.out, ".get_num_mip_levels()")?;
1303 conjunction = " && ";
1304 Some(level)
1305 } else {
1306 None
1307 };
1308
1309 if let Some(sample) = address.sample {
1311 write!(self.out, "uint(")?;
1312 self.put_expression(sample, context, true)?;
1313 write!(self.out, ") < ")?;
1314 self.put_expression(image, context, true)?;
1315 write!(self.out, ".get_num_samples()")?;
1316 conjunction = " && ";
1317 }
1318
1319 if let Some(array_index) = address.array_index {
1321 write!(self.out, "{conjunction}uint(")?;
1322 self.put_expression(array_index, context, true)?;
1323 write!(self.out, ") < ")?;
1324 self.put_expression(image, context, true)?;
1325 write!(self.out, ".get_array_size()")?;
1326 conjunction = " && ";
1327 }
1328
1329 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1331 crate::TypeInner::Vector { .. } => true,
1332 _ => false,
1333 };
1334 write!(self.out, "{conjunction}")?;
1335 if coord_is_vector {
1336 write!(self.out, "{NAMESPACE}::all(")?;
1337 }
1338 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1339 write!(self.out, " < ")?;
1340 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1341 if coord_is_vector {
1342 write!(self.out, ")")?;
1343 }
1344
1345 Ok(())
1346 }
1347
1348 fn put_image_load(
1349 &mut self,
1350 load: Handle<crate::Expression>,
1351 image: Handle<crate::Expression>,
1352 mut address: TexelAddress,
1353 context: &ExpressionContext,
1354 ) -> BackendResult {
1355 if let crate::TypeInner::Image {
1356 class: crate::ImageClass::External,
1357 ..
1358 } = *context.resolve_type(image)
1359 {
1360 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1361 self.put_expression(image, context, true)?;
1362 write!(self.out, ", ")?;
1363 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1364 write!(self.out, ")")?;
1365 return Ok(());
1366 }
1367
1368 match context.policies.image_load {
1369 proc::BoundsCheckPolicy::Restrict => {
1370 if address.level.is_some() {
1373 address.level = if context.image_needs_lod(image) {
1374 Some(LevelOfDetail::Restricted(load))
1375 } else {
1376 None
1377 }
1378 }
1379
1380 self.put_expression(image, context, false)?;
1381 write!(self.out, ".read(")?;
1382 self.put_restricted_texel_address(image, &address, context)?;
1383 write!(self.out, ")")?;
1384 }
1385 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1386 write!(self.out, "(")?;
1387 self.put_image_access_bounds_check(image, &address, context)?;
1388 write!(self.out, " ? ")?;
1389 self.put_unchecked_image_load(image, &address, context)?;
1390 write!(self.out, ": DefaultConstructible())")?;
1391 }
1392 proc::BoundsCheckPolicy::Unchecked => {
1393 self.put_unchecked_image_load(image, &address, context)?;
1394 }
1395 }
1396
1397 Ok(())
1398 }
1399
1400 fn put_unchecked_image_load(
1401 &mut self,
1402 image: Handle<crate::Expression>,
1403 address: &TexelAddress,
1404 context: &ExpressionContext,
1405 ) -> BackendResult {
1406 self.put_expression(image, context, false)?;
1407 write!(self.out, ".read(")?;
1408 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1410 if let Some(expr) = address.array_index {
1411 write!(self.out, ", ")?;
1412 self.put_expression(expr, context, true)?;
1413 }
1414 if let Some(sample) = address.sample {
1415 write!(self.out, ", ")?;
1416 self.put_expression(sample, context, true)?;
1417 }
1418 if let Some(level) = address.level {
1419 if context.image_needs_lod(image) {
1420 write!(self.out, ", ")?;
1421 self.put_level_of_detail(level, context)?;
1422 }
1423 }
1424 write!(self.out, ")")?;
1425
1426 Ok(())
1427 }
1428
1429 fn put_image_atomic(
1430 &mut self,
1431 level: back::Level,
1432 image: Handle<crate::Expression>,
1433 address: &TexelAddress,
1434 fun: crate::AtomicFunction,
1435 value: Handle<crate::Expression>,
1436 context: &StatementContext,
1437 ) -> BackendResult {
1438 write!(self.out, "{level}")?;
1439 self.put_expression(image, &context.expression, false)?;
1440 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1441 fun.to_msl_64_bit()?
1442 } else {
1443 fun.to_msl()
1444 };
1445 write!(self.out, ".atomic_{op}(")?;
1446 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1448 write!(self.out, ", ")?;
1449 self.put_expression(value, &context.expression, true)?;
1450 writeln!(self.out, ");")?;
1451
1452 Ok(())
1453 }
1454
1455 fn put_image_store(
1456 &mut self,
1457 level: back::Level,
1458 image: Handle<crate::Expression>,
1459 address: &TexelAddress,
1460 value: Handle<crate::Expression>,
1461 context: &StatementContext,
1462 ) -> BackendResult {
1463 write!(self.out, "{level}")?;
1464 self.put_expression(image, &context.expression, false)?;
1465 write!(self.out, ".write(")?;
1466 self.put_expression(value, &context.expression, true)?;
1467 write!(self.out, ", ")?;
1468 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1470 if let Some(expr) = address.array_index {
1471 write!(self.out, ", ")?;
1472 self.put_expression(expr, &context.expression, true)?;
1473 }
1474 writeln!(self.out, ");")?;
1475
1476 Ok(())
1477 }
1478
1479 fn put_dynamic_array_max_index(
1490 &mut self,
1491 handle: Handle<crate::GlobalVariable>,
1492 context: &ExpressionContext,
1493 ) -> BackendResult {
1494 let global = &context.module.global_variables[handle];
1495 let (offset, array_ty) = match context.module.types[global.ty].inner {
1496 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1497 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1498 None => return Err(Error::GenericValidation("Struct has no members".into())),
1499 },
1500 crate::TypeInner::Array {
1501 size: crate::ArraySize::Dynamic,
1502 ..
1503 } => (0, global.ty),
1504 ref ty => {
1505 return Err(Error::GenericValidation(format!(
1506 "Expected type with dynamic array, got {ty:?}"
1507 )))
1508 }
1509 };
1510
1511 let (size, stride) = match context.module.types[array_ty].inner {
1512 crate::TypeInner::Array { base, stride, .. } => (
1513 context.module.types[base]
1514 .inner
1515 .size(context.module.to_ctx()),
1516 stride,
1517 ),
1518 ref ty => {
1519 return Err(Error::GenericValidation(format!(
1520 "Expected array type, got {ty:?}"
1521 )))
1522 }
1523 };
1524
1525 write!(
1538 self.out,
1539 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1540 member = ArraySizeMember(handle),
1541 offset = offset,
1542 size = size,
1543 stride = stride,
1544 )?;
1545 Ok(())
1546 }
1547
1548 fn put_dot_product<T: Copy>(
1553 &mut self,
1554 arg: T,
1555 arg1: T,
1556 size: usize,
1557 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1558 ) -> BackendResult {
1559 write!(self.out, "(")?;
1562
1563 for index in 0..size {
1565 write!(self.out, " + ")?;
1568 extractor(self, arg, index)?;
1569 write!(self.out, " * ")?;
1570 extractor(self, arg1, index)?;
1571 }
1572
1573 write!(self.out, ")")?;
1574 Ok(())
1575 }
1576
1577 fn put_pack4x8(
1579 &mut self,
1580 arg: Handle<crate::Expression>,
1581 context: &ExpressionContext<'_>,
1582 was_signed: bool,
1583 clamp_bounds: Option<(&str, &str)>,
1584 ) -> Result<(), Error> {
1585 let write_arg = |this: &mut Self| -> BackendResult {
1586 if let Some((min, max)) = clamp_bounds {
1587 write!(this.out, "{NAMESPACE}::clamp(")?;
1589 this.put_expression(arg, context, true)?;
1590 write!(this.out, ", {min}, {max})")?;
1591 } else {
1592 this.put_expression(arg, context, true)?;
1593 }
1594 Ok(())
1595 };
1596
1597 if context.lang_version >= (2, 1) {
1598 let packed_type = if was_signed {
1599 "packed_char4"
1600 } else {
1601 "packed_uchar4"
1602 };
1603 write!(self.out, "as_type<uint>({packed_type}(")?;
1605 write_arg(self)?;
1606 write!(self.out, "))")?;
1607 } else {
1608 if was_signed {
1610 write!(self.out, "uint(")?;
1611 }
1612 write!(self.out, "(")?;
1613 write_arg(self)?;
1614 write!(self.out, "[0] & 0xFF) | ((")?;
1615 write_arg(self)?;
1616 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1617 write_arg(self)?;
1618 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1619 write_arg(self)?;
1620 write!(self.out, "[3] & 0xFF) << 24)")?;
1621 if was_signed {
1622 write!(self.out, ")")?;
1623 }
1624 }
1625
1626 Ok(())
1627 }
1628
1629 fn put_isign(
1632 &mut self,
1633 arg: Handle<crate::Expression>,
1634 context: &ExpressionContext,
1635 ) -> BackendResult {
1636 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1637 let scalar = context
1638 .resolve_type(arg)
1639 .scalar()
1640 .expect("put_isign should only be called for args which have an integer scalar type")
1641 .to_msl_name();
1642 match context.resolve_type(arg) {
1643 &crate::TypeInner::Vector { size, .. } => {
1644 let size = common::vector_size_str(size);
1645 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1646 }
1647 _ => {
1648 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1649 }
1650 }
1651 write!(self.out, ", (")?;
1652 self.put_expression(arg, context, true)?;
1653 write!(self.out, " > 0)), {scalar}(0), (")?;
1654 self.put_expression(arg, context, true)?;
1655 write!(self.out, " == 0))")?;
1656 Ok(())
1657 }
1658
1659 fn put_const_expression(
1660 &mut self,
1661 expr_handle: Handle<crate::Expression>,
1662 module: &crate::Module,
1663 mod_info: &valid::ModuleInfo,
1664 arena: &crate::Arena<crate::Expression>,
1665 ) -> BackendResult {
1666 self.put_possibly_const_expression(
1667 expr_handle,
1668 arena,
1669 module,
1670 mod_info,
1671 &(module, mod_info),
1672 |&(_, mod_info), expr| &mod_info[expr],
1673 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1674 )
1675 }
1676
1677 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1678 match literal {
1679 crate::Literal::F64(_) => {
1680 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1681 }
1682 crate::Literal::F16(value) => {
1683 if value.is_infinite() {
1684 let sign = if value.is_sign_negative() { "-" } else { "" };
1685 write!(self.out, "{sign}INFINITY")?;
1686 } else if value.is_nan() {
1687 write!(self.out, "NAN")?;
1688 } else {
1689 let suffix = if value.fract() == f16::from_f32(0.0) {
1690 ".0h"
1691 } else {
1692 "h"
1693 };
1694 write!(self.out, "{value}{suffix}")?;
1695 }
1696 }
1697 crate::Literal::F32(value) => {
1698 if value.is_infinite() {
1699 let sign = if value.is_sign_negative() { "-" } else { "" };
1700 write!(self.out, "{sign}INFINITY")?;
1701 } else if value.is_nan() {
1702 write!(self.out, "NAN")?;
1703 } else {
1704 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1705 write!(self.out, "{value}{suffix}")?;
1706 }
1707 }
1708 crate::Literal::U32(value) => {
1709 write!(self.out, "{value}u")?;
1710 }
1711 crate::Literal::I32(value) => {
1712 if value == i32::MIN {
1717 write!(self.out, "({} - 1)", value + 1)?;
1718 } else {
1719 write!(self.out, "{value}")?;
1720 }
1721 }
1722 crate::Literal::U64(value) => {
1723 write!(self.out, "{value}uL")?;
1724 }
1725 crate::Literal::I64(value) => {
1726 if value == i64::MIN {
1733 write!(self.out, "({}L - 1L)", value + 1)?;
1734 } else {
1735 write!(self.out, "{value}L")?;
1736 }
1737 }
1738 crate::Literal::Bool(value) => {
1739 write!(self.out, "{value}")?;
1740 }
1741 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1742 return Err(Error::GenericValidation(
1743 "Unsupported abstract literal".into(),
1744 ));
1745 }
1746 }
1747 Ok(())
1748 }
1749
1750 #[allow(clippy::too_many_arguments)]
1751 fn put_possibly_const_expression<C, I, E>(
1752 &mut self,
1753 expr_handle: Handle<crate::Expression>,
1754 expressions: &crate::Arena<crate::Expression>,
1755 module: &crate::Module,
1756 mod_info: &valid::ModuleInfo,
1757 ctx: &C,
1758 get_expr_ty: I,
1759 put_expression: E,
1760 ) -> BackendResult
1761 where
1762 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1763 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1764 {
1765 match expressions[expr_handle] {
1766 crate::Expression::Literal(literal) => {
1767 self.put_literal(literal)?;
1768 }
1769 crate::Expression::Constant(handle) => {
1770 let constant = &module.constants[handle];
1771 if constant.name.is_some() {
1772 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1773 } else {
1774 self.put_const_expression(
1775 constant.init,
1776 module,
1777 mod_info,
1778 &module.global_expressions,
1779 )?;
1780 }
1781 }
1782 crate::Expression::ZeroValue(ty) => {
1783 let ty_name = TypeContext {
1784 handle: ty,
1785 gctx: module.to_ctx(),
1786 names: &self.names,
1787 access: crate::StorageAccess::empty(),
1788 first_time: false,
1789 };
1790 write!(self.out, "{ty_name} {{}}")?;
1791 }
1792 crate::Expression::Compose { ty, ref components } => {
1793 let ty_name = TypeContext {
1794 handle: ty,
1795 gctx: module.to_ctx(),
1796 names: &self.names,
1797 access: crate::StorageAccess::empty(),
1798 first_time: false,
1799 };
1800 write!(self.out, "{ty_name}")?;
1801 match module.types[ty].inner {
1802 crate::TypeInner::Scalar(_)
1803 | crate::TypeInner::Vector { .. }
1804 | crate::TypeInner::Matrix { .. } => {
1805 self.put_call_parameters_impl(
1806 components.iter().copied(),
1807 ctx,
1808 put_expression,
1809 )?;
1810 }
1811 crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1812 write!(self.out, " {{")?;
1813 for (index, &component) in components.iter().enumerate() {
1814 if index != 0 {
1815 write!(self.out, ", ")?;
1816 }
1817 if self.struct_member_pads.contains(&(ty, index as u32)) {
1819 write!(self.out, "{{}}, ")?;
1820 }
1821 put_expression(self, ctx, component)?;
1822 }
1823 write!(self.out, "}}")?;
1824 }
1825 _ => return Err(Error::UnsupportedCompose(ty)),
1826 }
1827 }
1828 crate::Expression::Splat { size, value } => {
1829 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1830 crate::TypeInner::Scalar(scalar) => scalar,
1831 ref ty => {
1832 return Err(Error::GenericValidation(format!(
1833 "Expected splat value type must be a scalar, got {ty:?}",
1834 )))
1835 }
1836 };
1837 put_numeric_type(&mut self.out, scalar, &[size])?;
1838 write!(self.out, "(")?;
1839 put_expression(self, ctx, value)?;
1840 write!(self.out, ")")?;
1841 }
1842 _ => {
1843 return Err(Error::Override);
1844 }
1845 }
1846
1847 Ok(())
1848 }
1849
1850 fn put_expression(
1862 &mut self,
1863 expr_handle: Handle<crate::Expression>,
1864 context: &ExpressionContext,
1865 is_scoped: bool,
1866 ) -> BackendResult {
1867 #[cfg(test)]
1869 self.put_expression_stack_pointers
1870 .insert(ptr::from_ref(&expr_handle).cast());
1871
1872 if let Some(name) = self.named_expressions.get(&expr_handle) {
1873 write!(self.out, "{name}")?;
1874 return Ok(());
1875 }
1876
1877 let expression = &context.function.expressions[expr_handle];
1878 match *expression {
1879 crate::Expression::Literal(_)
1880 | crate::Expression::Constant(_)
1881 | crate::Expression::ZeroValue(_)
1882 | crate::Expression::Compose { .. }
1883 | crate::Expression::Splat { .. } => {
1884 self.put_possibly_const_expression(
1885 expr_handle,
1886 &context.function.expressions,
1887 context.module,
1888 context.mod_info,
1889 context,
1890 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1891 |writer, context, expr| writer.put_expression(expr, context, true),
1892 )?;
1893 }
1894 crate::Expression::Override(_) => return Err(Error::Override),
1895 crate::Expression::Access { base, .. }
1896 | crate::Expression::AccessIndex { base, .. } => {
1897 let policy = context.choose_bounds_check_policy(base);
1903 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1904 && self.put_bounds_checks(
1905 expr_handle,
1906 context,
1907 back::Level(0),
1908 if is_scoped { "" } else { "(" },
1909 )?
1910 {
1911 write!(self.out, " ? ")?;
1912 self.put_access_chain(expr_handle, policy, context)?;
1913 write!(self.out, " : ")?;
1914
1915 if context.resolve_type(base).pointer_space().is_some() {
1916 let result_ty = context.info[expr_handle]
1920 .ty
1921 .inner_with(&context.module.types)
1922 .pointer_base_type();
1923 let result_ty_handle = match result_ty {
1924 Some(TypeResolution::Handle(handle)) => handle,
1925 Some(TypeResolution::Value(_)) => {
1926 unreachable!(
1934 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1935 );
1936 }
1937 None => {
1938 unreachable!(
1939 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1940 )
1941 }
1942 };
1943 let name_key =
1944 NameKey::oob_local_for_type(context.origin, result_ty_handle);
1945 self.out.write_str(&self.names[&name_key])?;
1946 } else {
1947 write!(self.out, "DefaultConstructible()")?;
1948 }
1949
1950 if !is_scoped {
1951 write!(self.out, ")")?;
1952 }
1953 } else {
1954 self.put_access_chain(expr_handle, policy, context)?;
1955 }
1956 }
1957 crate::Expression::Swizzle {
1958 size,
1959 vector,
1960 pattern,
1961 } => {
1962 self.put_wrapped_expression_for_packed_vec3_access(
1963 vector,
1964 context,
1965 false,
1966 &Self::put_expression,
1967 )?;
1968 write!(self.out, ".")?;
1969 for &sc in pattern[..size as usize].iter() {
1970 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1971 }
1972 }
1973 crate::Expression::FunctionArgument(index) => {
1974 let name_key = match context.origin {
1975 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1976 FunctionOrigin::EntryPoint(ep_index) => {
1977 NameKey::EntryPointArgument(ep_index, index)
1978 }
1979 };
1980 let name = &self.names[&name_key];
1981 write!(self.out, "{name}")?;
1982 }
1983 crate::Expression::GlobalVariable(handle) => {
1984 let name = &self.names[&NameKey::GlobalVariable(handle)];
1985 write!(self.out, "{name}")?;
1986 }
1987 crate::Expression::LocalVariable(handle) => {
1988 let name_key = NameKey::local(context.origin, handle);
1989 let name = &self.names[&name_key];
1990 write!(self.out, "{name}")?;
1991 }
1992 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1993 crate::Expression::ImageSample {
1994 coordinate,
1995 image,
1996 sampler,
1997 clamp_to_edge: true,
1998 gather: None,
1999 array_index: None,
2000 offset: None,
2001 level: crate::SampleLevel::Zero,
2002 depth_ref: None,
2003 } => {
2004 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2005 self.put_expression(image, context, true)?;
2006 write!(self.out, ", ")?;
2007 self.put_expression(sampler, context, true)?;
2008 write!(self.out, ", ")?;
2009 self.put_expression(coordinate, context, true)?;
2010 write!(self.out, ")")?;
2011 }
2012 crate::Expression::ImageSample {
2013 image,
2014 sampler,
2015 gather,
2016 coordinate,
2017 array_index,
2018 offset,
2019 level,
2020 depth_ref,
2021 clamp_to_edge,
2022 } => {
2023 if clamp_to_edge {
2024 return Err(Error::GenericValidation(
2025 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2026 ));
2027 }
2028
2029 let main_op = match gather {
2030 Some(_) => "gather",
2031 None => "sample",
2032 };
2033 let comparison_op = match depth_ref {
2034 Some(_) => "_compare",
2035 None => "",
2036 };
2037 self.put_expression(image, context, false)?;
2038 write!(self.out, ".{main_op}{comparison_op}(")?;
2039 self.put_expression(sampler, context, true)?;
2040 write!(self.out, ", ")?;
2041 self.put_expression(coordinate, context, true)?;
2042 if let Some(expr) = array_index {
2043 write!(self.out, ", ")?;
2044 self.put_expression(expr, context, true)?;
2045 }
2046 if let Some(dref) = depth_ref {
2047 write!(self.out, ", ")?;
2048 self.put_expression(dref, context, true)?;
2049 }
2050
2051 self.put_image_sample_level(image, level, context)?;
2052
2053 if let Some(offset) = offset {
2054 write!(self.out, ", ")?;
2055 self.put_expression(offset, context, true)?;
2056 }
2057
2058 match gather {
2059 None | Some(crate::SwizzleComponent::X) => {}
2060 Some(component) => {
2061 let is_cube_map = match *context.resolve_type(image) {
2062 crate::TypeInner::Image {
2063 dim: crate::ImageDimension::Cube,
2064 ..
2065 } => true,
2066 _ => false,
2067 };
2068 if offset.is_none() && !is_cube_map {
2071 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2072 }
2073 let letter = back::COMPONENTS[component as usize];
2074 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2075 }
2076 }
2077 write!(self.out, ")")?;
2078 }
2079 crate::Expression::ImageLoad {
2080 image,
2081 coordinate,
2082 array_index,
2083 sample,
2084 level,
2085 } => {
2086 let address = TexelAddress {
2087 coordinate,
2088 array_index,
2089 sample,
2090 level: level.map(LevelOfDetail::Direct),
2091 };
2092 self.put_image_load(expr_handle, image, address, context)?;
2093 }
2094 crate::Expression::ImageQuery { image, query } => match query {
2097 crate::ImageQuery::Size { level } => {
2098 self.put_image_size_query(
2099 image,
2100 level.map(LevelOfDetail::Direct),
2101 crate::ScalarKind::Uint,
2102 context,
2103 )?;
2104 }
2105 crate::ImageQuery::NumLevels => {
2106 self.put_expression(image, context, false)?;
2107 write!(self.out, ".get_num_mip_levels()")?;
2108 }
2109 crate::ImageQuery::NumLayers => {
2110 self.put_expression(image, context, false)?;
2111 write!(self.out, ".get_array_size()")?;
2112 }
2113 crate::ImageQuery::NumSamples => {
2114 self.put_expression(image, context, false)?;
2115 write!(self.out, ".get_num_samples()")?;
2116 }
2117 },
2118 crate::Expression::Unary { op, expr } => {
2119 let op_str = match op {
2120 crate::UnaryOperator::Negate => {
2121 match context.resolve_type(expr).scalar_kind() {
2122 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2123 _ => "-",
2124 }
2125 }
2126 crate::UnaryOperator::LogicalNot => "!",
2127 crate::UnaryOperator::BitwiseNot => "~",
2128 };
2129 write!(self.out, "{op_str}(")?;
2130 self.put_expression(expr, context, false)?;
2131 write!(self.out, ")")?;
2132 }
2133 crate::Expression::Binary { op, left, right } => {
2134 let kind = context
2135 .resolve_type(left)
2136 .scalar_kind()
2137 .ok_or(Error::UnsupportedBinaryOp(op))?;
2138
2139 if op == crate::BinaryOperator::Divide
2140 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2141 {
2142 write!(self.out, "{DIV_FUNCTION}(")?;
2143 self.put_expression(left, context, true)?;
2144 write!(self.out, ", ")?;
2145 self.put_expression(right, context, true)?;
2146 write!(self.out, ")")?;
2147 } else if op == crate::BinaryOperator::Modulo
2148 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2149 {
2150 write!(self.out, "{MOD_FUNCTION}(")?;
2151 self.put_expression(left, context, true)?;
2152 write!(self.out, ", ")?;
2153 self.put_expression(right, context, true)?;
2154 write!(self.out, ")")?;
2155 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2156 write!(self.out, "{NAMESPACE}::fmod(")?;
2161 self.put_expression(left, context, true)?;
2162 write!(self.out, ", ")?;
2163 self.put_expression(right, context, true)?;
2164 write!(self.out, ")")?;
2165 } else if (op == crate::BinaryOperator::Add
2166 || op == crate::BinaryOperator::Subtract
2167 || op == crate::BinaryOperator::Multiply)
2168 && kind == crate::ScalarKind::Sint
2169 {
2170 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2171 crate::TypeInner::Scalar(scalar) => {
2172 Ok(crate::TypeInner::Scalar(crate::Scalar {
2173 kind: crate::ScalarKind::Uint,
2174 ..scalar
2175 }))
2176 }
2177 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2178 size,
2179 scalar: crate::Scalar {
2180 kind: crate::ScalarKind::Uint,
2181 ..scalar
2182 },
2183 }),
2184 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2185 };
2186
2187 self.put_bitcasted_expression(
2192 context.resolve_type(expr_handle),
2193 context,
2194 &|writer, context, is_scoped| {
2195 writer.put_binop(
2196 op,
2197 left,
2198 right,
2199 context,
2200 is_scoped,
2201 &|writer, expr, context, _is_scoped| {
2202 writer.put_bitcasted_expression(
2203 &to_unsigned(context.resolve_type(expr))?,
2204 context,
2205 &|writer, context, is_scoped| {
2206 writer.put_expression(expr, context, is_scoped)
2207 },
2208 )
2209 },
2210 )
2211 },
2212 )?;
2213 } else {
2214 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2215 }
2216 }
2217 crate::Expression::Select {
2218 condition,
2219 accept,
2220 reject,
2221 } => match *context.resolve_type(condition) {
2222 crate::TypeInner::Scalar(crate::Scalar {
2223 kind: crate::ScalarKind::Bool,
2224 ..
2225 }) => {
2226 if !is_scoped {
2227 write!(self.out, "(")?;
2228 }
2229 self.put_expression(condition, context, false)?;
2230 write!(self.out, " ? ")?;
2231 self.put_expression(accept, context, false)?;
2232 write!(self.out, " : ")?;
2233 self.put_expression(reject, context, false)?;
2234 if !is_scoped {
2235 write!(self.out, ")")?;
2236 }
2237 }
2238 crate::TypeInner::Vector {
2239 scalar:
2240 crate::Scalar {
2241 kind: crate::ScalarKind::Bool,
2242 ..
2243 },
2244 ..
2245 } => {
2246 write!(self.out, "{NAMESPACE}::select(")?;
2247 self.put_expression(reject, context, true)?;
2248 write!(self.out, ", ")?;
2249 self.put_expression(accept, context, true)?;
2250 write!(self.out, ", ")?;
2251 self.put_expression(condition, context, true)?;
2252 write!(self.out, ")")?;
2253 }
2254 ref ty => {
2255 return Err(Error::GenericValidation(format!(
2256 "Expected select condition to be a non-bool type, got {ty:?}",
2257 )))
2258 }
2259 },
2260 crate::Expression::Derivative { axis, expr, .. } => {
2261 use crate::DerivativeAxis as Axis;
2262 let op = match axis {
2263 Axis::X => "dfdx",
2264 Axis::Y => "dfdy",
2265 Axis::Width => "fwidth",
2266 };
2267 write!(self.out, "{NAMESPACE}::{op}")?;
2268 self.put_call_parameters(iter::once(expr), context)?;
2269 }
2270 crate::Expression::Relational { fun, argument } => {
2271 let op = match fun {
2272 crate::RelationalFunction::Any => "any",
2273 crate::RelationalFunction::All => "all",
2274 crate::RelationalFunction::IsNan => "isnan",
2275 crate::RelationalFunction::IsInf => "isinf",
2276 };
2277 write!(self.out, "{NAMESPACE}::{op}")?;
2278 self.put_call_parameters(iter::once(argument), context)?;
2279 }
2280 crate::Expression::Math {
2281 fun,
2282 arg,
2283 arg1,
2284 arg2,
2285 arg3,
2286 } => {
2287 use crate::MathFunction as Mf;
2288
2289 let arg_type = context.resolve_type(arg);
2290 let scalar_argument = match arg_type {
2291 &crate::TypeInner::Scalar(_) => true,
2292 _ => false,
2293 };
2294
2295 let fun_name = match fun {
2296 Mf::Abs => "abs",
2298 Mf::Min => "min",
2299 Mf::Max => "max",
2300 Mf::Clamp => "clamp",
2301 Mf::Saturate => "saturate",
2302 Mf::Cos => "cos",
2304 Mf::Cosh => "cosh",
2305 Mf::Sin => "sin",
2306 Mf::Sinh => "sinh",
2307 Mf::Tan => "tan",
2308 Mf::Tanh => "tanh",
2309 Mf::Acos => "acos",
2310 Mf::Asin => "asin",
2311 Mf::Atan => "atan",
2312 Mf::Atan2 => "atan2",
2313 Mf::Asinh => "asinh",
2314 Mf::Acosh => "acosh",
2315 Mf::Atanh => "atanh",
2316 Mf::Radians => "",
2317 Mf::Degrees => "",
2318 Mf::Ceil => "ceil",
2320 Mf::Floor => "floor",
2321 Mf::Round => "rint",
2322 Mf::Fract => "fract",
2323 Mf::Trunc => "trunc",
2324 Mf::Modf => MODF_FUNCTION,
2325 Mf::Frexp => FREXP_FUNCTION,
2326 Mf::Ldexp => "ldexp",
2327 Mf::Exp => "exp",
2329 Mf::Exp2 => "exp2",
2330 Mf::Log => "log",
2331 Mf::Log2 => "log2",
2332 Mf::Pow => "pow",
2333 Mf::Dot => match *context.resolve_type(arg) {
2335 crate::TypeInner::Vector {
2336 scalar:
2337 crate::Scalar {
2338 kind: crate::ScalarKind::Float,
2340 ..
2341 },
2342 ..
2343 } => "dot",
2344 crate::TypeInner::Vector {
2345 size,
2346 scalar:
2347 scalar @ crate::Scalar {
2348 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2349 ..
2350 },
2351 } => {
2352 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2354 write!(self.out, "{fun_name}(")?;
2355 self.put_expression(arg, context, true)?;
2356 write!(self.out, ", ")?;
2357 self.put_expression(arg1.unwrap(), context, true)?;
2358 write!(self.out, ")")?;
2359 return Ok(());
2360 }
2361 _ => unreachable!(
2362 "Correct TypeInner for dot product should be already validated"
2363 ),
2364 },
2365 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2366 if context.lang_version >= (2, 1) {
2367 let packed_type = match fun {
2371 Mf::Dot4I8Packed => "packed_char4",
2372 Mf::Dot4U8Packed => "packed_uchar4",
2373 _ => unreachable!(),
2374 };
2375
2376 return self.put_dot_product(
2377 Reinterpreted::new(packed_type, arg),
2378 Reinterpreted::new(packed_type, arg1.unwrap()),
2379 4,
2380 |writer, arg, index| {
2381 write!(writer.out, "{arg}[{index}]")?;
2384 Ok(())
2385 },
2386 );
2387 } else {
2388 let conversion = match fun {
2392 Mf::Dot4I8Packed => "int",
2393 Mf::Dot4U8Packed => "",
2394 _ => unreachable!(),
2395 };
2396
2397 return self.put_dot_product(
2398 arg,
2399 arg1.unwrap(),
2400 4,
2401 |writer, arg, index| {
2402 write!(writer.out, "({conversion}(")?;
2403 writer.put_expression(arg, context, true)?;
2404 if index == 3 {
2405 write!(writer.out, ") >> 24)")?;
2406 } else {
2407 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2408 }
2409 Ok(())
2410 },
2411 );
2412 }
2413 }
2414 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2415 Mf::Cross => "cross",
2416 Mf::Distance => "distance",
2417 Mf::Length if scalar_argument => "abs",
2418 Mf::Length => "length",
2419 Mf::Normalize => "normalize",
2420 Mf::FaceForward => "faceforward",
2421 Mf::Reflect => "reflect",
2422 Mf::Refract => "refract",
2423 Mf::Sign => match arg_type.scalar_kind() {
2425 Some(crate::ScalarKind::Sint) => {
2426 return self.put_isign(arg, context);
2427 }
2428 _ => "sign",
2429 },
2430 Mf::Fma => "fma",
2431 Mf::Mix => "mix",
2432 Mf::Step => "step",
2433 Mf::SmoothStep => "smoothstep",
2434 Mf::Sqrt => "sqrt",
2435 Mf::InverseSqrt => "rsqrt",
2436 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2437 Mf::Transpose => "transpose",
2438 Mf::Determinant => "determinant",
2439 Mf::QuantizeToF16 => "",
2440 Mf::CountTrailingZeros => "ctz",
2442 Mf::CountLeadingZeros => "clz",
2443 Mf::CountOneBits => "popcount",
2444 Mf::ReverseBits => "reverse_bits",
2445 Mf::ExtractBits => "",
2446 Mf::InsertBits => "",
2447 Mf::FirstTrailingBit => "",
2448 Mf::FirstLeadingBit => "",
2449 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2451 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2452 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2453 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2454 Mf::Pack2x16float => "",
2455 Mf::Pack4xI8 => "",
2456 Mf::Pack4xU8 => "",
2457 Mf::Pack4xI8Clamp => "",
2458 Mf::Pack4xU8Clamp => "",
2459 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2461 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2462 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2463 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2464 Mf::Unpack2x16float => "",
2465 Mf::Unpack4xI8 => "",
2466 Mf::Unpack4xU8 => "",
2467 };
2468
2469 match fun {
2470 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2471 if context.lang_version < (1, 2) {
2480 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2481 }
2482 }
2483 _ => {}
2484 }
2485
2486 match fun {
2487 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2488 write!(self.out, "{ABS_FUNCTION}(")?;
2489 self.put_expression(arg, context, true)?;
2490 write!(self.out, ")")?;
2491 }
2492 Mf::Distance if scalar_argument => {
2493 write!(self.out, "{NAMESPACE}::abs(")?;
2494 self.put_expression(arg, context, false)?;
2495 write!(self.out, " - ")?;
2496 self.put_expression(arg1.unwrap(), context, false)?;
2497 write!(self.out, ")")?;
2498 }
2499 Mf::FirstTrailingBit => {
2500 let scalar = context.resolve_type(arg).scalar().unwrap();
2501 let constant = scalar.width * 8 + 1;
2502
2503 write!(self.out, "((({NAMESPACE}::ctz(")?;
2504 self.put_expression(arg, context, true)?;
2505 write!(self.out, ") + 1) % {constant}) - 1)")?;
2506 }
2507 Mf::FirstLeadingBit => {
2508 let inner = context.resolve_type(arg);
2509 let scalar = inner.scalar().unwrap();
2510 let constant = scalar.width * 8 - 1;
2511
2512 write!(
2513 self.out,
2514 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2515 )?;
2516
2517 if scalar.kind == crate::ScalarKind::Sint {
2518 write!(self.out, "{NAMESPACE}::select(")?;
2519 self.put_expression(arg, context, true)?;
2520 write!(self.out, ", ~")?;
2521 self.put_expression(arg, context, true)?;
2522 write!(self.out, ", ")?;
2523 self.put_expression(arg, context, true)?;
2524 write!(self.out, " < 0)")?;
2525 } else {
2526 self.put_expression(arg, context, true)?;
2527 }
2528
2529 write!(self.out, "), ")?;
2530
2531 match *inner {
2533 crate::TypeInner::Vector { size, scalar } => {
2534 let size = common::vector_size_str(size);
2535 let name = scalar.to_msl_name();
2536 write!(self.out, "{name}{size}")?;
2537 }
2538 crate::TypeInner::Scalar(scalar) => {
2539 let name = scalar.to_msl_name();
2540 write!(self.out, "{name}")?;
2541 }
2542 _ => (),
2543 }
2544
2545 write!(self.out, "(-1), ")?;
2546 self.put_expression(arg, context, true)?;
2547 write!(self.out, " == 0 || ")?;
2548 self.put_expression(arg, context, true)?;
2549 write!(self.out, " == -1)")?;
2550 }
2551 Mf::Unpack2x16float => {
2552 write!(self.out, "float2(as_type<half2>(")?;
2553 self.put_expression(arg, context, false)?;
2554 write!(self.out, "))")?;
2555 }
2556 Mf::Pack2x16float => {
2557 write!(self.out, "as_type<uint>(half2(")?;
2558 self.put_expression(arg, context, false)?;
2559 write!(self.out, "))")?;
2560 }
2561 Mf::ExtractBits => {
2562 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2579
2580 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2581 self.put_expression(arg, context, true)?;
2582 write!(self.out, ", {NAMESPACE}::min(")?;
2583 self.put_expression(arg1.unwrap(), context, true)?;
2584 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2585 self.put_expression(arg2.unwrap(), context, true)?;
2586 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2587 self.put_expression(arg1.unwrap(), context, true)?;
2588 write!(self.out, ", {scalar_bits}u)))")?;
2589 }
2590 Mf::InsertBits => {
2591 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2596
2597 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2598 self.put_expression(arg, context, true)?;
2599 write!(self.out, ", ")?;
2600 self.put_expression(arg1.unwrap(), context, true)?;
2601 write!(self.out, ", {NAMESPACE}::min(")?;
2602 self.put_expression(arg2.unwrap(), context, true)?;
2603 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2604 self.put_expression(arg3.unwrap(), context, true)?;
2605 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2606 self.put_expression(arg2.unwrap(), context, true)?;
2607 write!(self.out, ", {scalar_bits}u)))")?;
2608 }
2609 Mf::Radians => {
2610 write!(self.out, "((")?;
2611 self.put_expression(arg, context, false)?;
2612 write!(self.out, ") * 0.017453292519943295474)")?;
2613 }
2614 Mf::Degrees => {
2615 write!(self.out, "((")?;
2616 self.put_expression(arg, context, false)?;
2617 write!(self.out, ") * 57.295779513082322865)")?;
2618 }
2619 Mf::Modf | Mf::Frexp => {
2620 write!(self.out, "{fun_name}")?;
2621 self.put_call_parameters(iter::once(arg), context)?;
2622 }
2623 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2624 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2625 Mf::Pack4xI8Clamp => {
2626 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2627 }
2628 Mf::Pack4xU8Clamp => {
2629 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2630 }
2631 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2632 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2633 "u"
2634 } else {
2635 ""
2636 };
2637
2638 if context.lang_version >= (2, 1) {
2639 write!(
2641 self.out,
2642 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2643 )?;
2644 self.put_expression(arg, context, true)?;
2645 write!(self.out, "))")?;
2646 } else {
2647 write!(self.out, "({sign_prefix}int4(")?;
2649 self.put_expression(arg, context, true)?;
2650 write!(self.out, ", ")?;
2651 self.put_expression(arg, context, true)?;
2652 write!(self.out, " >> 8, ")?;
2653 self.put_expression(arg, context, true)?;
2654 write!(self.out, " >> 16, ")?;
2655 self.put_expression(arg, context, true)?;
2656 write!(self.out, " >> 24) << 24 >> 24)")?;
2657 }
2658 }
2659 Mf::QuantizeToF16 => {
2660 match *context.resolve_type(arg) {
2661 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2662 crate::TypeInner::Vector { size, .. } => write!(
2663 self.out,
2664 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2665 size = common::vector_size_str(size),
2666 )?,
2667 _ => unreachable!(
2668 "Correct TypeInner for QuantizeToF16 should be already validated"
2669 ),
2670 };
2671
2672 self.put_expression(arg, context, true)?;
2673 write!(self.out, "))")?;
2674 }
2675 _ => {
2676 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2677 self.put_call_parameters(
2678 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2679 context,
2680 )?;
2681 }
2682 }
2683 }
2684 crate::Expression::As {
2685 expr,
2686 kind,
2687 convert,
2688 } => match *context.resolve_type(expr) {
2689 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2690 if src.kind == crate::ScalarKind::Float
2691 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2692 && convert.is_some()
2693 {
2694 let fun_name = match (kind, convert) {
2698 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2699 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2700 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2701 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2702 _ => unreachable!(),
2703 };
2704 write!(self.out, "{fun_name}(")?;
2705 self.put_expression(expr, context, true)?;
2706 write!(self.out, ")")?;
2707 } else {
2708 let target_scalar = crate::Scalar {
2709 kind,
2710 width: convert.unwrap_or(src.width),
2711 };
2712 let op = match convert {
2713 Some(_) => "static_cast",
2714 None => "as_type",
2715 };
2716 write!(self.out, "{op}<")?;
2717 match *context.resolve_type(expr) {
2718 crate::TypeInner::Vector { size, .. } => {
2719 put_numeric_type(&mut self.out, target_scalar, &[size])?
2720 }
2721 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2722 };
2723 write!(self.out, ">(")?;
2724 self.put_expression(expr, context, true)?;
2725 write!(self.out, ")")?;
2726 }
2727 }
2728 crate::TypeInner::Matrix {
2729 columns,
2730 rows,
2731 scalar,
2732 } => {
2733 let target_scalar = crate::Scalar {
2734 kind,
2735 width: convert.unwrap_or(scalar.width),
2736 };
2737 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2738 write!(self.out, "(")?;
2739 self.put_expression(expr, context, true)?;
2740 write!(self.out, ")")?;
2741 }
2742 ref ty => {
2743 return Err(Error::GenericValidation(format!(
2744 "Unsupported type for As: {ty:?}"
2745 )))
2746 }
2747 },
2748 crate::Expression::CallResult(_)
2750 | crate::Expression::AtomicResult { .. }
2751 | crate::Expression::WorkGroupUniformLoadResult { .. }
2752 | crate::Expression::SubgroupBallotResult
2753 | crate::Expression::SubgroupOperationResult { .. }
2754 | crate::Expression::RayQueryProceedResult => {
2755 unreachable!()
2756 }
2757 crate::Expression::ArrayLength(expr) => {
2758 let global = match context.function.expressions[expr] {
2760 crate::Expression::AccessIndex { base, .. } => {
2761 match context.function.expressions[base] {
2762 crate::Expression::GlobalVariable(handle) => handle,
2763 ref ex => {
2764 return Err(Error::GenericValidation(format!(
2765 "Expected global variable in AccessIndex, got {ex:?}"
2766 )))
2767 }
2768 }
2769 }
2770 crate::Expression::GlobalVariable(handle) => handle,
2771 ref ex => {
2772 return Err(Error::GenericValidation(format!(
2773 "Unexpected expression in ArrayLength, got {ex:?}"
2774 )))
2775 }
2776 };
2777
2778 if !is_scoped {
2779 write!(self.out, "(")?;
2780 }
2781 write!(self.out, "1 + ")?;
2782 self.put_dynamic_array_max_index(global, context)?;
2783 if !is_scoped {
2784 write!(self.out, ")")?;
2785 }
2786 }
2787 crate::Expression::RayQueryVertexPositions { .. } => {
2788 unimplemented!()
2789 }
2790 crate::Expression::RayQueryGetIntersection {
2791 query,
2792 committed: _,
2793 } => {
2794 if context.lang_version < (2, 4) {
2795 return Err(Error::UnsupportedRayTracing);
2796 }
2797
2798 let ty = context.module.special_types.ray_intersection.unwrap();
2799 let type_name = &self.names[&NameKey::Type(ty)];
2800 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2801 self.put_expression(query, context, true)?;
2802 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2803 let fields = [
2804 "distance",
2805 "user_instance_id", "instance_id",
2807 "", "geometry_id",
2809 "primitive_id",
2810 "triangle_barycentric_coord",
2811 "triangle_front_facing",
2812 "", "object_to_world_transform", "world_to_object_transform", ];
2816 for field in fields {
2817 write!(self.out, ", ")?;
2818 if field.is_empty() {
2819 write!(self.out, "{{}}")?;
2820 } else {
2821 self.put_expression(query, context, true)?;
2822 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2823 }
2824 }
2825 write!(self.out, "}}")?;
2826 }
2827 }
2828 Ok(())
2829 }
2830
2831 fn put_binop<F>(
2834 &mut self,
2835 op: crate::BinaryOperator,
2836 left: Handle<crate::Expression>,
2837 right: Handle<crate::Expression>,
2838 context: &ExpressionContext,
2839 is_scoped: bool,
2840 put_expression: &F,
2841 ) -> BackendResult
2842 where
2843 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2844 {
2845 let op_str = back::binary_operation_str(op);
2846
2847 if !is_scoped {
2848 write!(self.out, "(")?;
2849 }
2850
2851 if op == crate::BinaryOperator::Multiply
2854 && matches!(
2855 context.resolve_type(right),
2856 &crate::TypeInner::Matrix { .. }
2857 )
2858 {
2859 self.put_wrapped_expression_for_packed_vec3_access(
2860 left,
2861 context,
2862 false,
2863 put_expression,
2864 )?;
2865 } else {
2866 put_expression(self, left, context, false)?;
2867 }
2868
2869 write!(self.out, " {op_str} ")?;
2870
2871 if op == crate::BinaryOperator::Multiply
2873 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2874 {
2875 self.put_wrapped_expression_for_packed_vec3_access(
2876 right,
2877 context,
2878 false,
2879 put_expression,
2880 )?;
2881 } else {
2882 put_expression(self, right, context, false)?;
2883 }
2884
2885 if !is_scoped {
2886 write!(self.out, ")")?;
2887 }
2888
2889 Ok(())
2890 }
2891
2892 fn put_wrapped_expression_for_packed_vec3_access<F>(
2894 &mut self,
2895 expr_handle: Handle<crate::Expression>,
2896 context: &ExpressionContext,
2897 is_scoped: bool,
2898 put_expression: &F,
2899 ) -> BackendResult
2900 where
2901 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2902 {
2903 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2904 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2905 put_expression(self, expr_handle, context, is_scoped)?;
2906 write!(self.out, ")")?;
2907 } else {
2908 put_expression(self, expr_handle, context, is_scoped)?;
2909 }
2910 Ok(())
2911 }
2912
2913 fn put_bitcasted_expression<F>(
2916 &mut self,
2917 cast_to: &crate::TypeInner,
2918 context: &ExpressionContext,
2919 put_expression: &F,
2920 ) -> BackendResult
2921 where
2922 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2923 {
2924 write!(self.out, "as_type<")?;
2925 match *cast_to {
2926 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2927 crate::TypeInner::Vector { size, scalar } => {
2928 put_numeric_type(&mut self.out, scalar, &[size])?
2929 }
2930 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2931 };
2932 write!(self.out, ">(")?;
2933 put_expression(self, context, true)?;
2934 write!(self.out, ")")?;
2935
2936 Ok(())
2937 }
2938
2939 fn put_index(
2941 &mut self,
2942 index: index::GuardedIndex,
2943 context: &ExpressionContext,
2944 is_scoped: bool,
2945 ) -> BackendResult {
2946 match index {
2947 index::GuardedIndex::Expression(expr) => {
2948 self.put_expression(expr, context, is_scoped)?
2949 }
2950 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2951 }
2952 Ok(())
2953 }
2954
2955 #[allow(unused_variables)]
2985 fn put_bounds_checks(
2986 &mut self,
2987 chain: Handle<crate::Expression>,
2988 context: &ExpressionContext,
2989 level: back::Level,
2990 prefix: &'static str,
2991 ) -> Result<bool, Error> {
2992 let mut check_written = false;
2993
2994 for item in context.bounds_check_iter(chain) {
2996 let BoundsCheck {
2997 base,
2998 index,
2999 length,
3000 } = item;
3001
3002 if check_written {
3003 write!(self.out, " && ")?;
3004 } else {
3005 write!(self.out, "{level}{prefix}")?;
3006 check_written = true;
3007 }
3008
3009 write!(self.out, "uint(")?;
3013 self.put_index(index, context, true)?;
3014 self.out.write_str(") < ")?;
3015 match length {
3016 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3017 index::IndexableLength::Dynamic => {
3018 let global = context.function.originating_global(base).ok_or_else(|| {
3019 Error::GenericValidation("Could not find originating global".into())
3020 })?;
3021 write!(self.out, "1 + ")?;
3022 self.put_dynamic_array_max_index(global, context)?
3023 }
3024 }
3025 }
3026
3027 Ok(check_written)
3028 }
3029
3030 fn put_access_chain(
3050 &mut self,
3051 chain: Handle<crate::Expression>,
3052 policy: index::BoundsCheckPolicy,
3053 context: &ExpressionContext,
3054 ) -> BackendResult {
3055 match context.function.expressions[chain] {
3056 crate::Expression::Access { base, index } => {
3057 let mut base_ty = context.resolve_type(base);
3058
3059 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3061 base_ty = &context.module.types[base].inner;
3062 }
3063
3064 self.put_subscripted_access_chain(
3065 base,
3066 base_ty,
3067 index::GuardedIndex::Expression(index),
3068 policy,
3069 context,
3070 )?;
3071 }
3072 crate::Expression::AccessIndex { base, index } => {
3073 let base_resolution = &context.info[base].ty;
3074 let mut base_ty = base_resolution.inner_with(&context.module.types);
3075 let mut base_ty_handle = base_resolution.handle();
3076
3077 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3079 base_ty = &context.module.types[base].inner;
3080 base_ty_handle = Some(base);
3081 }
3082
3083 match *base_ty {
3087 crate::TypeInner::Struct { .. } => {
3088 let base_ty = base_ty_handle.unwrap();
3089 self.put_access_chain(base, policy, context)?;
3090 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3091 write!(self.out, ".{name}")?;
3092 }
3093 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3094 self.put_access_chain(base, policy, context)?;
3095 if context.get_packed_vec_kind(base).is_some() {
3098 write!(self.out, "[{index}]")?;
3099 } else {
3100 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3101 }
3102 }
3103 _ => {
3104 self.put_subscripted_access_chain(
3105 base,
3106 base_ty,
3107 index::GuardedIndex::Known(index),
3108 policy,
3109 context,
3110 )?;
3111 }
3112 }
3113 }
3114 _ => self.put_expression(chain, context, false)?,
3115 }
3116
3117 Ok(())
3118 }
3119
3120 fn put_subscripted_access_chain(
3137 &mut self,
3138 base: Handle<crate::Expression>,
3139 base_ty: &crate::TypeInner,
3140 index: index::GuardedIndex,
3141 policy: index::BoundsCheckPolicy,
3142 context: &ExpressionContext,
3143 ) -> BackendResult {
3144 let accessing_wrapped_array = match *base_ty {
3145 crate::TypeInner::Array {
3146 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3147 ..
3148 } => true,
3149 _ => false,
3150 };
3151 let accessing_wrapped_binding_array =
3152 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3153
3154 self.put_access_chain(base, policy, context)?;
3155 if accessing_wrapped_array {
3156 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3157 }
3158 write!(self.out, "[")?;
3159
3160 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3162 context.access_needs_check(base, index)
3163 } else {
3164 None
3165 };
3166 if let Some(limit) = restriction_needed {
3167 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3168 self.put_index(index, context, true)?;
3169 write!(self.out, "), ")?;
3170 match limit {
3171 index::IndexableLength::Known(limit) => {
3172 write!(self.out, "{}u", limit - 1)?;
3173 }
3174 index::IndexableLength::Dynamic => {
3175 let global = context.function.originating_global(base).ok_or_else(|| {
3176 Error::GenericValidation("Could not find originating global".into())
3177 })?;
3178 self.put_dynamic_array_max_index(global, context)?;
3179 }
3180 }
3181 write!(self.out, ")")?;
3182 } else {
3183 self.put_index(index, context, true)?;
3184 }
3185
3186 write!(self.out, "]")?;
3187
3188 if accessing_wrapped_binding_array {
3189 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3190 }
3191
3192 Ok(())
3193 }
3194
3195 fn put_load(
3196 &mut self,
3197 pointer: Handle<crate::Expression>,
3198 context: &ExpressionContext,
3199 is_scoped: bool,
3200 ) -> BackendResult {
3201 let policy = context.choose_bounds_check_policy(pointer);
3204 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3205 && self.put_bounds_checks(
3206 pointer,
3207 context,
3208 back::Level(0),
3209 if is_scoped { "" } else { "(" },
3210 )?
3211 {
3212 write!(self.out, " ? ")?;
3213 self.put_unchecked_load(pointer, policy, context)?;
3214 write!(self.out, " : DefaultConstructible()")?;
3215
3216 if !is_scoped {
3217 write!(self.out, ")")?;
3218 }
3219 } else {
3220 self.put_unchecked_load(pointer, policy, context)?;
3221 }
3222
3223 Ok(())
3224 }
3225
3226 fn put_unchecked_load(
3227 &mut self,
3228 pointer: Handle<crate::Expression>,
3229 policy: index::BoundsCheckPolicy,
3230 context: &ExpressionContext,
3231 ) -> BackendResult {
3232 let is_atomic_pointer = context
3233 .resolve_type(pointer)
3234 .is_atomic_pointer(&context.module.types);
3235
3236 if is_atomic_pointer {
3237 write!(
3238 self.out,
3239 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3240 )?;
3241 self.put_access_chain(pointer, policy, context)?;
3242 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3243 } else {
3244 self.put_access_chain(pointer, policy, context)?;
3248 }
3249
3250 Ok(())
3251 }
3252
3253 fn put_return_value(
3254 &mut self,
3255 level: back::Level,
3256 expr_handle: Handle<crate::Expression>,
3257 result_struct: Option<&str>,
3258 context: &ExpressionContext,
3259 ) -> BackendResult {
3260 match result_struct {
3261 Some(struct_name) => {
3262 let mut has_point_size = false;
3263 let result_ty = context.function.result.as_ref().unwrap().ty;
3264 match context.module.types[result_ty].inner {
3265 crate::TypeInner::Struct { ref members, .. } => {
3266 let tmp = "_tmp";
3267 write!(self.out, "{level}const auto {tmp} = ")?;
3268 self.put_expression(expr_handle, context, true)?;
3269 writeln!(self.out, ";")?;
3270 write!(self.out, "{level}return {struct_name} {{")?;
3271
3272 let mut is_first = true;
3273
3274 for (index, member) in members.iter().enumerate() {
3275 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3276 member.binding
3277 {
3278 has_point_size = true;
3279 if !context.pipeline_options.allow_and_force_point_size {
3280 continue;
3281 }
3282 }
3283
3284 let comma = if is_first { "" } else { "," };
3285 is_first = false;
3286 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3287 if let crate::TypeInner::Array {
3291 size: crate::ArraySize::Constant(size),
3292 ..
3293 } = context.module.types[member.ty].inner
3294 {
3295 write!(self.out, "{comma} {{")?;
3296 for j in 0..size.get() {
3297 if j != 0 {
3298 write!(self.out, ",")?;
3299 }
3300 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3301 }
3302 write!(self.out, "}}")?;
3303 } else {
3304 write!(self.out, "{comma} {tmp}.{name}")?;
3305 }
3306 }
3307 }
3308 _ => {
3309 write!(self.out, "{level}return {struct_name} {{ ")?;
3310 self.put_expression(expr_handle, context, true)?;
3311 }
3312 }
3313
3314 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3315 let stage = context.module.entry_points[ep_index as usize].stage;
3316 if context.pipeline_options.allow_and_force_point_size
3317 && stage == crate::ShaderStage::Vertex
3318 && !has_point_size
3319 {
3320 write!(self.out, ", 1.0")?;
3322 }
3323 }
3324 write!(self.out, " }}")?;
3325 }
3326 None => {
3327 write!(self.out, "{level}return ")?;
3328 self.put_expression(expr_handle, context, true)?;
3329 }
3330 }
3331 writeln!(self.out, ";")?;
3332 Ok(())
3333 }
3334
3335 fn update_expressions_to_bake(
3340 &mut self,
3341 func: &crate::Function,
3342 info: &valid::FunctionInfo,
3343 context: &ExpressionContext,
3344 ) {
3345 use crate::Expression;
3346 self.need_bake_expressions.clear();
3347
3348 for (expr_handle, expr) in func.expressions.iter() {
3349 let expr_info = &info[expr_handle];
3352 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3353 if min_ref_count <= expr_info.ref_count {
3354 self.need_bake_expressions.insert(expr_handle);
3355 } else {
3356 match expr_info.ty {
3357 TypeResolution::Handle(h)
3359 if Some(h) == context.module.special_types.ray_desc =>
3360 {
3361 self.need_bake_expressions.insert(expr_handle);
3362 }
3363 _ => {}
3364 }
3365 }
3366
3367 if let Expression::Math {
3368 fun,
3369 arg,
3370 arg1,
3371 arg2,
3372 ..
3373 } = *expr
3374 {
3375 match fun {
3376 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3386 self.need_bake_expressions.insert(arg);
3387 self.need_bake_expressions.insert(arg1.unwrap());
3388 }
3389 crate::MathFunction::FirstLeadingBit => {
3390 self.need_bake_expressions.insert(arg);
3391 }
3392 crate::MathFunction::Pack4xI8
3393 | crate::MathFunction::Pack4xU8
3394 | crate::MathFunction::Pack4xI8Clamp
3395 | crate::MathFunction::Pack4xU8Clamp
3396 | crate::MathFunction::Unpack4xI8
3397 | crate::MathFunction::Unpack4xU8 => {
3398 if context.lang_version < (2, 1) {
3401 self.need_bake_expressions.insert(arg);
3402 }
3403 }
3404 crate::MathFunction::ExtractBits => {
3405 self.need_bake_expressions.insert(arg1.unwrap());
3407 }
3408 crate::MathFunction::InsertBits => {
3409 self.need_bake_expressions.insert(arg2.unwrap());
3411 }
3412 crate::MathFunction::Sign => {
3413 let inner = context.resolve_type(expr_handle);
3418 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3419 self.need_bake_expressions.insert(arg);
3420 }
3421 }
3422 _ => {}
3423 }
3424 }
3425 }
3426 }
3427
3428 fn start_baking_expression(
3429 &mut self,
3430 handle: Handle<crate::Expression>,
3431 context: &ExpressionContext,
3432 name: &str,
3433 ) -> BackendResult {
3434 match context.info[handle].ty {
3435 TypeResolution::Handle(ty_handle) => {
3436 let ty_name = TypeContext {
3437 handle: ty_handle,
3438 gctx: context.module.to_ctx(),
3439 names: &self.names,
3440 access: crate::StorageAccess::empty(),
3441 first_time: false,
3442 };
3443 write!(self.out, "{ty_name}")?;
3444 }
3445 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3446 put_numeric_type(&mut self.out, scalar, &[])?;
3447 }
3448 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3449 put_numeric_type(&mut self.out, scalar, &[size])?;
3450 }
3451 TypeResolution::Value(crate::TypeInner::Matrix {
3452 columns,
3453 rows,
3454 scalar,
3455 }) => {
3456 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3457 }
3458 TypeResolution::Value(ref other) => {
3459 log::warn!("Type {other:?} isn't a known local"); return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3461 }
3462 }
3463
3464 write!(self.out, " {name} = ")?;
3466
3467 Ok(())
3468 }
3469
3470 fn put_cache_restricted_level(
3483 &mut self,
3484 load: Handle<crate::Expression>,
3485 image: Handle<crate::Expression>,
3486 mip_level: Option<Handle<crate::Expression>>,
3487 indent: back::Level,
3488 context: &StatementContext,
3489 ) -> BackendResult {
3490 let level_of_detail = match mip_level {
3493 Some(level) => level,
3494 None => return Ok(()),
3495 };
3496
3497 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3498 || !context.expression.image_needs_lod(image)
3499 {
3500 return Ok(());
3501 }
3502
3503 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3504 self.put_restricted_scalar_image_index(
3505 image,
3506 level_of_detail,
3507 "get_num_mip_levels",
3508 &context.expression,
3509 )?;
3510 writeln!(self.out, ";")?;
3511
3512 Ok(())
3513 }
3514
3515 fn put_casting_to_packed_chars(
3521 &mut self,
3522 fun: crate::MathFunction,
3523 arg0: Handle<crate::Expression>,
3524 arg1: Handle<crate::Expression>,
3525 indent: back::Level,
3526 context: &StatementContext<'_>,
3527 ) -> Result<(), Error> {
3528 let packed_type = match fun {
3529 crate::MathFunction::Dot4I8Packed => "packed_char4",
3530 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3531 _ => unreachable!(),
3532 };
3533
3534 for arg in [arg0, arg1] {
3535 write!(
3536 self.out,
3537 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3538 Reinterpreted::new(packed_type, arg)
3539 )?;
3540 self.put_expression(arg, &context.expression, true)?;
3541 writeln!(self.out, ");")?;
3542 }
3543
3544 Ok(())
3545 }
3546
3547 fn put_block(
3548 &mut self,
3549 level: back::Level,
3550 statements: &[crate::Statement],
3551 context: &StatementContext,
3552 ) -> BackendResult {
3553 #[cfg(test)]
3555 self.put_block_stack_pointers
3556 .insert(ptr::from_ref(&level).cast());
3557
3558 for statement in statements {
3559 log::trace!("statement[{}] {:?}", level.0, statement);
3560 match *statement {
3561 crate::Statement::Emit(ref range) => {
3562 for handle in range.clone() {
3563 use crate::MathFunction as Mf;
3564
3565 match context.expression.function.expressions[handle] {
3566 crate::Expression::ImageLoad {
3569 image,
3570 level: mip_level,
3571 ..
3572 } => {
3573 self.put_cache_restricted_level(
3574 handle, image, mip_level, level, context,
3575 )?;
3576 }
3577
3578 crate::Expression::Math {
3587 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3588 arg,
3589 arg1,
3590 ..
3591 } if context.expression.lang_version >= (2, 1) => {
3592 self.put_casting_to_packed_chars(
3593 fun,
3594 arg,
3595 arg1.unwrap(),
3596 level,
3597 context,
3598 )?;
3599 }
3600
3601 _ => (),
3602 }
3603
3604 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3605 let expr_name = if ptr_class.is_some() {
3606 None } else if let Some(name) =
3608 context.expression.function.named_expressions.get(&handle)
3609 {
3610 Some(self.namer.call(name))
3620 } else {
3621 let bake = if context.expression.guarded_indices.contains(handle) {
3625 true
3626 } else {
3627 self.need_bake_expressions.contains(&handle)
3628 };
3629
3630 if bake {
3631 Some(Baked(handle).to_string())
3632 } else {
3633 None
3634 }
3635 };
3636
3637 if let Some(name) = expr_name {
3638 write!(self.out, "{level}")?;
3639 self.start_baking_expression(handle, &context.expression, &name)?;
3640 self.put_expression(handle, &context.expression, true)?;
3641 self.named_expressions.insert(handle, name);
3642 writeln!(self.out, ";")?;
3643 }
3644 }
3645 }
3646 crate::Statement::Block(ref block) => {
3647 if !block.is_empty() {
3648 writeln!(self.out, "{level}{{")?;
3649 self.put_block(level.next(), block, context)?;
3650 writeln!(self.out, "{level}}}")?;
3651 }
3652 }
3653 crate::Statement::If {
3654 condition,
3655 ref accept,
3656 ref reject,
3657 } => {
3658 write!(self.out, "{level}if (")?;
3659 self.put_expression(condition, &context.expression, true)?;
3660 writeln!(self.out, ") {{")?;
3661 self.put_block(level.next(), accept, context)?;
3662 if !reject.is_empty() {
3663 writeln!(self.out, "{level}}} else {{")?;
3664 self.put_block(level.next(), reject, context)?;
3665 }
3666 writeln!(self.out, "{level}}}")?;
3667 }
3668 crate::Statement::Switch {
3669 selector,
3670 ref cases,
3671 } => {
3672 write!(self.out, "{level}switch(")?;
3673 self.put_expression(selector, &context.expression, true)?;
3674 writeln!(self.out, ") {{")?;
3675 let lcase = level.next();
3676 for case in cases.iter() {
3677 match case.value {
3678 crate::SwitchValue::I32(value) => {
3679 write!(self.out, "{lcase}case {value}:")?;
3680 }
3681 crate::SwitchValue::U32(value) => {
3682 write!(self.out, "{lcase}case {value}u:")?;
3683 }
3684 crate::SwitchValue::Default => {
3685 write!(self.out, "{lcase}default:")?;
3686 }
3687 }
3688
3689 let write_block_braces = !(case.fall_through && case.body.is_empty());
3690 if write_block_braces {
3691 writeln!(self.out, " {{")?;
3692 } else {
3693 writeln!(self.out)?;
3694 }
3695
3696 self.put_block(lcase.next(), &case.body, context)?;
3697 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3698 {
3699 writeln!(self.out, "{}break;", lcase.next())?;
3700 }
3701
3702 if write_block_braces {
3703 writeln!(self.out, "{lcase}}}")?;
3704 }
3705 }
3706 writeln!(self.out, "{level}}}")?;
3707 }
3708 crate::Statement::Loop {
3709 ref body,
3710 ref continuing,
3711 break_if,
3712 } => {
3713 let force_loop_bound_statements =
3714 self.gen_force_bounded_loop_statements(level, context);
3715 let gate_name = (!continuing.is_empty() || break_if.is_some())
3716 .then(|| self.namer.call("loop_init"));
3717
3718 if let Some((ref decl, _)) = force_loop_bound_statements {
3719 writeln!(self.out, "{decl}")?;
3720 }
3721 if let Some(ref gate_name) = gate_name {
3722 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3723 }
3724
3725 writeln!(self.out, "{level}while(true) {{",)?;
3726 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3727 writeln!(self.out, "{break_and_inc}")?;
3728 }
3729 if let Some(ref gate_name) = gate_name {
3730 let lif = level.next();
3731 let lcontinuing = lif.next();
3732 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3733 self.put_block(lcontinuing, continuing, context)?;
3734 if let Some(condition) = break_if {
3735 write!(self.out, "{lcontinuing}if (")?;
3736 self.put_expression(condition, &context.expression, true)?;
3737 writeln!(self.out, ") {{")?;
3738 writeln!(self.out, "{}break;", lcontinuing.next())?;
3739 writeln!(self.out, "{lcontinuing}}}")?;
3740 }
3741 writeln!(self.out, "{lif}}}")?;
3742 writeln!(self.out, "{lif}{gate_name} = false;")?;
3743 }
3744 self.put_block(level.next(), body, context)?;
3745
3746 writeln!(self.out, "{level}}}")?;
3747 }
3748 crate::Statement::Break => {
3749 writeln!(self.out, "{level}break;")?;
3750 }
3751 crate::Statement::Continue => {
3752 writeln!(self.out, "{level}continue;")?;
3753 }
3754 crate::Statement::Return {
3755 value: Some(expr_handle),
3756 } => {
3757 self.put_return_value(
3758 level,
3759 expr_handle,
3760 context.result_struct,
3761 &context.expression,
3762 )?;
3763 }
3764 crate::Statement::Return { value: None } => {
3765 writeln!(self.out, "{level}return;")?;
3766 }
3767 crate::Statement::Kill => {
3768 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3769 }
3770 crate::Statement::ControlBarrier(flags)
3771 | crate::Statement::MemoryBarrier(flags) => {
3772 self.write_barrier(flags, level)?;
3773 }
3774 crate::Statement::Store { pointer, value } => {
3775 self.put_store(pointer, value, level, context)?
3776 }
3777 crate::Statement::ImageStore {
3778 image,
3779 coordinate,
3780 array_index,
3781 value,
3782 } => {
3783 let address = TexelAddress {
3784 coordinate,
3785 array_index,
3786 sample: None,
3787 level: None,
3788 };
3789 self.put_image_store(level, image, &address, value, context)?
3790 }
3791 crate::Statement::Call {
3792 function,
3793 ref arguments,
3794 result,
3795 } => {
3796 write!(self.out, "{level}")?;
3797 if let Some(expr) = result {
3798 let name = Baked(expr).to_string();
3799 self.start_baking_expression(expr, &context.expression, &name)?;
3800 self.named_expressions.insert(expr, name);
3801 }
3802 let fun_name = &self.names[&NameKey::Function(function)];
3803 write!(self.out, "{fun_name}(")?;
3804 for (i, &handle) in arguments.iter().enumerate() {
3806 if i != 0 {
3807 write!(self.out, ", ")?;
3808 }
3809 self.put_expression(handle, &context.expression, true)?;
3810 }
3811 let mut separate = !arguments.is_empty();
3813 let fun_info = &context.expression.mod_info[function];
3814 let mut needs_buffer_sizes = false;
3815 for (handle, var) in context.expression.module.global_variables.iter() {
3816 if fun_info[handle].is_empty() {
3817 continue;
3818 }
3819 if var.space.needs_pass_through() {
3820 let name = &self.names[&NameKey::GlobalVariable(handle)];
3821 if separate {
3822 write!(self.out, ", ")?;
3823 } else {
3824 separate = true;
3825 }
3826 write!(self.out, "{name}")?;
3827 }
3828 needs_buffer_sizes |=
3829 needs_array_length(var.ty, &context.expression.module.types);
3830 }
3831 if needs_buffer_sizes {
3832 if separate {
3833 write!(self.out, ", ")?;
3834 }
3835 write!(self.out, "_buffer_sizes")?;
3836 }
3837
3838 writeln!(self.out, ");")?;
3840 }
3841 crate::Statement::Atomic {
3842 pointer,
3843 ref fun,
3844 value,
3845 result,
3846 } => {
3847 let context = &context.expression;
3848
3849 write!(self.out, "{level}")?;
3854 let fun_key = if let Some(result) = result {
3855 let res_name = Baked(result).to_string();
3856 self.start_baking_expression(result, context, &res_name)?;
3857 self.named_expressions.insert(result, res_name);
3858 fun.to_msl()
3859 } else if context.resolve_type(value).scalar_width() == Some(8) {
3860 fun.to_msl_64_bit()?
3861 } else {
3862 fun.to_msl()
3863 };
3864
3865 let policy = context.choose_bounds_check_policy(pointer);
3869 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3870 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3871
3872 if checked {
3874 write!(self.out, " ? ")?;
3875 }
3876
3877 match *fun {
3879 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3880 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3881 self.put_access_chain(pointer, policy, context)?;
3882 write!(self.out, ", ")?;
3883 self.put_expression(cmp, context, true)?;
3884 write!(self.out, ", ")?;
3885 self.put_expression(value, context, true)?;
3886 write!(self.out, ")")?;
3887 }
3888 _ => {
3889 write!(
3890 self.out,
3891 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3892 )?;
3893 self.put_access_chain(pointer, policy, context)?;
3894 write!(self.out, ", ")?;
3895 self.put_expression(value, context, true)?;
3896 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3897 }
3898 }
3899
3900 if checked {
3902 write!(self.out, " : DefaultConstructible()")?;
3903 }
3904
3905 writeln!(self.out, ";")?;
3907 }
3908 crate::Statement::ImageAtomic {
3909 image,
3910 coordinate,
3911 array_index,
3912 fun,
3913 value,
3914 } => {
3915 let address = TexelAddress {
3916 coordinate,
3917 array_index,
3918 sample: None,
3919 level: None,
3920 };
3921 self.put_image_atomic(level, image, &address, fun, value, context)?
3922 }
3923 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3924 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3925
3926 write!(self.out, "{level}")?;
3927 let name = self.namer.call("");
3928 self.start_baking_expression(result, &context.expression, &name)?;
3929 self.put_load(pointer, &context.expression, true)?;
3930 self.named_expressions.insert(result, name);
3931
3932 writeln!(self.out, ";")?;
3933 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3934 }
3935 crate::Statement::RayQuery { query, ref fun } => {
3936 if context.expression.lang_version < (2, 4) {
3937 return Err(Error::UnsupportedRayTracing);
3938 }
3939
3940 match *fun {
3941 crate::RayQueryFunction::Initialize {
3942 acceleration_structure,
3943 descriptor,
3944 } => {
3945 write!(self.out, "{level}")?;
3947 self.put_expression(query, &context.expression, true)?;
3948 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3949 {
3950 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3951 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3952 write!(self.out, "{level}")?;
3953 self.put_expression(query, &context.expression, true)?;
3954 write!(
3955 self.out,
3956 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3957 )?;
3958 self.put_expression(descriptor, &context.expression, true)?;
3959 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3960 self.put_expression(descriptor, &context.expression, true)?;
3961 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3962 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3963 }
3964 {
3965 let f_opaque = back::RayFlag::OPAQUE.bits();
3966 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3967 write!(self.out, "{level}")?;
3968 self.put_expression(query, &context.expression, true)?;
3969 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3970 self.put_expression(descriptor, &context.expression, true)?;
3971 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3972 self.put_expression(descriptor, &context.expression, true)?;
3973 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3974 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3975 }
3976 {
3977 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3978 write!(self.out, "{level}")?;
3979 self.put_expression(query, &context.expression, true)?;
3980 write!(
3981 self.out,
3982 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3983 )?;
3984 self.put_expression(descriptor, &context.expression, true)?;
3985 writeln!(self.out, ".flags & {flag}) != 0);")?;
3986 }
3987
3988 write!(self.out, "{level}")?;
3989 self.put_expression(query, &context.expression, true)?;
3990 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3991 self.put_expression(query, &context.expression, true)?;
3992 write!(
3993 self.out,
3994 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
3995 )?;
3996 self.put_expression(descriptor, &context.expression, true)?;
3997 write!(self.out, ".origin, ")?;
3998 self.put_expression(descriptor, &context.expression, true)?;
3999 write!(self.out, ".dir, ")?;
4000 self.put_expression(descriptor, &context.expression, true)?;
4001 write!(self.out, ".tmin, ")?;
4002 self.put_expression(descriptor, &context.expression, true)?;
4003 write!(self.out, ".tmax), ")?;
4004 self.put_expression(acceleration_structure, &context.expression, true)?;
4005 write!(self.out, ", ")?;
4006 self.put_expression(descriptor, &context.expression, true)?;
4007 write!(self.out, ".cull_mask);")?;
4008
4009 write!(self.out, "{level}")?;
4010 self.put_expression(query, &context.expression, true)?;
4011 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4012 }
4013 crate::RayQueryFunction::Proceed { result } => {
4014 write!(self.out, "{level}")?;
4015 let name = Baked(result).to_string();
4016 self.start_baking_expression(result, &context.expression, &name)?;
4017 self.named_expressions.insert(result, name);
4018 self.put_expression(query, &context.expression, true)?;
4019 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4020 if RAY_QUERY_MODERN_SUPPORT {
4021 write!(self.out, "{level}")?;
4022 self.put_expression(query, &context.expression, true)?;
4023 writeln!(self.out, ".?.next();")?;
4024 }
4025 }
4026 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4027 if RAY_QUERY_MODERN_SUPPORT {
4028 write!(self.out, "{level}")?;
4029 self.put_expression(query, &context.expression, true)?;
4030 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4031 self.put_expression(hit_t, &context.expression, true)?;
4032 writeln!(self.out, ");")?;
4033 } else {
4034 log::warn!("Ray Query GenerateIntersection is not yet supported");
4035 }
4036 }
4037 crate::RayQueryFunction::ConfirmIntersection => {
4038 if RAY_QUERY_MODERN_SUPPORT {
4039 write!(self.out, "{level}")?;
4040 self.put_expression(query, &context.expression, true)?;
4041 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4042 } else {
4043 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4044 }
4045 }
4046 crate::RayQueryFunction::Terminate => {
4047 if RAY_QUERY_MODERN_SUPPORT {
4048 write!(self.out, "{level}")?;
4049 self.put_expression(query, &context.expression, true)?;
4050 writeln!(self.out, ".?.abort();")?;
4051 }
4052 write!(self.out, "{level}")?;
4053 self.put_expression(query, &context.expression, true)?;
4054 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4055 }
4056 }
4057 }
4058 crate::Statement::SubgroupBallot { result, predicate } => {
4059 write!(self.out, "{level}")?;
4060 let name = self.namer.call("");
4061 self.start_baking_expression(result, &context.expression, &name)?;
4062 self.named_expressions.insert(result, name);
4063 write!(
4064 self.out,
4065 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4066 )?;
4067 if let Some(predicate) = predicate {
4068 self.put_expression(predicate, &context.expression, true)?;
4069 } else {
4070 write!(self.out, "true")?;
4071 }
4072 writeln!(self.out, "), 0, 0, 0);")?;
4073 }
4074 crate::Statement::SubgroupCollectiveOperation {
4075 op,
4076 collective_op,
4077 argument,
4078 result,
4079 } => {
4080 write!(self.out, "{level}")?;
4081 let name = self.namer.call("");
4082 self.start_baking_expression(result, &context.expression, &name)?;
4083 self.named_expressions.insert(result, name);
4084 match (collective_op, op) {
4085 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4086 write!(self.out, "{NAMESPACE}::simd_all(")?
4087 }
4088 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4089 write!(self.out, "{NAMESPACE}::simd_any(")?
4090 }
4091 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4092 write!(self.out, "{NAMESPACE}::simd_sum(")?
4093 }
4094 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4095 write!(self.out, "{NAMESPACE}::simd_product(")?
4096 }
4097 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4098 write!(self.out, "{NAMESPACE}::simd_max(")?
4099 }
4100 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4101 write!(self.out, "{NAMESPACE}::simd_min(")?
4102 }
4103 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4104 write!(self.out, "{NAMESPACE}::simd_and(")?
4105 }
4106 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4107 write!(self.out, "{NAMESPACE}::simd_or(")?
4108 }
4109 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4110 write!(self.out, "{NAMESPACE}::simd_xor(")?
4111 }
4112 (
4113 crate::CollectiveOperation::ExclusiveScan,
4114 crate::SubgroupOperation::Add,
4115 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4116 (
4117 crate::CollectiveOperation::ExclusiveScan,
4118 crate::SubgroupOperation::Mul,
4119 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4120 (
4121 crate::CollectiveOperation::InclusiveScan,
4122 crate::SubgroupOperation::Add,
4123 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4124 (
4125 crate::CollectiveOperation::InclusiveScan,
4126 crate::SubgroupOperation::Mul,
4127 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4128 _ => unimplemented!(),
4129 }
4130 self.put_expression(argument, &context.expression, true)?;
4131 writeln!(self.out, ");")?;
4132 }
4133 crate::Statement::SubgroupGather {
4134 mode,
4135 argument,
4136 result,
4137 } => {
4138 write!(self.out, "{level}")?;
4139 let name = self.namer.call("");
4140 self.start_baking_expression(result, &context.expression, &name)?;
4141 self.named_expressions.insert(result, name);
4142 match mode {
4143 crate::GatherMode::BroadcastFirst => {
4144 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4145 }
4146 crate::GatherMode::Broadcast(_) => {
4147 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4148 }
4149 crate::GatherMode::Shuffle(_) => {
4150 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4151 }
4152 crate::GatherMode::ShuffleDown(_) => {
4153 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4154 }
4155 crate::GatherMode::ShuffleUp(_) => {
4156 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4157 }
4158 crate::GatherMode::ShuffleXor(_) => {
4159 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4160 }
4161 crate::GatherMode::QuadBroadcast(_) => {
4162 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4163 }
4164 crate::GatherMode::QuadSwap(_) => {
4165 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4166 }
4167 }
4168 self.put_expression(argument, &context.expression, true)?;
4169 match mode {
4170 crate::GatherMode::BroadcastFirst => {}
4171 crate::GatherMode::Broadcast(index)
4172 | crate::GatherMode::Shuffle(index)
4173 | crate::GatherMode::ShuffleDown(index)
4174 | crate::GatherMode::ShuffleUp(index)
4175 | crate::GatherMode::ShuffleXor(index)
4176 | crate::GatherMode::QuadBroadcast(index) => {
4177 write!(self.out, ", ")?;
4178 self.put_expression(index, &context.expression, true)?;
4179 }
4180 crate::GatherMode::QuadSwap(direction) => {
4181 write!(self.out, ", ")?;
4182 match direction {
4183 crate::Direction::X => {
4184 write!(self.out, "1u")?;
4185 }
4186 crate::Direction::Y => {
4187 write!(self.out, "2u")?;
4188 }
4189 crate::Direction::Diagonal => {
4190 write!(self.out, "3u")?;
4191 }
4192 }
4193 }
4194 }
4195 writeln!(self.out, ");")?;
4196 }
4197 }
4198 }
4199
4200 for statement in statements {
4203 if let crate::Statement::Emit(ref range) = *statement {
4204 for handle in range.clone() {
4205 self.named_expressions.shift_remove(&handle);
4206 }
4207 }
4208 }
4209 Ok(())
4210 }
4211
4212 fn put_store(
4213 &mut self,
4214 pointer: Handle<crate::Expression>,
4215 value: Handle<crate::Expression>,
4216 level: back::Level,
4217 context: &StatementContext,
4218 ) -> BackendResult {
4219 let policy = context.expression.choose_bounds_check_policy(pointer);
4220 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4221 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4222 {
4223 writeln!(self.out, ") {{")?;
4224 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4225 writeln!(self.out, "{level}}}")?;
4226 } else {
4227 self.put_unchecked_store(pointer, value, policy, level, context)?;
4228 }
4229
4230 Ok(())
4231 }
4232
4233 fn put_unchecked_store(
4234 &mut self,
4235 pointer: Handle<crate::Expression>,
4236 value: Handle<crate::Expression>,
4237 policy: index::BoundsCheckPolicy,
4238 level: back::Level,
4239 context: &StatementContext,
4240 ) -> BackendResult {
4241 let is_atomic_pointer = context
4242 .expression
4243 .resolve_type(pointer)
4244 .is_atomic_pointer(&context.expression.module.types);
4245
4246 if is_atomic_pointer {
4247 write!(
4248 self.out,
4249 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4250 )?;
4251 self.put_access_chain(pointer, policy, &context.expression)?;
4252 write!(self.out, ", ")?;
4253 self.put_expression(value, &context.expression, true)?;
4254 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4255 } else {
4256 write!(self.out, "{level}")?;
4257 self.put_access_chain(pointer, policy, &context.expression)?;
4258 write!(self.out, " = ")?;
4259 self.put_expression(value, &context.expression, true)?;
4260 writeln!(self.out, ";")?;
4261 }
4262
4263 Ok(())
4264 }
4265
4266 pub fn write(
4267 &mut self,
4268 module: &crate::Module,
4269 info: &valid::ModuleInfo,
4270 options: &Options,
4271 pipeline_options: &PipelineOptions,
4272 ) -> Result<TranslationInfo, Error> {
4273 self.names.clear();
4274 self.namer.reset(
4275 module,
4276 &super::keywords::RESERVED_SET,
4277 proc::CaseInsensitiveKeywordSet::empty(),
4278 &[CLAMPED_LOD_LOAD_PREFIX],
4279 &mut self.names,
4280 );
4281 self.wrapped_functions.clear();
4282 self.struct_member_pads.clear();
4283
4284 writeln!(
4285 self.out,
4286 "// language: metal{}.{}",
4287 options.lang_version.0, options.lang_version.1
4288 )?;
4289 writeln!(self.out, "#include <metal_stdlib>")?;
4290 writeln!(self.out, "#include <simd/simd.h>")?;
4291 writeln!(self.out)?;
4292 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4294
4295 let mut uses_ray_query = false;
4296 for (_, ty) in module.types.iter() {
4297 match ty.inner {
4298 crate::TypeInner::AccelerationStructure { .. } => {
4299 if options.lang_version < (2, 4) {
4300 return Err(Error::UnsupportedRayTracing);
4301 }
4302 }
4303 crate::TypeInner::RayQuery { .. } => {
4304 if options.lang_version < (2, 4) {
4305 return Err(Error::UnsupportedRayTracing);
4306 }
4307 uses_ray_query = true;
4308 }
4309 _ => (),
4310 }
4311 }
4312
4313 if module.special_types.ray_desc.is_some()
4314 || module.special_types.ray_intersection.is_some()
4315 {
4316 if options.lang_version < (2, 4) {
4317 return Err(Error::UnsupportedRayTracing);
4318 }
4319 }
4320
4321 if uses_ray_query {
4322 self.put_ray_query_type()?;
4323 }
4324
4325 if options
4326 .bounds_check_policies
4327 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4328 {
4329 self.put_default_constructible()?;
4330 }
4331 writeln!(self.out)?;
4332
4333 {
4334 let globals: Vec<Handle<crate::GlobalVariable>> = module
4337 .global_variables
4338 .iter()
4339 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4340 .map(|(handle, _)| handle)
4341 .collect();
4342
4343 let mut buffer_indices = vec![];
4344 for vbm in &pipeline_options.vertex_buffer_mappings {
4345 buffer_indices.push(vbm.id);
4346 }
4347
4348 if !globals.is_empty() || !buffer_indices.is_empty() {
4349 writeln!(self.out, "struct _mslBufferSizes {{")?;
4350
4351 for global in globals {
4352 writeln!(
4353 self.out,
4354 "{}uint {};",
4355 back::INDENT,
4356 ArraySizeMember(global)
4357 )?;
4358 }
4359
4360 for idx in buffer_indices {
4361 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4362 }
4363
4364 writeln!(self.out, "}};")?;
4365 writeln!(self.out)?;
4366 }
4367 };
4368
4369 self.write_type_defs(module)?;
4370 self.write_global_constants(module, info)?;
4371 self.write_functions(module, info, options, pipeline_options)
4372 }
4373
4374 fn put_default_constructible(&mut self) -> BackendResult {
4387 let tab = back::INDENT;
4388 writeln!(self.out, "struct DefaultConstructible {{")?;
4389 writeln!(self.out, "{tab}template<typename T>")?;
4390 writeln!(self.out, "{tab}operator T() && {{")?;
4391 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4392 writeln!(self.out, "{tab}}}")?;
4393 writeln!(self.out, "}};")?;
4394 Ok(())
4395 }
4396
4397 fn put_ray_query_type(&mut self) -> BackendResult {
4398 let tab = back::INDENT;
4399 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4400 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4401 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4402 writeln!(
4403 self.out,
4404 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4405 )?;
4406 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4407 writeln!(self.out, "}};")?;
4408 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4409 let v_triangle = back::RayIntersectionType::Triangle as u32;
4410 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4411 writeln!(
4412 self.out,
4413 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4414 )?;
4415 writeln!(
4416 self.out,
4417 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4418 )?;
4419 writeln!(self.out, "}}")?;
4420 Ok(())
4421 }
4422
4423 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4424 let mut generated_argument_buffer_wrapper = false;
4425 let mut generated_external_texture_wrapper = false;
4426 for (handle, ty) in module.types.iter() {
4427 match ty.inner {
4428 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4429 writeln!(self.out, "template <typename T>")?;
4430 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4431 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4432 writeln!(self.out, "}};")?;
4433 generated_argument_buffer_wrapper = true;
4434 }
4435 crate::TypeInner::Image {
4436 class: crate::ImageClass::External,
4437 ..
4438 } if !generated_external_texture_wrapper => {
4439 let params_ty_name = &self.names
4440 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4441 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4442 writeln!(
4443 self.out,
4444 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4445 back::INDENT
4446 )?;
4447 writeln!(
4448 self.out,
4449 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4450 back::INDENT
4451 )?;
4452 writeln!(
4453 self.out,
4454 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4455 back::INDENT
4456 )?;
4457 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4458 writeln!(self.out, "}};")?;
4459 generated_external_texture_wrapper = true;
4460 }
4461 _ => {}
4462 }
4463
4464 if !ty.needs_alias() {
4465 continue;
4466 }
4467 let name = &self.names[&NameKey::Type(handle)];
4468 match ty.inner {
4469 crate::TypeInner::Array {
4483 base,
4484 size,
4485 stride: _,
4486 } => {
4487 let base_name = TypeContext {
4488 handle: base,
4489 gctx: module.to_ctx(),
4490 names: &self.names,
4491 access: crate::StorageAccess::empty(),
4492 first_time: false,
4493 };
4494
4495 match size.resolve(module.to_ctx())? {
4496 proc::IndexableLength::Known(size) => {
4497 writeln!(self.out, "struct {name} {{")?;
4498 writeln!(
4499 self.out,
4500 "{}{} {}[{}];",
4501 back::INDENT,
4502 base_name,
4503 WRAPPED_ARRAY_FIELD,
4504 size
4505 )?;
4506 writeln!(self.out, "}};")?;
4507 }
4508 proc::IndexableLength::Dynamic => {
4509 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4510 }
4511 }
4512 }
4513 crate::TypeInner::Struct {
4514 ref members, span, ..
4515 } => {
4516 writeln!(self.out, "struct {name} {{")?;
4517 let mut last_offset = 0;
4518 for (index, member) in members.iter().enumerate() {
4519 if member.offset > last_offset {
4520 self.struct_member_pads.insert((handle, index as u32));
4521 let pad = member.offset - last_offset;
4522 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4523 }
4524 let ty_inner = &module.types[member.ty].inner;
4525 last_offset = member.offset + ty_inner.size(module.to_ctx());
4526
4527 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4528
4529 match should_pack_struct_member(members, span, index, module) {
4531 Some(scalar) => {
4532 writeln!(
4533 self.out,
4534 "{}{}::packed_{}3 {};",
4535 back::INDENT,
4536 NAMESPACE,
4537 scalar.to_msl_name(),
4538 member_name
4539 )?;
4540 }
4541 None => {
4542 let base_name = TypeContext {
4543 handle: member.ty,
4544 gctx: module.to_ctx(),
4545 names: &self.names,
4546 access: crate::StorageAccess::empty(),
4547 first_time: false,
4548 };
4549 writeln!(
4550 self.out,
4551 "{}{} {};",
4552 back::INDENT,
4553 base_name,
4554 member_name
4555 )?;
4556
4557 if let crate::TypeInner::Vector {
4559 size: crate::VectorSize::Tri,
4560 scalar,
4561 } = *ty_inner
4562 {
4563 last_offset += scalar.width as u32;
4564 }
4565 }
4566 }
4567 }
4568 if last_offset < span {
4569 let pad = span - last_offset;
4570 writeln!(
4571 self.out,
4572 "{}char _pad{}[{}];",
4573 back::INDENT,
4574 members.len(),
4575 pad
4576 )?;
4577 }
4578 writeln!(self.out, "}};")?;
4579 }
4580 _ => {
4581 let ty_name = TypeContext {
4582 handle,
4583 gctx: module.to_ctx(),
4584 names: &self.names,
4585 access: crate::StorageAccess::empty(),
4586 first_time: true,
4587 };
4588 writeln!(self.out, "typedef {ty_name} {name};")?;
4589 }
4590 }
4591 }
4592
4593 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4595 match type_key {
4596 &crate::PredeclaredType::ModfResult { size, scalar }
4597 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4598 let arg_type_name_owner;
4599 let arg_type_name = if let Some(size) = size {
4600 arg_type_name_owner = format!(
4601 "{NAMESPACE}::{}{}",
4602 if scalar.width == 8 { "double" } else { "float" },
4603 size as u8
4604 );
4605 &arg_type_name_owner
4606 } else if scalar.width == 8 {
4607 "double"
4608 } else {
4609 "float"
4610 };
4611
4612 let other_type_name_owner;
4613 let (defined_func_name, called_func_name, other_type_name) =
4614 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4615 (MODF_FUNCTION, "modf", arg_type_name)
4616 } else {
4617 let other_type_name = if let Some(size) = size {
4618 other_type_name_owner = format!("int{}", size as u8);
4619 &other_type_name_owner
4620 } else {
4621 "int"
4622 };
4623 (FREXP_FUNCTION, "frexp", other_type_name)
4624 };
4625
4626 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4627
4628 writeln!(self.out)?;
4629 writeln!(
4630 self.out,
4631 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4632 {other_type_name} other;
4633 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4634 return {struct_name}{{ fract, other }};
4635}}"
4636 )?;
4637 }
4638 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4639 let arg_type_name = scalar.to_msl_name();
4640 let called_func_name = "atomic_compare_exchange_weak_explicit";
4641 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4642 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4643
4644 writeln!(self.out)?;
4645
4646 for address_space_name in ["device", "threadgroup"] {
4647 writeln!(
4648 self.out,
4649 "\
4650template <typename A>
4651{struct_name} {defined_func_name}(
4652 {address_space_name} A *atomic_ptr,
4653 {arg_type_name} cmp,
4654 {arg_type_name} v
4655) {{
4656 bool swapped = {NAMESPACE}::{called_func_name}(
4657 atomic_ptr, &cmp, v,
4658 metal::memory_order_relaxed, metal::memory_order_relaxed
4659 );
4660 return {struct_name}{{cmp, swapped}};
4661}}"
4662 )?;
4663 }
4664 }
4665 }
4666 }
4667
4668 Ok(())
4669 }
4670
4671 fn write_global_constants(
4673 &mut self,
4674 module: &crate::Module,
4675 mod_info: &valid::ModuleInfo,
4676 ) -> BackendResult {
4677 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4678
4679 for (handle, constant) in constants {
4680 let ty_name = TypeContext {
4681 handle: constant.ty,
4682 gctx: module.to_ctx(),
4683 names: &self.names,
4684 access: crate::StorageAccess::empty(),
4685 first_time: false,
4686 };
4687 let name = &self.names[&NameKey::Constant(handle)];
4688 write!(self.out, "constant {ty_name} {name} = ")?;
4689 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4690 writeln!(self.out, ";")?;
4691 }
4692
4693 Ok(())
4694 }
4695
4696 fn put_inline_sampler_properties(
4697 &mut self,
4698 level: back::Level,
4699 sampler: &sm::InlineSampler,
4700 ) -> BackendResult {
4701 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4702 writeln!(
4703 self.out,
4704 "{}{}::{}_address::{},",
4705 level,
4706 NAMESPACE,
4707 letter,
4708 address.as_str(),
4709 )?;
4710 }
4711 writeln!(
4712 self.out,
4713 "{}{}::mag_filter::{},",
4714 level,
4715 NAMESPACE,
4716 sampler.mag_filter.as_str(),
4717 )?;
4718 writeln!(
4719 self.out,
4720 "{}{}::min_filter::{},",
4721 level,
4722 NAMESPACE,
4723 sampler.min_filter.as_str(),
4724 )?;
4725 if let Some(filter) = sampler.mip_filter {
4726 writeln!(
4727 self.out,
4728 "{}{}::mip_filter::{},",
4729 level,
4730 NAMESPACE,
4731 filter.as_str(),
4732 )?;
4733 }
4734 if sampler.border_color != sm::BorderColor::TransparentBlack {
4736 writeln!(
4737 self.out,
4738 "{}{}::border_color::{},",
4739 level,
4740 NAMESPACE,
4741 sampler.border_color.as_str(),
4742 )?;
4743 }
4744 if false {
4748 if let Some(ref lod) = sampler.lod_clamp {
4749 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4750 }
4751 if let Some(aniso) = sampler.max_anisotropy {
4752 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4753 }
4754 }
4755 if sampler.compare_func != sm::CompareFunc::Never {
4756 writeln!(
4757 self.out,
4758 "{}{}::compare_func::{},",
4759 level,
4760 NAMESPACE,
4761 sampler.compare_func.as_str(),
4762 )?;
4763 }
4764 writeln!(
4765 self.out,
4766 "{}{}::coord::{}",
4767 level,
4768 NAMESPACE,
4769 sampler.coord.as_str()
4770 )?;
4771 Ok(())
4772 }
4773
4774 fn write_unpacking_function(
4775 &mut self,
4776 format: back::msl::VertexFormat,
4777 ) -> Result<(String, u32, u32), Error> {
4778 use back::msl::VertexFormat::*;
4779 match format {
4780 Uint8 => {
4781 let name = self.namer.call("unpackUint8");
4782 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4783 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4784 writeln!(self.out, "}}")?;
4785 Ok((name, 1, 1))
4786 }
4787 Uint8x2 => {
4788 let name = self.namer.call("unpackUint8x2");
4789 writeln!(
4790 self.out,
4791 "metal::uint2 {name}(metal::uchar b0, \
4792 metal::uchar b1) {{"
4793 )?;
4794 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4795 writeln!(self.out, "}}")?;
4796 Ok((name, 2, 2))
4797 }
4798 Uint8x4 => {
4799 let name = self.namer.call("unpackUint8x4");
4800 writeln!(
4801 self.out,
4802 "metal::uint4 {name}(metal::uchar b0, \
4803 metal::uchar b1, \
4804 metal::uchar b2, \
4805 metal::uchar b3) {{"
4806 )?;
4807 writeln!(
4808 self.out,
4809 "{}return metal::uint4(b0, b1, b2, b3);",
4810 back::INDENT
4811 )?;
4812 writeln!(self.out, "}}")?;
4813 Ok((name, 4, 4))
4814 }
4815 Sint8 => {
4816 let name = self.namer.call("unpackSint8");
4817 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4818 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4819 writeln!(self.out, "}}")?;
4820 Ok((name, 1, 1))
4821 }
4822 Sint8x2 => {
4823 let name = self.namer.call("unpackSint8x2");
4824 writeln!(
4825 self.out,
4826 "metal::int2 {name}(metal::uchar b0, \
4827 metal::uchar b1) {{"
4828 )?;
4829 writeln!(
4830 self.out,
4831 "{}return metal::int2(as_type<char>(b0), \
4832 as_type<char>(b1));",
4833 back::INDENT
4834 )?;
4835 writeln!(self.out, "}}")?;
4836 Ok((name, 2, 2))
4837 }
4838 Sint8x4 => {
4839 let name = self.namer.call("unpackSint8x4");
4840 writeln!(
4841 self.out,
4842 "metal::int4 {name}(metal::uchar b0, \
4843 metal::uchar b1, \
4844 metal::uchar b2, \
4845 metal::uchar b3) {{"
4846 )?;
4847 writeln!(
4848 self.out,
4849 "{}return metal::int4(as_type<char>(b0), \
4850 as_type<char>(b1), \
4851 as_type<char>(b2), \
4852 as_type<char>(b3));",
4853 back::INDENT
4854 )?;
4855 writeln!(self.out, "}}")?;
4856 Ok((name, 4, 4))
4857 }
4858 Unorm8 => {
4859 let name = self.namer.call("unpackUnorm8");
4860 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4861 writeln!(
4862 self.out,
4863 "{}return float(float(b0) / 255.0f);",
4864 back::INDENT
4865 )?;
4866 writeln!(self.out, "}}")?;
4867 Ok((name, 1, 1))
4868 }
4869 Unorm8x2 => {
4870 let name = self.namer.call("unpackUnorm8x2");
4871 writeln!(
4872 self.out,
4873 "metal::float2 {name}(metal::uchar b0, \
4874 metal::uchar b1) {{"
4875 )?;
4876 writeln!(
4877 self.out,
4878 "{}return metal::float2(float(b0) / 255.0f, \
4879 float(b1) / 255.0f);",
4880 back::INDENT
4881 )?;
4882 writeln!(self.out, "}}")?;
4883 Ok((name, 2, 2))
4884 }
4885 Unorm8x4 => {
4886 let name = self.namer.call("unpackUnorm8x4");
4887 writeln!(
4888 self.out,
4889 "metal::float4 {name}(metal::uchar b0, \
4890 metal::uchar b1, \
4891 metal::uchar b2, \
4892 metal::uchar b3) {{"
4893 )?;
4894 writeln!(
4895 self.out,
4896 "{}return metal::float4(float(b0) / 255.0f, \
4897 float(b1) / 255.0f, \
4898 float(b2) / 255.0f, \
4899 float(b3) / 255.0f);",
4900 back::INDENT
4901 )?;
4902 writeln!(self.out, "}}")?;
4903 Ok((name, 4, 4))
4904 }
4905 Snorm8 => {
4906 let name = self.namer.call("unpackSnorm8");
4907 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4908 writeln!(
4909 self.out,
4910 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4911 back::INDENT
4912 )?;
4913 writeln!(self.out, "}}")?;
4914 Ok((name, 1, 1))
4915 }
4916 Snorm8x2 => {
4917 let name = self.namer.call("unpackSnorm8x2");
4918 writeln!(
4919 self.out,
4920 "metal::float2 {name}(metal::uchar b0, \
4921 metal::uchar b1) {{"
4922 )?;
4923 writeln!(
4924 self.out,
4925 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4926 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4927 back::INDENT
4928 )?;
4929 writeln!(self.out, "}}")?;
4930 Ok((name, 2, 2))
4931 }
4932 Snorm8x4 => {
4933 let name = self.namer.call("unpackSnorm8x4");
4934 writeln!(
4935 self.out,
4936 "metal::float4 {name}(metal::uchar b0, \
4937 metal::uchar b1, \
4938 metal::uchar b2, \
4939 metal::uchar b3) {{"
4940 )?;
4941 writeln!(
4942 self.out,
4943 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4944 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4945 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4946 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4947 back::INDENT
4948 )?;
4949 writeln!(self.out, "}}")?;
4950 Ok((name, 4, 4))
4951 }
4952 Uint16 => {
4953 let name = self.namer.call("unpackUint16");
4954 writeln!(
4955 self.out,
4956 "metal::uint {name}(metal::uint b0, \
4957 metal::uint b1) {{"
4958 )?;
4959 writeln!(
4960 self.out,
4961 "{}return metal::uint(b1 << 8 | b0);",
4962 back::INDENT
4963 )?;
4964 writeln!(self.out, "}}")?;
4965 Ok((name, 2, 1))
4966 }
4967 Uint16x2 => {
4968 let name = self.namer.call("unpackUint16x2");
4969 writeln!(
4970 self.out,
4971 "metal::uint2 {name}(metal::uint b0, \
4972 metal::uint b1, \
4973 metal::uint b2, \
4974 metal::uint b3) {{"
4975 )?;
4976 writeln!(
4977 self.out,
4978 "{}return metal::uint2(b1 << 8 | b0, \
4979 b3 << 8 | b2);",
4980 back::INDENT
4981 )?;
4982 writeln!(self.out, "}}")?;
4983 Ok((name, 4, 2))
4984 }
4985 Uint16x4 => {
4986 let name = self.namer.call("unpackUint16x4");
4987 writeln!(
4988 self.out,
4989 "metal::uint4 {name}(metal::uint b0, \
4990 metal::uint b1, \
4991 metal::uint b2, \
4992 metal::uint b3, \
4993 metal::uint b4, \
4994 metal::uint b5, \
4995 metal::uint b6, \
4996 metal::uint b7) {{"
4997 )?;
4998 writeln!(
4999 self.out,
5000 "{}return metal::uint4(b1 << 8 | b0, \
5001 b3 << 8 | b2, \
5002 b5 << 8 | b4, \
5003 b7 << 8 | b6);",
5004 back::INDENT
5005 )?;
5006 writeln!(self.out, "}}")?;
5007 Ok((name, 8, 4))
5008 }
5009 Sint16 => {
5010 let name = self.namer.call("unpackSint16");
5011 writeln!(
5012 self.out,
5013 "int {name}(metal::ushort b0, \
5014 metal::ushort b1) {{"
5015 )?;
5016 writeln!(
5017 self.out,
5018 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5019 back::INDENT
5020 )?;
5021 writeln!(self.out, "}}")?;
5022 Ok((name, 2, 1))
5023 }
5024 Sint16x2 => {
5025 let name = self.namer.call("unpackSint16x2");
5026 writeln!(
5027 self.out,
5028 "metal::int2 {name}(metal::ushort b0, \
5029 metal::ushort b1, \
5030 metal::ushort b2, \
5031 metal::ushort b3) {{"
5032 )?;
5033 writeln!(
5034 self.out,
5035 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5036 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5037 back::INDENT
5038 )?;
5039 writeln!(self.out, "}}")?;
5040 Ok((name, 4, 2))
5041 }
5042 Sint16x4 => {
5043 let name = self.namer.call("unpackSint16x4");
5044 writeln!(
5045 self.out,
5046 "metal::int4 {name}(metal::ushort b0, \
5047 metal::ushort b1, \
5048 metal::ushort b2, \
5049 metal::ushort b3, \
5050 metal::ushort b4, \
5051 metal::ushort b5, \
5052 metal::ushort b6, \
5053 metal::ushort b7) {{"
5054 )?;
5055 writeln!(
5056 self.out,
5057 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5058 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5059 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5060 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5061 back::INDENT
5062 )?;
5063 writeln!(self.out, "}}")?;
5064 Ok((name, 8, 4))
5065 }
5066 Unorm16 => {
5067 let name = self.namer.call("unpackUnorm16");
5068 writeln!(
5069 self.out,
5070 "float {name}(metal::ushort b0, \
5071 metal::ushort b1) {{"
5072 )?;
5073 writeln!(
5074 self.out,
5075 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5076 back::INDENT
5077 )?;
5078 writeln!(self.out, "}}")?;
5079 Ok((name, 2, 1))
5080 }
5081 Unorm16x2 => {
5082 let name = self.namer.call("unpackUnorm16x2");
5083 writeln!(
5084 self.out,
5085 "metal::float2 {name}(metal::ushort b0, \
5086 metal::ushort b1, \
5087 metal::ushort b2, \
5088 metal::ushort b3) {{"
5089 )?;
5090 writeln!(
5091 self.out,
5092 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5093 float(b3 << 8 | b2) / 65535.0f);",
5094 back::INDENT
5095 )?;
5096 writeln!(self.out, "}}")?;
5097 Ok((name, 4, 2))
5098 }
5099 Unorm16x4 => {
5100 let name = self.namer.call("unpackUnorm16x4");
5101 writeln!(
5102 self.out,
5103 "metal::float4 {name}(metal::ushort b0, \
5104 metal::ushort b1, \
5105 metal::ushort b2, \
5106 metal::ushort b3, \
5107 metal::ushort b4, \
5108 metal::ushort b5, \
5109 metal::ushort b6, \
5110 metal::ushort b7) {{"
5111 )?;
5112 writeln!(
5113 self.out,
5114 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5115 float(b3 << 8 | b2) / 65535.0f, \
5116 float(b5 << 8 | b4) / 65535.0f, \
5117 float(b7 << 8 | b6) / 65535.0f);",
5118 back::INDENT
5119 )?;
5120 writeln!(self.out, "}}")?;
5121 Ok((name, 8, 4))
5122 }
5123 Snorm16 => {
5124 let name = self.namer.call("unpackSnorm16");
5125 writeln!(
5126 self.out,
5127 "float {name}(metal::ushort b0, \
5128 metal::ushort b1) {{"
5129 )?;
5130 writeln!(
5131 self.out,
5132 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5133 back::INDENT
5134 )?;
5135 writeln!(self.out, "}}")?;
5136 Ok((name, 2, 1))
5137 }
5138 Snorm16x2 => {
5139 let name = self.namer.call("unpackSnorm16x2");
5140 writeln!(
5141 self.out,
5142 "metal::float2 {name}(uint b0, \
5143 uint b1, \
5144 uint b2, \
5145 uint b3) {{"
5146 )?;
5147 writeln!(
5148 self.out,
5149 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5150 back::INDENT
5151 )?;
5152 writeln!(self.out, "}}")?;
5153 Ok((name, 4, 2))
5154 }
5155 Snorm16x4 => {
5156 let name = self.namer.call("unpackSnorm16x4");
5157 writeln!(
5158 self.out,
5159 "metal::float4 {name}(uint b0, \
5160 uint b1, \
5161 uint b2, \
5162 uint b3, \
5163 uint b4, \
5164 uint b5, \
5165 uint b6, \
5166 uint b7) {{"
5167 )?;
5168 writeln!(
5169 self.out,
5170 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5171 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5172 back::INDENT
5173 )?;
5174 writeln!(self.out, "}}")?;
5175 Ok((name, 8, 4))
5176 }
5177 Float16 => {
5178 let name = self.namer.call("unpackFloat16");
5179 writeln!(
5180 self.out,
5181 "float {name}(metal::ushort b0, \
5182 metal::ushort b1) {{"
5183 )?;
5184 writeln!(
5185 self.out,
5186 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5187 back::INDENT
5188 )?;
5189 writeln!(self.out, "}}")?;
5190 Ok((name, 2, 1))
5191 }
5192 Float16x2 => {
5193 let name = self.namer.call("unpackFloat16x2");
5194 writeln!(
5195 self.out,
5196 "metal::float2 {name}(metal::ushort b0, \
5197 metal::ushort b1, \
5198 metal::ushort b2, \
5199 metal::ushort b3) {{"
5200 )?;
5201 writeln!(
5202 self.out,
5203 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5204 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5205 back::INDENT
5206 )?;
5207 writeln!(self.out, "}}")?;
5208 Ok((name, 4, 2))
5209 }
5210 Float16x4 => {
5211 let name = self.namer.call("unpackFloat16x4");
5212 writeln!(
5213 self.out,
5214 "metal::float4 {name}(metal::ushort b0, \
5215 metal::ushort b1, \
5216 metal::ushort b2, \
5217 metal::ushort b3, \
5218 metal::ushort b4, \
5219 metal::ushort b5, \
5220 metal::ushort b6, \
5221 metal::ushort b7) {{"
5222 )?;
5223 writeln!(
5224 self.out,
5225 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5226 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5227 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5228 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5229 back::INDENT
5230 )?;
5231 writeln!(self.out, "}}")?;
5232 Ok((name, 8, 4))
5233 }
5234 Float32 => {
5235 let name = self.namer.call("unpackFloat32");
5236 writeln!(
5237 self.out,
5238 "float {name}(uint b0, \
5239 uint b1, \
5240 uint b2, \
5241 uint b3) {{"
5242 )?;
5243 writeln!(
5244 self.out,
5245 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5246 back::INDENT
5247 )?;
5248 writeln!(self.out, "}}")?;
5249 Ok((name, 4, 1))
5250 }
5251 Float32x2 => {
5252 let name = self.namer.call("unpackFloat32x2");
5253 writeln!(
5254 self.out,
5255 "metal::float2 {name}(uint b0, \
5256 uint b1, \
5257 uint b2, \
5258 uint b3, \
5259 uint b4, \
5260 uint b5, \
5261 uint b6, \
5262 uint b7) {{"
5263 )?;
5264 writeln!(
5265 self.out,
5266 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5267 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5268 back::INDENT
5269 )?;
5270 writeln!(self.out, "}}")?;
5271 Ok((name, 8, 2))
5272 }
5273 Float32x3 => {
5274 let name = self.namer.call("unpackFloat32x3");
5275 writeln!(
5276 self.out,
5277 "metal::float3 {name}(uint b0, \
5278 uint b1, \
5279 uint b2, \
5280 uint b3, \
5281 uint b4, \
5282 uint b5, \
5283 uint b6, \
5284 uint b7, \
5285 uint b8, \
5286 uint b9, \
5287 uint b10, \
5288 uint b11) {{"
5289 )?;
5290 writeln!(
5291 self.out,
5292 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5293 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5294 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5295 back::INDENT
5296 )?;
5297 writeln!(self.out, "}}")?;
5298 Ok((name, 12, 3))
5299 }
5300 Float32x4 => {
5301 let name = self.namer.call("unpackFloat32x4");
5302 writeln!(
5303 self.out,
5304 "metal::float4 {name}(uint b0, \
5305 uint b1, \
5306 uint b2, \
5307 uint b3, \
5308 uint b4, \
5309 uint b5, \
5310 uint b6, \
5311 uint b7, \
5312 uint b8, \
5313 uint b9, \
5314 uint b10, \
5315 uint b11, \
5316 uint b12, \
5317 uint b13, \
5318 uint b14, \
5319 uint b15) {{"
5320 )?;
5321 writeln!(
5322 self.out,
5323 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5324 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5325 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5326 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5327 back::INDENT
5328 )?;
5329 writeln!(self.out, "}}")?;
5330 Ok((name, 16, 4))
5331 }
5332 Uint32 => {
5333 let name = self.namer.call("unpackUint32");
5334 writeln!(
5335 self.out,
5336 "uint {name}(uint b0, \
5337 uint b1, \
5338 uint b2, \
5339 uint b3) {{"
5340 )?;
5341 writeln!(
5342 self.out,
5343 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5344 back::INDENT
5345 )?;
5346 writeln!(self.out, "}}")?;
5347 Ok((name, 4, 1))
5348 }
5349 Uint32x2 => {
5350 let name = self.namer.call("unpackUint32x2");
5351 writeln!(
5352 self.out,
5353 "uint2 {name}(uint b0, \
5354 uint b1, \
5355 uint b2, \
5356 uint b3, \
5357 uint b4, \
5358 uint b5, \
5359 uint b6, \
5360 uint b7) {{"
5361 )?;
5362 writeln!(
5363 self.out,
5364 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5365 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5366 back::INDENT
5367 )?;
5368 writeln!(self.out, "}}")?;
5369 Ok((name, 8, 2))
5370 }
5371 Uint32x3 => {
5372 let name = self.namer.call("unpackUint32x3");
5373 writeln!(
5374 self.out,
5375 "uint3 {name}(uint b0, \
5376 uint b1, \
5377 uint b2, \
5378 uint b3, \
5379 uint b4, \
5380 uint b5, \
5381 uint b6, \
5382 uint b7, \
5383 uint b8, \
5384 uint b9, \
5385 uint b10, \
5386 uint b11) {{"
5387 )?;
5388 writeln!(
5389 self.out,
5390 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5391 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5392 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5393 back::INDENT
5394 )?;
5395 writeln!(self.out, "}}")?;
5396 Ok((name, 12, 3))
5397 }
5398 Uint32x4 => {
5399 let name = self.namer.call("unpackUint32x4");
5400 writeln!(
5401 self.out,
5402 "{NAMESPACE}::uint4 {name}(uint b0, \
5403 uint b1, \
5404 uint b2, \
5405 uint b3, \
5406 uint b4, \
5407 uint b5, \
5408 uint b6, \
5409 uint b7, \
5410 uint b8, \
5411 uint b9, \
5412 uint b10, \
5413 uint b11, \
5414 uint b12, \
5415 uint b13, \
5416 uint b14, \
5417 uint b15) {{"
5418 )?;
5419 writeln!(
5420 self.out,
5421 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5422 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5423 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5424 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5425 back::INDENT
5426 )?;
5427 writeln!(self.out, "}}")?;
5428 Ok((name, 16, 4))
5429 }
5430 Sint32 => {
5431 let name = self.namer.call("unpackSint32");
5432 writeln!(
5433 self.out,
5434 "int {name}(uint b0, \
5435 uint b1, \
5436 uint b2, \
5437 uint b3) {{"
5438 )?;
5439 writeln!(
5440 self.out,
5441 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5442 back::INDENT
5443 )?;
5444 writeln!(self.out, "}}")?;
5445 Ok((name, 4, 1))
5446 }
5447 Sint32x2 => {
5448 let name = self.namer.call("unpackSint32x2");
5449 writeln!(
5450 self.out,
5451 "metal::int2 {name}(uint b0, \
5452 uint b1, \
5453 uint b2, \
5454 uint b3, \
5455 uint b4, \
5456 uint b5, \
5457 uint b6, \
5458 uint b7) {{"
5459 )?;
5460 writeln!(
5461 self.out,
5462 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5463 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5464 back::INDENT
5465 )?;
5466 writeln!(self.out, "}}")?;
5467 Ok((name, 8, 2))
5468 }
5469 Sint32x3 => {
5470 let name = self.namer.call("unpackSint32x3");
5471 writeln!(
5472 self.out,
5473 "metal::int3 {name}(uint b0, \
5474 uint b1, \
5475 uint b2, \
5476 uint b3, \
5477 uint b4, \
5478 uint b5, \
5479 uint b6, \
5480 uint b7, \
5481 uint b8, \
5482 uint b9, \
5483 uint b10, \
5484 uint b11) {{"
5485 )?;
5486 writeln!(
5487 self.out,
5488 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5489 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5490 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5491 back::INDENT
5492 )?;
5493 writeln!(self.out, "}}")?;
5494 Ok((name, 12, 3))
5495 }
5496 Sint32x4 => {
5497 let name = self.namer.call("unpackSint32x4");
5498 writeln!(
5499 self.out,
5500 "metal::int4 {name}(uint b0, \
5501 uint b1, \
5502 uint b2, \
5503 uint b3, \
5504 uint b4, \
5505 uint b5, \
5506 uint b6, \
5507 uint b7, \
5508 uint b8, \
5509 uint b9, \
5510 uint b10, \
5511 uint b11, \
5512 uint b12, \
5513 uint b13, \
5514 uint b14, \
5515 uint b15) {{"
5516 )?;
5517 writeln!(
5518 self.out,
5519 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5520 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5521 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5522 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5523 back::INDENT
5524 )?;
5525 writeln!(self.out, "}}")?;
5526 Ok((name, 16, 4))
5527 }
5528 Unorm10_10_10_2 => {
5529 let name = self.namer.call("unpackUnorm10_10_10_2");
5530 writeln!(
5531 self.out,
5532 "metal::float4 {name}(uint b0, \
5533 uint b1, \
5534 uint b2, \
5535 uint b3) {{"
5536 )?;
5537 writeln!(
5538 self.out,
5539 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5551 back::INDENT
5552 )?;
5553 writeln!(self.out, "}}")?;
5554 Ok((name, 4, 4))
5555 }
5556 Unorm8x4Bgra => {
5557 let name = self.namer.call("unpackUnorm8x4Bgra");
5558 writeln!(
5559 self.out,
5560 "metal::float4 {name}(metal::uchar b0, \
5561 metal::uchar b1, \
5562 metal::uchar b2, \
5563 metal::uchar b3) {{"
5564 )?;
5565 writeln!(
5566 self.out,
5567 "{}return metal::float4(float(b2) / 255.0f, \
5568 float(b1) / 255.0f, \
5569 float(b0) / 255.0f, \
5570 float(b3) / 255.0f);",
5571 back::INDENT
5572 )?;
5573 writeln!(self.out, "}}")?;
5574 Ok((name, 4, 4))
5575 }
5576 }
5577 }
5578
5579 fn write_wrapped_unary_op(
5580 &mut self,
5581 module: &crate::Module,
5582 func_ctx: &back::FunctionCtx,
5583 op: crate::UnaryOperator,
5584 operand: Handle<crate::Expression>,
5585 ) -> BackendResult {
5586 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5587 match op {
5588 crate::UnaryOperator::Negate
5595 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5596 {
5597 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5598 return Ok(());
5599 };
5600 let wrapped = WrappedFunction::UnaryOp {
5601 op,
5602 ty: (vector_size, scalar),
5603 };
5604 if !self.wrapped_functions.insert(wrapped) {
5605 return Ok(());
5606 }
5607
5608 let unsigned_scalar = crate::Scalar {
5609 kind: crate::ScalarKind::Uint,
5610 ..scalar
5611 };
5612 let mut type_name = String::new();
5613 let mut unsigned_type_name = String::new();
5614 match vector_size {
5615 None => {
5616 put_numeric_type(&mut type_name, scalar, &[])?;
5617 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5618 }
5619 Some(size) => {
5620 put_numeric_type(&mut type_name, scalar, &[size])?;
5621 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5622 }
5623 };
5624
5625 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5626 let level = back::Level(1);
5627 writeln!(
5628 self.out,
5629 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5630 )?;
5631 writeln!(self.out, "}}")?;
5632 writeln!(self.out)?;
5633 }
5634 _ => {}
5635 }
5636 Ok(())
5637 }
5638
5639 fn write_wrapped_binary_op(
5640 &mut self,
5641 module: &crate::Module,
5642 func_ctx: &back::FunctionCtx,
5643 expr: Handle<crate::Expression>,
5644 op: crate::BinaryOperator,
5645 left: Handle<crate::Expression>,
5646 right: Handle<crate::Expression>,
5647 ) -> BackendResult {
5648 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5649 let left_ty = func_ctx.resolve_type(left, &module.types);
5650 let right_ty = func_ctx.resolve_type(right, &module.types);
5651 match (op, expr_ty.scalar_kind()) {
5652 (
5659 crate::BinaryOperator::Divide,
5660 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5661 ) => {
5662 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5663 return Ok(());
5664 };
5665 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5666 return Ok(());
5667 };
5668 let wrapped = WrappedFunction::BinaryOp {
5669 op,
5670 left_ty: left_wrapped_ty,
5671 right_ty: right_wrapped_ty,
5672 };
5673 if !self.wrapped_functions.insert(wrapped) {
5674 return Ok(());
5675 }
5676
5677 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5678 return Ok(());
5679 };
5680 let mut type_name = String::new();
5681 match vector_size {
5682 None => put_numeric_type(&mut type_name, scalar, &[])?,
5683 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5684 };
5685 writeln!(
5686 self.out,
5687 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5688 )?;
5689 let level = back::Level(1);
5690 match scalar.kind {
5691 crate::ScalarKind::Sint => {
5692 let min_val = match scalar.width {
5693 4 => crate::Literal::I32(i32::MIN),
5694 8 => crate::Literal::I64(i64::MIN),
5695 _ => {
5696 return Err(Error::GenericValidation(format!(
5697 "Unexpected width for scalar {scalar:?}"
5698 )));
5699 }
5700 };
5701 write!(
5702 self.out,
5703 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5704 )?;
5705 self.put_literal(min_val)?;
5706 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5707 }
5708 crate::ScalarKind::Uint => writeln!(
5709 self.out,
5710 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5711 )?,
5712 _ => unreachable!(),
5713 }
5714 writeln!(self.out, "}}")?;
5715 writeln!(self.out)?;
5716 }
5717 (
5730 crate::BinaryOperator::Modulo,
5731 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5732 ) => {
5733 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5734 return Ok(());
5735 };
5736 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5737 else {
5738 return Ok(());
5739 };
5740 let wrapped = WrappedFunction::BinaryOp {
5741 op,
5742 left_ty: left_wrapped_ty,
5743 right_ty: (right_vector_size, right_scalar),
5744 };
5745 if !self.wrapped_functions.insert(wrapped) {
5746 return Ok(());
5747 }
5748
5749 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5750 return Ok(());
5751 };
5752 let mut type_name = String::new();
5753 match vector_size {
5754 None => put_numeric_type(&mut type_name, scalar, &[])?,
5755 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5756 };
5757 let mut rhs_type_name = String::new();
5758 match right_vector_size {
5759 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5760 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5761 };
5762
5763 writeln!(
5764 self.out,
5765 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5766 )?;
5767 let level = back::Level(1);
5768 match scalar.kind {
5769 crate::ScalarKind::Sint => {
5770 let min_val = match scalar.width {
5771 4 => crate::Literal::I32(i32::MIN),
5772 8 => crate::Literal::I64(i64::MIN),
5773 _ => {
5774 return Err(Error::GenericValidation(format!(
5775 "Unexpected width for scalar {scalar:?}"
5776 )));
5777 }
5778 };
5779 write!(
5780 self.out,
5781 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5782 )?;
5783 self.put_literal(min_val)?;
5784 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5785 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5786 }
5787 crate::ScalarKind::Uint => writeln!(
5788 self.out,
5789 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5790 )?,
5791 _ => unreachable!(),
5792 }
5793 writeln!(self.out, "}}")?;
5794 writeln!(self.out)?;
5795 }
5796 _ => {}
5797 }
5798 Ok(())
5799 }
5800
5801 fn get_dot_wrapper_function_helper_name(
5807 &self,
5808 scalar: crate::Scalar,
5809 size: crate::VectorSize,
5810 ) -> String {
5811 debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5813
5814 let type_name = scalar.to_msl_name();
5815 let size_suffix = common::vector_size_str(size);
5816 format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5817 }
5818
5819 #[allow(clippy::too_many_arguments)]
5820 fn write_wrapped_math_function(
5821 &mut self,
5822 module: &crate::Module,
5823 func_ctx: &back::FunctionCtx,
5824 fun: crate::MathFunction,
5825 arg: Handle<crate::Expression>,
5826 _arg1: Option<Handle<crate::Expression>>,
5827 _arg2: Option<Handle<crate::Expression>>,
5828 _arg3: Option<Handle<crate::Expression>>,
5829 ) -> BackendResult {
5830 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5831 match fun {
5832 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5840 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5841 return Ok(());
5842 };
5843 let wrapped = WrappedFunction::Math {
5844 fun,
5845 arg_ty: (vector_size, scalar),
5846 };
5847 if !self.wrapped_functions.insert(wrapped) {
5848 return Ok(());
5849 }
5850
5851 let unsigned_scalar = crate::Scalar {
5852 kind: crate::ScalarKind::Uint,
5853 ..scalar
5854 };
5855 let mut type_name = String::new();
5856 let mut unsigned_type_name = String::new();
5857 match vector_size {
5858 None => {
5859 put_numeric_type(&mut type_name, scalar, &[])?;
5860 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5861 }
5862 Some(size) => {
5863 put_numeric_type(&mut type_name, scalar, &[size])?;
5864 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5865 }
5866 };
5867
5868 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5869 let level = back::Level(1);
5870 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5871 writeln!(self.out, "}}")?;
5872 writeln!(self.out)?;
5873 }
5874
5875 crate::MathFunction::Dot => match *arg_ty {
5876 crate::TypeInner::Vector { size, scalar }
5877 if matches!(
5878 scalar.kind,
5879 crate::ScalarKind::Sint | crate::ScalarKind::Uint
5880 ) =>
5881 {
5882 let wrapped = WrappedFunction::Math {
5884 fun,
5885 arg_ty: (Some(size), scalar),
5886 };
5887 if !self.wrapped_functions.insert(wrapped) {
5888 return Ok(());
5889 }
5890
5891 let mut vec_ty = String::new();
5892 put_numeric_type(&mut vec_ty, scalar, &[size])?;
5893 let mut ret_ty = String::new();
5894 put_numeric_type(&mut ret_ty, scalar, &[])?;
5895
5896 let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5897
5898 writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5900 let level = back::Level(1);
5901 write!(self.out, "{level}return ")?;
5902 self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5903 write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5904 Ok(())
5905 })?;
5906 writeln!(self.out, ";")?;
5907 writeln!(self.out, "}}")?;
5908 writeln!(self.out)?;
5909 }
5910 _ => {}
5911 },
5912
5913 _ => {}
5914 }
5915 Ok(())
5916 }
5917
5918 fn write_wrapped_cast(
5919 &mut self,
5920 module: &crate::Module,
5921 func_ctx: &back::FunctionCtx,
5922 expr: Handle<crate::Expression>,
5923 kind: crate::ScalarKind,
5924 convert: Option<crate::Bytes>,
5925 ) -> BackendResult {
5926 let src_ty = func_ctx.resolve_type(expr, &module.types);
5937 let Some(width) = convert else {
5938 return Ok(());
5939 };
5940 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5941 return Ok(());
5942 };
5943 let dst_scalar = crate::Scalar { kind, width };
5944 if src_scalar.kind != crate::ScalarKind::Float
5945 || (dst_scalar.kind != crate::ScalarKind::Sint
5946 && dst_scalar.kind != crate::ScalarKind::Uint)
5947 {
5948 return Ok(());
5949 }
5950 let wrapped = WrappedFunction::Cast {
5951 src_scalar,
5952 vector_size,
5953 dst_scalar,
5954 };
5955 if !self.wrapped_functions.insert(wrapped) {
5956 return Ok(());
5957 }
5958 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5959
5960 let mut src_type_name = String::new();
5961 match vector_size {
5962 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5963 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5964 };
5965 let mut dst_type_name = String::new();
5966 match vector_size {
5967 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5968 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5969 };
5970 let fun_name = match dst_scalar {
5971 crate::Scalar::I32 => F2I32_FUNCTION,
5972 crate::Scalar::U32 => F2U32_FUNCTION,
5973 crate::Scalar::I64 => F2I64_FUNCTION,
5974 crate::Scalar::U64 => F2U64_FUNCTION,
5975 _ => unreachable!(),
5976 };
5977
5978 writeln!(
5979 self.out,
5980 "{dst_type_name} {fun_name}({src_type_name} value) {{"
5981 )?;
5982 let level = back::Level(1);
5983 write!(
5984 self.out,
5985 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5986 )?;
5987 self.put_literal(min)?;
5988 write!(self.out, ", ")?;
5989 self.put_literal(max)?;
5990 writeln!(self.out, "));")?;
5991 writeln!(self.out, "}}")?;
5992 writeln!(self.out)?;
5993 Ok(())
5994 }
5995
5996 fn write_convert_yuv_to_rgb_and_return(
6004 &mut self,
6005 level: back::Level,
6006 y: &str,
6007 uv: &str,
6008 params: &str,
6009 ) -> BackendResult {
6010 let l1 = level;
6011 let l2 = l1.next();
6012
6013 writeln!(
6015 self.out,
6016 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
6017 )?;
6018
6019 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
6022 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
6023 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
6024 writeln!(
6025 self.out,
6026 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
6027 )?;
6028
6029 writeln!(
6032 self.out,
6033 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
6034 )?;
6035
6036 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
6039 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
6040 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6041 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6042
6043 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6044 Ok(())
6045 }
6046
6047 #[allow(clippy::too_many_arguments)]
6048 fn write_wrapped_image_load(
6049 &mut self,
6050 module: &crate::Module,
6051 func_ctx: &back::FunctionCtx,
6052 image: Handle<crate::Expression>,
6053 _coordinate: Handle<crate::Expression>,
6054 _array_index: Option<Handle<crate::Expression>>,
6055 _sample: Option<Handle<crate::Expression>>,
6056 _level: Option<Handle<crate::Expression>>,
6057 ) -> BackendResult {
6058 let class = match *func_ctx.resolve_type(image, &module.types) {
6060 crate::TypeInner::Image { class, .. } => class,
6061 _ => unreachable!(),
6062 };
6063 if class != crate::ImageClass::External {
6064 return Ok(());
6065 }
6066 let wrapped = WrappedFunction::ImageLoad { class };
6067 if !self.wrapped_functions.insert(wrapped) {
6068 return Ok(());
6069 }
6070
6071 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6072 let l1 = back::Level(1);
6073 let l2 = l1.next();
6074 let l3 = l2.next();
6075 writeln!(
6076 self.out,
6077 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6078 )?;
6079 writeln!(
6083 self.out,
6084 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6085 )?;
6086 writeln!(
6087 self.out,
6088 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6089 )?;
6090
6091 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6093 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6094 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6096 writeln!(self.out, "{l1}}} else {{")?;
6097
6098 writeln!(
6100 self.out,
6101 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6102 )?;
6103 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6104
6105 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6107
6108 writeln!(self.out, "{l2}float2 uv;")?;
6109 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6110 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6112 writeln!(self.out, "{l2}}} else {{")?;
6113 writeln!(
6115 self.out,
6116 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6117 )?;
6118 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6119 writeln!(
6120 self.out,
6121 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6122 )?;
6123 writeln!(self.out, "{l2}}}")?;
6124
6125 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6126
6127 writeln!(self.out, "{l1}}}")?;
6128 writeln!(self.out, "}}")?;
6129 writeln!(self.out)?;
6130 Ok(())
6131 }
6132
6133 #[allow(clippy::too_many_arguments)]
6134 fn write_wrapped_image_sample(
6135 &mut self,
6136 module: &crate::Module,
6137 func_ctx: &back::FunctionCtx,
6138 image: Handle<crate::Expression>,
6139 _sampler: Handle<crate::Expression>,
6140 _gather: Option<crate::SwizzleComponent>,
6141 _coordinate: Handle<crate::Expression>,
6142 _array_index: Option<Handle<crate::Expression>>,
6143 _offset: Option<Handle<crate::Expression>>,
6144 _level: crate::SampleLevel,
6145 _depth_ref: Option<Handle<crate::Expression>>,
6146 clamp_to_edge: bool,
6147 ) -> BackendResult {
6148 if !clamp_to_edge {
6151 return Ok(());
6152 }
6153 let class = match *func_ctx.resolve_type(image, &module.types) {
6154 crate::TypeInner::Image { class, .. } => class,
6155 _ => unreachable!(),
6156 };
6157 let wrapped = WrappedFunction::ImageSample {
6158 class,
6159 clamp_to_edge: true,
6160 };
6161 if !self.wrapped_functions.insert(wrapped) {
6162 return Ok(());
6163 }
6164 match class {
6165 crate::ImageClass::External => {
6166 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6167 let l1 = back::Level(1);
6168 let l2 = l1.next();
6169 let l3 = l2.next();
6170 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6171 writeln!(
6172 self.out,
6173 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6174 )?;
6175
6176 writeln!(
6184 self.out,
6185 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6186 )?;
6187 writeln!(
6188 self.out,
6189 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6190 )?;
6191 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6192 writeln!(
6193 self.out,
6194 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6195 )?;
6196 writeln!(
6197 self.out,
6198 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6199 )?;
6200 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6201 writeln!(
6203 self.out,
6204 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6205 )?;
6206 writeln!(self.out, "{l1}}} else {{")?;
6207 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6208 writeln!(
6209 self.out,
6210 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6211 )?;
6212 writeln!(
6213 self.out,
6214 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6215 )?;
6216
6217 writeln!(
6219 self.out,
6220 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6221 )?;
6222 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6223 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6224 writeln!(
6226 self.out,
6227 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6228 )?;
6229 writeln!(self.out, "{l2}}} else {{")?;
6230 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6232 writeln!(
6233 self.out,
6234 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6235 )?;
6236 writeln!(
6237 self.out,
6238 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6239 )?;
6240 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6241 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6242 writeln!(self.out, "{l2}}}")?;
6243
6244 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6245
6246 writeln!(self.out, "{l1}}}")?;
6247 writeln!(self.out, "}}")?;
6248 writeln!(self.out)?;
6249 }
6250 _ => {
6251 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6252 let l1 = back::Level(1);
6253 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6254 writeln!(
6255 self.out,
6256 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6257 )?;
6258 writeln!(self.out, "}}")?;
6259 writeln!(self.out)?;
6260 }
6261 }
6262 Ok(())
6263 }
6264
6265 fn write_wrapped_image_query(
6266 &mut self,
6267 module: &crate::Module,
6268 func_ctx: &back::FunctionCtx,
6269 image: Handle<crate::Expression>,
6270 query: crate::ImageQuery,
6271 ) -> BackendResult {
6272 if !matches!(query, crate::ImageQuery::Size { .. }) {
6274 return Ok(());
6275 }
6276 let class = match *func_ctx.resolve_type(image, &module.types) {
6277 crate::TypeInner::Image { class, .. } => class,
6278 _ => unreachable!(),
6279 };
6280 if class != crate::ImageClass::External {
6281 return Ok(());
6282 }
6283 let wrapped = WrappedFunction::ImageQuerySize { class };
6284 if !self.wrapped_functions.insert(wrapped) {
6285 return Ok(());
6286 }
6287 writeln!(
6288 self.out,
6289 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6290 )?;
6291 let l1 = back::Level(1);
6292 let l2 = l1.next();
6293 writeln!(
6294 self.out,
6295 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6296 )?;
6297 writeln!(self.out, "{l2}return tex.params.size;")?;
6298 writeln!(self.out, "{l1}}} else {{")?;
6299 writeln!(
6301 self.out,
6302 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6303 )?;
6304 writeln!(self.out, "{l1}}}")?;
6305 writeln!(self.out, "}}")?;
6306 writeln!(self.out)?;
6307 Ok(())
6308 }
6309
6310 pub(super) fn write_wrapped_functions(
6311 &mut self,
6312 module: &crate::Module,
6313 func_ctx: &back::FunctionCtx,
6314 ) -> BackendResult {
6315 for (expr_handle, expr) in func_ctx.expressions.iter() {
6316 match *expr {
6317 crate::Expression::Unary { op, expr: operand } => {
6318 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6319 }
6320 crate::Expression::Binary { op, left, right } => {
6321 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6322 }
6323 crate::Expression::Math {
6324 fun,
6325 arg,
6326 arg1,
6327 arg2,
6328 arg3,
6329 } => {
6330 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6331 }
6332 crate::Expression::As {
6333 expr,
6334 kind,
6335 convert,
6336 } => {
6337 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6338 }
6339 crate::Expression::ImageLoad {
6340 image,
6341 coordinate,
6342 array_index,
6343 sample,
6344 level,
6345 } => {
6346 self.write_wrapped_image_load(
6347 module,
6348 func_ctx,
6349 image,
6350 coordinate,
6351 array_index,
6352 sample,
6353 level,
6354 )?;
6355 }
6356 crate::Expression::ImageSample {
6357 image,
6358 sampler,
6359 gather,
6360 coordinate,
6361 array_index,
6362 offset,
6363 level,
6364 depth_ref,
6365 clamp_to_edge,
6366 } => {
6367 self.write_wrapped_image_sample(
6368 module,
6369 func_ctx,
6370 image,
6371 sampler,
6372 gather,
6373 coordinate,
6374 array_index,
6375 offset,
6376 level,
6377 depth_ref,
6378 clamp_to_edge,
6379 )?;
6380 }
6381 crate::Expression::ImageQuery { image, query } => {
6382 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6383 }
6384 _ => {}
6385 }
6386 }
6387
6388 Ok(())
6389 }
6390
6391 fn write_functions(
6393 &mut self,
6394 module: &crate::Module,
6395 mod_info: &valid::ModuleInfo,
6396 options: &Options,
6397 pipeline_options: &PipelineOptions,
6398 ) -> Result<TranslationInfo, Error> {
6399 use back::msl::VertexFormat;
6400
6401 struct AttributeMappingResolved {
6404 ty_name: String,
6405 dimension: u32,
6406 ty_is_int: bool,
6407 name: String,
6408 }
6409 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6410
6411 struct VertexBufferMappingResolved<'a> {
6412 id: u32,
6413 stride: u32,
6414 step_mode: back::msl::VertexBufferStepMode,
6415 ty_name: String,
6416 param_name: String,
6417 elem_name: String,
6418 attributes: &'a Vec<back::msl::AttributeMapping>,
6419 }
6420 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6421
6422 struct UnpackingFunction {
6424 name: String,
6425 byte_count: u32,
6426 dimension: u32,
6427 }
6428 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6429
6430 let mut needs_vertex_id = false;
6436 let v_id = self.namer.call("v_id");
6437
6438 let mut needs_instance_id = false;
6439 let i_id = self.namer.call("i_id");
6440 if pipeline_options.vertex_pulling_transform {
6441 for vbm in &pipeline_options.vertex_buffer_mappings {
6442 let buffer_id = vbm.id;
6443 let buffer_stride = vbm.stride;
6444
6445 assert!(
6446 buffer_stride > 0,
6447 "Vertex pulling requires a non-zero buffer stride."
6448 );
6449
6450 match vbm.step_mode {
6451 back::msl::VertexBufferStepMode::Constant => {}
6452 back::msl::VertexBufferStepMode::ByVertex => {
6453 needs_vertex_id = true;
6454 }
6455 back::msl::VertexBufferStepMode::ByInstance => {
6456 needs_instance_id = true;
6457 }
6458 }
6459
6460 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6461 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6462 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6463
6464 vbm_resolved.push(VertexBufferMappingResolved {
6465 id: buffer_id,
6466 stride: buffer_stride,
6467 step_mode: vbm.step_mode,
6468 ty_name: buffer_ty,
6469 param_name: buffer_param,
6470 elem_name: buffer_elem,
6471 attributes: &vbm.attributes,
6472 });
6473
6474 for attribute in &vbm.attributes {
6476 if unpacking_functions.contains_key(&attribute.format) {
6477 continue;
6478 }
6479 let (name, byte_count, dimension) =
6480 match self.write_unpacking_function(attribute.format) {
6481 Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6482 _ => {
6483 continue;
6484 }
6485 };
6486 unpacking_functions.insert(
6487 attribute.format,
6488 UnpackingFunction {
6489 name,
6490 byte_count,
6491 dimension,
6492 },
6493 );
6494 }
6495 }
6496 }
6497
6498 let mut pass_through_globals = Vec::new();
6499 for (fun_handle, fun) in module.functions.iter() {
6500 log::trace!(
6501 "function {:?}, handle {:?}",
6502 fun.name.as_deref().unwrap_or("(anonymous)"),
6503 fun_handle
6504 );
6505
6506 let ctx = back::FunctionCtx {
6507 ty: back::FunctionType::Function(fun_handle),
6508 info: &mod_info[fun_handle],
6509 expressions: &fun.expressions,
6510 named_expressions: &fun.named_expressions,
6511 };
6512
6513 writeln!(self.out)?;
6514 self.write_wrapped_functions(module, &ctx)?;
6515
6516 let fun_info = &mod_info[fun_handle];
6517 pass_through_globals.clear();
6518 let mut needs_buffer_sizes = false;
6519 for (handle, var) in module.global_variables.iter() {
6520 if !fun_info[handle].is_empty() {
6521 if var.space.needs_pass_through() {
6522 pass_through_globals.push(handle);
6523 }
6524 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6525 }
6526 }
6527
6528 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6529 match fun.result {
6530 Some(ref result) => {
6531 let ty_name = TypeContext {
6532 handle: result.ty,
6533 gctx: module.to_ctx(),
6534 names: &self.names,
6535 access: crate::StorageAccess::empty(),
6536 first_time: false,
6537 };
6538 write!(self.out, "{ty_name}")?;
6539 }
6540 None => {
6541 write!(self.out, "void")?;
6542 }
6543 }
6544 writeln!(self.out, " {fun_name}(")?;
6545
6546 for (index, arg) in fun.arguments.iter().enumerate() {
6547 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6548 let param_type_name = TypeContext {
6549 handle: arg.ty,
6550 gctx: module.to_ctx(),
6551 names: &self.names,
6552 access: crate::StorageAccess::empty(),
6553 first_time: false,
6554 };
6555 let separator = separate(
6556 !pass_through_globals.is_empty()
6557 || index + 1 != fun.arguments.len()
6558 || needs_buffer_sizes,
6559 );
6560 writeln!(
6561 self.out,
6562 "{}{} {}{}",
6563 back::INDENT,
6564 param_type_name,
6565 name,
6566 separator
6567 )?;
6568 }
6569 for (index, &handle) in pass_through_globals.iter().enumerate() {
6570 let tyvar = TypedGlobalVariable {
6571 module,
6572 names: &self.names,
6573 handle,
6574 usage: fun_info[handle],
6575
6576 reference: true,
6577 };
6578 let separator =
6579 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6580 write!(self.out, "{}", back::INDENT)?;
6581 tyvar.try_fmt(&mut self.out)?;
6582 writeln!(self.out, "{separator}")?;
6583 }
6584
6585 if needs_buffer_sizes {
6586 writeln!(
6587 self.out,
6588 "{}constant _mslBufferSizes& _buffer_sizes",
6589 back::INDENT
6590 )?;
6591 }
6592
6593 writeln!(self.out, ") {{")?;
6594
6595 let guarded_indices =
6596 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6597
6598 let context = StatementContext {
6599 expression: ExpressionContext {
6600 function: fun,
6601 origin: FunctionOrigin::Handle(fun_handle),
6602 info: fun_info,
6603 lang_version: options.lang_version,
6604 policies: options.bounds_check_policies,
6605 guarded_indices,
6606 module,
6607 mod_info,
6608 pipeline_options,
6609 force_loop_bounding: options.force_loop_bounding,
6610 },
6611 result_struct: None,
6612 };
6613
6614 self.put_locals(&context.expression)?;
6615 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6616 self.put_block(back::Level(1), &fun.body, &context)?;
6617 writeln!(self.out, "}}")?;
6618 self.named_expressions.clear();
6619 }
6620
6621 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6622 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6623
6624 let mut info = TranslationInfo {
6625 entry_point_names: Vec::with_capacity(ep_range.len()),
6626 };
6627
6628 for ep_index in ep_range {
6629 let ep = &module.entry_points[ep_index];
6630 let fun = &ep.function;
6631 let fun_info = mod_info.get_entry_point(ep_index);
6632 let mut ep_error = None;
6633
6634 let mut v_existing_id = None;
6638 let mut i_existing_id = None;
6639
6640 log::trace!(
6641 "entry point {:?}, index {:?}",
6642 fun.name.as_deref().unwrap_or("(anonymous)"),
6643 ep_index
6644 );
6645
6646 let ctx = back::FunctionCtx {
6647 ty: back::FunctionType::EntryPoint(ep_index as u16),
6648 info: fun_info,
6649 expressions: &fun.expressions,
6650 named_expressions: &fun.named_expressions,
6651 };
6652
6653 self.write_wrapped_functions(module, &ctx)?;
6654
6655 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6656 crate::ShaderStage::Vertex => (
6657 "vertex",
6658 LocationMode::VertexInput,
6659 LocationMode::VertexOutput,
6660 true,
6661 ),
6662 crate::ShaderStage::Fragment => (
6663 "fragment",
6664 LocationMode::FragmentInput,
6665 LocationMode::FragmentOutput,
6666 false,
6667 ),
6668 crate::ShaderStage::Compute => (
6669 "kernel",
6670 LocationMode::Uniform,
6671 LocationMode::Uniform,
6672 false,
6673 ),
6674 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6675 };
6676
6677 let do_vertex_pulling = can_vertex_pull
6679 && pipeline_options.vertex_pulling_transform
6680 && !pipeline_options.vertex_buffer_mappings.is_empty();
6681
6682 let needs_buffer_sizes = do_vertex_pulling
6684 || module
6685 .global_variables
6686 .iter()
6687 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6688 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6689
6690 if !options.fake_missing_bindings {
6693 for (var_handle, var) in module.global_variables.iter() {
6694 if fun_info[var_handle].is_empty() {
6695 continue;
6696 }
6697 match var.space {
6698 crate::AddressSpace::Uniform
6699 | crate::AddressSpace::Storage { .. }
6700 | crate::AddressSpace::Handle => {
6701 let br = match var.binding {
6702 Some(ref br) => br,
6703 None => {
6704 let var_name = var.name.clone().unwrap_or_default();
6705 ep_error =
6706 Some(super::EntryPointError::MissingBinding(var_name));
6707 break;
6708 }
6709 };
6710 let target = options.get_resource_binding_target(ep, br);
6711 let good = match target {
6712 Some(target) => {
6713 match module.types[var.ty].inner {
6717 crate::TypeInner::Image {
6718 class: crate::ImageClass::External,
6719 ..
6720 } => target.external_texture.is_some(),
6721 crate::TypeInner::Image { .. } => target.texture.is_some(),
6722 crate::TypeInner::Sampler { .. } => {
6723 target.sampler.is_some()
6724 }
6725 _ => target.buffer.is_some(),
6726 }
6727 }
6728 None => false,
6729 };
6730 if !good {
6731 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6732 break;
6733 }
6734 }
6735 crate::AddressSpace::PushConstant => {
6736 if let Err(e) = options.resolve_push_constants(ep) {
6737 ep_error = Some(e);
6738 break;
6739 }
6740 }
6741 crate::AddressSpace::TaskPayload => {
6742 unimplemented!()
6743 }
6744 crate::AddressSpace::Function
6745 | crate::AddressSpace::Private
6746 | crate::AddressSpace::WorkGroup => {}
6747 }
6748 }
6749 if needs_buffer_sizes {
6750 if let Err(err) = options.resolve_sizes_buffer(ep) {
6751 ep_error = Some(err);
6752 }
6753 }
6754 }
6755
6756 if let Some(err) = ep_error {
6757 info.entry_point_names.push(Err(err));
6758 continue;
6759 }
6760 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6761 info.entry_point_names.push(Ok(fun_name.clone()));
6762
6763 writeln!(self.out)?;
6764
6765 let mut flattened_member_names = FastHashMap::default();
6771 let mut varyings_namer = proc::Namer::default();
6773
6774 let mut flattened_arguments = Vec::new();
6779 for (arg_index, arg) in fun.arguments.iter().enumerate() {
6780 match module.types[arg.ty].inner {
6781 crate::TypeInner::Struct { ref members, .. } => {
6782 for (member_index, member) in members.iter().enumerate() {
6783 let member_index = member_index as u32;
6784 flattened_arguments.push((
6785 NameKey::StructMember(arg.ty, member_index),
6786 member.ty,
6787 member.binding.as_ref(),
6788 ));
6789 let name_key = NameKey::StructMember(arg.ty, member_index);
6790 let name = match member.binding {
6791 Some(crate::Binding::Location { .. }) => {
6792 if do_vertex_pulling {
6793 self.namer.call(&self.names[&name_key])
6794 } else {
6795 varyings_namer.call(&self.names[&name_key])
6796 }
6797 }
6798 _ => self.namer.call(&self.names[&name_key]),
6799 };
6800 flattened_member_names.insert(name_key, name);
6801 }
6802 }
6803 _ => flattened_arguments.push((
6804 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6805 arg.ty,
6806 arg.binding.as_ref(),
6807 )),
6808 }
6809 }
6810
6811 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6816 let varyings_member_name = self.namer.call("varyings");
6817 let mut has_varyings = false;
6818 if !flattened_arguments.is_empty() {
6819 if !do_vertex_pulling {
6820 writeln!(self.out, "struct {stage_in_name} {{")?;
6821 }
6822 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6823 let (binding, location) = match binding {
6824 Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6825 (binding, location)
6826 }
6827 _ => continue,
6828 };
6829 let name = match *name_key {
6830 NameKey::StructMember(..) => &flattened_member_names[name_key],
6831 _ => &self.names[name_key],
6832 };
6833 let ty_name = TypeContext {
6834 handle: ty,
6835 gctx: module.to_ctx(),
6836 names: &self.names,
6837 access: crate::StorageAccess::empty(),
6838 first_time: false,
6839 };
6840 let resolved = options.resolve_local_binding(binding, in_mode)?;
6841 if do_vertex_pulling {
6842 am_resolved.insert(
6844 location,
6845 AttributeMappingResolved {
6846 ty_name: ty_name.to_string(),
6847 dimension: ty_name.vertex_input_dimension(),
6848 ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6849 name: name.to_string(),
6850 },
6851 );
6852 } else {
6853 has_varyings = true;
6854 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6855 resolved.try_fmt(&mut self.out)?;
6856 writeln!(self.out, ";")?;
6857 }
6858 }
6859 if !do_vertex_pulling {
6860 writeln!(self.out, "}};")?;
6861 }
6862 }
6863
6864 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6867 let result_member_name = self.namer.call("member");
6868 let result_type_name = match fun.result {
6869 Some(ref result) => {
6870 let mut result_members = Vec::new();
6871 if let crate::TypeInner::Struct { ref members, .. } =
6872 module.types[result.ty].inner
6873 {
6874 for (member_index, member) in members.iter().enumerate() {
6875 result_members.push((
6876 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6877 member.ty,
6878 member.binding.as_ref(),
6879 ));
6880 }
6881 } else {
6882 result_members.push((
6883 &result_member_name,
6884 result.ty,
6885 result.binding.as_ref(),
6886 ));
6887 }
6888
6889 writeln!(self.out, "struct {stage_out_name} {{")?;
6890 let mut has_point_size = false;
6891 for (name, ty, binding) in result_members {
6892 let ty_name = TypeContext {
6893 handle: ty,
6894 gctx: module.to_ctx(),
6895 names: &self.names,
6896 access: crate::StorageAccess::empty(),
6897 first_time: true,
6898 };
6899 let binding = binding.ok_or_else(|| {
6900 Error::GenericValidation("Expected binding, got None".into())
6901 })?;
6902
6903 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6904 has_point_size = true;
6905 if !pipeline_options.allow_and_force_point_size {
6906 continue;
6907 }
6908 }
6909
6910 let array_len = match module.types[ty].inner {
6911 crate::TypeInner::Array {
6912 size: crate::ArraySize::Constant(size),
6913 ..
6914 } => Some(size),
6915 _ => None,
6916 };
6917 let resolved = options.resolve_local_binding(binding, out_mode)?;
6918 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6919 if let Some(array_len) = array_len {
6920 write!(self.out, " [{array_len}]")?;
6921 }
6922 resolved.try_fmt(&mut self.out)?;
6923 writeln!(self.out, ";")?;
6924 }
6925
6926 if pipeline_options.allow_and_force_point_size
6927 && ep.stage == crate::ShaderStage::Vertex
6928 && !has_point_size
6929 {
6930 writeln!(
6932 self.out,
6933 "{}float _point_size [[point_size]];",
6934 back::INDENT
6935 )?;
6936 }
6937 writeln!(self.out, "}};")?;
6938 &stage_out_name
6939 }
6940 None => "void",
6941 };
6942
6943 if do_vertex_pulling {
6946 for vbm in &vbm_resolved {
6947 let buffer_stride = vbm.stride;
6948 let buffer_ty = &vbm.ty_name;
6949
6950 writeln!(
6954 self.out,
6955 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6956 )?;
6957 }
6958 }
6959
6960 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6962
6963 let mut is_first_argument = true;
6964 let mut separator = || {
6965 if is_first_argument {
6966 is_first_argument = false;
6967 ' '
6968 } else {
6969 ','
6970 }
6971 };
6972
6973 if has_varyings {
6976 writeln!(
6977 self.out,
6978 "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
6979 separator()
6980 )?;
6981 }
6982
6983 let mut local_invocation_id = None;
6984
6985 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6988 let binding = match binding {
6989 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6990 _ => continue,
6991 };
6992 let name = match *name_key {
6993 NameKey::StructMember(..) => &flattened_member_names[name_key],
6994 _ => &self.names[name_key],
6995 };
6996
6997 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6998 local_invocation_id = Some(name_key);
6999 }
7000
7001 let ty_name = TypeContext {
7002 handle: ty,
7003 gctx: module.to_ctx(),
7004 names: &self.names,
7005 access: crate::StorageAccess::empty(),
7006 first_time: false,
7007 };
7008
7009 match *binding {
7010 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
7011 v_existing_id = Some(name.clone());
7012 }
7013 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
7014 i_existing_id = Some(name.clone());
7015 }
7016 _ => {}
7017 };
7018
7019 let resolved = options.resolve_local_binding(binding, in_mode)?;
7020 write!(self.out, "{} {ty_name} {name}", separator())?;
7021 resolved.try_fmt(&mut self.out)?;
7022 writeln!(self.out)?;
7023 }
7024
7025 let need_workgroup_variables_initialization =
7026 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
7027
7028 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
7029 writeln!(
7030 self.out,
7031 "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
7032 )?;
7033 }
7034
7035 for (handle, var) in module.global_variables.iter() {
7040 let usage = fun_info[handle];
7041 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7042 continue;
7043 }
7044
7045 if options.lang_version < (1, 2) {
7046 match var.space {
7047 crate::AddressSpace::Storage { access }
7057 if access.contains(crate::StorageAccess::STORE)
7058 && ep.stage == crate::ShaderStage::Fragment =>
7059 {
7060 return Err(Error::UnsupportedWriteableStorageBuffer)
7061 }
7062 crate::AddressSpace::Handle => {
7063 match module.types[var.ty].inner {
7064 crate::TypeInner::Image {
7065 class: crate::ImageClass::Storage { access, .. },
7066 ..
7067 } => {
7068 if access.contains(crate::StorageAccess::STORE)
7078 && (ep.stage == crate::ShaderStage::Vertex
7079 || ep.stage == crate::ShaderStage::Fragment)
7080 {
7081 return Err(Error::UnsupportedWriteableStorageTexture(
7082 ep.stage,
7083 ));
7084 }
7085
7086 if access.contains(
7087 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7088 ) {
7089 return Err(Error::UnsupportedRWStorageTexture);
7090 }
7091 }
7092 _ => {}
7093 }
7094 }
7095 _ => {}
7096 }
7097 }
7098
7099 match var.space {
7101 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7102 crate::TypeInner::BindingArray { base, .. } => {
7103 match module.types[base].inner {
7104 crate::TypeInner::Sampler { .. } => {
7105 if options.lang_version < (2, 0) {
7106 return Err(Error::UnsupportedArrayOf(
7107 "samplers".to_string(),
7108 ));
7109 }
7110 }
7111 crate::TypeInner::Image { class, .. } => match class {
7112 crate::ImageClass::Sampled { .. }
7113 | crate::ImageClass::Depth { .. }
7114 | crate::ImageClass::Storage {
7115 access: crate::StorageAccess::LOAD,
7116 ..
7117 } => {
7118 if options.lang_version < (2, 0) {
7123 return Err(Error::UnsupportedArrayOf(
7124 "textures".to_string(),
7125 ));
7126 }
7127 }
7128 crate::ImageClass::Storage {
7129 access: crate::StorageAccess::STORE,
7130 ..
7131 } => {
7132 if options.lang_version < (2, 0) {
7137 return Err(Error::UnsupportedArrayOf(
7138 "write-only textures".to_string(),
7139 ));
7140 }
7141 }
7142 crate::ImageClass::Storage { .. } => {
7143 return Err(Error::UnsupportedArrayOf(
7144 "read-write textures".to_string(),
7145 ));
7146 }
7147 crate::ImageClass::External => {
7148 return Err(Error::UnsupportedArrayOf(
7149 "external textures".to_string(),
7150 ));
7151 }
7152 },
7153 _ => {
7154 return Err(Error::UnsupportedArrayOfType(base));
7155 }
7156 }
7157 }
7158 _ => {}
7159 },
7160 _ => {}
7161 }
7162
7163 let resolved = match var.space {
7165 crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7166 crate::AddressSpace::WorkGroup => None,
7167 _ => options
7168 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7169 .ok(),
7170 };
7171 if let Some(ref resolved) = resolved {
7172 if resolved.as_inline_sampler(options).is_some() {
7174 continue;
7175 }
7176 }
7177
7178 match module.types[var.ty].inner {
7179 crate::TypeInner::Image {
7180 class: crate::ImageClass::External,
7181 ..
7182 } => {
7183 let target = match resolved {
7187 Some(back::msl::ResolvedBinding::Resource(target)) => {
7188 target.external_texture
7189 }
7190 _ => None,
7191 };
7192
7193 for i in 0..3 {
7194 write!(self.out, "{} ", separator())?;
7195
7196 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7197 handle,
7198 ExternalTextureNameKey::Plane(i),
7199 )];
7200 write!(
7201 self.out,
7202 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7203 )?;
7204 if let Some(ref target) = target {
7205 write!(self.out, " [[texture({})]]", target.planes[i])?;
7206 }
7207 writeln!(self.out)?;
7208 }
7209 let params_ty_name = &self.names
7210 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7211 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7212 handle,
7213 ExternalTextureNameKey::Params,
7214 )];
7215 write!(self.out, "{} ", separator())?;
7216 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7217 if let Some(ref target) = target {
7218 write!(self.out, " [[buffer({})]]", target.params)?;
7219 }
7220 }
7221 _ => {
7222 let tyvar = TypedGlobalVariable {
7223 module,
7224 names: &self.names,
7225 handle,
7226 usage,
7227 reference: true,
7228 };
7229 write!(self.out, "{} ", separator())?;
7230 tyvar.try_fmt(&mut self.out)?;
7231 if let Some(resolved) = resolved {
7232 resolved.try_fmt(&mut self.out)?;
7233 }
7234 if let Some(value) = var.init {
7235 write!(self.out, " = ")?;
7236 self.put_const_expression(
7237 value,
7238 module,
7239 mod_info,
7240 &module.global_expressions,
7241 )?;
7242 }
7243 }
7244 }
7245 writeln!(self.out)?;
7246 }
7247
7248 if do_vertex_pulling {
7249 if needs_vertex_id && v_existing_id.is_none() {
7250 writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7252 }
7253
7254 if needs_instance_id && i_existing_id.is_none() {
7255 writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7256 }
7257
7258 for vbm in &vbm_resolved {
7261 let id = &vbm.id;
7262 let ty_name = &vbm.ty_name;
7263 let param_name = &vbm.param_name;
7264 writeln!(
7265 self.out,
7266 "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7267 separator()
7268 )?;
7269 }
7270 }
7271
7272 if needs_buffer_sizes {
7275 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7277 write!(
7278 self.out,
7279 "{} constant _mslBufferSizes& _buffer_sizes",
7280 separator()
7281 )?;
7282 resolved.try_fmt(&mut self.out)?;
7283 writeln!(self.out)?;
7284 }
7285
7286 writeln!(self.out, ") {{")?;
7288
7289 if do_vertex_pulling {
7291 for vbm in &vbm_resolved {
7294 for attribute in vbm.attributes {
7295 let location = attribute.shader_location;
7296 let am_option = am_resolved.get(&location);
7297 if am_option.is_none() {
7298 continue;
7301 }
7302 let am = am_option.unwrap();
7303 let attribute_ty_name = &am.ty_name;
7304 let attribute_name = &am.name;
7305
7306 writeln!(
7307 self.out,
7308 "{}{attribute_ty_name} {attribute_name} = {{}};",
7309 back::Level(1)
7310 )?;
7311 }
7312
7313 write!(self.out, "{}if (", back::Level(1))?;
7316
7317 let idx = &vbm.id;
7318 let stride = &vbm.stride;
7319 let index_name = match vbm.step_mode {
7320 back::msl::VertexBufferStepMode::Constant => "0",
7321 back::msl::VertexBufferStepMode::ByVertex => {
7322 if let Some(ref name) = v_existing_id {
7323 name
7324 } else {
7325 &v_id
7326 }
7327 }
7328 back::msl::VertexBufferStepMode::ByInstance => {
7329 if let Some(ref name) = i_existing_id {
7330 name
7331 } else {
7332 &i_id
7333 }
7334 }
7335 };
7336 write!(
7337 self.out,
7338 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7339 )?;
7340
7341 writeln!(self.out, ") {{")?;
7342
7343 let ty_name = &vbm.ty_name;
7345 let elem_name = &vbm.elem_name;
7346 let param_name = &vbm.param_name;
7347
7348 writeln!(
7349 self.out,
7350 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7351 back::Level(2),
7352 )?;
7353
7354 for attribute in vbm.attributes {
7357 let location = attribute.shader_location;
7358 let Some(am) = am_resolved.get(&location) else {
7359 continue;
7363 };
7364 let attribute_name = &am.name;
7365 let attribute_ty_name = &am.ty_name;
7366
7367 let offset = attribute.offset;
7368 let func = unpacking_functions
7369 .get(&attribute.format)
7370 .expect("Should have generated this unpacking function earlier.");
7371 let func_name = &func.name;
7372
7373 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7380
7381 if needs_padding_or_truncation != Ordering::Equal {
7382 writeln!(
7385 self.out,
7386 "{}// {attribute_ty_name} <- {:?}",
7387 back::Level(2),
7388 attribute.format
7389 )?;
7390 }
7391
7392 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7393
7394 if needs_padding_or_truncation == Ordering::Greater {
7395 write!(self.out, "{attribute_ty_name}(")?;
7397 }
7398
7399 write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7401 for i in (offset + 1)..(offset + func.byte_count) {
7402 write!(self.out, ", {elem_name}.data[{i}]")?;
7403 }
7404 write!(self.out, ")")?;
7405
7406 match needs_padding_or_truncation {
7407 Ordering::Greater => {
7408 let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7410 let one_value = if am.ty_is_int { "1" } else { "1.0" };
7411 for i in func.dimension..am.dimension {
7412 write!(
7413 self.out,
7414 ", {}",
7415 if i == 3 { one_value } else { zero_value }
7416 )?;
7417 }
7418 write!(self.out, ")")?;
7419 }
7420 Ordering::Less => {
7421 write!(
7423 self.out,
7424 ".{}",
7425 &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7426 )?;
7427 }
7428 Ordering::Equal => {}
7429 }
7430
7431 writeln!(self.out, ";")?;
7432 }
7433
7434 writeln!(self.out, "{}}}", back::Level(1))?;
7436 }
7437 }
7438
7439 if need_workgroup_variables_initialization {
7440 self.write_workgroup_variables_initialization(
7441 module,
7442 mod_info,
7443 fun_info,
7444 local_invocation_id,
7445 )?;
7446 }
7447
7448 for (handle, var) in module.global_variables.iter() {
7451 let usage = fun_info[handle];
7452 if usage.is_empty() {
7453 continue;
7454 }
7455 if var.space == crate::AddressSpace::Private {
7456 let tyvar = TypedGlobalVariable {
7457 module,
7458 names: &self.names,
7459 handle,
7460 usage,
7461
7462 reference: false,
7463 };
7464 write!(self.out, "{}", back::INDENT)?;
7465 tyvar.try_fmt(&mut self.out)?;
7466 match var.init {
7467 Some(value) => {
7468 write!(self.out, " = ")?;
7469 self.put_const_expression(
7470 value,
7471 module,
7472 mod_info,
7473 &module.global_expressions,
7474 )?;
7475 writeln!(self.out, ";")?;
7476 }
7477 None => {
7478 writeln!(self.out, " = {{}};")?;
7479 }
7480 };
7481 } else if let Some(ref binding) = var.binding {
7482 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7483 if let Some(sampler) = resolved.as_inline_sampler(options) {
7484 let name = &self.names[&NameKey::GlobalVariable(handle)];
7486 writeln!(
7487 self.out,
7488 "{}constexpr {}::sampler {}(",
7489 back::INDENT,
7490 NAMESPACE,
7491 name
7492 )?;
7493 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7494 writeln!(self.out, "{});", back::INDENT)?;
7495 } else if let crate::TypeInner::Image {
7496 class: crate::ImageClass::External,
7497 ..
7498 } = module.types[var.ty].inner
7499 {
7500 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7503 let l1 = back::Level(1);
7504 let l2 = l1.next();
7505 writeln!(
7506 self.out,
7507 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7508 )?;
7509 for i in 0..3 {
7510 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7511 handle,
7512 ExternalTextureNameKey::Plane(i),
7513 )];
7514 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7515 }
7516 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7517 handle,
7518 ExternalTextureNameKey::Params,
7519 )];
7520 writeln!(self.out, "{l2}.params = {params_name},")?;
7521 writeln!(self.out, "{l1}}};")?;
7522 }
7523 }
7524 }
7525
7526 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7537 let arg_name =
7538 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7539 match module.types[arg.ty].inner {
7540 crate::TypeInner::Struct { ref members, .. } => {
7541 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7542 write!(
7543 self.out,
7544 "{}const {} {} = {{ ",
7545 back::INDENT,
7546 struct_name,
7547 arg_name
7548 )?;
7549 for (member_index, member) in members.iter().enumerate() {
7550 let key = NameKey::StructMember(arg.ty, member_index as u32);
7551 let name = &flattened_member_names[&key];
7552 if member_index != 0 {
7553 write!(self.out, ", ")?;
7554 }
7555 if self
7557 .struct_member_pads
7558 .contains(&(arg.ty, member_index as u32))
7559 {
7560 write!(self.out, "{{}}, ")?;
7561 }
7562 if let Some(crate::Binding::Location { .. }) = member.binding {
7563 if has_varyings {
7564 write!(self.out, "{varyings_member_name}.")?;
7565 }
7566 }
7567 write!(self.out, "{name}")?;
7568 }
7569 writeln!(self.out, " }};")?;
7570 }
7571 _ => {
7572 if let Some(crate::Binding::Location { .. }) = arg.binding {
7573 if has_varyings {
7574 writeln!(
7575 self.out,
7576 "{}const auto {} = {}.{};",
7577 back::INDENT,
7578 arg_name,
7579 varyings_member_name,
7580 arg_name
7581 )?;
7582 }
7583 }
7584 }
7585 }
7586 }
7587
7588 let guarded_indices =
7589 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7590
7591 let context = StatementContext {
7592 expression: ExpressionContext {
7593 function: fun,
7594 origin: FunctionOrigin::EntryPoint(ep_index as _),
7595 info: fun_info,
7596 lang_version: options.lang_version,
7597 policies: options.bounds_check_policies,
7598 guarded_indices,
7599 module,
7600 mod_info,
7601 pipeline_options,
7602 force_loop_bounding: options.force_loop_bounding,
7603 },
7604 result_struct: Some(&stage_out_name),
7605 };
7606
7607 self.put_locals(&context.expression)?;
7610 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7611 self.put_block(back::Level(1), &fun.body, &context)?;
7612 writeln!(self.out, "}}")?;
7613 if ep_index + 1 != module.entry_points.len() {
7614 writeln!(self.out)?;
7615 }
7616 self.named_expressions.clear();
7617 }
7618
7619 Ok(info)
7620 }
7621
7622 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7623 if flags.is_empty() {
7626 writeln!(
7627 self.out,
7628 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7629 )?;
7630 }
7631 if flags.contains(crate::Barrier::STORAGE) {
7632 writeln!(
7633 self.out,
7634 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7635 )?;
7636 }
7637 if flags.contains(crate::Barrier::WORK_GROUP) {
7638 writeln!(
7639 self.out,
7640 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7641 )?;
7642 }
7643 if flags.contains(crate::Barrier::SUB_GROUP) {
7644 writeln!(
7645 self.out,
7646 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7647 )?;
7648 }
7649 if flags.contains(crate::Barrier::TEXTURE) {
7650 writeln!(
7651 self.out,
7652 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7653 )?;
7654 }
7655 Ok(())
7656 }
7657}
7658
7659mod workgroup_mem_init {
7662 use crate::EntryPoint;
7663
7664 use super::*;
7665
7666 enum Access {
7667 GlobalVariable(Handle<crate::GlobalVariable>),
7668 StructMember(Handle<crate::Type>, u32),
7669 Array(usize),
7670 }
7671
7672 impl Access {
7673 fn write<W: Write>(
7674 &self,
7675 writer: &mut W,
7676 names: &FastHashMap<NameKey, String>,
7677 ) -> Result<(), core::fmt::Error> {
7678 match *self {
7679 Access::GlobalVariable(handle) => {
7680 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7681 }
7682 Access::StructMember(handle, index) => {
7683 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7684 }
7685 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7686 }
7687 }
7688 }
7689
7690 struct AccessStack {
7691 stack: Vec<Access>,
7692 array_depth: usize,
7693 }
7694
7695 impl AccessStack {
7696 const fn new() -> Self {
7697 Self {
7698 stack: Vec::new(),
7699 array_depth: 0,
7700 }
7701 }
7702
7703 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7704 let array_depth = self.array_depth;
7705 self.stack.push(Access::Array(array_depth));
7706 self.array_depth += 1;
7707 let res = cb(self, array_depth);
7708 self.stack.pop();
7709 self.array_depth -= 1;
7710 res
7711 }
7712
7713 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7714 self.stack.push(new);
7715 let res = cb(self);
7716 self.stack.pop();
7717 res
7718 }
7719
7720 fn write<W: Write>(
7721 &self,
7722 writer: &mut W,
7723 names: &FastHashMap<NameKey, String>,
7724 ) -> Result<(), core::fmt::Error> {
7725 for next in self.stack.iter() {
7726 next.write(writer, names)?;
7727 }
7728 Ok(())
7729 }
7730 }
7731
7732 impl<W: Write> Writer<W> {
7733 pub(super) fn need_workgroup_variables_initialization(
7734 &mut self,
7735 options: &Options,
7736 ep: &EntryPoint,
7737 module: &crate::Module,
7738 fun_info: &valid::FunctionInfo,
7739 ) -> bool {
7740 options.zero_initialize_workgroup_memory
7741 && ep.stage.compute_like()
7742 && module.global_variables.iter().any(|(handle, var)| {
7743 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7744 })
7745 }
7746
7747 pub(super) fn write_workgroup_variables_initialization(
7748 &mut self,
7749 module: &crate::Module,
7750 module_info: &valid::ModuleInfo,
7751 fun_info: &valid::FunctionInfo,
7752 local_invocation_id: Option<&NameKey>,
7753 ) -> BackendResult {
7754 let level = back::Level(1);
7755
7756 writeln!(
7757 self.out,
7758 "{}if ({}::all({} == {}::uint3(0u))) {{",
7759 level,
7760 NAMESPACE,
7761 local_invocation_id
7762 .map(|name_key| self.names[name_key].as_str())
7763 .unwrap_or("__local_invocation_id"),
7764 NAMESPACE,
7765 )?;
7766
7767 let mut access_stack = AccessStack::new();
7768
7769 let vars = module.global_variables.iter().filter(|&(handle, var)| {
7770 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7771 });
7772
7773 for (handle, var) in vars {
7774 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7775 self.write_workgroup_variable_initialization(
7776 module,
7777 module_info,
7778 var.ty,
7779 access_stack,
7780 level.next(),
7781 )
7782 })?;
7783 }
7784
7785 writeln!(self.out, "{level}}}")?;
7786 self.write_barrier(crate::Barrier::WORK_GROUP, level)
7787 }
7788
7789 fn write_workgroup_variable_initialization(
7790 &mut self,
7791 module: &crate::Module,
7792 module_info: &valid::ModuleInfo,
7793 ty: Handle<crate::Type>,
7794 access_stack: &mut AccessStack,
7795 level: back::Level,
7796 ) -> BackendResult {
7797 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7798 write!(self.out, "{level}")?;
7799 access_stack.write(&mut self.out, &self.names)?;
7800 writeln!(self.out, " = {{}};")?;
7801 } else {
7802 match module.types[ty].inner {
7803 crate::TypeInner::Atomic { .. } => {
7804 write!(
7805 self.out,
7806 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7807 )?;
7808 access_stack.write(&mut self.out, &self.names)?;
7809 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7810 }
7811 crate::TypeInner::Array { base, size, .. } => {
7812 let count = match size.resolve(module.to_ctx())? {
7813 proc::IndexableLength::Known(count) => count,
7814 proc::IndexableLength::Dynamic => unreachable!(),
7815 };
7816
7817 access_stack.enter_array(|access_stack, array_depth| {
7818 writeln!(
7819 self.out,
7820 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7821 )?;
7822 self.write_workgroup_variable_initialization(
7823 module,
7824 module_info,
7825 base,
7826 access_stack,
7827 level.next(),
7828 )?;
7829 writeln!(self.out, "{level}}}")?;
7830 BackendResult::Ok(())
7831 })?;
7832 }
7833 crate::TypeInner::Struct { ref members, .. } => {
7834 for (index, member) in members.iter().enumerate() {
7835 access_stack.enter(
7836 Access::StructMember(ty, index as u32),
7837 |access_stack| {
7838 self.write_workgroup_variable_initialization(
7839 module,
7840 module_info,
7841 member.ty,
7842 access_stack,
7843 level,
7844 )
7845 },
7846 )?;
7847 }
7848 }
7849 _ => unreachable!(),
7850 }
7851 }
7852
7853 Ok(())
7854 }
7855 }
7856}
7857
7858impl crate::AtomicFunction {
7859 const fn to_msl(self) -> &'static str {
7860 match self {
7861 Self::Add => "fetch_add",
7862 Self::Subtract => "fetch_sub",
7863 Self::And => "fetch_and",
7864 Self::InclusiveOr => "fetch_or",
7865 Self::ExclusiveOr => "fetch_xor",
7866 Self::Min => "fetch_min",
7867 Self::Max => "fetch_max",
7868 Self::Exchange { compare: None } => "exchange",
7869 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7870 }
7871 }
7872
7873 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7874 Ok(match self {
7875 Self::Min => "min",
7876 Self::Max => "max",
7877 _ => Err(Error::FeatureNotImplemented(
7878 "64-bit atomic operation other than min/max".to_string(),
7879 ))?,
7880 })
7881 }
7882}