1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12 BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13 ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16 arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17 proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21 match *type_inner {
22 crate::TypeInner::Scalar(_) => Dimension::Scalar,
23 crate::TypeInner::Vector { .. } => Dimension::Vector,
24 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25 crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26 _ => unreachable!(),
27 }
28}
29
30#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43 None,
52
53 IntroducePointer(spirv::StorageClass),
77
78 UseStd140CompatType,
88}
89
90enum ExpressionPointer {
94 Ready { pointer_id: Word },
97
98 Conditional {
104 condition: Word,
105 access: Instruction,
106 },
107}
108
109enum BlockExit {
111 Return,
113 Branch {
115 target: Word,
117 },
118 BreakIf {
124 condition: Handle<crate::Expression>,
126 preamble_id: Word,
128 },
129}
130
131#[must_use]
142enum BlockExitDisposition {
143 Used,
147
148 Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157 continuing_id: Option<Word>,
158 break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163 pub source_code: &'a str,
164 pub source_file_id: Word,
165}
166
167impl Writer {
168 fn write_epilogue_position_y_flip(
173 &mut self,
174 position_id: Word,
175 body: &mut Vec<Instruction>,
176 ) -> Result<(), Error> {
177 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178 let index_y_id = self.get_index_constant(1);
179 let access_id = self.id_gen.next();
180 body.push(Instruction::access_chain(
181 float_ptr_type_id,
182 access_id,
183 position_id,
184 &[index_y_id],
185 ));
186
187 let float_type_id = self.get_f32_type_id();
188 let load_id = self.id_gen.next();
189 body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191 let neg_id = self.id_gen.next();
192 body.push(Instruction::unary(
193 spirv::Op::FNegate,
194 float_type_id,
195 neg_id,
196 load_id,
197 ));
198
199 body.push(Instruction::store(access_id, neg_id, None));
200 Ok(())
201 }
202
203 fn write_epilogue_frag_depth_clamp(
205 &mut self,
206 frag_depth_id: Word,
207 body: &mut Vec<Instruction>,
208 ) -> Result<(), Error> {
209 let float_type_id = self.get_f32_type_id();
210 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213 let original_id = self.id_gen.next();
214 body.push(Instruction::load(
215 float_type_id,
216 original_id,
217 frag_depth_id,
218 None,
219 ));
220
221 let clamp_id = self.id_gen.next();
222 body.push(Instruction::ext_inst_gl_op(
223 self.gl450_ext_inst_id,
224 spirv::GlslStd450Op::FClamp,
225 float_type_id,
226 clamp_id,
227 &[original_id, zero_scalar_id, one_scalar_id],
228 ));
229
230 body.push(Instruction::store(frag_depth_id, clamp_id, None));
231 Ok(())
232 }
233
234 fn write_entry_point_return(
235 &mut self,
236 value_id: Word,
237 ir_result: &crate::FunctionResult,
238 result_members: &[ResultMember],
239 body: &mut Vec<Instruction>,
240 ) -> Result<Instruction, Error> {
241 for (index, res_member) in result_members.iter().enumerate() {
242 if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
244 return Ok(Instruction::return_value(value_id));
245 }
246 let member_value_id = match ir_result.binding {
247 Some(_) => value_id,
248 None => {
249 let member_value_id = self.id_gen.next();
250 body.push(Instruction::composite_extract(
251 res_member.type_id,
252 member_value_id,
253 value_id,
254 &[index as u32],
255 ));
256 member_value_id
257 }
258 };
259
260 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
261
262 match res_member.built_in {
263 Some(crate::BuiltIn::Position { .. })
264 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
265 {
266 self.write_epilogue_position_y_flip(res_member.id, body)?;
267 }
268 Some(crate::BuiltIn::FragDepth)
269 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
270 {
271 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
272 }
273 _ => {}
274 }
275 }
276 Ok(Instruction::return_void())
277 }
278}
279
280impl BlockContext<'_> {
281 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
294 let uint_type_id = self.writer.get_u32_type_id();
295 let uint2_type_id = self.writer.get_vec2u_type_id();
296 let uint2_ptr_type_id = self
297 .writer
298 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
299 let bool_type_id = self.writer.get_bool_type_id();
300 let bool2_type_id = self.writer.get_vec2_bool_type_id();
301 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
302 let zero_uint2_const_id = self.writer.get_constant_composite(
303 LookupType::Local(LocalType::Numeric(NumericType::Vector {
304 size: crate::VectorSize::Bi,
305 scalar: crate::Scalar::U32,
306 })),
307 &[zero_uint_const_id, zero_uint_const_id],
308 );
309 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
310 let max_uint_const_id = self
311 .writer
312 .get_constant_scalar(crate::Literal::U32(u32::MAX));
313 let max_uint2_const_id = self.writer.get_constant_composite(
314 LookupType::Local(LocalType::Numeric(NumericType::Vector {
315 size: crate::VectorSize::Bi,
316 scalar: crate::Scalar::U32,
317 })),
318 &[max_uint_const_id, max_uint_const_id],
319 );
320
321 let loop_counter_var_id = self.gen_id();
322 if self.writer.flags.contains(WriterFlags::DEBUG) {
323 self.writer
324 .debugs
325 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
326 }
327 let var = super::LocalVariable {
328 id: loop_counter_var_id,
329 instruction: Instruction::variable(
330 uint2_ptr_type_id,
331 loop_counter_var_id,
332 spirv::StorageClass::Function,
333 Some(max_uint2_const_id),
334 ),
335 };
336 self.function.force_loop_bounding_vars.push(var);
337
338 let break_if_block = self.gen_id();
339
340 self.function
341 .consume(block, Instruction::branch(break_if_block));
342 block = Block::new(break_if_block);
343
344 let load_id = self.gen_id();
347 block.body.push(Instruction::load(
348 uint2_type_id,
349 load_id,
350 loop_counter_var_id,
351 None,
352 ));
353
354 let eq_id = self.gen_id();
357 block.body.push(Instruction::binary(
358 spirv::Op::IEqual,
359 bool2_type_id,
360 eq_id,
361 zero_uint2_const_id,
362 load_id,
363 ));
364 let all_eq_id = self.gen_id();
365 block.body.push(Instruction::relational(
366 spirv::Op::All,
367 bool_type_id,
368 all_eq_id,
369 eq_id,
370 ));
371
372 let inc_counter_block_id = self.gen_id();
373 block.body.push(Instruction::selection_merge(
374 inc_counter_block_id,
375 spirv::SelectionControl::empty(),
376 ));
377 self.function.consume(
378 block,
379 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
380 );
381 block = Block::new(inc_counter_block_id);
382
383 let low_id = self.gen_id();
389 block.body.push(Instruction::composite_extract(
390 uint_type_id,
391 low_id,
392 load_id,
393 &[1],
394 ));
395 let low_overflow_id = self.gen_id();
396 block.body.push(Instruction::binary(
397 spirv::Op::IEqual,
398 bool_type_id,
399 low_overflow_id,
400 low_id,
401 zero_uint_const_id,
402 ));
403 let carry_bit_id = self.gen_id();
404 block.body.push(Instruction::select(
405 uint_type_id,
406 carry_bit_id,
407 low_overflow_id,
408 one_uint_const_id,
409 zero_uint_const_id,
410 ));
411 let decrement_id = self.gen_id();
412 block.body.push(Instruction::composite_construct(
413 uint2_type_id,
414 decrement_id,
415 &[carry_bit_id, one_uint_const_id],
416 ));
417 let result_id = self.gen_id();
418 block.body.push(Instruction::binary(
419 spirv::Op::ISub,
420 uint2_type_id,
421 result_id,
422 load_id,
423 decrement_id,
424 ));
425 block
426 .body
427 .push(Instruction::store(loop_counter_var_id, result_id, None));
428
429 block
430 }
431
432 fn maybe_write_uniform_matcx2_dynamic_access(
449 &mut self,
450 pointer: Handle<crate::Expression>,
451 block: &mut Block,
452 ) -> Result<Option<Word>, Error> {
453 let (column_pointer, component_index) = match self.fun_info[pointer]
459 .ty
460 .inner_with(&self.ir_module.types)
461 .pointer_base_type()
462 {
463 Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
464 crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
465 crate::Expression::Access { base, index } => {
466 (base, Some(GuardedIndex::Expression(index)))
467 }
468 crate::Expression::AccessIndex { base, index } => {
469 (base, Some(GuardedIndex::Known(index)))
470 }
471 _ => return Ok(None),
472 },
473 crate::TypeInner::Vector { .. } => (pointer, None),
474 _ => return Ok(None),
475 },
476 None => return Ok(None),
477 };
478
479 let crate::Expression::Access {
482 base: matrix_pointer,
483 index: column_index,
484 } = self.ir_function.expressions[column_pointer]
485 else {
486 return Ok(None);
487 };
488
489 let crate::TypeInner::Pointer {
491 base: matrix_pointer_base_type,
492 space: crate::AddressSpace::Uniform,
493 } = *self.fun_info[matrix_pointer]
494 .ty
495 .inner_with(&self.ir_module.types)
496 else {
497 return Ok(None);
498 };
499
500 let crate::TypeInner::Matrix {
502 columns,
503 rows: rows @ crate::VectorSize::Bi,
504 scalar,
505 } = self.ir_module.types[matrix_pointer_base_type].inner
506 else {
507 return Ok(None);
508 };
509
510 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
511 columns,
512 rows,
513 scalar,
514 });
515 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
516 let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
517 let get_column_function_id = self.writer.wrapped_functions
518 [&WrappedFunction::MatCx2GetColumn {
519 r#type: matrix_pointer_base_type,
520 }];
521
522 let matrix_load_id = self.write_checked_load(
523 matrix_pointer,
524 block,
525 AccessTypeAdjustment::None,
526 matrix_type_id,
527 )?;
528
529 let column_index_id = match *self.fun_info[column_index]
532 .ty
533 .inner_with(&self.ir_module.types)
534 {
535 crate::TypeInner::Scalar(crate::Scalar {
536 kind: crate::ScalarKind::Uint,
537 ..
538 }) => self.cached[column_index],
539 crate::TypeInner::Scalar(crate::Scalar {
540 kind: crate::ScalarKind::Sint,
541 ..
542 }) => {
543 let cast_id = self.gen_id();
544 let u32_type_id = self.writer.get_u32_type_id();
545 block.body.push(Instruction::unary(
546 spirv::Op::Bitcast,
547 u32_type_id,
548 cast_id,
549 self.cached[column_index],
550 ));
551 cast_id
552 }
553 _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
554 };
555 let column_id = self.gen_id();
556 block.body.push(Instruction::function_call(
557 column_type_id,
558 column_id,
559 get_column_function_id,
560 &[matrix_load_id, column_index_id],
561 ));
562 let result_id = match component_index {
563 Some(index) => self.write_vector_access(
564 component_type_id,
565 column_pointer,
566 Some(column_id),
567 index,
568 block,
569 )?,
570 None => column_id,
571 };
572
573 Ok(Some(result_id))
574 }
575
576 fn maybe_write_load_uniform_matcx2_struct_member(
588 &mut self,
589 pointer: Handle<crate::Expression>,
590 block: &mut Block,
591 ) -> Result<Option<Word>, Error> {
592 let crate::TypeInner::Pointer {
594 base: matrix_type,
595 space: space @ crate::AddressSpace::Uniform,
596 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
597 else {
598 return Ok(None);
599 };
600
601 let crate::TypeInner::Matrix {
602 columns,
603 rows: rows @ crate::VectorSize::Bi,
604 scalar,
605 } = self.ir_module.types[matrix_type].inner
606 else {
607 return Ok(None);
608 };
609
610 let crate::Expression::AccessIndex {
613 base: struct_pointer,
614 index: member_index,
615 } = self.ir_function.expressions[pointer]
616 else {
617 return Ok(None);
618 };
619
620 let crate::TypeInner::Pointer {
621 base: struct_type, ..
622 } = *self.fun_info[struct_pointer]
623 .ty
624 .inner_with(&self.ir_module.types)
625 else {
626 return Ok(None);
627 };
628
629 let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
630 return Ok(None);
631 };
632
633 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
634 columns,
635 rows,
636 scalar,
637 });
638 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
639 let column_pointer_type_id =
640 self.get_pointer_type_id(column_type_id, map_storage_class(space));
641 let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
642 [member_index as usize];
643 let column_indices = (0..columns as u32)
644 .map(|c| self.get_index_constant(column0_index + c))
645 .collect::<ArrayVec<_, 4>>();
646
647 let load_mat_from_struct =
650 |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
651 let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
652 for index in &column_indices {
653 let column_pointer_id = id_gen.next();
654 block.body.push(Instruction::access_chain(
655 column_pointer_type_id,
656 column_pointer_id,
657 struct_pointer_id,
658 &[*index],
659 ));
660 let column_id = id_gen.next();
661 block.body.push(Instruction::load(
662 column_type_id,
663 column_id,
664 column_pointer_id,
665 None,
666 ));
667 column_ids.push(column_id);
668 }
669 let result_id = id_gen.next();
670 block.body.push(Instruction::composite_construct(
671 matrix_type_id,
672 result_id,
673 &column_ids,
674 ));
675 result_id
676 };
677
678 let result_id = match self.write_access_chain(
679 struct_pointer,
680 block,
681 AccessTypeAdjustment::UseStd140CompatType,
682 )? {
683 ExpressionPointer::Ready { pointer_id } => {
684 load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
685 }
686 ExpressionPointer::Conditional { condition, access } => self
687 .write_conditional_indexed_load(
688 matrix_type_id,
689 condition,
690 block,
691 |id_gen, block| {
692 let pointer_id = access.result_id.unwrap();
693 block.body.push(access);
694 load_mat_from_struct(pointer_id, id_gen, block)
695 },
696 ),
697 };
698
699 Ok(Some(result_id))
700 }
701
702 pub(super) fn cache_expression_value(
704 &mut self,
705 expr_handle: Handle<crate::Expression>,
706 block: &mut Block,
707 ) -> Result<(), Error> {
708 let is_named_expression = self
709 .ir_function
710 .named_expressions
711 .contains_key(&expr_handle);
712
713 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
714 return Ok(());
715 }
716
717 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
718 let id = match self.ir_function.expressions[expr_handle] {
719 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
720 crate::Expression::Constant(handle) => {
721 let init = self.ir_module.constants[handle].init;
722 self.writer.constant_ids[init]
723 }
724 crate::Expression::Override(_) => return Err(Error::Override),
725 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
726 crate::Expression::Compose { ty, ref components } => {
727 self.temp_list.clear();
728 if self.expression_constness.is_const(expr_handle) {
729 self.temp_list.extend(
730 crate::proc::flatten_compose(
731 ty,
732 components,
733 &self.ir_function.expressions,
734 &self.ir_module.types,
735 )
736 .map(|component| self.cached[component]),
737 );
738 self.writer
739 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
740 } else {
741 self.temp_list
742 .extend(components.iter().map(|&component| self.cached[component]));
743
744 let id = self.gen_id();
745 block.body.push(Instruction::composite_construct(
746 result_type_id,
747 id,
748 &self.temp_list,
749 ));
750 id
751 }
752 }
753 crate::Expression::Splat { size, value } => {
754 let value_id = self.cached[value];
755 let components = &[value_id; 4][..size as usize];
756
757 if self.expression_constness.is_const(expr_handle) {
758 let ty = self
759 .writer
760 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
761 self.writer.get_constant_composite(ty, components)
762 } else {
763 let id = self.gen_id();
764 block.body.push(Instruction::composite_construct(
765 result_type_id,
766 id,
767 components,
768 ));
769 id
770 }
771 }
772 crate::Expression::Access { base, index } => {
773 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
774 match *base_ty_inner {
775 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
776 0
782 }
783 _ if self.function.spilled_accesses.contains(base) => {
784 self.function.spilled_accesses.insert(expr_handle);
792 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
793 }
794 crate::TypeInner::Vector { .. } => self.write_vector_access(
795 result_type_id,
796 base,
797 None,
798 GuardedIndex::Expression(index),
799 block,
800 )?,
801 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
802 match GuardedIndex::from_expression(
804 index,
805 &self.ir_function.expressions,
806 self.ir_module,
807 ) {
808 GuardedIndex::Known(value) => {
809 let id = self.gen_id();
819 let base_id = self.cached[base];
820 block.body.push(Instruction::composite_extract(
821 result_type_id,
822 id,
823 base_id,
824 &[value],
825 ));
826 id
827 }
828 GuardedIndex::Expression(_) => {
829 self.spill_to_internal_variable(base, block);
836
837 self.function.spilled_accesses.insert(expr_handle);
840 self.maybe_access_spilled_composite(
841 expr_handle,
842 block,
843 result_type_id,
844 )?
845 }
846 }
847 }
848 crate::TypeInner::BindingArray {
849 base: binding_type, ..
850 } => {
851 let result_id = match self.write_access_chain(
854 expr_handle,
855 block,
856 AccessTypeAdjustment::IntroducePointer(
857 spirv::StorageClass::UniformConstant,
858 ),
859 )? {
860 ExpressionPointer::Ready { pointer_id } => pointer_id,
861 ExpressionPointer::Conditional { .. } => {
862 return Err(Error::FeatureNotImplemented(
863 "Texture array out-of-bounds handling",
864 ));
865 }
866 };
867
868 let binding_type_id = self.get_handle_type_id(binding_type);
869
870 let load_id = self.gen_id();
871 block.body.push(Instruction::load(
872 binding_type_id,
873 load_id,
874 result_id,
875 None,
876 ));
877
878 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
882 self.writer
883 .decorate_non_uniform_binding_array_access(load_id)?;
884 }
885
886 load_id
887 }
888 ref other => {
889 log::error!(
890 "Unable to access base {:?} of type {:?}",
891 self.ir_function.expressions[base],
892 other
893 );
894 return Err(Error::Validation(
895 "only vectors and arrays may be dynamically indexed by value",
896 ));
897 }
898 }
899 }
900 crate::Expression::AccessIndex { base, index } => {
901 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
902 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
903 0
909 }
910 _ if self.function.spilled_accesses.contains(base) => {
911 self.function.spilled_accesses.insert(expr_handle);
919 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
920 }
921 crate::TypeInner::Vector { .. }
922 | crate::TypeInner::Matrix { .. }
923 | crate::TypeInner::Array { .. }
924 | crate::TypeInner::Struct { .. } => {
925 let id = self.gen_id();
930 let base_id = self.cached[base];
931 block.body.push(Instruction::composite_extract(
932 result_type_id,
933 id,
934 base_id,
935 &[index],
936 ));
937 id
938 }
939 crate::TypeInner::BindingArray {
940 base: binding_type, ..
941 } => {
942 let result_id = match self.write_access_chain(
945 expr_handle,
946 block,
947 AccessTypeAdjustment::IntroducePointer(
948 spirv::StorageClass::UniformConstant,
949 ),
950 )? {
951 ExpressionPointer::Ready { pointer_id } => pointer_id,
952 ExpressionPointer::Conditional { .. } => {
953 return Err(Error::FeatureNotImplemented(
954 "Texture array out-of-bounds handling",
955 ));
956 }
957 };
958
959 let binding_type_id = self.get_handle_type_id(binding_type);
960
961 let load_id = self.gen_id();
962 block.body.push(Instruction::load(
963 binding_type_id,
964 load_id,
965 result_id,
966 None,
967 ));
968
969 load_id
970 }
971 ref other => {
972 log::error!("Unable to access index of {other:?}");
973 return Err(Error::FeatureNotImplemented("access index for type"));
974 }
975 }
976 }
977 crate::Expression::GlobalVariable(handle) => {
978 self.writer.global_variables[handle].access_id
979 }
980 crate::Expression::Swizzle {
981 size,
982 vector,
983 pattern,
984 } => {
985 let vector_id = self.cached[vector];
986 self.temp_list.clear();
987 for &sc in pattern[..size as usize].iter() {
988 self.temp_list.push(sc as Word);
989 }
990 let id = self.gen_id();
991 block.body.push(Instruction::vector_shuffle(
992 result_type_id,
993 id,
994 vector_id,
995 vector_id,
996 &self.temp_list,
997 ));
998 id
999 }
1000 crate::Expression::Unary { op, expr } => {
1001 let id = self.gen_id();
1002 let expr_id = self.cached[expr];
1003 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1004
1005 let spirv_op = match op {
1006 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1007 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1008 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1009 _ => return Err(Error::Validation("Unexpected kind for negation")),
1010 },
1011 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1012 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1013 };
1014
1015 block
1016 .body
1017 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1018 id
1019 }
1020 crate::Expression::Binary { op, left, right } => {
1021 let id = self.gen_id();
1022 let left_id = self.cached[left];
1023 let right_id = self.cached[right];
1024 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1025 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1026
1027 if let Some(function_id) =
1028 self.writer
1029 .wrapped_functions
1030 .get(&WrappedFunction::BinaryOp {
1031 op,
1032 left_type_id,
1033 right_type_id,
1034 })
1035 {
1036 block.body.push(Instruction::function_call(
1037 result_type_id,
1038 id,
1039 *function_id,
1040 &[left_id, right_id],
1041 ));
1042 } else {
1043 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1044 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1045
1046 let left_dimension = get_dimension(left_ty_inner);
1047 let right_dimension = get_dimension(right_ty_inner);
1048
1049 let mut reverse_operands = false;
1050
1051 let spirv_op = match op {
1052 crate::BinaryOperator::Add => match *left_ty_inner {
1053 crate::TypeInner::Scalar(scalar)
1054 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1055 crate::ScalarKind::Float => spirv::Op::FAdd,
1056 _ => spirv::Op::IAdd,
1057 },
1058 crate::TypeInner::Matrix {
1059 columns,
1060 rows,
1061 scalar,
1062 } => {
1063 self.write_matrix_matrix_column_op(
1065 block,
1066 id,
1067 result_type_id,
1068 left_id,
1069 right_id,
1070 columns,
1071 rows,
1072 scalar.width,
1073 spirv::Op::FAdd,
1074 );
1075
1076 self.cached[expr_handle] = id;
1077 return Ok(());
1078 }
1079 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1080 _ => unimplemented!(),
1081 },
1082 crate::BinaryOperator::Subtract => match *left_ty_inner {
1083 crate::TypeInner::Scalar(scalar)
1084 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1085 crate::ScalarKind::Float => spirv::Op::FSub,
1086 _ => spirv::Op::ISub,
1087 },
1088 crate::TypeInner::Matrix {
1089 columns,
1090 rows,
1091 scalar,
1092 } => {
1093 self.write_matrix_matrix_column_op(
1094 block,
1095 id,
1096 result_type_id,
1097 left_id,
1098 right_id,
1099 columns,
1100 rows,
1101 scalar.width,
1102 spirv::Op::FSub,
1103 );
1104
1105 self.cached[expr_handle] = id;
1106 return Ok(());
1107 }
1108 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1109 _ => unimplemented!(),
1110 },
1111 crate::BinaryOperator::Multiply => {
1112 match (left_dimension, right_dimension) {
1113 (Dimension::Scalar, Dimension::Vector) => {
1114 self.write_vector_scalar_mult(
1115 block,
1116 id,
1117 result_type_id,
1118 right_id,
1119 left_id,
1120 right_ty_inner,
1121 );
1122
1123 self.cached[expr_handle] = id;
1124 return Ok(());
1125 }
1126 (Dimension::Vector, Dimension::Scalar) => {
1127 self.write_vector_scalar_mult(
1128 block,
1129 id,
1130 result_type_id,
1131 left_id,
1132 right_id,
1133 left_ty_inner,
1134 );
1135
1136 self.cached[expr_handle] = id;
1137 return Ok(());
1138 }
1139 (Dimension::Vector, Dimension::Matrix) => {
1140 spirv::Op::VectorTimesMatrix
1141 }
1142 (Dimension::Matrix, Dimension::Scalar)
1143 | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1144 spirv::Op::MatrixTimesScalar
1145 }
1146 (Dimension::Scalar, Dimension::Matrix)
1147 | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1148 reverse_operands = true;
1149 spirv::Op::MatrixTimesScalar
1150 }
1151 (Dimension::Matrix, Dimension::Vector) => {
1152 spirv::Op::MatrixTimesVector
1153 }
1154 (Dimension::Matrix, Dimension::Matrix) => {
1155 spirv::Op::MatrixTimesMatrix
1156 }
1157 (Dimension::Vector, Dimension::Vector)
1158 | (Dimension::Scalar, Dimension::Scalar)
1159 if left_ty_inner.scalar_kind()
1160 == Some(crate::ScalarKind::Float) =>
1161 {
1162 spirv::Op::FMul
1163 }
1164 (Dimension::Vector, Dimension::Vector)
1165 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1166 (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1167 | (Dimension::CooperativeMatrix, _)
1169 | (_, Dimension::CooperativeMatrix) => {
1170 unimplemented!()
1171 }
1172 }
1173 }
1174 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1175 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1176 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1177 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1178 _ => unimplemented!(),
1179 },
1180 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1181 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1182 Some(crate::ScalarKind::Sint) => {
1183 assert!(!self.writer.emit_int_div_checks);
1184 spirv::Op::SRem
1185 }
1186 Some(crate::ScalarKind::Uint) => {
1187 assert!(!self.writer.emit_int_div_checks);
1188 spirv::Op::UMod
1189 }
1190 _ => unimplemented!(),
1191 },
1192 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1193 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1194 spirv::Op::IEqual
1195 }
1196 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1197 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1198 _ => unimplemented!(),
1199 },
1200 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1201 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1202 spirv::Op::INotEqual
1203 }
1204 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1205 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1206 _ => unimplemented!(),
1207 },
1208 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1209 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1210 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1211 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1212 _ => unimplemented!(),
1213 },
1214 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1215 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1216 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1217 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1218 _ => unimplemented!(),
1219 },
1220 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1221 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1222 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1223 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1224 _ => unimplemented!(),
1225 },
1226 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1227 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1228 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1229 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1230 _ => unimplemented!(),
1231 },
1232 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1233 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1234 _ => spirv::Op::BitwiseAnd,
1235 },
1236 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1237 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1238 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1239 _ => spirv::Op::BitwiseOr,
1240 },
1241 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1242 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1243 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1244 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1245 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1246 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1247 _ => unimplemented!(),
1248 },
1249 };
1250
1251 block.body.push(Instruction::binary(
1252 spirv_op,
1253 result_type_id,
1254 id,
1255 if reverse_operands { right_id } else { left_id },
1256 if reverse_operands { left_id } else { right_id },
1257 ));
1258 }
1259 id
1260 }
1261 crate::Expression::Math {
1262 fun,
1263 arg,
1264 arg1,
1265 arg2,
1266 arg3,
1267 } => {
1268 use crate::MathFunction as Mf;
1269 enum MathOp {
1270 Ext(spirv::GlslStd450Op),
1271 Custom(Instruction),
1272 }
1273
1274 let arg0_id = self.cached[arg];
1275 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1276 let arg_scalar_kind = arg_ty.scalar_kind();
1277 let arg1_id = match arg1 {
1278 Some(handle) => self.cached[handle],
1279 None => 0,
1280 };
1281 let arg2_id = match arg2 {
1282 Some(handle) => self.cached[handle],
1283 None => 0,
1284 };
1285 let arg3_id = match arg3 {
1286 Some(handle) => self.cached[handle],
1287 None => 0,
1288 };
1289
1290 let id = self.gen_id();
1291 let math_op = match fun {
1292 Mf::Abs => {
1294 match arg_scalar_kind {
1295 Some(crate::ScalarKind::Float) => {
1296 MathOp::Ext(spirv::GlslStd450Op::FAbs)
1297 }
1298 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GlslStd450Op::SAbs),
1299 Some(crate::ScalarKind::Uint) => {
1300 MathOp::Custom(Instruction::unary(
1301 spirv::Op::CopyObject, result_type_id,
1303 id,
1304 arg0_id,
1305 ))
1306 }
1307 other => unimplemented!("Unexpected abs({:?})", other),
1308 }
1309 }
1310 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1311 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMin,
1312 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMin,
1313 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMin,
1314 other => unimplemented!("Unexpected min({:?})", other),
1315 }),
1316 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1317 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FMax,
1318 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SMax,
1319 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::UMax,
1320 other => unimplemented!("Unexpected max({:?})", other),
1321 }),
1322 Mf::Clamp => match arg_scalar_kind {
1323 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GlslStd450Op::FClamp),
1327 Some(_) => {
1328 let (min_op, max_op) = match arg_scalar_kind {
1329 Some(crate::ScalarKind::Sint) => {
1330 (spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::SMax)
1331 }
1332 Some(crate::ScalarKind::Uint) => {
1333 (spirv::GlslStd450Op::UMin, spirv::GlslStd450Op::UMax)
1334 }
1335 _ => unreachable!(),
1336 };
1337
1338 let max_id = self.gen_id();
1339 block.body.push(Instruction::ext_inst_gl_op(
1340 self.writer.gl450_ext_inst_id,
1341 max_op,
1342 result_type_id,
1343 max_id,
1344 &[arg0_id, arg1_id],
1345 ));
1346
1347 MathOp::Custom(Instruction::ext_inst_gl_op(
1348 self.writer.gl450_ext_inst_id,
1349 min_op,
1350 result_type_id,
1351 id,
1352 &[max_id, arg2_id],
1353 ))
1354 }
1355 other => unimplemented!("Unexpected max({:?})", other),
1356 },
1357 Mf::Saturate => {
1358 let (maybe_size, scalar) = match *arg_ty {
1359 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1360 crate::TypeInner::Scalar(scalar) => (None, scalar),
1361 ref other => unimplemented!("Unexpected saturate({:?})", other),
1362 };
1363 let scalar = crate::Scalar::float(scalar.width);
1364 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1365 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1366
1367 if let Some(size) = maybe_size {
1368 let ty =
1369 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1370
1371 self.temp_list.clear();
1372 self.temp_list.resize(size as _, arg1_id);
1373
1374 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1375
1376 self.temp_list.fill(arg2_id);
1377
1378 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1379 }
1380
1381 MathOp::Custom(Instruction::ext_inst_gl_op(
1382 self.writer.gl450_ext_inst_id,
1383 spirv::GlslStd450Op::FClamp,
1384 result_type_id,
1385 id,
1386 &[arg0_id, arg1_id, arg2_id],
1387 ))
1388 }
1389 Mf::Sin => MathOp::Ext(spirv::GlslStd450Op::Sin),
1391 Mf::Sinh => MathOp::Ext(spirv::GlslStd450Op::Sinh),
1392 Mf::Asin => MathOp::Ext(spirv::GlslStd450Op::Asin),
1393 Mf::Cos => MathOp::Ext(spirv::GlslStd450Op::Cos),
1394 Mf::Cosh => MathOp::Ext(spirv::GlslStd450Op::Cosh),
1395 Mf::Acos => MathOp::Ext(spirv::GlslStd450Op::Acos),
1396 Mf::Tan => MathOp::Ext(spirv::GlslStd450Op::Tan),
1397 Mf::Tanh => MathOp::Ext(spirv::GlslStd450Op::Tanh),
1398 Mf::Atan => MathOp::Ext(spirv::GlslStd450Op::Atan),
1399 Mf::Atan2 => MathOp::Ext(spirv::GlslStd450Op::Atan2),
1400 Mf::Asinh => MathOp::Ext(spirv::GlslStd450Op::Asinh),
1401 Mf::Acosh => MathOp::Ext(spirv::GlslStd450Op::Acosh),
1402 Mf::Atanh => MathOp::Ext(spirv::GlslStd450Op::Atanh),
1403 Mf::Radians => MathOp::Ext(spirv::GlslStd450Op::Radians),
1404 Mf::Degrees => MathOp::Ext(spirv::GlslStd450Op::Degrees),
1405 Mf::Ceil => MathOp::Ext(spirv::GlslStd450Op::Ceil),
1407 Mf::Round => MathOp::Ext(spirv::GlslStd450Op::RoundEven),
1408 Mf::Floor => MathOp::Ext(spirv::GlslStd450Op::Floor),
1409 Mf::Fract => MathOp::Ext(spirv::GlslStd450Op::Fract),
1410 Mf::Trunc => MathOp::Ext(spirv::GlslStd450Op::Trunc),
1411 Mf::Modf => MathOp::Ext(spirv::GlslStd450Op::ModfStruct),
1412 Mf::Frexp => MathOp::Ext(spirv::GlslStd450Op::FrexpStruct),
1413 Mf::Ldexp => MathOp::Ext(spirv::GlslStd450Op::Ldexp),
1414 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1416 crate::TypeInner::Vector {
1417 scalar:
1418 crate::Scalar {
1419 kind: crate::ScalarKind::Float,
1420 ..
1421 },
1422 ..
1423 } => MathOp::Custom(Instruction::binary(
1424 spirv::Op::Dot,
1425 result_type_id,
1426 id,
1427 arg0_id,
1428 arg1_id,
1429 )),
1430 crate::TypeInner::Vector { size, .. } => {
1432 self.write_dot_product(
1433 id,
1434 result_type_id,
1435 arg0_id,
1436 arg1_id,
1437 size as u32,
1438 block,
1439 |result_id, composite_id, index| {
1440 Instruction::composite_extract(
1441 result_type_id,
1442 result_id,
1443 composite_id,
1444 &[index],
1445 )
1446 },
1447 );
1448 self.cached[expr_handle] = id;
1449 return Ok(());
1450 }
1451 _ => unreachable!(
1452 "Correct TypeInner for dot product should be already validated"
1453 ),
1454 },
1455 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1456 if self
1457 .writer
1458 .require_all(&[
1459 spirv::Capability::DotProduct,
1460 spirv::Capability::DotProductInput4x8BitPacked,
1461 ])
1462 .is_ok()
1463 {
1464 if self.writer.lang_version() < (1, 6) {
1466 self.writer.use_extension("SPV_KHR_integer_dot_product");
1470 }
1471
1472 let op = match fun {
1473 Mf::Dot4I8Packed => spirv::Op::SDot,
1474 Mf::Dot4U8Packed => spirv::Op::UDot,
1475 _ => unreachable!(),
1476 };
1477
1478 block.body.push(Instruction::ternary(
1479 op,
1480 result_type_id,
1481 id,
1482 arg0_id,
1483 arg1_id,
1484 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1485 ));
1486 } else {
1487 let (extract_op, arg0_id, arg1_id) = match fun {
1489 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1490 Mf::Dot4I8Packed => {
1491 let new_arg0_id = self.gen_id();
1494 block.body.push(Instruction::unary(
1495 spirv::Op::Bitcast,
1496 result_type_id,
1497 new_arg0_id,
1498 arg0_id,
1499 ));
1500
1501 let new_arg1_id = self.gen_id();
1502 block.body.push(Instruction::unary(
1503 spirv::Op::Bitcast,
1504 result_type_id,
1505 new_arg1_id,
1506 arg1_id,
1507 ));
1508
1509 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1510 }
1511 _ => unreachable!(),
1512 };
1513
1514 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1515
1516 const VEC_LENGTH: u8 = 4;
1517 let bit_shifts: [_; VEC_LENGTH as usize] =
1518 core::array::from_fn(|index| {
1519 self.writer
1520 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1521 });
1522
1523 self.write_dot_product(
1524 id,
1525 result_type_id,
1526 arg0_id,
1527 arg1_id,
1528 VEC_LENGTH as Word,
1529 block,
1530 |result_id, composite_id, index| {
1531 Instruction::ternary(
1532 extract_op,
1533 result_type_id,
1534 result_id,
1535 composite_id,
1536 bit_shifts[index as usize],
1537 eight,
1538 )
1539 },
1540 );
1541 }
1542
1543 self.cached[expr_handle] = id;
1544 return Ok(());
1545 }
1546 Mf::Outer => MathOp::Custom(Instruction::binary(
1547 spirv::Op::OuterProduct,
1548 result_type_id,
1549 id,
1550 arg0_id,
1551 arg1_id,
1552 )),
1553 Mf::Cross => MathOp::Ext(spirv::GlslStd450Op::Cross),
1554 Mf::Distance => MathOp::Ext(spirv::GlslStd450Op::Distance),
1555 Mf::Length => MathOp::Ext(spirv::GlslStd450Op::Length),
1556 Mf::Normalize => MathOp::Ext(spirv::GlslStd450Op::Normalize),
1557 Mf::FaceForward => MathOp::Ext(spirv::GlslStd450Op::FaceForward),
1558 Mf::Reflect => MathOp::Ext(spirv::GlslStd450Op::Reflect),
1559 Mf::Refract => MathOp::Ext(spirv::GlslStd450Op::Refract),
1560 Mf::Exp => MathOp::Ext(spirv::GlslStd450Op::Exp),
1562 Mf::Exp2 => MathOp::Ext(spirv::GlslStd450Op::Exp2),
1563 Mf::Log => MathOp::Ext(spirv::GlslStd450Op::Log),
1564 Mf::Log2 => MathOp::Ext(spirv::GlslStd450Op::Log2),
1565 Mf::Pow => MathOp::Ext(spirv::GlslStd450Op::Pow),
1566 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1568 Some(crate::ScalarKind::Float) => spirv::GlslStd450Op::FSign,
1569 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::SSign,
1570 other => unimplemented!("Unexpected sign({:?})", other),
1571 }),
1572 Mf::Fma => MathOp::Ext(spirv::GlslStd450Op::Fma),
1573 Mf::Mix => {
1574 let selector = arg2.unwrap();
1575 let selector_ty =
1576 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1577 match (arg_ty, selector_ty) {
1578 (
1580 &crate::TypeInner::Vector { size, .. },
1581 &crate::TypeInner::Scalar(scalar),
1582 ) => {
1583 let selector_type_id =
1584 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1585 self.temp_list.clear();
1586 self.temp_list.resize(size as usize, arg2_id);
1587
1588 let selector_id = self.gen_id();
1589 block.body.push(Instruction::composite_construct(
1590 selector_type_id,
1591 selector_id,
1592 &self.temp_list,
1593 ));
1594
1595 MathOp::Custom(Instruction::ext_inst_gl_op(
1596 self.writer.gl450_ext_inst_id,
1597 spirv::GlslStd450Op::FMix,
1598 result_type_id,
1599 id,
1600 &[arg0_id, arg1_id, selector_id],
1601 ))
1602 }
1603 _ => MathOp::Ext(spirv::GlslStd450Op::FMix),
1604 }
1605 }
1606 Mf::Step => MathOp::Ext(spirv::GlslStd450Op::Step),
1607 Mf::SmoothStep => MathOp::Ext(spirv::GlslStd450Op::SmoothStep),
1608 Mf::Sqrt => MathOp::Ext(spirv::GlslStd450Op::Sqrt),
1609 Mf::InverseSqrt => MathOp::Ext(spirv::GlslStd450Op::InverseSqrt),
1610 Mf::Inverse => MathOp::Ext(spirv::GlslStd450Op::MatrixInverse),
1611 Mf::Transpose => MathOp::Custom(Instruction::unary(
1612 spirv::Op::Transpose,
1613 result_type_id,
1614 id,
1615 arg0_id,
1616 )),
1617 Mf::Determinant => MathOp::Ext(spirv::GlslStd450Op::Determinant),
1618 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1619 spirv::Op::QuantizeToF16,
1620 result_type_id,
1621 id,
1622 arg0_id,
1623 )),
1624 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1625 spirv::Op::BitReverse,
1626 result_type_id,
1627 id,
1628 arg0_id,
1629 )),
1630 Mf::CountTrailingZeros => {
1631 let uint_id = match *arg_ty {
1632 crate::TypeInner::Vector { size, scalar } => {
1633 let ty =
1634 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1635
1636 self.temp_list.clear();
1637 self.temp_list.resize(
1638 size as _,
1639 self.writer
1640 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1641 );
1642
1643 self.writer.get_constant_composite(ty, &self.temp_list)
1644 }
1645 crate::TypeInner::Scalar(scalar) => self
1646 .writer
1647 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1648 _ => unreachable!(),
1649 };
1650
1651 let lsb_id = self.gen_id();
1652 block.body.push(Instruction::ext_inst_gl_op(
1653 self.writer.gl450_ext_inst_id,
1654 spirv::GlslStd450Op::FindILsb,
1655 result_type_id,
1656 lsb_id,
1657 &[arg0_id],
1658 ));
1659
1660 MathOp::Custom(Instruction::ext_inst_gl_op(
1661 self.writer.gl450_ext_inst_id,
1662 spirv::GlslStd450Op::UMin,
1663 result_type_id,
1664 id,
1665 &[uint_id, lsb_id],
1666 ))
1667 }
1668 Mf::CountLeadingZeros => {
1669 let (int_type_id, int_id, width) = match *arg_ty {
1670 crate::TypeInner::Vector { size, scalar } => {
1671 let ty =
1672 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1673
1674 self.temp_list.clear();
1675 self.temp_list.resize(
1676 size as _,
1677 self.writer
1678 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1679 );
1680
1681 (
1682 self.get_type_id(ty),
1683 self.writer.get_constant_composite(ty, &self.temp_list),
1684 scalar.width,
1685 )
1686 }
1687 crate::TypeInner::Scalar(scalar) => (
1688 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1689 self.writer
1690 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1691 scalar.width,
1692 ),
1693 _ => unreachable!(),
1694 };
1695
1696 if width != 4 {
1697 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1698 };
1699
1700 let msb_id = self.gen_id();
1701 block.body.push(Instruction::ext_inst_gl_op(
1702 self.writer.gl450_ext_inst_id,
1703 if width != 4 {
1704 spirv::GlslStd450Op::FindILsb
1705 } else {
1706 spirv::GlslStd450Op::FindUMsb
1707 },
1708 int_type_id,
1709 msb_id,
1710 &[arg0_id],
1711 ));
1712
1713 MathOp::Custom(Instruction::binary(
1714 spirv::Op::ISub,
1715 result_type_id,
1716 id,
1717 int_id,
1718 msb_id,
1719 ))
1720 }
1721 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1722 spirv::Op::BitCount,
1723 result_type_id,
1724 id,
1725 arg0_id,
1726 )),
1727 Mf::ExtractBits => {
1728 let op = match arg_scalar_kind {
1729 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1730 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1731 other => unimplemented!("Unexpected sign({:?})", other),
1732 };
1733
1734 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1749 let width_constant = self
1750 .writer
1751 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1752
1753 let u32_type =
1754 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1755
1756 let offset_id = self.gen_id();
1758 block.body.push(Instruction::ext_inst_gl_op(
1759 self.writer.gl450_ext_inst_id,
1760 spirv::GlslStd450Op::UMin,
1761 u32_type,
1762 offset_id,
1763 &[arg1_id, width_constant],
1764 ));
1765
1766 let max_count_id = self.gen_id();
1768 block.body.push(Instruction::binary(
1769 spirv::Op::ISub,
1770 u32_type,
1771 max_count_id,
1772 width_constant,
1773 offset_id,
1774 ));
1775
1776 let count_id = self.gen_id();
1778 block.body.push(Instruction::ext_inst_gl_op(
1779 self.writer.gl450_ext_inst_id,
1780 spirv::GlslStd450Op::UMin,
1781 u32_type,
1782 count_id,
1783 &[arg2_id, max_count_id],
1784 ));
1785
1786 MathOp::Custom(Instruction::ternary(
1787 op,
1788 result_type_id,
1789 id,
1790 arg0_id,
1791 offset_id,
1792 count_id,
1793 ))
1794 }
1795 Mf::InsertBits => {
1796 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1799 let width_constant = self
1800 .writer
1801 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1802
1803 let u32_type =
1804 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1805
1806 let offset_id = self.gen_id();
1808 block.body.push(Instruction::ext_inst_gl_op(
1809 self.writer.gl450_ext_inst_id,
1810 spirv::GlslStd450Op::UMin,
1811 u32_type,
1812 offset_id,
1813 &[arg2_id, width_constant],
1814 ));
1815
1816 let max_count_id = self.gen_id();
1818 block.body.push(Instruction::binary(
1819 spirv::Op::ISub,
1820 u32_type,
1821 max_count_id,
1822 width_constant,
1823 offset_id,
1824 ));
1825
1826 let count_id = self.gen_id();
1828 block.body.push(Instruction::ext_inst_gl_op(
1829 self.writer.gl450_ext_inst_id,
1830 spirv::GlslStd450Op::UMin,
1831 u32_type,
1832 count_id,
1833 &[arg3_id, max_count_id],
1834 ));
1835
1836 MathOp::Custom(Instruction::quaternary(
1837 spirv::Op::BitFieldInsert,
1838 result_type_id,
1839 id,
1840 arg0_id,
1841 arg1_id,
1842 offset_id,
1843 count_id,
1844 ))
1845 }
1846 Mf::FirstTrailingBit => MathOp::Ext(spirv::GlslStd450Op::FindILsb),
1847 Mf::FirstLeadingBit => {
1848 if arg_ty.scalar_width() == Some(4) {
1849 let thing = match arg_scalar_kind {
1850 Some(crate::ScalarKind::Uint) => spirv::GlslStd450Op::FindUMsb,
1851 Some(crate::ScalarKind::Sint) => spirv::GlslStd450Op::FindSMsb,
1852 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1853 };
1854 MathOp::Ext(thing)
1855 } else {
1856 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1857 }
1858 }
1859 Mf::Pack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm4x8),
1860 Mf::Pack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm4x8),
1861 Mf::Pack2x16float => MathOp::Ext(spirv::GlslStd450Op::PackHalf2x16),
1862 Mf::Pack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::PackUnorm2x16),
1863 Mf::Pack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::PackSnorm2x16),
1864 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1865 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1866 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1867
1868 let last_instruction =
1869 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1870 self.write_pack4x8_optimized(
1871 block,
1872 result_type_id,
1873 arg0_id,
1874 id,
1875 is_signed,
1876 should_clamp,
1877 )
1878 } else {
1879 self.write_pack4x8_polyfill(
1880 block,
1881 result_type_id,
1882 arg0_id,
1883 id,
1884 is_signed,
1885 should_clamp,
1886 )
1887 };
1888
1889 MathOp::Custom(last_instruction)
1890 }
1891 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm4x8),
1892 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm4x8),
1893 Mf::Unpack2x16float => MathOp::Ext(spirv::GlslStd450Op::UnpackHalf2x16),
1894 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GlslStd450Op::UnpackUnorm2x16),
1895 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GlslStd450Op::UnpackSnorm2x16),
1896 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1897 let is_signed = matches!(fun, Mf::Unpack4xI8);
1898
1899 let last_instruction =
1900 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1901 self.write_unpack4x8_optimized(
1902 block,
1903 result_type_id,
1904 arg0_id,
1905 id,
1906 is_signed,
1907 )
1908 } else {
1909 self.write_unpack4x8_polyfill(
1910 block,
1911 result_type_id,
1912 arg0_id,
1913 id,
1914 is_signed,
1915 )
1916 };
1917
1918 MathOp::Custom(last_instruction)
1919 }
1920 };
1921
1922 block.body.push(match math_op {
1923 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1924 self.writer.gl450_ext_inst_id,
1925 op,
1926 result_type_id,
1927 id,
1928 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1929 ),
1930 MathOp::Custom(inst) => inst,
1931 });
1932 id
1933 }
1934 crate::Expression::LocalVariable(variable) => {
1935 if let Some(rq_tracker) = self
1936 .function
1937 .ray_query_initialization_tracker_variables
1938 .get(&variable)
1939 {
1940 self.ray_query_tracker_expr.insert(
1941 expr_handle,
1942 super::RayQueryTrackers {
1943 initialized_tracker: rq_tracker.id,
1944 t_max_tracker: self
1945 .function
1946 .ray_query_t_max_tracker_variables
1947 .get(&variable)
1948 .expect("Both trackers are set at the same time.")
1949 .id,
1950 },
1951 );
1952 }
1953 self.function.variables[&variable].id
1954 }
1955 crate::Expression::Load { pointer } => {
1956 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1957 }
1958 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1959 crate::Expression::CallResult(_)
1960 | crate::Expression::AtomicResult { .. }
1961 | crate::Expression::WorkGroupUniformLoadResult { .. }
1962 | crate::Expression::RayQueryProceedResult
1963 | crate::Expression::SubgroupBallotResult
1964 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1965 crate::Expression::As {
1966 expr,
1967 kind,
1968 convert,
1969 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1970 crate::Expression::ImageLoad {
1971 image,
1972 coordinate,
1973 array_index,
1974 sample,
1975 level,
1976 } => self.write_image_load(
1977 result_type_id,
1978 image,
1979 coordinate,
1980 array_index,
1981 level,
1982 sample,
1983 block,
1984 )?,
1985 crate::Expression::ImageSample {
1986 image,
1987 sampler,
1988 gather,
1989 coordinate,
1990 array_index,
1991 offset,
1992 level,
1993 depth_ref,
1994 clamp_to_edge,
1995 } => self.write_image_sample(
1996 result_type_id,
1997 image,
1998 sampler,
1999 gather,
2000 coordinate,
2001 array_index,
2002 offset,
2003 level,
2004 depth_ref,
2005 clamp_to_edge,
2006 block,
2007 )?,
2008 crate::Expression::Select {
2009 condition,
2010 accept,
2011 reject,
2012 } => {
2013 let id = self.gen_id();
2014 let mut condition_id = self.cached[condition];
2015 let accept_id = self.cached[accept];
2016 let reject_id = self.cached[reject];
2017
2018 let condition_ty = self.fun_info[condition]
2019 .ty
2020 .inner_with(&self.ir_module.types);
2021 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2022
2023 if let (
2024 &crate::TypeInner::Scalar(
2025 condition_scalar @ crate::Scalar {
2026 kind: crate::ScalarKind::Bool,
2027 ..
2028 },
2029 ),
2030 &crate::TypeInner::Vector { size, .. },
2031 ) = (condition_ty, object_ty)
2032 {
2033 self.temp_list.clear();
2034 self.temp_list.resize(size as usize, condition_id);
2035
2036 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2037 size,
2038 scalar: condition_scalar,
2039 });
2040
2041 let id = self.gen_id();
2042 block.body.push(Instruction::composite_construct(
2043 bool_vector_type_id,
2044 id,
2045 &self.temp_list,
2046 ));
2047 condition_id = id
2048 }
2049
2050 let instruction =
2051 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2052 block.body.push(instruction);
2053 id
2054 }
2055 crate::Expression::Derivative { axis, ctrl, expr } => {
2056 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2057 match ctrl {
2058 Ctrl::Coarse | Ctrl::Fine => {
2059 self.writer.require_any(
2060 "DerivativeControl",
2061 &[spirv::Capability::DerivativeControl],
2062 )?;
2063 }
2064 Ctrl::None => {}
2065 }
2066 let id = self.gen_id();
2067 let expr_id = self.cached[expr];
2068 let op = match (axis, ctrl) {
2069 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2070 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2071 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2072 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2073 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2074 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2075 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2076 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2077 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2078 };
2079 block
2080 .body
2081 .push(Instruction::derivative(op, result_type_id, id, expr_id));
2082 id
2083 }
2084 crate::Expression::ImageQuery { image, query } => {
2085 self.write_image_query(result_type_id, image, query, block)?
2086 }
2087 crate::Expression::Relational { fun, argument } => {
2088 use crate::RelationalFunction as Rf;
2089 let arg_id = self.cached[argument];
2090 let op = match fun {
2091 Rf::All => spirv::Op::All,
2092 Rf::Any => spirv::Op::Any,
2093 Rf::IsNan => spirv::Op::IsNan,
2094 Rf::IsInf => spirv::Op::IsInf,
2095 };
2096 let id = self.gen_id();
2097 block
2098 .body
2099 .push(Instruction::relational(op, result_type_id, id, arg_id));
2100 id
2101 }
2102 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2103 crate::Expression::RayQueryGetIntersection { query, committed } => {
2104 let query_id = self.cached[query];
2105 let init_tracker_id = *self
2106 .ray_query_tracker_expr
2107 .get(&query)
2108 .expect("not a cached ray query");
2109 let func_id = self
2110 .writer
2111 .write_ray_query_get_intersection_function(committed, self.ir_module);
2112 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2113 let intersection_type_id = self.get_handle_type_id(ray_intersection);
2114 let id = self.gen_id();
2115 block.body.push(Instruction::function_call(
2116 intersection_type_id,
2117 id,
2118 func_id,
2119 &[query_id, init_tracker_id.initialized_tracker],
2120 ));
2121 id
2122 }
2123 crate::Expression::RayQueryVertexPositions { query, committed } => {
2124 self.writer.require_any(
2125 "RayQueryVertexPositions",
2126 &[spirv::Capability::RayQueryPositionFetchKHR],
2127 )?;
2128 self.write_ray_query_return_vertex_position(query, block, committed)
2129 }
2130 crate::Expression::CooperativeLoad { ref data, .. } => {
2131 self.writer.require_any(
2132 "CooperativeMatrix",
2133 &[spirv::Capability::CooperativeMatrixKHR],
2134 )?;
2135 let layout = if data.row_major {
2136 spirv::CooperativeMatrixLayout::RowMajorKHR
2137 } else {
2138 spirv::CooperativeMatrixLayout::ColumnMajorKHR
2139 };
2140 let layout_id = self.get_index_constant(layout as u32);
2141 let stride_id = self.cached[data.stride];
2142 match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2143 ExpressionPointer::Ready { pointer_id } => {
2144 let id = self.gen_id();
2145 block.body.push(Instruction::coop_load(
2146 result_type_id,
2147 id,
2148 pointer_id,
2149 layout_id,
2150 stride_id,
2151 ));
2152 id
2153 }
2154 ExpressionPointer::Conditional { condition, access } => self
2155 .write_conditional_indexed_load(
2156 result_type_id,
2157 condition,
2158 block,
2159 |id_gen, block| {
2160 let pointer_id = access.result_id.unwrap();
2161 block.body.push(access);
2162 let id = id_gen.next();
2163 block.body.push(Instruction::coop_load(
2164 result_type_id,
2165 id,
2166 pointer_id,
2167 layout_id,
2168 stride_id,
2169 ));
2170 id
2171 },
2172 ),
2173 }
2174 }
2175 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2176 self.writer.require_any(
2177 "CooperativeMatrix",
2178 &[spirv::Capability::CooperativeMatrixKHR],
2179 )?;
2180 let a_id = self.cached[a];
2181 let b_id = self.cached[b];
2182 let c_id = self.cached[c];
2183 let id = self.gen_id();
2184 block.body.push(Instruction::coop_mul_add(
2185 result_type_id,
2186 id,
2187 a_id,
2188 b_id,
2189 c_id,
2190 ));
2191 id
2192 }
2193 };
2194
2195 self.cached[expr_handle] = id;
2196 Ok(())
2197 }
2198
2199 fn write_as_expression(
2202 &mut self,
2203 expr: Handle<crate::Expression>,
2204 convert: Option<u8>,
2205 kind: crate::ScalarKind,
2206
2207 block: &mut Block,
2208 result_type_id: u32,
2209 ) -> Result<u32, Error> {
2210 use crate::ScalarKind as Sk;
2211 let expr_id = self.cached[expr];
2212 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2213
2214 if let crate::TypeInner::Matrix {
2219 columns,
2220 rows,
2221 scalar,
2222 } = *ty
2223 {
2224 let Some(convert) = convert else {
2225 return Ok(expr_id);
2227 };
2228
2229 if convert == scalar.width {
2230 return Ok(expr_id);
2232 }
2233
2234 if kind != Sk::Float {
2235 return Err(Error::Validation("Matrices must be floats"));
2237 }
2238
2239 let column_src_ty =
2241 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2242 size: rows,
2243 scalar,
2244 })));
2245
2246 let column_dst_ty =
2248 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2249 size: rows,
2250 scalar: crate::Scalar {
2251 kind,
2252 width: convert,
2253 },
2254 })));
2255
2256 let mut components = ArrayVec::<Word, 4>::new();
2257
2258 for column in 0..columns as usize {
2259 let column_id = self.gen_id();
2260 block.body.push(Instruction::composite_extract(
2261 column_src_ty,
2262 column_id,
2263 expr_id,
2264 &[column as u32],
2265 ));
2266
2267 let column_conv_id = self.gen_id();
2268 block.body.push(Instruction::unary(
2269 spirv::Op::FConvert,
2270 column_dst_ty,
2271 column_conv_id,
2272 column_id,
2273 ));
2274
2275 components.push(column_conv_id);
2276 }
2277
2278 let construct_id = self.gen_id();
2279
2280 block.body.push(Instruction::composite_construct(
2281 result_type_id,
2282 construct_id,
2283 &components,
2284 ));
2285
2286 return Ok(construct_id);
2287 }
2288
2289 let (src_scalar, src_size) = match *ty {
2290 crate::TypeInner::Scalar(scalar) => (scalar, None),
2291 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2292 ref other => {
2293 log::error!("As source {other:?}");
2294 return Err(Error::Validation("Unexpected Expression::As source"));
2295 }
2296 };
2297
2298 enum Cast {
2299 Identity(Word),
2300 Unary(spirv::Op, Word),
2301 Binary(spirv::Op, Word, Word),
2302 Ternary(spirv::Op, Word, Word, Word),
2303 }
2304 let cast = match (src_scalar.kind, kind, convert) {
2305 (src_kind, kind, convert)
2308 if src_kind == kind
2309 && convert.filter(|&width| width != src_scalar.width).is_none() =>
2310 {
2311 Cast::Identity(expr_id)
2312 }
2313 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2314 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2315 (_, Sk::Bool, Some(_)) => {
2317 let op = match src_scalar.kind {
2318 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2319 Sk::Float => spirv::Op::FUnordNotEqual,
2320 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2321 };
2322 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2323 let zero_id = match src_size {
2324 Some(size) => {
2325 let ty = LocalType::Numeric(NumericType::Vector {
2326 size,
2327 scalar: src_scalar,
2328 })
2329 .into();
2330
2331 self.temp_list.clear();
2332 self.temp_list.resize(size as _, zero_scalar_id);
2333
2334 self.writer.get_constant_composite(ty, &self.temp_list)
2335 }
2336 None => zero_scalar_id,
2337 };
2338
2339 Cast::Binary(op, expr_id, zero_id)
2340 }
2341 (Sk::Bool, _, Some(dst_width)) => {
2343 let dst_scalar = crate::Scalar {
2344 kind,
2345 width: dst_width,
2346 };
2347 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2348 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2349 let (accept_id, reject_id) = match src_size {
2350 Some(size) => {
2351 let ty = LocalType::Numeric(NumericType::Vector {
2352 size,
2353 scalar: dst_scalar,
2354 })
2355 .into();
2356
2357 self.temp_list.clear();
2358 self.temp_list.resize(size as _, zero_scalar_id);
2359
2360 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2361
2362 self.temp_list.fill(one_scalar_id);
2363
2364 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2365
2366 (vec1_id, vec0_id)
2367 }
2368 None => (one_scalar_id, zero_scalar_id),
2369 };
2370
2371 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2372 }
2373 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2384 let dst_scalar = crate::Scalar { kind, width };
2385 let (min, max) =
2386 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2387 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2388
2389 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2390 None => const_id,
2391 Some(size) => {
2392 let constituent_ids = [const_id; crate::VectorSize::MAX];
2393 writer.get_constant_composite(
2394 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2395 size,
2396 scalar: src_scalar,
2397 })),
2398 &constituent_ids[..size as usize],
2399 )
2400 }
2401 };
2402 let min_const_id = self.writer.get_constant_scalar(min);
2403 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2404 let max_const_id = self.writer.get_constant_scalar(max);
2405 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2406
2407 let clamp_id = self.gen_id();
2408 block.body.push(Instruction::ext_inst_gl_op(
2409 self.writer.gl450_ext_inst_id,
2410 spirv::GlslStd450Op::FClamp,
2411 expr_type_id,
2412 clamp_id,
2413 &[expr_id, min_const_id, max_const_id],
2414 ));
2415
2416 let op = match dst_scalar.kind {
2417 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2418 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2419 _ => unreachable!(),
2420 };
2421 Cast::Unary(op, clamp_id)
2422 }
2423 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2424 Cast::Unary(spirv::Op::FConvert, expr_id)
2425 }
2426 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2427 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2428 Cast::Unary(spirv::Op::SConvert, expr_id)
2429 }
2430 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2431 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2432 Cast::Unary(spirv::Op::UConvert, expr_id)
2433 }
2434 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2435 Cast::Unary(spirv::Op::SConvert, expr_id)
2436 }
2437 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2438 Cast::Unary(spirv::Op::UConvert, expr_id)
2439 }
2440 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2442 };
2443 Ok(match cast {
2444 Cast::Identity(expr) => expr,
2445 Cast::Unary(op, op1) => {
2446 let id = self.gen_id();
2447 block
2448 .body
2449 .push(Instruction::unary(op, result_type_id, id, op1));
2450 id
2451 }
2452 Cast::Binary(op, op1, op2) => {
2453 let id = self.gen_id();
2454 block
2455 .body
2456 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2457 id
2458 }
2459 Cast::Ternary(op, op1, op2, op3) => {
2460 let id = self.gen_id();
2461 block
2462 .body
2463 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2464 id
2465 }
2466 })
2467 }
2468
2469 fn write_access_chain(
2480 &mut self,
2481 mut expr_handle: Handle<crate::Expression>,
2482 block: &mut Block,
2483 type_adjustment: AccessTypeAdjustment,
2484 ) -> Result<ExpressionPointer, Error> {
2485 let result_type_id = {
2486 let resolution = &self.fun_info[expr_handle].ty;
2487 match type_adjustment {
2488 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2489 AccessTypeAdjustment::IntroducePointer(class) => {
2490 self.writer.get_resolution_pointer_id(resolution, class)
2491 }
2492 AccessTypeAdjustment::UseStd140CompatType => {
2493 match *resolution.inner_with(&self.ir_module.types) {
2494 crate::TypeInner::Pointer {
2495 base,
2496 space: space @ crate::AddressSpace::Uniform,
2497 } => self.writer.get_pointer_type_id(
2498 self.writer.std140_compat_uniform_types[&base].type_id,
2499 map_storage_class(space),
2500 ),
2501 _ => unreachable!(
2502 "`UseStd140CompatType` must only be used with uniform pointer types"
2503 ),
2504 }
2505 }
2506 }
2507 };
2508
2509 let mut accumulated_checks = None;
2513
2514 let mut is_non_uniform_binding_array = false;
2516
2517 let mut prev_decomposed_matrix_index = None;
2523
2524 self.temp_list.clear();
2525 let root_id = loop {
2526 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2529 break spilled.id;
2532 }
2533
2534 expr_handle = match self.ir_function.expressions[expr_handle] {
2535 crate::Expression::Access { base, index } => {
2536 is_non_uniform_binding_array |=
2537 self.is_nonuniform_binding_array_access(base, index);
2538
2539 let index = GuardedIndex::Expression(index);
2540 let index_id =
2541 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2542 self.temp_list.push(index_id);
2543
2544 base
2545 }
2546 crate::Expression::AccessIndex { base, index } => {
2547 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2550 let mut base_ty_handle = self.fun_info[base].ty.handle();
2551 let mut pointer_space = None;
2552 if let crate::TypeInner::Pointer { base, space } = *base_ty {
2553 base_ty = &self.ir_module.types[base].inner;
2554 base_ty_handle = Some(base);
2555 pointer_space = Some(space);
2556 }
2557 match *base_ty {
2558 crate::TypeInner::Struct { .. } => {
2565 let index = match base_ty_handle.and_then(|handle| {
2566 self.writer.std140_compat_uniform_types.get(&handle)
2567 }) {
2568 Some(std140_type_info)
2569 if pointer_space == Some(crate::AddressSpace::Uniform) =>
2570 {
2571 std140_type_info.member_indices[index as usize]
2572 + prev_decomposed_matrix_index.take().unwrap_or(0)
2573 }
2574 _ => index,
2575 };
2576 let index_id = self.get_index_constant(index);
2577 self.temp_list.push(index_id);
2578 }
2579 _ if is_uniform_matcx2_struct_member_access(
2586 self.ir_function,
2587 self.fun_info,
2588 self.ir_module,
2589 base,
2590 ) =>
2591 {
2592 assert!(prev_decomposed_matrix_index.is_none());
2593 prev_decomposed_matrix_index = Some(index);
2594 }
2595 _ => {
2596 let index_id = self.write_access_chain_index(
2603 base,
2604 GuardedIndex::Known(index),
2605 &mut accumulated_checks,
2606 block,
2607 )?;
2608 self.temp_list.push(index_id);
2609 }
2610 }
2611 base
2612 }
2613 crate::Expression::GlobalVariable(handle) => {
2614 let gv = &self.writer.global_variables[handle];
2615 break gv.access_id;
2616 }
2617 crate::Expression::LocalVariable(variable) => {
2618 let local_var = &self.function.variables[&variable];
2619 break local_var.id;
2620 }
2621 crate::Expression::FunctionArgument(index) => {
2622 break self.function.parameter_id(index);
2623 }
2624 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2625 }
2626 };
2627
2628 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2629 (
2630 root_id,
2631 ExpressionPointer::Ready {
2632 pointer_id: root_id,
2633 },
2634 )
2635 } else {
2636 self.temp_list.reverse();
2637 let pointer_id = self.gen_id();
2638 let access =
2639 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2640
2641 let expr_pointer = match accumulated_checks {
2646 Some(condition) => ExpressionPointer::Conditional { condition, access },
2647 None => {
2648 block.body.push(access);
2649 ExpressionPointer::Ready { pointer_id }
2650 }
2651 };
2652 (pointer_id, expr_pointer)
2653 };
2654 if is_non_uniform_binding_array {
2658 self.writer
2659 .decorate_non_uniform_binding_array_access(pointer_id)?;
2660 }
2661
2662 Ok(expr_pointer)
2663 }
2664
2665 fn is_nonuniform_binding_array_access(
2666 &mut self,
2667 base: Handle<crate::Expression>,
2668 index: Handle<crate::Expression>,
2669 ) -> bool {
2670 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2671 else {
2672 return false;
2673 };
2674
2675 let gvar = &self.ir_module.global_variables[var_handle];
2678 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2679 return false;
2680 };
2681
2682 self.fun_info[index].uniformity.non_uniform_result.is_some()
2683 }
2684
2685 fn write_access_chain_index(
2695 &mut self,
2696 base: Handle<crate::Expression>,
2697 index: GuardedIndex,
2698 accumulated_checks: &mut Option<Word>,
2699 block: &mut Block,
2700 ) -> Result<Word, Error> {
2701 match self.write_bounds_check(base, index, block)? {
2702 BoundsCheckResult::KnownInBounds(known_index) => {
2703 let scalar = crate::Literal::U32(known_index);
2706 Ok(self.writer.get_constant_scalar(scalar))
2707 }
2708 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2709 BoundsCheckResult::Conditional {
2710 condition_id: condition,
2711 index_id: index,
2712 } => {
2713 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2714
2715 Ok(index)
2717 }
2718 }
2719 }
2720
2721 fn extend_bounds_check_condition_chain(
2740 &mut self,
2741 chain: &mut Option<Word>,
2742 comparison_id: Word,
2743 block: &mut Block,
2744 ) {
2745 match *chain {
2746 Some(ref mut prior_checks) => {
2747 let combined = self.gen_id();
2748 block.body.push(Instruction::binary(
2749 spirv::Op::LogicalAnd,
2750 self.writer.get_bool_type_id(),
2751 combined,
2752 *prior_checks,
2753 comparison_id,
2754 ));
2755 *prior_checks = combined;
2756 }
2757 None => {
2758 *chain = Some(comparison_id);
2760 }
2761 }
2762 }
2763
2764 fn write_checked_load(
2765 &mut self,
2766 pointer: Handle<crate::Expression>,
2767 block: &mut Block,
2768 access_type_adjustment: AccessTypeAdjustment,
2769 result_type_id: Word,
2770 ) -> Result<Word, Error> {
2771 if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2772 Ok(result_id)
2773 } else if let Some(result_id) =
2774 self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2775 {
2776 Ok(result_id)
2777 } else {
2778 struct WrappedLoad {
2784 access_type_adjustment: AccessTypeAdjustment,
2785 r#type: Handle<crate::Type>,
2786 }
2787 let mut wrapped_load = None;
2788 if let crate::TypeInner::Pointer {
2789 base: pointer_base_type,
2790 space: crate::AddressSpace::Uniform,
2791 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2792 {
2793 if self
2794 .writer
2795 .std140_compat_uniform_types
2796 .contains_key(&pointer_base_type)
2797 {
2798 wrapped_load = Some(WrappedLoad {
2799 access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2800 r#type: pointer_base_type,
2801 });
2802 };
2803 };
2804
2805 let (load_type_id, access_type_adjustment) = match wrapped_load {
2806 Some(ref wrapped_load) => (
2807 self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2808 wrapped_load.access_type_adjustment,
2809 ),
2810 None => (result_type_id, access_type_adjustment),
2811 };
2812
2813 let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2814 ExpressionPointer::Ready { pointer_id } => {
2815 let id = self.gen_id();
2816 let atomic_space =
2817 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2818 crate::TypeInner::Pointer { base, space } => {
2819 match self.ir_module.types[base].inner {
2820 crate::TypeInner::Atomic { .. } => Some(space),
2821 _ => None,
2822 }
2823 }
2824 _ => None,
2825 };
2826 let instruction = if let Some(space) = atomic_space {
2827 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2828 let scope_constant_id = self.get_scope_constant(scope as u32);
2829 let semantics_id = self.get_index_constant(semantics.bits());
2830 Instruction::atomic_load(
2831 result_type_id,
2832 id,
2833 pointer_id,
2834 scope_constant_id,
2835 semantics_id,
2836 )
2837 } else {
2838 Instruction::load(load_type_id, id, pointer_id, None)
2839 };
2840 block.body.push(instruction);
2841 id
2842 }
2843 ExpressionPointer::Conditional { condition, access } => {
2844 self.write_conditional_indexed_load(
2846 load_type_id,
2847 condition,
2848 block,
2849 move |id_gen, block| {
2850 let pointer_id = access.result_id.unwrap();
2852 let value_id = id_gen.next();
2853 block.body.push(access);
2854 block.body.push(Instruction::load(
2855 load_type_id,
2856 value_id,
2857 pointer_id,
2858 None,
2859 ));
2860 value_id
2861 },
2862 )
2863 }
2864 };
2865
2866 match wrapped_load {
2867 Some(ref wrapped_load) => {
2868 let result_id = self.gen_id();
2871 let function_id = self.writer.wrapped_functions
2872 [&WrappedFunction::ConvertFromStd140CompatType {
2873 r#type: wrapped_load.r#type,
2874 }];
2875 block.body.push(Instruction::function_call(
2876 result_type_id,
2877 result_id,
2878 function_id,
2879 &[load_id],
2880 ));
2881 Ok(result_id)
2882 }
2883 None => Ok(load_id),
2884 }
2885 }
2886 }
2887
2888 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2889 use indexmap::map::Entry;
2890
2891 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2893 Entry::Occupied(preexisting) => preexisting.get().id,
2894 Entry::Vacant(vacant) => {
2895 let pointer_type_id = self.writer.get_resolution_pointer_id(
2898 &self.fun_info[base].ty,
2899 spirv::StorageClass::Function,
2900 );
2901 let id = self.writer.id_gen.next();
2902 vacant.insert(super::LocalVariable {
2903 id,
2904 instruction: Instruction::variable(
2905 pointer_type_id,
2906 id,
2907 spirv::StorageClass::Function,
2908 None,
2909 ),
2910 });
2911 id
2912 }
2913 };
2914
2915 let base_id = self.cached[base];
2940 block
2941 .body
2942 .push(Instruction::store(spill_variable_id, base_id, None));
2943 }
2944
2945 fn maybe_access_spilled_composite(
2962 &mut self,
2963 access: Handle<crate::Expression>,
2964 block: &mut Block,
2965 result_type_id: Word,
2966 ) -> Result<Word, Error> {
2967 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2968 if access_uses == self.fun_info[access].ref_count {
2969 Ok(0)
2973 } else {
2974 self.write_checked_load(
2979 access,
2980 block,
2981 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2982 result_type_id,
2983 )
2984 }
2985 }
2986
2987 #[allow(clippy::too_many_arguments)]
2989 fn write_matrix_matrix_column_op(
2990 &mut self,
2991 block: &mut Block,
2992 result_id: Word,
2993 result_type_id: Word,
2994 left_id: Word,
2995 right_id: Word,
2996 columns: crate::VectorSize,
2997 rows: crate::VectorSize,
2998 width: u8,
2999 op: spirv::Op,
3000 ) {
3001 self.temp_list.clear();
3002
3003 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3004 size: rows,
3005 scalar: crate::Scalar::float(width),
3006 });
3007
3008 for index in 0..columns as u32 {
3009 let column_id_left = self.gen_id();
3010 let column_id_right = self.gen_id();
3011 let column_id_res = self.gen_id();
3012
3013 block.body.push(Instruction::composite_extract(
3014 vector_type_id,
3015 column_id_left,
3016 left_id,
3017 &[index],
3018 ));
3019 block.body.push(Instruction::composite_extract(
3020 vector_type_id,
3021 column_id_right,
3022 right_id,
3023 &[index],
3024 ));
3025 block.body.push(Instruction::binary(
3026 op,
3027 vector_type_id,
3028 column_id_res,
3029 column_id_left,
3030 column_id_right,
3031 ));
3032
3033 self.temp_list.push(column_id_res);
3034 }
3035
3036 block.body.push(Instruction::composite_construct(
3037 result_type_id,
3038 result_id,
3039 &self.temp_list,
3040 ));
3041 }
3042
3043 fn write_vector_scalar_mult(
3045 &mut self,
3046 block: &mut Block,
3047 result_id: Word,
3048 result_type_id: Word,
3049 vector_id: Word,
3050 scalar_id: Word,
3051 vector: &crate::TypeInner,
3052 ) {
3053 let (size, kind) = match *vector {
3054 crate::TypeInner::Vector {
3055 size,
3056 scalar: crate::Scalar { kind, .. },
3057 } => (size, kind),
3058 _ => unreachable!(),
3059 };
3060
3061 let (op, operand_id) = match kind {
3062 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3063 _ => {
3064 let operand_id = self.gen_id();
3065 self.temp_list.clear();
3066 self.temp_list.resize(size as usize, scalar_id);
3067 block.body.push(Instruction::composite_construct(
3068 result_type_id,
3069 operand_id,
3070 &self.temp_list,
3071 ));
3072 (spirv::Op::IMul, operand_id)
3073 }
3074 };
3075
3076 block.body.push(Instruction::binary(
3077 op,
3078 result_type_id,
3079 result_id,
3080 vector_id,
3081 operand_id,
3082 ));
3083 }
3084
3085 #[expect(clippy::too_many_arguments)]
3092 fn write_dot_product(
3093 &mut self,
3094 result_id: Word,
3095 result_type_id: Word,
3096 arg0_id: Word,
3097 arg1_id: Word,
3098 size: u32,
3099 block: &mut Block,
3100 extractor: impl Fn(Word, Word, Word) -> Instruction,
3101 ) {
3102 let mut partial_sum = self.writer.get_constant_null(result_type_id);
3103 let last_component = size - 1;
3104 for index in 0..=last_component {
3105 let a_id = self.gen_id();
3107 block.body.push(extractor(a_id, arg0_id, index));
3108 let b_id = self.gen_id();
3109 block.body.push(extractor(b_id, arg1_id, index));
3110 let prod_id = self.gen_id();
3111 block.body.push(Instruction::binary(
3112 spirv::Op::IMul,
3113 result_type_id,
3114 prod_id,
3115 a_id,
3116 b_id,
3117 ));
3118
3119 let id = if index == last_component {
3121 result_id
3122 } else {
3123 self.gen_id()
3124 };
3125
3126 block.body.push(Instruction::binary(
3128 spirv::Op::IAdd,
3129 result_type_id,
3130 id,
3131 partial_sum,
3132 prod_id,
3133 ));
3134 partial_sum = id;
3136 }
3137 }
3138
3139 fn write_pack4x8_optimized(
3141 &mut self,
3142 block: &mut Block,
3143 result_type_id: u32,
3144 arg0_id: u32,
3145 id: u32,
3146 is_signed: bool,
3147 should_clamp: bool,
3148 ) -> Instruction {
3149 let int_type = if is_signed {
3150 crate::ScalarKind::Sint
3151 } else {
3152 crate::ScalarKind::Uint
3153 };
3154 let wide_vector_type = NumericType::Vector {
3155 size: crate::VectorSize::Quad,
3156 scalar: crate::Scalar {
3157 kind: int_type,
3158 width: 4,
3159 },
3160 };
3161 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3162 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3163 size: crate::VectorSize::Quad,
3164 scalar: crate::Scalar {
3165 kind: crate::ScalarKind::Uint,
3166 width: 1,
3167 },
3168 });
3169
3170 let mut wide_vector = arg0_id;
3171 if should_clamp {
3172 let (min, max, clamp_op) = if is_signed {
3173 (
3174 crate::Literal::I32(-128),
3175 crate::Literal::I32(127),
3176 spirv::GlslStd450Op::SClamp,
3177 )
3178 } else {
3179 (
3180 crate::Literal::U32(0),
3181 crate::Literal::U32(255),
3182 spirv::GlslStd450Op::UClamp,
3183 )
3184 };
3185 let [min, max] = [min, max].map(|lit| {
3186 let scalar = self.writer.get_constant_scalar(lit);
3187 self.writer.get_constant_composite(
3188 LookupType::Local(LocalType::Numeric(wide_vector_type)),
3189 &[scalar; 4],
3190 )
3191 });
3192
3193 let clamp_id = self.gen_id();
3194 block.body.push(Instruction::ext_inst_gl_op(
3195 self.writer.gl450_ext_inst_id,
3196 clamp_op,
3197 wide_vector_type_id,
3198 clamp_id,
3199 &[wide_vector, min, max],
3200 ));
3201
3202 wide_vector = clamp_id;
3203 }
3204
3205 let packed_vector = self.gen_id();
3206 block.body.push(Instruction::unary(
3207 spirv::Op::UConvert, packed_vector_type_id,
3209 packed_vector,
3210 wide_vector,
3211 ));
3212
3213 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3218 }
3219
3220 fn write_pack4x8_polyfill(
3222 &mut self,
3223 block: &mut Block,
3224 result_type_id: u32,
3225 arg0_id: u32,
3226 id: u32,
3227 is_signed: bool,
3228 should_clamp: bool,
3229 ) -> Instruction {
3230 let int_type = if is_signed {
3231 crate::ScalarKind::Sint
3232 } else {
3233 crate::ScalarKind::Uint
3234 };
3235 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3236 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3237 kind: int_type,
3238 width: 4,
3239 }));
3240
3241 let mut last_instruction = Instruction::new(spirv::Op::Nop);
3242
3243 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3244 let mut preresult = zero;
3245 block
3246 .body
3247 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3248
3249 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3250 const VEC_LENGTH: u8 = 4;
3251 for i in 0..u32::from(VEC_LENGTH) {
3252 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3253 let mut extracted = self.gen_id();
3254 block.body.push(Instruction::binary(
3255 spirv::Op::CompositeExtract,
3256 int_type_id,
3257 extracted,
3258 arg0_id,
3259 i,
3260 ));
3261 if is_signed {
3262 let casted = self.gen_id();
3263 block.body.push(Instruction::unary(
3264 spirv::Op::Bitcast,
3265 uint_type_id,
3266 casted,
3267 extracted,
3268 ));
3269 extracted = casted;
3270 }
3271 if should_clamp {
3272 let (min, max, clamp_op) = if is_signed {
3273 (
3274 crate::Literal::I32(-128),
3275 crate::Literal::I32(127),
3276 spirv::GlslStd450Op::SClamp,
3277 )
3278 } else {
3279 (
3280 crate::Literal::U32(0),
3281 crate::Literal::U32(255),
3282 spirv::GlslStd450Op::UClamp,
3283 )
3284 };
3285 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3286
3287 let clamp_id = self.gen_id();
3288 block.body.push(Instruction::ext_inst_gl_op(
3289 self.writer.gl450_ext_inst_id,
3290 clamp_op,
3291 result_type_id,
3292 clamp_id,
3293 &[extracted, min, max],
3294 ));
3295
3296 extracted = clamp_id;
3297 }
3298 let is_last = i == u32::from(VEC_LENGTH - 1);
3299 if is_last {
3300 last_instruction = Instruction::quaternary(
3301 spirv::Op::BitFieldInsert,
3302 result_type_id,
3303 id,
3304 preresult,
3305 extracted,
3306 offset,
3307 eight,
3308 )
3309 } else {
3310 let new_preresult = self.gen_id();
3311 block.body.push(Instruction::quaternary(
3312 spirv::Op::BitFieldInsert,
3313 result_type_id,
3314 new_preresult,
3315 preresult,
3316 extracted,
3317 offset,
3318 eight,
3319 ));
3320 preresult = new_preresult;
3321 }
3322 }
3323 last_instruction
3324 }
3325
3326 fn write_unpack4x8_optimized(
3328 &mut self,
3329 block: &mut Block,
3330 result_type_id: u32,
3331 arg0_id: u32,
3332 id: u32,
3333 is_signed: bool,
3334 ) -> Instruction {
3335 let (int_type, convert_op) = if is_signed {
3336 (crate::ScalarKind::Sint, spirv::Op::SConvert)
3337 } else {
3338 (crate::ScalarKind::Uint, spirv::Op::UConvert)
3339 };
3340
3341 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3342 size: crate::VectorSize::Quad,
3343 scalar: crate::Scalar {
3344 kind: int_type,
3345 width: 1,
3346 },
3347 });
3348
3349 let packed_vector = self.gen_id();
3354 block.body.push(Instruction::unary(
3355 spirv::Op::Bitcast,
3356 packed_vector_type_id,
3357 packed_vector,
3358 arg0_id,
3359 ));
3360
3361 Instruction::unary(convert_op, result_type_id, id, packed_vector)
3362 }
3363
3364 fn write_unpack4x8_polyfill(
3366 &mut self,
3367 block: &mut Block,
3368 result_type_id: u32,
3369 arg0_id: u32,
3370 id: u32,
3371 is_signed: bool,
3372 ) -> Instruction {
3373 let (int_type, extract_op) = if is_signed {
3374 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3375 } else {
3376 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3377 };
3378
3379 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3380
3381 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3382 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3383 kind: int_type,
3384 width: 4,
3385 }));
3386 block
3387 .body
3388 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3389 let arg_id = if is_signed {
3390 let new_arg_id = self.gen_id();
3391 block.body.push(Instruction::unary(
3392 spirv::Op::Bitcast,
3393 sint_type_id,
3394 new_arg_id,
3395 arg0_id,
3396 ));
3397 new_arg_id
3398 } else {
3399 arg0_id
3400 };
3401
3402 const VEC_LENGTH: u8 = 4;
3403 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3404 for (i, part_id) in parts.into_iter().enumerate() {
3405 let index = self
3406 .writer
3407 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3408 block.body.push(Instruction::ternary(
3409 extract_op,
3410 int_type_id,
3411 part_id,
3412 arg_id,
3413 index,
3414 eight,
3415 ));
3416 }
3417
3418 Instruction::composite_construct(result_type_id, id, &parts)
3419 }
3420
3421 fn write_block(
3438 &mut self,
3439 label_id: Word,
3440 naga_block: &crate::Block,
3441 exit: BlockExit,
3442 loop_context: LoopContext,
3443 debug_info: Option<&DebugInfoInner>,
3444 ) -> Result<BlockExitDisposition, Error> {
3445 let mut block = Block::new(label_id);
3446 for (statement, span) in naga_block.span_iter() {
3447 if let (Some(debug_info), false) = (
3448 debug_info,
3449 matches!(
3450 statement,
3451 &(Statement::Block(..)
3452 | Statement::Break
3453 | Statement::Continue
3454 | Statement::Kill
3455 | Statement::Return { .. }
3456 | Statement::Loop { .. })
3457 ),
3458 ) {
3459 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3460 block.body.push(Instruction::line(
3461 debug_info.source_file_id,
3462 loc.line_number,
3463 loc.line_position,
3464 ));
3465 };
3466 match *statement {
3467 Statement::Emit(ref range) => {
3468 for handle in range.clone() {
3469 if !self.expression_constness.is_const(handle) {
3471 self.cache_expression_value(handle, &mut block)?;
3472 }
3473 }
3474 }
3475 Statement::Block(ref block_statements) => {
3476 let scope_id = self.gen_id();
3477 self.function.consume(block, Instruction::branch(scope_id));
3478
3479 let merge_id = self.gen_id();
3480 let merge_used = self.write_block(
3481 scope_id,
3482 block_statements,
3483 BlockExit::Branch { target: merge_id },
3484 loop_context,
3485 debug_info,
3486 )?;
3487
3488 match merge_used {
3489 BlockExitDisposition::Used => {
3490 block = Block::new(merge_id);
3491 }
3492 BlockExitDisposition::Discarded => {
3493 return Ok(BlockExitDisposition::Discarded);
3494 }
3495 }
3496 }
3497 Statement::If {
3498 condition,
3499 ref accept,
3500 ref reject,
3501 } => {
3502 if !(accept.is_empty() && reject.is_empty()) {
3508 let condition_id = self.cached[condition];
3509
3510 let merge_id = self.gen_id();
3511 block.body.push(Instruction::selection_merge(
3512 merge_id,
3513 spirv::SelectionControl::NONE,
3514 ));
3515
3516 let accept_id = if accept.is_empty() {
3517 None
3518 } else {
3519 Some(self.gen_id())
3520 };
3521 let reject_id = if reject.is_empty() {
3522 None
3523 } else {
3524 Some(self.gen_id())
3525 };
3526
3527 self.function.consume(
3528 block,
3529 Instruction::branch_conditional(
3530 condition_id,
3531 accept_id.unwrap_or(merge_id),
3532 reject_id.unwrap_or(merge_id),
3533 ),
3534 );
3535
3536 if let Some(block_id) = accept_id {
3537 let _ = self.write_block(
3542 block_id,
3543 accept,
3544 BlockExit::Branch { target: merge_id },
3545 loop_context,
3546 debug_info,
3547 )?;
3548 }
3549 if let Some(block_id) = reject_id {
3550 let _ = self.write_block(
3555 block_id,
3556 reject,
3557 BlockExit::Branch { target: merge_id },
3558 loop_context,
3559 debug_info,
3560 )?;
3561 }
3562
3563 block = Block::new(merge_id);
3564 }
3565 }
3566 Statement::Switch {
3567 selector,
3568 ref cases,
3569 } => {
3570 let selector_id = self.cached[selector];
3571
3572 let merge_id = self.gen_id();
3573 block.body.push(Instruction::selection_merge(
3574 merge_id,
3575 spirv::SelectionControl::NONE,
3576 ));
3577
3578 let mut default_id = None;
3579 let mut last_id = None;
3581
3582 let mut raw_cases = Vec::with_capacity(cases.len());
3583 let mut case_ids = Vec::with_capacity(cases.len());
3584 for case in cases.iter() {
3585 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3587
3588 if case.fall_through && case.body.is_empty() {
3589 last_id = Some(label_id);
3590 }
3591
3592 case_ids.push(label_id);
3593
3594 match case.value {
3595 crate::SwitchValue::I32(value) => {
3596 raw_cases.push(super::instructions::Case {
3597 value: value as Word,
3598 label_id,
3599 });
3600 }
3601 crate::SwitchValue::U32(value) => {
3602 raw_cases.push(super::instructions::Case { value, label_id });
3603 }
3604 crate::SwitchValue::Default => {
3605 default_id = Some(label_id);
3606 }
3607 }
3608 }
3609
3610 let default_id = default_id.unwrap();
3611
3612 self.function.consume(
3613 block,
3614 Instruction::switch(selector_id, default_id, &raw_cases),
3615 );
3616
3617 let inner_context = LoopContext {
3618 break_id: Some(merge_id),
3619 ..loop_context
3620 };
3621
3622 for (i, (case, label_id)) in cases
3623 .iter()
3624 .zip(case_ids.iter())
3625 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3626 .enumerate()
3627 {
3628 let case_finish_id = if case.fall_through {
3629 case_ids[i + 1]
3630 } else {
3631 merge_id
3632 };
3633 let _ = self.write_block(
3642 *label_id,
3643 &case.body,
3644 BlockExit::Branch {
3645 target: case_finish_id,
3646 },
3647 inner_context,
3648 debug_info,
3649 )?;
3650 }
3651
3652 block = Block::new(merge_id);
3653 }
3654 Statement::Loop {
3655 ref body,
3656 ref continuing,
3657 break_if,
3658 } => {
3659 let preamble_id = self.gen_id();
3660 self.function
3661 .consume(block, Instruction::branch(preamble_id));
3662
3663 let merge_id = self.gen_id();
3664 let body_id = self.gen_id();
3665 let continuing_id = self.gen_id();
3666
3667 block = Block::new(preamble_id);
3670 if let Some(debug_info) = debug_info {
3673 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3674 block.body.push(Instruction::line(
3675 debug_info.source_file_id,
3676 loc.line_number,
3677 loc.line_position,
3678 ))
3679 }
3680 block.body.push(Instruction::loop_merge(
3681 merge_id,
3682 continuing_id,
3683 spirv::SelectionControl::NONE,
3684 ));
3685
3686 if self.force_loop_bounding {
3687 block = self.write_force_bounded_loop_instructions(block, merge_id);
3688 }
3689 self.function.consume(block, Instruction::branch(body_id));
3690
3691 let _ = self.write_block(
3695 body_id,
3696 body,
3697 BlockExit::Branch {
3698 target: continuing_id,
3699 },
3700 LoopContext {
3701 continuing_id: Some(continuing_id),
3702 break_id: Some(merge_id),
3703 },
3704 debug_info,
3705 )?;
3706
3707 let exit = match break_if {
3708 Some(condition) => BlockExit::BreakIf {
3709 condition,
3710 preamble_id,
3711 },
3712 None => BlockExit::Branch {
3713 target: preamble_id,
3714 },
3715 };
3716
3717 let _ = self.write_block(
3721 continuing_id,
3722 continuing,
3723 exit,
3724 LoopContext {
3725 continuing_id: None,
3726 break_id: Some(merge_id),
3727 },
3728 debug_info,
3729 )?;
3730
3731 block = Block::new(merge_id);
3732 }
3733 Statement::Break => {
3734 self.function
3735 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3736 return Ok(BlockExitDisposition::Discarded);
3737 }
3738 Statement::Continue => {
3739 self.function.consume(
3740 block,
3741 Instruction::branch(loop_context.continuing_id.unwrap()),
3742 );
3743 return Ok(BlockExitDisposition::Discarded);
3744 }
3745 Statement::Return { value: Some(value) } => {
3746 let value_id = self.cached[value];
3747 let instruction = match self.function.entry_point_context {
3748 Some(ref context) => self.writer.write_entry_point_return(
3751 value_id,
3752 self.ir_function.result.as_ref().unwrap(),
3753 &context.results,
3754 &mut block.body,
3755 )?,
3756 None => Instruction::return_value(value_id),
3757 };
3758 self.function.consume(block, instruction);
3759 return Ok(BlockExitDisposition::Discarded);
3760 }
3761 Statement::Return { value: None } => {
3762 self.function.consume(block, Instruction::return_void());
3763 return Ok(BlockExitDisposition::Discarded);
3764 }
3765 Statement::Kill => {
3766 self.function.consume(block, Instruction::kill());
3767 return Ok(BlockExitDisposition::Discarded);
3768 }
3769 Statement::ControlBarrier(flags) => {
3770 self.writer.write_control_barrier(flags, &mut block.body);
3771 }
3772 Statement::MemoryBarrier(flags) => {
3773 self.writer.write_memory_barrier(flags, &mut block);
3774 }
3775 Statement::Store { pointer, value } => {
3776 let value_id = self.cached[value];
3777 match self.write_access_chain(
3778 pointer,
3779 &mut block,
3780 AccessTypeAdjustment::None,
3781 )? {
3782 ExpressionPointer::Ready { pointer_id } => {
3783 let atomic_space = match *self.fun_info[pointer]
3784 .ty
3785 .inner_with(&self.ir_module.types)
3786 {
3787 crate::TypeInner::Pointer { base, space } => {
3788 match self.ir_module.types[base].inner {
3789 crate::TypeInner::Atomic { .. } => Some(space),
3790 _ => None,
3791 }
3792 }
3793 _ => None,
3794 };
3795 let instruction = if let Some(space) = atomic_space {
3796 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3797 let scope_constant_id = self.get_scope_constant(scope as u32);
3798 let semantics_id = self.get_index_constant(semantics.bits());
3799 Instruction::atomic_store(
3800 pointer_id,
3801 scope_constant_id,
3802 semantics_id,
3803 value_id,
3804 )
3805 } else {
3806 Instruction::store(pointer_id, value_id, None)
3807 };
3808 block.body.push(instruction);
3809 }
3810 ExpressionPointer::Conditional { condition, access } => {
3811 let mut selection = Selection::start(&mut block, ());
3812 selection.if_true(self, condition, ());
3813
3814 let pointer_id = access.result_id.unwrap();
3816 selection.block().body.push(access);
3817 selection
3818 .block()
3819 .body
3820 .push(Instruction::store(pointer_id, value_id, None));
3821
3822 selection.finish(self, ());
3825 }
3826 };
3827 }
3828 Statement::ImageStore {
3829 image,
3830 coordinate,
3831 array_index,
3832 value,
3833 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3834 Statement::Call {
3835 function: local_function,
3836 ref arguments,
3837 result,
3838 } => {
3839 let id = self.gen_id();
3840 self.temp_list.clear();
3841 for &argument in arguments {
3842 self.temp_list.push(self.cached[argument]);
3843 }
3844
3845 let type_id = match result {
3846 Some(expr) => {
3847 self.cached[expr] = id;
3848 self.get_expression_type_id(&self.fun_info[expr].ty)
3849 }
3850 None => self.writer.void_type,
3851 };
3852
3853 block.body.push(Instruction::function_call(
3854 type_id,
3855 id,
3856 self.writer.lookup_function[&local_function],
3857 &self.temp_list,
3858 ));
3859 }
3860 Statement::Atomic {
3861 pointer,
3862 ref fun,
3863 value,
3864 result,
3865 } => {
3866 let id = self.gen_id();
3867 let result_type_id =
3871 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3872
3873 if let Some(result) = result {
3874 self.cached[result] = id;
3875 }
3876
3877 let pointer_id = match self.write_access_chain(
3878 pointer,
3879 &mut block,
3880 AccessTypeAdjustment::None,
3881 )? {
3882 ExpressionPointer::Ready { pointer_id } => pointer_id,
3883 ExpressionPointer::Conditional { .. } => {
3884 return Err(Error::FeatureNotImplemented(
3885 "Atomics out-of-bounds handling",
3886 ));
3887 }
3888 };
3889
3890 let space = self.fun_info[pointer]
3891 .ty
3892 .inner_with(&self.ir_module.types)
3893 .pointer_space()
3894 .unwrap();
3895 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3896 let scope_constant_id = self.get_scope_constant(scope as u32);
3897 let semantics_id = self.get_index_constant(semantics.bits());
3898 let value_id = self.cached[value];
3899 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3900
3901 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3902 return Err(Error::FeatureNotImplemented(
3903 "Atomics with non-scalar values",
3904 ));
3905 };
3906
3907 let instruction = match *fun {
3908 crate::AtomicFunction::Add => {
3909 let spirv_op = match scalar.kind {
3910 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3911 spirv::Op::AtomicIAdd
3912 }
3913 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3914 _ => unimplemented!(),
3915 };
3916 Instruction::atomic_binary(
3917 spirv_op,
3918 result_type_id,
3919 id,
3920 pointer_id,
3921 scope_constant_id,
3922 semantics_id,
3923 value_id,
3924 )
3925 }
3926 crate::AtomicFunction::Subtract => {
3927 let (spirv_op, value_id) = match scalar.kind {
3928 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3929 (spirv::Op::AtomicISub, value_id)
3930 }
3931 crate::ScalarKind::Float => {
3932 let neg_result_id = self.gen_id();
3935 block.body.push(Instruction::unary(
3936 spirv::Op::FNegate,
3937 result_type_id,
3938 neg_result_id,
3939 value_id,
3940 ));
3941 (spirv::Op::AtomicFAddEXT, neg_result_id)
3942 }
3943 _ => unimplemented!(),
3944 };
3945 Instruction::atomic_binary(
3946 spirv_op,
3947 result_type_id,
3948 id,
3949 pointer_id,
3950 scope_constant_id,
3951 semantics_id,
3952 value_id,
3953 )
3954 }
3955 crate::AtomicFunction::And => {
3956 let spirv_op = match scalar.kind {
3957 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3958 spirv::Op::AtomicAnd
3959 }
3960 _ => unimplemented!(),
3961 };
3962 Instruction::atomic_binary(
3963 spirv_op,
3964 result_type_id,
3965 id,
3966 pointer_id,
3967 scope_constant_id,
3968 semantics_id,
3969 value_id,
3970 )
3971 }
3972 crate::AtomicFunction::InclusiveOr => {
3973 let spirv_op = match scalar.kind {
3974 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3975 spirv::Op::AtomicOr
3976 }
3977 _ => unimplemented!(),
3978 };
3979 Instruction::atomic_binary(
3980 spirv_op,
3981 result_type_id,
3982 id,
3983 pointer_id,
3984 scope_constant_id,
3985 semantics_id,
3986 value_id,
3987 )
3988 }
3989 crate::AtomicFunction::ExclusiveOr => {
3990 let spirv_op = match scalar.kind {
3991 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3992 spirv::Op::AtomicXor
3993 }
3994 _ => unimplemented!(),
3995 };
3996 Instruction::atomic_binary(
3997 spirv_op,
3998 result_type_id,
3999 id,
4000 pointer_id,
4001 scope_constant_id,
4002 semantics_id,
4003 value_id,
4004 )
4005 }
4006 crate::AtomicFunction::Min => {
4007 let spirv_op = match scalar.kind {
4008 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4009 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4010 _ => unimplemented!(),
4011 };
4012 Instruction::atomic_binary(
4013 spirv_op,
4014 result_type_id,
4015 id,
4016 pointer_id,
4017 scope_constant_id,
4018 semantics_id,
4019 value_id,
4020 )
4021 }
4022 crate::AtomicFunction::Max => {
4023 let spirv_op = match scalar.kind {
4024 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4025 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4026 _ => unimplemented!(),
4027 };
4028 Instruction::atomic_binary(
4029 spirv_op,
4030 result_type_id,
4031 id,
4032 pointer_id,
4033 scope_constant_id,
4034 semantics_id,
4035 value_id,
4036 )
4037 }
4038 crate::AtomicFunction::Exchange { compare: None } => {
4039 Instruction::atomic_binary(
4040 spirv::Op::AtomicExchange,
4041 result_type_id,
4042 id,
4043 pointer_id,
4044 scope_constant_id,
4045 semantics_id,
4046 value_id,
4047 )
4048 }
4049 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4050 let scalar_type_id =
4051 self.get_numeric_type_id(NumericType::Scalar(scalar));
4052 let bool_type_id =
4053 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4054
4055 let cas_result_id = self.gen_id();
4056 let equality_result_id = self.gen_id();
4057 let equality_operator = match scalar.kind {
4058 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4059 spirv::Op::IEqual
4060 }
4061 _ => unimplemented!(),
4062 };
4063
4064 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4065 cas_instr.set_type(scalar_type_id);
4066 cas_instr.set_result(cas_result_id);
4067 cas_instr.add_operand(pointer_id);
4068 cas_instr.add_operand(scope_constant_id);
4069 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
4072 cas_instr.add_operand(self.cached[cmp]);
4073 block.body.push(cas_instr);
4074 block.body.push(Instruction::binary(
4075 equality_operator,
4076 bool_type_id,
4077 equality_result_id,
4078 cas_result_id,
4079 self.cached[cmp],
4080 ));
4081 Instruction::composite_construct(
4082 result_type_id,
4083 id,
4084 &[cas_result_id, equality_result_id],
4085 )
4086 }
4087 };
4088
4089 block.body.push(instruction);
4090 }
4091 Statement::ImageAtomic {
4092 image,
4093 coordinate,
4094 array_index,
4095 fun,
4096 value,
4097 } => {
4098 self.write_image_atomic(
4099 image,
4100 coordinate,
4101 array_index,
4102 fun,
4103 value,
4104 &mut block,
4105 )?;
4106 }
4107 Statement::WorkGroupUniformLoad { pointer, result } => {
4108 self.writer
4109 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4110 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4111 let id = self.write_checked_load(
4114 pointer,
4115 &mut block,
4116 AccessTypeAdjustment::None,
4117 result_type_id,
4118 )?;
4119 self.cached[result] = id;
4120 self.writer
4121 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4122 }
4123 Statement::RayQuery { query, ref fun } => {
4124 self.write_ray_query_function(query, fun, &mut block);
4125 }
4126 Statement::SubgroupBallot {
4127 result,
4128 ref predicate,
4129 } => {
4130 self.write_subgroup_ballot(predicate, result, &mut block)?;
4131 }
4132 Statement::SubgroupCollectiveOperation {
4133 ref op,
4134 ref collective_op,
4135 argument,
4136 result,
4137 } => {
4138 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4139 }
4140 Statement::SubgroupGather {
4141 ref mode,
4142 argument,
4143 result,
4144 } => {
4145 self.write_subgroup_gather(mode, argument, result, &mut block)?;
4146 }
4147 Statement::CooperativeStore { target, ref data } => {
4148 let target_id = self.cached[target];
4149 let layout = if data.row_major {
4150 spirv::CooperativeMatrixLayout::RowMajorKHR
4151 } else {
4152 spirv::CooperativeMatrixLayout::ColumnMajorKHR
4153 };
4154 let layout_id = self.get_index_constant(layout as u32);
4155 let stride_id = self.cached[data.stride];
4156 match self.write_access_chain(
4157 data.pointer,
4158 &mut block,
4159 AccessTypeAdjustment::None,
4160 )? {
4161 ExpressionPointer::Ready { pointer_id } => {
4162 block.body.push(Instruction::coop_store(
4163 target_id, pointer_id, layout_id, stride_id,
4164 ));
4165 }
4166 ExpressionPointer::Conditional { condition, access } => {
4167 let mut selection = Selection::start(&mut block, ());
4168 selection.if_true(self, condition, ());
4169
4170 let pointer_id = access.result_id.unwrap();
4172 selection.block().body.push(access);
4173 selection.block().body.push(Instruction::coop_store(
4174 target_id, pointer_id, layout_id, stride_id,
4175 ));
4176
4177 selection.finish(self, ());
4180 }
4181 };
4182 }
4183 Statement::RayPipelineFunction(ref fun) => {
4184 self.write_ray_tracing_pipeline_function(fun, &mut block);
4185 }
4186 }
4187 }
4188
4189 let termination = match exit {
4190 BlockExit::Return => match self.ir_function.result {
4193 Some(ref result) if self.function.entry_point_context.is_none() => {
4194 let type_id = self.get_handle_type_id(result.ty);
4195 let null_id = self.writer.get_constant_null(type_id);
4196 Instruction::return_value(null_id)
4197 }
4198 _ => Instruction::return_void(),
4199 },
4200 BlockExit::Branch { target } => Instruction::branch(target),
4201 BlockExit::BreakIf {
4202 condition,
4203 preamble_id,
4204 } => {
4205 let condition_id = self.cached[condition];
4206
4207 Instruction::branch_conditional(
4208 condition_id,
4209 loop_context.break_id.unwrap(),
4210 preamble_id,
4211 )
4212 }
4213 };
4214
4215 self.function.consume(block, termination);
4216 Ok(BlockExitDisposition::Used)
4217 }
4218
4219 pub(super) fn write_function_body(
4220 &mut self,
4221 entry_id: Word,
4222 debug_info: Option<&DebugInfoInner>,
4223 ) -> Result<(), Error> {
4224 let _ = self.write_block(
4227 entry_id,
4228 &self.ir_function.body,
4229 BlockExit::Return,
4230 LoopContext::default(),
4231 debug_info,
4232 )?;
4233
4234 Ok(())
4235 }
4236}