1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12 BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13 ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16 arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17 proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21 match *type_inner {
22 crate::TypeInner::Scalar(_) => Dimension::Scalar,
23 crate::TypeInner::Vector { .. } => Dimension::Vector,
24 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25 crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26 _ => unreachable!(),
27 }
28}
29
30#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43 None,
52
53 IntroducePointer(spirv::StorageClass),
77
78 UseStd140CompatType,
88}
89
90enum ExpressionPointer {
94 Ready { pointer_id: Word },
97
98 Conditional {
104 condition: Word,
105 access: Instruction,
106 },
107}
108
109enum BlockExit {
111 Return,
113 Branch {
115 target: Word,
117 },
118 BreakIf {
124 condition: Handle<crate::Expression>,
126 preamble_id: Word,
128 },
129}
130
131#[must_use]
142enum BlockExitDisposition {
143 Used,
147
148 Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157 continuing_id: Option<Word>,
158 break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163 pub source_code: &'a str,
164 pub source_file_id: Word,
165}
166
167impl Writer {
168 fn write_epilogue_position_y_flip(
173 &mut self,
174 position_id: Word,
175 body: &mut Vec<Instruction>,
176 ) -> Result<(), Error> {
177 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178 let index_y_id = self.get_index_constant(1);
179 let access_id = self.id_gen.next();
180 body.push(Instruction::access_chain(
181 float_ptr_type_id,
182 access_id,
183 position_id,
184 &[index_y_id],
185 ));
186
187 let float_type_id = self.get_f32_type_id();
188 let load_id = self.id_gen.next();
189 body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191 let neg_id = self.id_gen.next();
192 body.push(Instruction::unary(
193 spirv::Op::FNegate,
194 float_type_id,
195 neg_id,
196 load_id,
197 ));
198
199 body.push(Instruction::store(access_id, neg_id, None));
200 Ok(())
201 }
202
203 fn write_epilogue_frag_depth_clamp(
205 &mut self,
206 frag_depth_id: Word,
207 body: &mut Vec<Instruction>,
208 ) -> Result<(), Error> {
209 let float_type_id = self.get_f32_type_id();
210 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213 let original_id = self.id_gen.next();
214 body.push(Instruction::load(
215 float_type_id,
216 original_id,
217 frag_depth_id,
218 None,
219 ));
220
221 let clamp_id = self.id_gen.next();
222 body.push(Instruction::ext_inst_gl_op(
223 self.gl450_ext_inst_id,
224 spirv::GlslStd450Op::FClamp,
225 float_type_id,
226 clamp_id,
227 &[original_id, zero_scalar_id, one_scalar_id],
228 ));
229
230 body.push(Instruction::store(frag_depth_id, clamp_id, None));
231 Ok(())
232 }
233
234 fn write_entry_point_return(
235 &mut self,
236 value_id: Word,
237 ir_result: &crate::FunctionResult,
238 result_members: &[ResultMember],
239 body: &mut Vec<Instruction>,
240 ) -> Result<Instruction, Error> {
241 for (index, res_member) in result_members.iter().enumerate() {
242 if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
244 return Ok(Instruction::return_value(value_id));
245 }
246 let member_value_id = match ir_result.binding {
247 Some(_) => value_id,
248 None => {
249 let member_value_id = self.id_gen.next();
250 body.push(Instruction::composite_extract(
251 res_member.type_id,
252 member_value_id,
253 value_id,
254 &[index as u32],
255 ));
256 member_value_id
257 }
258 };
259
260 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
261
262 match res_member.built_in {
263 Some(crate::BuiltIn::Position { .. })
264 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
265 {
266 self.write_epilogue_position_y_flip(res_member.id, body)?;
267 }
268 Some(crate::BuiltIn::FragDepth)
269 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
270 {
271 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
272 }
273 _ => {}
274 }
275 }
276 Ok(Instruction::return_void())
277 }
278}
279
280impl BlockContext<'_> {
281 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
294 let uint_type_id = self.writer.get_u32_type_id();
295 let uint2_type_id = self.writer.get_vec2u_type_id();
296 let uint2_ptr_type_id = self
297 .writer
298 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
299 let bool_type_id = self.writer.get_bool_type_id();
300 let bool2_type_id = self.writer.get_vec2_bool_type_id();
301 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
302 let zero_uint2_const_id = self.writer.get_constant_composite(
303 LookupType::Local(LocalType::Numeric(NumericType::Vector {
304 size: crate::VectorSize::Bi,
305 scalar: crate::Scalar::U32,
306 })),
307 &[zero_uint_const_id, zero_uint_const_id],
308 );
309 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
310 let max_uint_const_id = self
311 .writer
312 .get_constant_scalar(crate::Literal::U32(u32::MAX));
313 let max_uint2_const_id = self.writer.get_constant_composite(
314 LookupType::Local(LocalType::Numeric(NumericType::Vector {
315 size: crate::VectorSize::Bi,
316 scalar: crate::Scalar::U32,
317 })),
318 &[max_uint_const_id, max_uint_const_id],
319 );
320
321 let loop_counter_var_id = self.gen_id();
322 if self.writer.flags.contains(WriterFlags::DEBUG) {
323 self.writer
324 .debugs
325 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
326 }
327 let var = super::LocalVariable {
328 id: loop_counter_var_id,
329 instruction: Instruction::variable(
330 uint2_ptr_type_id,
331 loop_counter_var_id,
332 spirv::StorageClass::Function,
333 Some(max_uint2_const_id),
334 ),
335 };
336 self.function.force_loop_bounding_vars.push(var);
337
338 let break_if_block = self.gen_id();
339
340 self.function
341 .consume(block, Instruction::branch(break_if_block));
342 block = Block::new(break_if_block);
343
344 let load_id = self.gen_id();
347 block.body.push(Instruction::load(
348 uint2_type_id,
349 load_id,
350 loop_counter_var_id,
351 None,
352 ));
353
354 let eq_id = self.gen_id();
357 block.body.push(Instruction::binary(
358 spirv::Op::IEqual,
359 bool2_type_id,
360 eq_id,
361 zero_uint2_const_id,
362 load_id,
363 ));
364 let all_eq_id = self.gen_id();
365 block.body.push(Instruction::relational(
366 spirv::Op::All,
367 bool_type_id,
368 all_eq_id,
369 eq_id,
370 ));
371
372 let inc_counter_block_id = self.gen_id();
373 block.body.push(Instruction::selection_merge(
374 inc_counter_block_id,
375 spirv::SelectionControl::empty(),
376 ));
377 self.function.consume(
378 block,
379 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
380 );
381 block = Block::new(inc_counter_block_id);
382
383 let low_id = self.gen_id();
389 block.body.push(Instruction::composite_extract(
390 uint_type_id,
391 low_id,
392 load_id,
393 &[1],
394 ));
395 let low_overflow_id = self.gen_id();
396 block.body.push(Instruction::binary(
397 spirv::Op::IEqual,
398 bool_type_id,
399 low_overflow_id,
400 low_id,
401 zero_uint_const_id,
402 ));
403 let carry_bit_id = self.gen_id();
404 block.body.push(Instruction::select(
405 uint_type_id,
406 carry_bit_id,
407 low_overflow_id,
408 one_uint_const_id,
409 zero_uint_const_id,
410 ));
411 let decrement_id = self.gen_id();
412 block.body.push(Instruction::composite_construct(
413 uint2_type_id,
414 decrement_id,
415 &[carry_bit_id, one_uint_const_id],
416 ));
417 let result_id = self.gen_id();
418 block.body.push(Instruction::binary(
419 spirv::Op::ISub,
420 uint2_type_id,
421 result_id,
422 load_id,
423 decrement_id,
424 ));
425 block
426 .body
427 .push(Instruction::store(loop_counter_var_id, result_id, None));
428
429 block
430 }
431
432 fn maybe_write_uniform_matcx2_dynamic_access(
449 &mut self,
450 pointer: Handle<crate::Expression>,
451 block: &mut Block,
452 ) -> Result<Option<Word>, Error> {
453 let (column_pointer, component_index) = match self.fun_info[pointer]
459 .ty
460 .inner_with(&self.ir_module.types)
461 .pointer_base_type()
462 {
463 Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
464 crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
465 crate::Expression::Access { base, index } => {
466 (base, Some(GuardedIndex::Expression(index)))
467 }
468 crate::Expression::AccessIndex { base, index } => {
469 (base, Some(GuardedIndex::Known(index)))
470 }
471 _ => return Ok(None),
472 },
473 crate::TypeInner::Vector { .. } => (pointer, None),
474 _ => return Ok(None),
475 },
476 None => return Ok(None),
477 };
478
479 let crate::Expression::Access {
482 base: matrix_pointer,
483 index: column_index,
484 } = self.ir_function.expressions[column_pointer]
485 else {
486 return Ok(None);
487 };
488
489 let crate::TypeInner::Pointer {
491 base: matrix_pointer_base_type,
492 space: crate::AddressSpace::Uniform,
493 } = *self.fun_info[matrix_pointer]
494 .ty
495 .inner_with(&self.ir_module.types)
496 else {
497 return Ok(None);
498 };
499
500 let crate::TypeInner::Matrix {
502 columns,
503 rows: rows @ crate::VectorSize::Bi,
504 scalar,
505 } = self.ir_module.types[matrix_pointer_base_type].inner
506 else {
507 return Ok(None);
508 };
509
510 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
511 columns,
512 rows,
513 scalar,
514 });
515 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
516 let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
517 let get_column_function_id = self.writer.wrapped_functions
518 [&WrappedFunction::MatCx2GetColumn {
519 r#type: matrix_pointer_base_type,
520 }];
521
522 let matrix_load_id = self.write_checked_load(
523 matrix_pointer,
524 block,
525 AccessTypeAdjustment::None,
526 matrix_type_id,
527 )?;
528
529 let column_index_id = match *self.fun_info[column_index]
532 .ty
533 .inner_with(&self.ir_module.types)
534 {
535 crate::TypeInner::Scalar(crate::Scalar {
536 kind: crate::ScalarKind::Uint,
537 ..
538 }) => self.cached[column_index],
539 crate::TypeInner::Scalar(crate::Scalar {
540 kind: crate::ScalarKind::Sint,
541 ..
542 }) => {
543 let cast_id = self.gen_id();
544 let u32_type_id = self.writer.get_u32_type_id();
545 block.body.push(Instruction::unary(
546 spirv::Op::Bitcast,
547 u32_type_id,
548 cast_id,
549 self.cached[column_index],
550 ));
551 cast_id
552 }
553 _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
554 };
555 let column_id = self.gen_id();
556 block.body.push(Instruction::function_call(
557 column_type_id,
558 column_id,
559 get_column_function_id,
560 &[matrix_load_id, column_index_id],
561 ));
562 let result_id = match component_index {
563 Some(index) => self.write_vector_access(
564 component_type_id,
565 column_pointer,
566 Some(column_id),
567 index,
568 block,
569 )?,
570 None => column_id,
571 };
572
573 Ok(Some(result_id))
574 }
575
576 fn maybe_write_load_uniform_matcx2_struct_member(
588 &mut self,
589 pointer: Handle<crate::Expression>,
590 block: &mut Block,
591 ) -> Result<Option<Word>, Error> {
592 let crate::TypeInner::Pointer {
594 base: matrix_type,
595 space: space @ crate::AddressSpace::Uniform,
596 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
597 else {
598 return Ok(None);
599 };
600
601 let crate::TypeInner::Matrix {
602 columns,
603 rows: rows @ crate::VectorSize::Bi,
604 scalar,
605 } = self.ir_module.types[matrix_type].inner
606 else {
607 return Ok(None);
608 };
609
610 let crate::Expression::AccessIndex {
613 base: struct_pointer,
614 index: member_index,
615 } = self.ir_function.expressions[pointer]
616 else {
617 return Ok(None);
618 };
619
620 let crate::TypeInner::Pointer {
621 base: struct_type, ..
622 } = *self.fun_info[struct_pointer]
623 .ty
624 .inner_with(&self.ir_module.types)
625 else {
626 return Ok(None);
627 };
628
629 let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
630 return Ok(None);
631 };
632
633 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
634 columns,
635 rows,
636 scalar,
637 });
638 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
639 let column_pointer_type_id =
640 self.get_pointer_type_id(column_type_id, map_storage_class(space));
641 let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
642 [member_index as usize];
643 let column_indices = (0..columns as u32)
644 .map(|c| self.get_index_constant(column0_index + c))
645 .collect::<ArrayVec<_, 4>>();
646
647 let load_mat_from_struct =
650 |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
651 let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
652 for index in &column_indices {
653 let column_pointer_id = id_gen.next();
654 block.body.push(Instruction::access_chain(
655 column_pointer_type_id,
656 column_pointer_id,
657 struct_pointer_id,
658 &[*index],
659 ));
660 let column_id = id_gen.next();
661 block.body.push(Instruction::load(
662 column_type_id,
663 column_id,
664 column_pointer_id,
665 None,
666 ));
667 column_ids.push(column_id);
668 }
669 let result_id = id_gen.next();
670 block.body.push(Instruction::composite_construct(
671 matrix_type_id,
672 result_id,
673 &column_ids,
674 ));
675 result_id
676 };
677
678 let result_id = match self.write_access_chain(
679 struct_pointer,
680 block,
681 AccessTypeAdjustment::UseStd140CompatType,
682 )? {
683 ExpressionPointer::Ready { pointer_id } => {
684 load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
685 }
686 ExpressionPointer::Conditional { condition, access } => self
687 .write_conditional_indexed_load(
688 matrix_type_id,
689 condition,
690 block,
691 |id_gen, block| {
692 let pointer_id = access.result_id.unwrap();
693 block.body.push(access);
694 load_mat_from_struct(pointer_id, id_gen, block)
695 },
696 ),
697 };
698
699 Ok(Some(result_id))
700 }
701
702 pub(super) fn cache_expression_value(
704 &mut self,
705 expr_handle: Handle<crate::Expression>,
706 block: &mut Block,
707 ) -> Result<(), Error> {
708 let is_named_expression = self
709 .ir_function
710 .named_expressions
711 .contains_key(&expr_handle);
712
713 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
714 return Ok(());
715 }
716
717 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
718 let id = match self.ir_function.expressions[expr_handle] {
719 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
720 crate::Expression::Constant(handle) => {
721 let init = self.ir_module.constants[handle].init;
722 self.writer.constant_ids[init]
723 }
724 crate::Expression::Override(_) => return Err(Error::Override),
725 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
726 crate::Expression::Compose { ty, ref components } => {
727 self.temp_list.clear();
728 if self.expression_constness.is_const(expr_handle) {
729 self.temp_list.extend(
730 crate::proc::flatten_compose(
731 ty,
732 components,
733 &self.ir_function.expressions,
734 &self.ir_module.types,
735 )
736 .map(|component| self.cached[component]),
737 );
738 self.writer
739 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
740 } else {
741 self.temp_list
742 .extend(components.iter().map(|&component| self.cached[component]));
743
744 let id = self.gen_id();
745 block.body.push(Instruction::composite_construct(
746 result_type_id,
747 id,
748 &self.temp_list,
749 ));
750 id
751 }
752 }
753 crate::Expression::Splat { size, value } => {
754 let value_id = self.cached[value];
755 let components = &[value_id; 4][..size as usize];
756
757 if self.expression_constness.is_const(expr_handle) {
758 let ty = self
759 .writer
760 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
761 self.writer.get_constant_composite(ty, components)
762 } else {
763 let id = self.gen_id();
764 block.body.push(Instruction::composite_construct(
765 result_type_id,
766 id,
767 components,
768 ));
769 id
770 }
771 }
772 crate::Expression::Access { base, index } => {
773 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
774 match *base_ty_inner {
775 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
776 0
782 }
783 _ if self.function.spilled_accesses.contains(base) => {
784 self.function.spilled_accesses.insert(expr_handle);
792 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
793 }
794 crate::TypeInner::Vector { .. } => self.write_vector_access(
795 result_type_id,
796 base,
797 None,
798 GuardedIndex::Expression(index),
799 block,
800 )?,
801 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
802 match GuardedIndex::from_expression(
804 index,
805 &self.ir_function.expressions,
806 self.ir_module,
807 ) {
808 GuardedIndex::Known(value) => {
809 let id = self.gen_id();
819 let base_id = self.cached[base];
820 block.body.push(Instruction::composite_extract(
821 result_type_id,
822 id,
823 base_id,
824 &[value],
825 ));
826 id
827 }
828 GuardedIndex::Expression(_) => {
829 self.spill_to_internal_variable(base, block);
836
837 self.function.spilled_accesses.insert(expr_handle);
840 self.maybe_access_spilled_composite(
841 expr_handle,
842 block,
843 result_type_id,
844 )?
845 }
846 }
847 }
848 crate::TypeInner::BindingArray {
849 base: binding_type, ..
850 } => {
851 let result_id = match self.write_access_chain(
854 expr_handle,
855 block,
856 AccessTypeAdjustment::IntroducePointer(
857 spirv::StorageClass::UniformConstant,
858 ),
859 )? {
860 ExpressionPointer::Ready { pointer_id } => pointer_id,
861 ExpressionPointer::Conditional { .. } => {
862 return Err(Error::FeatureNotImplemented(
863 "Texture array out-of-bounds handling",
864 ));
865 }
866 };
867
868 let binding_type_id = self.get_handle_type_id(binding_type);
869
870 let load_id = self.gen_id();
871 block.body.push(Instruction::load(
872 binding_type_id,
873 load_id,
874 result_id,
875 None,
876 ));
877
878 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
882 self.writer
883 .decorate_non_uniform_binding_array_access(load_id)?;
884 }
885
886 load_id
887 }
888 ref other => {
889 log::error!(
890 "Unable to access base {:?} of type {:?}",
891 self.ir_function.expressions[base],
892 other
893 );
894 return Err(Error::Validation(
895 "only vectors and arrays may be dynamically indexed by value",
896 ));
897 }
898 }
899 }
900 crate::Expression::AccessIndex { base, index } => {
901 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
902 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
903 0
909 }
910 _ if self.function.spilled_accesses.contains(base) => {
911 self.function.spilled_accesses.insert(expr_handle);
919 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
920 }
921 crate::TypeInner::Vector { .. }
922 | crate::TypeInner::Matrix { .. }
923 | crate::TypeInner::Array { .. }
924 | crate::TypeInner::Struct { .. } => {
925 let id = self.gen_id();
930 let base_id = self.cached[base];
931 block.body.push(Instruction::composite_extract(
932 result_type_id,
933 id,
934 base_id,
935 &[index],
936 ));
937 id
938 }
939 crate::TypeInner::BindingArray {
940 base: binding_type, ..
941 } => {
942 let result_id = match self.write_access_chain(
945 expr_handle,
946 block,
947 AccessTypeAdjustment::IntroducePointer(
948 spirv::StorageClass::UniformConstant,
949 ),
950 )? {
951 ExpressionPointer::Ready { pointer_id } => pointer_id,
952 ExpressionPointer::Conditional { .. } => {
953 return Err(Error::FeatureNotImplemented(
954 "Texture array out-of-bounds handling",
955 ));
956 }
957 };
958
959 let binding_type_id = self.get_handle_type_id(binding_type);
960
961 let load_id = self.gen_id();
962 block.body.push(Instruction::load(
963 binding_type_id,
964 load_id,
965 result_id,
966 None,
967 ));
968
969 load_id
970 }
971 ref other => {
972 log::error!("Unable to access index of {other:?}");
973 return Err(Error::FeatureNotImplemented("access index for type"));
974 }
975 }
976 }
977 crate::Expression::GlobalVariable(handle) => {
978 self.writer.global_variables[handle].access_id
979 }
980 crate::Expression::Swizzle {
981 size,
982 vector,
983 pattern,
984 } => {
985 let vector_id = self.cached[vector];
986 self.temp_list.clear();
987 for &sc in pattern[..size as usize].iter() {
988 self.temp_list.push(sc as Word);
989 }
990 let id = self.gen_id();
991 block.body.push(Instruction::vector_shuffle(
992 result_type_id,
993 id,
994 vector_id,
995 vector_id,
996 &self.temp_list,
997 ));
998 id
999 }
1000 crate::Expression::Unary { op, expr } => {
1001 let id = self.gen_id();
1002 let expr_id = self.cached[expr];
1003 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1004
1005 let spirv_op = match op {
1006 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1007 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1008 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1009 _ => return Err(Error::Validation("Unexpected kind for negation")),
1010 },
1011 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1012 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1013 };
1014
1015 block
1016 .body
1017 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1018 id
1019 }
1020 crate::Expression::Binary { op, left, right } => {
1021 let id = self.gen_id();
1022 let left_id = self.cached[left];
1023 let right_id = self.cached[right];
1024 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1025 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1026
1027 if let Some(function_id) =
1028 self.writer
1029 .wrapped_functions
1030 .get(&WrappedFunction::BinaryOp {
1031 op,
1032 left_type_id,
1033 right_type_id,
1034 })
1035 {
1036 block.body.push(Instruction::function_call(
1037 result_type_id,
1038 id,
1039 *function_id,
1040 &[left_id, right_id],
1041 ));
1042 } else {
1043 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1044 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1045
1046 let left_dimension = get_dimension(left_ty_inner);
1047 let right_dimension = get_dimension(right_ty_inner);
1048
1049 let mut reverse_operands = false;
1050
1051 let spirv_op = match op {
1052 crate::BinaryOperator::Add => match *left_ty_inner {
1053 crate::TypeInner::Scalar(scalar)
1054 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1055 crate::ScalarKind::Float => spirv::Op::FAdd,
1056 _ => spirv::Op::IAdd,
1057 },
1058 crate::TypeInner::Matrix {
1059 columns,
1060 rows,
1061 scalar,
1062 } => {
1063 self.write_matrix_matrix_column_op(
1065 block,
1066 id,
1067 result_type_id,
1068 left_id,
1069 right_id,
1070 columns,
1071 rows,
1072 scalar.width,
1073 spirv::Op::FAdd,
1074 );
1075
1076 self.cached[expr_handle] = id;
1077 return Ok(());
1078 }
1079 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1080 _ => unimplemented!(),
1081 },
1082 crate::BinaryOperator::Subtract => match *left_ty_inner {
1083 crate::TypeInner::Scalar(scalar)
1084 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1085 crate::ScalarKind::Float => spirv::Op::FSub,
1086 _ => spirv::Op::ISub,
1087 },
1088 crate::TypeInner::Matrix {
1089 columns,
1090 rows,
1091 scalar,
1092 } => {
1093 self.write_matrix_matrix_column_op(
1094 block,
1095 id,
1096 result_type_id,
1097 left_id,
1098 right_id,
1099 columns,
1100 rows,
1101 scalar.width,
1102 spirv::Op::FSub,
1103 );
1104
1105 self.cached[expr_handle] = id;
1106 return Ok(());
1107 }
1108 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1109 _ => unimplemented!(),
1110 },
1111 crate::BinaryOperator::Multiply => {
1112 match (left_dimension, right_dimension) {
1113 (Dimension::Scalar, Dimension::Vector) => {
1114 self.write_vector_scalar_mult(
1115 block,
1116 id,
1117 result_type_id,
1118 right_id,
1119 left_id,
1120 right_ty_inner,
1121 );
1122
1123 self.cached[expr_handle] = id;
1124 return Ok(());
1125 }
1126 (Dimension::Vector, Dimension::Scalar) => {
1127 self.write_vector_scalar_mult(
1128 block,
1129 id,
1130 result_type_id,
1131 left_id,
1132 right_id,
1133 left_ty_inner,
1134 );
1135
1136 self.cached[expr_handle] = id;
1137 return Ok(());
1138 }
1139 (Dimension::Vector, Dimension::Matrix) => {
1140 spirv::Op::VectorTimesMatrix
1141 }
1142 (Dimension::Matrix, Dimension::Scalar)
1143 | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1144 spirv::Op::MatrixTimesScalar
1145 }
1146 (Dimension::Scalar, Dimension::Matrix)
1147 | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1148 reverse_operands = true;
1149 spirv::Op::MatrixTimesScalar
1150 }
1151 (Dimension::Matrix, Dimension::Vector) => {
1152 spirv::Op::MatrixTimesVector
1153 }
1154 (Dimension::Matrix, Dimension::Matrix) => {
1155 spirv::Op::MatrixTimesMatrix
1156 }
1157 (Dimension::Vector, Dimension::Vector)
1158 | (Dimension::Scalar, Dimension::Scalar)
1159 if left_ty_inner.scalar_kind()
1160 == Some(crate::ScalarKind::Float) =>
1161 {
1162 spirv::Op::FMul
1163 }
1164 (Dimension::Vector, Dimension::Vector)
1165 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1166 (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1167 | (Dimension::CooperativeMatrix, _)
1169 | (_, Dimension::CooperativeMatrix) => {
1170 unimplemented!()
1171 }
1172 }
1173 }
1174 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1175 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1176 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1177 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1178 _ => unimplemented!(),
1179 },
1180 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1181 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1182 Some(crate::ScalarKind::Sint) => {
1183 unreachable!("signed modulo must be lowered via the wrapped path")
1190 }
1191 Some(crate::ScalarKind::Uint) => {
1192 assert!(!self.writer.emit_int_div_checks);
1193 spirv::Op::UMod
1194 }
1195 _ => unimplemented!(),
1196 },
1197 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1198 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1199 spirv::Op::IEqual
1200 }
1201 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1202 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1203 _ => unimplemented!(),
1204 },
1205 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1206 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1207 spirv::Op::INotEqual
1208 }
1209 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1210 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1211 _ => unimplemented!(),
1212 },
1213 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1214 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1215 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1216 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1217 _ => unimplemented!(),
1218 },
1219 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1220 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1221 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1222 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1223 _ => unimplemented!(),
1224 },
1225 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1226 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1227 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1228 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1229 _ => unimplemented!(),
1230 },
1231 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1232 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1233 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1234 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1235 _ => unimplemented!(),
1236 },
1237 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1238 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1239 _ => spirv::Op::BitwiseAnd,
1240 },
1241 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1242 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1243 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1244 _ => spirv::Op::BitwiseOr,
1245 },
1246 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1247 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1248 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1249 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1250 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1251 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1252 _ => unimplemented!(),
1253 },
1254 };
1255
1256 block.body.push(Instruction::binary(
1257 spirv_op,
1258 result_type_id,
1259 id,
1260 if reverse_operands { right_id } else { left_id },
1261 if reverse_operands { left_id } else { right_id },
1262 ));
1263 }
1264 id
1265 }
1266 crate::Expression::Math {
1267 fun,
1268 arg,
1269 arg1,
1270 arg2,
1271 arg3,
1272 } => {
1273 use crate::MathFunction as Mf;
1274 enum MathOp {
1275 Ext(spirv::GlslStd450Op),
1276 Custom(Instruction),
1277 }
1278
1279 let arg0_id = self.cached[arg];
1280 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1281 let arg_scalar_kind = arg_ty.scalar_kind();
1282 let arg1_id = match arg1 {
1283 Some(handle) => self.cached[handle],
1284 None => 0,
1285 };
1286 let arg2_id = match arg2 {
1287 Some(handle) => self.cached[handle],
1288 None => 0,
1289 };
1290 let arg3_id = match arg3 {
1291 Some(handle) => self.cached[handle],
1292 None => 0,
1293 };
1294
1295 let id = self.gen_id();
1296 let math_op = match fun {
1297 Mf::Abs => {
1299 match arg_scalar_kind {
1300 Some(crate::ScalarKind::Float) => {
1301 MathOp::Ext(spirv::GlslStd450Op::FAbs)
1302 }
1303 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GlslStd450Op::SAbs),
1304 Some(crate::ScalarKind::Uint) => {
1305 MathOp::Custom(Instruction::unary(
1306 spirv::Op::CopyObject, result_type_id,
1308 id,
1309 arg0_id,
1310 ))
1311 }
1312 other => unimplemented!("Unexpected abs({:?})", other),
1313 }
1314 }
1315 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1316 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMin,
1317 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMin,
1318 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMin,
1319 other => unimplemented!("Unexpected min({:?})", other),
1320 }),
1321 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1322 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMax,
1323 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMax,
1324 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMax,
1325 other => unimplemented!("Unexpected max({:?})", other),
1326 }),
1327 Mf::Clamp => match arg_scalar_kind {
1328 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GlslStd450Op::FClamp),
1332 Some(_) => {
1333 let (min_op, max_op) = match arg_scalar_kind {
1334 Some(crate::ScalarKind::Sint) => {
1335 (spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::SMax)
1336 }
1337 Some(crate::ScalarKind::Uint) => {
1338 (spirv::GlslStd450Op::UMin, spirv::GlslStd450Op::UMax)
1339 }
1340 _ => unreachable!(),
1341 };
1342
1343 let max_id = self.gen_id();
1344 block.body.push(Instruction::ext_inst_gl_op(
1345 self.writer.gl450_ext_inst_id,
1346 max_op,
1347 result_type_id,
1348 max_id,
1349 &[arg0_id, arg1_id],
1350 ));
1351
1352 MathOp::Custom(Instruction::ext_inst_gl_op(
1353 self.writer.gl450_ext_inst_id,
1354 min_op,
1355 result_type_id,
1356 id,
1357 &[max_id, arg2_id],
1358 ))
1359 }
1360 other => unimplemented!("Unexpected max({:?})", other),
1361 },
1362 Mf::Saturate => {
1363 let (maybe_size, scalar) = match *arg_ty {
1364 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1365 crate::TypeInner::Scalar(scalar) => (None, scalar),
1366 ref other => unimplemented!("Unexpected saturate({:?})", other),
1367 };
1368 let scalar = crate::Scalar::float(scalar.width);
1369 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1370 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1371
1372 if let Some(size) = maybe_size {
1373 let ty =
1374 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1375
1376 self.temp_list.clear();
1377 self.temp_list.resize(size as _, arg1_id);
1378
1379 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1380
1381 self.temp_list.fill(arg2_id);
1382
1383 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1384 }
1385
1386 MathOp::Custom(Instruction::ext_inst_gl_op(
1387 self.writer.gl450_ext_inst_id,
1388 spirv::GlslStd450Op::FClamp,
1389 result_type_id,
1390 id,
1391 &[arg0_id, arg1_id, arg2_id],
1392 ))
1393 }
1394 Mf::Sin => MathOp::Ext(spirv::GlslStd450Op::Sin),
1396 Mf::Sinh => MathOp::Ext(spirv::GlslStd450Op::Sinh),
1397 Mf::Asin => MathOp::Ext(spirv::GlslStd450Op::Asin),
1398 Mf::Cos => MathOp::Ext(spirv::GlslStd450Op::Cos),
1399 Mf::Cosh => MathOp::Ext(spirv::GlslStd450Op::Cosh),
1400 Mf::Acos => MathOp::Ext(spirv::GlslStd450Op::Acos),
1401 Mf::Tan => MathOp::Ext(spirv::GlslStd450Op::Tan),
1402 Mf::Tanh => MathOp::Ext(spirv::GlslStd450Op::Tanh),
1403 Mf::Atan => MathOp::Ext(spirv::GlslStd450Op::Atan),
1404 Mf::Atan2 => MathOp::Ext(spirv::GlslStd450Op::Atan2),
1405 Mf::Asinh => MathOp::Ext(spirv::GlslStd450Op::Asinh),
1406 Mf::Acosh => MathOp::Ext(spirv::GlslStd450Op::Acosh),
1407 Mf::Atanh => MathOp::Ext(spirv::GlslStd450Op::Atanh),
1408 Mf::Radians => MathOp::Ext(spirv::GlslStd450Op::Radians),
1409 Mf::Degrees => MathOp::Ext(spirv::GlslStd450Op::Degrees),
1410 Mf::Ceil => MathOp::Ext(spirv::GlslStd450Op::Ceil),
1412 Mf::Round => MathOp::Ext(spirv::GlslStd450Op::RoundEven),
1413 Mf::Floor => MathOp::Ext(spirv::GlslStd450Op::Floor),
1414 Mf::Fract => MathOp::Ext(spirv::GlslStd450Op::Fract),
1415 Mf::Trunc => MathOp::Ext(spirv::GlslStd450Op::Trunc),
1416 Mf::Modf => MathOp::Ext(spirv::GlslStd450Op::ModfStruct),
1417 Mf::Frexp => MathOp::Ext(spirv::GlslStd450Op::FrexpStruct),
1418 Mf::Ldexp => MathOp::Ext(spirv::GlslStd450Op::Ldexp),
1419 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1421 crate::TypeInner::Vector {
1422 scalar:
1423 crate::Scalar {
1424 kind: crate::ScalarKind::Float,
1425 ..
1426 },
1427 ..
1428 } => MathOp::Custom(Instruction::binary(
1429 spirv::Op::Dot,
1430 result_type_id,
1431 id,
1432 arg0_id,
1433 arg1_id,
1434 )),
1435 crate::TypeInner::Vector { size, .. } => {
1437 self.write_dot_product(
1438 id,
1439 result_type_id,
1440 arg0_id,
1441 arg1_id,
1442 size as u32,
1443 block,
1444 |result_id, composite_id, index| {
1445 Instruction::composite_extract(
1446 result_type_id,
1447 result_id,
1448 composite_id,
1449 &[index],
1450 )
1451 },
1452 );
1453 self.cached[expr_handle] = id;
1454 return Ok(());
1455 }
1456 _ => unreachable!(
1457 "Correct TypeInner for dot product should be already validated"
1458 ),
1459 },
1460 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1461 if self
1462 .writer
1463 .require_all(&[
1464 spirv::Capability::DotProduct,
1465 spirv::Capability::DotProductInput4x8BitPacked,
1466 ])
1467 .is_ok()
1468 {
1469 if self.writer.lang_version() < (1, 6) {
1471 self.writer.use_extension("SPV_KHR_integer_dot_product");
1475 }
1476
1477 let op = match fun {
1478 Mf::Dot4I8Packed => spirv::Op::SDot,
1479 Mf::Dot4U8Packed => spirv::Op::UDot,
1480 _ => unreachable!(),
1481 };
1482
1483 block.body.push(Instruction::ternary(
1484 op,
1485 result_type_id,
1486 id,
1487 arg0_id,
1488 arg1_id,
1489 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1490 ));
1491 } else {
1492 let (extract_op, arg0_id, arg1_id) = match fun {
1494 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1495 Mf::Dot4I8Packed => {
1496 let new_arg0_id = self.gen_id();
1499 block.body.push(Instruction::unary(
1500 spirv::Op::Bitcast,
1501 result_type_id,
1502 new_arg0_id,
1503 arg0_id,
1504 ));
1505
1506 let new_arg1_id = self.gen_id();
1507 block.body.push(Instruction::unary(
1508 spirv::Op::Bitcast,
1509 result_type_id,
1510 new_arg1_id,
1511 arg1_id,
1512 ));
1513
1514 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1515 }
1516 _ => unreachable!(),
1517 };
1518
1519 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1520
1521 const VEC_LENGTH: u8 = 4;
1522 let bit_shifts: [_; VEC_LENGTH as usize] =
1523 core::array::from_fn(|index| {
1524 self.writer
1525 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1526 });
1527
1528 self.write_dot_product(
1529 id,
1530 result_type_id,
1531 arg0_id,
1532 arg1_id,
1533 VEC_LENGTH as Word,
1534 block,
1535 |result_id, composite_id, index| {
1536 Instruction::ternary(
1537 extract_op,
1538 result_type_id,
1539 result_id,
1540 composite_id,
1541 bit_shifts[index as usize],
1542 eight,
1543 )
1544 },
1545 );
1546 }
1547
1548 self.cached[expr_handle] = id;
1549 return Ok(());
1550 }
1551 Mf::Outer => MathOp::Custom(Instruction::binary(
1552 spirv::Op::OuterProduct,
1553 result_type_id,
1554 id,
1555 arg0_id,
1556 arg1_id,
1557 )),
1558 Mf::Cross => MathOp::Ext(spirv::GlslStd450Op::Cross),
1559 Mf::Distance => MathOp::Ext(spirv::GlslStd450Op::Distance),
1560 Mf::Length => MathOp::Ext(spirv::GlslStd450Op::Length),
1561 Mf::Normalize => MathOp::Ext(spirv::GlslStd450Op::Normalize),
1562 Mf::FaceForward => MathOp::Ext(spirv::GlslStd450Op::FaceForward),
1563 Mf::Reflect => MathOp::Ext(spirv::GlslStd450Op::Reflect),
1564 Mf::Refract => MathOp::Ext(spirv::GlslStd450Op::Refract),
1565 Mf::Exp => MathOp::Ext(spirv::GlslStd450Op::Exp),
1567 Mf::Exp2 => MathOp::Ext(spirv::GlslStd450Op::Exp2),
1568 Mf::Log => MathOp::Ext(spirv::GlslStd450Op::Log),
1569 Mf::Log2 => MathOp::Ext(spirv::GlslStd450Op::Log2),
1570 Mf::Pow => MathOp::Ext(spirv::GlslStd450Op::Pow),
1571 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1573 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FSign,
1574 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SSign,
1575 other => unimplemented!("Unexpected sign({:?})", other),
1576 }),
1577 Mf::Fma => MathOp::Ext(spirv::GlslStd450Op::Fma),
1578 Mf::Mix => {
1579 let selector = arg2.unwrap();
1580 let selector_ty =
1581 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1582 match (arg_ty, selector_ty) {
1583 (
1585 &crate::TypeInner::Vector { size, .. },
1586 &crate::TypeInner::Scalar(scalar),
1587 ) => {
1588 let selector_type_id =
1589 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1590 self.temp_list.clear();
1591 self.temp_list.resize(size as usize, arg2_id);
1592
1593 let selector_id = self.gen_id();
1594 block.body.push(Instruction::composite_construct(
1595 selector_type_id,
1596 selector_id,
1597 &self.temp_list,
1598 ));
1599
1600 MathOp::Custom(Instruction::ext_inst_gl_op(
1601 self.writer.gl450_ext_inst_id,
1602 spirv::GlslStd450Op::FMix,
1603 result_type_id,
1604 id,
1605 &[arg0_id, arg1_id, selector_id],
1606 ))
1607 }
1608 _ => MathOp::Ext(spirv::GlslStd450Op::FMix),
1609 }
1610 }
1611 Mf::Step => MathOp::Ext(spirv::GlslStd450Op::Step),
1612 Mf::SmoothStep => MathOp::Ext(spirv::GlslStd450Op::SmoothStep),
1613 Mf::Sqrt => MathOp::Ext(spirv::GlslStd450Op::Sqrt),
1614 Mf::InverseSqrt => MathOp::Ext(spirv::GlslStd450Op::InverseSqrt),
1615 Mf::Inverse => MathOp::Ext(spirv::GlslStd450Op::MatrixInverse),
1616 Mf::Transpose => MathOp::Custom(Instruction::unary(
1617 spirv::Op::Transpose,
1618 result_type_id,
1619 id,
1620 arg0_id,
1621 )),
1622 Mf::Determinant => MathOp::Ext(spirv::GlslStd450Op::Determinant),
1623 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1624 spirv::Op::QuantizeToF16,
1625 result_type_id,
1626 id,
1627 arg0_id,
1628 )),
1629 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1630 spirv::Op::BitReverse,
1631 result_type_id,
1632 id,
1633 arg0_id,
1634 )),
1635 Mf::CountTrailingZeros => {
1636 let uint_id = match *arg_ty {
1637 crate::TypeInner::Vector { size, scalar } => {
1638 let ty =
1639 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1640
1641 self.temp_list.clear();
1642 self.temp_list.resize(
1643 size as _,
1644 self.writer
1645 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1646 );
1647
1648 self.writer.get_constant_composite(ty, &self.temp_list)
1649 }
1650 crate::TypeInner::Scalar(scalar) => self
1651 .writer
1652 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1653 _ => unreachable!(),
1654 };
1655
1656 let lsb_id = self.gen_id();
1657 block.body.push(Instruction::ext_inst_gl_op(
1658 self.writer.gl450_ext_inst_id,
1659 spirv::GlslStd450Op::FindILsb,
1660 result_type_id,
1661 lsb_id,
1662 &[arg0_id],
1663 ));
1664
1665 MathOp::Custom(Instruction::ext_inst_gl_op(
1666 self.writer.gl450_ext_inst_id,
1667 spirv::GlslStd450Op::UMin,
1668 result_type_id,
1669 id,
1670 &[uint_id, lsb_id],
1671 ))
1672 }
1673 Mf::CountLeadingZeros => {
1674 let (int_type_id, int_id, width) = match *arg_ty {
1675 crate::TypeInner::Vector { size, scalar } => {
1676 let ty =
1677 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1678
1679 self.temp_list.clear();
1680 self.temp_list.resize(
1681 size as _,
1682 self.writer
1683 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1684 );
1685
1686 (
1687 self.get_type_id(ty),
1688 self.writer.get_constant_composite(ty, &self.temp_list),
1689 scalar.width,
1690 )
1691 }
1692 crate::TypeInner::Scalar(scalar) => (
1693 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1694 self.writer
1695 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1696 scalar.width,
1697 ),
1698 _ => unreachable!(),
1699 };
1700
1701 if width != 4 {
1702 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1703 };
1704
1705 let msb_id = self.gen_id();
1706 block.body.push(Instruction::ext_inst_gl_op(
1707 self.writer.gl450_ext_inst_id,
1708 if width != 4 {
1709 spirv::GlslStd450Op::FindILsb
1710 } else {
1711 spirv::GlslStd450Op::FindUMsb
1712 },
1713 int_type_id,
1714 msb_id,
1715 &[arg0_id],
1716 ));
1717
1718 MathOp::Custom(Instruction::binary(
1719 spirv::Op::ISub,
1720 result_type_id,
1721 id,
1722 int_id,
1723 msb_id,
1724 ))
1725 }
1726 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1727 spirv::Op::BitCount,
1728 result_type_id,
1729 id,
1730 arg0_id,
1731 )),
1732 Mf::ExtractBits => {
1733 let op = match arg_scalar_kind {
1734 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1735 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1736 other => unimplemented!("Unexpected sign({:?})", other),
1737 };
1738
1739 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1754 let width_constant = self
1755 .writer
1756 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1757
1758 let u32_type =
1759 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1760
1761 let offset_id = self.gen_id();
1763 block.body.push(Instruction::ext_inst_gl_op(
1764 self.writer.gl450_ext_inst_id,
1765 spirv::GlslStd450Op::UMin,
1766 u32_type,
1767 offset_id,
1768 &[arg1_id, width_constant],
1769 ));
1770
1771 let max_count_id = self.gen_id();
1773 block.body.push(Instruction::binary(
1774 spirv::Op::ISub,
1775 u32_type,
1776 max_count_id,
1777 width_constant,
1778 offset_id,
1779 ));
1780
1781 let count_id = self.gen_id();
1783 block.body.push(Instruction::ext_inst_gl_op(
1784 self.writer.gl450_ext_inst_id,
1785 spirv::GlslStd450Op::UMin,
1786 u32_type,
1787 count_id,
1788 &[arg2_id, max_count_id],
1789 ));
1790
1791 MathOp::Custom(Instruction::ternary(
1792 op,
1793 result_type_id,
1794 id,
1795 arg0_id,
1796 offset_id,
1797 count_id,
1798 ))
1799 }
1800 Mf::InsertBits => {
1801 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1804 let width_constant = self
1805 .writer
1806 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1807
1808 let u32_type =
1809 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1810
1811 let offset_id = self.gen_id();
1813 block.body.push(Instruction::ext_inst_gl_op(
1814 self.writer.gl450_ext_inst_id,
1815 spirv::GlslStd450Op::UMin,
1816 u32_type,
1817 offset_id,
1818 &[arg2_id, width_constant],
1819 ));
1820
1821 let max_count_id = self.gen_id();
1823 block.body.push(Instruction::binary(
1824 spirv::Op::ISub,
1825 u32_type,
1826 max_count_id,
1827 width_constant,
1828 offset_id,
1829 ));
1830
1831 let count_id = self.gen_id();
1833 block.body.push(Instruction::ext_inst_gl_op(
1834 self.writer.gl450_ext_inst_id,
1835 spirv::GlslStd450Op::UMin,
1836 u32_type,
1837 count_id,
1838 &[arg3_id, max_count_id],
1839 ));
1840
1841 MathOp::Custom(Instruction::quaternary(
1842 spirv::Op::BitFieldInsert,
1843 result_type_id,
1844 id,
1845 arg0_id,
1846 arg1_id,
1847 offset_id,
1848 count_id,
1849 ))
1850 }
1851 Mf::FirstTrailingBit => MathOp::Ext(spirv::GlslStd450Op::FindILsb),
1852 Mf::FirstLeadingBit => {
1853 if arg_ty.scalar_width() == Some(4) {
1854 let thing = match arg_scalar_kind {
1855 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::FindUMsb,
1856 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::FindSMsb,
1857 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1858 };
1859 MathOp::Ext(thing)
1860 } else {
1861 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1862 }
1863 }
1864 Mf::Pack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm4x8),
1865 Mf::Pack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm4x8),
1866 Mf::Pack2x16float => MathOp::Ext(spirv::GlslStd450Op::PackHalf2x16),
1867 Mf::Pack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm2x16),
1868 Mf::Pack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm2x16),
1869 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1870 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1871 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1872
1873 let last_instruction =
1874 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1875 self.write_pack4x8_optimized(
1876 block,
1877 result_type_id,
1878 arg0_id,
1879 id,
1880 is_signed,
1881 should_clamp,
1882 )
1883 } else {
1884 self.write_pack4x8_polyfill(
1885 block,
1886 result_type_id,
1887 arg0_id,
1888 id,
1889 is_signed,
1890 should_clamp,
1891 )
1892 };
1893
1894 MathOp::Custom(last_instruction)
1895 }
1896 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm4x8),
1897 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm4x8),
1898 Mf::Unpack2x16float => MathOp::Ext(spirv::GlslStd450Op::UnpackHalf2x16),
1899 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm2x16),
1900 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm2x16),
1901 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1902 let is_signed = matches!(fun, Mf::Unpack4xI8);
1903
1904 let last_instruction =
1905 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1906 self.write_unpack4x8_optimized(
1907 block,
1908 result_type_id,
1909 arg0_id,
1910 id,
1911 is_signed,
1912 )
1913 } else {
1914 self.write_unpack4x8_polyfill(
1915 block,
1916 result_type_id,
1917 arg0_id,
1918 id,
1919 is_signed,
1920 )
1921 };
1922
1923 MathOp::Custom(last_instruction)
1924 }
1925 };
1926
1927 block.body.push(match math_op {
1928 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1929 self.writer.gl450_ext_inst_id,
1930 op,
1931 result_type_id,
1932 id,
1933 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1934 ),
1935 MathOp::Custom(inst) => inst,
1936 });
1937 id
1938 }
1939 crate::Expression::LocalVariable(variable) => {
1940 if let Some(rq_tracker) = self
1941 .function
1942 .ray_query_initialization_tracker_variables
1943 .get(&variable)
1944 {
1945 self.ray_query_tracker_expr.insert(
1946 expr_handle,
1947 super::RayQueryTrackers {
1948 initialized_tracker: rq_tracker.id,
1949 t_max_tracker: self
1950 .function
1951 .ray_query_t_max_tracker_variables
1952 .get(&variable)
1953 .expect("Both trackers are set at the same time.")
1954 .id,
1955 },
1956 );
1957 }
1958 self.function.variables[&variable].id
1959 }
1960 crate::Expression::Load { pointer } => {
1961 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1962 }
1963 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1964 crate::Expression::CallResult(_)
1965 | crate::Expression::AtomicResult { .. }
1966 | crate::Expression::WorkGroupUniformLoadResult { .. }
1967 | crate::Expression::RayQueryProceedResult
1968 | crate::Expression::SubgroupBallotResult
1969 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1970 crate::Expression::As {
1971 expr,
1972 kind,
1973 convert,
1974 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1975 crate::Expression::ImageLoad {
1976 image,
1977 coordinate,
1978 array_index,
1979 sample,
1980 level,
1981 } => self.write_image_load(
1982 result_type_id,
1983 image,
1984 coordinate,
1985 array_index,
1986 level,
1987 sample,
1988 block,
1989 )?,
1990 crate::Expression::ImageSample {
1991 image,
1992 sampler,
1993 gather,
1994 coordinate,
1995 array_index,
1996 offset,
1997 level,
1998 depth_ref,
1999 clamp_to_edge,
2000 } => self.write_image_sample(
2001 result_type_id,
2002 image,
2003 sampler,
2004 gather,
2005 coordinate,
2006 array_index,
2007 offset,
2008 level,
2009 depth_ref,
2010 clamp_to_edge,
2011 block,
2012 )?,
2013 crate::Expression::Select {
2014 condition,
2015 accept,
2016 reject,
2017 } => {
2018 let id = self.gen_id();
2019 let mut condition_id = self.cached[condition];
2020 let accept_id = self.cached[accept];
2021 let reject_id = self.cached[reject];
2022
2023 let condition_ty = self.fun_info[condition]
2024 .ty
2025 .inner_with(&self.ir_module.types);
2026 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2027
2028 if let (
2029 &crate::TypeInner::Scalar(
2030 condition_scalar @ crate::Scalar {
2031 kind: crate::ScalarKind::Bool,
2032 ..
2033 },
2034 ),
2035 &crate::TypeInner::Vector { size, .. },
2036 ) = (condition_ty, object_ty)
2037 {
2038 self.temp_list.clear();
2039 self.temp_list.resize(size as usize, condition_id);
2040
2041 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2042 size,
2043 scalar: condition_scalar,
2044 });
2045
2046 let id = self.gen_id();
2047 block.body.push(Instruction::composite_construct(
2048 bool_vector_type_id,
2049 id,
2050 &self.temp_list,
2051 ));
2052 condition_id = id
2053 }
2054
2055 let instruction =
2056 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2057 block.body.push(instruction);
2058 id
2059 }
2060 crate::Expression::Derivative { axis, ctrl, expr } => {
2061 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2062 match ctrl {
2063 Ctrl::Coarse | Ctrl::Fine => {
2064 self.writer.require_any(
2065 "DerivativeControl",
2066 &[spirv::Capability::DerivativeControl],
2067 )?;
2068 }
2069 Ctrl::None => {}
2070 }
2071 let id = self.gen_id();
2072 let expr_id = self.cached[expr];
2073 let op = match (axis, ctrl) {
2074 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2075 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2076 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2077 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2078 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2079 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2080 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2081 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2082 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2083 };
2084 block
2085 .body
2086 .push(Instruction::derivative(op, result_type_id, id, expr_id));
2087 id
2088 }
2089 crate::Expression::ImageQuery { image, query } => {
2090 self.write_image_query(result_type_id, image, query, block)?
2091 }
2092 crate::Expression::Relational { fun, argument } => {
2093 use crate::RelationalFunction as Rf;
2094 let arg_id = self.cached[argument];
2095 let op = match fun {
2096 Rf::All => spirv::Op::All,
2097 Rf::Any => spirv::Op::Any,
2098 Rf::IsNan => spirv::Op::IsNan,
2099 Rf::IsInf => spirv::Op::IsInf,
2100 };
2101 let id = self.gen_id();
2102 block
2103 .body
2104 .push(Instruction::relational(op, result_type_id, id, arg_id));
2105 id
2106 }
2107 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2108 crate::Expression::RayQueryGetIntersection { query, committed } => {
2109 let query_id = self.cached[query];
2110 let init_tracker_id = *self
2111 .ray_query_tracker_expr
2112 .get(&query)
2113 .expect("not a cached ray query");
2114 let func_id = self
2115 .writer
2116 .write_ray_query_get_intersection_function(committed, self.ir_module);
2117 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2118 let intersection_type_id = self.get_handle_type_id(ray_intersection);
2119 let id = self.gen_id();
2120 block.body.push(Instruction::function_call(
2121 intersection_type_id,
2122 id,
2123 func_id,
2124 &[query_id, init_tracker_id.initialized_tracker],
2125 ));
2126 id
2127 }
2128 crate::Expression::RayQueryVertexPositions { query, committed } => {
2129 self.writer.require_any(
2130 "RayQueryVertexPositions",
2131 &[spirv::Capability::RayQueryPositionFetchKHR],
2132 )?;
2133 self.write_ray_query_return_vertex_position(query, block, committed)
2134 }
2135 crate::Expression::CooperativeLoad { ref data, .. } => {
2136 self.writer.require_any(
2137 "CooperativeMatrix",
2138 &[spirv::Capability::CooperativeMatrixKHR],
2139 )?;
2140 let layout = if data.row_major {
2141 spirv::CooperativeMatrixLayout::RowMajorKHR
2142 } else {
2143 spirv::CooperativeMatrixLayout::ColumnMajorKHR
2144 };
2145 let layout_id = self.get_index_constant(layout as u32);
2146 let stride_id = self.cached[data.stride];
2147 match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2148 ExpressionPointer::Ready { pointer_id } => {
2149 let id = self.gen_id();
2150 block.body.push(Instruction::coop_load(
2151 result_type_id,
2152 id,
2153 pointer_id,
2154 layout_id,
2155 stride_id,
2156 ));
2157 id
2158 }
2159 ExpressionPointer::Conditional { condition, access } => self
2160 .write_conditional_indexed_load(
2161 result_type_id,
2162 condition,
2163 block,
2164 |id_gen, block| {
2165 let pointer_id = access.result_id.unwrap();
2166 block.body.push(access);
2167 let id = id_gen.next();
2168 block.body.push(Instruction::coop_load(
2169 result_type_id,
2170 id,
2171 pointer_id,
2172 layout_id,
2173 stride_id,
2174 ));
2175 id
2176 },
2177 ),
2178 }
2179 }
2180 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2181 self.writer.require_any(
2182 "CooperativeMatrix",
2183 &[spirv::Capability::CooperativeMatrixKHR],
2184 )?;
2185 let a_id = self.cached[a];
2186 let b_id = self.cached[b];
2187 let c_id = self.cached[c];
2188 let id = self.gen_id();
2189 block.body.push(Instruction::coop_mul_add(
2190 result_type_id,
2191 id,
2192 a_id,
2193 b_id,
2194 c_id,
2195 ));
2196 id
2197 }
2198 };
2199
2200 self.cached[expr_handle] = id;
2201 Ok(())
2202 }
2203
2204 fn write_as_expression(
2207 &mut self,
2208 expr: Handle<crate::Expression>,
2209 convert: Option<u8>,
2210 kind: crate::ScalarKind,
2211
2212 block: &mut Block,
2213 result_type_id: u32,
2214 ) -> Result<u32, Error> {
2215 use crate::ScalarKind as Sk;
2216 let expr_id = self.cached[expr];
2217 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2218
2219 if let crate::TypeInner::Matrix {
2224 columns,
2225 rows,
2226 scalar,
2227 } = *ty
2228 {
2229 let Some(convert) = convert else {
2230 return Ok(expr_id);
2232 };
2233
2234 if convert == scalar.width {
2235 return Ok(expr_id);
2237 }
2238
2239 if kind != Sk::Float {
2240 return Err(Error::Validation("Matrices must be floats"));
2242 }
2243
2244 let column_src_ty =
2246 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2247 size: rows,
2248 scalar,
2249 })));
2250
2251 let column_dst_ty =
2253 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2254 size: rows,
2255 scalar: crate::Scalar {
2256 kind,
2257 width: convert,
2258 },
2259 })));
2260
2261 let mut components = ArrayVec::<Word, 4>::new();
2262
2263 for column in 0..columns as usize {
2264 let column_id = self.gen_id();
2265 block.body.push(Instruction::composite_extract(
2266 column_src_ty,
2267 column_id,
2268 expr_id,
2269 &[column as u32],
2270 ));
2271
2272 let column_conv_id = self.gen_id();
2273 block.body.push(Instruction::unary(
2274 spirv::Op::FConvert,
2275 column_dst_ty,
2276 column_conv_id,
2277 column_id,
2278 ));
2279
2280 components.push(column_conv_id);
2281 }
2282
2283 let construct_id = self.gen_id();
2284
2285 block.body.push(Instruction::composite_construct(
2286 result_type_id,
2287 construct_id,
2288 &components,
2289 ));
2290
2291 return Ok(construct_id);
2292 }
2293
2294 let (src_scalar, src_size) = match *ty {
2295 crate::TypeInner::Scalar(scalar) => (scalar, None),
2296 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2297 ref other => {
2298 log::error!("As source {other:?}");
2299 return Err(Error::Validation("Unexpected Expression::As source"));
2300 }
2301 };
2302
2303 enum Cast {
2304 Identity(Word),
2305 Unary(spirv::Op, Word),
2306 Binary(spirv::Op, Word, Word),
2307 Ternary(spirv::Op, Word, Word, Word),
2308 }
2309 let cast = match (src_scalar.kind, kind, convert) {
2310 (src_kind, kind, convert)
2313 if src_kind == kind
2314 && convert.filter(|&width| width != src_scalar.width).is_none() =>
2315 {
2316 Cast::Identity(expr_id)
2317 }
2318 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2319 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2320 (_, Sk::Bool, Some(_)) => {
2322 let op = match src_scalar.kind {
2323 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2324 Sk::Float => spirv::Op::FUnordNotEqual,
2325 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2326 };
2327 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2328 let zero_id = match src_size {
2329 Some(size) => {
2330 let ty = LocalType::Numeric(NumericType::Vector {
2331 size,
2332 scalar: src_scalar,
2333 })
2334 .into();
2335
2336 self.temp_list.clear();
2337 self.temp_list.resize(size as _, zero_scalar_id);
2338
2339 self.writer.get_constant_composite(ty, &self.temp_list)
2340 }
2341 None => zero_scalar_id,
2342 };
2343
2344 Cast::Binary(op, expr_id, zero_id)
2345 }
2346 (Sk::Bool, _, Some(dst_width)) => {
2348 let dst_scalar = crate::Scalar {
2349 kind,
2350 width: dst_width,
2351 };
2352 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2353 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2354 let (accept_id, reject_id) = match src_size {
2355 Some(size) => {
2356 let ty = LocalType::Numeric(NumericType::Vector {
2357 size,
2358 scalar: dst_scalar,
2359 })
2360 .into();
2361
2362 self.temp_list.clear();
2363 self.temp_list.resize(size as _, zero_scalar_id);
2364
2365 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2366
2367 self.temp_list.fill(one_scalar_id);
2368
2369 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2370
2371 (vec1_id, vec0_id)
2372 }
2373 None => (one_scalar_id, zero_scalar_id),
2374 };
2375
2376 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2377 }
2378 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2389 let dst_scalar = crate::Scalar { kind, width };
2390 let (min, max) =
2391 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2392 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2393
2394 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2395 None => const_id,
2396 Some(size) => {
2397 let constituent_ids = [const_id; crate::VectorSize::MAX];
2398 writer.get_constant_composite(
2399 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2400 size,
2401 scalar: src_scalar,
2402 })),
2403 &constituent_ids[..size as usize],
2404 )
2405 }
2406 };
2407 let min_const_id = self.writer.get_constant_scalar(min);
2408 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2409 let max_const_id = self.writer.get_constant_scalar(max);
2410 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2411
2412 let clamp_id = self.gen_id();
2413 block.body.push(Instruction::ext_inst_gl_op(
2414 self.writer.gl450_ext_inst_id,
2415 spirv::GlslStd450Op::FClamp,
2416 expr_type_id,
2417 clamp_id,
2418 &[expr_id, min_const_id, max_const_id],
2419 ));
2420
2421 let op = match dst_scalar.kind {
2422 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2423 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2424 _ => unreachable!(),
2425 };
2426 Cast::Unary(op, clamp_id)
2427 }
2428 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2429 Cast::Unary(spirv::Op::FConvert, expr_id)
2430 }
2431 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2432 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2433 Cast::Unary(spirv::Op::SConvert, expr_id)
2434 }
2435 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2436 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2437 Cast::Unary(spirv::Op::UConvert, expr_id)
2438 }
2439 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2440 Cast::Unary(spirv::Op::SConvert, expr_id)
2441 }
2442 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2443 Cast::Unary(spirv::Op::UConvert, expr_id)
2444 }
2445 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2447 };
2448 Ok(match cast {
2449 Cast::Identity(expr) => expr,
2450 Cast::Unary(op, op1) => {
2451 let id = self.gen_id();
2452 block
2453 .body
2454 .push(Instruction::unary(op, result_type_id, id, op1));
2455 id
2456 }
2457 Cast::Binary(op, op1, op2) => {
2458 let id = self.gen_id();
2459 block
2460 .body
2461 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2462 id
2463 }
2464 Cast::Ternary(op, op1, op2, op3) => {
2465 let id = self.gen_id();
2466 block
2467 .body
2468 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2469 id
2470 }
2471 })
2472 }
2473
2474 fn write_access_chain(
2485 &mut self,
2486 mut expr_handle: Handle<crate::Expression>,
2487 block: &mut Block,
2488 type_adjustment: AccessTypeAdjustment,
2489 ) -> Result<ExpressionPointer, Error> {
2490 let result_type_id = {
2491 let resolution = &self.fun_info[expr_handle].ty;
2492 match type_adjustment {
2493 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2494 AccessTypeAdjustment::IntroducePointer(class) => {
2495 self.writer.get_resolution_pointer_id(resolution, class)
2496 }
2497 AccessTypeAdjustment::UseStd140CompatType => {
2498 match *resolution.inner_with(&self.ir_module.types) {
2499 crate::TypeInner::Pointer {
2500 base,
2501 space: space @ crate::AddressSpace::Uniform,
2502 } => self.writer.get_pointer_type_id(
2503 self.writer.std140_compat_uniform_types[&base].type_id,
2504 map_storage_class(space),
2505 ),
2506 _ => unreachable!(
2507 "`UseStd140CompatType` must only be used with uniform pointer types"
2508 ),
2509 }
2510 }
2511 }
2512 };
2513
2514 let mut accumulated_checks = None;
2518
2519 let mut is_non_uniform_binding_array = false;
2521
2522 let mut prev_decomposed_matrix_index = None;
2528
2529 self.temp_list.clear();
2530 let root_id = loop {
2531 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2534 break spilled.id;
2537 }
2538
2539 expr_handle = match self.ir_function.expressions[expr_handle] {
2540 crate::Expression::Access { base, index } => {
2541 is_non_uniform_binding_array |=
2542 self.is_nonuniform_binding_array_access(base, index);
2543
2544 let index = GuardedIndex::Expression(index);
2545 let index_id =
2546 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2547 self.temp_list.push(index_id);
2548
2549 base
2550 }
2551 crate::Expression::AccessIndex { base, index } => {
2552 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2555 let mut base_ty_handle = self.fun_info[base].ty.handle();
2556 let mut pointer_space = None;
2557 if let crate::TypeInner::Pointer { base, space } = *base_ty {
2558 base_ty = &self.ir_module.types[base].inner;
2559 base_ty_handle = Some(base);
2560 pointer_space = Some(space);
2561 }
2562 match *base_ty {
2563 crate::TypeInner::Struct { .. } => {
2570 let index = match base_ty_handle.and_then(|handle| {
2571 self.writer.std140_compat_uniform_types.get(&handle)
2572 }) {
2573 Some(std140_type_info)
2574 if pointer_space == Some(crate::AddressSpace::Uniform) =>
2575 {
2576 std140_type_info.member_indices[index as usize]
2577 + prev_decomposed_matrix_index.take().unwrap_or(0)
2578 }
2579 _ => index,
2580 };
2581 let index_id = self.get_index_constant(index);
2582 self.temp_list.push(index_id);
2583 }
2584 _ if is_uniform_matcx2_struct_member_access(
2591 self.ir_function,
2592 self.fun_info,
2593 self.ir_module,
2594 base,
2595 ) =>
2596 {
2597 assert!(prev_decomposed_matrix_index.is_none());
2598 prev_decomposed_matrix_index = Some(index);
2599 }
2600 _ => {
2601 let index_id = self.write_access_chain_index(
2608 base,
2609 GuardedIndex::Known(index),
2610 &mut accumulated_checks,
2611 block,
2612 )?;
2613 self.temp_list.push(index_id);
2614 }
2615 }
2616 base
2617 }
2618 crate::Expression::GlobalVariable(handle) => {
2619 let gv = &self.writer.global_variables[handle];
2620 break gv.access_id;
2621 }
2622 crate::Expression::LocalVariable(variable) => {
2623 let local_var = &self.function.variables[&variable];
2624 break local_var.id;
2625 }
2626 crate::Expression::FunctionArgument(index) => {
2627 break self.function.parameter_id(index);
2628 }
2629 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2630 }
2631 };
2632
2633 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2634 (
2635 root_id,
2636 ExpressionPointer::Ready {
2637 pointer_id: root_id,
2638 },
2639 )
2640 } else {
2641 self.temp_list.reverse();
2642 let pointer_id = self.gen_id();
2643 let access =
2644 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2645
2646 let expr_pointer = match accumulated_checks {
2651 Some(condition) => ExpressionPointer::Conditional { condition, access },
2652 None => {
2653 block.body.push(access);
2654 ExpressionPointer::Ready { pointer_id }
2655 }
2656 };
2657 (pointer_id, expr_pointer)
2658 };
2659 if is_non_uniform_binding_array {
2663 self.writer
2664 .decorate_non_uniform_binding_array_access(pointer_id)?;
2665 }
2666
2667 Ok(expr_pointer)
2668 }
2669
2670 fn is_nonuniform_binding_array_access(
2671 &mut self,
2672 base: Handle<crate::Expression>,
2673 index: Handle<crate::Expression>,
2674 ) -> bool {
2675 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2676 else {
2677 return false;
2678 };
2679
2680 let gvar = &self.ir_module.global_variables[var_handle];
2683 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2684 return false;
2685 };
2686
2687 self.fun_info[index].uniformity.non_uniform_result.is_some()
2688 }
2689
2690 fn write_access_chain_index(
2700 &mut self,
2701 base: Handle<crate::Expression>,
2702 index: GuardedIndex,
2703 accumulated_checks: &mut Option<Word>,
2704 block: &mut Block,
2705 ) -> Result<Word, Error> {
2706 match self.write_bounds_check(base, index, block)? {
2707 BoundsCheckResult::KnownInBounds(known_index) => {
2708 let scalar = crate::Literal::U32(known_index);
2711 Ok(self.writer.get_constant_scalar(scalar))
2712 }
2713 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2714 BoundsCheckResult::Conditional {
2715 condition_id: condition,
2716 index_id: index,
2717 } => {
2718 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2719
2720 Ok(index)
2722 }
2723 }
2724 }
2725
2726 fn extend_bounds_check_condition_chain(
2745 &mut self,
2746 chain: &mut Option<Word>,
2747 comparison_id: Word,
2748 block: &mut Block,
2749 ) {
2750 match *chain {
2751 Some(ref mut prior_checks) => {
2752 let combined = self.gen_id();
2753 block.body.push(Instruction::binary(
2754 spirv::Op::LogicalAnd,
2755 self.writer.get_bool_type_id(),
2756 combined,
2757 *prior_checks,
2758 comparison_id,
2759 ));
2760 *prior_checks = combined;
2761 }
2762 None => {
2763 *chain = Some(comparison_id);
2765 }
2766 }
2767 }
2768
2769 fn write_checked_load(
2770 &mut self,
2771 pointer: Handle<crate::Expression>,
2772 block: &mut Block,
2773 access_type_adjustment: AccessTypeAdjustment,
2774 result_type_id: Word,
2775 ) -> Result<Word, Error> {
2776 if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2777 Ok(result_id)
2778 } else if let Some(result_id) =
2779 self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2780 {
2781 Ok(result_id)
2782 } else {
2783 struct WrappedLoad {
2789 access_type_adjustment: AccessTypeAdjustment,
2790 r#type: Handle<crate::Type>,
2791 }
2792 let mut wrapped_load = None;
2793 if let crate::TypeInner::Pointer {
2794 base: pointer_base_type,
2795 space: crate::AddressSpace::Uniform,
2796 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2797 {
2798 if self
2799 .writer
2800 .std140_compat_uniform_types
2801 .contains_key(&pointer_base_type)
2802 {
2803 wrapped_load = Some(WrappedLoad {
2804 access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2805 r#type: pointer_base_type,
2806 });
2807 };
2808 };
2809
2810 let (load_type_id, access_type_adjustment) = match wrapped_load {
2811 Some(ref wrapped_load) => (
2812 self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2813 wrapped_load.access_type_adjustment,
2814 ),
2815 None => (result_type_id, access_type_adjustment),
2816 };
2817
2818 let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2819 ExpressionPointer::Ready { pointer_id } => {
2820 let id = self.gen_id();
2821 let atomic_space =
2822 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2823 crate::TypeInner::Pointer { base, space } => {
2824 match self.ir_module.types[base].inner {
2825 crate::TypeInner::Atomic { .. } => Some(space),
2826 _ => None,
2827 }
2828 }
2829 _ => None,
2830 };
2831 let instruction = if let Some(space) = atomic_space {
2832 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2833 let scope_constant_id = self.get_scope_constant(scope as u32);
2834 let semantics_id = self.get_index_constant(semantics.bits());
2835 Instruction::atomic_load(
2836 result_type_id,
2837 id,
2838 pointer_id,
2839 scope_constant_id,
2840 semantics_id,
2841 )
2842 } else {
2843 Instruction::load(load_type_id, id, pointer_id, None)
2844 };
2845 block.body.push(instruction);
2846 id
2847 }
2848 ExpressionPointer::Conditional { condition, access } => {
2849 self.write_conditional_indexed_load(
2851 load_type_id,
2852 condition,
2853 block,
2854 move |id_gen, block| {
2855 let pointer_id = access.result_id.unwrap();
2857 let value_id = id_gen.next();
2858 block.body.push(access);
2859 block.body.push(Instruction::load(
2860 load_type_id,
2861 value_id,
2862 pointer_id,
2863 None,
2864 ));
2865 value_id
2866 },
2867 )
2868 }
2869 };
2870
2871 match wrapped_load {
2872 Some(ref wrapped_load) => {
2873 let result_id = self.gen_id();
2876 let function_id = self.writer.wrapped_functions
2877 [&WrappedFunction::ConvertFromStd140CompatType {
2878 r#type: wrapped_load.r#type,
2879 }];
2880 block.body.push(Instruction::function_call(
2881 result_type_id,
2882 result_id,
2883 function_id,
2884 &[load_id],
2885 ));
2886 Ok(result_id)
2887 }
2888 None => Ok(load_id),
2889 }
2890 }
2891 }
2892
2893 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2894 use indexmap::map::Entry;
2895
2896 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2898 Entry::Occupied(preexisting) => preexisting.get().id,
2899 Entry::Vacant(vacant) => {
2900 let pointer_type_id = self.writer.get_resolution_pointer_id(
2903 &self.fun_info[base].ty,
2904 spirv::StorageClass::Function,
2905 );
2906 let id = self.writer.id_gen.next();
2907 vacant.insert(super::LocalVariable {
2908 id,
2909 instruction: Instruction::variable(
2910 pointer_type_id,
2911 id,
2912 spirv::StorageClass::Function,
2913 None,
2914 ),
2915 });
2916 id
2917 }
2918 };
2919
2920 let base_id = self.cached[base];
2945 block
2946 .body
2947 .push(Instruction::store(spill_variable_id, base_id, None));
2948 }
2949
2950 fn maybe_access_spilled_composite(
2967 &mut self,
2968 access: Handle<crate::Expression>,
2969 block: &mut Block,
2970 result_type_id: Word,
2971 ) -> Result<Word, Error> {
2972 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2973 if access_uses == self.fun_info[access].ref_count {
2974 Ok(0)
2978 } else {
2979 self.write_checked_load(
2984 access,
2985 block,
2986 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2987 result_type_id,
2988 )
2989 }
2990 }
2991
2992 #[allow(clippy::too_many_arguments)]
2994 fn write_matrix_matrix_column_op(
2995 &mut self,
2996 block: &mut Block,
2997 result_id: Word,
2998 result_type_id: Word,
2999 left_id: Word,
3000 right_id: Word,
3001 columns: crate::VectorSize,
3002 rows: crate::VectorSize,
3003 width: u8,
3004 op: spirv::Op,
3005 ) {
3006 self.temp_list.clear();
3007
3008 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3009 size: rows,
3010 scalar: crate::Scalar::float(width),
3011 });
3012
3013 for index in 0..columns as u32 {
3014 let column_id_left = self.gen_id();
3015 let column_id_right = self.gen_id();
3016 let column_id_res = self.gen_id();
3017
3018 block.body.push(Instruction::composite_extract(
3019 vector_type_id,
3020 column_id_left,
3021 left_id,
3022 &[index],
3023 ));
3024 block.body.push(Instruction::composite_extract(
3025 vector_type_id,
3026 column_id_right,
3027 right_id,
3028 &[index],
3029 ));
3030 block.body.push(Instruction::binary(
3031 op,
3032 vector_type_id,
3033 column_id_res,
3034 column_id_left,
3035 column_id_right,
3036 ));
3037
3038 self.temp_list.push(column_id_res);
3039 }
3040
3041 block.body.push(Instruction::composite_construct(
3042 result_type_id,
3043 result_id,
3044 &self.temp_list,
3045 ));
3046 }
3047
3048 fn write_vector_scalar_mult(
3050 &mut self,
3051 block: &mut Block,
3052 result_id: Word,
3053 result_type_id: Word,
3054 vector_id: Word,
3055 scalar_id: Word,
3056 vector: &crate::TypeInner,
3057 ) {
3058 let (size, kind) = match *vector {
3059 crate::TypeInner::Vector {
3060 size,
3061 scalar: crate::Scalar { kind, .. },
3062 } => (size, kind),
3063 _ => unreachable!(),
3064 };
3065
3066 let (op, operand_id) = match kind {
3067 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3068 _ => {
3069 let operand_id = self.gen_id();
3070 self.temp_list.clear();
3071 self.temp_list.resize(size as usize, scalar_id);
3072 block.body.push(Instruction::composite_construct(
3073 result_type_id,
3074 operand_id,
3075 &self.temp_list,
3076 ));
3077 (spirv::Op::IMul, operand_id)
3078 }
3079 };
3080
3081 block.body.push(Instruction::binary(
3082 op,
3083 result_type_id,
3084 result_id,
3085 vector_id,
3086 operand_id,
3087 ));
3088 }
3089
3090 #[expect(clippy::too_many_arguments)]
3097 fn write_dot_product(
3098 &mut self,
3099 result_id: Word,
3100 result_type_id: Word,
3101 arg0_id: Word,
3102 arg1_id: Word,
3103 size: u32,
3104 block: &mut Block,
3105 extractor: impl Fn(Word, Word, Word) -> Instruction,
3106 ) {
3107 let mut partial_sum = self.writer.get_constant_null(result_type_id);
3108 let last_component = size - 1;
3109 for index in 0..=last_component {
3110 let a_id = self.gen_id();
3112 block.body.push(extractor(a_id, arg0_id, index));
3113 let b_id = self.gen_id();
3114 block.body.push(extractor(b_id, arg1_id, index));
3115 let prod_id = self.gen_id();
3116 block.body.push(Instruction::binary(
3117 spirv::Op::IMul,
3118 result_type_id,
3119 prod_id,
3120 a_id,
3121 b_id,
3122 ));
3123
3124 let id = if index == last_component {
3126 result_id
3127 } else {
3128 self.gen_id()
3129 };
3130
3131 block.body.push(Instruction::binary(
3133 spirv::Op::IAdd,
3134 result_type_id,
3135 id,
3136 partial_sum,
3137 prod_id,
3138 ));
3139 partial_sum = id;
3141 }
3142 }
3143
3144 fn write_pack4x8_optimized(
3146 &mut self,
3147 block: &mut Block,
3148 result_type_id: u32,
3149 arg0_id: u32,
3150 id: u32,
3151 is_signed: bool,
3152 should_clamp: bool,
3153 ) -> Instruction {
3154 let int_type = if is_signed {
3155 crate::ScalarKind::Sint
3156 } else {
3157 crate::ScalarKind::Uint
3158 };
3159 let wide_vector_type = NumericType::Vector {
3160 size: crate::VectorSize::Quad,
3161 scalar: crate::Scalar {
3162 kind: int_type,
3163 width: 4,
3164 },
3165 };
3166 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3167 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3168 size: crate::VectorSize::Quad,
3169 scalar: crate::Scalar {
3170 kind: crate::ScalarKind::Uint,
3171 width: 1,
3172 },
3173 });
3174
3175 let mut wide_vector = arg0_id;
3176 if should_clamp {
3177 let (min, max, clamp_op) = if is_signed {
3178 (
3179 crate::Literal::I32(-128),
3180 crate::Literal::I32(127),
3181 spirv::GlslStd450Op::SClamp,
3182 )
3183 } else {
3184 (
3185 crate::Literal::U32(0),
3186 crate::Literal::U32(255),
3187 spirv::GlslStd450Op::UClamp,
3188 )
3189 };
3190 let [min, max] = [min, max].map(|lit| {
3191 let scalar = self.writer.get_constant_scalar(lit);
3192 self.writer.get_constant_composite(
3193 LookupType::Local(LocalType::Numeric(wide_vector_type)),
3194 &[scalar; 4],
3195 )
3196 });
3197
3198 let clamp_id = self.gen_id();
3199 block.body.push(Instruction::ext_inst_gl_op(
3200 self.writer.gl450_ext_inst_id,
3201 clamp_op,
3202 wide_vector_type_id,
3203 clamp_id,
3204 &[wide_vector, min, max],
3205 ));
3206
3207 wide_vector = clamp_id;
3208 }
3209
3210 let packed_vector = self.gen_id();
3211 block.body.push(Instruction::unary(
3212 spirv::Op::UConvert, packed_vector_type_id,
3214 packed_vector,
3215 wide_vector,
3216 ));
3217
3218 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3223 }
3224
3225 fn write_pack4x8_polyfill(
3227 &mut self,
3228 block: &mut Block,
3229 result_type_id: u32,
3230 arg0_id: u32,
3231 id: u32,
3232 is_signed: bool,
3233 should_clamp: bool,
3234 ) -> Instruction {
3235 let int_type = if is_signed {
3236 crate::ScalarKind::Sint
3237 } else {
3238 crate::ScalarKind::Uint
3239 };
3240 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3241 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3242 kind: int_type,
3243 width: 4,
3244 }));
3245
3246 let mut last_instruction = Instruction::new(spirv::Op::Nop);
3247
3248 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3249 let mut preresult = zero;
3250 block
3251 .body
3252 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3253
3254 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3255 const VEC_LENGTH: u8 = 4;
3256 for i in 0..u32::from(VEC_LENGTH) {
3257 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3258 let mut extracted = self.gen_id();
3259 block.body.push(Instruction::binary(
3260 spirv::Op::CompositeExtract,
3261 int_type_id,
3262 extracted,
3263 arg0_id,
3264 i,
3265 ));
3266 if is_signed {
3267 let casted = self.gen_id();
3268 block.body.push(Instruction::unary(
3269 spirv::Op::Bitcast,
3270 uint_type_id,
3271 casted,
3272 extracted,
3273 ));
3274 extracted = casted;
3275 }
3276 if should_clamp {
3277 let (min, max, clamp_op) = if is_signed {
3278 (
3279 crate::Literal::I32(-128),
3280 crate::Literal::I32(127),
3281 spirv::GlslStd450Op::SClamp,
3282 )
3283 } else {
3284 (
3285 crate::Literal::U32(0),
3286 crate::Literal::U32(255),
3287 spirv::GlslStd450Op::UClamp,
3288 )
3289 };
3290 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3291
3292 let clamp_id = self.gen_id();
3293 block.body.push(Instruction::ext_inst_gl_op(
3294 self.writer.gl450_ext_inst_id,
3295 clamp_op,
3296 result_type_id,
3297 clamp_id,
3298 &[extracted, min, max],
3299 ));
3300
3301 extracted = clamp_id;
3302 }
3303 let is_last = i == u32::from(VEC_LENGTH - 1);
3304 if is_last {
3305 last_instruction = Instruction::quaternary(
3306 spirv::Op::BitFieldInsert,
3307 result_type_id,
3308 id,
3309 preresult,
3310 extracted,
3311 offset,
3312 eight,
3313 )
3314 } else {
3315 let new_preresult = self.gen_id();
3316 block.body.push(Instruction::quaternary(
3317 spirv::Op::BitFieldInsert,
3318 result_type_id,
3319 new_preresult,
3320 preresult,
3321 extracted,
3322 offset,
3323 eight,
3324 ));
3325 preresult = new_preresult;
3326 }
3327 }
3328 last_instruction
3329 }
3330
3331 fn write_unpack4x8_optimized(
3333 &mut self,
3334 block: &mut Block,
3335 result_type_id: u32,
3336 arg0_id: u32,
3337 id: u32,
3338 is_signed: bool,
3339 ) -> Instruction {
3340 let (int_type, convert_op) = if is_signed {
3341 (crate::ScalarKind::Sint, spirv::Op::SConvert)
3342 } else {
3343 (crate::ScalarKind::Uint, spirv::Op::UConvert)
3344 };
3345
3346 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3347 size: crate::VectorSize::Quad,
3348 scalar: crate::Scalar {
3349 kind: int_type,
3350 width: 1,
3351 },
3352 });
3353
3354 let packed_vector = self.gen_id();
3359 block.body.push(Instruction::unary(
3360 spirv::Op::Bitcast,
3361 packed_vector_type_id,
3362 packed_vector,
3363 arg0_id,
3364 ));
3365
3366 Instruction::unary(convert_op, result_type_id, id, packed_vector)
3367 }
3368
3369 fn write_unpack4x8_polyfill(
3371 &mut self,
3372 block: &mut Block,
3373 result_type_id: u32,
3374 arg0_id: u32,
3375 id: u32,
3376 is_signed: bool,
3377 ) -> Instruction {
3378 let (int_type, extract_op) = if is_signed {
3379 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3380 } else {
3381 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3382 };
3383
3384 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3385
3386 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3387 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3388 kind: int_type,
3389 width: 4,
3390 }));
3391 block
3392 .body
3393 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3394 let arg_id = if is_signed {
3395 let new_arg_id = self.gen_id();
3396 block.body.push(Instruction::unary(
3397 spirv::Op::Bitcast,
3398 sint_type_id,
3399 new_arg_id,
3400 arg0_id,
3401 ));
3402 new_arg_id
3403 } else {
3404 arg0_id
3405 };
3406
3407 const VEC_LENGTH: u8 = 4;
3408 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3409 for (i, part_id) in parts.into_iter().enumerate() {
3410 let index = self
3411 .writer
3412 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3413 block.body.push(Instruction::ternary(
3414 extract_op,
3415 int_type_id,
3416 part_id,
3417 arg_id,
3418 index,
3419 eight,
3420 ));
3421 }
3422
3423 Instruction::composite_construct(result_type_id, id, &parts)
3424 }
3425
3426 fn write_block(
3443 &mut self,
3444 label_id: Word,
3445 naga_block: &crate::Block,
3446 exit: BlockExit,
3447 loop_context: LoopContext,
3448 debug_info: Option<&DebugInfoInner>,
3449 ) -> Result<BlockExitDisposition, Error> {
3450 let mut block = Block::new(label_id);
3451 for (statement, span) in naga_block.span_iter() {
3452 if let (Some(debug_info), false) = (
3453 debug_info,
3454 matches!(
3455 statement,
3456 &(Statement::Block(..)
3457 | Statement::Break
3458 | Statement::Continue
3459 | Statement::Kill
3460 | Statement::Return { .. }
3461 | Statement::Loop { .. })
3462 ),
3463 ) {
3464 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3465 block.body.push(Instruction::line(
3466 debug_info.source_file_id,
3467 loc.line_number,
3468 loc.line_position,
3469 ));
3470 };
3471 match *statement {
3472 Statement::Emit(ref range) => {
3473 for handle in range.clone() {
3474 if !self.expression_constness.is_const(handle) {
3476 self.cache_expression_value(handle, &mut block)?;
3477 }
3478 }
3479 }
3480 Statement::Block(ref block_statements) => {
3481 let scope_id = self.gen_id();
3482 self.function.consume(block, Instruction::branch(scope_id));
3483
3484 let merge_id = self.gen_id();
3485 let merge_used = self.write_block(
3486 scope_id,
3487 block_statements,
3488 BlockExit::Branch { target: merge_id },
3489 loop_context,
3490 debug_info,
3491 )?;
3492
3493 match merge_used {
3494 BlockExitDisposition::Used => {
3495 block = Block::new(merge_id);
3496 }
3497 BlockExitDisposition::Discarded => {
3498 return Ok(BlockExitDisposition::Discarded);
3499 }
3500 }
3501 }
3502 Statement::If {
3503 condition,
3504 ref accept,
3505 ref reject,
3506 } => {
3507 if !(accept.is_empty() && reject.is_empty()) {
3513 let condition_id = self.cached[condition];
3514
3515 let merge_id = self.gen_id();
3516 block.body.push(Instruction::selection_merge(
3517 merge_id,
3518 spirv::SelectionControl::NONE,
3519 ));
3520
3521 let accept_id = if accept.is_empty() {
3522 None
3523 } else {
3524 Some(self.gen_id())
3525 };
3526 let reject_id = if reject.is_empty() {
3527 None
3528 } else {
3529 Some(self.gen_id())
3530 };
3531
3532 self.function.consume(
3533 block,
3534 Instruction::branch_conditional(
3535 condition_id,
3536 accept_id.unwrap_or(merge_id),
3537 reject_id.unwrap_or(merge_id),
3538 ),
3539 );
3540
3541 if let Some(block_id) = accept_id {
3542 let _ = self.write_block(
3547 block_id,
3548 accept,
3549 BlockExit::Branch { target: merge_id },
3550 loop_context,
3551 debug_info,
3552 )?;
3553 }
3554 if let Some(block_id) = reject_id {
3555 let _ = self.write_block(
3560 block_id,
3561 reject,
3562 BlockExit::Branch { target: merge_id },
3563 loop_context,
3564 debug_info,
3565 )?;
3566 }
3567
3568 block = Block::new(merge_id);
3569 }
3570 }
3571 Statement::Switch {
3572 selector,
3573 ref cases,
3574 } => {
3575 let selector_id = self.cached[selector];
3576
3577 let merge_id = self.gen_id();
3578 block.body.push(Instruction::selection_merge(
3579 merge_id,
3580 spirv::SelectionControl::NONE,
3581 ));
3582
3583 let mut default_id = None;
3584 let mut last_id = None;
3586
3587 let mut raw_cases = Vec::with_capacity(cases.len());
3588 let mut case_ids = Vec::with_capacity(cases.len());
3589 for case in cases.iter() {
3590 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3592
3593 if case.fall_through && case.body.is_empty() {
3594 last_id = Some(label_id);
3595 }
3596
3597 case_ids.push(label_id);
3598
3599 match case.value {
3600 crate::SwitchValue::I32(value) => {
3601 raw_cases.push(super::instructions::Case {
3602 value: value as Word,
3603 label_id,
3604 });
3605 }
3606 crate::SwitchValue::U32(value) => {
3607 raw_cases.push(super::instructions::Case { value, label_id });
3608 }
3609 crate::SwitchValue::Default => {
3610 default_id = Some(label_id);
3611 }
3612 }
3613 }
3614
3615 let default_id = default_id.unwrap();
3616
3617 self.function.consume(
3618 block,
3619 Instruction::switch(selector_id, default_id, &raw_cases),
3620 );
3621
3622 let inner_context = LoopContext {
3623 break_id: Some(merge_id),
3624 ..loop_context
3625 };
3626
3627 for (i, (case, label_id)) in cases
3628 .iter()
3629 .zip(case_ids.iter())
3630 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3631 .enumerate()
3632 {
3633 let case_finish_id = if case.fall_through {
3634 case_ids[i + 1]
3635 } else {
3636 merge_id
3637 };
3638 let _ = self.write_block(
3647 *label_id,
3648 &case.body,
3649 BlockExit::Branch {
3650 target: case_finish_id,
3651 },
3652 inner_context,
3653 debug_info,
3654 )?;
3655 }
3656
3657 block = Block::new(merge_id);
3658 }
3659 Statement::Loop {
3660 ref body,
3661 ref continuing,
3662 break_if,
3663 } => {
3664 let preamble_id = self.gen_id();
3665 self.function
3666 .consume(block, Instruction::branch(preamble_id));
3667
3668 let merge_id = self.gen_id();
3669 let body_id = self.gen_id();
3670 let continuing_id = self.gen_id();
3671
3672 block = Block::new(preamble_id);
3675 if let Some(debug_info) = debug_info {
3678 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3679 block.body.push(Instruction::line(
3680 debug_info.source_file_id,
3681 loc.line_number,
3682 loc.line_position,
3683 ))
3684 }
3685 block.body.push(Instruction::loop_merge(
3686 merge_id,
3687 continuing_id,
3688 spirv::SelectionControl::NONE,
3689 ));
3690
3691 if self.force_loop_bounding {
3692 block = self.write_force_bounded_loop_instructions(block, merge_id);
3693 }
3694 self.function.consume(block, Instruction::branch(body_id));
3695
3696 let _ = self.write_block(
3700 body_id,
3701 body,
3702 BlockExit::Branch {
3703 target: continuing_id,
3704 },
3705 LoopContext {
3706 continuing_id: Some(continuing_id),
3707 break_id: Some(merge_id),
3708 },
3709 debug_info,
3710 )?;
3711
3712 let exit = match break_if {
3713 Some(condition) => BlockExit::BreakIf {
3714 condition,
3715 preamble_id,
3716 },
3717 None => BlockExit::Branch {
3718 target: preamble_id,
3719 },
3720 };
3721
3722 let _ = self.write_block(
3726 continuing_id,
3727 continuing,
3728 exit,
3729 LoopContext {
3730 continuing_id: None,
3731 break_id: Some(merge_id),
3732 },
3733 debug_info,
3734 )?;
3735
3736 block = Block::new(merge_id);
3737 }
3738 Statement::Break => {
3739 self.function
3740 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3741 return Ok(BlockExitDisposition::Discarded);
3742 }
3743 Statement::Continue => {
3744 self.function.consume(
3745 block,
3746 Instruction::branch(loop_context.continuing_id.unwrap()),
3747 );
3748 return Ok(BlockExitDisposition::Discarded);
3749 }
3750 Statement::Return { value: Some(value) } => {
3751 let value_id = self.cached[value];
3752 let instruction = match self.function.entry_point_context {
3753 Some(ref context) => self.writer.write_entry_point_return(
3756 value_id,
3757 self.ir_function.result.as_ref().unwrap(),
3758 &context.results,
3759 &mut block.body,
3760 )?,
3761 None => Instruction::return_value(value_id),
3762 };
3763 self.function.consume(block, instruction);
3764 return Ok(BlockExitDisposition::Discarded);
3765 }
3766 Statement::Return { value: None } => {
3767 self.function.consume(block, Instruction::return_void());
3768 return Ok(BlockExitDisposition::Discarded);
3769 }
3770 Statement::Kill => {
3771 self.function.consume(block, Instruction::kill());
3772 return Ok(BlockExitDisposition::Discarded);
3773 }
3774 Statement::ControlBarrier(flags) => {
3775 self.writer.write_control_barrier(flags, &mut block.body);
3776 }
3777 Statement::MemoryBarrier(flags) => {
3778 self.writer.write_memory_barrier(flags, &mut block);
3779 }
3780 Statement::Store { pointer, value } => {
3781 let value_id = self.cached[value];
3782 match self.write_access_chain(
3783 pointer,
3784 &mut block,
3785 AccessTypeAdjustment::None,
3786 )? {
3787 ExpressionPointer::Ready { pointer_id } => {
3788 let atomic_space = match *self.fun_info[pointer]
3789 .ty
3790 .inner_with(&self.ir_module.types)
3791 {
3792 crate::TypeInner::Pointer { base, space } => {
3793 match self.ir_module.types[base].inner {
3794 crate::TypeInner::Atomic { .. } => Some(space),
3795 _ => None,
3796 }
3797 }
3798 _ => None,
3799 };
3800 let instruction = if let Some(space) = atomic_space {
3801 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3802 let scope_constant_id = self.get_scope_constant(scope as u32);
3803 let semantics_id = self.get_index_constant(semantics.bits());
3804 Instruction::atomic_store(
3805 pointer_id,
3806 scope_constant_id,
3807 semantics_id,
3808 value_id,
3809 )
3810 } else {
3811 Instruction::store(pointer_id, value_id, None)
3812 };
3813 block.body.push(instruction);
3814 }
3815 ExpressionPointer::Conditional { condition, access } => {
3816 let mut selection = Selection::start(&mut block, ());
3817 selection.if_true(self, condition, ());
3818
3819 let pointer_id = access.result_id.unwrap();
3821 selection.block().body.push(access);
3822 selection
3823 .block()
3824 .body
3825 .push(Instruction::store(pointer_id, value_id, None));
3826
3827 selection.finish(self, ());
3830 }
3831 };
3832 }
3833 Statement::ImageStore {
3834 image,
3835 coordinate,
3836 array_index,
3837 value,
3838 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3839 Statement::Call {
3840 function: local_function,
3841 ref arguments,
3842 result,
3843 } => {
3844 let id = self.gen_id();
3845 self.temp_list.clear();
3846 for &argument in arguments {
3847 self.temp_list.push(self.cached[argument]);
3848 }
3849
3850 let type_id = match result {
3851 Some(expr) => {
3852 self.cached[expr] = id;
3853 self.get_expression_type_id(&self.fun_info[expr].ty)
3854 }
3855 None => self.writer.void_type,
3856 };
3857
3858 block.body.push(Instruction::function_call(
3859 type_id,
3860 id,
3861 self.writer.lookup_function[&local_function],
3862 &self.temp_list,
3863 ));
3864 }
3865 Statement::Atomic {
3866 pointer,
3867 ref fun,
3868 value,
3869 result,
3870 } => {
3871 let id = self.gen_id();
3872 let result_type_id =
3876 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3877
3878 if let Some(result) = result {
3879 self.cached[result] = id;
3880 }
3881
3882 let pointer_id = match self.write_access_chain(
3883 pointer,
3884 &mut block,
3885 AccessTypeAdjustment::None,
3886 )? {
3887 ExpressionPointer::Ready { pointer_id } => pointer_id,
3888 ExpressionPointer::Conditional { .. } => {
3889 return Err(Error::FeatureNotImplemented(
3890 "Atomics out-of-bounds handling",
3891 ));
3892 }
3893 };
3894
3895 let space = self.fun_info[pointer]
3896 .ty
3897 .inner_with(&self.ir_module.types)
3898 .pointer_space()
3899 .unwrap();
3900 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3901 let scope_constant_id = self.get_scope_constant(scope as u32);
3902 let semantics_id = self.get_index_constant(semantics.bits());
3903 let value_id = self.cached[value];
3904 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3905
3906 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3907 return Err(Error::FeatureNotImplemented(
3908 "Atomics with non-scalar values",
3909 ));
3910 };
3911
3912 let instruction = match *fun {
3913 crate::AtomicFunction::Add => {
3914 let spirv_op = match scalar.kind {
3915 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3916 spirv::Op::AtomicIAdd
3917 }
3918 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3919 _ => unimplemented!(),
3920 };
3921 Instruction::atomic_binary(
3922 spirv_op,
3923 result_type_id,
3924 id,
3925 pointer_id,
3926 scope_constant_id,
3927 semantics_id,
3928 value_id,
3929 )
3930 }
3931 crate::AtomicFunction::Subtract => {
3932 let (spirv_op, value_id) = match scalar.kind {
3933 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3934 (spirv::Op::AtomicISub, value_id)
3935 }
3936 crate::ScalarKind::Float => {
3937 let neg_result_id = self.gen_id();
3940 block.body.push(Instruction::unary(
3941 spirv::Op::FNegate,
3942 result_type_id,
3943 neg_result_id,
3944 value_id,
3945 ));
3946 (spirv::Op::AtomicFAddEXT, neg_result_id)
3947 }
3948 _ => unimplemented!(),
3949 };
3950 Instruction::atomic_binary(
3951 spirv_op,
3952 result_type_id,
3953 id,
3954 pointer_id,
3955 scope_constant_id,
3956 semantics_id,
3957 value_id,
3958 )
3959 }
3960 crate::AtomicFunction::And => {
3961 let spirv_op = match scalar.kind {
3962 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3963 spirv::Op::AtomicAnd
3964 }
3965 _ => unimplemented!(),
3966 };
3967 Instruction::atomic_binary(
3968 spirv_op,
3969 result_type_id,
3970 id,
3971 pointer_id,
3972 scope_constant_id,
3973 semantics_id,
3974 value_id,
3975 )
3976 }
3977 crate::AtomicFunction::InclusiveOr => {
3978 let spirv_op = match scalar.kind {
3979 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3980 spirv::Op::AtomicOr
3981 }
3982 _ => unimplemented!(),
3983 };
3984 Instruction::atomic_binary(
3985 spirv_op,
3986 result_type_id,
3987 id,
3988 pointer_id,
3989 scope_constant_id,
3990 semantics_id,
3991 value_id,
3992 )
3993 }
3994 crate::AtomicFunction::ExclusiveOr => {
3995 let spirv_op = match scalar.kind {
3996 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3997 spirv::Op::AtomicXor
3998 }
3999 _ => unimplemented!(),
4000 };
4001 Instruction::atomic_binary(
4002 spirv_op,
4003 result_type_id,
4004 id,
4005 pointer_id,
4006 scope_constant_id,
4007 semantics_id,
4008 value_id,
4009 )
4010 }
4011 crate::AtomicFunction::Min => {
4012 let spirv_op = match scalar.kind {
4013 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4014 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4015 _ => unimplemented!(),
4016 };
4017 Instruction::atomic_binary(
4018 spirv_op,
4019 result_type_id,
4020 id,
4021 pointer_id,
4022 scope_constant_id,
4023 semantics_id,
4024 value_id,
4025 )
4026 }
4027 crate::AtomicFunction::Max => {
4028 let spirv_op = match scalar.kind {
4029 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4030 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4031 _ => unimplemented!(),
4032 };
4033 Instruction::atomic_binary(
4034 spirv_op,
4035 result_type_id,
4036 id,
4037 pointer_id,
4038 scope_constant_id,
4039 semantics_id,
4040 value_id,
4041 )
4042 }
4043 crate::AtomicFunction::Exchange { compare: None } => {
4044 Instruction::atomic_binary(
4045 spirv::Op::AtomicExchange,
4046 result_type_id,
4047 id,
4048 pointer_id,
4049 scope_constant_id,
4050 semantics_id,
4051 value_id,
4052 )
4053 }
4054 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4055 let scalar_type_id =
4056 self.get_numeric_type_id(NumericType::Scalar(scalar));
4057 let bool_type_id =
4058 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4059
4060 let cas_result_id = self.gen_id();
4061 let equality_result_id = self.gen_id();
4062 let equality_operator = match scalar.kind {
4063 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4064 spirv::Op::IEqual
4065 }
4066 _ => unimplemented!(),
4067 };
4068
4069 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4070 cas_instr.set_type(scalar_type_id);
4071 cas_instr.set_result(cas_result_id);
4072 cas_instr.add_operand(pointer_id);
4073 cas_instr.add_operand(scope_constant_id);
4074 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
4077 cas_instr.add_operand(self.cached[cmp]);
4078 block.body.push(cas_instr);
4079 block.body.push(Instruction::binary(
4080 equality_operator,
4081 bool_type_id,
4082 equality_result_id,
4083 cas_result_id,
4084 self.cached[cmp],
4085 ));
4086 Instruction::composite_construct(
4087 result_type_id,
4088 id,
4089 &[cas_result_id, equality_result_id],
4090 )
4091 }
4092 };
4093
4094 block.body.push(instruction);
4095 }
4096 Statement::ImageAtomic {
4097 image,
4098 coordinate,
4099 array_index,
4100 fun,
4101 value,
4102 } => {
4103 self.write_image_atomic(
4104 image,
4105 coordinate,
4106 array_index,
4107 fun,
4108 value,
4109 &mut block,
4110 )?;
4111 }
4112 Statement::WorkGroupUniformLoad { pointer, result } => {
4113 self.writer
4114 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4115 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4116 let id = self.write_checked_load(
4119 pointer,
4120 &mut block,
4121 AccessTypeAdjustment::None,
4122 result_type_id,
4123 )?;
4124 self.cached[result] = id;
4125 self.writer
4126 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4127 }
4128 Statement::RayQuery { query, ref fun } => {
4129 self.write_ray_query_function(query, fun, &mut block);
4130 }
4131 Statement::SubgroupBallot {
4132 result,
4133 ref predicate,
4134 } => {
4135 self.write_subgroup_ballot(predicate, result, &mut block)?;
4136 }
4137 Statement::SubgroupCollectiveOperation {
4138 ref op,
4139 ref collective_op,
4140 argument,
4141 result,
4142 } => {
4143 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4144 }
4145 Statement::SubgroupGather {
4146 ref mode,
4147 argument,
4148 result,
4149 } => {
4150 self.write_subgroup_gather(mode, argument, result, &mut block)?;
4151 }
4152 Statement::CooperativeStore { target, ref data } => {
4153 let target_id = self.cached[target];
4154 let layout = if data.row_major {
4155 spirv::CooperativeMatrixLayout::RowMajorKHR
4156 } else {
4157 spirv::CooperativeMatrixLayout::ColumnMajorKHR
4158 };
4159 let layout_id = self.get_index_constant(layout as u32);
4160 let stride_id = self.cached[data.stride];
4161 match self.write_access_chain(
4162 data.pointer,
4163 &mut block,
4164 AccessTypeAdjustment::None,
4165 )? {
4166 ExpressionPointer::Ready { pointer_id } => {
4167 block.body.push(Instruction::coop_store(
4168 target_id, pointer_id, layout_id, stride_id,
4169 ));
4170 }
4171 ExpressionPointer::Conditional { condition, access } => {
4172 let mut selection = Selection::start(&mut block, ());
4173 selection.if_true(self, condition, ());
4174
4175 let pointer_id = access.result_id.unwrap();
4177 selection.block().body.push(access);
4178 selection.block().body.push(Instruction::coop_store(
4179 target_id, pointer_id, layout_id, stride_id,
4180 ));
4181
4182 selection.finish(self, ());
4185 }
4186 };
4187 }
4188 Statement::RayPipelineFunction(ref fun) => {
4189 self.write_ray_tracing_pipeline_function(fun, &mut block);
4190 }
4191 }
4192 }
4193
4194 let termination = match exit {
4195 BlockExit::Return => match self.ir_function.result {
4198 Some(ref result) if self.function.entry_point_context.is_none() => {
4199 let type_id = self.get_handle_type_id(result.ty);
4200 let null_id = self.writer.get_constant_null(type_id);
4201 Instruction::return_value(null_id)
4202 }
4203 _ => Instruction::return_void(),
4204 },
4205 BlockExit::Branch { target } => Instruction::branch(target),
4206 BlockExit::BreakIf {
4207 condition,
4208 preamble_id,
4209 } => {
4210 let condition_id = self.cached[condition];
4211
4212 Instruction::branch_conditional(
4213 condition_id,
4214 loop_context.break_id.unwrap(),
4215 preamble_id,
4216 )
4217 }
4218 };
4219
4220 self.function.consume(block, termination);
4221 Ok(BlockExitDisposition::Used)
4222 }
4223
4224 pub(super) fn write_function_body(
4225 &mut self,
4226 entry_id: Word,
4227 debug_info: Option<&DebugInfoInner>,
4228 ) -> Result<(), Error> {
4229 let _ = self.write_block(
4232 entry_id,
4233 &self.ir_function.body,
4234 BlockExit::Return,
4235 LoopContext::default(),
4236 debug_info,
4237 )?;
4238
4239 Ok(())
4240 }
4241}