1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12 BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13 ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16 arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17 proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21 match *type_inner {
22 crate::TypeInner::Scalar(_) => Dimension::Scalar,
23 crate::TypeInner::Vector { .. } => Dimension::Vector,
24 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25 crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26 _ => unreachable!(),
27 }
28}
29
30#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43 None,
52
53 IntroducePointer(spirv::StorageClass),
77
78 UseStd140CompatType,
88}
89
90enum ExpressionPointer {
94 Ready { pointer_id: Word },
97
98 Conditional {
104 condition: Word,
105 access: Instruction,
106 },
107}
108
109enum BlockExit {
111 Return,
113 Branch {
115 target: Word,
117 },
118 BreakIf {
124 condition: Handle<crate::Expression>,
126 preamble_id: Word,
128 },
129}
130
131#[must_use]
142enum BlockExitDisposition {
143 Used,
147
148 Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157 continuing_id: Option<Word>,
158 break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163 pub source_code: &'a str,
164 pub source_file_id: Word,
165}
166
167impl Writer {
168 fn write_epilogue_position_y_flip(
173 &mut self,
174 position_id: Word,
175 body: &mut Vec<Instruction>,
176 ) -> Result<(), Error> {
177 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178 let index_y_id = self.get_index_constant(1);
179 let access_id = self.id_gen.next();
180 body.push(Instruction::access_chain(
181 float_ptr_type_id,
182 access_id,
183 position_id,
184 &[index_y_id],
185 ));
186
187 let float_type_id = self.get_f32_type_id();
188 let load_id = self.id_gen.next();
189 body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191 let neg_id = self.id_gen.next();
192 body.push(Instruction::unary(
193 spirv::Op::FNegate,
194 float_type_id,
195 neg_id,
196 load_id,
197 ));
198
199 body.push(Instruction::store(access_id, neg_id, None));
200 Ok(())
201 }
202
203 fn write_epilogue_frag_depth_clamp(
205 &mut self,
206 frag_depth_id: Word,
207 body: &mut Vec<Instruction>,
208 ) -> Result<(), Error> {
209 let float_type_id = self.get_f32_type_id();
210 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213 let original_id = self.id_gen.next();
214 body.push(Instruction::load(
215 float_type_id,
216 original_id,
217 frag_depth_id,
218 None,
219 ));
220
221 let clamp_id = self.id_gen.next();
222 body.push(Instruction::ext_inst_gl_op(
223 self.gl450_ext_inst_id,
224 spirv::GlslStd450Op::FClamp,
225 float_type_id,
226 clamp_id,
227 &[original_id, zero_scalar_id, one_scalar_id],
228 ));
229
230 body.push(Instruction::store(frag_depth_id, clamp_id, None));
231 Ok(())
232 }
233
234 fn write_entry_point_return(
235 &mut self,
236 value_id: Word,
237 ir_result: &crate::FunctionResult,
238 result_members: &[ResultMember],
239 body: &mut Vec<Instruction>,
240 ) -> Result<Instruction, Error> {
241 for (index, res_member) in result_members.iter().enumerate() {
242 if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
244 return Ok(Instruction::return_value(value_id));
245 }
246 let member_value_id = match ir_result.binding {
247 Some(_) => value_id,
248 None => {
249 let member_value_id = self.id_gen.next();
250 body.push(Instruction::composite_extract(
251 res_member.type_id,
252 member_value_id,
253 value_id,
254 &[index as u32],
255 ));
256 member_value_id
257 }
258 };
259
260 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
261
262 match res_member.built_in {
263 Some(crate::BuiltIn::Position { .. })
264 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
265 {
266 self.write_epilogue_position_y_flip(res_member.id, body)?;
267 }
268 Some(crate::BuiltIn::FragDepth)
269 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
270 {
271 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
272 }
273 _ => {}
274 }
275 }
276 Ok(Instruction::return_void())
277 }
278}
279
280impl BlockContext<'_> {
281 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
294 let uint_type_id = self.writer.get_u32_type_id();
295 let uint2_type_id = self.writer.get_vec2u_type_id();
296 let uint2_ptr_type_id = self
297 .writer
298 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
299 let bool_type_id = self.writer.get_bool_type_id();
300 let bool2_type_id = self.writer.get_vec2_bool_type_id();
301 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
302 let zero_uint2_const_id = self.writer.get_constant_composite(
303 LookupType::Local(LocalType::Numeric(NumericType::Vector {
304 size: crate::VectorSize::Bi,
305 scalar: crate::Scalar::U32,
306 })),
307 &[zero_uint_const_id, zero_uint_const_id],
308 );
309 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
310 let max_uint_const_id = self
311 .writer
312 .get_constant_scalar(crate::Literal::U32(u32::MAX));
313 let max_uint2_const_id = self.writer.get_constant_composite(
314 LookupType::Local(LocalType::Numeric(NumericType::Vector {
315 size: crate::VectorSize::Bi,
316 scalar: crate::Scalar::U32,
317 })),
318 &[max_uint_const_id, max_uint_const_id],
319 );
320
321 let loop_counter_var_id = self.gen_id();
322 if self.writer.flags.contains(WriterFlags::DEBUG) {
323 self.writer
324 .debugs
325 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
326 }
327 let var = super::LocalVariable {
328 id: loop_counter_var_id,
329 instruction: Instruction::variable(
330 uint2_ptr_type_id,
331 loop_counter_var_id,
332 spirv::StorageClass::Function,
333 Some(max_uint2_const_id),
334 ),
335 };
336 self.function.force_loop_bounding_vars.push(var);
337
338 let break_if_block = self.gen_id();
339
340 self.function
341 .consume(block, Instruction::branch(break_if_block));
342 block = Block::new(break_if_block);
343
344 let load_id = self.gen_id();
347 block.body.push(Instruction::load(
348 uint2_type_id,
349 load_id,
350 loop_counter_var_id,
351 None,
352 ));
353
354 let eq_id = self.gen_id();
357 block.body.push(Instruction::binary(
358 spirv::Op::IEqual,
359 bool2_type_id,
360 eq_id,
361 zero_uint2_const_id,
362 load_id,
363 ));
364 let all_eq_id = self.gen_id();
365 block.body.push(Instruction::relational(
366 spirv::Op::All,
367 bool_type_id,
368 all_eq_id,
369 eq_id,
370 ));
371
372 let inc_counter_block_id = self.gen_id();
373 block.body.push(Instruction::selection_merge(
374 inc_counter_block_id,
375 spirv::SelectionControl::empty(),
376 ));
377 self.function.consume(
378 block,
379 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
380 );
381 block = Block::new(inc_counter_block_id);
382
383 let low_id = self.gen_id();
389 block.body.push(Instruction::composite_extract(
390 uint_type_id,
391 low_id,
392 load_id,
393 &[1],
394 ));
395 let low_overflow_id = self.gen_id();
396 block.body.push(Instruction::binary(
397 spirv::Op::IEqual,
398 bool_type_id,
399 low_overflow_id,
400 low_id,
401 zero_uint_const_id,
402 ));
403 let carry_bit_id = self.gen_id();
404 block.body.push(Instruction::select(
405 uint_type_id,
406 carry_bit_id,
407 low_overflow_id,
408 one_uint_const_id,
409 zero_uint_const_id,
410 ));
411 let decrement_id = self.gen_id();
412 block.body.push(Instruction::composite_construct(
413 uint2_type_id,
414 decrement_id,
415 &[carry_bit_id, one_uint_const_id],
416 ));
417 let result_id = self.gen_id();
418 block.body.push(Instruction::binary(
419 spirv::Op::ISub,
420 uint2_type_id,
421 result_id,
422 load_id,
423 decrement_id,
424 ));
425 block
426 .body
427 .push(Instruction::store(loop_counter_var_id, result_id, None));
428
429 block
430 }
431
432 fn maybe_write_uniform_matcx2_dynamic_access(
449 &mut self,
450 pointer: Handle<crate::Expression>,
451 block: &mut Block,
452 ) -> Result<Option<Word>, Error> {
453 let (column_pointer, component_index) = match self.fun_info[pointer]
459 .ty
460 .inner_with(&self.ir_module.types)
461 .pointer_base_type()
462 {
463 Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
464 crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
465 crate::Expression::Access { base, index } => {
466 (base, Some(GuardedIndex::Expression(index)))
467 }
468 crate::Expression::AccessIndex { base, index } => {
469 (base, Some(GuardedIndex::Known(index)))
470 }
471 _ => return Ok(None),
472 },
473 crate::TypeInner::Vector { .. } => (pointer, None),
474 _ => return Ok(None),
475 },
476 None => return Ok(None),
477 };
478
479 let crate::Expression::Access {
482 base: matrix_pointer,
483 index: column_index,
484 } = self.ir_function.expressions[column_pointer]
485 else {
486 return Ok(None);
487 };
488
489 let crate::TypeInner::Pointer {
491 base: matrix_pointer_base_type,
492 space: crate::AddressSpace::Uniform,
493 } = *self.fun_info[matrix_pointer]
494 .ty
495 .inner_with(&self.ir_module.types)
496 else {
497 return Ok(None);
498 };
499
500 let crate::TypeInner::Matrix {
502 columns,
503 rows: rows @ crate::VectorSize::Bi,
504 scalar,
505 } = self.ir_module.types[matrix_pointer_base_type].inner
506 else {
507 return Ok(None);
508 };
509
510 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
511 columns,
512 rows,
513 scalar,
514 });
515 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
516 let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
517 let get_column_function_id = self.writer.wrapped_functions
518 [&WrappedFunction::MatCx2GetColumn {
519 r#type: matrix_pointer_base_type,
520 }];
521
522 let matrix_load_id = self.write_checked_load(
523 matrix_pointer,
524 block,
525 AccessTypeAdjustment::None,
526 matrix_type_id,
527 )?;
528
529 let column_index_id = match *self.fun_info[column_index]
532 .ty
533 .inner_with(&self.ir_module.types)
534 {
535 crate::TypeInner::Scalar(crate::Scalar {
536 kind: crate::ScalarKind::Uint,
537 ..
538 }) => self.cached[column_index],
539 crate::TypeInner::Scalar(crate::Scalar {
540 kind: crate::ScalarKind::Sint,
541 ..
542 }) => {
543 let cast_id = self.gen_id();
544 let u32_type_id = self.writer.get_u32_type_id();
545 block.body.push(Instruction::unary(
546 spirv::Op::Bitcast,
547 u32_type_id,
548 cast_id,
549 self.cached[column_index],
550 ));
551 cast_id
552 }
553 _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
554 };
555 let column_id = self.gen_id();
556 block.body.push(Instruction::function_call(
557 column_type_id,
558 column_id,
559 get_column_function_id,
560 &[matrix_load_id, column_index_id],
561 ));
562 let result_id = match component_index {
563 Some(index) => self.write_vector_access(
564 component_type_id,
565 column_pointer,
566 Some(column_id),
567 index,
568 block,
569 )?,
570 None => column_id,
571 };
572
573 Ok(Some(result_id))
574 }
575
576 fn maybe_write_load_uniform_matcx2_struct_member(
588 &mut self,
589 pointer: Handle<crate::Expression>,
590 block: &mut Block,
591 ) -> Result<Option<Word>, Error> {
592 let crate::TypeInner::Pointer {
594 base: matrix_type,
595 space: space @ crate::AddressSpace::Uniform,
596 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
597 else {
598 return Ok(None);
599 };
600
601 let crate::TypeInner::Matrix {
602 columns,
603 rows: rows @ crate::VectorSize::Bi,
604 scalar,
605 } = self.ir_module.types[matrix_type].inner
606 else {
607 return Ok(None);
608 };
609
610 let crate::Expression::AccessIndex {
613 base: struct_pointer,
614 index: member_index,
615 } = self.ir_function.expressions[pointer]
616 else {
617 return Ok(None);
618 };
619
620 let crate::TypeInner::Pointer {
621 base: struct_type, ..
622 } = *self.fun_info[struct_pointer]
623 .ty
624 .inner_with(&self.ir_module.types)
625 else {
626 return Ok(None);
627 };
628
629 let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
630 return Ok(None);
631 };
632
633 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
634 columns,
635 rows,
636 scalar,
637 });
638 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
639 let column_pointer_type_id =
640 self.get_pointer_type_id(column_type_id, map_storage_class(space));
641 let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
642 [member_index as usize];
643 let column_indices = (0..columns as u32)
644 .map(|c| self.get_index_constant(column0_index + c))
645 .collect::<ArrayVec<_, 4>>();
646
647 let load_mat_from_struct =
650 |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
651 let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
652 for index in &column_indices {
653 let column_pointer_id = id_gen.next();
654 block.body.push(Instruction::access_chain(
655 column_pointer_type_id,
656 column_pointer_id,
657 struct_pointer_id,
658 &[*index],
659 ));
660 let column_id = id_gen.next();
661 block.body.push(Instruction::load(
662 column_type_id,
663 column_id,
664 column_pointer_id,
665 None,
666 ));
667 column_ids.push(column_id);
668 }
669 let result_id = id_gen.next();
670 block.body.push(Instruction::composite_construct(
671 matrix_type_id,
672 result_id,
673 &column_ids,
674 ));
675 result_id
676 };
677
678 let result_id = match self.write_access_chain(
679 struct_pointer,
680 block,
681 AccessTypeAdjustment::UseStd140CompatType,
682 )? {
683 ExpressionPointer::Ready { pointer_id } => {
684 load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
685 }
686 ExpressionPointer::Conditional { condition, access } => self
687 .write_conditional_indexed_load(
688 matrix_type_id,
689 condition,
690 block,
691 |id_gen, block| {
692 let pointer_id = access.result_id.unwrap();
693 block.body.push(access);
694 load_mat_from_struct(pointer_id, id_gen, block)
695 },
696 ),
697 };
698
699 Ok(Some(result_id))
700 }
701
702 pub(super) fn cache_expression_value(
704 &mut self,
705 expr_handle: Handle<crate::Expression>,
706 block: &mut Block,
707 ) -> Result<(), Error> {
708 let is_named_expression = self
709 .ir_function
710 .named_expressions
711 .contains_key(&expr_handle);
712
713 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
714 return Ok(());
715 }
716
717 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
718 let id = match self.ir_function.expressions[expr_handle] {
719 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
720 crate::Expression::Constant(handle) => {
721 let init = self.ir_module.constants[handle].init;
722 self.writer.constant_ids[init]
723 }
724 crate::Expression::Override(_) => return Err(Error::Override),
725 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
726 crate::Expression::Compose { ty, ref components } => {
727 self.temp_list.clear();
728 if self.expression_constness.is_const(expr_handle) {
729 self.temp_list.extend(
730 crate::proc::flatten_compose(
731 ty,
732 components,
733 &self.ir_function.expressions,
734 &self.ir_module.types,
735 )
736 .map(|component| self.cached[component]),
737 );
738 self.writer
739 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
740 } else {
741 self.temp_list
742 .extend(components.iter().map(|&component| self.cached[component]));
743
744 let id = self.gen_id();
745 block.body.push(Instruction::composite_construct(
746 result_type_id,
747 id,
748 &self.temp_list,
749 ));
750 id
751 }
752 }
753 crate::Expression::Splat { size, value } => {
754 let value_id = self.cached[value];
755 let components = &[value_id; 4][..size as usize];
756
757 if self.expression_constness.is_const(expr_handle) {
758 let ty = self
759 .writer
760 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
761 self.writer.get_constant_composite(ty, components)
762 } else {
763 let id = self.gen_id();
764 block.body.push(Instruction::composite_construct(
765 result_type_id,
766 id,
767 components,
768 ));
769 id
770 }
771 }
772 crate::Expression::Access { base, index } => {
773 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
774 match *base_ty_inner {
775 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
776 0
782 }
783 _ if self.function.spilled_accesses.contains(base) => {
784 self.function.spilled_accesses.insert(expr_handle);
792 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
793 }
794 crate::TypeInner::Vector { .. } => self.write_vector_access(
795 result_type_id,
796 base,
797 None,
798 GuardedIndex::Expression(index),
799 block,
800 )?,
801 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
802 match GuardedIndex::from_expression(
804 index,
805 &self.ir_function.expressions,
806 self.ir_module,
807 ) {
808 GuardedIndex::Known(value) => {
809 let id = self.gen_id();
819 let base_id = self.cached[base];
820 block.body.push(Instruction::composite_extract(
821 result_type_id,
822 id,
823 base_id,
824 &[value],
825 ));
826 id
827 }
828 GuardedIndex::Expression(_) => {
829 self.spill_to_internal_variable(base, block);
836
837 self.function.spilled_accesses.insert(expr_handle);
840 self.maybe_access_spilled_composite(
841 expr_handle,
842 block,
843 result_type_id,
844 )?
845 }
846 }
847 }
848 crate::TypeInner::BindingArray {
849 base: binding_type, ..
850 } => {
851 let result_id = match self.write_access_chain(
854 expr_handle,
855 block,
856 AccessTypeAdjustment::IntroducePointer(
857 spirv::StorageClass::UniformConstant,
858 ),
859 )? {
860 ExpressionPointer::Ready { pointer_id } => pointer_id,
861 ExpressionPointer::Conditional { .. } => {
862 return Err(Error::FeatureNotImplemented(
863 "Texture array out-of-bounds handling",
864 ));
865 }
866 };
867
868 let binding_type_id = self.get_handle_type_id(binding_type);
869
870 let load_id = self.gen_id();
871 block.body.push(Instruction::load(
872 binding_type_id,
873 load_id,
874 result_id,
875 None,
876 ));
877
878 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
882 self.writer
883 .decorate_non_uniform_binding_array_access(load_id)?;
884 }
885
886 load_id
887 }
888 ref other => {
889 log::error!(
890 "Unable to access base {:?} of type {:?}",
891 self.ir_function.expressions[base],
892 other
893 );
894 return Err(Error::Validation(
895 "only vectors and arrays may be dynamically indexed by value",
896 ));
897 }
898 }
899 }
900 crate::Expression::AccessIndex { base, index } => {
901 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
902 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
903 0
909 }
910 _ if self.function.spilled_accesses.contains(base) => {
911 self.function.spilled_accesses.insert(expr_handle);
919 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
920 }
921 crate::TypeInner::Vector { .. }
922 | crate::TypeInner::Matrix { .. }
923 | crate::TypeInner::Array { .. }
924 | crate::TypeInner::Struct { .. } => {
925 let id = self.gen_id();
930 let base_id = self.cached[base];
931 block.body.push(Instruction::composite_extract(
932 result_type_id,
933 id,
934 base_id,
935 &[index],
936 ));
937 id
938 }
939 crate::TypeInner::BindingArray {
940 base: binding_type, ..
941 } => {
942 let result_id = match self.write_access_chain(
945 expr_handle,
946 block,
947 AccessTypeAdjustment::IntroducePointer(
948 spirv::StorageClass::UniformConstant,
949 ),
950 )? {
951 ExpressionPointer::Ready { pointer_id } => pointer_id,
952 ExpressionPointer::Conditional { .. } => {
953 return Err(Error::FeatureNotImplemented(
954 "Texture array out-of-bounds handling",
955 ));
956 }
957 };
958
959 let binding_type_id = self.get_handle_type_id(binding_type);
960
961 let load_id = self.gen_id();
962 block.body.push(Instruction::load(
963 binding_type_id,
964 load_id,
965 result_id,
966 None,
967 ));
968
969 load_id
970 }
971 ref other => {
972 log::error!("Unable to access index of {other:?}");
973 return Err(Error::FeatureNotImplemented("access index for type"));
974 }
975 }
976 }
977 crate::Expression::GlobalVariable(handle) => {
978 self.writer.global_variables[handle].access_id
979 }
980 crate::Expression::Swizzle {
981 size,
982 vector,
983 pattern,
984 } => {
985 let vector_id = self.cached[vector];
986 self.temp_list.clear();
987 for &sc in pattern[..size as usize].iter() {
988 self.temp_list.push(sc as Word);
989 }
990 let id = self.gen_id();
991 block.body.push(Instruction::vector_shuffle(
992 result_type_id,
993 id,
994 vector_id,
995 vector_id,
996 &self.temp_list,
997 ));
998 id
999 }
1000 crate::Expression::Unary { op, expr } => {
1001 let id = self.gen_id();
1002 let expr_id = self.cached[expr];
1003 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1004
1005 let spirv_op = match op {
1006 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1007 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1008 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1009 _ => return Err(Error::Validation("Unexpected kind for negation")),
1010 },
1011 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1012 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1013 };
1014
1015 block
1016 .body
1017 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1018 id
1019 }
1020 crate::Expression::Binary { op, left, right } => {
1021 let id = self.gen_id();
1022 let left_id = self.cached[left];
1023 let right_id = self.cached[right];
1024 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1025 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1026
1027 if let Some(function_id) =
1028 self.writer
1029 .wrapped_functions
1030 .get(&WrappedFunction::BinaryOp {
1031 op,
1032 left_type_id,
1033 right_type_id,
1034 })
1035 {
1036 block.body.push(Instruction::function_call(
1037 result_type_id,
1038 id,
1039 *function_id,
1040 &[left_id, right_id],
1041 ));
1042 } else {
1043 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1044 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1045
1046 let left_dimension = get_dimension(left_ty_inner);
1047 let right_dimension = get_dimension(right_ty_inner);
1048
1049 let mut reverse_operands = false;
1050
1051 let spirv_op = match op {
1052 crate::BinaryOperator::Add => match *left_ty_inner {
1053 crate::TypeInner::Scalar(scalar)
1054 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1055 crate::ScalarKind::Float => spirv::Op::FAdd,
1056 _ => spirv::Op::IAdd,
1057 },
1058 crate::TypeInner::Matrix {
1059 columns,
1060 rows,
1061 scalar,
1062 } => {
1063 self.write_matrix_matrix_column_op(
1065 block,
1066 id,
1067 result_type_id,
1068 left_id,
1069 right_id,
1070 columns,
1071 rows,
1072 scalar.width,
1073 spirv::Op::FAdd,
1074 );
1075
1076 self.cached[expr_handle] = id;
1077 return Ok(());
1078 }
1079 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1080 _ => unimplemented!(),
1081 },
1082 crate::BinaryOperator::Subtract => match *left_ty_inner {
1083 crate::TypeInner::Scalar(scalar)
1084 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1085 crate::ScalarKind::Float => spirv::Op::FSub,
1086 _ => spirv::Op::ISub,
1087 },
1088 crate::TypeInner::Matrix {
1089 columns,
1090 rows,
1091 scalar,
1092 } => {
1093 self.write_matrix_matrix_column_op(
1094 block,
1095 id,
1096 result_type_id,
1097 left_id,
1098 right_id,
1099 columns,
1100 rows,
1101 scalar.width,
1102 spirv::Op::FSub,
1103 );
1104
1105 self.cached[expr_handle] = id;
1106 return Ok(());
1107 }
1108 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1109 _ => unimplemented!(),
1110 },
1111 crate::BinaryOperator::Multiply => {
1112 match (left_dimension, right_dimension) {
1113 (Dimension::Scalar, Dimension::Vector) => {
1114 self.write_vector_scalar_mult(
1115 block,
1116 id,
1117 result_type_id,
1118 right_id,
1119 left_id,
1120 right_ty_inner,
1121 );
1122
1123 self.cached[expr_handle] = id;
1124 return Ok(());
1125 }
1126 (Dimension::Vector, Dimension::Scalar) => {
1127 self.write_vector_scalar_mult(
1128 block,
1129 id,
1130 result_type_id,
1131 left_id,
1132 right_id,
1133 left_ty_inner,
1134 );
1135
1136 self.cached[expr_handle] = id;
1137 return Ok(());
1138 }
1139 (Dimension::Vector, Dimension::Matrix) => {
1140 spirv::Op::VectorTimesMatrix
1141 }
1142 (Dimension::Matrix, Dimension::Scalar)
1143 | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1144 spirv::Op::MatrixTimesScalar
1145 }
1146 (Dimension::Scalar, Dimension::Matrix)
1147 | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1148 reverse_operands = true;
1149 spirv::Op::MatrixTimesScalar
1150 }
1151 (Dimension::Matrix, Dimension::Vector) => {
1152 spirv::Op::MatrixTimesVector
1153 }
1154 (Dimension::Matrix, Dimension::Matrix) => {
1155 spirv::Op::MatrixTimesMatrix
1156 }
1157 (Dimension::Vector, Dimension::Vector)
1158 | (Dimension::Scalar, Dimension::Scalar)
1159 if left_ty_inner.scalar_kind()
1160 == Some(crate::ScalarKind::Float) =>
1161 {
1162 spirv::Op::FMul
1163 }
1164 (Dimension::Vector, Dimension::Vector)
1165 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1166 (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1167 | (Dimension::CooperativeMatrix, _)
1169 | (_, Dimension::CooperativeMatrix) => {
1170 unimplemented!()
1171 }
1172 }
1173 }
1174 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1175 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1176 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1177 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1178 _ => unimplemented!(),
1179 },
1180 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1181 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1184 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1185 unreachable!("Should have been handled by wrapped function")
1186 }
1187 _ => unimplemented!(),
1188 },
1189 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1190 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1191 spirv::Op::IEqual
1192 }
1193 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1194 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1195 _ => unimplemented!(),
1196 },
1197 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1198 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1199 spirv::Op::INotEqual
1200 }
1201 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1202 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1203 _ => unimplemented!(),
1204 },
1205 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1206 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1207 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1208 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1209 _ => unimplemented!(),
1210 },
1211 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1212 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1213 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1214 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1215 _ => unimplemented!(),
1216 },
1217 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1218 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1219 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1220 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1221 _ => unimplemented!(),
1222 },
1223 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1224 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1225 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1226 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1227 _ => unimplemented!(),
1228 },
1229 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1230 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1231 _ => spirv::Op::BitwiseAnd,
1232 },
1233 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1234 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1235 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1236 _ => spirv::Op::BitwiseOr,
1237 },
1238 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1239 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1240 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1241 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1242 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1243 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1244 _ => unimplemented!(),
1245 },
1246 };
1247
1248 block.body.push(Instruction::binary(
1249 spirv_op,
1250 result_type_id,
1251 id,
1252 if reverse_operands { right_id } else { left_id },
1253 if reverse_operands { left_id } else { right_id },
1254 ));
1255 }
1256 id
1257 }
1258 crate::Expression::Math {
1259 fun,
1260 arg,
1261 arg1,
1262 arg2,
1263 arg3,
1264 } => {
1265 use crate::MathFunction as Mf;
1266 enum MathOp {
1267 Ext(spirv::GlslStd450Op),
1268 Custom(Instruction),
1269 }
1270
1271 let arg0_id = self.cached[arg];
1272 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1273 let arg_scalar_kind = arg_ty.scalar_kind();
1274 let arg1_id = match arg1 {
1275 Some(handle) => self.cached[handle],
1276 None => 0,
1277 };
1278 let arg2_id = match arg2 {
1279 Some(handle) => self.cached[handle],
1280 None => 0,
1281 };
1282 let arg3_id = match arg3 {
1283 Some(handle) => self.cached[handle],
1284 None => 0,
1285 };
1286
1287 let id = self.gen_id();
1288 let math_op = match fun {
1289 Mf::Abs => {
1291 match arg_scalar_kind {
1292 Some(crate::ScalarKind::Float) => {
1293 MathOp::Ext(spirv::GlslStd450Op::FAbs)
1294 }
1295 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GlslStd450Op::SAbs),
1296 Some(crate::ScalarKind::Uint) => {
1297 MathOp::Custom(Instruction::unary(
1298 spirv::Op::CopyObject, result_type_id,
1300 id,
1301 arg0_id,
1302 ))
1303 }
1304 other => unimplemented!("Unexpected abs({:?})", other),
1305 }
1306 }
1307 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1308 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMin,
1309 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMin,
1310 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMin,
1311 other => unimplemented!("Unexpected min({:?})", other),
1312 }),
1313 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1314 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMax,
1315 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMax,
1316 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMax,
1317 other => unimplemented!("Unexpected max({:?})", other),
1318 }),
1319 Mf::Clamp => match arg_scalar_kind {
1320 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GlslStd450Op::FClamp),
1324 Some(_) => {
1325 let (min_op, max_op) = match arg_scalar_kind {
1326 Some(crate::ScalarKind::Sint) => {
1327 (spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::SMax)
1328 }
1329 Some(crate::ScalarKind::Uint) => {
1330 (spirv::GlslStd450Op::UMin, spirv::GlslStd450Op::UMax)
1331 }
1332 _ => unreachable!(),
1333 };
1334
1335 let max_id = self.gen_id();
1336 block.body.push(Instruction::ext_inst_gl_op(
1337 self.writer.gl450_ext_inst_id,
1338 max_op,
1339 result_type_id,
1340 max_id,
1341 &[arg0_id, arg1_id],
1342 ));
1343
1344 MathOp::Custom(Instruction::ext_inst_gl_op(
1345 self.writer.gl450_ext_inst_id,
1346 min_op,
1347 result_type_id,
1348 id,
1349 &[max_id, arg2_id],
1350 ))
1351 }
1352 other => unimplemented!("Unexpected max({:?})", other),
1353 },
1354 Mf::Saturate => {
1355 let (maybe_size, scalar) = match *arg_ty {
1356 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1357 crate::TypeInner::Scalar(scalar) => (None, scalar),
1358 ref other => unimplemented!("Unexpected saturate({:?})", other),
1359 };
1360 let scalar = crate::Scalar::float(scalar.width);
1361 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1362 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1363
1364 if let Some(size) = maybe_size {
1365 let ty =
1366 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1367
1368 self.temp_list.clear();
1369 self.temp_list.resize(size as _, arg1_id);
1370
1371 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1372
1373 self.temp_list.fill(arg2_id);
1374
1375 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1376 }
1377
1378 MathOp::Custom(Instruction::ext_inst_gl_op(
1379 self.writer.gl450_ext_inst_id,
1380 spirv::GlslStd450Op::FClamp,
1381 result_type_id,
1382 id,
1383 &[arg0_id, arg1_id, arg2_id],
1384 ))
1385 }
1386 Mf::Sin => MathOp::Ext(spirv::GlslStd450Op::Sin),
1388 Mf::Sinh => MathOp::Ext(spirv::GlslStd450Op::Sinh),
1389 Mf::Asin => MathOp::Ext(spirv::GlslStd450Op::Asin),
1390 Mf::Cos => MathOp::Ext(spirv::GlslStd450Op::Cos),
1391 Mf::Cosh => MathOp::Ext(spirv::GlslStd450Op::Cosh),
1392 Mf::Acos => MathOp::Ext(spirv::GlslStd450Op::Acos),
1393 Mf::Tan => MathOp::Ext(spirv::GlslStd450Op::Tan),
1394 Mf::Tanh => MathOp::Ext(spirv::GlslStd450Op::Tanh),
1395 Mf::Atan => MathOp::Ext(spirv::GlslStd450Op::Atan),
1396 Mf::Atan2 => MathOp::Ext(spirv::GlslStd450Op::Atan2),
1397 Mf::Asinh => MathOp::Ext(spirv::GlslStd450Op::Asinh),
1398 Mf::Acosh => MathOp::Ext(spirv::GlslStd450Op::Acosh),
1399 Mf::Atanh => MathOp::Ext(spirv::GlslStd450Op::Atanh),
1400 Mf::Radians => MathOp::Ext(spirv::GlslStd450Op::Radians),
1401 Mf::Degrees => MathOp::Ext(spirv::GlslStd450Op::Degrees),
1402 Mf::Ceil => MathOp::Ext(spirv::GlslStd450Op::Ceil),
1404 Mf::Round => MathOp::Ext(spirv::GlslStd450Op::RoundEven),
1405 Mf::Floor => MathOp::Ext(spirv::GlslStd450Op::Floor),
1406 Mf::Fract => MathOp::Ext(spirv::GlslStd450Op::Fract),
1407 Mf::Trunc => MathOp::Ext(spirv::GlslStd450Op::Trunc),
1408 Mf::Modf => MathOp::Ext(spirv::GlslStd450Op::ModfStruct),
1409 Mf::Frexp => MathOp::Ext(spirv::GlslStd450Op::FrexpStruct),
1410 Mf::Ldexp => MathOp::Ext(spirv::GlslStd450Op::Ldexp),
1411 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1413 crate::TypeInner::Vector {
1414 scalar:
1415 crate::Scalar {
1416 kind: crate::ScalarKind::Float,
1417 ..
1418 },
1419 ..
1420 } => MathOp::Custom(Instruction::binary(
1421 spirv::Op::Dot,
1422 result_type_id,
1423 id,
1424 arg0_id,
1425 arg1_id,
1426 )),
1427 crate::TypeInner::Vector { size, .. } => {
1429 self.write_dot_product(
1430 id,
1431 result_type_id,
1432 arg0_id,
1433 arg1_id,
1434 size as u32,
1435 block,
1436 |result_id, composite_id, index| {
1437 Instruction::composite_extract(
1438 result_type_id,
1439 result_id,
1440 composite_id,
1441 &[index],
1442 )
1443 },
1444 );
1445 self.cached[expr_handle] = id;
1446 return Ok(());
1447 }
1448 _ => unreachable!(
1449 "Correct TypeInner for dot product should be already validated"
1450 ),
1451 },
1452 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1453 if self
1454 .writer
1455 .require_all(&[
1456 spirv::Capability::DotProduct,
1457 spirv::Capability::DotProductInput4x8BitPacked,
1458 ])
1459 .is_ok()
1460 {
1461 if self.writer.lang_version() < (1, 6) {
1463 self.writer.use_extension("SPV_KHR_integer_dot_product");
1467 }
1468
1469 let op = match fun {
1470 Mf::Dot4I8Packed => spirv::Op::SDot,
1471 Mf::Dot4U8Packed => spirv::Op::UDot,
1472 _ => unreachable!(),
1473 };
1474
1475 block.body.push(Instruction::ternary(
1476 op,
1477 result_type_id,
1478 id,
1479 arg0_id,
1480 arg1_id,
1481 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1482 ));
1483 } else {
1484 let (extract_op, arg0_id, arg1_id) = match fun {
1486 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1487 Mf::Dot4I8Packed => {
1488 let new_arg0_id = self.gen_id();
1491 block.body.push(Instruction::unary(
1492 spirv::Op::Bitcast,
1493 result_type_id,
1494 new_arg0_id,
1495 arg0_id,
1496 ));
1497
1498 let new_arg1_id = self.gen_id();
1499 block.body.push(Instruction::unary(
1500 spirv::Op::Bitcast,
1501 result_type_id,
1502 new_arg1_id,
1503 arg1_id,
1504 ));
1505
1506 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1507 }
1508 _ => unreachable!(),
1509 };
1510
1511 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1512
1513 const VEC_LENGTH: u8 = 4;
1514 let bit_shifts: [_; VEC_LENGTH as usize] =
1515 core::array::from_fn(|index| {
1516 self.writer
1517 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1518 });
1519
1520 self.write_dot_product(
1521 id,
1522 result_type_id,
1523 arg0_id,
1524 arg1_id,
1525 VEC_LENGTH as Word,
1526 block,
1527 |result_id, composite_id, index| {
1528 Instruction::ternary(
1529 extract_op,
1530 result_type_id,
1531 result_id,
1532 composite_id,
1533 bit_shifts[index as usize],
1534 eight,
1535 )
1536 },
1537 );
1538 }
1539
1540 self.cached[expr_handle] = id;
1541 return Ok(());
1542 }
1543 Mf::Outer => MathOp::Custom(Instruction::binary(
1544 spirv::Op::OuterProduct,
1545 result_type_id,
1546 id,
1547 arg0_id,
1548 arg1_id,
1549 )),
1550 Mf::Cross => MathOp::Ext(spirv::GlslStd450Op::Cross),
1551 Mf::Distance => MathOp::Ext(spirv::GlslStd450Op::Distance),
1552 Mf::Length => MathOp::Ext(spirv::GlslStd450Op::Length),
1553 Mf::Normalize => MathOp::Ext(spirv::GlslStd450Op::Normalize),
1554 Mf::FaceForward => MathOp::Ext(spirv::GlslStd450Op::FaceForward),
1555 Mf::Reflect => MathOp::Ext(spirv::GlslStd450Op::Reflect),
1556 Mf::Refract => MathOp::Ext(spirv::GlslStd450Op::Refract),
1557 Mf::Exp => MathOp::Ext(spirv::GlslStd450Op::Exp),
1559 Mf::Exp2 => MathOp::Ext(spirv::GlslStd450Op::Exp2),
1560 Mf::Log => MathOp::Ext(spirv::GlslStd450Op::Log),
1561 Mf::Log2 => MathOp::Ext(spirv::GlslStd450Op::Log2),
1562 Mf::Pow => MathOp::Ext(spirv::GlslStd450Op::Pow),
1563 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1565 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FSign,
1566 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SSign,
1567 other => unimplemented!("Unexpected sign({:?})", other),
1568 }),
1569 Mf::Fma => MathOp::Ext(spirv::GlslStd450Op::Fma),
1570 Mf::Mix => {
1571 let selector = arg2.unwrap();
1572 let selector_ty =
1573 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1574 match (arg_ty, selector_ty) {
1575 (
1577 &crate::TypeInner::Vector { size, .. },
1578 &crate::TypeInner::Scalar(scalar),
1579 ) => {
1580 let selector_type_id =
1581 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1582 self.temp_list.clear();
1583 self.temp_list.resize(size as usize, arg2_id);
1584
1585 let selector_id = self.gen_id();
1586 block.body.push(Instruction::composite_construct(
1587 selector_type_id,
1588 selector_id,
1589 &self.temp_list,
1590 ));
1591
1592 MathOp::Custom(Instruction::ext_inst_gl_op(
1593 self.writer.gl450_ext_inst_id,
1594 spirv::GlslStd450Op::FMix,
1595 result_type_id,
1596 id,
1597 &[arg0_id, arg1_id, selector_id],
1598 ))
1599 }
1600 _ => MathOp::Ext(spirv::GlslStd450Op::FMix),
1601 }
1602 }
1603 Mf::Step => MathOp::Ext(spirv::GlslStd450Op::Step),
1604 Mf::SmoothStep => MathOp::Ext(spirv::GlslStd450Op::SmoothStep),
1605 Mf::Sqrt => MathOp::Ext(spirv::GlslStd450Op::Sqrt),
1606 Mf::InverseSqrt => MathOp::Ext(spirv::GlslStd450Op::InverseSqrt),
1607 Mf::Inverse => MathOp::Ext(spirv::GlslStd450Op::MatrixInverse),
1608 Mf::Transpose => MathOp::Custom(Instruction::unary(
1609 spirv::Op::Transpose,
1610 result_type_id,
1611 id,
1612 arg0_id,
1613 )),
1614 Mf::Determinant => MathOp::Ext(spirv::GlslStd450Op::Determinant),
1615 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1616 spirv::Op::QuantizeToF16,
1617 result_type_id,
1618 id,
1619 arg0_id,
1620 )),
1621 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1622 spirv::Op::BitReverse,
1623 result_type_id,
1624 id,
1625 arg0_id,
1626 )),
1627 Mf::CountTrailingZeros => {
1628 let uint_id = match *arg_ty {
1629 crate::TypeInner::Vector { size, scalar } => {
1630 let ty =
1631 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1632
1633 self.temp_list.clear();
1634 self.temp_list.resize(
1635 size as _,
1636 self.writer
1637 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1638 );
1639
1640 self.writer.get_constant_composite(ty, &self.temp_list)
1641 }
1642 crate::TypeInner::Scalar(scalar) => self
1643 .writer
1644 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1645 _ => unreachable!(),
1646 };
1647
1648 let lsb_id = self.gen_id();
1649 block.body.push(Instruction::ext_inst_gl_op(
1650 self.writer.gl450_ext_inst_id,
1651 spirv::GlslStd450Op::FindILsb,
1652 result_type_id,
1653 lsb_id,
1654 &[arg0_id],
1655 ));
1656
1657 MathOp::Custom(Instruction::ext_inst_gl_op(
1658 self.writer.gl450_ext_inst_id,
1659 spirv::GlslStd450Op::UMin,
1660 result_type_id,
1661 id,
1662 &[uint_id, lsb_id],
1663 ))
1664 }
1665 Mf::CountLeadingZeros => {
1666 let (int_type_id, int_id, width) = match *arg_ty {
1667 crate::TypeInner::Vector { size, scalar } => {
1668 let ty =
1669 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1670
1671 self.temp_list.clear();
1672 self.temp_list.resize(
1673 size as _,
1674 self.writer
1675 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1676 );
1677
1678 (
1679 self.get_type_id(ty),
1680 self.writer.get_constant_composite(ty, &self.temp_list),
1681 scalar.width,
1682 )
1683 }
1684 crate::TypeInner::Scalar(scalar) => (
1685 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1686 self.writer
1687 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1688 scalar.width,
1689 ),
1690 _ => unreachable!(),
1691 };
1692
1693 if width != 4 {
1694 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1695 };
1696
1697 let msb_id = self.gen_id();
1698 block.body.push(Instruction::ext_inst_gl_op(
1699 self.writer.gl450_ext_inst_id,
1700 if width != 4 {
1701 spirv::GlslStd450Op::FindILsb
1702 } else {
1703 spirv::GlslStd450Op::FindUMsb
1704 },
1705 int_type_id,
1706 msb_id,
1707 &[arg0_id],
1708 ));
1709
1710 MathOp::Custom(Instruction::binary(
1711 spirv::Op::ISub,
1712 result_type_id,
1713 id,
1714 int_id,
1715 msb_id,
1716 ))
1717 }
1718 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1719 spirv::Op::BitCount,
1720 result_type_id,
1721 id,
1722 arg0_id,
1723 )),
1724 Mf::ExtractBits => {
1725 let op = match arg_scalar_kind {
1726 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1727 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1728 other => unimplemented!("Unexpected sign({:?})", other),
1729 };
1730
1731 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1746 let width_constant = self
1747 .writer
1748 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1749
1750 let u32_type =
1751 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1752
1753 let offset_id = self.gen_id();
1755 block.body.push(Instruction::ext_inst_gl_op(
1756 self.writer.gl450_ext_inst_id,
1757 spirv::GlslStd450Op::UMin,
1758 u32_type,
1759 offset_id,
1760 &[arg1_id, width_constant],
1761 ));
1762
1763 let max_count_id = self.gen_id();
1765 block.body.push(Instruction::binary(
1766 spirv::Op::ISub,
1767 u32_type,
1768 max_count_id,
1769 width_constant,
1770 offset_id,
1771 ));
1772
1773 let count_id = self.gen_id();
1775 block.body.push(Instruction::ext_inst_gl_op(
1776 self.writer.gl450_ext_inst_id,
1777 spirv::GlslStd450Op::UMin,
1778 u32_type,
1779 count_id,
1780 &[arg2_id, max_count_id],
1781 ));
1782
1783 MathOp::Custom(Instruction::ternary(
1784 op,
1785 result_type_id,
1786 id,
1787 arg0_id,
1788 offset_id,
1789 count_id,
1790 ))
1791 }
1792 Mf::InsertBits => {
1793 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1796 let width_constant = self
1797 .writer
1798 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1799
1800 let u32_type =
1801 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1802
1803 let offset_id = self.gen_id();
1805 block.body.push(Instruction::ext_inst_gl_op(
1806 self.writer.gl450_ext_inst_id,
1807 spirv::GlslStd450Op::UMin,
1808 u32_type,
1809 offset_id,
1810 &[arg2_id, width_constant],
1811 ));
1812
1813 let max_count_id = self.gen_id();
1815 block.body.push(Instruction::binary(
1816 spirv::Op::ISub,
1817 u32_type,
1818 max_count_id,
1819 width_constant,
1820 offset_id,
1821 ));
1822
1823 let count_id = self.gen_id();
1825 block.body.push(Instruction::ext_inst_gl_op(
1826 self.writer.gl450_ext_inst_id,
1827 spirv::GlslStd450Op::UMin,
1828 u32_type,
1829 count_id,
1830 &[arg3_id, max_count_id],
1831 ));
1832
1833 MathOp::Custom(Instruction::quaternary(
1834 spirv::Op::BitFieldInsert,
1835 result_type_id,
1836 id,
1837 arg0_id,
1838 arg1_id,
1839 offset_id,
1840 count_id,
1841 ))
1842 }
1843 Mf::FirstTrailingBit => MathOp::Ext(spirv::GlslStd450Op::FindILsb),
1844 Mf::FirstLeadingBit => {
1845 if arg_ty.scalar_width() == Some(4) {
1846 let thing = match arg_scalar_kind {
1847 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::FindUMsb,
1848 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::FindSMsb,
1849 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1850 };
1851 MathOp::Ext(thing)
1852 } else {
1853 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1854 }
1855 }
1856 Mf::Pack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm4x8),
1857 Mf::Pack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm4x8),
1858 Mf::Pack2x16float => MathOp::Ext(spirv::GlslStd450Op::PackHalf2x16),
1859 Mf::Pack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm2x16),
1860 Mf::Pack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm2x16),
1861 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1862 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1863 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1864
1865 let last_instruction =
1866 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1867 self.write_pack4x8_optimized(
1868 block,
1869 result_type_id,
1870 arg0_id,
1871 id,
1872 is_signed,
1873 should_clamp,
1874 )
1875 } else {
1876 self.write_pack4x8_polyfill(
1877 block,
1878 result_type_id,
1879 arg0_id,
1880 id,
1881 is_signed,
1882 should_clamp,
1883 )
1884 };
1885
1886 MathOp::Custom(last_instruction)
1887 }
1888 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm4x8),
1889 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm4x8),
1890 Mf::Unpack2x16float => MathOp::Ext(spirv::GlslStd450Op::UnpackHalf2x16),
1891 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm2x16),
1892 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm2x16),
1893 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1894 let is_signed = matches!(fun, Mf::Unpack4xI8);
1895
1896 let last_instruction =
1897 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1898 self.write_unpack4x8_optimized(
1899 block,
1900 result_type_id,
1901 arg0_id,
1902 id,
1903 is_signed,
1904 )
1905 } else {
1906 self.write_unpack4x8_polyfill(
1907 block,
1908 result_type_id,
1909 arg0_id,
1910 id,
1911 is_signed,
1912 )
1913 };
1914
1915 MathOp::Custom(last_instruction)
1916 }
1917 };
1918
1919 block.body.push(match math_op {
1920 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1921 self.writer.gl450_ext_inst_id,
1922 op,
1923 result_type_id,
1924 id,
1925 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1926 ),
1927 MathOp::Custom(inst) => inst,
1928 });
1929 id
1930 }
1931 crate::Expression::LocalVariable(variable) => {
1932 if let Some(rq_tracker) = self
1933 .function
1934 .ray_query_initialization_tracker_variables
1935 .get(&variable)
1936 {
1937 self.ray_query_tracker_expr.insert(
1938 expr_handle,
1939 super::RayQueryTrackers {
1940 initialized_tracker: rq_tracker.id,
1941 t_max_tracker: self
1942 .function
1943 .ray_query_t_max_tracker_variables
1944 .get(&variable)
1945 .expect("Both trackers are set at the same time.")
1946 .id,
1947 },
1948 );
1949 }
1950 self.function.variables[&variable].id
1951 }
1952 crate::Expression::Load { pointer } => {
1953 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1954 }
1955 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1956 crate::Expression::CallResult(_)
1957 | crate::Expression::AtomicResult { .. }
1958 | crate::Expression::WorkGroupUniformLoadResult { .. }
1959 | crate::Expression::RayQueryProceedResult
1960 | crate::Expression::SubgroupBallotResult
1961 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1962 crate::Expression::As {
1963 expr,
1964 kind,
1965 convert,
1966 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1967 crate::Expression::ImageLoad {
1968 image,
1969 coordinate,
1970 array_index,
1971 sample,
1972 level,
1973 } => self.write_image_load(
1974 result_type_id,
1975 image,
1976 coordinate,
1977 array_index,
1978 level,
1979 sample,
1980 block,
1981 )?,
1982 crate::Expression::ImageSample {
1983 image,
1984 sampler,
1985 gather,
1986 coordinate,
1987 array_index,
1988 offset,
1989 level,
1990 depth_ref,
1991 clamp_to_edge,
1992 } => self.write_image_sample(
1993 result_type_id,
1994 image,
1995 sampler,
1996 gather,
1997 coordinate,
1998 array_index,
1999 offset,
2000 level,
2001 depth_ref,
2002 clamp_to_edge,
2003 block,
2004 )?,
2005 crate::Expression::Select {
2006 condition,
2007 accept,
2008 reject,
2009 } => {
2010 let id = self.gen_id();
2011 let mut condition_id = self.cached[condition];
2012 let accept_id = self.cached[accept];
2013 let reject_id = self.cached[reject];
2014
2015 let condition_ty = self.fun_info[condition]
2016 .ty
2017 .inner_with(&self.ir_module.types);
2018 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2019
2020 if let (
2021 &crate::TypeInner::Scalar(
2022 condition_scalar @ crate::Scalar {
2023 kind: crate::ScalarKind::Bool,
2024 ..
2025 },
2026 ),
2027 &crate::TypeInner::Vector { size, .. },
2028 ) = (condition_ty, object_ty)
2029 {
2030 self.temp_list.clear();
2031 self.temp_list.resize(size as usize, condition_id);
2032
2033 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2034 size,
2035 scalar: condition_scalar,
2036 });
2037
2038 let id = self.gen_id();
2039 block.body.push(Instruction::composite_construct(
2040 bool_vector_type_id,
2041 id,
2042 &self.temp_list,
2043 ));
2044 condition_id = id
2045 }
2046
2047 let instruction =
2048 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2049 block.body.push(instruction);
2050 id
2051 }
2052 crate::Expression::Derivative { axis, ctrl, expr } => {
2053 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2054 match ctrl {
2055 Ctrl::Coarse | Ctrl::Fine => {
2056 self.writer.require_any(
2057 "DerivativeControl",
2058 &[spirv::Capability::DerivativeControl],
2059 )?;
2060 }
2061 Ctrl::None => {}
2062 }
2063 let id = self.gen_id();
2064 let expr_id = self.cached[expr];
2065 let op = match (axis, ctrl) {
2066 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2067 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2068 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2069 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2070 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2071 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2072 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2073 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2074 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2075 };
2076 block
2077 .body
2078 .push(Instruction::derivative(op, result_type_id, id, expr_id));
2079 id
2080 }
2081 crate::Expression::ImageQuery { image, query } => {
2082 self.write_image_query(result_type_id, image, query, block)?
2083 }
2084 crate::Expression::Relational { fun, argument } => {
2085 use crate::RelationalFunction as Rf;
2086 let arg_id = self.cached[argument];
2087 let op = match fun {
2088 Rf::All => spirv::Op::All,
2089 Rf::Any => spirv::Op::Any,
2090 Rf::IsNan => spirv::Op::IsNan,
2091 Rf::IsInf => spirv::Op::IsInf,
2092 };
2093 let id = self.gen_id();
2094 block
2095 .body
2096 .push(Instruction::relational(op, result_type_id, id, arg_id));
2097 id
2098 }
2099 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2100 crate::Expression::RayQueryGetIntersection { query, committed } => {
2101 let query_id = self.cached[query];
2102 let init_tracker_id = *self
2103 .ray_query_tracker_expr
2104 .get(&query)
2105 .expect("not a cached ray query");
2106 let func_id = self
2107 .writer
2108 .write_ray_query_get_intersection_function(committed, self.ir_module);
2109 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2110 let intersection_type_id = self.get_handle_type_id(ray_intersection);
2111 let id = self.gen_id();
2112 block.body.push(Instruction::function_call(
2113 intersection_type_id,
2114 id,
2115 func_id,
2116 &[query_id, init_tracker_id.initialized_tracker],
2117 ));
2118 id
2119 }
2120 crate::Expression::RayQueryVertexPositions { query, committed } => {
2121 self.writer.require_any(
2122 "RayQueryVertexPositions",
2123 &[spirv::Capability::RayQueryPositionFetchKHR],
2124 )?;
2125 self.write_ray_query_return_vertex_position(query, block, committed)
2126 }
2127 crate::Expression::CooperativeLoad { ref data, .. } => {
2128 self.writer.require_any(
2129 "CooperativeMatrix",
2130 &[spirv::Capability::CooperativeMatrixKHR],
2131 )?;
2132 let layout = if data.row_major {
2133 spirv::CooperativeMatrixLayout::RowMajorKHR
2134 } else {
2135 spirv::CooperativeMatrixLayout::ColumnMajorKHR
2136 };
2137 let layout_id = self.get_index_constant(layout as u32);
2138 let stride_id = self.cached[data.stride];
2139 match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2140 ExpressionPointer::Ready { pointer_id } => {
2141 let id = self.gen_id();
2142 block.body.push(Instruction::coop_load(
2143 result_type_id,
2144 id,
2145 pointer_id,
2146 layout_id,
2147 stride_id,
2148 ));
2149 id
2150 }
2151 ExpressionPointer::Conditional { condition, access } => self
2152 .write_conditional_indexed_load(
2153 result_type_id,
2154 condition,
2155 block,
2156 |id_gen, block| {
2157 let pointer_id = access.result_id.unwrap();
2158 block.body.push(access);
2159 let id = id_gen.next();
2160 block.body.push(Instruction::coop_load(
2161 result_type_id,
2162 id,
2163 pointer_id,
2164 layout_id,
2165 stride_id,
2166 ));
2167 id
2168 },
2169 ),
2170 }
2171 }
2172 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2173 self.writer.require_any(
2174 "CooperativeMatrix",
2175 &[spirv::Capability::CooperativeMatrixKHR],
2176 )?;
2177 let a_id = self.cached[a];
2178 let b_id = self.cached[b];
2179 let c_id = self.cached[c];
2180 let id = self.gen_id();
2181 block.body.push(Instruction::coop_mul_add(
2182 result_type_id,
2183 id,
2184 a_id,
2185 b_id,
2186 c_id,
2187 ));
2188 id
2189 }
2190 };
2191
2192 self.cached[expr_handle] = id;
2193 Ok(())
2194 }
2195
2196 fn write_as_expression(
2199 &mut self,
2200 expr: Handle<crate::Expression>,
2201 convert: Option<u8>,
2202 kind: crate::ScalarKind,
2203
2204 block: &mut Block,
2205 result_type_id: u32,
2206 ) -> Result<u32, Error> {
2207 use crate::ScalarKind as Sk;
2208 let expr_id = self.cached[expr];
2209 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2210
2211 if let crate::TypeInner::Matrix {
2216 columns,
2217 rows,
2218 scalar,
2219 } = *ty
2220 {
2221 let Some(convert) = convert else {
2222 return Ok(expr_id);
2224 };
2225
2226 if convert == scalar.width {
2227 return Ok(expr_id);
2229 }
2230
2231 if kind != Sk::Float {
2232 return Err(Error::Validation("Matrices must be floats"));
2234 }
2235
2236 let column_src_ty =
2238 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2239 size: rows,
2240 scalar,
2241 })));
2242
2243 let column_dst_ty =
2245 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2246 size: rows,
2247 scalar: crate::Scalar {
2248 kind,
2249 width: convert,
2250 },
2251 })));
2252
2253 let mut components = ArrayVec::<Word, 4>::new();
2254
2255 for column in 0..columns as usize {
2256 let column_id = self.gen_id();
2257 block.body.push(Instruction::composite_extract(
2258 column_src_ty,
2259 column_id,
2260 expr_id,
2261 &[column as u32],
2262 ));
2263
2264 let column_conv_id = self.gen_id();
2265 block.body.push(Instruction::unary(
2266 spirv::Op::FConvert,
2267 column_dst_ty,
2268 column_conv_id,
2269 column_id,
2270 ));
2271
2272 components.push(column_conv_id);
2273 }
2274
2275 let construct_id = self.gen_id();
2276
2277 block.body.push(Instruction::composite_construct(
2278 result_type_id,
2279 construct_id,
2280 &components,
2281 ));
2282
2283 return Ok(construct_id);
2284 }
2285
2286 let (src_scalar, src_size) = match *ty {
2287 crate::TypeInner::Scalar(scalar) => (scalar, None),
2288 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2289 ref other => {
2290 log::error!("As source {other:?}");
2291 return Err(Error::Validation("Unexpected Expression::As source"));
2292 }
2293 };
2294
2295 enum Cast {
2296 Identity(Word),
2297 Unary(spirv::Op, Word),
2298 Binary(spirv::Op, Word, Word),
2299 Ternary(spirv::Op, Word, Word, Word),
2300 }
2301 let cast = match (src_scalar.kind, kind, convert) {
2302 (src_kind, kind, convert)
2305 if src_kind == kind
2306 && convert.filter(|&width| width != src_scalar.width).is_none() =>
2307 {
2308 Cast::Identity(expr_id)
2309 }
2310 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2311 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2312 (_, Sk::Bool, Some(_)) => {
2314 let op = match src_scalar.kind {
2315 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2316 Sk::Float => spirv::Op::FUnordNotEqual,
2317 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2318 };
2319 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2320 let zero_id = match src_size {
2321 Some(size) => {
2322 let ty = LocalType::Numeric(NumericType::Vector {
2323 size,
2324 scalar: src_scalar,
2325 })
2326 .into();
2327
2328 self.temp_list.clear();
2329 self.temp_list.resize(size as _, zero_scalar_id);
2330
2331 self.writer.get_constant_composite(ty, &self.temp_list)
2332 }
2333 None => zero_scalar_id,
2334 };
2335
2336 Cast::Binary(op, expr_id, zero_id)
2337 }
2338 (Sk::Bool, _, Some(dst_width)) => {
2340 let dst_scalar = crate::Scalar {
2341 kind,
2342 width: dst_width,
2343 };
2344 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2345 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2346 let (accept_id, reject_id) = match src_size {
2347 Some(size) => {
2348 let ty = LocalType::Numeric(NumericType::Vector {
2349 size,
2350 scalar: dst_scalar,
2351 })
2352 .into();
2353
2354 self.temp_list.clear();
2355 self.temp_list.resize(size as _, zero_scalar_id);
2356
2357 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2358
2359 self.temp_list.fill(one_scalar_id);
2360
2361 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2362
2363 (vec1_id, vec0_id)
2364 }
2365 None => (one_scalar_id, zero_scalar_id),
2366 };
2367
2368 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2369 }
2370 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2381 let dst_scalar = crate::Scalar { kind, width };
2382 let (min, max) =
2383 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2384 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2385
2386 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2387 None => const_id,
2388 Some(size) => {
2389 let constituent_ids = [const_id; crate::VectorSize::MAX];
2390 writer.get_constant_composite(
2391 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2392 size,
2393 scalar: src_scalar,
2394 })),
2395 &constituent_ids[..size as usize],
2396 )
2397 }
2398 };
2399 let min_const_id = self.writer.get_constant_scalar(min);
2400 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2401 let max_const_id = self.writer.get_constant_scalar(max);
2402 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2403
2404 let clamp_id = self.gen_id();
2405 block.body.push(Instruction::ext_inst_gl_op(
2406 self.writer.gl450_ext_inst_id,
2407 spirv::GlslStd450Op::FClamp,
2408 expr_type_id,
2409 clamp_id,
2410 &[expr_id, min_const_id, max_const_id],
2411 ));
2412
2413 let op = match dst_scalar.kind {
2414 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2415 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2416 _ => unreachable!(),
2417 };
2418 Cast::Unary(op, clamp_id)
2419 }
2420 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2421 Cast::Unary(spirv::Op::FConvert, expr_id)
2422 }
2423 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2424 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2425 Cast::Unary(spirv::Op::SConvert, expr_id)
2426 }
2427 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2428 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2429 Cast::Unary(spirv::Op::UConvert, expr_id)
2430 }
2431 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2432 Cast::Unary(spirv::Op::SConvert, expr_id)
2433 }
2434 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2435 Cast::Unary(spirv::Op::UConvert, expr_id)
2436 }
2437 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2439 };
2440 Ok(match cast {
2441 Cast::Identity(expr) => expr,
2442 Cast::Unary(op, op1) => {
2443 let id = self.gen_id();
2444 block
2445 .body
2446 .push(Instruction::unary(op, result_type_id, id, op1));
2447 id
2448 }
2449 Cast::Binary(op, op1, op2) => {
2450 let id = self.gen_id();
2451 block
2452 .body
2453 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2454 id
2455 }
2456 Cast::Ternary(op, op1, op2, op3) => {
2457 let id = self.gen_id();
2458 block
2459 .body
2460 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2461 id
2462 }
2463 })
2464 }
2465
2466 fn write_access_chain(
2477 &mut self,
2478 mut expr_handle: Handle<crate::Expression>,
2479 block: &mut Block,
2480 type_adjustment: AccessTypeAdjustment,
2481 ) -> Result<ExpressionPointer, Error> {
2482 let result_type_id = {
2483 let resolution = &self.fun_info[expr_handle].ty;
2484 match type_adjustment {
2485 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2486 AccessTypeAdjustment::IntroducePointer(class) => {
2487 self.writer.get_resolution_pointer_id(resolution, class)
2488 }
2489 AccessTypeAdjustment::UseStd140CompatType => {
2490 match *resolution.inner_with(&self.ir_module.types) {
2491 crate::TypeInner::Pointer {
2492 base,
2493 space: space @ crate::AddressSpace::Uniform,
2494 } => self.writer.get_pointer_type_id(
2495 self.writer.std140_compat_uniform_types[&base].type_id,
2496 map_storage_class(space),
2497 ),
2498 _ => unreachable!(
2499 "`UseStd140CompatType` must only be used with uniform pointer types"
2500 ),
2501 }
2502 }
2503 }
2504 };
2505
2506 let mut accumulated_checks = None;
2510
2511 let mut is_non_uniform_binding_array = false;
2513
2514 let mut prev_decomposed_matrix_index = None;
2520
2521 self.temp_list.clear();
2522 let root_id = loop {
2523 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2526 break spilled.id;
2529 }
2530
2531 expr_handle = match self.ir_function.expressions[expr_handle] {
2532 crate::Expression::Access { base, index } => {
2533 is_non_uniform_binding_array |=
2534 self.is_nonuniform_binding_array_access(base, index);
2535
2536 let index = GuardedIndex::Expression(index);
2537 let index_id =
2538 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2539 self.temp_list.push(index_id);
2540
2541 base
2542 }
2543 crate::Expression::AccessIndex { base, index } => {
2544 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2547 let mut base_ty_handle = self.fun_info[base].ty.handle();
2548 let mut pointer_space = None;
2549 if let crate::TypeInner::Pointer { base, space } = *base_ty {
2550 base_ty = &self.ir_module.types[base].inner;
2551 base_ty_handle = Some(base);
2552 pointer_space = Some(space);
2553 }
2554 match *base_ty {
2555 crate::TypeInner::Struct { .. } => {
2562 let index = match base_ty_handle.and_then(|handle| {
2563 self.writer.std140_compat_uniform_types.get(&handle)
2564 }) {
2565 Some(std140_type_info)
2566 if pointer_space == Some(crate::AddressSpace::Uniform) =>
2567 {
2568 std140_type_info.member_indices[index as usize]
2569 + prev_decomposed_matrix_index.take().unwrap_or(0)
2570 }
2571 _ => index,
2572 };
2573 let index_id = self.get_index_constant(index);
2574 self.temp_list.push(index_id);
2575 }
2576 _ if is_uniform_matcx2_struct_member_access(
2583 self.ir_function,
2584 self.fun_info,
2585 self.ir_module,
2586 base,
2587 ) =>
2588 {
2589 assert!(prev_decomposed_matrix_index.is_none());
2590 prev_decomposed_matrix_index = Some(index);
2591 }
2592 _ => {
2593 let index_id = self.write_access_chain_index(
2600 base,
2601 GuardedIndex::Known(index),
2602 &mut accumulated_checks,
2603 block,
2604 )?;
2605 self.temp_list.push(index_id);
2606 }
2607 }
2608 base
2609 }
2610 crate::Expression::GlobalVariable(handle) => {
2611 let gv = &self.writer.global_variables[handle];
2612 break gv.access_id;
2613 }
2614 crate::Expression::LocalVariable(variable) => {
2615 let local_var = &self.function.variables[&variable];
2616 break local_var.id;
2617 }
2618 crate::Expression::FunctionArgument(index) => {
2619 break self.function.parameter_id(index);
2620 }
2621 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2622 }
2623 };
2624
2625 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2626 (
2627 root_id,
2628 ExpressionPointer::Ready {
2629 pointer_id: root_id,
2630 },
2631 )
2632 } else {
2633 self.temp_list.reverse();
2634 let pointer_id = self.gen_id();
2635 let access =
2636 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2637
2638 let expr_pointer = match accumulated_checks {
2643 Some(condition) => ExpressionPointer::Conditional { condition, access },
2644 None => {
2645 block.body.push(access);
2646 ExpressionPointer::Ready { pointer_id }
2647 }
2648 };
2649 (pointer_id, expr_pointer)
2650 };
2651 if is_non_uniform_binding_array {
2655 self.writer
2656 .decorate_non_uniform_binding_array_access(pointer_id)?;
2657 }
2658
2659 Ok(expr_pointer)
2660 }
2661
2662 fn is_nonuniform_binding_array_access(
2663 &mut self,
2664 base: Handle<crate::Expression>,
2665 index: Handle<crate::Expression>,
2666 ) -> bool {
2667 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2668 else {
2669 return false;
2670 };
2671
2672 let gvar = &self.ir_module.global_variables[var_handle];
2675 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2676 return false;
2677 };
2678
2679 self.fun_info[index].uniformity.non_uniform_result.is_some()
2680 }
2681
2682 fn write_access_chain_index(
2692 &mut self,
2693 base: Handle<crate::Expression>,
2694 index: GuardedIndex,
2695 accumulated_checks: &mut Option<Word>,
2696 block: &mut Block,
2697 ) -> Result<Word, Error> {
2698 match self.write_bounds_check(base, index, block)? {
2699 BoundsCheckResult::KnownInBounds(known_index) => {
2700 let scalar = crate::Literal::U32(known_index);
2703 Ok(self.writer.get_constant_scalar(scalar))
2704 }
2705 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2706 BoundsCheckResult::Conditional {
2707 condition_id: condition,
2708 index_id: index,
2709 } => {
2710 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2711
2712 Ok(index)
2714 }
2715 }
2716 }
2717
2718 fn extend_bounds_check_condition_chain(
2737 &mut self,
2738 chain: &mut Option<Word>,
2739 comparison_id: Word,
2740 block: &mut Block,
2741 ) {
2742 match *chain {
2743 Some(ref mut prior_checks) => {
2744 let combined = self.gen_id();
2745 block.body.push(Instruction::binary(
2746 spirv::Op::LogicalAnd,
2747 self.writer.get_bool_type_id(),
2748 combined,
2749 *prior_checks,
2750 comparison_id,
2751 ));
2752 *prior_checks = combined;
2753 }
2754 None => {
2755 *chain = Some(comparison_id);
2757 }
2758 }
2759 }
2760
2761 fn write_checked_load(
2762 &mut self,
2763 pointer: Handle<crate::Expression>,
2764 block: &mut Block,
2765 access_type_adjustment: AccessTypeAdjustment,
2766 result_type_id: Word,
2767 ) -> Result<Word, Error> {
2768 if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2769 Ok(result_id)
2770 } else if let Some(result_id) =
2771 self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2772 {
2773 Ok(result_id)
2774 } else {
2775 struct WrappedLoad {
2781 access_type_adjustment: AccessTypeAdjustment,
2782 r#type: Handle<crate::Type>,
2783 }
2784 let mut wrapped_load = None;
2785 if let crate::TypeInner::Pointer {
2786 base: pointer_base_type,
2787 space: crate::AddressSpace::Uniform,
2788 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2789 {
2790 if self
2791 .writer
2792 .std140_compat_uniform_types
2793 .contains_key(&pointer_base_type)
2794 {
2795 wrapped_load = Some(WrappedLoad {
2796 access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2797 r#type: pointer_base_type,
2798 });
2799 };
2800 };
2801
2802 let (load_type_id, access_type_adjustment) = match wrapped_load {
2803 Some(ref wrapped_load) => (
2804 self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2805 wrapped_load.access_type_adjustment,
2806 ),
2807 None => (result_type_id, access_type_adjustment),
2808 };
2809
2810 let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2811 ExpressionPointer::Ready { pointer_id } => {
2812 let id = self.gen_id();
2813 let atomic_space =
2814 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2815 crate::TypeInner::Pointer { base, space } => {
2816 match self.ir_module.types[base].inner {
2817 crate::TypeInner::Atomic { .. } => Some(space),
2818 _ => None,
2819 }
2820 }
2821 _ => None,
2822 };
2823 let instruction = if let Some(space) = atomic_space {
2824 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2825 let scope_constant_id = self.get_scope_constant(scope as u32);
2826 let semantics_id = self.get_index_constant(semantics.bits());
2827 Instruction::atomic_load(
2828 result_type_id,
2829 id,
2830 pointer_id,
2831 scope_constant_id,
2832 semantics_id,
2833 )
2834 } else {
2835 Instruction::load(load_type_id, id, pointer_id, None)
2836 };
2837 block.body.push(instruction);
2838 id
2839 }
2840 ExpressionPointer::Conditional { condition, access } => {
2841 self.write_conditional_indexed_load(
2843 load_type_id,
2844 condition,
2845 block,
2846 move |id_gen, block| {
2847 let pointer_id = access.result_id.unwrap();
2849 let value_id = id_gen.next();
2850 block.body.push(access);
2851 block.body.push(Instruction::load(
2852 load_type_id,
2853 value_id,
2854 pointer_id,
2855 None,
2856 ));
2857 value_id
2858 },
2859 )
2860 }
2861 };
2862
2863 match wrapped_load {
2864 Some(ref wrapped_load) => {
2865 let result_id = self.gen_id();
2868 let function_id = self.writer.wrapped_functions
2869 [&WrappedFunction::ConvertFromStd140CompatType {
2870 r#type: wrapped_load.r#type,
2871 }];
2872 block.body.push(Instruction::function_call(
2873 result_type_id,
2874 result_id,
2875 function_id,
2876 &[load_id],
2877 ));
2878 Ok(result_id)
2879 }
2880 None => Ok(load_id),
2881 }
2882 }
2883 }
2884
2885 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2886 use indexmap::map::Entry;
2887
2888 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2890 Entry::Occupied(preexisting) => preexisting.get().id,
2891 Entry::Vacant(vacant) => {
2892 let pointer_type_id = self.writer.get_resolution_pointer_id(
2895 &self.fun_info[base].ty,
2896 spirv::StorageClass::Function,
2897 );
2898 let id = self.writer.id_gen.next();
2899 vacant.insert(super::LocalVariable {
2900 id,
2901 instruction: Instruction::variable(
2902 pointer_type_id,
2903 id,
2904 spirv::StorageClass::Function,
2905 None,
2906 ),
2907 });
2908 id
2909 }
2910 };
2911
2912 let base_id = self.cached[base];
2937 block
2938 .body
2939 .push(Instruction::store(spill_variable_id, base_id, None));
2940 }
2941
2942 fn maybe_access_spilled_composite(
2959 &mut self,
2960 access: Handle<crate::Expression>,
2961 block: &mut Block,
2962 result_type_id: Word,
2963 ) -> Result<Word, Error> {
2964 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2965 if access_uses == self.fun_info[access].ref_count {
2966 Ok(0)
2970 } else {
2971 self.write_checked_load(
2976 access,
2977 block,
2978 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2979 result_type_id,
2980 )
2981 }
2982 }
2983
2984 #[allow(clippy::too_many_arguments)]
2986 fn write_matrix_matrix_column_op(
2987 &mut self,
2988 block: &mut Block,
2989 result_id: Word,
2990 result_type_id: Word,
2991 left_id: Word,
2992 right_id: Word,
2993 columns: crate::VectorSize,
2994 rows: crate::VectorSize,
2995 width: u8,
2996 op: spirv::Op,
2997 ) {
2998 self.temp_list.clear();
2999
3000 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3001 size: rows,
3002 scalar: crate::Scalar::float(width),
3003 });
3004
3005 for index in 0..columns as u32 {
3006 let column_id_left = self.gen_id();
3007 let column_id_right = self.gen_id();
3008 let column_id_res = self.gen_id();
3009
3010 block.body.push(Instruction::composite_extract(
3011 vector_type_id,
3012 column_id_left,
3013 left_id,
3014 &[index],
3015 ));
3016 block.body.push(Instruction::composite_extract(
3017 vector_type_id,
3018 column_id_right,
3019 right_id,
3020 &[index],
3021 ));
3022 block.body.push(Instruction::binary(
3023 op,
3024 vector_type_id,
3025 column_id_res,
3026 column_id_left,
3027 column_id_right,
3028 ));
3029
3030 self.temp_list.push(column_id_res);
3031 }
3032
3033 block.body.push(Instruction::composite_construct(
3034 result_type_id,
3035 result_id,
3036 &self.temp_list,
3037 ));
3038 }
3039
3040 fn write_vector_scalar_mult(
3042 &mut self,
3043 block: &mut Block,
3044 result_id: Word,
3045 result_type_id: Word,
3046 vector_id: Word,
3047 scalar_id: Word,
3048 vector: &crate::TypeInner,
3049 ) {
3050 let (size, kind) = match *vector {
3051 crate::TypeInner::Vector {
3052 size,
3053 scalar: crate::Scalar { kind, .. },
3054 } => (size, kind),
3055 _ => unreachable!(),
3056 };
3057
3058 let (op, operand_id) = match kind {
3059 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3060 _ => {
3061 let operand_id = self.gen_id();
3062 self.temp_list.clear();
3063 self.temp_list.resize(size as usize, scalar_id);
3064 block.body.push(Instruction::composite_construct(
3065 result_type_id,
3066 operand_id,
3067 &self.temp_list,
3068 ));
3069 (spirv::Op::IMul, operand_id)
3070 }
3071 };
3072
3073 block.body.push(Instruction::binary(
3074 op,
3075 result_type_id,
3076 result_id,
3077 vector_id,
3078 operand_id,
3079 ));
3080 }
3081
3082 #[expect(clippy::too_many_arguments)]
3089 fn write_dot_product(
3090 &mut self,
3091 result_id: Word,
3092 result_type_id: Word,
3093 arg0_id: Word,
3094 arg1_id: Word,
3095 size: u32,
3096 block: &mut Block,
3097 extractor: impl Fn(Word, Word, Word) -> Instruction,
3098 ) {
3099 let mut partial_sum = self.writer.get_constant_null(result_type_id);
3100 let last_component = size - 1;
3101 for index in 0..=last_component {
3102 let a_id = self.gen_id();
3104 block.body.push(extractor(a_id, arg0_id, index));
3105 let b_id = self.gen_id();
3106 block.body.push(extractor(b_id, arg1_id, index));
3107 let prod_id = self.gen_id();
3108 block.body.push(Instruction::binary(
3109 spirv::Op::IMul,
3110 result_type_id,
3111 prod_id,
3112 a_id,
3113 b_id,
3114 ));
3115
3116 let id = if index == last_component {
3118 result_id
3119 } else {
3120 self.gen_id()
3121 };
3122
3123 block.body.push(Instruction::binary(
3125 spirv::Op::IAdd,
3126 result_type_id,
3127 id,
3128 partial_sum,
3129 prod_id,
3130 ));
3131 partial_sum = id;
3133 }
3134 }
3135
3136 fn write_pack4x8_optimized(
3138 &mut self,
3139 block: &mut Block,
3140 result_type_id: u32,
3141 arg0_id: u32,
3142 id: u32,
3143 is_signed: bool,
3144 should_clamp: bool,
3145 ) -> Instruction {
3146 let int_type = if is_signed {
3147 crate::ScalarKind::Sint
3148 } else {
3149 crate::ScalarKind::Uint
3150 };
3151 let wide_vector_type = NumericType::Vector {
3152 size: crate::VectorSize::Quad,
3153 scalar: crate::Scalar {
3154 kind: int_type,
3155 width: 4,
3156 },
3157 };
3158 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3159 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3160 size: crate::VectorSize::Quad,
3161 scalar: crate::Scalar {
3162 kind: crate::ScalarKind::Uint,
3163 width: 1,
3164 },
3165 });
3166
3167 let mut wide_vector = arg0_id;
3168 if should_clamp {
3169 let (min, max, clamp_op) = if is_signed {
3170 (
3171 crate::Literal::I32(-128),
3172 crate::Literal::I32(127),
3173 spirv::GlslStd450Op::SClamp,
3174 )
3175 } else {
3176 (
3177 crate::Literal::U32(0),
3178 crate::Literal::U32(255),
3179 spirv::GlslStd450Op::UClamp,
3180 )
3181 };
3182 let [min, max] = [min, max].map(|lit| {
3183 let scalar = self.writer.get_constant_scalar(lit);
3184 self.writer.get_constant_composite(
3185 LookupType::Local(LocalType::Numeric(wide_vector_type)),
3186 &[scalar; 4],
3187 )
3188 });
3189
3190 let clamp_id = self.gen_id();
3191 block.body.push(Instruction::ext_inst_gl_op(
3192 self.writer.gl450_ext_inst_id,
3193 clamp_op,
3194 wide_vector_type_id,
3195 clamp_id,
3196 &[wide_vector, min, max],
3197 ));
3198
3199 wide_vector = clamp_id;
3200 }
3201
3202 let packed_vector = self.gen_id();
3203 block.body.push(Instruction::unary(
3204 spirv::Op::UConvert, packed_vector_type_id,
3206 packed_vector,
3207 wide_vector,
3208 ));
3209
3210 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3215 }
3216
3217 fn write_pack4x8_polyfill(
3219 &mut self,
3220 block: &mut Block,
3221 result_type_id: u32,
3222 arg0_id: u32,
3223 id: u32,
3224 is_signed: bool,
3225 should_clamp: bool,
3226 ) -> Instruction {
3227 let int_type = if is_signed {
3228 crate::ScalarKind::Sint
3229 } else {
3230 crate::ScalarKind::Uint
3231 };
3232 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3233 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3234 kind: int_type,
3235 width: 4,
3236 }));
3237
3238 let mut last_instruction = Instruction::new(spirv::Op::Nop);
3239
3240 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3241 let mut preresult = zero;
3242 block
3243 .body
3244 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3245
3246 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3247 const VEC_LENGTH: u8 = 4;
3248 for i in 0..u32::from(VEC_LENGTH) {
3249 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3250 let mut extracted = self.gen_id();
3251 block.body.push(Instruction::binary(
3252 spirv::Op::CompositeExtract,
3253 int_type_id,
3254 extracted,
3255 arg0_id,
3256 i,
3257 ));
3258 if is_signed {
3259 let casted = self.gen_id();
3260 block.body.push(Instruction::unary(
3261 spirv::Op::Bitcast,
3262 uint_type_id,
3263 casted,
3264 extracted,
3265 ));
3266 extracted = casted;
3267 }
3268 if should_clamp {
3269 let (min, max, clamp_op) = if is_signed {
3270 (
3271 crate::Literal::I32(-128),
3272 crate::Literal::I32(127),
3273 spirv::GlslStd450Op::SClamp,
3274 )
3275 } else {
3276 (
3277 crate::Literal::U32(0),
3278 crate::Literal::U32(255),
3279 spirv::GlslStd450Op::UClamp,
3280 )
3281 };
3282 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3283
3284 let clamp_id = self.gen_id();
3285 block.body.push(Instruction::ext_inst_gl_op(
3286 self.writer.gl450_ext_inst_id,
3287 clamp_op,
3288 result_type_id,
3289 clamp_id,
3290 &[extracted, min, max],
3291 ));
3292
3293 extracted = clamp_id;
3294 }
3295 let is_last = i == u32::from(VEC_LENGTH - 1);
3296 if is_last {
3297 last_instruction = Instruction::quaternary(
3298 spirv::Op::BitFieldInsert,
3299 result_type_id,
3300 id,
3301 preresult,
3302 extracted,
3303 offset,
3304 eight,
3305 )
3306 } else {
3307 let new_preresult = self.gen_id();
3308 block.body.push(Instruction::quaternary(
3309 spirv::Op::BitFieldInsert,
3310 result_type_id,
3311 new_preresult,
3312 preresult,
3313 extracted,
3314 offset,
3315 eight,
3316 ));
3317 preresult = new_preresult;
3318 }
3319 }
3320 last_instruction
3321 }
3322
3323 fn write_unpack4x8_optimized(
3325 &mut self,
3326 block: &mut Block,
3327 result_type_id: u32,
3328 arg0_id: u32,
3329 id: u32,
3330 is_signed: bool,
3331 ) -> Instruction {
3332 let (int_type, convert_op) = if is_signed {
3333 (crate::ScalarKind::Sint, spirv::Op::SConvert)
3334 } else {
3335 (crate::ScalarKind::Uint, spirv::Op::UConvert)
3336 };
3337
3338 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3339 size: crate::VectorSize::Quad,
3340 scalar: crate::Scalar {
3341 kind: int_type,
3342 width: 1,
3343 },
3344 });
3345
3346 let packed_vector = self.gen_id();
3351 block.body.push(Instruction::unary(
3352 spirv::Op::Bitcast,
3353 packed_vector_type_id,
3354 packed_vector,
3355 arg0_id,
3356 ));
3357
3358 Instruction::unary(convert_op, result_type_id, id, packed_vector)
3359 }
3360
3361 fn write_unpack4x8_polyfill(
3363 &mut self,
3364 block: &mut Block,
3365 result_type_id: u32,
3366 arg0_id: u32,
3367 id: u32,
3368 is_signed: bool,
3369 ) -> Instruction {
3370 let (int_type, extract_op) = if is_signed {
3371 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3372 } else {
3373 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3374 };
3375
3376 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3377
3378 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3379 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3380 kind: int_type,
3381 width: 4,
3382 }));
3383 block
3384 .body
3385 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3386 let arg_id = if is_signed {
3387 let new_arg_id = self.gen_id();
3388 block.body.push(Instruction::unary(
3389 spirv::Op::Bitcast,
3390 sint_type_id,
3391 new_arg_id,
3392 arg0_id,
3393 ));
3394 new_arg_id
3395 } else {
3396 arg0_id
3397 };
3398
3399 const VEC_LENGTH: u8 = 4;
3400 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3401 for (i, part_id) in parts.into_iter().enumerate() {
3402 let index = self
3403 .writer
3404 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3405 block.body.push(Instruction::ternary(
3406 extract_op,
3407 int_type_id,
3408 part_id,
3409 arg_id,
3410 index,
3411 eight,
3412 ));
3413 }
3414
3415 Instruction::composite_construct(result_type_id, id, &parts)
3416 }
3417
3418 fn write_block(
3435 &mut self,
3436 label_id: Word,
3437 naga_block: &crate::Block,
3438 exit: BlockExit,
3439 loop_context: LoopContext,
3440 debug_info: Option<&DebugInfoInner>,
3441 ) -> Result<BlockExitDisposition, Error> {
3442 let mut block = Block::new(label_id);
3443 for (statement, span) in naga_block.span_iter() {
3444 if let (Some(debug_info), false) = (
3445 debug_info,
3446 matches!(
3447 statement,
3448 &(Statement::Block(..)
3449 | Statement::Break
3450 | Statement::Continue
3451 | Statement::Kill
3452 | Statement::Return { .. }
3453 | Statement::Loop { .. })
3454 ),
3455 ) {
3456 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3457 block.body.push(Instruction::line(
3458 debug_info.source_file_id,
3459 loc.line_number,
3460 loc.line_position,
3461 ));
3462 };
3463 match *statement {
3464 Statement::Emit(ref range) => {
3465 for handle in range.clone() {
3466 if !self.expression_constness.is_const(handle) {
3468 self.cache_expression_value(handle, &mut block)?;
3469 }
3470 }
3471 }
3472 Statement::Block(ref block_statements) => {
3473 let scope_id = self.gen_id();
3474 self.function.consume(block, Instruction::branch(scope_id));
3475
3476 let merge_id = self.gen_id();
3477 let merge_used = self.write_block(
3478 scope_id,
3479 block_statements,
3480 BlockExit::Branch { target: merge_id },
3481 loop_context,
3482 debug_info,
3483 )?;
3484
3485 match merge_used {
3486 BlockExitDisposition::Used => {
3487 block = Block::new(merge_id);
3488 }
3489 BlockExitDisposition::Discarded => {
3490 return Ok(BlockExitDisposition::Discarded);
3491 }
3492 }
3493 }
3494 Statement::If {
3495 condition,
3496 ref accept,
3497 ref reject,
3498 } => {
3499 if !(accept.is_empty() && reject.is_empty()) {
3505 let condition_id = self.cached[condition];
3506
3507 let merge_id = self.gen_id();
3508 block.body.push(Instruction::selection_merge(
3509 merge_id,
3510 spirv::SelectionControl::NONE,
3511 ));
3512
3513 let accept_id = if accept.is_empty() {
3514 None
3515 } else {
3516 Some(self.gen_id())
3517 };
3518 let reject_id = if reject.is_empty() {
3519 None
3520 } else {
3521 Some(self.gen_id())
3522 };
3523
3524 self.function.consume(
3525 block,
3526 Instruction::branch_conditional(
3527 condition_id,
3528 accept_id.unwrap_or(merge_id),
3529 reject_id.unwrap_or(merge_id),
3530 ),
3531 );
3532
3533 if let Some(block_id) = accept_id {
3534 let _ = self.write_block(
3539 block_id,
3540 accept,
3541 BlockExit::Branch { target: merge_id },
3542 loop_context,
3543 debug_info,
3544 )?;
3545 }
3546 if let Some(block_id) = reject_id {
3547 let _ = self.write_block(
3552 block_id,
3553 reject,
3554 BlockExit::Branch { target: merge_id },
3555 loop_context,
3556 debug_info,
3557 )?;
3558 }
3559
3560 block = Block::new(merge_id);
3561 }
3562 }
3563 Statement::Switch {
3564 selector,
3565 ref cases,
3566 } => {
3567 let selector_id = self.cached[selector];
3568
3569 let merge_id = self.gen_id();
3570 block.body.push(Instruction::selection_merge(
3571 merge_id,
3572 spirv::SelectionControl::NONE,
3573 ));
3574
3575 let mut default_id = None;
3576 let mut last_id = None;
3578
3579 let mut raw_cases = Vec::with_capacity(cases.len());
3580 let mut case_ids = Vec::with_capacity(cases.len());
3581 for case in cases.iter() {
3582 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3584
3585 if case.fall_through && case.body.is_empty() {
3586 last_id = Some(label_id);
3587 }
3588
3589 case_ids.push(label_id);
3590
3591 match case.value {
3592 crate::SwitchValue::I32(value) => {
3593 raw_cases.push(super::instructions::Case {
3594 value: value as Word,
3595 label_id,
3596 });
3597 }
3598 crate::SwitchValue::U32(value) => {
3599 raw_cases.push(super::instructions::Case { value, label_id });
3600 }
3601 crate::SwitchValue::Default => {
3602 default_id = Some(label_id);
3603 }
3604 }
3605 }
3606
3607 let default_id = default_id.unwrap();
3608
3609 self.function.consume(
3610 block,
3611 Instruction::switch(selector_id, default_id, &raw_cases),
3612 );
3613
3614 let inner_context = LoopContext {
3615 break_id: Some(merge_id),
3616 ..loop_context
3617 };
3618
3619 for (i, (case, label_id)) in cases
3620 .iter()
3621 .zip(case_ids.iter())
3622 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3623 .enumerate()
3624 {
3625 let case_finish_id = if case.fall_through {
3626 case_ids[i + 1]
3627 } else {
3628 merge_id
3629 };
3630 let _ = self.write_block(
3639 *label_id,
3640 &case.body,
3641 BlockExit::Branch {
3642 target: case_finish_id,
3643 },
3644 inner_context,
3645 debug_info,
3646 )?;
3647 }
3648
3649 block = Block::new(merge_id);
3650 }
3651 Statement::Loop {
3652 ref body,
3653 ref continuing,
3654 break_if,
3655 } => {
3656 let preamble_id = self.gen_id();
3657 self.function
3658 .consume(block, Instruction::branch(preamble_id));
3659
3660 let merge_id = self.gen_id();
3661 let body_id = self.gen_id();
3662 let continuing_id = self.gen_id();
3663
3664 block = Block::new(preamble_id);
3667 if let Some(debug_info) = debug_info {
3670 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3671 block.body.push(Instruction::line(
3672 debug_info.source_file_id,
3673 loc.line_number,
3674 loc.line_position,
3675 ))
3676 }
3677 block.body.push(Instruction::loop_merge(
3678 merge_id,
3679 continuing_id,
3680 spirv::SelectionControl::NONE,
3681 ));
3682
3683 if self.force_loop_bounding {
3684 block = self.write_force_bounded_loop_instructions(block, merge_id);
3685 }
3686 self.function.consume(block, Instruction::branch(body_id));
3687
3688 let _ = self.write_block(
3692 body_id,
3693 body,
3694 BlockExit::Branch {
3695 target: continuing_id,
3696 },
3697 LoopContext {
3698 continuing_id: Some(continuing_id),
3699 break_id: Some(merge_id),
3700 },
3701 debug_info,
3702 )?;
3703
3704 let exit = match break_if {
3705 Some(condition) => BlockExit::BreakIf {
3706 condition,
3707 preamble_id,
3708 },
3709 None => BlockExit::Branch {
3710 target: preamble_id,
3711 },
3712 };
3713
3714 let _ = self.write_block(
3718 continuing_id,
3719 continuing,
3720 exit,
3721 LoopContext {
3722 continuing_id: None,
3723 break_id: Some(merge_id),
3724 },
3725 debug_info,
3726 )?;
3727
3728 block = Block::new(merge_id);
3729 }
3730 Statement::Break => {
3731 self.function
3732 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3733 return Ok(BlockExitDisposition::Discarded);
3734 }
3735 Statement::Continue => {
3736 self.function.consume(
3737 block,
3738 Instruction::branch(loop_context.continuing_id.unwrap()),
3739 );
3740 return Ok(BlockExitDisposition::Discarded);
3741 }
3742 Statement::Return { value: Some(value) } => {
3743 let value_id = self.cached[value];
3744 let instruction = match self.function.entry_point_context {
3745 Some(ref context) => self.writer.write_entry_point_return(
3748 value_id,
3749 self.ir_function.result.as_ref().unwrap(),
3750 &context.results,
3751 &mut block.body,
3752 )?,
3753 None => Instruction::return_value(value_id),
3754 };
3755 self.function.consume(block, instruction);
3756 return Ok(BlockExitDisposition::Discarded);
3757 }
3758 Statement::Return { value: None } => {
3759 self.function.consume(block, Instruction::return_void());
3760 return Ok(BlockExitDisposition::Discarded);
3761 }
3762 Statement::Kill => {
3763 self.function.consume(block, Instruction::kill());
3764 return Ok(BlockExitDisposition::Discarded);
3765 }
3766 Statement::ControlBarrier(flags) => {
3767 self.writer.write_control_barrier(flags, &mut block.body);
3768 }
3769 Statement::MemoryBarrier(flags) => {
3770 self.writer.write_memory_barrier(flags, &mut block);
3771 }
3772 Statement::Store { pointer, value } => {
3773 let value_id = self.cached[value];
3774 match self.write_access_chain(
3775 pointer,
3776 &mut block,
3777 AccessTypeAdjustment::None,
3778 )? {
3779 ExpressionPointer::Ready { pointer_id } => {
3780 let atomic_space = match *self.fun_info[pointer]
3781 .ty
3782 .inner_with(&self.ir_module.types)
3783 {
3784 crate::TypeInner::Pointer { base, space } => {
3785 match self.ir_module.types[base].inner {
3786 crate::TypeInner::Atomic { .. } => Some(space),
3787 _ => None,
3788 }
3789 }
3790 _ => None,
3791 };
3792 let instruction = if let Some(space) = atomic_space {
3793 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3794 let scope_constant_id = self.get_scope_constant(scope as u32);
3795 let semantics_id = self.get_index_constant(semantics.bits());
3796 Instruction::atomic_store(
3797 pointer_id,
3798 scope_constant_id,
3799 semantics_id,
3800 value_id,
3801 )
3802 } else {
3803 Instruction::store(pointer_id, value_id, None)
3804 };
3805 block.body.push(instruction);
3806 }
3807 ExpressionPointer::Conditional { condition, access } => {
3808 let mut selection = Selection::start(&mut block, ());
3809 selection.if_true(self, condition, ());
3810
3811 let pointer_id = access.result_id.unwrap();
3813 selection.block().body.push(access);
3814 selection
3815 .block()
3816 .body
3817 .push(Instruction::store(pointer_id, value_id, None));
3818
3819 selection.finish(self, ());
3822 }
3823 };
3824 }
3825 Statement::ImageStore {
3826 image,
3827 coordinate,
3828 array_index,
3829 value,
3830 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3831 Statement::Call {
3832 function: local_function,
3833 ref arguments,
3834 result,
3835 } => {
3836 let id = self.gen_id();
3837 self.temp_list.clear();
3838 for &argument in arguments {
3839 self.temp_list.push(self.cached[argument]);
3840 }
3841
3842 let type_id = match result {
3843 Some(expr) => {
3844 self.cached[expr] = id;
3845 self.get_expression_type_id(&self.fun_info[expr].ty)
3846 }
3847 None => self.writer.void_type,
3848 };
3849
3850 block.body.push(Instruction::function_call(
3851 type_id,
3852 id,
3853 self.writer.lookup_function[&local_function],
3854 &self.temp_list,
3855 ));
3856 }
3857 Statement::Atomic {
3858 pointer,
3859 ref fun,
3860 value,
3861 result,
3862 } => {
3863 let id = self.gen_id();
3864 let result_type_id =
3868 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3869
3870 if let Some(result) = result {
3871 self.cached[result] = id;
3872 }
3873
3874 let pointer_id = match self.write_access_chain(
3875 pointer,
3876 &mut block,
3877 AccessTypeAdjustment::None,
3878 )? {
3879 ExpressionPointer::Ready { pointer_id } => pointer_id,
3880 ExpressionPointer::Conditional { .. } => {
3881 return Err(Error::FeatureNotImplemented(
3882 "Atomics out-of-bounds handling",
3883 ));
3884 }
3885 };
3886
3887 let space = self.fun_info[pointer]
3888 .ty
3889 .inner_with(&self.ir_module.types)
3890 .pointer_space()
3891 .unwrap();
3892 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3893 let scope_constant_id = self.get_scope_constant(scope as u32);
3894 let semantics_id = self.get_index_constant(semantics.bits());
3895 let value_id = self.cached[value];
3896 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3897
3898 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3899 return Err(Error::FeatureNotImplemented(
3900 "Atomics with non-scalar values",
3901 ));
3902 };
3903
3904 let instruction = match *fun {
3905 crate::AtomicFunction::Add => {
3906 let spirv_op = match scalar.kind {
3907 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3908 spirv::Op::AtomicIAdd
3909 }
3910 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3911 _ => unimplemented!(),
3912 };
3913 Instruction::atomic_binary(
3914 spirv_op,
3915 result_type_id,
3916 id,
3917 pointer_id,
3918 scope_constant_id,
3919 semantics_id,
3920 value_id,
3921 )
3922 }
3923 crate::AtomicFunction::Subtract => {
3924 let (spirv_op, value_id) = match scalar.kind {
3925 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3926 (spirv::Op::AtomicISub, value_id)
3927 }
3928 crate::ScalarKind::Float => {
3929 let neg_result_id = self.gen_id();
3932 block.body.push(Instruction::unary(
3933 spirv::Op::FNegate,
3934 result_type_id,
3935 neg_result_id,
3936 value_id,
3937 ));
3938 (spirv::Op::AtomicFAddEXT, neg_result_id)
3939 }
3940 _ => unimplemented!(),
3941 };
3942 Instruction::atomic_binary(
3943 spirv_op,
3944 result_type_id,
3945 id,
3946 pointer_id,
3947 scope_constant_id,
3948 semantics_id,
3949 value_id,
3950 )
3951 }
3952 crate::AtomicFunction::And => {
3953 let spirv_op = match scalar.kind {
3954 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3955 spirv::Op::AtomicAnd
3956 }
3957 _ => unimplemented!(),
3958 };
3959 Instruction::atomic_binary(
3960 spirv_op,
3961 result_type_id,
3962 id,
3963 pointer_id,
3964 scope_constant_id,
3965 semantics_id,
3966 value_id,
3967 )
3968 }
3969 crate::AtomicFunction::InclusiveOr => {
3970 let spirv_op = match scalar.kind {
3971 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3972 spirv::Op::AtomicOr
3973 }
3974 _ => unimplemented!(),
3975 };
3976 Instruction::atomic_binary(
3977 spirv_op,
3978 result_type_id,
3979 id,
3980 pointer_id,
3981 scope_constant_id,
3982 semantics_id,
3983 value_id,
3984 )
3985 }
3986 crate::AtomicFunction::ExclusiveOr => {
3987 let spirv_op = match scalar.kind {
3988 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3989 spirv::Op::AtomicXor
3990 }
3991 _ => unimplemented!(),
3992 };
3993 Instruction::atomic_binary(
3994 spirv_op,
3995 result_type_id,
3996 id,
3997 pointer_id,
3998 scope_constant_id,
3999 semantics_id,
4000 value_id,
4001 )
4002 }
4003 crate::AtomicFunction::Min => {
4004 let spirv_op = match scalar.kind {
4005 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4006 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4007 _ => unimplemented!(),
4008 };
4009 Instruction::atomic_binary(
4010 spirv_op,
4011 result_type_id,
4012 id,
4013 pointer_id,
4014 scope_constant_id,
4015 semantics_id,
4016 value_id,
4017 )
4018 }
4019 crate::AtomicFunction::Max => {
4020 let spirv_op = match scalar.kind {
4021 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4022 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4023 _ => unimplemented!(),
4024 };
4025 Instruction::atomic_binary(
4026 spirv_op,
4027 result_type_id,
4028 id,
4029 pointer_id,
4030 scope_constant_id,
4031 semantics_id,
4032 value_id,
4033 )
4034 }
4035 crate::AtomicFunction::Exchange { compare: None } => {
4036 Instruction::atomic_binary(
4037 spirv::Op::AtomicExchange,
4038 result_type_id,
4039 id,
4040 pointer_id,
4041 scope_constant_id,
4042 semantics_id,
4043 value_id,
4044 )
4045 }
4046 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4047 let scalar_type_id =
4048 self.get_numeric_type_id(NumericType::Scalar(scalar));
4049 let bool_type_id =
4050 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4051
4052 let cas_result_id = self.gen_id();
4053 let equality_result_id = self.gen_id();
4054 let equality_operator = match scalar.kind {
4055 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4056 spirv::Op::IEqual
4057 }
4058 _ => unimplemented!(),
4059 };
4060
4061 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4062 cas_instr.set_type(scalar_type_id);
4063 cas_instr.set_result(cas_result_id);
4064 cas_instr.add_operand(pointer_id);
4065 cas_instr.add_operand(scope_constant_id);
4066 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
4069 cas_instr.add_operand(self.cached[cmp]);
4070 block.body.push(cas_instr);
4071 block.body.push(Instruction::binary(
4072 equality_operator,
4073 bool_type_id,
4074 equality_result_id,
4075 cas_result_id,
4076 self.cached[cmp],
4077 ));
4078 Instruction::composite_construct(
4079 result_type_id,
4080 id,
4081 &[cas_result_id, equality_result_id],
4082 )
4083 }
4084 };
4085
4086 block.body.push(instruction);
4087 }
4088 Statement::ImageAtomic {
4089 image,
4090 coordinate,
4091 array_index,
4092 fun,
4093 value,
4094 } => {
4095 self.write_image_atomic(
4096 image,
4097 coordinate,
4098 array_index,
4099 fun,
4100 value,
4101 &mut block,
4102 )?;
4103 }
4104 Statement::WorkGroupUniformLoad { pointer, result } => {
4105 self.writer
4106 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4107 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4108 let id = self.write_checked_load(
4111 pointer,
4112 &mut block,
4113 AccessTypeAdjustment::None,
4114 result_type_id,
4115 )?;
4116 self.cached[result] = id;
4117 self.writer
4118 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4119 }
4120 Statement::RayQuery { query, ref fun } => {
4121 self.write_ray_query_function(query, fun, &mut block);
4122 }
4123 Statement::SubgroupBallot {
4124 result,
4125 ref predicate,
4126 } => {
4127 self.write_subgroup_ballot(predicate, result, &mut block)?;
4128 }
4129 Statement::SubgroupCollectiveOperation {
4130 ref op,
4131 ref collective_op,
4132 argument,
4133 result,
4134 } => {
4135 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4136 }
4137 Statement::SubgroupGather {
4138 ref mode,
4139 argument,
4140 result,
4141 } => {
4142 self.write_subgroup_gather(mode, argument, result, &mut block)?;
4143 }
4144 Statement::CooperativeStore { target, ref data } => {
4145 let target_id = self.cached[target];
4146 let layout = if data.row_major {
4147 spirv::CooperativeMatrixLayout::RowMajorKHR
4148 } else {
4149 spirv::CooperativeMatrixLayout::ColumnMajorKHR
4150 };
4151 let layout_id = self.get_index_constant(layout as u32);
4152 let stride_id = self.cached[data.stride];
4153 match self.write_access_chain(
4154 data.pointer,
4155 &mut block,
4156 AccessTypeAdjustment::None,
4157 )? {
4158 ExpressionPointer::Ready { pointer_id } => {
4159 block.body.push(Instruction::coop_store(
4160 target_id, pointer_id, layout_id, stride_id,
4161 ));
4162 }
4163 ExpressionPointer::Conditional { condition, access } => {
4164 let mut selection = Selection::start(&mut block, ());
4165 selection.if_true(self, condition, ());
4166
4167 let pointer_id = access.result_id.unwrap();
4169 selection.block().body.push(access);
4170 selection.block().body.push(Instruction::coop_store(
4171 target_id, pointer_id, layout_id, stride_id,
4172 ));
4173
4174 selection.finish(self, ());
4177 }
4178 };
4179 }
4180 Statement::RayPipelineFunction(_) => unreachable!(),
4181 }
4182 }
4183
4184 let termination = match exit {
4185 BlockExit::Return => match self.ir_function.result {
4188 Some(ref result) if self.function.entry_point_context.is_none() => {
4189 let type_id = self.get_handle_type_id(result.ty);
4190 let null_id = self.writer.get_constant_null(type_id);
4191 Instruction::return_value(null_id)
4192 }
4193 _ => Instruction::return_void(),
4194 },
4195 BlockExit::Branch { target } => Instruction::branch(target),
4196 BlockExit::BreakIf {
4197 condition,
4198 preamble_id,
4199 } => {
4200 let condition_id = self.cached[condition];
4201
4202 Instruction::branch_conditional(
4203 condition_id,
4204 loop_context.break_id.unwrap(),
4205 preamble_id,
4206 )
4207 }
4208 };
4209
4210 self.function.consume(block, termination);
4211 Ok(BlockExitDisposition::Used)
4212 }
4213
4214 pub(super) fn write_function_body(
4215 &mut self,
4216 entry_id: Word,
4217 debug_info: Option<&DebugInfoInner>,
4218 ) -> Result<(), Error> {
4219 let _ = self.write_block(
4222 entry_id,
4223 &self.ir_function.body,
4224 BlockExit::Return,
4225 LoopContext::default(),
4226 debug_info,
4227 )?;
4228
4229 Ok(())
4230 }
4231}