1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::{
8 cmp::Ordering,
9 fmt::{Display, Error as FmtError, Formatter, Write},
10 iter,
11};
12use num_traits::real::Real as _;
13
14use half::f16;
15
16use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
17use crate::{
18 arena::{Handle, HandleSet},
19 back::{self, get_entry_points, Baked},
20 common,
21 proc::{
22 self,
23 index::{self, BoundsCheck},
24 ExternalTextureNameKey, NameKey, TypeResolution,
25 },
26 valid, FastHashMap, FastHashSet,
27};
28
29#[cfg(test)]
30use core::ptr;
31
32type BackendResult = Result<(), Error>;
34
35const NAMESPACE: &str = "metal";
36const WRAPPED_ARRAY_FIELD: &str = "inner";
40const ATOMIC_REFERENCE: &str = "&";
44
45const RT_NAMESPACE: &str = "metal::raytracing";
46const RAY_QUERY_TYPE: &str = "_RayQuery";
47const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
48const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
49const RAY_QUERY_MODERN_SUPPORT: bool = false; const RAY_QUERY_FIELD_READY: &str = "ready";
51const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
52
53pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
54pub(crate) const MODF_FUNCTION: &str = "naga_modf";
55pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
56pub(crate) const ABS_FUNCTION: &str = "naga_abs";
57pub(crate) const DIV_FUNCTION: &str = "naga_div";
58pub(crate) const MOD_FUNCTION: &str = "naga_mod";
59pub(crate) const NEG_FUNCTION: &str = "naga_neg";
60pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
61pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
62pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
63pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
64pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal";
65pub(crate) const IMAGE_SIZE_EXTERNAL_FUNCTION: &str = "nagaTextureDimensionsExternal";
66pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
67 "nagaTextureSampleBaseClampToEdge";
68pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";
76pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
81
82fn put_numeric_type(
91 out: &mut impl Write,
92 scalar: crate::Scalar,
93 sizes: &[crate::VectorSize],
94) -> Result<(), FmtError> {
95 match (scalar, sizes) {
96 (scalar, &[]) => {
97 write!(out, "{}", scalar.to_msl_name())
98 }
99 (scalar, &[rows]) => {
100 write!(
101 out,
102 "{}::{}{}",
103 NAMESPACE,
104 scalar.to_msl_name(),
105 common::vector_size_str(rows)
106 )
107 }
108 (scalar, &[rows, columns]) => {
109 write!(
110 out,
111 "{}::{}{}x{}",
112 NAMESPACE,
113 scalar.to_msl_name(),
114 common::vector_size_str(columns),
115 common::vector_size_str(rows)
116 )
117 }
118 (_, _) => Ok(()), }
120}
121
122const fn scalar_is_int(scalar: crate::Scalar) -> bool {
123 use crate::ScalarKind::*;
124 match scalar.kind {
125 Sint | Uint | AbstractInt | Bool => true,
126 Float | AbstractFloat => false,
127 }
128}
129
130const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
132
133const REINTERPRET_PREFIX: &str = "reinterpreted_";
135
136struct ClampedLod(Handle<crate::Expression>);
142
143impl Display for ClampedLod {
144 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
145 self.0.write_prefixed(f, CLAMPED_LOD_LOAD_PREFIX)
146 }
147}
148
149struct ArraySizeMember(Handle<crate::GlobalVariable>);
164
165impl Display for ArraySizeMember {
166 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167 self.0.write_prefixed(f, "size")
168 }
169}
170
171#[derive(Clone, Copy)]
176struct Reinterpreted<'a> {
177 target_type: &'a str,
178 orig: Handle<crate::Expression>,
179}
180
181impl<'a> Reinterpreted<'a> {
182 const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
183 Self { target_type, orig }
184 }
185}
186
187impl Display for Reinterpreted<'_> {
188 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
189 f.write_str(REINTERPRET_PREFIX)?;
190 f.write_str(self.target_type)?;
191 self.orig.write_prefixed(f, "_e")
192 }
193}
194
195struct TypeContext<'a> {
196 handle: Handle<crate::Type>,
197 gctx: proc::GlobalCtx<'a>,
198 names: &'a FastHashMap<NameKey, String>,
199 access: crate::StorageAccess,
200 first_time: bool,
201}
202
203impl TypeContext<'_> {
204 fn scalar(&self) -> Option<crate::Scalar> {
205 let ty = &self.gctx.types[self.handle];
206 ty.inner.scalar()
207 }
208
209 fn vertex_input_dimension(&self) -> u32 {
210 let ty = &self.gctx.types[self.handle];
211 match ty.inner {
212 crate::TypeInner::Scalar(_) => 1,
213 crate::TypeInner::Vector { size, .. } => size as u32,
214 _ => unreachable!(),
215 }
216 }
217}
218
219impl Display for TypeContext<'_> {
220 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
221 let ty = &self.gctx.types[self.handle];
222 if ty.needs_alias() && !self.first_time {
223 let name = &self.names[&NameKey::Type(self.handle)];
224 return write!(out, "{name}");
225 }
226
227 match ty.inner {
228 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
229 crate::TypeInner::Atomic(scalar) => {
230 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
231 }
232 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
233 crate::TypeInner::Matrix {
234 columns,
235 rows,
236 scalar,
237 } => put_numeric_type(out, scalar, &[rows, columns]),
238 crate::TypeInner::Pointer { base, space } => {
239 let sub = Self {
240 handle: base,
241 first_time: false,
242 ..*self
243 };
244 let space_name = match space.to_msl_name() {
245 Some(name) => name,
246 None => return Ok(()),
247 };
248 write!(out, "{space_name} {sub}&")
249 }
250 crate::TypeInner::ValuePointer {
251 size,
252 scalar,
253 space,
254 } => {
255 match space.to_msl_name() {
256 Some(name) => write!(out, "{name} ")?,
257 None => return Ok(()),
258 };
259 match size {
260 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
261 None => put_numeric_type(out, scalar, &[])?,
262 };
263
264 write!(out, "&")
265 }
266 crate::TypeInner::Array { base, .. } => {
267 let sub = Self {
268 handle: base,
269 first_time: false,
270 ..*self
271 };
272 write!(out, "{sub}")
275 }
276 crate::TypeInner::Struct { .. } => unreachable!(),
277 crate::TypeInner::Image {
278 dim,
279 arrayed,
280 class,
281 } => {
282 let dim_str = match dim {
283 crate::ImageDimension::D1 => "1d",
284 crate::ImageDimension::D2 => "2d",
285 crate::ImageDimension::D3 => "3d",
286 crate::ImageDimension::Cube => "cube",
287 };
288 let (texture_str, msaa_str, scalar, access) = match class {
289 crate::ImageClass::Sampled { kind, multi } => {
290 let (msaa_str, access) = if multi {
291 ("_ms", "read")
292 } else {
293 ("", "sample")
294 };
295 let scalar = crate::Scalar { kind, width: 4 };
296 ("texture", msaa_str, scalar, access)
297 }
298 crate::ImageClass::Depth { multi } => {
299 let (msaa_str, access) = if multi {
300 ("_ms", "read")
301 } else {
302 ("", "sample")
303 };
304 let scalar = crate::Scalar {
305 kind: crate::ScalarKind::Float,
306 width: 4,
307 };
308 ("depth", msaa_str, scalar, access)
309 }
310 crate::ImageClass::Storage { format, .. } => {
311 let access = if self
312 .access
313 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
314 {
315 "read_write"
316 } else if self.access.contains(crate::StorageAccess::STORE) {
317 "write"
318 } else if self.access.contains(crate::StorageAccess::LOAD) {
319 "read"
320 } else {
321 log::warn!(
322 "Storage access for {:?} (name '{}'): {:?}",
323 self.handle,
324 ty.name.as_deref().unwrap_or_default(),
325 self.access
326 );
327 unreachable!("module is not valid");
328 };
329 ("texture", "", format.into(), access)
330 }
331 crate::ImageClass::External => {
332 return write!(out, "{EXTERNAL_TEXTURE_WRAPPER_STRUCT}");
333 }
334 };
335 let base_name = scalar.to_msl_name();
336 let array_str = if arrayed { "_array" } else { "" };
337 write!(
338 out,
339 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
340 )
341 }
342 crate::TypeInner::Sampler { comparison: _ } => {
343 write!(out, "{NAMESPACE}::sampler")
344 }
345 crate::TypeInner::AccelerationStructure { vertex_return } => {
346 if vertex_return {
347 unimplemented!("metal does not support vertex ray hit return")
348 }
349 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
350 }
351 crate::TypeInner::RayQuery { vertex_return } => {
352 if vertex_return {
353 unimplemented!("metal does not support vertex ray hit return")
354 }
355 write!(out, "{RAY_QUERY_TYPE}")
356 }
357 crate::TypeInner::BindingArray { base, .. } => {
358 let base_tyname = Self {
359 handle: base,
360 first_time: false,
361 ..*self
362 };
363
364 write!(
365 out,
366 "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
367 )
368 }
369 }
370 }
371}
372
373struct TypedGlobalVariable<'a> {
374 module: &'a crate::Module,
375 names: &'a FastHashMap<NameKey, String>,
376 handle: Handle<crate::GlobalVariable>,
377 usage: valid::GlobalUse,
378 reference: bool,
379}
380
381impl TypedGlobalVariable<'_> {
382 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
383 let var = &self.module.global_variables[self.handle];
384 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
385
386 let storage_access = match var.space {
387 crate::AddressSpace::Storage { access } => access,
388 _ => match self.module.types[var.ty].inner {
389 crate::TypeInner::Image {
390 class: crate::ImageClass::Storage { access, .. },
391 ..
392 } => access,
393 crate::TypeInner::BindingArray { base, .. } => {
394 match self.module.types[base].inner {
395 crate::TypeInner::Image {
396 class: crate::ImageClass::Storage { access, .. },
397 ..
398 } => access,
399 _ => crate::StorageAccess::default(),
400 }
401 }
402 _ => crate::StorageAccess::default(),
403 },
404 };
405 let ty_name = TypeContext {
406 handle: var.ty,
407 gctx: self.module.to_ctx(),
408 names: self.names,
409 access: storage_access,
410 first_time: false,
411 };
412
413 let (space, access, reference) = match var.space.to_msl_name() {
414 Some(space) if self.reference => {
415 let access = if var.space.needs_access_qualifier()
416 && !self.usage.intersects(valid::GlobalUse::WRITE)
417 {
418 "const"
419 } else {
420 ""
421 };
422 (space, access, "&")
423 }
424 _ => ("", "", ""),
425 };
426
427 Ok(write!(
428 out,
429 "{}{}{}{}{}{} {}",
430 space,
431 if space.is_empty() { "" } else { " " },
432 ty_name,
433 if access.is_empty() { "" } else { " " },
434 access,
435 reference,
436 name,
437 )?)
438 }
439}
440
441#[derive(Eq, PartialEq, Hash)]
442enum WrappedFunction {
443 UnaryOp {
444 op: crate::UnaryOperator,
445 ty: (Option<crate::VectorSize>, crate::Scalar),
446 },
447 BinaryOp {
448 op: crate::BinaryOperator,
449 left_ty: (Option<crate::VectorSize>, crate::Scalar),
450 right_ty: (Option<crate::VectorSize>, crate::Scalar),
451 },
452 Math {
453 fun: crate::MathFunction,
454 arg_ty: (Option<crate::VectorSize>, crate::Scalar),
455 },
456 Cast {
457 src_scalar: crate::Scalar,
458 vector_size: Option<crate::VectorSize>,
459 dst_scalar: crate::Scalar,
460 },
461 ImageLoad {
462 class: crate::ImageClass,
463 },
464 ImageSample {
465 class: crate::ImageClass,
466 clamp_to_edge: bool,
467 },
468 ImageQuerySize {
469 class: crate::ImageClass,
470 },
471}
472
473pub struct Writer<W> {
474 out: W,
475 names: FastHashMap<NameKey, String>,
476 named_expressions: crate::NamedExpressions,
477 need_bake_expressions: back::NeedBakeExpressions,
479 namer: proc::Namer,
480 wrapped_functions: FastHashSet<WrappedFunction>,
481 #[cfg(test)]
482 put_expression_stack_pointers: FastHashSet<*const ()>,
483 #[cfg(test)]
484 put_block_stack_pointers: FastHashSet<*const ()>,
485 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
488}
489
490impl crate::Scalar {
491 fn to_msl_name(self) -> &'static str {
492 use crate::ScalarKind as Sk;
493 match self {
494 Self {
495 kind: Sk::Float,
496 width: 4,
497 } => "float",
498 Self {
499 kind: Sk::Float,
500 width: 2,
501 } => "half",
502 Self {
503 kind: Sk::Sint,
504 width: 4,
505 } => "int",
506 Self {
507 kind: Sk::Uint,
508 width: 4,
509 } => "uint",
510 Self {
511 kind: Sk::Sint,
512 width: 8,
513 } => "long",
514 Self {
515 kind: Sk::Uint,
516 width: 8,
517 } => "ulong",
518 Self {
519 kind: Sk::Bool,
520 width: _,
521 } => "bool",
522 Self {
523 kind: Sk::AbstractInt | Sk::AbstractFloat,
524 width: _,
525 } => unreachable!("Found Abstract scalar kind"),
526 _ => unreachable!("Unsupported scalar kind: {:?}", self),
527 }
528 }
529}
530
531const fn separate(need_separator: bool) -> &'static str {
532 if need_separator {
533 ","
534 } else {
535 ""
536 }
537}
538
539fn should_pack_struct_member(
540 members: &[crate::StructMember],
541 span: u32,
542 index: usize,
543 module: &crate::Module,
544) -> Option<crate::Scalar> {
545 let member = &members[index];
546
547 let ty_inner = &module.types[member.ty].inner;
548 let last_offset = member.offset + ty_inner.size(module.to_ctx());
549 let next_offset = match members.get(index + 1) {
550 Some(next) => next.offset,
551 None => span,
552 };
553 let is_tight = next_offset == last_offset;
554
555 match *ty_inner {
556 crate::TypeInner::Vector {
557 size: crate::VectorSize::Tri,
558 scalar: scalar @ crate::Scalar { width: 4 | 2, .. },
559 } if is_tight => Some(scalar),
560 _ => None,
561 }
562}
563
564fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
565 match arena[ty].inner {
566 crate::TypeInner::Struct { ref members, .. } => {
567 if let Some(member) = members.last() {
568 if let crate::TypeInner::Array {
569 size: crate::ArraySize::Dynamic,
570 ..
571 } = arena[member.ty].inner
572 {
573 return true;
574 }
575 }
576 false
577 }
578 crate::TypeInner::Array {
579 size: crate::ArraySize::Dynamic,
580 ..
581 } => true,
582 _ => false,
583 }
584}
585
586impl crate::AddressSpace {
587 const fn needs_pass_through(&self) -> bool {
591 match *self {
592 Self::Uniform
593 | Self::Storage { .. }
594 | Self::Private
595 | Self::WorkGroup
596 | Self::PushConstant
597 | Self::Handle
598 | Self::TaskPayload => true,
599 Self::Function => false,
600 }
601 }
602
603 const fn needs_access_qualifier(&self) -> bool {
605 match *self {
606 Self::Storage { .. } => true,
611 Self::TaskPayload => unimplemented!(),
612 Self::Private | Self::WorkGroup => false,
614 Self::Uniform | Self::PushConstant => false,
616 Self::Handle | Self::Function => false,
618 }
619 }
620
621 const fn to_msl_name(self) -> Option<&'static str> {
622 match self {
623 Self::Handle => None,
624 Self::Uniform | Self::PushConstant => Some("constant"),
625 Self::Storage { .. } => Some("device"),
626 Self::Private | Self::Function => Some("thread"),
627 Self::WorkGroup => Some("threadgroup"),
628 Self::TaskPayload => Some("object_data"),
629 }
630 }
631}
632
633impl crate::Type {
634 const fn needs_alias(&self) -> bool {
636 use crate::TypeInner as Ti;
637
638 match self.inner {
639 Ti::Scalar(_)
641 | Ti::Vector { .. }
642 | Ti::Matrix { .. }
643 | Ti::Atomic(_)
644 | Ti::Pointer { .. }
645 | Ti::ValuePointer { .. } => self.name.is_some(),
646 Ti::Struct { .. } | Ti::Array { .. } => true,
648 Ti::Image { .. }
650 | Ti::Sampler { .. }
651 | Ti::AccelerationStructure { .. }
652 | Ti::RayQuery { .. }
653 | Ti::BindingArray { .. } => false,
654 }
655 }
656}
657
658#[derive(Clone, Copy)]
659enum FunctionOrigin {
660 Handle(Handle<crate::Function>),
661 EntryPoint(proc::EntryPointIndex),
662}
663
664trait NameKeyExt {
665 fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
666 match origin {
667 FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
668 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
669 }
670 }
671
672 fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
677 match origin {
678 FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
679 FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
680 }
681 }
682}
683
684impl NameKeyExt for NameKey {}
685
686#[derive(Clone, Copy)]
696enum LevelOfDetail {
697 Direct(Handle<crate::Expression>),
698 Restricted(Handle<crate::Expression>),
699}
700
701struct TexelAddress {
711 coordinate: Handle<crate::Expression>,
712 array_index: Option<Handle<crate::Expression>>,
713 sample: Option<Handle<crate::Expression>>,
714 level: Option<LevelOfDetail>,
715}
716
717struct ExpressionContext<'a> {
718 function: &'a crate::Function,
719 origin: FunctionOrigin,
720 info: &'a valid::FunctionInfo,
721 module: &'a crate::Module,
722 mod_info: &'a valid::ModuleInfo,
723 pipeline_options: &'a PipelineOptions,
724 lang_version: (u8, u8),
725 policies: index::BoundsCheckPolicies,
726
727 guarded_indices: HandleSet<crate::Expression>,
731 force_loop_bounding: bool,
733}
734
735impl<'a> ExpressionContext<'a> {
736 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
737 self.info[handle].ty.inner_with(&self.module.types)
738 }
739
740 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
747 let image_ty = self.resolve_type(image);
748 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
749 class.is_mipmapped() && dim != crate::ImageDimension::D1
750 } else {
751 false
752 }
753 }
754
755 fn choose_bounds_check_policy(
756 &self,
757 pointer: Handle<crate::Expression>,
758 ) -> index::BoundsCheckPolicy {
759 self.policies
760 .choose_policy(pointer, &self.module.types, self.info)
761 }
762
763 fn access_needs_check(
765 &self,
766 base: Handle<crate::Expression>,
767 index: index::GuardedIndex,
768 ) -> Option<index::IndexableLength> {
769 index::access_needs_check(
770 base,
771 index,
772 self.module,
773 &self.function.expressions,
774 self.info,
775 )
776 }
777
778 fn bounds_check_iter(
780 &self,
781 chain: Handle<crate::Expression>,
782 ) -> impl Iterator<Item = BoundsCheck> + '_ {
783 index::bounds_check_iter(chain, self.module, self.function, self.info)
784 }
785
786 fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
788 index::oob_local_types(self.module, self.function, self.info, self.policies)
789 }
790
791 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
792 match self.function.expressions[expr_handle] {
793 crate::Expression::AccessIndex { base, index } => {
794 let ty = match *self.resolve_type(base) {
795 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
796 ref ty => ty,
797 };
798 match *ty {
799 crate::TypeInner::Struct {
800 ref members, span, ..
801 } => should_pack_struct_member(members, span, index as usize, self.module),
802 _ => None,
803 }
804 }
805 _ => None,
806 }
807 }
808}
809
810struct StatementContext<'a> {
811 expression: ExpressionContext<'a>,
812 result_struct: Option<&'a str>,
813}
814
815impl<W: Write> Writer<W> {
816 pub fn new(out: W) -> Self {
818 Writer {
819 out,
820 names: FastHashMap::default(),
821 named_expressions: Default::default(),
822 need_bake_expressions: Default::default(),
823 namer: proc::Namer::default(),
824 wrapped_functions: FastHashSet::default(),
825 #[cfg(test)]
826 put_expression_stack_pointers: Default::default(),
827 #[cfg(test)]
828 put_block_stack_pointers: Default::default(),
829 struct_member_pads: FastHashSet::default(),
830 }
831 }
832
833 #[allow(clippy::missing_const_for_fn)]
836 pub fn finish(self) -> W {
837 self.out
838 }
839
840 fn gen_force_bounded_loop_statements(
947 &mut self,
948 level: back::Level,
949 context: &StatementContext,
950 ) -> Option<(String, String)> {
951 if !context.expression.force_loop_bounding {
952 return None;
953 }
954
955 let loop_bound_name = self.namer.call("loop_bound");
956 let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
959 let level = level.next();
960 let break_and_inc = format!(
961 "{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
962{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
963 );
964
965 Some((decl, break_and_inc))
966 }
967
968 fn put_call_parameters(
969 &mut self,
970 parameters: impl Iterator<Item = Handle<crate::Expression>>,
971 context: &ExpressionContext,
972 ) -> BackendResult {
973 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
974 writer.put_expression(expr, context, true)
975 })
976 }
977
978 fn put_call_parameters_impl<C, E>(
979 &mut self,
980 parameters: impl Iterator<Item = Handle<crate::Expression>>,
981 ctx: &C,
982 put_expression: E,
983 ) -> BackendResult
984 where
985 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
986 {
987 write!(self.out, "(")?;
988 for (i, handle) in parameters.enumerate() {
989 if i != 0 {
990 write!(self.out, ", ")?;
991 }
992 put_expression(self, ctx, handle)?;
993 }
994 write!(self.out, ")")?;
995 Ok(())
996 }
997
998 fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
1004 let oob_local_types = context.oob_local_types();
1005 for &ty in oob_local_types.iter() {
1006 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1007 self.names.insert(name_key, self.namer.call("oob"));
1008 }
1009
1010 for (name_key, ty, init) in context
1011 .function
1012 .local_variables
1013 .iter()
1014 .map(|(local_handle, local)| {
1015 let name_key = NameKey::local(context.origin, local_handle);
1016 (name_key, local.ty, local.init)
1017 })
1018 .chain(oob_local_types.iter().map(|&ty| {
1019 let name_key = NameKey::oob_local_for_type(context.origin, ty);
1020 (name_key, ty, None)
1021 }))
1022 {
1023 let ty_name = TypeContext {
1024 handle: ty,
1025 gctx: context.module.to_ctx(),
1026 names: &self.names,
1027 access: crate::StorageAccess::empty(),
1028 first_time: false,
1029 };
1030 write!(
1031 self.out,
1032 "{}{} {}",
1033 back::INDENT,
1034 ty_name,
1035 self.names[&name_key]
1036 )?;
1037 match init {
1038 Some(value) => {
1039 write!(self.out, " = ")?;
1040 self.put_expression(value, context, true)?;
1041 }
1042 None => {
1043 write!(self.out, " = {{}}")?;
1044 }
1045 };
1046 writeln!(self.out, ";")?;
1047 }
1048 Ok(())
1049 }
1050
1051 fn put_level_of_detail(
1052 &mut self,
1053 level: LevelOfDetail,
1054 context: &ExpressionContext,
1055 ) -> BackendResult {
1056 match level {
1057 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
1058 LevelOfDetail::Restricted(load) => write!(self.out, "{}", ClampedLod(load))?,
1059 }
1060 Ok(())
1061 }
1062
1063 fn put_image_query(
1064 &mut self,
1065 image: Handle<crate::Expression>,
1066 query: &str,
1067 level: Option<LevelOfDetail>,
1068 context: &ExpressionContext,
1069 ) -> BackendResult {
1070 self.put_expression(image, context, false)?;
1071 write!(self.out, ".get_{query}(")?;
1072 if let Some(level) = level {
1073 self.put_level_of_detail(level, context)?;
1074 }
1075 write!(self.out, ")")?;
1076 Ok(())
1077 }
1078
1079 fn put_image_size_query(
1080 &mut self,
1081 image: Handle<crate::Expression>,
1082 level: Option<LevelOfDetail>,
1083 kind: crate::ScalarKind,
1084 context: &ExpressionContext,
1085 ) -> BackendResult {
1086 if let crate::TypeInner::Image {
1087 class: crate::ImageClass::External,
1088 ..
1089 } = *context.resolve_type(image)
1090 {
1091 write!(self.out, "{IMAGE_SIZE_EXTERNAL_FUNCTION}(")?;
1092 self.put_expression(image, context, true)?;
1093 write!(self.out, ")")?;
1094 return Ok(());
1095 }
1096
1097 let dim = match *context.resolve_type(image) {
1100 crate::TypeInner::Image { dim, .. } => dim,
1101 ref other => unreachable!("Unexpected type {:?}", other),
1102 };
1103 let scalar = crate::Scalar { kind, width: 4 };
1104 let coordinate_type = scalar.to_msl_name();
1105 match dim {
1106 crate::ImageDimension::D1 => {
1107 if kind == crate::ScalarKind::Uint {
1111 self.put_image_query(image, "width", None, context)?;
1113 } else {
1114 write!(self.out, "int(")?;
1116 self.put_image_query(image, "width", None, context)?;
1117 write!(self.out, ")")?;
1118 }
1119 }
1120 crate::ImageDimension::D2 => {
1121 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1122 self.put_image_query(image, "width", level, context)?;
1123 write!(self.out, ", ")?;
1124 self.put_image_query(image, "height", level, context)?;
1125 write!(self.out, ")")?;
1126 }
1127 crate::ImageDimension::D3 => {
1128 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
1129 self.put_image_query(image, "width", level, context)?;
1130 write!(self.out, ", ")?;
1131 self.put_image_query(image, "height", level, context)?;
1132 write!(self.out, ", ")?;
1133 self.put_image_query(image, "depth", level, context)?;
1134 write!(self.out, ")")?;
1135 }
1136 crate::ImageDimension::Cube => {
1137 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
1138 self.put_image_query(image, "width", level, context)?;
1139 write!(self.out, ")")?;
1140 }
1141 }
1142 Ok(())
1143 }
1144
1145 fn put_cast_to_uint_scalar_or_vector(
1146 &mut self,
1147 expr: Handle<crate::Expression>,
1148 context: &ExpressionContext,
1149 ) -> BackendResult {
1150 match *context.resolve_type(expr) {
1152 crate::TypeInner::Scalar(_) => {
1153 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
1154 }
1155 crate::TypeInner::Vector { size, .. } => {
1156 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
1157 }
1158 _ => {
1159 return Err(Error::GenericValidation(
1160 "Invalid type for image coordinate".into(),
1161 ))
1162 }
1163 };
1164
1165 write!(self.out, "(")?;
1166 self.put_expression(expr, context, true)?;
1167 write!(self.out, ")")?;
1168 Ok(())
1169 }
1170
1171 fn put_image_sample_level(
1172 &mut self,
1173 image: Handle<crate::Expression>,
1174 level: crate::SampleLevel,
1175 context: &ExpressionContext,
1176 ) -> BackendResult {
1177 let has_levels = context.image_needs_lod(image);
1178 match level {
1179 crate::SampleLevel::Auto => {}
1180 crate::SampleLevel::Zero => {
1181 }
1183 _ if !has_levels => {
1184 log::warn!("1D image can't be sampled with level {level:?}");
1185 }
1186 crate::SampleLevel::Exact(h) => {
1187 write!(self.out, ", {NAMESPACE}::level(")?;
1188 self.put_expression(h, context, true)?;
1189 write!(self.out, ")")?;
1190 }
1191 crate::SampleLevel::Bias(h) => {
1192 write!(self.out, ", {NAMESPACE}::bias(")?;
1193 self.put_expression(h, context, true)?;
1194 write!(self.out, ")")?;
1195 }
1196 crate::SampleLevel::Gradient { x, y } => {
1197 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
1198 self.put_expression(x, context, true)?;
1199 write!(self.out, ", ")?;
1200 self.put_expression(y, context, true)?;
1201 write!(self.out, ")")?;
1202 }
1203 }
1204 Ok(())
1205 }
1206
1207 fn put_image_coordinate_limits(
1208 &mut self,
1209 image: Handle<crate::Expression>,
1210 level: Option<LevelOfDetail>,
1211 context: &ExpressionContext,
1212 ) -> BackendResult {
1213 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1214 write!(self.out, " - 1")?;
1215 Ok(())
1216 }
1217
1218 fn put_restricted_scalar_image_index(
1236 &mut self,
1237 image: Handle<crate::Expression>,
1238 index: Handle<crate::Expression>,
1239 limit_method: &str,
1240 context: &ExpressionContext,
1241 ) -> BackendResult {
1242 write!(self.out, "{NAMESPACE}::min(uint(")?;
1243 self.put_expression(index, context, true)?;
1244 write!(self.out, "), ")?;
1245 self.put_expression(image, context, false)?;
1246 write!(self.out, ".{limit_method}() - 1)")?;
1247 Ok(())
1248 }
1249
1250 fn put_restricted_texel_address(
1251 &mut self,
1252 image: Handle<crate::Expression>,
1253 address: &TexelAddress,
1254 context: &ExpressionContext,
1255 ) -> BackendResult {
1256 write!(self.out, "{NAMESPACE}::min(")?;
1258 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1259 write!(self.out, ", ")?;
1260 self.put_image_coordinate_limits(image, address.level, context)?;
1261 write!(self.out, ")")?;
1262
1263 if let Some(array_index) = address.array_index {
1265 write!(self.out, ", ")?;
1266 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
1267 }
1268
1269 if let Some(sample) = address.sample {
1271 write!(self.out, ", ")?;
1272 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
1273 }
1274
1275 if let Some(level) = address.level {
1278 write!(self.out, ", ")?;
1279 self.put_level_of_detail(level, context)?;
1280 }
1281
1282 Ok(())
1283 }
1284
1285 fn put_image_access_bounds_check(
1287 &mut self,
1288 image: Handle<crate::Expression>,
1289 address: &TexelAddress,
1290 context: &ExpressionContext,
1291 ) -> BackendResult {
1292 let mut conjunction = "";
1293
1294 let level = if let Some(level) = address.level {
1297 write!(self.out, "uint(")?;
1298 self.put_level_of_detail(level, context)?;
1299 write!(self.out, ") < ")?;
1300 self.put_expression(image, context, true)?;
1301 write!(self.out, ".get_num_mip_levels()")?;
1302 conjunction = " && ";
1303 Some(level)
1304 } else {
1305 None
1306 };
1307
1308 if let Some(sample) = address.sample {
1310 write!(self.out, "uint(")?;
1311 self.put_expression(sample, context, true)?;
1312 write!(self.out, ") < ")?;
1313 self.put_expression(image, context, true)?;
1314 write!(self.out, ".get_num_samples()")?;
1315 conjunction = " && ";
1316 }
1317
1318 if let Some(array_index) = address.array_index {
1320 write!(self.out, "{conjunction}uint(")?;
1321 self.put_expression(array_index, context, true)?;
1322 write!(self.out, ") < ")?;
1323 self.put_expression(image, context, true)?;
1324 write!(self.out, ".get_array_size()")?;
1325 conjunction = " && ";
1326 }
1327
1328 let coord_is_vector = match *context.resolve_type(address.coordinate) {
1330 crate::TypeInner::Vector { .. } => true,
1331 _ => false,
1332 };
1333 write!(self.out, "{conjunction}")?;
1334 if coord_is_vector {
1335 write!(self.out, "{NAMESPACE}::all(")?;
1336 }
1337 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1338 write!(self.out, " < ")?;
1339 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
1340 if coord_is_vector {
1341 write!(self.out, ")")?;
1342 }
1343
1344 Ok(())
1345 }
1346
1347 fn put_image_load(
1348 &mut self,
1349 load: Handle<crate::Expression>,
1350 image: Handle<crate::Expression>,
1351 mut address: TexelAddress,
1352 context: &ExpressionContext,
1353 ) -> BackendResult {
1354 if let crate::TypeInner::Image {
1355 class: crate::ImageClass::External,
1356 ..
1357 } = *context.resolve_type(image)
1358 {
1359 write!(self.out, "{IMAGE_LOAD_EXTERNAL_FUNCTION}(")?;
1360 self.put_expression(image, context, true)?;
1361 write!(self.out, ", ")?;
1362 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1363 write!(self.out, ")")?;
1364 return Ok(());
1365 }
1366
1367 match context.policies.image_load {
1368 proc::BoundsCheckPolicy::Restrict => {
1369 if address.level.is_some() {
1372 address.level = if context.image_needs_lod(image) {
1373 Some(LevelOfDetail::Restricted(load))
1374 } else {
1375 None
1376 }
1377 }
1378
1379 self.put_expression(image, context, false)?;
1380 write!(self.out, ".read(")?;
1381 self.put_restricted_texel_address(image, &address, context)?;
1382 write!(self.out, ")")?;
1383 }
1384 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1385 write!(self.out, "(")?;
1386 self.put_image_access_bounds_check(image, &address, context)?;
1387 write!(self.out, " ? ")?;
1388 self.put_unchecked_image_load(image, &address, context)?;
1389 write!(self.out, ": DefaultConstructible())")?;
1390 }
1391 proc::BoundsCheckPolicy::Unchecked => {
1392 self.put_unchecked_image_load(image, &address, context)?;
1393 }
1394 }
1395
1396 Ok(())
1397 }
1398
1399 fn put_unchecked_image_load(
1400 &mut self,
1401 image: Handle<crate::Expression>,
1402 address: &TexelAddress,
1403 context: &ExpressionContext,
1404 ) -> BackendResult {
1405 self.put_expression(image, context, false)?;
1406 write!(self.out, ".read(")?;
1407 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
1409 if let Some(expr) = address.array_index {
1410 write!(self.out, ", ")?;
1411 self.put_expression(expr, context, true)?;
1412 }
1413 if let Some(sample) = address.sample {
1414 write!(self.out, ", ")?;
1415 self.put_expression(sample, context, true)?;
1416 }
1417 if let Some(level) = address.level {
1418 if context.image_needs_lod(image) {
1419 write!(self.out, ", ")?;
1420 self.put_level_of_detail(level, context)?;
1421 }
1422 }
1423 write!(self.out, ")")?;
1424
1425 Ok(())
1426 }
1427
1428 fn put_image_atomic(
1429 &mut self,
1430 level: back::Level,
1431 image: Handle<crate::Expression>,
1432 address: &TexelAddress,
1433 fun: crate::AtomicFunction,
1434 value: Handle<crate::Expression>,
1435 context: &StatementContext,
1436 ) -> BackendResult {
1437 write!(self.out, "{level}")?;
1438 self.put_expression(image, &context.expression, false)?;
1439 let op = if context.expression.resolve_type(value).scalar_width() == Some(8) {
1440 fun.to_msl_64_bit()?
1441 } else {
1442 fun.to_msl()
1443 };
1444 write!(self.out, ".atomic_{op}(")?;
1445 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1447 write!(self.out, ", ")?;
1448 self.put_expression(value, &context.expression, true)?;
1449 writeln!(self.out, ");")?;
1450
1451 Ok(())
1452 }
1453
1454 fn put_image_store(
1455 &mut self,
1456 level: back::Level,
1457 image: Handle<crate::Expression>,
1458 address: &TexelAddress,
1459 value: Handle<crate::Expression>,
1460 context: &StatementContext,
1461 ) -> BackendResult {
1462 write!(self.out, "{level}")?;
1463 self.put_expression(image, &context.expression, false)?;
1464 write!(self.out, ".write(")?;
1465 self.put_expression(value, &context.expression, true)?;
1466 write!(self.out, ", ")?;
1467 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1469 if let Some(expr) = address.array_index {
1470 write!(self.out, ", ")?;
1471 self.put_expression(expr, &context.expression, true)?;
1472 }
1473 writeln!(self.out, ");")?;
1474
1475 Ok(())
1476 }
1477
1478 fn put_dynamic_array_max_index(
1489 &mut self,
1490 handle: Handle<crate::GlobalVariable>,
1491 context: &ExpressionContext,
1492 ) -> BackendResult {
1493 let global = &context.module.global_variables[handle];
1494 let (offset, array_ty) = match context.module.types[global.ty].inner {
1495 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1496 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1497 None => return Err(Error::GenericValidation("Struct has no members".into())),
1498 },
1499 crate::TypeInner::Array {
1500 size: crate::ArraySize::Dynamic,
1501 ..
1502 } => (0, global.ty),
1503 ref ty => {
1504 return Err(Error::GenericValidation(format!(
1505 "Expected type with dynamic array, got {ty:?}"
1506 )))
1507 }
1508 };
1509
1510 let (size, stride) = match context.module.types[array_ty].inner {
1511 crate::TypeInner::Array { base, stride, .. } => (
1512 context.module.types[base]
1513 .inner
1514 .size(context.module.to_ctx()),
1515 stride,
1516 ),
1517 ref ty => {
1518 return Err(Error::GenericValidation(format!(
1519 "Expected array type, got {ty:?}"
1520 )))
1521 }
1522 };
1523
1524 write!(
1537 self.out,
1538 "(_buffer_sizes.{member} - {offset} - {size}) / {stride}",
1539 member = ArraySizeMember(handle),
1540 offset = offset,
1541 size = size,
1542 stride = stride,
1543 )?;
1544 Ok(())
1545 }
1546
1547 fn put_dot_product<T: Copy>(
1552 &mut self,
1553 arg: T,
1554 arg1: T,
1555 size: usize,
1556 extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
1557 ) -> BackendResult {
1558 write!(self.out, "(")?;
1561
1562 for index in 0..size {
1564 write!(self.out, " + ")?;
1567 extractor(self, arg, index)?;
1568 write!(self.out, " * ")?;
1569 extractor(self, arg1, index)?;
1570 }
1571
1572 write!(self.out, ")")?;
1573 Ok(())
1574 }
1575
1576 fn put_pack4x8(
1578 &mut self,
1579 arg: Handle<crate::Expression>,
1580 context: &ExpressionContext<'_>,
1581 was_signed: bool,
1582 clamp_bounds: Option<(&str, &str)>,
1583 ) -> Result<(), Error> {
1584 let write_arg = |this: &mut Self| -> BackendResult {
1585 if let Some((min, max)) = clamp_bounds {
1586 write!(this.out, "{NAMESPACE}::clamp(")?;
1588 this.put_expression(arg, context, true)?;
1589 write!(this.out, ", {min}, {max})")?;
1590 } else {
1591 this.put_expression(arg, context, true)?;
1592 }
1593 Ok(())
1594 };
1595
1596 if context.lang_version >= (2, 1) {
1597 let packed_type = if was_signed {
1598 "packed_char4"
1599 } else {
1600 "packed_uchar4"
1601 };
1602 write!(self.out, "as_type<uint>({packed_type}(")?;
1604 write_arg(self)?;
1605 write!(self.out, "))")?;
1606 } else {
1607 if was_signed {
1609 write!(self.out, "uint(")?;
1610 }
1611 write!(self.out, "(")?;
1612 write_arg(self)?;
1613 write!(self.out, "[0] & 0xFF) | ((")?;
1614 write_arg(self)?;
1615 write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1616 write_arg(self)?;
1617 write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1618 write_arg(self)?;
1619 write!(self.out, "[3] & 0xFF) << 24)")?;
1620 if was_signed {
1621 write!(self.out, ")")?;
1622 }
1623 }
1624
1625 Ok(())
1626 }
1627
1628 fn put_isign(
1631 &mut self,
1632 arg: Handle<crate::Expression>,
1633 context: &ExpressionContext,
1634 ) -> BackendResult {
1635 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1636 let scalar = context
1637 .resolve_type(arg)
1638 .scalar()
1639 .expect("put_isign should only be called for args which have an integer scalar type")
1640 .to_msl_name();
1641 match context.resolve_type(arg) {
1642 &crate::TypeInner::Vector { size, .. } => {
1643 let size = common::vector_size_str(size);
1644 write!(self.out, "{scalar}{size}(-1), {scalar}{size}(1)")?;
1645 }
1646 _ => {
1647 write!(self.out, "{scalar}(-1), {scalar}(1)")?;
1648 }
1649 }
1650 write!(self.out, ", (")?;
1651 self.put_expression(arg, context, true)?;
1652 write!(self.out, " > 0)), {scalar}(0), (")?;
1653 self.put_expression(arg, context, true)?;
1654 write!(self.out, " == 0))")?;
1655 Ok(())
1656 }
1657
1658 fn put_const_expression(
1659 &mut self,
1660 expr_handle: Handle<crate::Expression>,
1661 module: &crate::Module,
1662 mod_info: &valid::ModuleInfo,
1663 arena: &crate::Arena<crate::Expression>,
1664 ) -> BackendResult {
1665 self.put_possibly_const_expression(
1666 expr_handle,
1667 arena,
1668 module,
1669 mod_info,
1670 &(module, mod_info),
1671 |&(_, mod_info), expr| &mod_info[expr],
1672 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info, arena),
1673 )
1674 }
1675
1676 fn put_literal(&mut self, literal: crate::Literal) -> BackendResult {
1677 match literal {
1678 crate::Literal::F64(_) => {
1679 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1680 }
1681 crate::Literal::F16(value) => {
1682 if value.is_infinite() {
1683 let sign = if value.is_sign_negative() { "-" } else { "" };
1684 write!(self.out, "{sign}INFINITY")?;
1685 } else if value.is_nan() {
1686 write!(self.out, "NAN")?;
1687 } else {
1688 let suffix = if value.fract() == f16::from_f32(0.0) {
1689 ".0h"
1690 } else {
1691 "h"
1692 };
1693 write!(self.out, "{value}{suffix}")?;
1694 }
1695 }
1696 crate::Literal::F32(value) => {
1697 if value.is_infinite() {
1698 let sign = if value.is_sign_negative() { "-" } else { "" };
1699 write!(self.out, "{sign}INFINITY")?;
1700 } else if value.is_nan() {
1701 write!(self.out, "NAN")?;
1702 } else {
1703 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1704 write!(self.out, "{value}{suffix}")?;
1705 }
1706 }
1707 crate::Literal::U32(value) => {
1708 write!(self.out, "{value}u")?;
1709 }
1710 crate::Literal::I32(value) => {
1711 if value == i32::MIN {
1716 write!(self.out, "({} - 1)", value + 1)?;
1717 } else {
1718 write!(self.out, "{value}")?;
1719 }
1720 }
1721 crate::Literal::U64(value) => {
1722 write!(self.out, "{value}uL")?;
1723 }
1724 crate::Literal::I64(value) => {
1725 if value == i64::MIN {
1732 write!(self.out, "({}L - 1L)", value + 1)?;
1733 } else {
1734 write!(self.out, "{value}L")?;
1735 }
1736 }
1737 crate::Literal::Bool(value) => {
1738 write!(self.out, "{value}")?;
1739 }
1740 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1741 return Err(Error::GenericValidation(
1742 "Unsupported abstract literal".into(),
1743 ));
1744 }
1745 }
1746 Ok(())
1747 }
1748
1749 #[allow(clippy::too_many_arguments)]
1750 fn put_possibly_const_expression<C, I, E>(
1751 &mut self,
1752 expr_handle: Handle<crate::Expression>,
1753 expressions: &crate::Arena<crate::Expression>,
1754 module: &crate::Module,
1755 mod_info: &valid::ModuleInfo,
1756 ctx: &C,
1757 get_expr_ty: I,
1758 put_expression: E,
1759 ) -> BackendResult
1760 where
1761 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1762 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1763 {
1764 match expressions[expr_handle] {
1765 crate::Expression::Literal(literal) => {
1766 self.put_literal(literal)?;
1767 }
1768 crate::Expression::Constant(handle) => {
1769 let constant = &module.constants[handle];
1770 if constant.name.is_some() {
1771 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1772 } else {
1773 self.put_const_expression(
1774 constant.init,
1775 module,
1776 mod_info,
1777 &module.global_expressions,
1778 )?;
1779 }
1780 }
1781 crate::Expression::ZeroValue(ty) => {
1782 let ty_name = TypeContext {
1783 handle: ty,
1784 gctx: module.to_ctx(),
1785 names: &self.names,
1786 access: crate::StorageAccess::empty(),
1787 first_time: false,
1788 };
1789 write!(self.out, "{ty_name} {{}}")?;
1790 }
1791 crate::Expression::Compose { ty, ref components } => {
1792 let ty_name = TypeContext {
1793 handle: ty,
1794 gctx: module.to_ctx(),
1795 names: &self.names,
1796 access: crate::StorageAccess::empty(),
1797 first_time: false,
1798 };
1799 write!(self.out, "{ty_name}")?;
1800 match module.types[ty].inner {
1801 crate::TypeInner::Scalar(_)
1802 | crate::TypeInner::Vector { .. }
1803 | crate::TypeInner::Matrix { .. } => {
1804 self.put_call_parameters_impl(
1805 components.iter().copied(),
1806 ctx,
1807 put_expression,
1808 )?;
1809 }
1810 crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1811 write!(self.out, " {{")?;
1812 for (index, &component) in components.iter().enumerate() {
1813 if index != 0 {
1814 write!(self.out, ", ")?;
1815 }
1816 if self.struct_member_pads.contains(&(ty, index as u32)) {
1818 write!(self.out, "{{}}, ")?;
1819 }
1820 put_expression(self, ctx, component)?;
1821 }
1822 write!(self.out, "}}")?;
1823 }
1824 _ => return Err(Error::UnsupportedCompose(ty)),
1825 }
1826 }
1827 crate::Expression::Splat { size, value } => {
1828 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1829 crate::TypeInner::Scalar(scalar) => scalar,
1830 ref ty => {
1831 return Err(Error::GenericValidation(format!(
1832 "Expected splat value type must be a scalar, got {ty:?}",
1833 )))
1834 }
1835 };
1836 put_numeric_type(&mut self.out, scalar, &[size])?;
1837 write!(self.out, "(")?;
1838 put_expression(self, ctx, value)?;
1839 write!(self.out, ")")?;
1840 }
1841 _ => {
1842 return Err(Error::Override);
1843 }
1844 }
1845
1846 Ok(())
1847 }
1848
1849 fn put_expression(
1861 &mut self,
1862 expr_handle: Handle<crate::Expression>,
1863 context: &ExpressionContext,
1864 is_scoped: bool,
1865 ) -> BackendResult {
1866 #[cfg(test)]
1868 self.put_expression_stack_pointers
1869 .insert(ptr::from_ref(&expr_handle).cast());
1870
1871 if let Some(name) = self.named_expressions.get(&expr_handle) {
1872 write!(self.out, "{name}")?;
1873 return Ok(());
1874 }
1875
1876 let expression = &context.function.expressions[expr_handle];
1877 match *expression {
1878 crate::Expression::Literal(_)
1879 | crate::Expression::Constant(_)
1880 | crate::Expression::ZeroValue(_)
1881 | crate::Expression::Compose { .. }
1882 | crate::Expression::Splat { .. } => {
1883 self.put_possibly_const_expression(
1884 expr_handle,
1885 &context.function.expressions,
1886 context.module,
1887 context.mod_info,
1888 context,
1889 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1890 |writer, context, expr| writer.put_expression(expr, context, true),
1891 )?;
1892 }
1893 crate::Expression::Override(_) => return Err(Error::Override),
1894 crate::Expression::Access { base, .. }
1895 | crate::Expression::AccessIndex { base, .. } => {
1896 let policy = context.choose_bounds_check_policy(base);
1902 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1903 && self.put_bounds_checks(
1904 expr_handle,
1905 context,
1906 back::Level(0),
1907 if is_scoped { "" } else { "(" },
1908 )?
1909 {
1910 write!(self.out, " ? ")?;
1911 self.put_access_chain(expr_handle, policy, context)?;
1912 write!(self.out, " : ")?;
1913
1914 if context.resolve_type(base).pointer_space().is_some() {
1915 let result_ty = context.info[expr_handle]
1919 .ty
1920 .inner_with(&context.module.types)
1921 .pointer_base_type();
1922 let result_ty_handle = match result_ty {
1923 Some(TypeResolution::Handle(handle)) => handle,
1924 Some(TypeResolution::Value(_)) => {
1925 unreachable!(
1933 "Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
1934 );
1935 }
1936 None => {
1937 unreachable!(
1938 "Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
1939 )
1940 }
1941 };
1942 let name_key =
1943 NameKey::oob_local_for_type(context.origin, result_ty_handle);
1944 self.out.write_str(&self.names[&name_key])?;
1945 } else {
1946 write!(self.out, "DefaultConstructible()")?;
1947 }
1948
1949 if !is_scoped {
1950 write!(self.out, ")")?;
1951 }
1952 } else {
1953 self.put_access_chain(expr_handle, policy, context)?;
1954 }
1955 }
1956 crate::Expression::Swizzle {
1957 size,
1958 vector,
1959 pattern,
1960 } => {
1961 self.put_wrapped_expression_for_packed_vec3_access(
1962 vector,
1963 context,
1964 false,
1965 &Self::put_expression,
1966 )?;
1967 write!(self.out, ".")?;
1968 for &sc in pattern[..size as usize].iter() {
1969 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1970 }
1971 }
1972 crate::Expression::FunctionArgument(index) => {
1973 let name_key = match context.origin {
1974 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1975 FunctionOrigin::EntryPoint(ep_index) => {
1976 NameKey::EntryPointArgument(ep_index, index)
1977 }
1978 };
1979 let name = &self.names[&name_key];
1980 write!(self.out, "{name}")?;
1981 }
1982 crate::Expression::GlobalVariable(handle) => {
1983 let name = &self.names[&NameKey::GlobalVariable(handle)];
1984 write!(self.out, "{name}")?;
1985 }
1986 crate::Expression::LocalVariable(handle) => {
1987 let name_key = NameKey::local(context.origin, handle);
1988 let name = &self.names[&name_key];
1989 write!(self.out, "{name}")?;
1990 }
1991 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1992 crate::Expression::ImageSample {
1993 coordinate,
1994 image,
1995 sampler,
1996 clamp_to_edge: true,
1997 gather: None,
1998 array_index: None,
1999 offset: None,
2000 level: crate::SampleLevel::Zero,
2001 depth_ref: None,
2002 } => {
2003 write!(self.out, "{IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}(")?;
2004 self.put_expression(image, context, true)?;
2005 write!(self.out, ", ")?;
2006 self.put_expression(sampler, context, true)?;
2007 write!(self.out, ", ")?;
2008 self.put_expression(coordinate, context, true)?;
2009 write!(self.out, ")")?;
2010 }
2011 crate::Expression::ImageSample {
2012 image,
2013 sampler,
2014 gather,
2015 coordinate,
2016 array_index,
2017 offset,
2018 level,
2019 depth_ref,
2020 clamp_to_edge,
2021 } => {
2022 if clamp_to_edge {
2023 return Err(Error::GenericValidation(
2024 "ImageSample::clamp_to_edge should have been validated out".to_string(),
2025 ));
2026 }
2027
2028 let main_op = match gather {
2029 Some(_) => "gather",
2030 None => "sample",
2031 };
2032 let comparison_op = match depth_ref {
2033 Some(_) => "_compare",
2034 None => "",
2035 };
2036 self.put_expression(image, context, false)?;
2037 write!(self.out, ".{main_op}{comparison_op}(")?;
2038 self.put_expression(sampler, context, true)?;
2039 write!(self.out, ", ")?;
2040 self.put_expression(coordinate, context, true)?;
2041 if let Some(expr) = array_index {
2042 write!(self.out, ", ")?;
2043 self.put_expression(expr, context, true)?;
2044 }
2045 if let Some(dref) = depth_ref {
2046 write!(self.out, ", ")?;
2047 self.put_expression(dref, context, true)?;
2048 }
2049
2050 self.put_image_sample_level(image, level, context)?;
2051
2052 if let Some(offset) = offset {
2053 write!(self.out, ", ")?;
2054 self.put_expression(offset, context, true)?;
2055 }
2056
2057 match gather {
2058 None | Some(crate::SwizzleComponent::X) => {}
2059 Some(component) => {
2060 let is_cube_map = match *context.resolve_type(image) {
2061 crate::TypeInner::Image {
2062 dim: crate::ImageDimension::Cube,
2063 ..
2064 } => true,
2065 _ => false,
2066 };
2067 if offset.is_none() && !is_cube_map {
2070 write!(self.out, ", {NAMESPACE}::int2(0)")?;
2071 }
2072 let letter = back::COMPONENTS[component as usize];
2073 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
2074 }
2075 }
2076 write!(self.out, ")")?;
2077 }
2078 crate::Expression::ImageLoad {
2079 image,
2080 coordinate,
2081 array_index,
2082 sample,
2083 level,
2084 } => {
2085 let address = TexelAddress {
2086 coordinate,
2087 array_index,
2088 sample,
2089 level: level.map(LevelOfDetail::Direct),
2090 };
2091 self.put_image_load(expr_handle, image, address, context)?;
2092 }
2093 crate::Expression::ImageQuery { image, query } => match query {
2096 crate::ImageQuery::Size { level } => {
2097 self.put_image_size_query(
2098 image,
2099 level.map(LevelOfDetail::Direct),
2100 crate::ScalarKind::Uint,
2101 context,
2102 )?;
2103 }
2104 crate::ImageQuery::NumLevels => {
2105 self.put_expression(image, context, false)?;
2106 write!(self.out, ".get_num_mip_levels()")?;
2107 }
2108 crate::ImageQuery::NumLayers => {
2109 self.put_expression(image, context, false)?;
2110 write!(self.out, ".get_array_size()")?;
2111 }
2112 crate::ImageQuery::NumSamples => {
2113 self.put_expression(image, context, false)?;
2114 write!(self.out, ".get_num_samples()")?;
2115 }
2116 },
2117 crate::Expression::Unary { op, expr } => {
2118 let op_str = match op {
2119 crate::UnaryOperator::Negate => {
2120 match context.resolve_type(expr).scalar_kind() {
2121 Some(crate::ScalarKind::Sint) => NEG_FUNCTION,
2122 _ => "-",
2123 }
2124 }
2125 crate::UnaryOperator::LogicalNot => "!",
2126 crate::UnaryOperator::BitwiseNot => "~",
2127 };
2128 write!(self.out, "{op_str}(")?;
2129 self.put_expression(expr, context, false)?;
2130 write!(self.out, ")")?;
2131 }
2132 crate::Expression::Binary { op, left, right } => {
2133 let kind = context
2134 .resolve_type(left)
2135 .scalar_kind()
2136 .ok_or(Error::UnsupportedBinaryOp(op))?;
2137
2138 if op == crate::BinaryOperator::Divide
2139 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2140 {
2141 write!(self.out, "{DIV_FUNCTION}(")?;
2142 self.put_expression(left, context, true)?;
2143 write!(self.out, ", ")?;
2144 self.put_expression(right, context, true)?;
2145 write!(self.out, ")")?;
2146 } else if op == crate::BinaryOperator::Modulo
2147 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2148 {
2149 write!(self.out, "{MOD_FUNCTION}(")?;
2150 self.put_expression(left, context, true)?;
2151 write!(self.out, ", ")?;
2152 self.put_expression(right, context, true)?;
2153 write!(self.out, ")")?;
2154 } else if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
2155 write!(self.out, "{NAMESPACE}::fmod(")?;
2160 self.put_expression(left, context, true)?;
2161 write!(self.out, ", ")?;
2162 self.put_expression(right, context, true)?;
2163 write!(self.out, ")")?;
2164 } else if (op == crate::BinaryOperator::Add
2165 || op == crate::BinaryOperator::Subtract
2166 || op == crate::BinaryOperator::Multiply)
2167 && kind == crate::ScalarKind::Sint
2168 {
2169 let to_unsigned = |ty: &crate::TypeInner| match *ty {
2170 crate::TypeInner::Scalar(scalar) => {
2171 Ok(crate::TypeInner::Scalar(crate::Scalar {
2172 kind: crate::ScalarKind::Uint,
2173 ..scalar
2174 }))
2175 }
2176 crate::TypeInner::Vector { size, scalar } => Ok(crate::TypeInner::Vector {
2177 size,
2178 scalar: crate::Scalar {
2179 kind: crate::ScalarKind::Uint,
2180 ..scalar
2181 },
2182 }),
2183 _ => Err(Error::UnsupportedBitCast(ty.clone())),
2184 };
2185
2186 self.put_bitcasted_expression(
2191 context.resolve_type(expr_handle),
2192 context,
2193 &|writer, context, is_scoped| {
2194 writer.put_binop(
2195 op,
2196 left,
2197 right,
2198 context,
2199 is_scoped,
2200 &|writer, expr, context, _is_scoped| {
2201 writer.put_bitcasted_expression(
2202 &to_unsigned(context.resolve_type(expr))?,
2203 context,
2204 &|writer, context, is_scoped| {
2205 writer.put_expression(expr, context, is_scoped)
2206 },
2207 )
2208 },
2209 )
2210 },
2211 )?;
2212 } else {
2213 self.put_binop(op, left, right, context, is_scoped, &Self::put_expression)?;
2214 }
2215 }
2216 crate::Expression::Select {
2217 condition,
2218 accept,
2219 reject,
2220 } => match *context.resolve_type(condition) {
2221 crate::TypeInner::Scalar(crate::Scalar {
2222 kind: crate::ScalarKind::Bool,
2223 ..
2224 }) => {
2225 if !is_scoped {
2226 write!(self.out, "(")?;
2227 }
2228 self.put_expression(condition, context, false)?;
2229 write!(self.out, " ? ")?;
2230 self.put_expression(accept, context, false)?;
2231 write!(self.out, " : ")?;
2232 self.put_expression(reject, context, false)?;
2233 if !is_scoped {
2234 write!(self.out, ")")?;
2235 }
2236 }
2237 crate::TypeInner::Vector {
2238 scalar:
2239 crate::Scalar {
2240 kind: crate::ScalarKind::Bool,
2241 ..
2242 },
2243 ..
2244 } => {
2245 write!(self.out, "{NAMESPACE}::select(")?;
2246 self.put_expression(reject, context, true)?;
2247 write!(self.out, ", ")?;
2248 self.put_expression(accept, context, true)?;
2249 write!(self.out, ", ")?;
2250 self.put_expression(condition, context, true)?;
2251 write!(self.out, ")")?;
2252 }
2253 ref ty => {
2254 return Err(Error::GenericValidation(format!(
2255 "Expected select condition to be a non-bool type, got {ty:?}",
2256 )))
2257 }
2258 },
2259 crate::Expression::Derivative { axis, expr, .. } => {
2260 use crate::DerivativeAxis as Axis;
2261 let op = match axis {
2262 Axis::X => "dfdx",
2263 Axis::Y => "dfdy",
2264 Axis::Width => "fwidth",
2265 };
2266 write!(self.out, "{NAMESPACE}::{op}")?;
2267 self.put_call_parameters(iter::once(expr), context)?;
2268 }
2269 crate::Expression::Relational { fun, argument } => {
2270 let op = match fun {
2271 crate::RelationalFunction::Any => "any",
2272 crate::RelationalFunction::All => "all",
2273 crate::RelationalFunction::IsNan => "isnan",
2274 crate::RelationalFunction::IsInf => "isinf",
2275 };
2276 write!(self.out, "{NAMESPACE}::{op}")?;
2277 self.put_call_parameters(iter::once(argument), context)?;
2278 }
2279 crate::Expression::Math {
2280 fun,
2281 arg,
2282 arg1,
2283 arg2,
2284 arg3,
2285 } => {
2286 use crate::MathFunction as Mf;
2287
2288 let arg_type = context.resolve_type(arg);
2289 let scalar_argument = match arg_type {
2290 &crate::TypeInner::Scalar(_) => true,
2291 _ => false,
2292 };
2293
2294 let fun_name = match fun {
2295 Mf::Abs => "abs",
2297 Mf::Min => "min",
2298 Mf::Max => "max",
2299 Mf::Clamp => "clamp",
2300 Mf::Saturate => "saturate",
2301 Mf::Cos => "cos",
2303 Mf::Cosh => "cosh",
2304 Mf::Sin => "sin",
2305 Mf::Sinh => "sinh",
2306 Mf::Tan => "tan",
2307 Mf::Tanh => "tanh",
2308 Mf::Acos => "acos",
2309 Mf::Asin => "asin",
2310 Mf::Atan => "atan",
2311 Mf::Atan2 => "atan2",
2312 Mf::Asinh => "asinh",
2313 Mf::Acosh => "acosh",
2314 Mf::Atanh => "atanh",
2315 Mf::Radians => "",
2316 Mf::Degrees => "",
2317 Mf::Ceil => "ceil",
2319 Mf::Floor => "floor",
2320 Mf::Round => "rint",
2321 Mf::Fract => "fract",
2322 Mf::Trunc => "trunc",
2323 Mf::Modf => MODF_FUNCTION,
2324 Mf::Frexp => FREXP_FUNCTION,
2325 Mf::Ldexp => "ldexp",
2326 Mf::Exp => "exp",
2328 Mf::Exp2 => "exp2",
2329 Mf::Log => "log",
2330 Mf::Log2 => "log2",
2331 Mf::Pow => "pow",
2332 Mf::Dot => match *context.resolve_type(arg) {
2334 crate::TypeInner::Vector {
2335 scalar:
2336 crate::Scalar {
2337 kind: crate::ScalarKind::Float,
2338 ..
2339 },
2340 ..
2341 } => "dot",
2342 crate::TypeInner::Vector { size, .. } => {
2343 return self.put_dot_product(
2344 arg,
2345 arg1.unwrap(),
2346 size as usize,
2347 |writer, arg, index| {
2348 writer.put_expression(arg, context, true)?;
2352 write!(writer.out, ".{}", back::COMPONENTS[index])?;
2354 Ok(())
2355 },
2356 );
2357 }
2358 _ => unreachable!(
2359 "Correct TypeInner for dot product should be already validated"
2360 ),
2361 },
2362 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2363 if context.lang_version >= (2, 1) {
2364 let packed_type = match fun {
2368 Mf::Dot4I8Packed => "packed_char4",
2369 Mf::Dot4U8Packed => "packed_uchar4",
2370 _ => unreachable!(),
2371 };
2372
2373 return self.put_dot_product(
2374 Reinterpreted::new(packed_type, arg),
2375 Reinterpreted::new(packed_type, arg1.unwrap()),
2376 4,
2377 |writer, arg, index| {
2378 write!(writer.out, "{arg}[{index}]")?;
2381 Ok(())
2382 },
2383 );
2384 } else {
2385 let conversion = match fun {
2389 Mf::Dot4I8Packed => "int",
2390 Mf::Dot4U8Packed => "",
2391 _ => unreachable!(),
2392 };
2393
2394 return self.put_dot_product(
2395 arg,
2396 arg1.unwrap(),
2397 4,
2398 |writer, arg, index| {
2399 write!(writer.out, "({conversion}(")?;
2400 writer.put_expression(arg, context, true)?;
2401 if index == 3 {
2402 write!(writer.out, ") >> 24)")?;
2403 } else {
2404 write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2405 }
2406 Ok(())
2407 },
2408 );
2409 }
2410 }
2411 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2412 Mf::Cross => "cross",
2413 Mf::Distance => "distance",
2414 Mf::Length if scalar_argument => "abs",
2415 Mf::Length => "length",
2416 Mf::Normalize => "normalize",
2417 Mf::FaceForward => "faceforward",
2418 Mf::Reflect => "reflect",
2419 Mf::Refract => "refract",
2420 Mf::Sign => match arg_type.scalar_kind() {
2422 Some(crate::ScalarKind::Sint) => {
2423 return self.put_isign(arg, context);
2424 }
2425 _ => "sign",
2426 },
2427 Mf::Fma => "fma",
2428 Mf::Mix => "mix",
2429 Mf::Step => "step",
2430 Mf::SmoothStep => "smoothstep",
2431 Mf::Sqrt => "sqrt",
2432 Mf::InverseSqrt => "rsqrt",
2433 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
2434 Mf::Transpose => "transpose",
2435 Mf::Determinant => "determinant",
2436 Mf::QuantizeToF16 => "",
2437 Mf::CountTrailingZeros => "ctz",
2439 Mf::CountLeadingZeros => "clz",
2440 Mf::CountOneBits => "popcount",
2441 Mf::ReverseBits => "reverse_bits",
2442 Mf::ExtractBits => "",
2443 Mf::InsertBits => "",
2444 Mf::FirstTrailingBit => "",
2445 Mf::FirstLeadingBit => "",
2446 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
2448 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
2449 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
2450 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
2451 Mf::Pack2x16float => "",
2452 Mf::Pack4xI8 => "",
2453 Mf::Pack4xU8 => "",
2454 Mf::Pack4xI8Clamp => "",
2455 Mf::Pack4xU8Clamp => "",
2456 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
2458 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
2459 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
2460 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
2461 Mf::Unpack2x16float => "",
2462 Mf::Unpack4xI8 => "",
2463 Mf::Unpack4xU8 => "",
2464 };
2465
2466 match fun {
2467 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
2468 if context.lang_version < (1, 2) {
2477 return Err(Error::UnsupportedFunction(fun_name.to_string()));
2478 }
2479 }
2480 _ => {}
2481 }
2482
2483 match fun {
2484 Mf::Abs if arg_type.scalar_kind() == Some(crate::ScalarKind::Sint) => {
2485 write!(self.out, "{ABS_FUNCTION}(")?;
2486 self.put_expression(arg, context, true)?;
2487 write!(self.out, ")")?;
2488 }
2489 Mf::Distance if scalar_argument => {
2490 write!(self.out, "{NAMESPACE}::abs(")?;
2491 self.put_expression(arg, context, false)?;
2492 write!(self.out, " - ")?;
2493 self.put_expression(arg1.unwrap(), context, false)?;
2494 write!(self.out, ")")?;
2495 }
2496 Mf::FirstTrailingBit => {
2497 let scalar = context.resolve_type(arg).scalar().unwrap();
2498 let constant = scalar.width * 8 + 1;
2499
2500 write!(self.out, "((({NAMESPACE}::ctz(")?;
2501 self.put_expression(arg, context, true)?;
2502 write!(self.out, ") + 1) % {constant}) - 1)")?;
2503 }
2504 Mf::FirstLeadingBit => {
2505 let inner = context.resolve_type(arg);
2506 let scalar = inner.scalar().unwrap();
2507 let constant = scalar.width * 8 - 1;
2508
2509 write!(
2510 self.out,
2511 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
2512 )?;
2513
2514 if scalar.kind == crate::ScalarKind::Sint {
2515 write!(self.out, "{NAMESPACE}::select(")?;
2516 self.put_expression(arg, context, true)?;
2517 write!(self.out, ", ~")?;
2518 self.put_expression(arg, context, true)?;
2519 write!(self.out, ", ")?;
2520 self.put_expression(arg, context, true)?;
2521 write!(self.out, " < 0)")?;
2522 } else {
2523 self.put_expression(arg, context, true)?;
2524 }
2525
2526 write!(self.out, "), ")?;
2527
2528 match *inner {
2530 crate::TypeInner::Vector { size, scalar } => {
2531 let size = common::vector_size_str(size);
2532 let name = scalar.to_msl_name();
2533 write!(self.out, "{name}{size}")?;
2534 }
2535 crate::TypeInner::Scalar(scalar) => {
2536 let name = scalar.to_msl_name();
2537 write!(self.out, "{name}")?;
2538 }
2539 _ => (),
2540 }
2541
2542 write!(self.out, "(-1), ")?;
2543 self.put_expression(arg, context, true)?;
2544 write!(self.out, " == 0 || ")?;
2545 self.put_expression(arg, context, true)?;
2546 write!(self.out, " == -1)")?;
2547 }
2548 Mf::Unpack2x16float => {
2549 write!(self.out, "float2(as_type<half2>(")?;
2550 self.put_expression(arg, context, false)?;
2551 write!(self.out, "))")?;
2552 }
2553 Mf::Pack2x16float => {
2554 write!(self.out, "as_type<uint>(half2(")?;
2555 self.put_expression(arg, context, false)?;
2556 write!(self.out, "))")?;
2557 }
2558 Mf::ExtractBits => {
2559 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2576
2577 write!(self.out, "{NAMESPACE}::extract_bits(")?;
2578 self.put_expression(arg, context, true)?;
2579 write!(self.out, ", {NAMESPACE}::min(")?;
2580 self.put_expression(arg1.unwrap(), context, true)?;
2581 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2582 self.put_expression(arg2.unwrap(), context, true)?;
2583 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2584 self.put_expression(arg1.unwrap(), context, true)?;
2585 write!(self.out, ", {scalar_bits}u)))")?;
2586 }
2587 Mf::InsertBits => {
2588 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
2593
2594 write!(self.out, "{NAMESPACE}::insert_bits(")?;
2595 self.put_expression(arg, context, true)?;
2596 write!(self.out, ", ")?;
2597 self.put_expression(arg1.unwrap(), context, true)?;
2598 write!(self.out, ", {NAMESPACE}::min(")?;
2599 self.put_expression(arg2.unwrap(), context, true)?;
2600 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
2601 self.put_expression(arg3.unwrap(), context, true)?;
2602 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
2603 self.put_expression(arg2.unwrap(), context, true)?;
2604 write!(self.out, ", {scalar_bits}u)))")?;
2605 }
2606 Mf::Radians => {
2607 write!(self.out, "((")?;
2608 self.put_expression(arg, context, false)?;
2609 write!(self.out, ") * 0.017453292519943295474)")?;
2610 }
2611 Mf::Degrees => {
2612 write!(self.out, "((")?;
2613 self.put_expression(arg, context, false)?;
2614 write!(self.out, ") * 57.295779513082322865)")?;
2615 }
2616 Mf::Modf | Mf::Frexp => {
2617 write!(self.out, "{fun_name}")?;
2618 self.put_call_parameters(iter::once(arg), context)?;
2619 }
2620 Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2621 Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2622 Mf::Pack4xI8Clamp => {
2623 self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2624 }
2625 Mf::Pack4xU8Clamp => {
2626 self.put_pack4x8(arg, context, false, Some(("0", "255")))?
2627 }
2628 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2629 let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2630 "u"
2631 } else {
2632 ""
2633 };
2634
2635 if context.lang_version >= (2, 1) {
2636 write!(
2638 self.out,
2639 "{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2640 )?;
2641 self.put_expression(arg, context, true)?;
2642 write!(self.out, "))")?;
2643 } else {
2644 write!(self.out, "({sign_prefix}int4(")?;
2646 self.put_expression(arg, context, true)?;
2647 write!(self.out, ", ")?;
2648 self.put_expression(arg, context, true)?;
2649 write!(self.out, " >> 8, ")?;
2650 self.put_expression(arg, context, true)?;
2651 write!(self.out, " >> 16, ")?;
2652 self.put_expression(arg, context, true)?;
2653 write!(self.out, " >> 24) << 24 >> 24)")?;
2654 }
2655 }
2656 Mf::QuantizeToF16 => {
2657 match *context.resolve_type(arg) {
2658 crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
2659 crate::TypeInner::Vector { size, .. } => write!(
2660 self.out,
2661 "{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
2662 size = common::vector_size_str(size),
2663 )?,
2664 _ => unreachable!(
2665 "Correct TypeInner for QuantizeToF16 should be already validated"
2666 ),
2667 };
2668
2669 self.put_expression(arg, context, true)?;
2670 write!(self.out, "))")?;
2671 }
2672 _ => {
2673 write!(self.out, "{NAMESPACE}::{fun_name}")?;
2674 self.put_call_parameters(
2675 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
2676 context,
2677 )?;
2678 }
2679 }
2680 }
2681 crate::Expression::As {
2682 expr,
2683 kind,
2684 convert,
2685 } => match *context.resolve_type(expr) {
2686 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
2687 if src.kind == crate::ScalarKind::Float
2688 && (kind == crate::ScalarKind::Sint || kind == crate::ScalarKind::Uint)
2689 && convert.is_some()
2690 {
2691 let fun_name = match (kind, convert) {
2695 (crate::ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
2696 (crate::ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
2697 (crate::ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
2698 (crate::ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
2699 _ => unreachable!(),
2700 };
2701 write!(self.out, "{fun_name}(")?;
2702 self.put_expression(expr, context, true)?;
2703 write!(self.out, ")")?;
2704 } else {
2705 let target_scalar = crate::Scalar {
2706 kind,
2707 width: convert.unwrap_or(src.width),
2708 };
2709 let op = match convert {
2710 Some(_) => "static_cast",
2711 None => "as_type",
2712 };
2713 write!(self.out, "{op}<")?;
2714 match *context.resolve_type(expr) {
2715 crate::TypeInner::Vector { size, .. } => {
2716 put_numeric_type(&mut self.out, target_scalar, &[size])?
2717 }
2718 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2719 };
2720 write!(self.out, ">(")?;
2721 self.put_expression(expr, context, true)?;
2722 write!(self.out, ")")?;
2723 }
2724 }
2725 crate::TypeInner::Matrix {
2726 columns,
2727 rows,
2728 scalar,
2729 } => {
2730 let target_scalar = crate::Scalar {
2731 kind,
2732 width: convert.unwrap_or(scalar.width),
2733 };
2734 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2735 write!(self.out, "(")?;
2736 self.put_expression(expr, context, true)?;
2737 write!(self.out, ")")?;
2738 }
2739 ref ty => {
2740 return Err(Error::GenericValidation(format!(
2741 "Unsupported type for As: {ty:?}"
2742 )))
2743 }
2744 },
2745 crate::Expression::CallResult(_)
2747 | crate::Expression::AtomicResult { .. }
2748 | crate::Expression::WorkGroupUniformLoadResult { .. }
2749 | crate::Expression::SubgroupBallotResult
2750 | crate::Expression::SubgroupOperationResult { .. }
2751 | crate::Expression::RayQueryProceedResult => {
2752 unreachable!()
2753 }
2754 crate::Expression::ArrayLength(expr) => {
2755 let global = match context.function.expressions[expr] {
2757 crate::Expression::AccessIndex { base, .. } => {
2758 match context.function.expressions[base] {
2759 crate::Expression::GlobalVariable(handle) => handle,
2760 ref ex => {
2761 return Err(Error::GenericValidation(format!(
2762 "Expected global variable in AccessIndex, got {ex:?}"
2763 )))
2764 }
2765 }
2766 }
2767 crate::Expression::GlobalVariable(handle) => handle,
2768 ref ex => {
2769 return Err(Error::GenericValidation(format!(
2770 "Unexpected expression in ArrayLength, got {ex:?}"
2771 )))
2772 }
2773 };
2774
2775 if !is_scoped {
2776 write!(self.out, "(")?;
2777 }
2778 write!(self.out, "1 + ")?;
2779 self.put_dynamic_array_max_index(global, context)?;
2780 if !is_scoped {
2781 write!(self.out, ")")?;
2782 }
2783 }
2784 crate::Expression::RayQueryVertexPositions { .. } => {
2785 unimplemented!()
2786 }
2787 crate::Expression::RayQueryGetIntersection {
2788 query,
2789 committed: _,
2790 } => {
2791 if context.lang_version < (2, 4) {
2792 return Err(Error::UnsupportedRayTracing);
2793 }
2794
2795 let ty = context.module.special_types.ray_intersection.unwrap();
2796 let type_name = &self.names[&NameKey::Type(ty)];
2797 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2798 self.put_expression(query, context, true)?;
2799 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2800 let fields = [
2801 "distance",
2802 "user_instance_id", "instance_id",
2804 "", "geometry_id",
2806 "primitive_id",
2807 "triangle_barycentric_coord",
2808 "triangle_front_facing",
2809 "", "object_to_world_transform", "world_to_object_transform", ];
2813 for field in fields {
2814 write!(self.out, ", ")?;
2815 if field.is_empty() {
2816 write!(self.out, "{{}}")?;
2817 } else {
2818 self.put_expression(query, context, true)?;
2819 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2820 }
2821 }
2822 write!(self.out, "}}")?;
2823 }
2824 }
2825 Ok(())
2826 }
2827
2828 fn put_binop<F>(
2831 &mut self,
2832 op: crate::BinaryOperator,
2833 left: Handle<crate::Expression>,
2834 right: Handle<crate::Expression>,
2835 context: &ExpressionContext,
2836 is_scoped: bool,
2837 put_expression: &F,
2838 ) -> BackendResult
2839 where
2840 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2841 {
2842 let op_str = back::binary_operation_str(op);
2843
2844 if !is_scoped {
2845 write!(self.out, "(")?;
2846 }
2847
2848 if op == crate::BinaryOperator::Multiply
2851 && matches!(
2852 context.resolve_type(right),
2853 &crate::TypeInner::Matrix { .. }
2854 )
2855 {
2856 self.put_wrapped_expression_for_packed_vec3_access(
2857 left,
2858 context,
2859 false,
2860 put_expression,
2861 )?;
2862 } else {
2863 put_expression(self, left, context, false)?;
2864 }
2865
2866 write!(self.out, " {op_str} ")?;
2867
2868 if op == crate::BinaryOperator::Multiply
2870 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
2871 {
2872 self.put_wrapped_expression_for_packed_vec3_access(
2873 right,
2874 context,
2875 false,
2876 put_expression,
2877 )?;
2878 } else {
2879 put_expression(self, right, context, false)?;
2880 }
2881
2882 if !is_scoped {
2883 write!(self.out, ")")?;
2884 }
2885
2886 Ok(())
2887 }
2888
2889 fn put_wrapped_expression_for_packed_vec3_access<F>(
2891 &mut self,
2892 expr_handle: Handle<crate::Expression>,
2893 context: &ExpressionContext,
2894 is_scoped: bool,
2895 put_expression: &F,
2896 ) -> BackendResult
2897 where
2898 F: Fn(&mut Self, Handle<crate::Expression>, &ExpressionContext, bool) -> BackendResult,
2899 {
2900 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2901 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2902 put_expression(self, expr_handle, context, is_scoped)?;
2903 write!(self.out, ")")?;
2904 } else {
2905 put_expression(self, expr_handle, context, is_scoped)?;
2906 }
2907 Ok(())
2908 }
2909
2910 fn put_bitcasted_expression<F>(
2913 &mut self,
2914 cast_to: &crate::TypeInner,
2915 context: &ExpressionContext,
2916 put_expression: &F,
2917 ) -> BackendResult
2918 where
2919 F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
2920 {
2921 write!(self.out, "as_type<")?;
2922 match *cast_to {
2923 crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
2924 crate::TypeInner::Vector { size, scalar } => {
2925 put_numeric_type(&mut self.out, scalar, &[size])?
2926 }
2927 _ => return Err(Error::UnsupportedBitCast(cast_to.clone())),
2928 };
2929 write!(self.out, ">(")?;
2930 put_expression(self, context, true)?;
2931 write!(self.out, ")")?;
2932
2933 Ok(())
2934 }
2935
2936 fn put_index(
2938 &mut self,
2939 index: index::GuardedIndex,
2940 context: &ExpressionContext,
2941 is_scoped: bool,
2942 ) -> BackendResult {
2943 match index {
2944 index::GuardedIndex::Expression(expr) => {
2945 self.put_expression(expr, context, is_scoped)?
2946 }
2947 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2948 }
2949 Ok(())
2950 }
2951
2952 #[allow(unused_variables)]
2982 fn put_bounds_checks(
2983 &mut self,
2984 chain: Handle<crate::Expression>,
2985 context: &ExpressionContext,
2986 level: back::Level,
2987 prefix: &'static str,
2988 ) -> Result<bool, Error> {
2989 let mut check_written = false;
2990
2991 for item in context.bounds_check_iter(chain) {
2993 let BoundsCheck {
2994 base,
2995 index,
2996 length,
2997 } = item;
2998
2999 if check_written {
3000 write!(self.out, " && ")?;
3001 } else {
3002 write!(self.out, "{level}{prefix}")?;
3003 check_written = true;
3004 }
3005
3006 write!(self.out, "uint(")?;
3010 self.put_index(index, context, true)?;
3011 self.out.write_str(") < ")?;
3012 match length {
3013 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
3014 index::IndexableLength::Dynamic => {
3015 let global = context.function.originating_global(base).ok_or_else(|| {
3016 Error::GenericValidation("Could not find originating global".into())
3017 })?;
3018 write!(self.out, "1 + ")?;
3019 self.put_dynamic_array_max_index(global, context)?
3020 }
3021 }
3022 }
3023
3024 Ok(check_written)
3025 }
3026
3027 fn put_access_chain(
3047 &mut self,
3048 chain: Handle<crate::Expression>,
3049 policy: index::BoundsCheckPolicy,
3050 context: &ExpressionContext,
3051 ) -> BackendResult {
3052 match context.function.expressions[chain] {
3053 crate::Expression::Access { base, index } => {
3054 let mut base_ty = context.resolve_type(base);
3055
3056 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3058 base_ty = &context.module.types[base].inner;
3059 }
3060
3061 self.put_subscripted_access_chain(
3062 base,
3063 base_ty,
3064 index::GuardedIndex::Expression(index),
3065 policy,
3066 context,
3067 )?;
3068 }
3069 crate::Expression::AccessIndex { base, index } => {
3070 let base_resolution = &context.info[base].ty;
3071 let mut base_ty = base_resolution.inner_with(&context.module.types);
3072 let mut base_ty_handle = base_resolution.handle();
3073
3074 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
3076 base_ty = &context.module.types[base].inner;
3077 base_ty_handle = Some(base);
3078 }
3079
3080 match *base_ty {
3084 crate::TypeInner::Struct { .. } => {
3085 let base_ty = base_ty_handle.unwrap();
3086 self.put_access_chain(base, policy, context)?;
3087 let name = &self.names[&NameKey::StructMember(base_ty, index)];
3088 write!(self.out, ".{name}")?;
3089 }
3090 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
3091 self.put_access_chain(base, policy, context)?;
3092 if context.get_packed_vec_kind(base).is_some() {
3095 write!(self.out, "[{index}]")?;
3096 } else {
3097 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
3098 }
3099 }
3100 _ => {
3101 self.put_subscripted_access_chain(
3102 base,
3103 base_ty,
3104 index::GuardedIndex::Known(index),
3105 policy,
3106 context,
3107 )?;
3108 }
3109 }
3110 }
3111 _ => self.put_expression(chain, context, false)?,
3112 }
3113
3114 Ok(())
3115 }
3116
3117 fn put_subscripted_access_chain(
3134 &mut self,
3135 base: Handle<crate::Expression>,
3136 base_ty: &crate::TypeInner,
3137 index: index::GuardedIndex,
3138 policy: index::BoundsCheckPolicy,
3139 context: &ExpressionContext,
3140 ) -> BackendResult {
3141 let accessing_wrapped_array = match *base_ty {
3142 crate::TypeInner::Array {
3143 size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
3144 ..
3145 } => true,
3146 _ => false,
3147 };
3148 let accessing_wrapped_binding_array =
3149 matches!(*base_ty, crate::TypeInner::BindingArray { .. });
3150
3151 self.put_access_chain(base, policy, context)?;
3152 if accessing_wrapped_array {
3153 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3154 }
3155 write!(self.out, "[")?;
3156
3157 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
3159 context.access_needs_check(base, index)
3160 } else {
3161 None
3162 };
3163 if let Some(limit) = restriction_needed {
3164 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
3165 self.put_index(index, context, true)?;
3166 write!(self.out, "), ")?;
3167 match limit {
3168 index::IndexableLength::Known(limit) => {
3169 write!(self.out, "{}u", limit - 1)?;
3170 }
3171 index::IndexableLength::Dynamic => {
3172 let global = context.function.originating_global(base).ok_or_else(|| {
3173 Error::GenericValidation("Could not find originating global".into())
3174 })?;
3175 self.put_dynamic_array_max_index(global, context)?;
3176 }
3177 }
3178 write!(self.out, ")")?;
3179 } else {
3180 self.put_index(index, context, true)?;
3181 }
3182
3183 write!(self.out, "]")?;
3184
3185 if accessing_wrapped_binding_array {
3186 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
3187 }
3188
3189 Ok(())
3190 }
3191
3192 fn put_load(
3193 &mut self,
3194 pointer: Handle<crate::Expression>,
3195 context: &ExpressionContext,
3196 is_scoped: bool,
3197 ) -> BackendResult {
3198 let policy = context.choose_bounds_check_policy(pointer);
3201 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3202 && self.put_bounds_checks(
3203 pointer,
3204 context,
3205 back::Level(0),
3206 if is_scoped { "" } else { "(" },
3207 )?
3208 {
3209 write!(self.out, " ? ")?;
3210 self.put_unchecked_load(pointer, policy, context)?;
3211 write!(self.out, " : DefaultConstructible()")?;
3212
3213 if !is_scoped {
3214 write!(self.out, ")")?;
3215 }
3216 } else {
3217 self.put_unchecked_load(pointer, policy, context)?;
3218 }
3219
3220 Ok(())
3221 }
3222
3223 fn put_unchecked_load(
3224 &mut self,
3225 pointer: Handle<crate::Expression>,
3226 policy: index::BoundsCheckPolicy,
3227 context: &ExpressionContext,
3228 ) -> BackendResult {
3229 let is_atomic_pointer = context
3230 .resolve_type(pointer)
3231 .is_atomic_pointer(&context.module.types);
3232
3233 if is_atomic_pointer {
3234 write!(
3235 self.out,
3236 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
3237 )?;
3238 self.put_access_chain(pointer, policy, context)?;
3239 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3240 } else {
3241 self.put_access_chain(pointer, policy, context)?;
3245 }
3246
3247 Ok(())
3248 }
3249
3250 fn put_return_value(
3251 &mut self,
3252 level: back::Level,
3253 expr_handle: Handle<crate::Expression>,
3254 result_struct: Option<&str>,
3255 context: &ExpressionContext,
3256 ) -> BackendResult {
3257 match result_struct {
3258 Some(struct_name) => {
3259 let mut has_point_size = false;
3260 let result_ty = context.function.result.as_ref().unwrap().ty;
3261 match context.module.types[result_ty].inner {
3262 crate::TypeInner::Struct { ref members, .. } => {
3263 let tmp = "_tmp";
3264 write!(self.out, "{level}const auto {tmp} = ")?;
3265 self.put_expression(expr_handle, context, true)?;
3266 writeln!(self.out, ";")?;
3267 write!(self.out, "{level}return {struct_name} {{")?;
3268
3269 let mut is_first = true;
3270
3271 for (index, member) in members.iter().enumerate() {
3272 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
3273 member.binding
3274 {
3275 has_point_size = true;
3276 if !context.pipeline_options.allow_and_force_point_size {
3277 continue;
3278 }
3279 }
3280
3281 let comma = if is_first { "" } else { "," };
3282 is_first = false;
3283 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
3284 if let crate::TypeInner::Array {
3288 size: crate::ArraySize::Constant(size),
3289 ..
3290 } = context.module.types[member.ty].inner
3291 {
3292 write!(self.out, "{comma} {{")?;
3293 for j in 0..size.get() {
3294 if j != 0 {
3295 write!(self.out, ",")?;
3296 }
3297 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
3298 }
3299 write!(self.out, "}}")?;
3300 } else {
3301 write!(self.out, "{comma} {tmp}.{name}")?;
3302 }
3303 }
3304 }
3305 _ => {
3306 write!(self.out, "{level}return {struct_name} {{ ")?;
3307 self.put_expression(expr_handle, context, true)?;
3308 }
3309 }
3310
3311 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
3312 let stage = context.module.entry_points[ep_index as usize].stage;
3313 if context.pipeline_options.allow_and_force_point_size
3314 && stage == crate::ShaderStage::Vertex
3315 && !has_point_size
3316 {
3317 write!(self.out, ", 1.0")?;
3319 }
3320 }
3321 write!(self.out, " }}")?;
3322 }
3323 None => {
3324 write!(self.out, "{level}return ")?;
3325 self.put_expression(expr_handle, context, true)?;
3326 }
3327 }
3328 writeln!(self.out, ";")?;
3329 Ok(())
3330 }
3331
3332 fn update_expressions_to_bake(
3337 &mut self,
3338 func: &crate::Function,
3339 info: &valid::FunctionInfo,
3340 context: &ExpressionContext,
3341 ) {
3342 use crate::Expression;
3343 self.need_bake_expressions.clear();
3344
3345 for (expr_handle, expr) in func.expressions.iter() {
3346 let expr_info = &info[expr_handle];
3349 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
3350 if min_ref_count <= expr_info.ref_count {
3351 self.need_bake_expressions.insert(expr_handle);
3352 } else {
3353 match expr_info.ty {
3354 TypeResolution::Handle(h)
3356 if Some(h) == context.module.special_types.ray_desc =>
3357 {
3358 self.need_bake_expressions.insert(expr_handle);
3359 }
3360 _ => {}
3361 }
3362 }
3363
3364 if let Expression::Math {
3365 fun,
3366 arg,
3367 arg1,
3368 arg2,
3369 ..
3370 } = *expr
3371 {
3372 match fun {
3373 crate::MathFunction::Dot => {
3374 let inner = context.resolve_type(expr_handle);
3383 if let crate::TypeInner::Scalar(scalar) = *inner {
3384 match scalar.kind {
3385 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3386 self.need_bake_expressions.insert(arg);
3387 self.need_bake_expressions.insert(arg1.unwrap());
3388 }
3389 _ => {}
3390 }
3391 }
3392 }
3393 crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
3394 self.need_bake_expressions.insert(arg);
3395 self.need_bake_expressions.insert(arg1.unwrap());
3396 }
3397 crate::MathFunction::FirstLeadingBit => {
3398 self.need_bake_expressions.insert(arg);
3399 }
3400 crate::MathFunction::Pack4xI8
3401 | crate::MathFunction::Pack4xU8
3402 | crate::MathFunction::Pack4xI8Clamp
3403 | crate::MathFunction::Pack4xU8Clamp
3404 | crate::MathFunction::Unpack4xI8
3405 | crate::MathFunction::Unpack4xU8 => {
3406 if context.lang_version < (2, 1) {
3409 self.need_bake_expressions.insert(arg);
3410 }
3411 }
3412 crate::MathFunction::ExtractBits => {
3413 self.need_bake_expressions.insert(arg1.unwrap());
3415 }
3416 crate::MathFunction::InsertBits => {
3417 self.need_bake_expressions.insert(arg2.unwrap());
3419 }
3420 crate::MathFunction::Sign => {
3421 let inner = context.resolve_type(expr_handle);
3426 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
3427 self.need_bake_expressions.insert(arg);
3428 }
3429 }
3430 _ => {}
3431 }
3432 }
3433 }
3434 }
3435
3436 fn start_baking_expression(
3437 &mut self,
3438 handle: Handle<crate::Expression>,
3439 context: &ExpressionContext,
3440 name: &str,
3441 ) -> BackendResult {
3442 match context.info[handle].ty {
3443 TypeResolution::Handle(ty_handle) => {
3444 let ty_name = TypeContext {
3445 handle: ty_handle,
3446 gctx: context.module.to_ctx(),
3447 names: &self.names,
3448 access: crate::StorageAccess::empty(),
3449 first_time: false,
3450 };
3451 write!(self.out, "{ty_name}")?;
3452 }
3453 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
3454 put_numeric_type(&mut self.out, scalar, &[])?;
3455 }
3456 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
3457 put_numeric_type(&mut self.out, scalar, &[size])?;
3458 }
3459 TypeResolution::Value(crate::TypeInner::Matrix {
3460 columns,
3461 rows,
3462 scalar,
3463 }) => {
3464 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
3465 }
3466 TypeResolution::Value(ref other) => {
3467 log::warn!("Type {other:?} isn't a known local"); return Err(Error::FeatureNotImplemented("weird local type".to_string()));
3469 }
3470 }
3471
3472 write!(self.out, " {name} = ")?;
3474
3475 Ok(())
3476 }
3477
3478 fn put_cache_restricted_level(
3491 &mut self,
3492 load: Handle<crate::Expression>,
3493 image: Handle<crate::Expression>,
3494 mip_level: Option<Handle<crate::Expression>>,
3495 indent: back::Level,
3496 context: &StatementContext,
3497 ) -> BackendResult {
3498 let level_of_detail = match mip_level {
3501 Some(level) => level,
3502 None => return Ok(()),
3503 };
3504
3505 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
3506 || !context.expression.image_needs_lod(image)
3507 {
3508 return Ok(());
3509 }
3510
3511 write!(self.out, "{}uint {} = ", indent, ClampedLod(load),)?;
3512 self.put_restricted_scalar_image_index(
3513 image,
3514 level_of_detail,
3515 "get_num_mip_levels",
3516 &context.expression,
3517 )?;
3518 writeln!(self.out, ";")?;
3519
3520 Ok(())
3521 }
3522
3523 fn put_casting_to_packed_chars(
3529 &mut self,
3530 fun: crate::MathFunction,
3531 arg0: Handle<crate::Expression>,
3532 arg1: Handle<crate::Expression>,
3533 indent: back::Level,
3534 context: &StatementContext<'_>,
3535 ) -> Result<(), Error> {
3536 let packed_type = match fun {
3537 crate::MathFunction::Dot4I8Packed => "packed_char4",
3538 crate::MathFunction::Dot4U8Packed => "packed_uchar4",
3539 _ => unreachable!(),
3540 };
3541
3542 for arg in [arg0, arg1] {
3543 write!(
3544 self.out,
3545 "{indent}{packed_type} {0} = as_type<{packed_type}>(",
3546 Reinterpreted::new(packed_type, arg)
3547 )?;
3548 self.put_expression(arg, &context.expression, true)?;
3549 writeln!(self.out, ");")?;
3550 }
3551
3552 Ok(())
3553 }
3554
3555 fn put_block(
3556 &mut self,
3557 level: back::Level,
3558 statements: &[crate::Statement],
3559 context: &StatementContext,
3560 ) -> BackendResult {
3561 #[cfg(test)]
3563 self.put_block_stack_pointers
3564 .insert(ptr::from_ref(&level).cast());
3565
3566 for statement in statements {
3567 log::trace!("statement[{}] {:?}", level.0, statement);
3568 match *statement {
3569 crate::Statement::Emit(ref range) => {
3570 for handle in range.clone() {
3571 use crate::MathFunction as Mf;
3572
3573 match context.expression.function.expressions[handle] {
3574 crate::Expression::ImageLoad {
3577 image,
3578 level: mip_level,
3579 ..
3580 } => {
3581 self.put_cache_restricted_level(
3582 handle, image, mip_level, level, context,
3583 )?;
3584 }
3585
3586 crate::Expression::Math {
3595 fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3596 arg,
3597 arg1,
3598 ..
3599 } if context.expression.lang_version >= (2, 1) => {
3600 self.put_casting_to_packed_chars(
3601 fun,
3602 arg,
3603 arg1.unwrap(),
3604 level,
3605 context,
3606 )?;
3607 }
3608
3609 _ => (),
3610 }
3611
3612 let ptr_class = context.expression.resolve_type(handle).pointer_space();
3613 let expr_name = if ptr_class.is_some() {
3614 None } else if let Some(name) =
3616 context.expression.function.named_expressions.get(&handle)
3617 {
3618 Some(self.namer.call(name))
3628 } else {
3629 let bake = if context.expression.guarded_indices.contains(handle) {
3633 true
3634 } else {
3635 self.need_bake_expressions.contains(&handle)
3636 };
3637
3638 if bake {
3639 Some(Baked(handle).to_string())
3640 } else {
3641 None
3642 }
3643 };
3644
3645 if let Some(name) = expr_name {
3646 write!(self.out, "{level}")?;
3647 self.start_baking_expression(handle, &context.expression, &name)?;
3648 self.put_expression(handle, &context.expression, true)?;
3649 self.named_expressions.insert(handle, name);
3650 writeln!(self.out, ";")?;
3651 }
3652 }
3653 }
3654 crate::Statement::Block(ref block) => {
3655 if !block.is_empty() {
3656 writeln!(self.out, "{level}{{")?;
3657 self.put_block(level.next(), block, context)?;
3658 writeln!(self.out, "{level}}}")?;
3659 }
3660 }
3661 crate::Statement::If {
3662 condition,
3663 ref accept,
3664 ref reject,
3665 } => {
3666 write!(self.out, "{level}if (")?;
3667 self.put_expression(condition, &context.expression, true)?;
3668 writeln!(self.out, ") {{")?;
3669 self.put_block(level.next(), accept, context)?;
3670 if !reject.is_empty() {
3671 writeln!(self.out, "{level}}} else {{")?;
3672 self.put_block(level.next(), reject, context)?;
3673 }
3674 writeln!(self.out, "{level}}}")?;
3675 }
3676 crate::Statement::Switch {
3677 selector,
3678 ref cases,
3679 } => {
3680 write!(self.out, "{level}switch(")?;
3681 self.put_expression(selector, &context.expression, true)?;
3682 writeln!(self.out, ") {{")?;
3683 let lcase = level.next();
3684 for case in cases.iter() {
3685 match case.value {
3686 crate::SwitchValue::I32(value) => {
3687 write!(self.out, "{lcase}case {value}:")?;
3688 }
3689 crate::SwitchValue::U32(value) => {
3690 write!(self.out, "{lcase}case {value}u:")?;
3691 }
3692 crate::SwitchValue::Default => {
3693 write!(self.out, "{lcase}default:")?;
3694 }
3695 }
3696
3697 let write_block_braces = !(case.fall_through && case.body.is_empty());
3698 if write_block_braces {
3699 writeln!(self.out, " {{")?;
3700 } else {
3701 writeln!(self.out)?;
3702 }
3703
3704 self.put_block(lcase.next(), &case.body, context)?;
3705 if !case.fall_through && case.body.last().is_none_or(|s| !s.is_terminator())
3706 {
3707 writeln!(self.out, "{}break;", lcase.next())?;
3708 }
3709
3710 if write_block_braces {
3711 writeln!(self.out, "{lcase}}}")?;
3712 }
3713 }
3714 writeln!(self.out, "{level}}}")?;
3715 }
3716 crate::Statement::Loop {
3717 ref body,
3718 ref continuing,
3719 break_if,
3720 } => {
3721 let force_loop_bound_statements =
3722 self.gen_force_bounded_loop_statements(level, context);
3723 let gate_name = (!continuing.is_empty() || break_if.is_some())
3724 .then(|| self.namer.call("loop_init"));
3725
3726 if let Some((ref decl, _)) = force_loop_bound_statements {
3727 writeln!(self.out, "{decl}")?;
3728 }
3729 if let Some(ref gate_name) = gate_name {
3730 writeln!(self.out, "{level}bool {gate_name} = true;")?;
3731 }
3732
3733 writeln!(self.out, "{level}while(true) {{",)?;
3734 if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3735 writeln!(self.out, "{break_and_inc}")?;
3736 }
3737 if let Some(ref gate_name) = gate_name {
3738 let lif = level.next();
3739 let lcontinuing = lif.next();
3740 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
3741 self.put_block(lcontinuing, continuing, context)?;
3742 if let Some(condition) = break_if {
3743 write!(self.out, "{lcontinuing}if (")?;
3744 self.put_expression(condition, &context.expression, true)?;
3745 writeln!(self.out, ") {{")?;
3746 writeln!(self.out, "{}break;", lcontinuing.next())?;
3747 writeln!(self.out, "{lcontinuing}}}")?;
3748 }
3749 writeln!(self.out, "{lif}}}")?;
3750 writeln!(self.out, "{lif}{gate_name} = false;")?;
3751 }
3752 self.put_block(level.next(), body, context)?;
3753
3754 writeln!(self.out, "{level}}}")?;
3755 }
3756 crate::Statement::Break => {
3757 writeln!(self.out, "{level}break;")?;
3758 }
3759 crate::Statement::Continue => {
3760 writeln!(self.out, "{level}continue;")?;
3761 }
3762 crate::Statement::Return {
3763 value: Some(expr_handle),
3764 } => {
3765 self.put_return_value(
3766 level,
3767 expr_handle,
3768 context.result_struct,
3769 &context.expression,
3770 )?;
3771 }
3772 crate::Statement::Return { value: None } => {
3773 writeln!(self.out, "{level}return;")?;
3774 }
3775 crate::Statement::Kill => {
3776 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
3777 }
3778 crate::Statement::ControlBarrier(flags)
3779 | crate::Statement::MemoryBarrier(flags) => {
3780 self.write_barrier(flags, level)?;
3781 }
3782 crate::Statement::Store { pointer, value } => {
3783 self.put_store(pointer, value, level, context)?
3784 }
3785 crate::Statement::ImageStore {
3786 image,
3787 coordinate,
3788 array_index,
3789 value,
3790 } => {
3791 let address = TexelAddress {
3792 coordinate,
3793 array_index,
3794 sample: None,
3795 level: None,
3796 };
3797 self.put_image_store(level, image, &address, value, context)?
3798 }
3799 crate::Statement::Call {
3800 function,
3801 ref arguments,
3802 result,
3803 } => {
3804 write!(self.out, "{level}")?;
3805 if let Some(expr) = result {
3806 let name = Baked(expr).to_string();
3807 self.start_baking_expression(expr, &context.expression, &name)?;
3808 self.named_expressions.insert(expr, name);
3809 }
3810 let fun_name = &self.names[&NameKey::Function(function)];
3811 write!(self.out, "{fun_name}(")?;
3812 for (i, &handle) in arguments.iter().enumerate() {
3814 if i != 0 {
3815 write!(self.out, ", ")?;
3816 }
3817 self.put_expression(handle, &context.expression, true)?;
3818 }
3819 let mut separate = !arguments.is_empty();
3821 let fun_info = &context.expression.mod_info[function];
3822 let mut needs_buffer_sizes = false;
3823 for (handle, var) in context.expression.module.global_variables.iter() {
3824 if fun_info[handle].is_empty() {
3825 continue;
3826 }
3827 if var.space.needs_pass_through() {
3828 let name = &self.names[&NameKey::GlobalVariable(handle)];
3829 if separate {
3830 write!(self.out, ", ")?;
3831 } else {
3832 separate = true;
3833 }
3834 write!(self.out, "{name}")?;
3835 }
3836 needs_buffer_sizes |=
3837 needs_array_length(var.ty, &context.expression.module.types);
3838 }
3839 if needs_buffer_sizes {
3840 if separate {
3841 write!(self.out, ", ")?;
3842 }
3843 write!(self.out, "_buffer_sizes")?;
3844 }
3845
3846 writeln!(self.out, ");")?;
3848 }
3849 crate::Statement::Atomic {
3850 pointer,
3851 ref fun,
3852 value,
3853 result,
3854 } => {
3855 let context = &context.expression;
3856
3857 write!(self.out, "{level}")?;
3862 let fun_key = if let Some(result) = result {
3863 let res_name = Baked(result).to_string();
3864 self.start_baking_expression(result, context, &res_name)?;
3865 self.named_expressions.insert(result, res_name);
3866 fun.to_msl()
3867 } else if context.resolve_type(value).scalar_width() == Some(8) {
3868 fun.to_msl_64_bit()?
3869 } else {
3870 fun.to_msl()
3871 };
3872
3873 let policy = context.choose_bounds_check_policy(pointer);
3877 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3878 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
3879
3880 if checked {
3882 write!(self.out, " ? ")?;
3883 }
3884
3885 match *fun {
3887 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3888 write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
3889 self.put_access_chain(pointer, policy, context)?;
3890 write!(self.out, ", ")?;
3891 self.put_expression(cmp, context, true)?;
3892 write!(self.out, ", ")?;
3893 self.put_expression(value, context, true)?;
3894 write!(self.out, ")")?;
3895 }
3896 _ => {
3897 write!(
3898 self.out,
3899 "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
3900 )?;
3901 self.put_access_chain(pointer, policy, context)?;
3902 write!(self.out, ", ")?;
3903 self.put_expression(value, context, true)?;
3904 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
3905 }
3906 }
3907
3908 if checked {
3910 write!(self.out, " : DefaultConstructible()")?;
3911 }
3912
3913 writeln!(self.out, ";")?;
3915 }
3916 crate::Statement::ImageAtomic {
3917 image,
3918 coordinate,
3919 array_index,
3920 fun,
3921 value,
3922 } => {
3923 let address = TexelAddress {
3924 coordinate,
3925 array_index,
3926 sample: None,
3927 level: None,
3928 };
3929 self.put_image_atomic(level, image, &address, fun, value, context)?
3930 }
3931 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
3932 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3933
3934 write!(self.out, "{level}")?;
3935 let name = self.namer.call("");
3936 self.start_baking_expression(result, &context.expression, &name)?;
3937 self.put_load(pointer, &context.expression, true)?;
3938 self.named_expressions.insert(result, name);
3939
3940 writeln!(self.out, ";")?;
3941 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3942 }
3943 crate::Statement::RayQuery { query, ref fun } => {
3944 if context.expression.lang_version < (2, 4) {
3945 return Err(Error::UnsupportedRayTracing);
3946 }
3947
3948 match *fun {
3949 crate::RayQueryFunction::Initialize {
3950 acceleration_structure,
3951 descriptor,
3952 } => {
3953 write!(self.out, "{level}")?;
3955 self.put_expression(query, &context.expression, true)?;
3956 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3957 {
3958 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3959 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3960 write!(self.out, "{level}")?;
3961 self.put_expression(query, &context.expression, true)?;
3962 write!(
3963 self.out,
3964 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3965 )?;
3966 self.put_expression(descriptor, &context.expression, true)?;
3967 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3968 self.put_expression(descriptor, &context.expression, true)?;
3969 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3970 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3971 }
3972 {
3973 let f_opaque = back::RayFlag::OPAQUE.bits();
3974 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3975 write!(self.out, "{level}")?;
3976 self.put_expression(query, &context.expression, true)?;
3977 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3978 self.put_expression(descriptor, &context.expression, true)?;
3979 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3980 self.put_expression(descriptor, &context.expression, true)?;
3981 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3982 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3983 }
3984 {
3985 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3986 write!(self.out, "{level}")?;
3987 self.put_expression(query, &context.expression, true)?;
3988 write!(
3989 self.out,
3990 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3991 )?;
3992 self.put_expression(descriptor, &context.expression, true)?;
3993 writeln!(self.out, ".flags & {flag}) != 0);")?;
3994 }
3995
3996 write!(self.out, "{level}")?;
3997 self.put_expression(query, &context.expression, true)?;
3998 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3999 self.put_expression(query, &context.expression, true)?;
4000 write!(
4001 self.out,
4002 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
4003 )?;
4004 self.put_expression(descriptor, &context.expression, true)?;
4005 write!(self.out, ".origin, ")?;
4006 self.put_expression(descriptor, &context.expression, true)?;
4007 write!(self.out, ".dir, ")?;
4008 self.put_expression(descriptor, &context.expression, true)?;
4009 write!(self.out, ".tmin, ")?;
4010 self.put_expression(descriptor, &context.expression, true)?;
4011 write!(self.out, ".tmax), ")?;
4012 self.put_expression(acceleration_structure, &context.expression, true)?;
4013 write!(self.out, ", ")?;
4014 self.put_expression(descriptor, &context.expression, true)?;
4015 write!(self.out, ".cull_mask);")?;
4016
4017 write!(self.out, "{level}")?;
4018 self.put_expression(query, &context.expression, true)?;
4019 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
4020 }
4021 crate::RayQueryFunction::Proceed { result } => {
4022 write!(self.out, "{level}")?;
4023 let name = Baked(result).to_string();
4024 self.start_baking_expression(result, &context.expression, &name)?;
4025 self.named_expressions.insert(result, name);
4026 self.put_expression(query, &context.expression, true)?;
4027 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
4028 if RAY_QUERY_MODERN_SUPPORT {
4029 write!(self.out, "{level}")?;
4030 self.put_expression(query, &context.expression, true)?;
4031 writeln!(self.out, ".?.next();")?;
4032 }
4033 }
4034 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
4035 if RAY_QUERY_MODERN_SUPPORT {
4036 write!(self.out, "{level}")?;
4037 self.put_expression(query, &context.expression, true)?;
4038 write!(self.out, ".?.commit_bounding_box_intersection(")?;
4039 self.put_expression(hit_t, &context.expression, true)?;
4040 writeln!(self.out, ");")?;
4041 } else {
4042 log::warn!("Ray Query GenerateIntersection is not yet supported");
4043 }
4044 }
4045 crate::RayQueryFunction::ConfirmIntersection => {
4046 if RAY_QUERY_MODERN_SUPPORT {
4047 write!(self.out, "{level}")?;
4048 self.put_expression(query, &context.expression, true)?;
4049 writeln!(self.out, ".?.commit_triangle_intersection();")?;
4050 } else {
4051 log::warn!("Ray Query ConfirmIntersection is not yet supported");
4052 }
4053 }
4054 crate::RayQueryFunction::Terminate => {
4055 if RAY_QUERY_MODERN_SUPPORT {
4056 write!(self.out, "{level}")?;
4057 self.put_expression(query, &context.expression, true)?;
4058 writeln!(self.out, ".?.abort();")?;
4059 }
4060 write!(self.out, "{level}")?;
4061 self.put_expression(query, &context.expression, true)?;
4062 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
4063 }
4064 }
4065 }
4066 crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => {
4068 unimplemented!()
4069 }
4070 crate::Statement::MeshFunction(
4071 crate::MeshFunction::SetVertex { .. }
4072 | crate::MeshFunction::SetPrimitive { .. },
4073 ) => unimplemented!(),
4074 crate::Statement::SubgroupBallot { result, predicate } => {
4075 write!(self.out, "{level}")?;
4076 let name = self.namer.call("");
4077 self.start_baking_expression(result, &context.expression, &name)?;
4078 self.named_expressions.insert(result, name);
4079 write!(
4080 self.out,
4081 "{NAMESPACE}::uint4((uint64_t){NAMESPACE}::simd_ballot("
4082 )?;
4083 if let Some(predicate) = predicate {
4084 self.put_expression(predicate, &context.expression, true)?;
4085 } else {
4086 write!(self.out, "true")?;
4087 }
4088 writeln!(self.out, "), 0, 0, 0);")?;
4089 }
4090 crate::Statement::SubgroupCollectiveOperation {
4091 op,
4092 collective_op,
4093 argument,
4094 result,
4095 } => {
4096 write!(self.out, "{level}")?;
4097 let name = self.namer.call("");
4098 self.start_baking_expression(result, &context.expression, &name)?;
4099 self.named_expressions.insert(result, name);
4100 match (collective_op, op) {
4101 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
4102 write!(self.out, "{NAMESPACE}::simd_all(")?
4103 }
4104 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
4105 write!(self.out, "{NAMESPACE}::simd_any(")?
4106 }
4107 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
4108 write!(self.out, "{NAMESPACE}::simd_sum(")?
4109 }
4110 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
4111 write!(self.out, "{NAMESPACE}::simd_product(")?
4112 }
4113 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
4114 write!(self.out, "{NAMESPACE}::simd_max(")?
4115 }
4116 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
4117 write!(self.out, "{NAMESPACE}::simd_min(")?
4118 }
4119 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
4120 write!(self.out, "{NAMESPACE}::simd_and(")?
4121 }
4122 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
4123 write!(self.out, "{NAMESPACE}::simd_or(")?
4124 }
4125 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
4126 write!(self.out, "{NAMESPACE}::simd_xor(")?
4127 }
4128 (
4129 crate::CollectiveOperation::ExclusiveScan,
4130 crate::SubgroupOperation::Add,
4131 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
4132 (
4133 crate::CollectiveOperation::ExclusiveScan,
4134 crate::SubgroupOperation::Mul,
4135 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
4136 (
4137 crate::CollectiveOperation::InclusiveScan,
4138 crate::SubgroupOperation::Add,
4139 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
4140 (
4141 crate::CollectiveOperation::InclusiveScan,
4142 crate::SubgroupOperation::Mul,
4143 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
4144 _ => unimplemented!(),
4145 }
4146 self.put_expression(argument, &context.expression, true)?;
4147 writeln!(self.out, ");")?;
4148 }
4149 crate::Statement::SubgroupGather {
4150 mode,
4151 argument,
4152 result,
4153 } => {
4154 write!(self.out, "{level}")?;
4155 let name = self.namer.call("");
4156 self.start_baking_expression(result, &context.expression, &name)?;
4157 self.named_expressions.insert(result, name);
4158 match mode {
4159 crate::GatherMode::BroadcastFirst => {
4160 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
4161 }
4162 crate::GatherMode::Broadcast(_) => {
4163 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
4164 }
4165 crate::GatherMode::Shuffle(_) => {
4166 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
4167 }
4168 crate::GatherMode::ShuffleDown(_) => {
4169 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
4170 }
4171 crate::GatherMode::ShuffleUp(_) => {
4172 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
4173 }
4174 crate::GatherMode::ShuffleXor(_) => {
4175 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
4176 }
4177 crate::GatherMode::QuadBroadcast(_) => {
4178 write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
4179 }
4180 crate::GatherMode::QuadSwap(_) => {
4181 write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?;
4182 }
4183 }
4184 self.put_expression(argument, &context.expression, true)?;
4185 match mode {
4186 crate::GatherMode::BroadcastFirst => {}
4187 crate::GatherMode::Broadcast(index)
4188 | crate::GatherMode::Shuffle(index)
4189 | crate::GatherMode::ShuffleDown(index)
4190 | crate::GatherMode::ShuffleUp(index)
4191 | crate::GatherMode::ShuffleXor(index)
4192 | crate::GatherMode::QuadBroadcast(index) => {
4193 write!(self.out, ", ")?;
4194 self.put_expression(index, &context.expression, true)?;
4195 }
4196 crate::GatherMode::QuadSwap(direction) => {
4197 write!(self.out, ", ")?;
4198 match direction {
4199 crate::Direction::X => {
4200 write!(self.out, "1u")?;
4201 }
4202 crate::Direction::Y => {
4203 write!(self.out, "2u")?;
4204 }
4205 crate::Direction::Diagonal => {
4206 write!(self.out, "3u")?;
4207 }
4208 }
4209 }
4210 }
4211 writeln!(self.out, ");")?;
4212 }
4213 }
4214 }
4215
4216 for statement in statements {
4219 if let crate::Statement::Emit(ref range) = *statement {
4220 for handle in range.clone() {
4221 self.named_expressions.shift_remove(&handle);
4222 }
4223 }
4224 }
4225 Ok(())
4226 }
4227
4228 fn put_store(
4229 &mut self,
4230 pointer: Handle<crate::Expression>,
4231 value: Handle<crate::Expression>,
4232 level: back::Level,
4233 context: &StatementContext,
4234 ) -> BackendResult {
4235 let policy = context.expression.choose_bounds_check_policy(pointer);
4236 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
4237 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
4238 {
4239 writeln!(self.out, ") {{")?;
4240 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
4241 writeln!(self.out, "{level}}}")?;
4242 } else {
4243 self.put_unchecked_store(pointer, value, policy, level, context)?;
4244 }
4245
4246 Ok(())
4247 }
4248
4249 fn put_unchecked_store(
4250 &mut self,
4251 pointer: Handle<crate::Expression>,
4252 value: Handle<crate::Expression>,
4253 policy: index::BoundsCheckPolicy,
4254 level: back::Level,
4255 context: &StatementContext,
4256 ) -> BackendResult {
4257 let is_atomic_pointer = context
4258 .expression
4259 .resolve_type(pointer)
4260 .is_atomic_pointer(&context.expression.module.types);
4261
4262 if is_atomic_pointer {
4263 write!(
4264 self.out,
4265 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4266 )?;
4267 self.put_access_chain(pointer, policy, &context.expression)?;
4268 write!(self.out, ", ")?;
4269 self.put_expression(value, &context.expression, true)?;
4270 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
4271 } else {
4272 write!(self.out, "{level}")?;
4273 self.put_access_chain(pointer, policy, &context.expression)?;
4274 write!(self.out, " = ")?;
4275 self.put_expression(value, &context.expression, true)?;
4276 writeln!(self.out, ";")?;
4277 }
4278
4279 Ok(())
4280 }
4281
4282 pub fn write(
4283 &mut self,
4284 module: &crate::Module,
4285 info: &valid::ModuleInfo,
4286 options: &Options,
4287 pipeline_options: &PipelineOptions,
4288 ) -> Result<TranslationInfo, Error> {
4289 self.names.clear();
4290 self.namer.reset(
4291 module,
4292 &super::keywords::RESERVED_SET,
4293 proc::CaseInsensitiveKeywordSet::empty(),
4294 &[CLAMPED_LOD_LOAD_PREFIX],
4295 &mut self.names,
4296 );
4297 self.wrapped_functions.clear();
4298 self.struct_member_pads.clear();
4299
4300 writeln!(
4301 self.out,
4302 "// language: metal{}.{}",
4303 options.lang_version.0, options.lang_version.1
4304 )?;
4305 writeln!(self.out, "#include <metal_stdlib>")?;
4306 writeln!(self.out, "#include <simd/simd.h>")?;
4307 writeln!(self.out)?;
4308 writeln!(self.out, "using {NAMESPACE}::uint;")?;
4310
4311 let mut uses_ray_query = false;
4312 for (_, ty) in module.types.iter() {
4313 match ty.inner {
4314 crate::TypeInner::AccelerationStructure { .. } => {
4315 if options.lang_version < (2, 4) {
4316 return Err(Error::UnsupportedRayTracing);
4317 }
4318 }
4319 crate::TypeInner::RayQuery { .. } => {
4320 if options.lang_version < (2, 4) {
4321 return Err(Error::UnsupportedRayTracing);
4322 }
4323 uses_ray_query = true;
4324 }
4325 _ => (),
4326 }
4327 }
4328
4329 if module.special_types.ray_desc.is_some()
4330 || module.special_types.ray_intersection.is_some()
4331 {
4332 if options.lang_version < (2, 4) {
4333 return Err(Error::UnsupportedRayTracing);
4334 }
4335 }
4336
4337 if uses_ray_query {
4338 self.put_ray_query_type()?;
4339 }
4340
4341 if options
4342 .bounds_check_policies
4343 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
4344 {
4345 self.put_default_constructible()?;
4346 }
4347 writeln!(self.out)?;
4348
4349 {
4350 let globals: Vec<Handle<crate::GlobalVariable>> = module
4353 .global_variables
4354 .iter()
4355 .filter(|&(_, var)| needs_array_length(var.ty, &module.types))
4356 .map(|(handle, _)| handle)
4357 .collect();
4358
4359 let mut buffer_indices = vec![];
4360 for vbm in &pipeline_options.vertex_buffer_mappings {
4361 buffer_indices.push(vbm.id);
4362 }
4363
4364 if !globals.is_empty() || !buffer_indices.is_empty() {
4365 writeln!(self.out, "struct _mslBufferSizes {{")?;
4366
4367 for global in globals {
4368 writeln!(
4369 self.out,
4370 "{}uint {};",
4371 back::INDENT,
4372 ArraySizeMember(global)
4373 )?;
4374 }
4375
4376 for idx in buffer_indices {
4377 writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?;
4378 }
4379
4380 writeln!(self.out, "}};")?;
4381 writeln!(self.out)?;
4382 }
4383 };
4384
4385 self.write_type_defs(module)?;
4386 self.write_global_constants(module, info)?;
4387 self.write_functions(module, info, options, pipeline_options)
4388 }
4389
4390 fn put_default_constructible(&mut self) -> BackendResult {
4403 let tab = back::INDENT;
4404 writeln!(self.out, "struct DefaultConstructible {{")?;
4405 writeln!(self.out, "{tab}template<typename T>")?;
4406 writeln!(self.out, "{tab}operator T() && {{")?;
4407 writeln!(self.out, "{tab}{tab}return T {{}};")?;
4408 writeln!(self.out, "{tab}}}")?;
4409 writeln!(self.out, "}};")?;
4410 Ok(())
4411 }
4412
4413 fn put_ray_query_type(&mut self) -> BackendResult {
4414 let tab = back::INDENT;
4415 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
4416 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
4417 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
4418 writeln!(
4419 self.out,
4420 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
4421 )?;
4422 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
4423 writeln!(self.out, "}};")?;
4424 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
4425 let v_triangle = back::RayIntersectionType::Triangle as u32;
4426 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
4427 writeln!(
4428 self.out,
4429 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
4430 )?;
4431 writeln!(
4432 self.out,
4433 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
4434 )?;
4435 writeln!(self.out, "}}")?;
4436 Ok(())
4437 }
4438
4439 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
4440 let mut generated_argument_buffer_wrapper = false;
4441 let mut generated_external_texture_wrapper = false;
4442 for (handle, ty) in module.types.iter() {
4443 match ty.inner {
4444 crate::TypeInner::BindingArray { .. } if !generated_argument_buffer_wrapper => {
4445 writeln!(self.out, "template <typename T>")?;
4446 writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
4447 writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
4448 writeln!(self.out, "}};")?;
4449 generated_argument_buffer_wrapper = true;
4450 }
4451 crate::TypeInner::Image {
4452 class: crate::ImageClass::External,
4453 ..
4454 } if !generated_external_texture_wrapper => {
4455 let params_ty_name = &self.names
4456 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
4457 writeln!(self.out, "struct {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {{")?;
4458 writeln!(
4459 self.out,
4460 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane0;",
4461 back::INDENT
4462 )?;
4463 writeln!(
4464 self.out,
4465 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane1;",
4466 back::INDENT
4467 )?;
4468 writeln!(
4469 self.out,
4470 "{}{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> plane2;",
4471 back::INDENT
4472 )?;
4473 writeln!(self.out, "{}{params_ty_name} params;", back::INDENT)?;
4474 writeln!(self.out, "}};")?;
4475 generated_external_texture_wrapper = true;
4476 }
4477 _ => {}
4478 }
4479
4480 if !ty.needs_alias() {
4481 continue;
4482 }
4483 let name = &self.names[&NameKey::Type(handle)];
4484 match ty.inner {
4485 crate::TypeInner::Array {
4499 base,
4500 size,
4501 stride: _,
4502 } => {
4503 let base_name = TypeContext {
4504 handle: base,
4505 gctx: module.to_ctx(),
4506 names: &self.names,
4507 access: crate::StorageAccess::empty(),
4508 first_time: false,
4509 };
4510
4511 match size.resolve(module.to_ctx())? {
4512 proc::IndexableLength::Known(size) => {
4513 writeln!(self.out, "struct {name} {{")?;
4514 writeln!(
4515 self.out,
4516 "{}{} {}[{}];",
4517 back::INDENT,
4518 base_name,
4519 WRAPPED_ARRAY_FIELD,
4520 size
4521 )?;
4522 writeln!(self.out, "}};")?;
4523 }
4524 proc::IndexableLength::Dynamic => {
4525 writeln!(self.out, "typedef {base_name} {name}[1];")?;
4526 }
4527 }
4528 }
4529 crate::TypeInner::Struct {
4530 ref members, span, ..
4531 } => {
4532 writeln!(self.out, "struct {name} {{")?;
4533 let mut last_offset = 0;
4534 for (index, member) in members.iter().enumerate() {
4535 if member.offset > last_offset {
4536 self.struct_member_pads.insert((handle, index as u32));
4537 let pad = member.offset - last_offset;
4538 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
4539 }
4540 let ty_inner = &module.types[member.ty].inner;
4541 last_offset = member.offset + ty_inner.size(module.to_ctx());
4542
4543 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
4544
4545 match should_pack_struct_member(members, span, index, module) {
4547 Some(scalar) => {
4548 writeln!(
4549 self.out,
4550 "{}{}::packed_{}3 {};",
4551 back::INDENT,
4552 NAMESPACE,
4553 scalar.to_msl_name(),
4554 member_name
4555 )?;
4556 }
4557 None => {
4558 let base_name = TypeContext {
4559 handle: member.ty,
4560 gctx: module.to_ctx(),
4561 names: &self.names,
4562 access: crate::StorageAccess::empty(),
4563 first_time: false,
4564 };
4565 writeln!(
4566 self.out,
4567 "{}{} {};",
4568 back::INDENT,
4569 base_name,
4570 member_name
4571 )?;
4572
4573 if let crate::TypeInner::Vector {
4575 size: crate::VectorSize::Tri,
4576 scalar,
4577 } = *ty_inner
4578 {
4579 last_offset += scalar.width as u32;
4580 }
4581 }
4582 }
4583 }
4584 if last_offset < span {
4585 let pad = span - last_offset;
4586 writeln!(
4587 self.out,
4588 "{}char _pad{}[{}];",
4589 back::INDENT,
4590 members.len(),
4591 pad
4592 )?;
4593 }
4594 writeln!(self.out, "}};")?;
4595 }
4596 _ => {
4597 let ty_name = TypeContext {
4598 handle,
4599 gctx: module.to_ctx(),
4600 names: &self.names,
4601 access: crate::StorageAccess::empty(),
4602 first_time: true,
4603 };
4604 writeln!(self.out, "typedef {ty_name} {name};")?;
4605 }
4606 }
4607 }
4608
4609 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
4611 match type_key {
4612 &crate::PredeclaredType::ModfResult { size, scalar }
4613 | &crate::PredeclaredType::FrexpResult { size, scalar } => {
4614 let arg_type_name_owner;
4615 let arg_type_name = if let Some(size) = size {
4616 arg_type_name_owner = format!(
4617 "{NAMESPACE}::{}{}",
4618 if scalar.width == 8 { "double" } else { "float" },
4619 size as u8
4620 );
4621 &arg_type_name_owner
4622 } else if scalar.width == 8 {
4623 "double"
4624 } else {
4625 "float"
4626 };
4627
4628 let other_type_name_owner;
4629 let (defined_func_name, called_func_name, other_type_name) =
4630 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
4631 (MODF_FUNCTION, "modf", arg_type_name)
4632 } else {
4633 let other_type_name = if let Some(size) = size {
4634 other_type_name_owner = format!("int{}", size as u8);
4635 &other_type_name_owner
4636 } else {
4637 "int"
4638 };
4639 (FREXP_FUNCTION, "frexp", other_type_name)
4640 };
4641
4642 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4643
4644 writeln!(self.out)?;
4645 writeln!(
4646 self.out,
4647 "{struct_name} {defined_func_name}({arg_type_name} arg) {{
4648 {other_type_name} other;
4649 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
4650 return {struct_name}{{ fract, other }};
4651}}"
4652 )?;
4653 }
4654 &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
4655 let arg_type_name = scalar.to_msl_name();
4656 let called_func_name = "atomic_compare_exchange_weak_explicit";
4657 let defined_func_name = ATOMIC_COMP_EXCH_FUNCTION;
4658 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
4659
4660 writeln!(self.out)?;
4661
4662 for address_space_name in ["device", "threadgroup"] {
4663 writeln!(
4664 self.out,
4665 "\
4666template <typename A>
4667{struct_name} {defined_func_name}(
4668 {address_space_name} A *atomic_ptr,
4669 {arg_type_name} cmp,
4670 {arg_type_name} v
4671) {{
4672 bool swapped = {NAMESPACE}::{called_func_name}(
4673 atomic_ptr, &cmp, v,
4674 metal::memory_order_relaxed, metal::memory_order_relaxed
4675 );
4676 return {struct_name}{{cmp, swapped}};
4677}}"
4678 )?;
4679 }
4680 }
4681 }
4682 }
4683
4684 Ok(())
4685 }
4686
4687 fn write_global_constants(
4689 &mut self,
4690 module: &crate::Module,
4691 mod_info: &valid::ModuleInfo,
4692 ) -> BackendResult {
4693 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
4694
4695 for (handle, constant) in constants {
4696 let ty_name = TypeContext {
4697 handle: constant.ty,
4698 gctx: module.to_ctx(),
4699 names: &self.names,
4700 access: crate::StorageAccess::empty(),
4701 first_time: false,
4702 };
4703 let name = &self.names[&NameKey::Constant(handle)];
4704 write!(self.out, "constant {ty_name} {name} = ")?;
4705 self.put_const_expression(constant.init, module, mod_info, &module.global_expressions)?;
4706 writeln!(self.out, ";")?;
4707 }
4708
4709 Ok(())
4710 }
4711
4712 fn put_inline_sampler_properties(
4713 &mut self,
4714 level: back::Level,
4715 sampler: &sm::InlineSampler,
4716 ) -> BackendResult {
4717 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
4718 writeln!(
4719 self.out,
4720 "{}{}::{}_address::{},",
4721 level,
4722 NAMESPACE,
4723 letter,
4724 address.as_str(),
4725 )?;
4726 }
4727 writeln!(
4728 self.out,
4729 "{}{}::mag_filter::{},",
4730 level,
4731 NAMESPACE,
4732 sampler.mag_filter.as_str(),
4733 )?;
4734 writeln!(
4735 self.out,
4736 "{}{}::min_filter::{},",
4737 level,
4738 NAMESPACE,
4739 sampler.min_filter.as_str(),
4740 )?;
4741 if let Some(filter) = sampler.mip_filter {
4742 writeln!(
4743 self.out,
4744 "{}{}::mip_filter::{},",
4745 level,
4746 NAMESPACE,
4747 filter.as_str(),
4748 )?;
4749 }
4750 if sampler.border_color != sm::BorderColor::TransparentBlack {
4752 writeln!(
4753 self.out,
4754 "{}{}::border_color::{},",
4755 level,
4756 NAMESPACE,
4757 sampler.border_color.as_str(),
4758 )?;
4759 }
4760 if false {
4764 if let Some(ref lod) = sampler.lod_clamp {
4765 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
4766 }
4767 if let Some(aniso) = sampler.max_anisotropy {
4768 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
4769 }
4770 }
4771 if sampler.compare_func != sm::CompareFunc::Never {
4772 writeln!(
4773 self.out,
4774 "{}{}::compare_func::{},",
4775 level,
4776 NAMESPACE,
4777 sampler.compare_func.as_str(),
4778 )?;
4779 }
4780 writeln!(
4781 self.out,
4782 "{}{}::coord::{}",
4783 level,
4784 NAMESPACE,
4785 sampler.coord.as_str()
4786 )?;
4787 Ok(())
4788 }
4789
4790 fn write_unpacking_function(
4791 &mut self,
4792 format: back::msl::VertexFormat,
4793 ) -> Result<(String, u32, u32), Error> {
4794 use back::msl::VertexFormat::*;
4795 match format {
4796 Uint8 => {
4797 let name = self.namer.call("unpackUint8");
4798 writeln!(self.out, "uint {name}(metal::uchar b0) {{")?;
4799 writeln!(self.out, "{}return uint(b0);", back::INDENT)?;
4800 writeln!(self.out, "}}")?;
4801 Ok((name, 1, 1))
4802 }
4803 Uint8x2 => {
4804 let name = self.namer.call("unpackUint8x2");
4805 writeln!(
4806 self.out,
4807 "metal::uint2 {name}(metal::uchar b0, \
4808 metal::uchar b1) {{"
4809 )?;
4810 writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?;
4811 writeln!(self.out, "}}")?;
4812 Ok((name, 2, 2))
4813 }
4814 Uint8x4 => {
4815 let name = self.namer.call("unpackUint8x4");
4816 writeln!(
4817 self.out,
4818 "metal::uint4 {name}(metal::uchar b0, \
4819 metal::uchar b1, \
4820 metal::uchar b2, \
4821 metal::uchar b3) {{"
4822 )?;
4823 writeln!(
4824 self.out,
4825 "{}return metal::uint4(b0, b1, b2, b3);",
4826 back::INDENT
4827 )?;
4828 writeln!(self.out, "}}")?;
4829 Ok((name, 4, 4))
4830 }
4831 Sint8 => {
4832 let name = self.namer.call("unpackSint8");
4833 writeln!(self.out, "int {name}(metal::uchar b0) {{")?;
4834 writeln!(self.out, "{}return int(as_type<char>(b0));", back::INDENT)?;
4835 writeln!(self.out, "}}")?;
4836 Ok((name, 1, 1))
4837 }
4838 Sint8x2 => {
4839 let name = self.namer.call("unpackSint8x2");
4840 writeln!(
4841 self.out,
4842 "metal::int2 {name}(metal::uchar b0, \
4843 metal::uchar b1) {{"
4844 )?;
4845 writeln!(
4846 self.out,
4847 "{}return metal::int2(as_type<char>(b0), \
4848 as_type<char>(b1));",
4849 back::INDENT
4850 )?;
4851 writeln!(self.out, "}}")?;
4852 Ok((name, 2, 2))
4853 }
4854 Sint8x4 => {
4855 let name = self.namer.call("unpackSint8x4");
4856 writeln!(
4857 self.out,
4858 "metal::int4 {name}(metal::uchar b0, \
4859 metal::uchar b1, \
4860 metal::uchar b2, \
4861 metal::uchar b3) {{"
4862 )?;
4863 writeln!(
4864 self.out,
4865 "{}return metal::int4(as_type<char>(b0), \
4866 as_type<char>(b1), \
4867 as_type<char>(b2), \
4868 as_type<char>(b3));",
4869 back::INDENT
4870 )?;
4871 writeln!(self.out, "}}")?;
4872 Ok((name, 4, 4))
4873 }
4874 Unorm8 => {
4875 let name = self.namer.call("unpackUnorm8");
4876 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4877 writeln!(
4878 self.out,
4879 "{}return float(float(b0) / 255.0f);",
4880 back::INDENT
4881 )?;
4882 writeln!(self.out, "}}")?;
4883 Ok((name, 1, 1))
4884 }
4885 Unorm8x2 => {
4886 let name = self.namer.call("unpackUnorm8x2");
4887 writeln!(
4888 self.out,
4889 "metal::float2 {name}(metal::uchar b0, \
4890 metal::uchar b1) {{"
4891 )?;
4892 writeln!(
4893 self.out,
4894 "{}return metal::float2(float(b0) / 255.0f, \
4895 float(b1) / 255.0f);",
4896 back::INDENT
4897 )?;
4898 writeln!(self.out, "}}")?;
4899 Ok((name, 2, 2))
4900 }
4901 Unorm8x4 => {
4902 let name = self.namer.call("unpackUnorm8x4");
4903 writeln!(
4904 self.out,
4905 "metal::float4 {name}(metal::uchar b0, \
4906 metal::uchar b1, \
4907 metal::uchar b2, \
4908 metal::uchar b3) {{"
4909 )?;
4910 writeln!(
4911 self.out,
4912 "{}return metal::float4(float(b0) / 255.0f, \
4913 float(b1) / 255.0f, \
4914 float(b2) / 255.0f, \
4915 float(b3) / 255.0f);",
4916 back::INDENT
4917 )?;
4918 writeln!(self.out, "}}")?;
4919 Ok((name, 4, 4))
4920 }
4921 Snorm8 => {
4922 let name = self.namer.call("unpackSnorm8");
4923 writeln!(self.out, "float {name}(metal::uchar b0) {{")?;
4924 writeln!(
4925 self.out,
4926 "{}return float(metal::max(-1.0f, as_type<char>(b0) / 127.0f));",
4927 back::INDENT
4928 )?;
4929 writeln!(self.out, "}}")?;
4930 Ok((name, 1, 1))
4931 }
4932 Snorm8x2 => {
4933 let name = self.namer.call("unpackSnorm8x2");
4934 writeln!(
4935 self.out,
4936 "metal::float2 {name}(metal::uchar b0, \
4937 metal::uchar b1) {{"
4938 )?;
4939 writeln!(
4940 self.out,
4941 "{}return metal::float2(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4942 metal::max(-1.0f, as_type<char>(b1) / 127.0f));",
4943 back::INDENT
4944 )?;
4945 writeln!(self.out, "}}")?;
4946 Ok((name, 2, 2))
4947 }
4948 Snorm8x4 => {
4949 let name = self.namer.call("unpackSnorm8x4");
4950 writeln!(
4951 self.out,
4952 "metal::float4 {name}(metal::uchar b0, \
4953 metal::uchar b1, \
4954 metal::uchar b2, \
4955 metal::uchar b3) {{"
4956 )?;
4957 writeln!(
4958 self.out,
4959 "{}return metal::float4(metal::max(-1.0f, as_type<char>(b0) / 127.0f), \
4960 metal::max(-1.0f, as_type<char>(b1) / 127.0f), \
4961 metal::max(-1.0f, as_type<char>(b2) / 127.0f), \
4962 metal::max(-1.0f, as_type<char>(b3) / 127.0f));",
4963 back::INDENT
4964 )?;
4965 writeln!(self.out, "}}")?;
4966 Ok((name, 4, 4))
4967 }
4968 Uint16 => {
4969 let name = self.namer.call("unpackUint16");
4970 writeln!(
4971 self.out,
4972 "metal::uint {name}(metal::uint b0, \
4973 metal::uint b1) {{"
4974 )?;
4975 writeln!(
4976 self.out,
4977 "{}return metal::uint(b1 << 8 | b0);",
4978 back::INDENT
4979 )?;
4980 writeln!(self.out, "}}")?;
4981 Ok((name, 2, 1))
4982 }
4983 Uint16x2 => {
4984 let name = self.namer.call("unpackUint16x2");
4985 writeln!(
4986 self.out,
4987 "metal::uint2 {name}(metal::uint b0, \
4988 metal::uint b1, \
4989 metal::uint b2, \
4990 metal::uint b3) {{"
4991 )?;
4992 writeln!(
4993 self.out,
4994 "{}return metal::uint2(b1 << 8 | b0, \
4995 b3 << 8 | b2);",
4996 back::INDENT
4997 )?;
4998 writeln!(self.out, "}}")?;
4999 Ok((name, 4, 2))
5000 }
5001 Uint16x4 => {
5002 let name = self.namer.call("unpackUint16x4");
5003 writeln!(
5004 self.out,
5005 "metal::uint4 {name}(metal::uint b0, \
5006 metal::uint b1, \
5007 metal::uint b2, \
5008 metal::uint b3, \
5009 metal::uint b4, \
5010 metal::uint b5, \
5011 metal::uint b6, \
5012 metal::uint b7) {{"
5013 )?;
5014 writeln!(
5015 self.out,
5016 "{}return metal::uint4(b1 << 8 | b0, \
5017 b3 << 8 | b2, \
5018 b5 << 8 | b4, \
5019 b7 << 8 | b6);",
5020 back::INDENT
5021 )?;
5022 writeln!(self.out, "}}")?;
5023 Ok((name, 8, 4))
5024 }
5025 Sint16 => {
5026 let name = self.namer.call("unpackSint16");
5027 writeln!(
5028 self.out,
5029 "int {name}(metal::ushort b0, \
5030 metal::ushort b1) {{"
5031 )?;
5032 writeln!(
5033 self.out,
5034 "{}return int(as_type<short>(metal::ushort(b1 << 8 | b0)));",
5035 back::INDENT
5036 )?;
5037 writeln!(self.out, "}}")?;
5038 Ok((name, 2, 1))
5039 }
5040 Sint16x2 => {
5041 let name = self.namer.call("unpackSint16x2");
5042 writeln!(
5043 self.out,
5044 "metal::int2 {name}(metal::ushort b0, \
5045 metal::ushort b1, \
5046 metal::ushort b2, \
5047 metal::ushort b3) {{"
5048 )?;
5049 writeln!(
5050 self.out,
5051 "{}return metal::int2(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5052 as_type<short>(metal::ushort(b3 << 8 | b2)));",
5053 back::INDENT
5054 )?;
5055 writeln!(self.out, "}}")?;
5056 Ok((name, 4, 2))
5057 }
5058 Sint16x4 => {
5059 let name = self.namer.call("unpackSint16x4");
5060 writeln!(
5061 self.out,
5062 "metal::int4 {name}(metal::ushort b0, \
5063 metal::ushort b1, \
5064 metal::ushort b2, \
5065 metal::ushort b3, \
5066 metal::ushort b4, \
5067 metal::ushort b5, \
5068 metal::ushort b6, \
5069 metal::ushort b7) {{"
5070 )?;
5071 writeln!(
5072 self.out,
5073 "{}return metal::int4(as_type<short>(metal::ushort(b1 << 8 | b0)), \
5074 as_type<short>(metal::ushort(b3 << 8 | b2)), \
5075 as_type<short>(metal::ushort(b5 << 8 | b4)), \
5076 as_type<short>(metal::ushort(b7 << 8 | b6)));",
5077 back::INDENT
5078 )?;
5079 writeln!(self.out, "}}")?;
5080 Ok((name, 8, 4))
5081 }
5082 Unorm16 => {
5083 let name = self.namer.call("unpackUnorm16");
5084 writeln!(
5085 self.out,
5086 "float {name}(metal::ushort b0, \
5087 metal::ushort b1) {{"
5088 )?;
5089 writeln!(
5090 self.out,
5091 "{}return float(float(b1 << 8 | b0) / 65535.0f);",
5092 back::INDENT
5093 )?;
5094 writeln!(self.out, "}}")?;
5095 Ok((name, 2, 1))
5096 }
5097 Unorm16x2 => {
5098 let name = self.namer.call("unpackUnorm16x2");
5099 writeln!(
5100 self.out,
5101 "metal::float2 {name}(metal::ushort b0, \
5102 metal::ushort b1, \
5103 metal::ushort b2, \
5104 metal::ushort b3) {{"
5105 )?;
5106 writeln!(
5107 self.out,
5108 "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \
5109 float(b3 << 8 | b2) / 65535.0f);",
5110 back::INDENT
5111 )?;
5112 writeln!(self.out, "}}")?;
5113 Ok((name, 4, 2))
5114 }
5115 Unorm16x4 => {
5116 let name = self.namer.call("unpackUnorm16x4");
5117 writeln!(
5118 self.out,
5119 "metal::float4 {name}(metal::ushort b0, \
5120 metal::ushort b1, \
5121 metal::ushort b2, \
5122 metal::ushort b3, \
5123 metal::ushort b4, \
5124 metal::ushort b5, \
5125 metal::ushort b6, \
5126 metal::ushort b7) {{"
5127 )?;
5128 writeln!(
5129 self.out,
5130 "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \
5131 float(b3 << 8 | b2) / 65535.0f, \
5132 float(b5 << 8 | b4) / 65535.0f, \
5133 float(b7 << 8 | b6) / 65535.0f);",
5134 back::INDENT
5135 )?;
5136 writeln!(self.out, "}}")?;
5137 Ok((name, 8, 4))
5138 }
5139 Snorm16 => {
5140 let name = self.namer.call("unpackSnorm16");
5141 writeln!(
5142 self.out,
5143 "float {name}(metal::ushort b0, \
5144 metal::ushort b1) {{"
5145 )?;
5146 writeln!(
5147 self.out,
5148 "{}return metal::unpack_snorm2x16_to_float(b1 << 8 | b0).x;",
5149 back::INDENT
5150 )?;
5151 writeln!(self.out, "}}")?;
5152 Ok((name, 2, 1))
5153 }
5154 Snorm16x2 => {
5155 let name = self.namer.call("unpackSnorm16x2");
5156 writeln!(
5157 self.out,
5158 "metal::float2 {name}(uint b0, \
5159 uint b1, \
5160 uint b2, \
5161 uint b3) {{"
5162 )?;
5163 writeln!(
5164 self.out,
5165 "{}return metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5166 back::INDENT
5167 )?;
5168 writeln!(self.out, "}}")?;
5169 Ok((name, 4, 2))
5170 }
5171 Snorm16x4 => {
5172 let name = self.namer.call("unpackSnorm16x4");
5173 writeln!(
5174 self.out,
5175 "metal::float4 {name}(uint b0, \
5176 uint b1, \
5177 uint b2, \
5178 uint b3, \
5179 uint b4, \
5180 uint b5, \
5181 uint b6, \
5182 uint b7) {{"
5183 )?;
5184 writeln!(
5185 self.out,
5186 "{}return metal::float4(metal::unpack_snorm2x16_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5187 metal::unpack_snorm2x16_to_float(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5188 back::INDENT
5189 )?;
5190 writeln!(self.out, "}}")?;
5191 Ok((name, 8, 4))
5192 }
5193 Float16 => {
5194 let name = self.namer.call("unpackFloat16");
5195 writeln!(
5196 self.out,
5197 "float {name}(metal::ushort b0, \
5198 metal::ushort b1) {{"
5199 )?;
5200 writeln!(
5201 self.out,
5202 "{}return float(as_type<half>(metal::ushort(b1 << 8 | b0)));",
5203 back::INDENT
5204 )?;
5205 writeln!(self.out, "}}")?;
5206 Ok((name, 2, 1))
5207 }
5208 Float16x2 => {
5209 let name = self.namer.call("unpackFloat16x2");
5210 writeln!(
5211 self.out,
5212 "metal::float2 {name}(metal::ushort b0, \
5213 metal::ushort b1, \
5214 metal::ushort b2, \
5215 metal::ushort b3) {{"
5216 )?;
5217 writeln!(
5218 self.out,
5219 "{}return metal::float2(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5220 as_type<half>(metal::ushort(b3 << 8 | b2)));",
5221 back::INDENT
5222 )?;
5223 writeln!(self.out, "}}")?;
5224 Ok((name, 4, 2))
5225 }
5226 Float16x4 => {
5227 let name = self.namer.call("unpackFloat16x4");
5228 writeln!(
5229 self.out,
5230 "metal::float4 {name}(metal::ushort b0, \
5231 metal::ushort b1, \
5232 metal::ushort b2, \
5233 metal::ushort b3, \
5234 metal::ushort b4, \
5235 metal::ushort b5, \
5236 metal::ushort b6, \
5237 metal::ushort b7) {{"
5238 )?;
5239 writeln!(
5240 self.out,
5241 "{}return metal::float4(as_type<half>(metal::ushort(b1 << 8 | b0)), \
5242 as_type<half>(metal::ushort(b3 << 8 | b2)), \
5243 as_type<half>(metal::ushort(b5 << 8 | b4)), \
5244 as_type<half>(metal::ushort(b7 << 8 | b6)));",
5245 back::INDENT
5246 )?;
5247 writeln!(self.out, "}}")?;
5248 Ok((name, 8, 4))
5249 }
5250 Float32 => {
5251 let name = self.namer.call("unpackFloat32");
5252 writeln!(
5253 self.out,
5254 "float {name}(uint b0, \
5255 uint b1, \
5256 uint b2, \
5257 uint b3) {{"
5258 )?;
5259 writeln!(
5260 self.out,
5261 "{}return as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5262 back::INDENT
5263 )?;
5264 writeln!(self.out, "}}")?;
5265 Ok((name, 4, 1))
5266 }
5267 Float32x2 => {
5268 let name = self.namer.call("unpackFloat32x2");
5269 writeln!(
5270 self.out,
5271 "metal::float2 {name}(uint b0, \
5272 uint b1, \
5273 uint b2, \
5274 uint b3, \
5275 uint b4, \
5276 uint b5, \
5277 uint b6, \
5278 uint b7) {{"
5279 )?;
5280 writeln!(
5281 self.out,
5282 "{}return metal::float2(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5283 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5284 back::INDENT
5285 )?;
5286 writeln!(self.out, "}}")?;
5287 Ok((name, 8, 2))
5288 }
5289 Float32x3 => {
5290 let name = self.namer.call("unpackFloat32x3");
5291 writeln!(
5292 self.out,
5293 "metal::float3 {name}(uint b0, \
5294 uint b1, \
5295 uint b2, \
5296 uint b3, \
5297 uint b4, \
5298 uint b5, \
5299 uint b6, \
5300 uint b7, \
5301 uint b8, \
5302 uint b9, \
5303 uint b10, \
5304 uint b11) {{"
5305 )?;
5306 writeln!(
5307 self.out,
5308 "{}return metal::float3(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5309 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5310 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5311 back::INDENT
5312 )?;
5313 writeln!(self.out, "}}")?;
5314 Ok((name, 12, 3))
5315 }
5316 Float32x4 => {
5317 let name = self.namer.call("unpackFloat32x4");
5318 writeln!(
5319 self.out,
5320 "metal::float4 {name}(uint b0, \
5321 uint b1, \
5322 uint b2, \
5323 uint b3, \
5324 uint b4, \
5325 uint b5, \
5326 uint b6, \
5327 uint b7, \
5328 uint b8, \
5329 uint b9, \
5330 uint b10, \
5331 uint b11, \
5332 uint b12, \
5333 uint b13, \
5334 uint b14, \
5335 uint b15) {{"
5336 )?;
5337 writeln!(
5338 self.out,
5339 "{}return metal::float4(as_type<float>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5340 as_type<float>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5341 as_type<float>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5342 as_type<float>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5343 back::INDENT
5344 )?;
5345 writeln!(self.out, "}}")?;
5346 Ok((name, 16, 4))
5347 }
5348 Uint32 => {
5349 let name = self.namer.call("unpackUint32");
5350 writeln!(
5351 self.out,
5352 "uint {name}(uint b0, \
5353 uint b1, \
5354 uint b2, \
5355 uint b3) {{"
5356 )?;
5357 writeln!(
5358 self.out,
5359 "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5360 back::INDENT
5361 )?;
5362 writeln!(self.out, "}}")?;
5363 Ok((name, 4, 1))
5364 }
5365 Uint32x2 => {
5366 let name = self.namer.call("unpackUint32x2");
5367 writeln!(
5368 self.out,
5369 "uint2 {name}(uint b0, \
5370 uint b1, \
5371 uint b2, \
5372 uint b3, \
5373 uint b4, \
5374 uint b5, \
5375 uint b6, \
5376 uint b7) {{"
5377 )?;
5378 writeln!(
5379 self.out,
5380 "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5381 (b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5382 back::INDENT
5383 )?;
5384 writeln!(self.out, "}}")?;
5385 Ok((name, 8, 2))
5386 }
5387 Uint32x3 => {
5388 let name = self.namer.call("unpackUint32x3");
5389 writeln!(
5390 self.out,
5391 "uint3 {name}(uint b0, \
5392 uint b1, \
5393 uint b2, \
5394 uint b3, \
5395 uint b4, \
5396 uint b5, \
5397 uint b6, \
5398 uint b7, \
5399 uint b8, \
5400 uint b9, \
5401 uint b10, \
5402 uint b11) {{"
5403 )?;
5404 writeln!(
5405 self.out,
5406 "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5407 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5408 (b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5409 back::INDENT
5410 )?;
5411 writeln!(self.out, "}}")?;
5412 Ok((name, 12, 3))
5413 }
5414 Uint32x4 => {
5415 let name = self.namer.call("unpackUint32x4");
5416 writeln!(
5417 self.out,
5418 "{NAMESPACE}::uint4 {name}(uint b0, \
5419 uint b1, \
5420 uint b2, \
5421 uint b3, \
5422 uint b4, \
5423 uint b5, \
5424 uint b6, \
5425 uint b7, \
5426 uint b8, \
5427 uint b9, \
5428 uint b10, \
5429 uint b11, \
5430 uint b12, \
5431 uint b13, \
5432 uint b14, \
5433 uint b15) {{"
5434 )?;
5435 writeln!(
5436 self.out,
5437 "{}return {NAMESPACE}::uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5438 (b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5439 (b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5440 (b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5441 back::INDENT
5442 )?;
5443 writeln!(self.out, "}}")?;
5444 Ok((name, 16, 4))
5445 }
5446 Sint32 => {
5447 let name = self.namer.call("unpackSint32");
5448 writeln!(
5449 self.out,
5450 "int {name}(uint b0, \
5451 uint b1, \
5452 uint b2, \
5453 uint b3) {{"
5454 )?;
5455 writeln!(
5456 self.out,
5457 "{}return as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5458 back::INDENT
5459 )?;
5460 writeln!(self.out, "}}")?;
5461 Ok((name, 4, 1))
5462 }
5463 Sint32x2 => {
5464 let name = self.namer.call("unpackSint32x2");
5465 writeln!(
5466 self.out,
5467 "metal::int2 {name}(uint b0, \
5468 uint b1, \
5469 uint b2, \
5470 uint b3, \
5471 uint b4, \
5472 uint b5, \
5473 uint b6, \
5474 uint b7) {{"
5475 )?;
5476 writeln!(
5477 self.out,
5478 "{}return metal::int2(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5479 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4));",
5480 back::INDENT
5481 )?;
5482 writeln!(self.out, "}}")?;
5483 Ok((name, 8, 2))
5484 }
5485 Sint32x3 => {
5486 let name = self.namer.call("unpackSint32x3");
5487 writeln!(
5488 self.out,
5489 "metal::int3 {name}(uint b0, \
5490 uint b1, \
5491 uint b2, \
5492 uint b3, \
5493 uint b4, \
5494 uint b5, \
5495 uint b6, \
5496 uint b7, \
5497 uint b8, \
5498 uint b9, \
5499 uint b10, \
5500 uint b11) {{"
5501 )?;
5502 writeln!(
5503 self.out,
5504 "{}return metal::int3(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5505 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5506 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8));",
5507 back::INDENT
5508 )?;
5509 writeln!(self.out, "}}")?;
5510 Ok((name, 12, 3))
5511 }
5512 Sint32x4 => {
5513 let name = self.namer.call("unpackSint32x4");
5514 writeln!(
5515 self.out,
5516 "metal::int4 {name}(uint b0, \
5517 uint b1, \
5518 uint b2, \
5519 uint b3, \
5520 uint b4, \
5521 uint b5, \
5522 uint b6, \
5523 uint b7, \
5524 uint b8, \
5525 uint b9, \
5526 uint b10, \
5527 uint b11, \
5528 uint b12, \
5529 uint b13, \
5530 uint b14, \
5531 uint b15) {{"
5532 )?;
5533 writeln!(
5534 self.out,
5535 "{}return metal::int4(as_type<int>(b3 << 24 | b2 << 16 | b1 << 8 | b0), \
5536 as_type<int>(b7 << 24 | b6 << 16 | b5 << 8 | b4), \
5537 as_type<int>(b11 << 24 | b10 << 16 | b9 << 8 | b8), \
5538 as_type<int>(b15 << 24 | b14 << 16 | b13 << 8 | b12));",
5539 back::INDENT
5540 )?;
5541 writeln!(self.out, "}}")?;
5542 Ok((name, 16, 4))
5543 }
5544 Unorm10_10_10_2 => {
5545 let name = self.namer.call("unpackUnorm10_10_10_2");
5546 writeln!(
5547 self.out,
5548 "metal::float4 {name}(uint b0, \
5549 uint b1, \
5550 uint b2, \
5551 uint b3) {{"
5552 )?;
5553 writeln!(
5554 self.out,
5555 "{}return metal::unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);",
5567 back::INDENT
5568 )?;
5569 writeln!(self.out, "}}")?;
5570 Ok((name, 4, 4))
5571 }
5572 Unorm8x4Bgra => {
5573 let name = self.namer.call("unpackUnorm8x4Bgra");
5574 writeln!(
5575 self.out,
5576 "metal::float4 {name}(metal::uchar b0, \
5577 metal::uchar b1, \
5578 metal::uchar b2, \
5579 metal::uchar b3) {{"
5580 )?;
5581 writeln!(
5582 self.out,
5583 "{}return metal::float4(float(b2) / 255.0f, \
5584 float(b1) / 255.0f, \
5585 float(b0) / 255.0f, \
5586 float(b3) / 255.0f);",
5587 back::INDENT
5588 )?;
5589 writeln!(self.out, "}}")?;
5590 Ok((name, 4, 4))
5591 }
5592 }
5593 }
5594
5595 fn write_wrapped_unary_op(
5596 &mut self,
5597 module: &crate::Module,
5598 func_ctx: &back::FunctionCtx,
5599 op: crate::UnaryOperator,
5600 operand: Handle<crate::Expression>,
5601 ) -> BackendResult {
5602 let operand_ty = func_ctx.resolve_type(operand, &module.types);
5603 match op {
5604 crate::UnaryOperator::Negate
5611 if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
5612 {
5613 let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
5614 return Ok(());
5615 };
5616 let wrapped = WrappedFunction::UnaryOp {
5617 op,
5618 ty: (vector_size, scalar),
5619 };
5620 if !self.wrapped_functions.insert(wrapped) {
5621 return Ok(());
5622 }
5623
5624 let unsigned_scalar = crate::Scalar {
5625 kind: crate::ScalarKind::Uint,
5626 ..scalar
5627 };
5628 let mut type_name = String::new();
5629 let mut unsigned_type_name = String::new();
5630 match vector_size {
5631 None => {
5632 put_numeric_type(&mut type_name, scalar, &[])?;
5633 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5634 }
5635 Some(size) => {
5636 put_numeric_type(&mut type_name, scalar, &[size])?;
5637 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5638 }
5639 };
5640
5641 writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
5642 let level = back::Level(1);
5643 writeln!(
5644 self.out,
5645 "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
5646 )?;
5647 writeln!(self.out, "}}")?;
5648 writeln!(self.out)?;
5649 }
5650 _ => {}
5651 }
5652 Ok(())
5653 }
5654
5655 fn write_wrapped_binary_op(
5656 &mut self,
5657 module: &crate::Module,
5658 func_ctx: &back::FunctionCtx,
5659 expr: Handle<crate::Expression>,
5660 op: crate::BinaryOperator,
5661 left: Handle<crate::Expression>,
5662 right: Handle<crate::Expression>,
5663 ) -> BackendResult {
5664 let expr_ty = func_ctx.resolve_type(expr, &module.types);
5665 let left_ty = func_ctx.resolve_type(left, &module.types);
5666 let right_ty = func_ctx.resolve_type(right, &module.types);
5667 match (op, expr_ty.scalar_kind()) {
5668 (
5675 crate::BinaryOperator::Divide,
5676 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5677 ) => {
5678 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5679 return Ok(());
5680 };
5681 let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
5682 return Ok(());
5683 };
5684 let wrapped = WrappedFunction::BinaryOp {
5685 op,
5686 left_ty: left_wrapped_ty,
5687 right_ty: right_wrapped_ty,
5688 };
5689 if !self.wrapped_functions.insert(wrapped) {
5690 return Ok(());
5691 }
5692
5693 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5694 return Ok(());
5695 };
5696 let mut type_name = String::new();
5697 match vector_size {
5698 None => put_numeric_type(&mut type_name, scalar, &[])?,
5699 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5700 };
5701 writeln!(
5702 self.out,
5703 "{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5704 )?;
5705 let level = back::Level(1);
5706 match scalar.kind {
5707 crate::ScalarKind::Sint => {
5708 let min_val = match scalar.width {
5709 4 => crate::Literal::I32(i32::MIN),
5710 8 => crate::Literal::I64(i64::MIN),
5711 _ => {
5712 return Err(Error::GenericValidation(format!(
5713 "Unexpected width for scalar {scalar:?}"
5714 )));
5715 }
5716 };
5717 write!(
5718 self.out,
5719 "{level}return lhs / metal::select(rhs, 1, (lhs == "
5720 )?;
5721 self.put_literal(min_val)?;
5722 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
5723 }
5724 crate::ScalarKind::Uint => writeln!(
5725 self.out,
5726 "{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
5727 )?,
5728 _ => unreachable!(),
5729 }
5730 writeln!(self.out, "}}")?;
5731 writeln!(self.out)?;
5732 }
5733 (
5746 crate::BinaryOperator::Modulo,
5747 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
5748 ) => {
5749 let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
5750 return Ok(());
5751 };
5752 let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
5753 else {
5754 return Ok(());
5755 };
5756 let wrapped = WrappedFunction::BinaryOp {
5757 op,
5758 left_ty: left_wrapped_ty,
5759 right_ty: (right_vector_size, right_scalar),
5760 };
5761 if !self.wrapped_functions.insert(wrapped) {
5762 return Ok(());
5763 }
5764
5765 let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
5766 return Ok(());
5767 };
5768 let mut type_name = String::new();
5769 match vector_size {
5770 None => put_numeric_type(&mut type_name, scalar, &[])?,
5771 Some(size) => put_numeric_type(&mut type_name, scalar, &[size])?,
5772 };
5773 let mut rhs_type_name = String::new();
5774 match right_vector_size {
5775 None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
5776 Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
5777 };
5778
5779 writeln!(
5780 self.out,
5781 "{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
5782 )?;
5783 let level = back::Level(1);
5784 match scalar.kind {
5785 crate::ScalarKind::Sint => {
5786 let min_val = match scalar.width {
5787 4 => crate::Literal::I32(i32::MIN),
5788 8 => crate::Literal::I64(i64::MIN),
5789 _ => {
5790 return Err(Error::GenericValidation(format!(
5791 "Unexpected width for scalar {scalar:?}"
5792 )));
5793 }
5794 };
5795 write!(
5796 self.out,
5797 "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
5798 )?;
5799 self.put_literal(min_val)?;
5800 writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
5801 writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
5802 }
5803 crate::ScalarKind::Uint => writeln!(
5804 self.out,
5805 "{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
5806 )?,
5807 _ => unreachable!(),
5808 }
5809 writeln!(self.out, "}}")?;
5810 writeln!(self.out)?;
5811 }
5812 _ => {}
5813 }
5814 Ok(())
5815 }
5816
5817 #[allow(clippy::too_many_arguments)]
5818 fn write_wrapped_math_function(
5819 &mut self,
5820 module: &crate::Module,
5821 func_ctx: &back::FunctionCtx,
5822 fun: crate::MathFunction,
5823 arg: Handle<crate::Expression>,
5824 _arg1: Option<Handle<crate::Expression>>,
5825 _arg2: Option<Handle<crate::Expression>>,
5826 _arg3: Option<Handle<crate::Expression>>,
5827 ) -> BackendResult {
5828 let arg_ty = func_ctx.resolve_type(arg, &module.types);
5829 match fun {
5830 crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
5838 let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
5839 return Ok(());
5840 };
5841 let wrapped = WrappedFunction::Math {
5842 fun,
5843 arg_ty: (vector_size, scalar),
5844 };
5845 if !self.wrapped_functions.insert(wrapped) {
5846 return Ok(());
5847 }
5848
5849 let unsigned_scalar = crate::Scalar {
5850 kind: crate::ScalarKind::Uint,
5851 ..scalar
5852 };
5853 let mut type_name = String::new();
5854 let mut unsigned_type_name = String::new();
5855 match vector_size {
5856 None => {
5857 put_numeric_type(&mut type_name, scalar, &[])?;
5858 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[])?
5859 }
5860 Some(size) => {
5861 put_numeric_type(&mut type_name, scalar, &[size])?;
5862 put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
5863 }
5864 };
5865
5866 writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
5867 let level = back::Level(1);
5868 writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
5869 writeln!(self.out, "}}")?;
5870 writeln!(self.out)?;
5871 }
5872 _ => {}
5873 }
5874 Ok(())
5875 }
5876
5877 fn write_wrapped_cast(
5878 &mut self,
5879 module: &crate::Module,
5880 func_ctx: &back::FunctionCtx,
5881 expr: Handle<crate::Expression>,
5882 kind: crate::ScalarKind,
5883 convert: Option<crate::Bytes>,
5884 ) -> BackendResult {
5885 let src_ty = func_ctx.resolve_type(expr, &module.types);
5896 let Some(width) = convert else {
5897 return Ok(());
5898 };
5899 let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
5900 return Ok(());
5901 };
5902 let dst_scalar = crate::Scalar { kind, width };
5903 if src_scalar.kind != crate::ScalarKind::Float
5904 || (dst_scalar.kind != crate::ScalarKind::Sint
5905 && dst_scalar.kind != crate::ScalarKind::Uint)
5906 {
5907 return Ok(());
5908 }
5909 let wrapped = WrappedFunction::Cast {
5910 src_scalar,
5911 vector_size,
5912 dst_scalar,
5913 };
5914 if !self.wrapped_functions.insert(wrapped) {
5915 return Ok(());
5916 }
5917 let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
5918
5919 let mut src_type_name = String::new();
5920 match vector_size {
5921 None => put_numeric_type(&mut src_type_name, src_scalar, &[])?,
5922 Some(size) => put_numeric_type(&mut src_type_name, src_scalar, &[size])?,
5923 };
5924 let mut dst_type_name = String::new();
5925 match vector_size {
5926 None => put_numeric_type(&mut dst_type_name, dst_scalar, &[])?,
5927 Some(size) => put_numeric_type(&mut dst_type_name, dst_scalar, &[size])?,
5928 };
5929 let fun_name = match dst_scalar {
5930 crate::Scalar::I32 => F2I32_FUNCTION,
5931 crate::Scalar::U32 => F2U32_FUNCTION,
5932 crate::Scalar::I64 => F2I64_FUNCTION,
5933 crate::Scalar::U64 => F2U64_FUNCTION,
5934 _ => unreachable!(),
5935 };
5936
5937 writeln!(
5938 self.out,
5939 "{dst_type_name} {fun_name}({src_type_name} value) {{"
5940 )?;
5941 let level = back::Level(1);
5942 write!(
5943 self.out,
5944 "{level}return static_cast<{dst_type_name}>({NAMESPACE}::clamp(value, "
5945 )?;
5946 self.put_literal(min)?;
5947 write!(self.out, ", ")?;
5948 self.put_literal(max)?;
5949 writeln!(self.out, "));")?;
5950 writeln!(self.out, "}}")?;
5951 writeln!(self.out)?;
5952 Ok(())
5953 }
5954
5955 fn write_convert_yuv_to_rgb_and_return(
5963 &mut self,
5964 level: back::Level,
5965 y: &str,
5966 uv: &str,
5967 params: &str,
5968 ) -> BackendResult {
5969 let l1 = level;
5970 let l2 = l1.next();
5971
5972 writeln!(
5974 self.out,
5975 "{l1}float3 srcGammaRgb = ({params}.yuv_conversion_matrix * float4({y}, {uv}, 1.0)).rgb;"
5976 )?;
5977
5978 writeln!(self.out, "{l1}float3 srcLinearRgb = {NAMESPACE}::select(")?;
5981 writeln!(self.out, "{l2}{NAMESPACE}::pow((srcGammaRgb + {params}.src_tf.a - 1.0) / {params}.src_tf.a, {params}.src_tf.g),")?;
5982 writeln!(self.out, "{l2}srcGammaRgb / {params}.src_tf.k,")?;
5983 writeln!(
5984 self.out,
5985 "{l2}srcGammaRgb < {params}.src_tf.k * {params}.src_tf.b);"
5986 )?;
5987
5988 writeln!(
5991 self.out,
5992 "{l1}float3 dstLinearRgb = {params}.gamut_conversion_matrix * srcLinearRgb;"
5993 )?;
5994
5995 writeln!(self.out, "{l1}float3 dstGammaRgb = {NAMESPACE}::select(")?;
5998 writeln!(self.out, "{l2}{params}.dst_tf.a * {NAMESPACE}::pow(dstLinearRgb, 1.0 / {params}.dst_tf.g) - ({params}.dst_tf.a - 1),")?;
5999 writeln!(self.out, "{l2}{params}.dst_tf.k * dstLinearRgb,")?;
6000 writeln!(self.out, "{l2}dstLinearRgb < {params}.dst_tf.b);")?;
6001
6002 writeln!(self.out, "{l1}return float4(dstGammaRgb, 1.0);")?;
6003 Ok(())
6004 }
6005
6006 #[allow(clippy::too_many_arguments)]
6007 fn write_wrapped_image_load(
6008 &mut self,
6009 module: &crate::Module,
6010 func_ctx: &back::FunctionCtx,
6011 image: Handle<crate::Expression>,
6012 _coordinate: Handle<crate::Expression>,
6013 _array_index: Option<Handle<crate::Expression>>,
6014 _sample: Option<Handle<crate::Expression>>,
6015 _level: Option<Handle<crate::Expression>>,
6016 ) -> BackendResult {
6017 let class = match *func_ctx.resolve_type(image, &module.types) {
6019 crate::TypeInner::Image { class, .. } => class,
6020 _ => unreachable!(),
6021 };
6022 if class != crate::ImageClass::External {
6023 return Ok(());
6024 }
6025 let wrapped = WrappedFunction::ImageLoad { class };
6026 if !self.wrapped_functions.insert(wrapped) {
6027 return Ok(());
6028 }
6029
6030 writeln!(self.out, "float4 {IMAGE_LOAD_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, uint2 coords) {{")?;
6031 let l1 = back::Level(1);
6032 let l2 = l1.next();
6033 let l3 = l2.next();
6034 writeln!(
6035 self.out,
6036 "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6037 )?;
6038 writeln!(
6042 self.out,
6043 "{l1}uint2 cropped_size = {NAMESPACE}::any(tex.params.size != 0) ? tex.params.size : plane0_size;"
6044 )?;
6045 writeln!(
6046 self.out,
6047 "{l1}coords = {NAMESPACE}::min(coords, cropped_size - 1);"
6048 )?;
6049
6050 writeln!(self.out, "{l1}uint2 plane0_coords = uint2({NAMESPACE}::round(tex.params.load_transform * float3(float2(coords), 1.0)));")?;
6052 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6053 writeln!(self.out, "{l2}return tex.plane0.read(plane0_coords);")?;
6055 writeln!(self.out, "{l1}}} else {{")?;
6056
6057 writeln!(
6059 self.out,
6060 "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());"
6061 )?;
6062 writeln!(self.out, "{l2}uint2 plane1_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane1_size) / float2(plane0_size)));")?;
6063
6064 writeln!(self.out, "{l2}float y = tex.plane0.read(plane0_coords).x;")?;
6066
6067 writeln!(self.out, "{l2}float2 uv;")?;
6068 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6069 writeln!(self.out, "{l3}uv = tex.plane1.read(plane1_coords).xy;")?;
6071 writeln!(self.out, "{l2}}} else {{")?;
6072 writeln!(
6074 self.out,
6075 "{l2}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());"
6076 )?;
6077 writeln!(self.out, "{l2}uint2 plane2_coords = uint2({NAMESPACE}::floor(float2(plane0_coords) * float2(plane2_size) / float2(plane0_size)));")?;
6078 writeln!(
6079 self.out,
6080 "{l3}uv = float2(tex.plane1.read(plane1_coords).x, tex.plane2.read(plane2_coords).x);"
6081 )?;
6082 writeln!(self.out, "{l2}}}")?;
6083
6084 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6085
6086 writeln!(self.out, "{l1}}}")?;
6087 writeln!(self.out, "}}")?;
6088 writeln!(self.out)?;
6089 Ok(())
6090 }
6091
6092 #[allow(clippy::too_many_arguments)]
6093 fn write_wrapped_image_sample(
6094 &mut self,
6095 module: &crate::Module,
6096 func_ctx: &back::FunctionCtx,
6097 image: Handle<crate::Expression>,
6098 _sampler: Handle<crate::Expression>,
6099 _gather: Option<crate::SwizzleComponent>,
6100 _coordinate: Handle<crate::Expression>,
6101 _array_index: Option<Handle<crate::Expression>>,
6102 _offset: Option<Handle<crate::Expression>>,
6103 _level: crate::SampleLevel,
6104 _depth_ref: Option<Handle<crate::Expression>>,
6105 clamp_to_edge: bool,
6106 ) -> BackendResult {
6107 if !clamp_to_edge {
6110 return Ok(());
6111 }
6112 let class = match *func_ctx.resolve_type(image, &module.types) {
6113 crate::TypeInner::Image { class, .. } => class,
6114 _ => unreachable!(),
6115 };
6116 let wrapped = WrappedFunction::ImageSample {
6117 class,
6118 clamp_to_edge: true,
6119 };
6120 if !self.wrapped_functions.insert(wrapped) {
6121 return Ok(());
6122 }
6123 match class {
6124 crate::ImageClass::External => {
6125 writeln!(self.out, "float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex, {NAMESPACE}::sampler samp, float2 coords) {{")?;
6126 let l1 = back::Level(1);
6127 let l2 = l1.next();
6128 let l3 = l2.next();
6129 writeln!(self.out, "{l1}uint2 plane0_size = uint2(tex.plane0.get_width(), tex.plane0.get_height());")?;
6130 writeln!(
6131 self.out,
6132 "{l1}coords = tex.params.sample_transform * float3(coords, 1.0);"
6133 )?;
6134
6135 writeln!(
6143 self.out,
6144 "{l1}float2 bounds_min = tex.params.sample_transform * float3(0.0, 0.0, 1.0);"
6145 )?;
6146 writeln!(
6147 self.out,
6148 "{l1}float2 bounds_max = tex.params.sample_transform * float3(1.0, 1.0, 1.0);"
6149 )?;
6150 writeln!(self.out, "{l1}float4 bounds = float4({NAMESPACE}::min(bounds_min, bounds_max), {NAMESPACE}::max(bounds_min, bounds_max));")?;
6151 writeln!(
6152 self.out,
6153 "{l1}float2 plane0_half_texel = float2(0.5, 0.5) / float2(plane0_size);"
6154 )?;
6155 writeln!(
6156 self.out,
6157 "{l1}float2 plane0_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane0_half_texel, bounds.zw - plane0_half_texel);"
6158 )?;
6159 writeln!(self.out, "{l1}if (tex.params.num_planes == 1u) {{")?;
6160 writeln!(
6162 self.out,
6163 "{l2}return tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f));"
6164 )?;
6165 writeln!(self.out, "{l1}}} else {{")?;
6166 writeln!(self.out, "{l2}uint2 plane1_size = uint2(tex.plane1.get_width(), tex.plane1.get_height());")?;
6167 writeln!(
6168 self.out,
6169 "{l2}float2 plane1_half_texel = float2(0.5, 0.5) / float2(plane1_size);"
6170 )?;
6171 writeln!(
6172 self.out,
6173 "{l2}float2 plane1_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane1_half_texel, bounds.zw - plane1_half_texel);"
6174 )?;
6175
6176 writeln!(
6178 self.out,
6179 "{l2}float y = tex.plane0.sample(samp, plane0_coords, {NAMESPACE}::level(0.0f)).r;"
6180 )?;
6181 writeln!(self.out, "{l2}float2 uv = float2(0.0, 0.0);")?;
6182 writeln!(self.out, "{l2}if (tex.params.num_planes == 2u) {{")?;
6183 writeln!(
6185 self.out,
6186 "{l3}uv = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).xy;"
6187 )?;
6188 writeln!(self.out, "{l2}}} else {{")?;
6189 writeln!(self.out, "{l3}uint2 plane2_size = uint2(tex.plane2.get_width(), tex.plane2.get_height());")?;
6191 writeln!(
6192 self.out,
6193 "{l3}float2 plane2_half_texel = float2(0.5, 0.5) / float2(plane2_size);"
6194 )?;
6195 writeln!(
6196 self.out,
6197 "{l3}float2 plane2_coords = {NAMESPACE}::clamp(coords, bounds.xy + plane2_half_texel, bounds.zw - plane1_half_texel);"
6198 )?;
6199 writeln!(self.out, "{l3}uv.x = tex.plane1.sample(samp, plane1_coords, {NAMESPACE}::level(0.0f)).x;")?;
6200 writeln!(self.out, "{l3}uv.y = tex.plane2.sample(samp, plane2_coords, {NAMESPACE}::level(0.0f)).x;")?;
6201 writeln!(self.out, "{l2}}}")?;
6202
6203 self.write_convert_yuv_to_rgb_and_return(l2, "y", "uv", "tex.params")?;
6204
6205 writeln!(self.out, "{l1}}}")?;
6206 writeln!(self.out, "}}")?;
6207 writeln!(self.out)?;
6208 }
6209 _ => {
6210 writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
6211 let l1 = back::Level(1);
6212 writeln!(self.out, "{l1}{NAMESPACE}::float2 half_texel = 0.5 / {NAMESPACE}::float2(tex.get_width(0u), tex.get_height(0u));")?;
6213 writeln!(
6214 self.out,
6215 "{l1}return tex.sample(samp, {NAMESPACE}::clamp(coords, half_texel, 1.0 - half_texel), {NAMESPACE}::level(0.0));"
6216 )?;
6217 writeln!(self.out, "}}")?;
6218 writeln!(self.out)?;
6219 }
6220 }
6221 Ok(())
6222 }
6223
6224 fn write_wrapped_image_query(
6225 &mut self,
6226 module: &crate::Module,
6227 func_ctx: &back::FunctionCtx,
6228 image: Handle<crate::Expression>,
6229 query: crate::ImageQuery,
6230 ) -> BackendResult {
6231 if !matches!(query, crate::ImageQuery::Size { .. }) {
6233 return Ok(());
6234 }
6235 let class = match *func_ctx.resolve_type(image, &module.types) {
6236 crate::TypeInner::Image { class, .. } => class,
6237 _ => unreachable!(),
6238 };
6239 if class != crate::ImageClass::External {
6240 return Ok(());
6241 }
6242 let wrapped = WrappedFunction::ImageQuerySize { class };
6243 if !self.wrapped_functions.insert(wrapped) {
6244 return Ok(());
6245 }
6246 writeln!(
6247 self.out,
6248 "uint2 {IMAGE_SIZE_EXTERNAL_FUNCTION}({EXTERNAL_TEXTURE_WRAPPER_STRUCT} tex) {{"
6249 )?;
6250 let l1 = back::Level(1);
6251 let l2 = l1.next();
6252 writeln!(
6253 self.out,
6254 "{l1}if ({NAMESPACE}::any(tex.params.size != uint2(0u))) {{"
6255 )?;
6256 writeln!(self.out, "{l2}return tex.params.size;")?;
6257 writeln!(self.out, "{l1}}} else {{")?;
6258 writeln!(
6260 self.out,
6261 "{l2}return uint2(tex.plane0.get_width(), tex.plane0.get_height());"
6262 )?;
6263 writeln!(self.out, "{l1}}}")?;
6264 writeln!(self.out, "}}")?;
6265 writeln!(self.out)?;
6266 Ok(())
6267 }
6268
6269 pub(super) fn write_wrapped_functions(
6270 &mut self,
6271 module: &crate::Module,
6272 func_ctx: &back::FunctionCtx,
6273 ) -> BackendResult {
6274 for (expr_handle, expr) in func_ctx.expressions.iter() {
6275 match *expr {
6276 crate::Expression::Unary { op, expr: operand } => {
6277 self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
6278 }
6279 crate::Expression::Binary { op, left, right } => {
6280 self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
6281 }
6282 crate::Expression::Math {
6283 fun,
6284 arg,
6285 arg1,
6286 arg2,
6287 arg3,
6288 } => {
6289 self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
6290 }
6291 crate::Expression::As {
6292 expr,
6293 kind,
6294 convert,
6295 } => {
6296 self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
6297 }
6298 crate::Expression::ImageLoad {
6299 image,
6300 coordinate,
6301 array_index,
6302 sample,
6303 level,
6304 } => {
6305 self.write_wrapped_image_load(
6306 module,
6307 func_ctx,
6308 image,
6309 coordinate,
6310 array_index,
6311 sample,
6312 level,
6313 )?;
6314 }
6315 crate::Expression::ImageSample {
6316 image,
6317 sampler,
6318 gather,
6319 coordinate,
6320 array_index,
6321 offset,
6322 level,
6323 depth_ref,
6324 clamp_to_edge,
6325 } => {
6326 self.write_wrapped_image_sample(
6327 module,
6328 func_ctx,
6329 image,
6330 sampler,
6331 gather,
6332 coordinate,
6333 array_index,
6334 offset,
6335 level,
6336 depth_ref,
6337 clamp_to_edge,
6338 )?;
6339 }
6340 crate::Expression::ImageQuery { image, query } => {
6341 self.write_wrapped_image_query(module, func_ctx, image, query)?;
6342 }
6343 _ => {}
6344 }
6345 }
6346
6347 Ok(())
6348 }
6349
6350 fn write_functions(
6352 &mut self,
6353 module: &crate::Module,
6354 mod_info: &valid::ModuleInfo,
6355 options: &Options,
6356 pipeline_options: &PipelineOptions,
6357 ) -> Result<TranslationInfo, Error> {
6358 use back::msl::VertexFormat;
6359
6360 struct AttributeMappingResolved {
6363 ty_name: String,
6364 dimension: u32,
6365 ty_is_int: bool,
6366 name: String,
6367 }
6368 let mut am_resolved = FastHashMap::<u32, AttributeMappingResolved>::default();
6369
6370 struct VertexBufferMappingResolved<'a> {
6371 id: u32,
6372 stride: u32,
6373 step_mode: back::msl::VertexBufferStepMode,
6374 ty_name: String,
6375 param_name: String,
6376 elem_name: String,
6377 attributes: &'a Vec<back::msl::AttributeMapping>,
6378 }
6379 let mut vbm_resolved = Vec::<VertexBufferMappingResolved>::new();
6380
6381 struct UnpackingFunction {
6383 name: String,
6384 byte_count: u32,
6385 dimension: u32,
6386 }
6387 let mut unpacking_functions = FastHashMap::<VertexFormat, UnpackingFunction>::default();
6388
6389 let mut needs_vertex_id = false;
6395 let v_id = self.namer.call("v_id");
6396
6397 let mut needs_instance_id = false;
6398 let i_id = self.namer.call("i_id");
6399 if pipeline_options.vertex_pulling_transform {
6400 for vbm in &pipeline_options.vertex_buffer_mappings {
6401 let buffer_id = vbm.id;
6402 let buffer_stride = vbm.stride;
6403
6404 assert!(
6405 buffer_stride > 0,
6406 "Vertex pulling requires a non-zero buffer stride."
6407 );
6408
6409 match vbm.step_mode {
6410 back::msl::VertexBufferStepMode::Constant => {}
6411 back::msl::VertexBufferStepMode::ByVertex => {
6412 needs_vertex_id = true;
6413 }
6414 back::msl::VertexBufferStepMode::ByInstance => {
6415 needs_instance_id = true;
6416 }
6417 }
6418
6419 let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str());
6420 let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str());
6421 let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str());
6422
6423 vbm_resolved.push(VertexBufferMappingResolved {
6424 id: buffer_id,
6425 stride: buffer_stride,
6426 step_mode: vbm.step_mode,
6427 ty_name: buffer_ty,
6428 param_name: buffer_param,
6429 elem_name: buffer_elem,
6430 attributes: &vbm.attributes,
6431 });
6432
6433 for attribute in &vbm.attributes {
6435 if unpacking_functions.contains_key(&attribute.format) {
6436 continue;
6437 }
6438 let (name, byte_count, dimension) =
6439 match self.write_unpacking_function(attribute.format) {
6440 Ok((name, byte_count, dimension)) => (name, byte_count, dimension),
6441 _ => {
6442 continue;
6443 }
6444 };
6445 unpacking_functions.insert(
6446 attribute.format,
6447 UnpackingFunction {
6448 name,
6449 byte_count,
6450 dimension,
6451 },
6452 );
6453 }
6454 }
6455 }
6456
6457 let mut pass_through_globals = Vec::new();
6458 for (fun_handle, fun) in module.functions.iter() {
6459 log::trace!(
6460 "function {:?}, handle {:?}",
6461 fun.name.as_deref().unwrap_or("(anonymous)"),
6462 fun_handle
6463 );
6464
6465 let ctx = back::FunctionCtx {
6466 ty: back::FunctionType::Function(fun_handle),
6467 info: &mod_info[fun_handle],
6468 expressions: &fun.expressions,
6469 named_expressions: &fun.named_expressions,
6470 };
6471
6472 writeln!(self.out)?;
6473 self.write_wrapped_functions(module, &ctx)?;
6474
6475 let fun_info = &mod_info[fun_handle];
6476 pass_through_globals.clear();
6477 let mut needs_buffer_sizes = false;
6478 for (handle, var) in module.global_variables.iter() {
6479 if !fun_info[handle].is_empty() {
6480 if var.space.needs_pass_through() {
6481 pass_through_globals.push(handle);
6482 }
6483 needs_buffer_sizes |= needs_array_length(var.ty, &module.types);
6484 }
6485 }
6486
6487 let fun_name = &self.names[&NameKey::Function(fun_handle)];
6488 match fun.result {
6489 Some(ref result) => {
6490 let ty_name = TypeContext {
6491 handle: result.ty,
6492 gctx: module.to_ctx(),
6493 names: &self.names,
6494 access: crate::StorageAccess::empty(),
6495 first_time: false,
6496 };
6497 write!(self.out, "{ty_name}")?;
6498 }
6499 None => {
6500 write!(self.out, "void")?;
6501 }
6502 }
6503 writeln!(self.out, " {fun_name}(")?;
6504
6505 for (index, arg) in fun.arguments.iter().enumerate() {
6506 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
6507 let param_type_name = TypeContext {
6508 handle: arg.ty,
6509 gctx: module.to_ctx(),
6510 names: &self.names,
6511 access: crate::StorageAccess::empty(),
6512 first_time: false,
6513 };
6514 let separator = separate(
6515 !pass_through_globals.is_empty()
6516 || index + 1 != fun.arguments.len()
6517 || needs_buffer_sizes,
6518 );
6519 writeln!(
6520 self.out,
6521 "{}{} {}{}",
6522 back::INDENT,
6523 param_type_name,
6524 name,
6525 separator
6526 )?;
6527 }
6528 for (index, &handle) in pass_through_globals.iter().enumerate() {
6529 let tyvar = TypedGlobalVariable {
6530 module,
6531 names: &self.names,
6532 handle,
6533 usage: fun_info[handle],
6534
6535 reference: true,
6536 };
6537 let separator =
6538 separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes);
6539 write!(self.out, "{}", back::INDENT)?;
6540 tyvar.try_fmt(&mut self.out)?;
6541 writeln!(self.out, "{separator}")?;
6542 }
6543
6544 if needs_buffer_sizes {
6545 writeln!(
6546 self.out,
6547 "{}constant _mslBufferSizes& _buffer_sizes",
6548 back::INDENT
6549 )?;
6550 }
6551
6552 writeln!(self.out, ") {{")?;
6553
6554 let guarded_indices =
6555 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
6556
6557 let context = StatementContext {
6558 expression: ExpressionContext {
6559 function: fun,
6560 origin: FunctionOrigin::Handle(fun_handle),
6561 info: fun_info,
6562 lang_version: options.lang_version,
6563 policies: options.bounds_check_policies,
6564 guarded_indices,
6565 module,
6566 mod_info,
6567 pipeline_options,
6568 force_loop_bounding: options.force_loop_bounding,
6569 },
6570 result_struct: None,
6571 };
6572
6573 self.put_locals(&context.expression)?;
6574 self.update_expressions_to_bake(fun, fun_info, &context.expression);
6575 self.put_block(back::Level(1), &fun.body, &context)?;
6576 writeln!(self.out, "}}")?;
6577 self.named_expressions.clear();
6578 }
6579
6580 let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
6581 .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;
6582
6583 let mut info = TranslationInfo {
6584 entry_point_names: Vec::with_capacity(ep_range.len()),
6585 };
6586
6587 for ep_index in ep_range {
6588 let ep = &module.entry_points[ep_index];
6589 let fun = &ep.function;
6590 let fun_info = mod_info.get_entry_point(ep_index);
6591 let mut ep_error = None;
6592
6593 let mut v_existing_id = None;
6597 let mut i_existing_id = None;
6598
6599 log::trace!(
6600 "entry point {:?}, index {:?}",
6601 fun.name.as_deref().unwrap_or("(anonymous)"),
6602 ep_index
6603 );
6604
6605 let ctx = back::FunctionCtx {
6606 ty: back::FunctionType::EntryPoint(ep_index as u16),
6607 info: fun_info,
6608 expressions: &fun.expressions,
6609 named_expressions: &fun.named_expressions,
6610 };
6611
6612 self.write_wrapped_functions(module, &ctx)?;
6613
6614 let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage {
6615 crate::ShaderStage::Vertex => (
6616 "vertex",
6617 LocationMode::VertexInput,
6618 LocationMode::VertexOutput,
6619 true,
6620 ),
6621 crate::ShaderStage::Fragment => (
6622 "fragment",
6623 LocationMode::FragmentInput,
6624 LocationMode::FragmentOutput,
6625 false,
6626 ),
6627 crate::ShaderStage::Compute => (
6628 "kernel",
6629 LocationMode::Uniform,
6630 LocationMode::Uniform,
6631 false,
6632 ),
6633 crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(),
6634 };
6635
6636 let do_vertex_pulling = can_vertex_pull
6638 && pipeline_options.vertex_pulling_transform
6639 && !pipeline_options.vertex_buffer_mappings.is_empty();
6640
6641 let needs_buffer_sizes = do_vertex_pulling
6643 || module
6644 .global_variables
6645 .iter()
6646 .filter(|&(handle, _)| !fun_info[handle].is_empty())
6647 .any(|(_, var)| needs_array_length(var.ty, &module.types));
6648
6649 if !options.fake_missing_bindings {
6652 for (var_handle, var) in module.global_variables.iter() {
6653 if fun_info[var_handle].is_empty() {
6654 continue;
6655 }
6656 match var.space {
6657 crate::AddressSpace::Uniform
6658 | crate::AddressSpace::Storage { .. }
6659 | crate::AddressSpace::Handle => {
6660 let br = match var.binding {
6661 Some(ref br) => br,
6662 None => {
6663 let var_name = var.name.clone().unwrap_or_default();
6664 ep_error =
6665 Some(super::EntryPointError::MissingBinding(var_name));
6666 break;
6667 }
6668 };
6669 let target = options.get_resource_binding_target(ep, br);
6670 let good = match target {
6671 Some(target) => {
6672 match module.types[var.ty].inner {
6676 crate::TypeInner::Image {
6677 class: crate::ImageClass::External,
6678 ..
6679 } => target.external_texture.is_some(),
6680 crate::TypeInner::Image { .. } => target.texture.is_some(),
6681 crate::TypeInner::Sampler { .. } => {
6682 target.sampler.is_some()
6683 }
6684 _ => target.buffer.is_some(),
6685 }
6686 }
6687 None => false,
6688 };
6689 if !good {
6690 ep_error = Some(super::EntryPointError::MissingBindTarget(*br));
6691 break;
6692 }
6693 }
6694 crate::AddressSpace::PushConstant => {
6695 if let Err(e) = options.resolve_push_constants(ep) {
6696 ep_error = Some(e);
6697 break;
6698 }
6699 }
6700 crate::AddressSpace::TaskPayload => {
6701 unimplemented!()
6702 }
6703 crate::AddressSpace::Function
6704 | crate::AddressSpace::Private
6705 | crate::AddressSpace::WorkGroup => {}
6706 }
6707 }
6708 if needs_buffer_sizes {
6709 if let Err(err) = options.resolve_sizes_buffer(ep) {
6710 ep_error = Some(err);
6711 }
6712 }
6713 }
6714
6715 if let Some(err) = ep_error {
6716 info.entry_point_names.push(Err(err));
6717 continue;
6718 }
6719 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
6720 info.entry_point_names.push(Ok(fun_name.clone()));
6721
6722 writeln!(self.out)?;
6723
6724 let mut flattened_member_names = FastHashMap::default();
6730 let mut varyings_namer = proc::Namer::default();
6732
6733 let mut flattened_arguments = Vec::new();
6738 for (arg_index, arg) in fun.arguments.iter().enumerate() {
6739 match module.types[arg.ty].inner {
6740 crate::TypeInner::Struct { ref members, .. } => {
6741 for (member_index, member) in members.iter().enumerate() {
6742 let member_index = member_index as u32;
6743 flattened_arguments.push((
6744 NameKey::StructMember(arg.ty, member_index),
6745 member.ty,
6746 member.binding.as_ref(),
6747 ));
6748 let name_key = NameKey::StructMember(arg.ty, member_index);
6749 let name = match member.binding {
6750 Some(crate::Binding::Location { .. }) => {
6751 if do_vertex_pulling {
6752 self.namer.call(&self.names[&name_key])
6753 } else {
6754 varyings_namer.call(&self.names[&name_key])
6755 }
6756 }
6757 _ => self.namer.call(&self.names[&name_key]),
6758 };
6759 flattened_member_names.insert(name_key, name);
6760 }
6761 }
6762 _ => flattened_arguments.push((
6763 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
6764 arg.ty,
6765 arg.binding.as_ref(),
6766 )),
6767 }
6768 }
6769
6770 let stage_in_name = self.namer.call(&format!("{fun_name}Input"));
6775 let varyings_member_name = self.namer.call("varyings");
6776 let mut has_varyings = false;
6777 if !flattened_arguments.is_empty() {
6778 if !do_vertex_pulling {
6779 writeln!(self.out, "struct {stage_in_name} {{")?;
6780 }
6781 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6782 let (binding, location) = match binding {
6783 Some(ref binding @ &crate::Binding::Location { location, .. }) => {
6784 (binding, location)
6785 }
6786 _ => continue,
6787 };
6788 let name = match *name_key {
6789 NameKey::StructMember(..) => &flattened_member_names[name_key],
6790 _ => &self.names[name_key],
6791 };
6792 let ty_name = TypeContext {
6793 handle: ty,
6794 gctx: module.to_ctx(),
6795 names: &self.names,
6796 access: crate::StorageAccess::empty(),
6797 first_time: false,
6798 };
6799 let resolved = options.resolve_local_binding(binding, in_mode)?;
6800 if do_vertex_pulling {
6801 am_resolved.insert(
6803 location,
6804 AttributeMappingResolved {
6805 ty_name: ty_name.to_string(),
6806 dimension: ty_name.vertex_input_dimension(),
6807 ty_is_int: ty_name.scalar().is_some_and(scalar_is_int),
6808 name: name.to_string(),
6809 },
6810 );
6811 } else {
6812 has_varyings = true;
6813 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6814 resolved.try_fmt(&mut self.out)?;
6815 writeln!(self.out, ";")?;
6816 }
6817 }
6818 if !do_vertex_pulling {
6819 writeln!(self.out, "}};")?;
6820 }
6821 }
6822
6823 let stage_out_name = self.namer.call(&format!("{fun_name}Output"));
6826 let result_member_name = self.namer.call("member");
6827 let result_type_name = match fun.result {
6828 Some(ref result) => {
6829 let mut result_members = Vec::new();
6830 if let crate::TypeInner::Struct { ref members, .. } =
6831 module.types[result.ty].inner
6832 {
6833 for (member_index, member) in members.iter().enumerate() {
6834 result_members.push((
6835 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
6836 member.ty,
6837 member.binding.as_ref(),
6838 ));
6839 }
6840 } else {
6841 result_members.push((
6842 &result_member_name,
6843 result.ty,
6844 result.binding.as_ref(),
6845 ));
6846 }
6847
6848 writeln!(self.out, "struct {stage_out_name} {{")?;
6849 let mut has_point_size = false;
6850 for (name, ty, binding) in result_members {
6851 let ty_name = TypeContext {
6852 handle: ty,
6853 gctx: module.to_ctx(),
6854 names: &self.names,
6855 access: crate::StorageAccess::empty(),
6856 first_time: true,
6857 };
6858 let binding = binding.ok_or_else(|| {
6859 Error::GenericValidation("Expected binding, got None".into())
6860 })?;
6861
6862 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
6863 has_point_size = true;
6864 if !pipeline_options.allow_and_force_point_size {
6865 continue;
6866 }
6867 }
6868
6869 let array_len = match module.types[ty].inner {
6870 crate::TypeInner::Array {
6871 size: crate::ArraySize::Constant(size),
6872 ..
6873 } => Some(size),
6874 _ => None,
6875 };
6876 let resolved = options.resolve_local_binding(binding, out_mode)?;
6877 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
6878 if let Some(array_len) = array_len {
6879 write!(self.out, " [{array_len}]")?;
6880 }
6881 resolved.try_fmt(&mut self.out)?;
6882 writeln!(self.out, ";")?;
6883 }
6884
6885 if pipeline_options.allow_and_force_point_size
6886 && ep.stage == crate::ShaderStage::Vertex
6887 && !has_point_size
6888 {
6889 writeln!(
6891 self.out,
6892 "{}float _point_size [[point_size]];",
6893 back::INDENT
6894 )?;
6895 }
6896 writeln!(self.out, "}};")?;
6897 &stage_out_name
6898 }
6899 None => "void",
6900 };
6901
6902 if do_vertex_pulling {
6905 for vbm in &vbm_resolved {
6906 let buffer_stride = vbm.stride;
6907 let buffer_ty = &vbm.ty_name;
6908
6909 writeln!(
6913 self.out,
6914 "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};"
6915 )?;
6916 }
6917 }
6918
6919 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
6921
6922 let mut is_first_argument = true;
6923 let mut separator = || {
6924 if is_first_argument {
6925 is_first_argument = false;
6926 ' '
6927 } else {
6928 ','
6929 }
6930 };
6931
6932 if has_varyings {
6935 writeln!(
6936 self.out,
6937 "{} {stage_in_name} {varyings_member_name} [[stage_in]]",
6938 separator()
6939 )?;
6940 }
6941
6942 let mut local_invocation_id = None;
6943
6944 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
6947 let binding = match binding {
6948 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
6949 _ => continue,
6950 };
6951 let name = match *name_key {
6952 NameKey::StructMember(..) => &flattened_member_names[name_key],
6953 _ => &self.names[name_key],
6954 };
6955
6956 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
6957 local_invocation_id = Some(name_key);
6958 }
6959
6960 let ty_name = TypeContext {
6961 handle: ty,
6962 gctx: module.to_ctx(),
6963 names: &self.names,
6964 access: crate::StorageAccess::empty(),
6965 first_time: false,
6966 };
6967
6968 match *binding {
6969 crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => {
6970 v_existing_id = Some(name.clone());
6971 }
6972 crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => {
6973 i_existing_id = Some(name.clone());
6974 }
6975 _ => {}
6976 };
6977
6978 let resolved = options.resolve_local_binding(binding, in_mode)?;
6979 write!(self.out, "{} {ty_name} {name}", separator())?;
6980 resolved.try_fmt(&mut self.out)?;
6981 writeln!(self.out)?;
6982 }
6983
6984 let need_workgroup_variables_initialization =
6985 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
6986
6987 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
6988 writeln!(
6989 self.out,
6990 "{} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]", separator()
6991 )?;
6992 }
6993
6994 for (handle, var) in module.global_variables.iter() {
6999 let usage = fun_info[handle];
7000 if usage.is_empty() || var.space == crate::AddressSpace::Private {
7001 continue;
7002 }
7003
7004 if options.lang_version < (1, 2) {
7005 match var.space {
7006 crate::AddressSpace::Storage { access }
7016 if access.contains(crate::StorageAccess::STORE)
7017 && ep.stage == crate::ShaderStage::Fragment =>
7018 {
7019 return Err(Error::UnsupportedWriteableStorageBuffer)
7020 }
7021 crate::AddressSpace::Handle => {
7022 match module.types[var.ty].inner {
7023 crate::TypeInner::Image {
7024 class: crate::ImageClass::Storage { access, .. },
7025 ..
7026 } => {
7027 if access.contains(crate::StorageAccess::STORE)
7037 && (ep.stage == crate::ShaderStage::Vertex
7038 || ep.stage == crate::ShaderStage::Fragment)
7039 {
7040 return Err(Error::UnsupportedWriteableStorageTexture(
7041 ep.stage,
7042 ));
7043 }
7044
7045 if access.contains(
7046 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
7047 ) {
7048 return Err(Error::UnsupportedRWStorageTexture);
7049 }
7050 }
7051 _ => {}
7052 }
7053 }
7054 _ => {}
7055 }
7056 }
7057
7058 match var.space {
7060 crate::AddressSpace::Handle => match module.types[var.ty].inner {
7061 crate::TypeInner::BindingArray { base, .. } => {
7062 match module.types[base].inner {
7063 crate::TypeInner::Sampler { .. } => {
7064 if options.lang_version < (2, 0) {
7065 return Err(Error::UnsupportedArrayOf(
7066 "samplers".to_string(),
7067 ));
7068 }
7069 }
7070 crate::TypeInner::Image { class, .. } => match class {
7071 crate::ImageClass::Sampled { .. }
7072 | crate::ImageClass::Depth { .. }
7073 | crate::ImageClass::Storage {
7074 access: crate::StorageAccess::LOAD,
7075 ..
7076 } => {
7077 if options.lang_version < (2, 0) {
7082 return Err(Error::UnsupportedArrayOf(
7083 "textures".to_string(),
7084 ));
7085 }
7086 }
7087 crate::ImageClass::Storage {
7088 access: crate::StorageAccess::STORE,
7089 ..
7090 } => {
7091 if options.lang_version < (2, 0) {
7096 return Err(Error::UnsupportedArrayOf(
7097 "write-only textures".to_string(),
7098 ));
7099 }
7100 }
7101 crate::ImageClass::Storage { .. } => {
7102 return Err(Error::UnsupportedArrayOf(
7103 "read-write textures".to_string(),
7104 ));
7105 }
7106 crate::ImageClass::External => {
7107 return Err(Error::UnsupportedArrayOf(
7108 "external textures".to_string(),
7109 ));
7110 }
7111 },
7112 _ => {
7113 return Err(Error::UnsupportedArrayOfType(base));
7114 }
7115 }
7116 }
7117 _ => {}
7118 },
7119 _ => {}
7120 }
7121
7122 let resolved = match var.space {
7124 crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
7125 crate::AddressSpace::WorkGroup => None,
7126 _ => options
7127 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
7128 .ok(),
7129 };
7130 if let Some(ref resolved) = resolved {
7131 if resolved.as_inline_sampler(options).is_some() {
7133 continue;
7134 }
7135 }
7136
7137 match module.types[var.ty].inner {
7138 crate::TypeInner::Image {
7139 class: crate::ImageClass::External,
7140 ..
7141 } => {
7142 let target = match resolved {
7146 Some(back::msl::ResolvedBinding::Resource(target)) => {
7147 target.external_texture
7148 }
7149 _ => None,
7150 };
7151
7152 for i in 0..3 {
7153 write!(self.out, "{} ", separator())?;
7154
7155 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7156 handle,
7157 ExternalTextureNameKey::Plane(i),
7158 )];
7159 write!(
7160 self.out,
7161 "{NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> {plane_name}"
7162 )?;
7163 if let Some(ref target) = target {
7164 write!(self.out, " [[texture({})]]", target.planes[i])?;
7165 }
7166 writeln!(self.out)?;
7167 }
7168 let params_ty_name = &self.names
7169 [&NameKey::Type(module.special_types.external_texture_params.unwrap())];
7170 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7171 handle,
7172 ExternalTextureNameKey::Params,
7173 )];
7174 write!(self.out, "{} ", separator())?;
7175 write!(self.out, "constant {params_ty_name}& {params_name}")?;
7176 if let Some(ref target) = target {
7177 write!(self.out, " [[buffer({})]]", target.params)?;
7178 }
7179 }
7180 _ => {
7181 let tyvar = TypedGlobalVariable {
7182 module,
7183 names: &self.names,
7184 handle,
7185 usage,
7186 reference: true,
7187 };
7188 write!(self.out, "{} ", separator())?;
7189 tyvar.try_fmt(&mut self.out)?;
7190 if let Some(resolved) = resolved {
7191 resolved.try_fmt(&mut self.out)?;
7192 }
7193 if let Some(value) = var.init {
7194 write!(self.out, " = ")?;
7195 self.put_const_expression(
7196 value,
7197 module,
7198 mod_info,
7199 &module.global_expressions,
7200 )?;
7201 }
7202 }
7203 }
7204 writeln!(self.out)?;
7205 }
7206
7207 if do_vertex_pulling {
7208 if needs_vertex_id && v_existing_id.is_none() {
7209 writeln!(self.out, "{} uint {v_id} [[vertex_id]]", separator())?;
7211 }
7212
7213 if needs_instance_id && i_existing_id.is_none() {
7214 writeln!(self.out, "{} uint {i_id} [[instance_id]]", separator())?;
7215 }
7216
7217 for vbm in &vbm_resolved {
7220 let id = &vbm.id;
7221 let ty_name = &vbm.ty_name;
7222 let param_name = &vbm.param_name;
7223 writeln!(
7224 self.out,
7225 "{} const device {ty_name}* {param_name} [[buffer({id})]]",
7226 separator()
7227 )?;
7228 }
7229 }
7230
7231 if needs_buffer_sizes {
7234 let resolved = options.resolve_sizes_buffer(ep).unwrap();
7236 write!(
7237 self.out,
7238 "{} constant _mslBufferSizes& _buffer_sizes",
7239 separator()
7240 )?;
7241 resolved.try_fmt(&mut self.out)?;
7242 writeln!(self.out)?;
7243 }
7244
7245 writeln!(self.out, ") {{")?;
7247
7248 if do_vertex_pulling {
7250 for vbm in &vbm_resolved {
7253 for attribute in vbm.attributes {
7254 let location = attribute.shader_location;
7255 let am_option = am_resolved.get(&location);
7256 if am_option.is_none() {
7257 continue;
7260 }
7261 let am = am_option.unwrap();
7262 let attribute_ty_name = &am.ty_name;
7263 let attribute_name = &am.name;
7264
7265 writeln!(
7266 self.out,
7267 "{}{attribute_ty_name} {attribute_name} = {{}};",
7268 back::Level(1)
7269 )?;
7270 }
7271
7272 write!(self.out, "{}if (", back::Level(1))?;
7275
7276 let idx = &vbm.id;
7277 let stride = &vbm.stride;
7278 let index_name = match vbm.step_mode {
7279 back::msl::VertexBufferStepMode::Constant => "0",
7280 back::msl::VertexBufferStepMode::ByVertex => {
7281 if let Some(ref name) = v_existing_id {
7282 name
7283 } else {
7284 &v_id
7285 }
7286 }
7287 back::msl::VertexBufferStepMode::ByInstance => {
7288 if let Some(ref name) = i_existing_id {
7289 name
7290 } else {
7291 &i_id
7292 }
7293 }
7294 };
7295 write!(
7296 self.out,
7297 "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})"
7298 )?;
7299
7300 writeln!(self.out, ") {{")?;
7301
7302 let ty_name = &vbm.ty_name;
7304 let elem_name = &vbm.elem_name;
7305 let param_name = &vbm.param_name;
7306
7307 writeln!(
7308 self.out,
7309 "{}const {ty_name} {elem_name} = {param_name}[{index_name}];",
7310 back::Level(2),
7311 )?;
7312
7313 for attribute in vbm.attributes {
7316 let location = attribute.shader_location;
7317 let Some(am) = am_resolved.get(&location) else {
7318 continue;
7322 };
7323 let attribute_name = &am.name;
7324 let attribute_ty_name = &am.ty_name;
7325
7326 let offset = attribute.offset;
7327 let func = unpacking_functions
7328 .get(&attribute.format)
7329 .expect("Should have generated this unpacking function earlier.");
7330 let func_name = &func.name;
7331
7332 let needs_padding_or_truncation = am.dimension.cmp(&func.dimension);
7339
7340 if needs_padding_or_truncation != Ordering::Equal {
7341 writeln!(
7344 self.out,
7345 "{}// {attribute_ty_name} <- {:?}",
7346 back::Level(2),
7347 attribute.format
7348 )?;
7349 }
7350
7351 write!(self.out, "{}{attribute_name} = ", back::Level(2),)?;
7352
7353 if needs_padding_or_truncation == Ordering::Greater {
7354 write!(self.out, "{attribute_ty_name}(")?;
7356 }
7357
7358 write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?;
7360 for i in (offset + 1)..(offset + func.byte_count) {
7361 write!(self.out, ", {elem_name}.data[{i}]")?;
7362 }
7363 write!(self.out, ")")?;
7364
7365 match needs_padding_or_truncation {
7366 Ordering::Greater => {
7367 let zero_value = if am.ty_is_int { "0" } else { "0.0" };
7369 let one_value = if am.ty_is_int { "1" } else { "1.0" };
7370 for i in func.dimension..am.dimension {
7371 write!(
7372 self.out,
7373 ", {}",
7374 if i == 3 { one_value } else { zero_value }
7375 )?;
7376 }
7377 write!(self.out, ")")?;
7378 }
7379 Ordering::Less => {
7380 write!(
7382 self.out,
7383 ".{}",
7384 &"xyzw"[0..usize::try_from(am.dimension).unwrap()]
7385 )?;
7386 }
7387 Ordering::Equal => {}
7388 }
7389
7390 writeln!(self.out, ";")?;
7391 }
7392
7393 writeln!(self.out, "{}}}", back::Level(1))?;
7395 }
7396 }
7397
7398 if need_workgroup_variables_initialization {
7399 self.write_workgroup_variables_initialization(
7400 module,
7401 mod_info,
7402 fun_info,
7403 local_invocation_id,
7404 )?;
7405 }
7406
7407 for (handle, var) in module.global_variables.iter() {
7410 let usage = fun_info[handle];
7411 if usage.is_empty() {
7412 continue;
7413 }
7414 if var.space == crate::AddressSpace::Private {
7415 let tyvar = TypedGlobalVariable {
7416 module,
7417 names: &self.names,
7418 handle,
7419 usage,
7420
7421 reference: false,
7422 };
7423 write!(self.out, "{}", back::INDENT)?;
7424 tyvar.try_fmt(&mut self.out)?;
7425 match var.init {
7426 Some(value) => {
7427 write!(self.out, " = ")?;
7428 self.put_const_expression(
7429 value,
7430 module,
7431 mod_info,
7432 &module.global_expressions,
7433 )?;
7434 writeln!(self.out, ";")?;
7435 }
7436 None => {
7437 writeln!(self.out, " = {{}};")?;
7438 }
7439 };
7440 } else if let Some(ref binding) = var.binding {
7441 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
7442 if let Some(sampler) = resolved.as_inline_sampler(options) {
7443 let name = &self.names[&NameKey::GlobalVariable(handle)];
7445 writeln!(
7446 self.out,
7447 "{}constexpr {}::sampler {}(",
7448 back::INDENT,
7449 NAMESPACE,
7450 name
7451 )?;
7452 self.put_inline_sampler_properties(back::Level(2), sampler)?;
7453 writeln!(self.out, "{});", back::INDENT)?;
7454 } else if let crate::TypeInner::Image {
7455 class: crate::ImageClass::External,
7456 ..
7457 } = module.types[var.ty].inner
7458 {
7459 let wrapper_name = &self.names[&NameKey::GlobalVariable(handle)];
7462 let l1 = back::Level(1);
7463 let l2 = l1.next();
7464 writeln!(
7465 self.out,
7466 "{l1}const {EXTERNAL_TEXTURE_WRAPPER_STRUCT} {wrapper_name} {{"
7467 )?;
7468 for i in 0..3 {
7469 let plane_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7470 handle,
7471 ExternalTextureNameKey::Plane(i),
7472 )];
7473 writeln!(self.out, "{l2}.plane{i} = {plane_name},")?;
7474 }
7475 let params_name = &self.names[&NameKey::ExternalTextureGlobalVariable(
7476 handle,
7477 ExternalTextureNameKey::Params,
7478 )];
7479 writeln!(self.out, "{l2}.params = {params_name},")?;
7480 writeln!(self.out, "{l1}}};")?;
7481 }
7482 }
7483 }
7484
7485 for (arg_index, arg) in fun.arguments.iter().enumerate() {
7496 let arg_name =
7497 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
7498 match module.types[arg.ty].inner {
7499 crate::TypeInner::Struct { ref members, .. } => {
7500 let struct_name = &self.names[&NameKey::Type(arg.ty)];
7501 write!(
7502 self.out,
7503 "{}const {} {} = {{ ",
7504 back::INDENT,
7505 struct_name,
7506 arg_name
7507 )?;
7508 for (member_index, member) in members.iter().enumerate() {
7509 let key = NameKey::StructMember(arg.ty, member_index as u32);
7510 let name = &flattened_member_names[&key];
7511 if member_index != 0 {
7512 write!(self.out, ", ")?;
7513 }
7514 if self
7516 .struct_member_pads
7517 .contains(&(arg.ty, member_index as u32))
7518 {
7519 write!(self.out, "{{}}, ")?;
7520 }
7521 if let Some(crate::Binding::Location { .. }) = member.binding {
7522 if has_varyings {
7523 write!(self.out, "{varyings_member_name}.")?;
7524 }
7525 }
7526 write!(self.out, "{name}")?;
7527 }
7528 writeln!(self.out, " }};")?;
7529 }
7530 _ => {
7531 if let Some(crate::Binding::Location { .. }) = arg.binding {
7532 if has_varyings {
7533 writeln!(
7534 self.out,
7535 "{}const auto {} = {}.{};",
7536 back::INDENT,
7537 arg_name,
7538 varyings_member_name,
7539 arg_name
7540 )?;
7541 }
7542 }
7543 }
7544 }
7545 }
7546
7547 let guarded_indices =
7548 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
7549
7550 let context = StatementContext {
7551 expression: ExpressionContext {
7552 function: fun,
7553 origin: FunctionOrigin::EntryPoint(ep_index as _),
7554 info: fun_info,
7555 lang_version: options.lang_version,
7556 policies: options.bounds_check_policies,
7557 guarded_indices,
7558 module,
7559 mod_info,
7560 pipeline_options,
7561 force_loop_bounding: options.force_loop_bounding,
7562 },
7563 result_struct: Some(&stage_out_name),
7564 };
7565
7566 self.put_locals(&context.expression)?;
7569 self.update_expressions_to_bake(fun, fun_info, &context.expression);
7570 self.put_block(back::Level(1), &fun.body, &context)?;
7571 writeln!(self.out, "}}")?;
7572 if ep_index + 1 != module.entry_points.len() {
7573 writeln!(self.out)?;
7574 }
7575 self.named_expressions.clear();
7576 }
7577
7578 Ok(info)
7579 }
7580
7581 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
7582 if flags.is_empty() {
7585 writeln!(
7586 self.out,
7587 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
7588 )?;
7589 }
7590 if flags.contains(crate::Barrier::STORAGE) {
7591 writeln!(
7592 self.out,
7593 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
7594 )?;
7595 }
7596 if flags.contains(crate::Barrier::WORK_GROUP) {
7597 writeln!(
7598 self.out,
7599 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7600 )?;
7601 }
7602 if flags.contains(crate::Barrier::SUB_GROUP) {
7603 writeln!(
7604 self.out,
7605 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
7606 )?;
7607 }
7608 if flags.contains(crate::Barrier::TEXTURE) {
7609 writeln!(
7610 self.out,
7611 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_texture);",
7612 )?;
7613 }
7614 Ok(())
7615 }
7616}
7617
7618mod workgroup_mem_init {
7621 use crate::EntryPoint;
7622
7623 use super::*;
7624
7625 enum Access {
7626 GlobalVariable(Handle<crate::GlobalVariable>),
7627 StructMember(Handle<crate::Type>, u32),
7628 Array(usize),
7629 }
7630
7631 impl Access {
7632 fn write<W: Write>(
7633 &self,
7634 writer: &mut W,
7635 names: &FastHashMap<NameKey, String>,
7636 ) -> Result<(), core::fmt::Error> {
7637 match *self {
7638 Access::GlobalVariable(handle) => {
7639 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
7640 }
7641 Access::StructMember(handle, index) => {
7642 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
7643 }
7644 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
7645 }
7646 }
7647 }
7648
7649 struct AccessStack {
7650 stack: Vec<Access>,
7651 array_depth: usize,
7652 }
7653
7654 impl AccessStack {
7655 const fn new() -> Self {
7656 Self {
7657 stack: Vec::new(),
7658 array_depth: 0,
7659 }
7660 }
7661
7662 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
7663 let array_depth = self.array_depth;
7664 self.stack.push(Access::Array(array_depth));
7665 self.array_depth += 1;
7666 let res = cb(self, array_depth);
7667 self.stack.pop();
7668 self.array_depth -= 1;
7669 res
7670 }
7671
7672 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
7673 self.stack.push(new);
7674 let res = cb(self);
7675 self.stack.pop();
7676 res
7677 }
7678
7679 fn write<W: Write>(
7680 &self,
7681 writer: &mut W,
7682 names: &FastHashMap<NameKey, String>,
7683 ) -> Result<(), core::fmt::Error> {
7684 for next in self.stack.iter() {
7685 next.write(writer, names)?;
7686 }
7687 Ok(())
7688 }
7689 }
7690
7691 impl<W: Write> Writer<W> {
7692 pub(super) fn need_workgroup_variables_initialization(
7693 &mut self,
7694 options: &Options,
7695 ep: &EntryPoint,
7696 module: &crate::Module,
7697 fun_info: &valid::FunctionInfo,
7698 ) -> bool {
7699 options.zero_initialize_workgroup_memory
7700 && ep.stage.compute_like()
7701 && module.global_variables.iter().any(|(handle, var)| {
7702 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7703 })
7704 }
7705
7706 pub(super) fn write_workgroup_variables_initialization(
7707 &mut self,
7708 module: &crate::Module,
7709 module_info: &valid::ModuleInfo,
7710 fun_info: &valid::FunctionInfo,
7711 local_invocation_id: Option<&NameKey>,
7712 ) -> BackendResult {
7713 let level = back::Level(1);
7714
7715 writeln!(
7716 self.out,
7717 "{}if ({}::all({} == {}::uint3(0u))) {{",
7718 level,
7719 NAMESPACE,
7720 local_invocation_id
7721 .map(|name_key| self.names[name_key].as_str())
7722 .unwrap_or("__local_invocation_id"),
7723 NAMESPACE,
7724 )?;
7725
7726 let mut access_stack = AccessStack::new();
7727
7728 let vars = module.global_variables.iter().filter(|&(handle, var)| {
7729 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
7730 });
7731
7732 for (handle, var) in vars {
7733 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
7734 self.write_workgroup_variable_initialization(
7735 module,
7736 module_info,
7737 var.ty,
7738 access_stack,
7739 level.next(),
7740 )
7741 })?;
7742 }
7743
7744 writeln!(self.out, "{level}}}")?;
7745 self.write_barrier(crate::Barrier::WORK_GROUP, level)
7746 }
7747
7748 fn write_workgroup_variable_initialization(
7749 &mut self,
7750 module: &crate::Module,
7751 module_info: &valid::ModuleInfo,
7752 ty: Handle<crate::Type>,
7753 access_stack: &mut AccessStack,
7754 level: back::Level,
7755 ) -> BackendResult {
7756 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
7757 write!(self.out, "{level}")?;
7758 access_stack.write(&mut self.out, &self.names)?;
7759 writeln!(self.out, " = {{}};")?;
7760 } else {
7761 match module.types[ty].inner {
7762 crate::TypeInner::Atomic { .. } => {
7763 write!(
7764 self.out,
7765 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
7766 )?;
7767 access_stack.write(&mut self.out, &self.names)?;
7768 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
7769 }
7770 crate::TypeInner::Array { base, size, .. } => {
7771 let count = match size.resolve(module.to_ctx())? {
7772 proc::IndexableLength::Known(count) => count,
7773 proc::IndexableLength::Dynamic => unreachable!(),
7774 };
7775
7776 access_stack.enter_array(|access_stack, array_depth| {
7777 writeln!(
7778 self.out,
7779 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
7780 )?;
7781 self.write_workgroup_variable_initialization(
7782 module,
7783 module_info,
7784 base,
7785 access_stack,
7786 level.next(),
7787 )?;
7788 writeln!(self.out, "{level}}}")?;
7789 BackendResult::Ok(())
7790 })?;
7791 }
7792 crate::TypeInner::Struct { ref members, .. } => {
7793 for (index, member) in members.iter().enumerate() {
7794 access_stack.enter(
7795 Access::StructMember(ty, index as u32),
7796 |access_stack| {
7797 self.write_workgroup_variable_initialization(
7798 module,
7799 module_info,
7800 member.ty,
7801 access_stack,
7802 level,
7803 )
7804 },
7805 )?;
7806 }
7807 }
7808 _ => unreachable!(),
7809 }
7810 }
7811
7812 Ok(())
7813 }
7814 }
7815}
7816
7817impl crate::AtomicFunction {
7818 const fn to_msl(self) -> &'static str {
7819 match self {
7820 Self::Add => "fetch_add",
7821 Self::Subtract => "fetch_sub",
7822 Self::And => "fetch_and",
7823 Self::InclusiveOr => "fetch_or",
7824 Self::ExclusiveOr => "fetch_xor",
7825 Self::Min => "fetch_min",
7826 Self::Max => "fetch_max",
7827 Self::Exchange { compare: None } => "exchange",
7828 Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION,
7829 }
7830 }
7831
7832 fn to_msl_64_bit(self) -> Result<&'static str, Error> {
7833 Ok(match self {
7834 Self::Min => "min",
7835 Self::Max => "max",
7836 _ => Err(Error::FeatureNotImplemented(
7837 "64-bit atomic operation other than min/max".to_string(),
7838 ))?,
7839 })
7840 }
7841}