1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12 BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13 ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16 arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17 proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21 match *type_inner {
22 crate::TypeInner::Scalar(_) => Dimension::Scalar,
23 crate::TypeInner::Vector { .. } => Dimension::Vector,
24 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25 crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26 _ => unreachable!(),
27 }
28}
29
30#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43 None,
52
53 IntroducePointer(spirv::StorageClass),
77
78 UseStd140CompatType,
88}
89
90enum ExpressionPointer {
94 Ready { pointer_id: Word },
97
98 Conditional {
104 condition: Word,
105 access: Instruction,
106 },
107}
108
109enum BlockExit {
111 Return,
113 Branch {
115 target: Word,
117 },
118 BreakIf {
124 condition: Handle<crate::Expression>,
126 preamble_id: Word,
128 },
129}
130
131#[must_use]
142enum BlockExitDisposition {
143 Used,
147
148 Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157 continuing_id: Option<Word>,
158 break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163 pub source_code: &'a str,
164 pub source_file_id: Word,
165}
166
167impl Writer {
168 fn write_epilogue_position_y_flip(
173 &mut self,
174 position_id: Word,
175 body: &mut Vec<Instruction>,
176 ) -> Result<(), Error> {
177 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178 let index_y_id = self.get_index_constant(1);
179 let access_id = self.id_gen.next();
180 body.push(Instruction::access_chain(
181 float_ptr_type_id,
182 access_id,
183 position_id,
184 &[index_y_id],
185 ));
186
187 let float_type_id = self.get_f32_type_id();
188 let load_id = self.id_gen.next();
189 body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191 let neg_id = self.id_gen.next();
192 body.push(Instruction::unary(
193 spirv::Op::FNegate,
194 float_type_id,
195 neg_id,
196 load_id,
197 ));
198
199 body.push(Instruction::store(access_id, neg_id, None));
200 Ok(())
201 }
202
203 fn write_epilogue_frag_depth_clamp(
205 &mut self,
206 frag_depth_id: Word,
207 body: &mut Vec<Instruction>,
208 ) -> Result<(), Error> {
209 let float_type_id = self.get_f32_type_id();
210 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213 let original_id = self.id_gen.next();
214 body.push(Instruction::load(
215 float_type_id,
216 original_id,
217 frag_depth_id,
218 None,
219 ));
220
221 let clamp_id = self.id_gen.next();
222 body.push(Instruction::ext_inst_gl_op(
223 self.gl450_ext_inst_id,
224 spirv::GLOp::FClamp,
225 float_type_id,
226 clamp_id,
227 &[original_id, zero_scalar_id, one_scalar_id],
228 ));
229
230 body.push(Instruction::store(frag_depth_id, clamp_id, None));
231 Ok(())
232 }
233
234 fn write_entry_point_return(
235 &mut self,
236 value_id: Word,
237 ir_result: &crate::FunctionResult,
238 result_members: &[ResultMember],
239 body: &mut Vec<Instruction>,
240 task_payload: Option<Word>,
241 ) -> Result<Instruction, Error> {
242 for (index, res_member) in result_members.iter().enumerate() {
243 if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
245 continue;
246 }
247 let member_value_id = match ir_result.binding {
248 Some(_) => value_id,
249 None => {
250 let member_value_id = self.id_gen.next();
251 body.push(Instruction::composite_extract(
252 res_member.type_id,
253 member_value_id,
254 value_id,
255 &[index as u32],
256 ));
257 member_value_id
258 }
259 };
260
261 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
262
263 match res_member.built_in {
264 Some(crate::BuiltIn::Position { .. })
265 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
266 {
267 self.write_epilogue_position_y_flip(res_member.id, body)?;
268 }
269 Some(crate::BuiltIn::FragDepth)
270 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
271 {
272 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
273 }
274 _ => {}
275 }
276 }
277 self.try_write_entry_point_task_return(
278 value_id,
279 ir_result,
280 result_members,
281 body,
282 task_payload,
283 )
284 }
285}
286
287impl BlockContext<'_> {
288 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
301 let uint_type_id = self.writer.get_u32_type_id();
302 let uint2_type_id = self.writer.get_vec2u_type_id();
303 let uint2_ptr_type_id = self
304 .writer
305 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
306 let bool_type_id = self.writer.get_bool_type_id();
307 let bool2_type_id = self.writer.get_vec2_bool_type_id();
308 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
309 let zero_uint2_const_id = self.writer.get_constant_composite(
310 LookupType::Local(LocalType::Numeric(NumericType::Vector {
311 size: crate::VectorSize::Bi,
312 scalar: crate::Scalar::U32,
313 })),
314 &[zero_uint_const_id, zero_uint_const_id],
315 );
316 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
317 let max_uint_const_id = self
318 .writer
319 .get_constant_scalar(crate::Literal::U32(u32::MAX));
320 let max_uint2_const_id = self.writer.get_constant_composite(
321 LookupType::Local(LocalType::Numeric(NumericType::Vector {
322 size: crate::VectorSize::Bi,
323 scalar: crate::Scalar::U32,
324 })),
325 &[max_uint_const_id, max_uint_const_id],
326 );
327
328 let loop_counter_var_id = self.gen_id();
329 if self.writer.flags.contains(WriterFlags::DEBUG) {
330 self.writer
331 .debugs
332 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
333 }
334 let var = super::LocalVariable {
335 id: loop_counter_var_id,
336 instruction: Instruction::variable(
337 uint2_ptr_type_id,
338 loop_counter_var_id,
339 spirv::StorageClass::Function,
340 Some(max_uint2_const_id),
341 ),
342 };
343 self.function.force_loop_bounding_vars.push(var);
344
345 let break_if_block = self.gen_id();
346
347 self.function
348 .consume(block, Instruction::branch(break_if_block));
349 block = Block::new(break_if_block);
350
351 let load_id = self.gen_id();
354 block.body.push(Instruction::load(
355 uint2_type_id,
356 load_id,
357 loop_counter_var_id,
358 None,
359 ));
360
361 let eq_id = self.gen_id();
364 block.body.push(Instruction::binary(
365 spirv::Op::IEqual,
366 bool2_type_id,
367 eq_id,
368 zero_uint2_const_id,
369 load_id,
370 ));
371 let all_eq_id = self.gen_id();
372 block.body.push(Instruction::relational(
373 spirv::Op::All,
374 bool_type_id,
375 all_eq_id,
376 eq_id,
377 ));
378
379 let inc_counter_block_id = self.gen_id();
380 block.body.push(Instruction::selection_merge(
381 inc_counter_block_id,
382 spirv::SelectionControl::empty(),
383 ));
384 self.function.consume(
385 block,
386 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
387 );
388 block = Block::new(inc_counter_block_id);
389
390 let low_id = self.gen_id();
396 block.body.push(Instruction::composite_extract(
397 uint_type_id,
398 low_id,
399 load_id,
400 &[1],
401 ));
402 let low_overflow_id = self.gen_id();
403 block.body.push(Instruction::binary(
404 spirv::Op::IEqual,
405 bool_type_id,
406 low_overflow_id,
407 low_id,
408 zero_uint_const_id,
409 ));
410 let carry_bit_id = self.gen_id();
411 block.body.push(Instruction::select(
412 uint_type_id,
413 carry_bit_id,
414 low_overflow_id,
415 one_uint_const_id,
416 zero_uint_const_id,
417 ));
418 let decrement_id = self.gen_id();
419 block.body.push(Instruction::composite_construct(
420 uint2_type_id,
421 decrement_id,
422 &[carry_bit_id, one_uint_const_id],
423 ));
424 let result_id = self.gen_id();
425 block.body.push(Instruction::binary(
426 spirv::Op::ISub,
427 uint2_type_id,
428 result_id,
429 load_id,
430 decrement_id,
431 ));
432 block
433 .body
434 .push(Instruction::store(loop_counter_var_id, result_id, None));
435
436 block
437 }
438
439 fn maybe_write_uniform_matcx2_dynamic_access(
456 &mut self,
457 pointer: Handle<crate::Expression>,
458 block: &mut Block,
459 ) -> Result<Option<Word>, Error> {
460 let (column_pointer, component_index) = match self.fun_info[pointer]
466 .ty
467 .inner_with(&self.ir_module.types)
468 .pointer_base_type()
469 {
470 Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
471 crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
472 crate::Expression::Access { base, index } => {
473 (base, Some(GuardedIndex::Expression(index)))
474 }
475 crate::Expression::AccessIndex { base, index } => {
476 (base, Some(GuardedIndex::Known(index)))
477 }
478 _ => return Ok(None),
479 },
480 crate::TypeInner::Vector { .. } => (pointer, None),
481 _ => return Ok(None),
482 },
483 None => return Ok(None),
484 };
485
486 let crate::Expression::Access {
489 base: matrix_pointer,
490 index: column_index,
491 } = self.ir_function.expressions[column_pointer]
492 else {
493 return Ok(None);
494 };
495
496 let crate::TypeInner::Pointer {
498 base: matrix_pointer_base_type,
499 space: crate::AddressSpace::Uniform,
500 } = *self.fun_info[matrix_pointer]
501 .ty
502 .inner_with(&self.ir_module.types)
503 else {
504 return Ok(None);
505 };
506
507 let crate::TypeInner::Matrix {
509 columns,
510 rows: rows @ crate::VectorSize::Bi,
511 scalar,
512 } = self.ir_module.types[matrix_pointer_base_type].inner
513 else {
514 return Ok(None);
515 };
516
517 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
518 columns,
519 rows,
520 scalar,
521 });
522 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
523 let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
524 let get_column_function_id = self.writer.wrapped_functions
525 [&WrappedFunction::MatCx2GetColumn {
526 r#type: matrix_pointer_base_type,
527 }];
528
529 let matrix_load_id = self.write_checked_load(
530 matrix_pointer,
531 block,
532 AccessTypeAdjustment::None,
533 matrix_type_id,
534 )?;
535
536 let column_index_id = match *self.fun_info[column_index]
539 .ty
540 .inner_with(&self.ir_module.types)
541 {
542 crate::TypeInner::Scalar(crate::Scalar {
543 kind: crate::ScalarKind::Uint,
544 ..
545 }) => self.cached[column_index],
546 crate::TypeInner::Scalar(crate::Scalar {
547 kind: crate::ScalarKind::Sint,
548 ..
549 }) => {
550 let cast_id = self.gen_id();
551 let u32_type_id = self.writer.get_u32_type_id();
552 block.body.push(Instruction::unary(
553 spirv::Op::Bitcast,
554 u32_type_id,
555 cast_id,
556 self.cached[column_index],
557 ));
558 cast_id
559 }
560 _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
561 };
562 let column_id = self.gen_id();
563 block.body.push(Instruction::function_call(
564 column_type_id,
565 column_id,
566 get_column_function_id,
567 &[matrix_load_id, column_index_id],
568 ));
569 let result_id = match component_index {
570 Some(index) => self.write_vector_access(
571 component_type_id,
572 column_pointer,
573 Some(column_id),
574 index,
575 block,
576 )?,
577 None => column_id,
578 };
579
580 Ok(Some(result_id))
581 }
582
583 fn maybe_write_load_uniform_matcx2_struct_member(
595 &mut self,
596 pointer: Handle<crate::Expression>,
597 block: &mut Block,
598 ) -> Result<Option<Word>, Error> {
599 let crate::TypeInner::Pointer {
601 base: matrix_type,
602 space: space @ crate::AddressSpace::Uniform,
603 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
604 else {
605 return Ok(None);
606 };
607
608 let crate::TypeInner::Matrix {
609 columns,
610 rows: rows @ crate::VectorSize::Bi,
611 scalar,
612 } = self.ir_module.types[matrix_type].inner
613 else {
614 return Ok(None);
615 };
616
617 let crate::Expression::AccessIndex {
620 base: struct_pointer,
621 index: member_index,
622 } = self.ir_function.expressions[pointer]
623 else {
624 return Ok(None);
625 };
626
627 let crate::TypeInner::Pointer {
628 base: struct_type, ..
629 } = *self.fun_info[struct_pointer]
630 .ty
631 .inner_with(&self.ir_module.types)
632 else {
633 return Ok(None);
634 };
635
636 let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
637 return Ok(None);
638 };
639
640 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
641 columns,
642 rows,
643 scalar,
644 });
645 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
646 let column_pointer_type_id =
647 self.get_pointer_type_id(column_type_id, map_storage_class(space));
648 let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
649 [member_index as usize];
650 let column_indices = (0..columns as u32)
651 .map(|c| self.get_index_constant(column0_index + c))
652 .collect::<ArrayVec<_, 4>>();
653
654 let load_mat_from_struct =
657 |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
658 let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
659 for index in &column_indices {
660 let column_pointer_id = id_gen.next();
661 block.body.push(Instruction::access_chain(
662 column_pointer_type_id,
663 column_pointer_id,
664 struct_pointer_id,
665 &[*index],
666 ));
667 let column_id = id_gen.next();
668 block.body.push(Instruction::load(
669 column_type_id,
670 column_id,
671 column_pointer_id,
672 None,
673 ));
674 column_ids.push(column_id);
675 }
676 let result_id = id_gen.next();
677 block.body.push(Instruction::composite_construct(
678 matrix_type_id,
679 result_id,
680 &column_ids,
681 ));
682 result_id
683 };
684
685 let result_id = match self.write_access_chain(
686 struct_pointer,
687 block,
688 AccessTypeAdjustment::UseStd140CompatType,
689 )? {
690 ExpressionPointer::Ready { pointer_id } => {
691 load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
692 }
693 ExpressionPointer::Conditional { condition, access } => self
694 .write_conditional_indexed_load(
695 matrix_type_id,
696 condition,
697 block,
698 |id_gen, block| {
699 let pointer_id = access.result_id.unwrap();
700 block.body.push(access);
701 load_mat_from_struct(pointer_id, id_gen, block)
702 },
703 ),
704 };
705
706 Ok(Some(result_id))
707 }
708
709 pub(super) fn cache_expression_value(
711 &mut self,
712 expr_handle: Handle<crate::Expression>,
713 block: &mut Block,
714 ) -> Result<(), Error> {
715 let is_named_expression = self
716 .ir_function
717 .named_expressions
718 .contains_key(&expr_handle);
719
720 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
721 return Ok(());
722 }
723
724 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
725 let id = match self.ir_function.expressions[expr_handle] {
726 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
727 crate::Expression::Constant(handle) => {
728 let init = self.ir_module.constants[handle].init;
729 self.writer.constant_ids[init]
730 }
731 crate::Expression::Override(_) => return Err(Error::Override),
732 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
733 crate::Expression::Compose { ty, ref components } => {
734 self.temp_list.clear();
735 if self.expression_constness.is_const(expr_handle) {
736 self.temp_list.extend(
737 crate::proc::flatten_compose(
738 ty,
739 components,
740 &self.ir_function.expressions,
741 &self.ir_module.types,
742 )
743 .map(|component| self.cached[component]),
744 );
745 self.writer
746 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
747 } else {
748 self.temp_list
749 .extend(components.iter().map(|&component| self.cached[component]));
750
751 let id = self.gen_id();
752 block.body.push(Instruction::composite_construct(
753 result_type_id,
754 id,
755 &self.temp_list,
756 ));
757 id
758 }
759 }
760 crate::Expression::Splat { size, value } => {
761 let value_id = self.cached[value];
762 let components = &[value_id; 4][..size as usize];
763
764 if self.expression_constness.is_const(expr_handle) {
765 let ty = self
766 .writer
767 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
768 self.writer.get_constant_composite(ty, components)
769 } else {
770 let id = self.gen_id();
771 block.body.push(Instruction::composite_construct(
772 result_type_id,
773 id,
774 components,
775 ));
776 id
777 }
778 }
779 crate::Expression::Access { base, index } => {
780 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
781 match *base_ty_inner {
782 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
783 0
789 }
790 _ if self.function.spilled_accesses.contains(base) => {
791 self.function.spilled_accesses.insert(expr_handle);
799 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
800 }
801 crate::TypeInner::Vector { .. } => self.write_vector_access(
802 result_type_id,
803 base,
804 None,
805 GuardedIndex::Expression(index),
806 block,
807 )?,
808 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
809 match GuardedIndex::from_expression(
811 index,
812 &self.ir_function.expressions,
813 self.ir_module,
814 ) {
815 GuardedIndex::Known(value) => {
816 let id = self.gen_id();
826 let base_id = self.cached[base];
827 block.body.push(Instruction::composite_extract(
828 result_type_id,
829 id,
830 base_id,
831 &[value],
832 ));
833 id
834 }
835 GuardedIndex::Expression(_) => {
836 self.spill_to_internal_variable(base, block);
843
844 self.function.spilled_accesses.insert(expr_handle);
847 self.maybe_access_spilled_composite(
848 expr_handle,
849 block,
850 result_type_id,
851 )?
852 }
853 }
854 }
855 crate::TypeInner::BindingArray {
856 base: binding_type, ..
857 } => {
858 let result_id = match self.write_access_chain(
861 expr_handle,
862 block,
863 AccessTypeAdjustment::IntroducePointer(
864 spirv::StorageClass::UniformConstant,
865 ),
866 )? {
867 ExpressionPointer::Ready { pointer_id } => pointer_id,
868 ExpressionPointer::Conditional { .. } => {
869 return Err(Error::FeatureNotImplemented(
870 "Texture array out-of-bounds handling",
871 ));
872 }
873 };
874
875 let binding_type_id = self.get_handle_type_id(binding_type);
876
877 let load_id = self.gen_id();
878 block.body.push(Instruction::load(
879 binding_type_id,
880 load_id,
881 result_id,
882 None,
883 ));
884
885 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
889 self.writer
890 .decorate_non_uniform_binding_array_access(load_id)?;
891 }
892
893 load_id
894 }
895 ref other => {
896 log::error!(
897 "Unable to access base {:?} of type {:?}",
898 self.ir_function.expressions[base],
899 other
900 );
901 return Err(Error::Validation(
902 "only vectors and arrays may be dynamically indexed by value",
903 ));
904 }
905 }
906 }
907 crate::Expression::AccessIndex { base, index } => {
908 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
909 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
910 0
916 }
917 _ if self.function.spilled_accesses.contains(base) => {
918 self.function.spilled_accesses.insert(expr_handle);
926 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
927 }
928 crate::TypeInner::Vector { .. }
929 | crate::TypeInner::Matrix { .. }
930 | crate::TypeInner::Array { .. }
931 | crate::TypeInner::Struct { .. } => {
932 let id = self.gen_id();
937 let base_id = self.cached[base];
938 block.body.push(Instruction::composite_extract(
939 result_type_id,
940 id,
941 base_id,
942 &[index],
943 ));
944 id
945 }
946 crate::TypeInner::BindingArray {
947 base: binding_type, ..
948 } => {
949 let result_id = match self.write_access_chain(
952 expr_handle,
953 block,
954 AccessTypeAdjustment::IntroducePointer(
955 spirv::StorageClass::UniformConstant,
956 ),
957 )? {
958 ExpressionPointer::Ready { pointer_id } => pointer_id,
959 ExpressionPointer::Conditional { .. } => {
960 return Err(Error::FeatureNotImplemented(
961 "Texture array out-of-bounds handling",
962 ));
963 }
964 };
965
966 let binding_type_id = self.get_handle_type_id(binding_type);
967
968 let load_id = self.gen_id();
969 block.body.push(Instruction::load(
970 binding_type_id,
971 load_id,
972 result_id,
973 None,
974 ));
975
976 load_id
977 }
978 ref other => {
979 log::error!("Unable to access index of {other:?}");
980 return Err(Error::FeatureNotImplemented("access index for type"));
981 }
982 }
983 }
984 crate::Expression::GlobalVariable(handle) => {
985 self.writer.global_variables[handle].access_id
986 }
987 crate::Expression::Swizzle {
988 size,
989 vector,
990 pattern,
991 } => {
992 let vector_id = self.cached[vector];
993 self.temp_list.clear();
994 for &sc in pattern[..size as usize].iter() {
995 self.temp_list.push(sc as Word);
996 }
997 let id = self.gen_id();
998 block.body.push(Instruction::vector_shuffle(
999 result_type_id,
1000 id,
1001 vector_id,
1002 vector_id,
1003 &self.temp_list,
1004 ));
1005 id
1006 }
1007 crate::Expression::Unary { op, expr } => {
1008 let id = self.gen_id();
1009 let expr_id = self.cached[expr];
1010 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1011
1012 let spirv_op = match op {
1013 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1014 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1015 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1016 _ => return Err(Error::Validation("Unexpected kind for negation")),
1017 },
1018 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1019 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1020 };
1021
1022 block
1023 .body
1024 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1025 id
1026 }
1027 crate::Expression::Binary { op, left, right } => {
1028 let id = self.gen_id();
1029 let left_id = self.cached[left];
1030 let right_id = self.cached[right];
1031 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1032 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1033
1034 if let Some(function_id) =
1035 self.writer
1036 .wrapped_functions
1037 .get(&WrappedFunction::BinaryOp {
1038 op,
1039 left_type_id,
1040 right_type_id,
1041 })
1042 {
1043 block.body.push(Instruction::function_call(
1044 result_type_id,
1045 id,
1046 *function_id,
1047 &[left_id, right_id],
1048 ));
1049 } else {
1050 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1051 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1052
1053 let left_dimension = get_dimension(left_ty_inner);
1054 let right_dimension = get_dimension(right_ty_inner);
1055
1056 let mut reverse_operands = false;
1057
1058 let spirv_op = match op {
1059 crate::BinaryOperator::Add => match *left_ty_inner {
1060 crate::TypeInner::Scalar(scalar)
1061 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1062 crate::ScalarKind::Float => spirv::Op::FAdd,
1063 _ => spirv::Op::IAdd,
1064 },
1065 crate::TypeInner::Matrix {
1066 columns,
1067 rows,
1068 scalar,
1069 } => {
1070 self.write_matrix_matrix_column_op(
1072 block,
1073 id,
1074 result_type_id,
1075 left_id,
1076 right_id,
1077 columns,
1078 rows,
1079 scalar.width,
1080 spirv::Op::FAdd,
1081 );
1082
1083 self.cached[expr_handle] = id;
1084 return Ok(());
1085 }
1086 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1087 _ => unimplemented!(),
1088 },
1089 crate::BinaryOperator::Subtract => match *left_ty_inner {
1090 crate::TypeInner::Scalar(scalar)
1091 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1092 crate::ScalarKind::Float => spirv::Op::FSub,
1093 _ => spirv::Op::ISub,
1094 },
1095 crate::TypeInner::Matrix {
1096 columns,
1097 rows,
1098 scalar,
1099 } => {
1100 self.write_matrix_matrix_column_op(
1101 block,
1102 id,
1103 result_type_id,
1104 left_id,
1105 right_id,
1106 columns,
1107 rows,
1108 scalar.width,
1109 spirv::Op::FSub,
1110 );
1111
1112 self.cached[expr_handle] = id;
1113 return Ok(());
1114 }
1115 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1116 _ => unimplemented!(),
1117 },
1118 crate::BinaryOperator::Multiply => {
1119 match (left_dimension, right_dimension) {
1120 (Dimension::Scalar, Dimension::Vector) => {
1121 self.write_vector_scalar_mult(
1122 block,
1123 id,
1124 result_type_id,
1125 right_id,
1126 left_id,
1127 right_ty_inner,
1128 );
1129
1130 self.cached[expr_handle] = id;
1131 return Ok(());
1132 }
1133 (Dimension::Vector, Dimension::Scalar) => {
1134 self.write_vector_scalar_mult(
1135 block,
1136 id,
1137 result_type_id,
1138 left_id,
1139 right_id,
1140 left_ty_inner,
1141 );
1142
1143 self.cached[expr_handle] = id;
1144 return Ok(());
1145 }
1146 (Dimension::Vector, Dimension::Matrix) => {
1147 spirv::Op::VectorTimesMatrix
1148 }
1149 (Dimension::Matrix, Dimension::Scalar)
1150 | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1151 spirv::Op::MatrixTimesScalar
1152 }
1153 (Dimension::Scalar, Dimension::Matrix)
1154 | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1155 reverse_operands = true;
1156 spirv::Op::MatrixTimesScalar
1157 }
1158 (Dimension::Matrix, Dimension::Vector) => {
1159 spirv::Op::MatrixTimesVector
1160 }
1161 (Dimension::Matrix, Dimension::Matrix) => {
1162 spirv::Op::MatrixTimesMatrix
1163 }
1164 (Dimension::Vector, Dimension::Vector)
1165 | (Dimension::Scalar, Dimension::Scalar)
1166 if left_ty_inner.scalar_kind()
1167 == Some(crate::ScalarKind::Float) =>
1168 {
1169 spirv::Op::FMul
1170 }
1171 (Dimension::Vector, Dimension::Vector)
1172 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1173 (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1174 | (Dimension::CooperativeMatrix, _)
1176 | (_, Dimension::CooperativeMatrix) => {
1177 unimplemented!()
1178 }
1179 }
1180 }
1181 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1182 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1183 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1184 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1185 _ => unimplemented!(),
1186 },
1187 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1188 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1191 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1192 unreachable!("Should have been handled by wrapped function")
1193 }
1194 _ => unimplemented!(),
1195 },
1196 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1197 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1198 spirv::Op::IEqual
1199 }
1200 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1201 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1202 _ => unimplemented!(),
1203 },
1204 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1205 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1206 spirv::Op::INotEqual
1207 }
1208 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1209 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1210 _ => unimplemented!(),
1211 },
1212 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1213 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1214 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1215 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1216 _ => unimplemented!(),
1217 },
1218 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1219 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1220 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1221 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1222 _ => unimplemented!(),
1223 },
1224 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1225 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1226 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1227 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1228 _ => unimplemented!(),
1229 },
1230 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1231 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1232 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1233 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1234 _ => unimplemented!(),
1235 },
1236 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1237 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1238 _ => spirv::Op::BitwiseAnd,
1239 },
1240 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1241 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1242 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1243 _ => spirv::Op::BitwiseOr,
1244 },
1245 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1246 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1247 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1248 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1249 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1250 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1251 _ => unimplemented!(),
1252 },
1253 };
1254
1255 block.body.push(Instruction::binary(
1256 spirv_op,
1257 result_type_id,
1258 id,
1259 if reverse_operands { right_id } else { left_id },
1260 if reverse_operands { left_id } else { right_id },
1261 ));
1262 }
1263 id
1264 }
1265 crate::Expression::Math {
1266 fun,
1267 arg,
1268 arg1,
1269 arg2,
1270 arg3,
1271 } => {
1272 use crate::MathFunction as Mf;
1273 enum MathOp {
1274 Ext(spirv::GLOp),
1275 Custom(Instruction),
1276 }
1277
1278 let arg0_id = self.cached[arg];
1279 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1280 let arg_scalar_kind = arg_ty.scalar_kind();
1281 let arg1_id = match arg1 {
1282 Some(handle) => self.cached[handle],
1283 None => 0,
1284 };
1285 let arg2_id = match arg2 {
1286 Some(handle) => self.cached[handle],
1287 None => 0,
1288 };
1289 let arg3_id = match arg3 {
1290 Some(handle) => self.cached[handle],
1291 None => 0,
1292 };
1293
1294 let id = self.gen_id();
1295 let math_op = match fun {
1296 Mf::Abs => {
1298 match arg_scalar_kind {
1299 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
1300 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
1301 Some(crate::ScalarKind::Uint) => {
1302 MathOp::Custom(Instruction::unary(
1303 spirv::Op::CopyObject, result_type_id,
1305 id,
1306 arg0_id,
1307 ))
1308 }
1309 other => unimplemented!("Unexpected abs({:?})", other),
1310 }
1311 }
1312 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1313 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1314 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1315 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1316 other => unimplemented!("Unexpected min({:?})", other),
1317 }),
1318 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1319 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1320 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1321 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1322 other => unimplemented!("Unexpected max({:?})", other),
1323 }),
1324 Mf::Clamp => match arg_scalar_kind {
1325 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1329 Some(_) => {
1330 let (min_op, max_op) = match arg_scalar_kind {
1331 Some(crate::ScalarKind::Sint) => {
1332 (spirv::GLOp::SMin, spirv::GLOp::SMax)
1333 }
1334 Some(crate::ScalarKind::Uint) => {
1335 (spirv::GLOp::UMin, spirv::GLOp::UMax)
1336 }
1337 _ => unreachable!(),
1338 };
1339
1340 let max_id = self.gen_id();
1341 block.body.push(Instruction::ext_inst_gl_op(
1342 self.writer.gl450_ext_inst_id,
1343 max_op,
1344 result_type_id,
1345 max_id,
1346 &[arg0_id, arg1_id],
1347 ));
1348
1349 MathOp::Custom(Instruction::ext_inst_gl_op(
1350 self.writer.gl450_ext_inst_id,
1351 min_op,
1352 result_type_id,
1353 id,
1354 &[max_id, arg2_id],
1355 ))
1356 }
1357 other => unimplemented!("Unexpected max({:?})", other),
1358 },
1359 Mf::Saturate => {
1360 let (maybe_size, scalar) = match *arg_ty {
1361 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1362 crate::TypeInner::Scalar(scalar) => (None, scalar),
1363 ref other => unimplemented!("Unexpected saturate({:?})", other),
1364 };
1365 let scalar = crate::Scalar::float(scalar.width);
1366 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1367 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1368
1369 if let Some(size) = maybe_size {
1370 let ty =
1371 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1372
1373 self.temp_list.clear();
1374 self.temp_list.resize(size as _, arg1_id);
1375
1376 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1377
1378 self.temp_list.fill(arg2_id);
1379
1380 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1381 }
1382
1383 MathOp::Custom(Instruction::ext_inst_gl_op(
1384 self.writer.gl450_ext_inst_id,
1385 spirv::GLOp::FClamp,
1386 result_type_id,
1387 id,
1388 &[arg0_id, arg1_id, arg2_id],
1389 ))
1390 }
1391 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1393 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1394 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1395 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1396 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1397 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1398 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1399 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1400 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1401 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1402 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1403 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1404 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1405 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1406 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1407 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1409 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1410 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1411 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1412 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1413 Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1414 Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1415 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1416 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1418 crate::TypeInner::Vector {
1419 scalar:
1420 crate::Scalar {
1421 kind: crate::ScalarKind::Float,
1422 ..
1423 },
1424 ..
1425 } => MathOp::Custom(Instruction::binary(
1426 spirv::Op::Dot,
1427 result_type_id,
1428 id,
1429 arg0_id,
1430 arg1_id,
1431 )),
1432 crate::TypeInner::Vector { size, .. } => {
1434 self.write_dot_product(
1435 id,
1436 result_type_id,
1437 arg0_id,
1438 arg1_id,
1439 size as u32,
1440 block,
1441 |result_id, composite_id, index| {
1442 Instruction::composite_extract(
1443 result_type_id,
1444 result_id,
1445 composite_id,
1446 &[index],
1447 )
1448 },
1449 );
1450 self.cached[expr_handle] = id;
1451 return Ok(());
1452 }
1453 _ => unreachable!(
1454 "Correct TypeInner for dot product should be already validated"
1455 ),
1456 },
1457 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1458 if self
1459 .writer
1460 .require_all(&[
1461 spirv::Capability::DotProduct,
1462 spirv::Capability::DotProductInput4x8BitPacked,
1463 ])
1464 .is_ok()
1465 {
1466 if self.writer.lang_version() < (1, 6) {
1468 self.writer.use_extension("SPV_KHR_integer_dot_product");
1472 }
1473
1474 let op = match fun {
1475 Mf::Dot4I8Packed => spirv::Op::SDot,
1476 Mf::Dot4U8Packed => spirv::Op::UDot,
1477 _ => unreachable!(),
1478 };
1479
1480 block.body.push(Instruction::ternary(
1481 op,
1482 result_type_id,
1483 id,
1484 arg0_id,
1485 arg1_id,
1486 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1487 ));
1488 } else {
1489 let (extract_op, arg0_id, arg1_id) = match fun {
1491 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1492 Mf::Dot4I8Packed => {
1493 let new_arg0_id = self.gen_id();
1496 block.body.push(Instruction::unary(
1497 spirv::Op::Bitcast,
1498 result_type_id,
1499 new_arg0_id,
1500 arg0_id,
1501 ));
1502
1503 let new_arg1_id = self.gen_id();
1504 block.body.push(Instruction::unary(
1505 spirv::Op::Bitcast,
1506 result_type_id,
1507 new_arg1_id,
1508 arg1_id,
1509 ));
1510
1511 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1512 }
1513 _ => unreachable!(),
1514 };
1515
1516 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1517
1518 const VEC_LENGTH: u8 = 4;
1519 let bit_shifts: [_; VEC_LENGTH as usize] =
1520 core::array::from_fn(|index| {
1521 self.writer
1522 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1523 });
1524
1525 self.write_dot_product(
1526 id,
1527 result_type_id,
1528 arg0_id,
1529 arg1_id,
1530 VEC_LENGTH as Word,
1531 block,
1532 |result_id, composite_id, index| {
1533 Instruction::ternary(
1534 extract_op,
1535 result_type_id,
1536 result_id,
1537 composite_id,
1538 bit_shifts[index as usize],
1539 eight,
1540 )
1541 },
1542 );
1543 }
1544
1545 self.cached[expr_handle] = id;
1546 return Ok(());
1547 }
1548 Mf::Outer => MathOp::Custom(Instruction::binary(
1549 spirv::Op::OuterProduct,
1550 result_type_id,
1551 id,
1552 arg0_id,
1553 arg1_id,
1554 )),
1555 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1556 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1557 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1558 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1559 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1560 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1561 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1562 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1564 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1565 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1566 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1567 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1568 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1570 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1571 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1572 other => unimplemented!("Unexpected sign({:?})", other),
1573 }),
1574 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1575 Mf::Mix => {
1576 let selector = arg2.unwrap();
1577 let selector_ty =
1578 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1579 match (arg_ty, selector_ty) {
1580 (
1582 &crate::TypeInner::Vector { size, .. },
1583 &crate::TypeInner::Scalar(scalar),
1584 ) => {
1585 let selector_type_id =
1586 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1587 self.temp_list.clear();
1588 self.temp_list.resize(size as usize, arg2_id);
1589
1590 let selector_id = self.gen_id();
1591 block.body.push(Instruction::composite_construct(
1592 selector_type_id,
1593 selector_id,
1594 &self.temp_list,
1595 ));
1596
1597 MathOp::Custom(Instruction::ext_inst_gl_op(
1598 self.writer.gl450_ext_inst_id,
1599 spirv::GLOp::FMix,
1600 result_type_id,
1601 id,
1602 &[arg0_id, arg1_id, selector_id],
1603 ))
1604 }
1605 _ => MathOp::Ext(spirv::GLOp::FMix),
1606 }
1607 }
1608 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1609 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1610 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1611 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1612 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1613 Mf::Transpose => MathOp::Custom(Instruction::unary(
1614 spirv::Op::Transpose,
1615 result_type_id,
1616 id,
1617 arg0_id,
1618 )),
1619 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1620 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1621 spirv::Op::QuantizeToF16,
1622 result_type_id,
1623 id,
1624 arg0_id,
1625 )),
1626 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1627 spirv::Op::BitReverse,
1628 result_type_id,
1629 id,
1630 arg0_id,
1631 )),
1632 Mf::CountTrailingZeros => {
1633 let uint_id = match *arg_ty {
1634 crate::TypeInner::Vector { size, scalar } => {
1635 let ty =
1636 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1637
1638 self.temp_list.clear();
1639 self.temp_list.resize(
1640 size as _,
1641 self.writer
1642 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1643 );
1644
1645 self.writer.get_constant_composite(ty, &self.temp_list)
1646 }
1647 crate::TypeInner::Scalar(scalar) => self
1648 .writer
1649 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1650 _ => unreachable!(),
1651 };
1652
1653 let lsb_id = self.gen_id();
1654 block.body.push(Instruction::ext_inst_gl_op(
1655 self.writer.gl450_ext_inst_id,
1656 spirv::GLOp::FindILsb,
1657 result_type_id,
1658 lsb_id,
1659 &[arg0_id],
1660 ));
1661
1662 MathOp::Custom(Instruction::ext_inst_gl_op(
1663 self.writer.gl450_ext_inst_id,
1664 spirv::GLOp::UMin,
1665 result_type_id,
1666 id,
1667 &[uint_id, lsb_id],
1668 ))
1669 }
1670 Mf::CountLeadingZeros => {
1671 let (int_type_id, int_id, width) = match *arg_ty {
1672 crate::TypeInner::Vector { size, scalar } => {
1673 let ty =
1674 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1675
1676 self.temp_list.clear();
1677 self.temp_list.resize(
1678 size as _,
1679 self.writer
1680 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1681 );
1682
1683 (
1684 self.get_type_id(ty),
1685 self.writer.get_constant_composite(ty, &self.temp_list),
1686 scalar.width,
1687 )
1688 }
1689 crate::TypeInner::Scalar(scalar) => (
1690 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1691 self.writer
1692 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1693 scalar.width,
1694 ),
1695 _ => unreachable!(),
1696 };
1697
1698 if width != 4 {
1699 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1700 };
1701
1702 let msb_id = self.gen_id();
1703 block.body.push(Instruction::ext_inst_gl_op(
1704 self.writer.gl450_ext_inst_id,
1705 if width != 4 {
1706 spirv::GLOp::FindILsb
1707 } else {
1708 spirv::GLOp::FindUMsb
1709 },
1710 int_type_id,
1711 msb_id,
1712 &[arg0_id],
1713 ));
1714
1715 MathOp::Custom(Instruction::binary(
1716 spirv::Op::ISub,
1717 result_type_id,
1718 id,
1719 int_id,
1720 msb_id,
1721 ))
1722 }
1723 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1724 spirv::Op::BitCount,
1725 result_type_id,
1726 id,
1727 arg0_id,
1728 )),
1729 Mf::ExtractBits => {
1730 let op = match arg_scalar_kind {
1731 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1732 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1733 other => unimplemented!("Unexpected sign({:?})", other),
1734 };
1735
1736 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1751 let width_constant = self
1752 .writer
1753 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1754
1755 let u32_type =
1756 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1757
1758 let offset_id = self.gen_id();
1760 block.body.push(Instruction::ext_inst_gl_op(
1761 self.writer.gl450_ext_inst_id,
1762 spirv::GLOp::UMin,
1763 u32_type,
1764 offset_id,
1765 &[arg1_id, width_constant],
1766 ));
1767
1768 let max_count_id = self.gen_id();
1770 block.body.push(Instruction::binary(
1771 spirv::Op::ISub,
1772 u32_type,
1773 max_count_id,
1774 width_constant,
1775 offset_id,
1776 ));
1777
1778 let count_id = self.gen_id();
1780 block.body.push(Instruction::ext_inst_gl_op(
1781 self.writer.gl450_ext_inst_id,
1782 spirv::GLOp::UMin,
1783 u32_type,
1784 count_id,
1785 &[arg2_id, max_count_id],
1786 ));
1787
1788 MathOp::Custom(Instruction::ternary(
1789 op,
1790 result_type_id,
1791 id,
1792 arg0_id,
1793 offset_id,
1794 count_id,
1795 ))
1796 }
1797 Mf::InsertBits => {
1798 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1801 let width_constant = self
1802 .writer
1803 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1804
1805 let u32_type =
1806 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1807
1808 let offset_id = self.gen_id();
1810 block.body.push(Instruction::ext_inst_gl_op(
1811 self.writer.gl450_ext_inst_id,
1812 spirv::GLOp::UMin,
1813 u32_type,
1814 offset_id,
1815 &[arg2_id, width_constant],
1816 ));
1817
1818 let max_count_id = self.gen_id();
1820 block.body.push(Instruction::binary(
1821 spirv::Op::ISub,
1822 u32_type,
1823 max_count_id,
1824 width_constant,
1825 offset_id,
1826 ));
1827
1828 let count_id = self.gen_id();
1830 block.body.push(Instruction::ext_inst_gl_op(
1831 self.writer.gl450_ext_inst_id,
1832 spirv::GLOp::UMin,
1833 u32_type,
1834 count_id,
1835 &[arg3_id, max_count_id],
1836 ));
1837
1838 MathOp::Custom(Instruction::quaternary(
1839 spirv::Op::BitFieldInsert,
1840 result_type_id,
1841 id,
1842 arg0_id,
1843 arg1_id,
1844 offset_id,
1845 count_id,
1846 ))
1847 }
1848 Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1849 Mf::FirstLeadingBit => {
1850 if arg_ty.scalar_width() == Some(4) {
1851 let thing = match arg_scalar_kind {
1852 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1853 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1854 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1855 };
1856 MathOp::Ext(thing)
1857 } else {
1858 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1859 }
1860 }
1861 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1862 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1863 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1864 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1865 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1866 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1867 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1868 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1869
1870 let last_instruction =
1871 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1872 self.write_pack4x8_optimized(
1873 block,
1874 result_type_id,
1875 arg0_id,
1876 id,
1877 is_signed,
1878 should_clamp,
1879 )
1880 } else {
1881 self.write_pack4x8_polyfill(
1882 block,
1883 result_type_id,
1884 arg0_id,
1885 id,
1886 is_signed,
1887 should_clamp,
1888 )
1889 };
1890
1891 MathOp::Custom(last_instruction)
1892 }
1893 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1894 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1895 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1896 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1897 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1898 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1899 let is_signed = matches!(fun, Mf::Unpack4xI8);
1900
1901 let last_instruction =
1902 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1903 self.write_unpack4x8_optimized(
1904 block,
1905 result_type_id,
1906 arg0_id,
1907 id,
1908 is_signed,
1909 )
1910 } else {
1911 self.write_unpack4x8_polyfill(
1912 block,
1913 result_type_id,
1914 arg0_id,
1915 id,
1916 is_signed,
1917 )
1918 };
1919
1920 MathOp::Custom(last_instruction)
1921 }
1922 };
1923
1924 block.body.push(match math_op {
1925 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1926 self.writer.gl450_ext_inst_id,
1927 op,
1928 result_type_id,
1929 id,
1930 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1931 ),
1932 MathOp::Custom(inst) => inst,
1933 });
1934 id
1935 }
1936 crate::Expression::LocalVariable(variable) => {
1937 if let Some(rq_tracker) = self
1938 .function
1939 .ray_query_initialization_tracker_variables
1940 .get(&variable)
1941 {
1942 self.ray_query_tracker_expr.insert(
1943 expr_handle,
1944 super::RayQueryTrackers {
1945 initialized_tracker: rq_tracker.id,
1946 t_max_tracker: self
1947 .function
1948 .ray_query_t_max_tracker_variables
1949 .get(&variable)
1950 .expect("Both trackers are set at the same time.")
1951 .id,
1952 },
1953 );
1954 }
1955 self.function.variables[&variable].id
1956 }
1957 crate::Expression::Load { pointer } => {
1958 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1959 }
1960 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1961 crate::Expression::CallResult(_)
1962 | crate::Expression::AtomicResult { .. }
1963 | crate::Expression::WorkGroupUniformLoadResult { .. }
1964 | crate::Expression::RayQueryProceedResult
1965 | crate::Expression::SubgroupBallotResult
1966 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1967 crate::Expression::As {
1968 expr,
1969 kind,
1970 convert,
1971 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1972 crate::Expression::ImageLoad {
1973 image,
1974 coordinate,
1975 array_index,
1976 sample,
1977 level,
1978 } => self.write_image_load(
1979 result_type_id,
1980 image,
1981 coordinate,
1982 array_index,
1983 level,
1984 sample,
1985 block,
1986 )?,
1987 crate::Expression::ImageSample {
1988 image,
1989 sampler,
1990 gather,
1991 coordinate,
1992 array_index,
1993 offset,
1994 level,
1995 depth_ref,
1996 clamp_to_edge,
1997 } => self.write_image_sample(
1998 result_type_id,
1999 image,
2000 sampler,
2001 gather,
2002 coordinate,
2003 array_index,
2004 offset,
2005 level,
2006 depth_ref,
2007 clamp_to_edge,
2008 block,
2009 )?,
2010 crate::Expression::Select {
2011 condition,
2012 accept,
2013 reject,
2014 } => {
2015 let id = self.gen_id();
2016 let mut condition_id = self.cached[condition];
2017 let accept_id = self.cached[accept];
2018 let reject_id = self.cached[reject];
2019
2020 let condition_ty = self.fun_info[condition]
2021 .ty
2022 .inner_with(&self.ir_module.types);
2023 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2024
2025 if let (
2026 &crate::TypeInner::Scalar(
2027 condition_scalar @ crate::Scalar {
2028 kind: crate::ScalarKind::Bool,
2029 ..
2030 },
2031 ),
2032 &crate::TypeInner::Vector { size, .. },
2033 ) = (condition_ty, object_ty)
2034 {
2035 self.temp_list.clear();
2036 self.temp_list.resize(size as usize, condition_id);
2037
2038 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2039 size,
2040 scalar: condition_scalar,
2041 });
2042
2043 let id = self.gen_id();
2044 block.body.push(Instruction::composite_construct(
2045 bool_vector_type_id,
2046 id,
2047 &self.temp_list,
2048 ));
2049 condition_id = id
2050 }
2051
2052 let instruction =
2053 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2054 block.body.push(instruction);
2055 id
2056 }
2057 crate::Expression::Derivative { axis, ctrl, expr } => {
2058 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2059 match ctrl {
2060 Ctrl::Coarse | Ctrl::Fine => {
2061 self.writer.require_any(
2062 "DerivativeControl",
2063 &[spirv::Capability::DerivativeControl],
2064 )?;
2065 }
2066 Ctrl::None => {}
2067 }
2068 let id = self.gen_id();
2069 let expr_id = self.cached[expr];
2070 let op = match (axis, ctrl) {
2071 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2072 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2073 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2074 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2075 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2076 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2077 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2078 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2079 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2080 };
2081 block
2082 .body
2083 .push(Instruction::derivative(op, result_type_id, id, expr_id));
2084 id
2085 }
2086 crate::Expression::ImageQuery { image, query } => {
2087 self.write_image_query(result_type_id, image, query, block)?
2088 }
2089 crate::Expression::Relational { fun, argument } => {
2090 use crate::RelationalFunction as Rf;
2091 let arg_id = self.cached[argument];
2092 let op = match fun {
2093 Rf::All => spirv::Op::All,
2094 Rf::Any => spirv::Op::Any,
2095 Rf::IsNan => spirv::Op::IsNan,
2096 Rf::IsInf => spirv::Op::IsInf,
2097 };
2098 let id = self.gen_id();
2099 block
2100 .body
2101 .push(Instruction::relational(op, result_type_id, id, arg_id));
2102 id
2103 }
2104 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2105 crate::Expression::RayQueryGetIntersection { query, committed } => {
2106 let query_id = self.cached[query];
2107 let init_tracker_id = *self
2108 .ray_query_tracker_expr
2109 .get(&query)
2110 .expect("not a cached ray query");
2111 let func_id = self
2112 .writer
2113 .write_ray_query_get_intersection_function(committed, self.ir_module);
2114 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2115 let intersection_type_id = self.get_handle_type_id(ray_intersection);
2116 let id = self.gen_id();
2117 block.body.push(Instruction::function_call(
2118 intersection_type_id,
2119 id,
2120 func_id,
2121 &[query_id, init_tracker_id.initialized_tracker],
2122 ));
2123 id
2124 }
2125 crate::Expression::RayQueryVertexPositions { query, committed } => {
2126 self.writer.require_any(
2127 "RayQueryVertexPositions",
2128 &[spirv::Capability::RayQueryPositionFetchKHR],
2129 )?;
2130 self.write_ray_query_return_vertex_position(query, block, committed)
2131 }
2132 crate::Expression::CooperativeLoad { ref data, .. } => {
2133 self.writer.require_any(
2134 "CooperativeMatrix",
2135 &[spirv::Capability::CooperativeMatrixKHR],
2136 )?;
2137 let layout = if data.row_major {
2138 spirv::CooperativeMatrixLayout::RowMajorKHR
2139 } else {
2140 spirv::CooperativeMatrixLayout::ColumnMajorKHR
2141 };
2142 let layout_id = self.get_index_constant(layout as u32);
2143 let stride_id = self.cached[data.stride];
2144 match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2145 ExpressionPointer::Ready { pointer_id } => {
2146 let id = self.gen_id();
2147 block.body.push(Instruction::coop_load(
2148 result_type_id,
2149 id,
2150 pointer_id,
2151 layout_id,
2152 stride_id,
2153 ));
2154 id
2155 }
2156 ExpressionPointer::Conditional { condition, access } => self
2157 .write_conditional_indexed_load(
2158 result_type_id,
2159 condition,
2160 block,
2161 |id_gen, block| {
2162 let pointer_id = access.result_id.unwrap();
2163 block.body.push(access);
2164 let id = id_gen.next();
2165 block.body.push(Instruction::coop_load(
2166 result_type_id,
2167 id,
2168 pointer_id,
2169 layout_id,
2170 stride_id,
2171 ));
2172 id
2173 },
2174 ),
2175 }
2176 }
2177 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2178 self.writer.require_any(
2179 "CooperativeMatrix",
2180 &[spirv::Capability::CooperativeMatrixKHR],
2181 )?;
2182 let a_id = self.cached[a];
2183 let b_id = self.cached[b];
2184 let c_id = self.cached[c];
2185 let id = self.gen_id();
2186 block.body.push(Instruction::coop_mul_add(
2187 result_type_id,
2188 id,
2189 a_id,
2190 b_id,
2191 c_id,
2192 ));
2193 id
2194 }
2195 };
2196
2197 self.cached[expr_handle] = id;
2198 Ok(())
2199 }
2200
2201 fn write_as_expression(
2204 &mut self,
2205 expr: Handle<crate::Expression>,
2206 convert: Option<u8>,
2207 kind: crate::ScalarKind,
2208
2209 block: &mut Block,
2210 result_type_id: u32,
2211 ) -> Result<u32, Error> {
2212 use crate::ScalarKind as Sk;
2213 let expr_id = self.cached[expr];
2214 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2215
2216 if let crate::TypeInner::Matrix {
2221 columns,
2222 rows,
2223 scalar,
2224 } = *ty
2225 {
2226 let Some(convert) = convert else {
2227 return Ok(expr_id);
2229 };
2230
2231 if convert == scalar.width {
2232 return Ok(expr_id);
2234 }
2235
2236 if kind != Sk::Float {
2237 return Err(Error::Validation("Matrices must be floats"));
2239 }
2240
2241 let column_src_ty =
2243 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2244 size: rows,
2245 scalar,
2246 })));
2247
2248 let column_dst_ty =
2250 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2251 size: rows,
2252 scalar: crate::Scalar {
2253 kind,
2254 width: convert,
2255 },
2256 })));
2257
2258 let mut components = ArrayVec::<Word, 4>::new();
2259
2260 for column in 0..columns as usize {
2261 let column_id = self.gen_id();
2262 block.body.push(Instruction::composite_extract(
2263 column_src_ty,
2264 column_id,
2265 expr_id,
2266 &[column as u32],
2267 ));
2268
2269 let column_conv_id = self.gen_id();
2270 block.body.push(Instruction::unary(
2271 spirv::Op::FConvert,
2272 column_dst_ty,
2273 column_conv_id,
2274 column_id,
2275 ));
2276
2277 components.push(column_conv_id);
2278 }
2279
2280 let construct_id = self.gen_id();
2281
2282 block.body.push(Instruction::composite_construct(
2283 result_type_id,
2284 construct_id,
2285 &components,
2286 ));
2287
2288 return Ok(construct_id);
2289 }
2290
2291 let (src_scalar, src_size) = match *ty {
2292 crate::TypeInner::Scalar(scalar) => (scalar, None),
2293 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2294 ref other => {
2295 log::error!("As source {other:?}");
2296 return Err(Error::Validation("Unexpected Expression::As source"));
2297 }
2298 };
2299
2300 enum Cast {
2301 Identity(Word),
2302 Unary(spirv::Op, Word),
2303 Binary(spirv::Op, Word, Word),
2304 Ternary(spirv::Op, Word, Word, Word),
2305 }
2306 let cast = match (src_scalar.kind, kind, convert) {
2307 (src_kind, kind, convert)
2310 if src_kind == kind
2311 && convert.filter(|&width| width != src_scalar.width).is_none() =>
2312 {
2313 Cast::Identity(expr_id)
2314 }
2315 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2316 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2317 (_, Sk::Bool, Some(_)) => {
2319 let op = match src_scalar.kind {
2320 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2321 Sk::Float => spirv::Op::FUnordNotEqual,
2322 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2323 };
2324 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2325 let zero_id = match src_size {
2326 Some(size) => {
2327 let ty = LocalType::Numeric(NumericType::Vector {
2328 size,
2329 scalar: src_scalar,
2330 })
2331 .into();
2332
2333 self.temp_list.clear();
2334 self.temp_list.resize(size as _, zero_scalar_id);
2335
2336 self.writer.get_constant_composite(ty, &self.temp_list)
2337 }
2338 None => zero_scalar_id,
2339 };
2340
2341 Cast::Binary(op, expr_id, zero_id)
2342 }
2343 (Sk::Bool, _, Some(dst_width)) => {
2345 let dst_scalar = crate::Scalar {
2346 kind,
2347 width: dst_width,
2348 };
2349 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2350 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2351 let (accept_id, reject_id) = match src_size {
2352 Some(size) => {
2353 let ty = LocalType::Numeric(NumericType::Vector {
2354 size,
2355 scalar: dst_scalar,
2356 })
2357 .into();
2358
2359 self.temp_list.clear();
2360 self.temp_list.resize(size as _, zero_scalar_id);
2361
2362 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2363
2364 self.temp_list.fill(one_scalar_id);
2365
2366 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2367
2368 (vec1_id, vec0_id)
2369 }
2370 None => (one_scalar_id, zero_scalar_id),
2371 };
2372
2373 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2374 }
2375 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2386 let dst_scalar = crate::Scalar { kind, width };
2387 let (min, max) =
2388 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2389 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2390
2391 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2392 None => const_id,
2393 Some(size) => {
2394 let constituent_ids = [const_id; crate::VectorSize::MAX];
2395 writer.get_constant_composite(
2396 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2397 size,
2398 scalar: src_scalar,
2399 })),
2400 &constituent_ids[..size as usize],
2401 )
2402 }
2403 };
2404 let min_const_id = self.writer.get_constant_scalar(min);
2405 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2406 let max_const_id = self.writer.get_constant_scalar(max);
2407 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2408
2409 let clamp_id = self.gen_id();
2410 block.body.push(Instruction::ext_inst_gl_op(
2411 self.writer.gl450_ext_inst_id,
2412 spirv::GLOp::FClamp,
2413 expr_type_id,
2414 clamp_id,
2415 &[expr_id, min_const_id, max_const_id],
2416 ));
2417
2418 let op = match dst_scalar.kind {
2419 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2420 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2421 _ => unreachable!(),
2422 };
2423 Cast::Unary(op, clamp_id)
2424 }
2425 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2426 Cast::Unary(spirv::Op::FConvert, expr_id)
2427 }
2428 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2429 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2430 Cast::Unary(spirv::Op::SConvert, expr_id)
2431 }
2432 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2433 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2434 Cast::Unary(spirv::Op::UConvert, expr_id)
2435 }
2436 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2437 Cast::Unary(spirv::Op::SConvert, expr_id)
2438 }
2439 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2440 Cast::Unary(spirv::Op::UConvert, expr_id)
2441 }
2442 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2444 };
2445 Ok(match cast {
2446 Cast::Identity(expr) => expr,
2447 Cast::Unary(op, op1) => {
2448 let id = self.gen_id();
2449 block
2450 .body
2451 .push(Instruction::unary(op, result_type_id, id, op1));
2452 id
2453 }
2454 Cast::Binary(op, op1, op2) => {
2455 let id = self.gen_id();
2456 block
2457 .body
2458 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2459 id
2460 }
2461 Cast::Ternary(op, op1, op2, op3) => {
2462 let id = self.gen_id();
2463 block
2464 .body
2465 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2466 id
2467 }
2468 })
2469 }
2470
2471 fn write_access_chain(
2482 &mut self,
2483 mut expr_handle: Handle<crate::Expression>,
2484 block: &mut Block,
2485 type_adjustment: AccessTypeAdjustment,
2486 ) -> Result<ExpressionPointer, Error> {
2487 let result_type_id = {
2488 let resolution = &self.fun_info[expr_handle].ty;
2489 match type_adjustment {
2490 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2491 AccessTypeAdjustment::IntroducePointer(class) => {
2492 self.writer.get_resolution_pointer_id(resolution, class)
2493 }
2494 AccessTypeAdjustment::UseStd140CompatType => {
2495 match *resolution.inner_with(&self.ir_module.types) {
2496 crate::TypeInner::Pointer {
2497 base,
2498 space: space @ crate::AddressSpace::Uniform,
2499 } => self.writer.get_pointer_type_id(
2500 self.writer.std140_compat_uniform_types[&base].type_id,
2501 map_storage_class(space),
2502 ),
2503 _ => unreachable!(
2504 "`UseStd140CompatType` must only be used with uniform pointer types"
2505 ),
2506 }
2507 }
2508 }
2509 };
2510
2511 let mut accumulated_checks = None;
2515
2516 let mut is_non_uniform_binding_array = false;
2518
2519 let mut prev_decomposed_matrix_index = None;
2525
2526 self.temp_list.clear();
2527 let root_id = loop {
2528 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2531 break spilled.id;
2534 }
2535
2536 expr_handle = match self.ir_function.expressions[expr_handle] {
2537 crate::Expression::Access { base, index } => {
2538 is_non_uniform_binding_array |=
2539 self.is_nonuniform_binding_array_access(base, index);
2540
2541 let index = GuardedIndex::Expression(index);
2542 let index_id =
2543 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2544 self.temp_list.push(index_id);
2545
2546 base
2547 }
2548 crate::Expression::AccessIndex { base, index } => {
2549 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2552 let mut base_ty_handle = self.fun_info[base].ty.handle();
2553 let mut pointer_space = None;
2554 if let crate::TypeInner::Pointer { base, space } = *base_ty {
2555 base_ty = &self.ir_module.types[base].inner;
2556 base_ty_handle = Some(base);
2557 pointer_space = Some(space);
2558 }
2559 match *base_ty {
2560 crate::TypeInner::Struct { .. } => {
2567 let index = match base_ty_handle.and_then(|handle| {
2568 self.writer.std140_compat_uniform_types.get(&handle)
2569 }) {
2570 Some(std140_type_info)
2571 if pointer_space == Some(crate::AddressSpace::Uniform) =>
2572 {
2573 std140_type_info.member_indices[index as usize]
2574 + prev_decomposed_matrix_index.take().unwrap_or(0)
2575 }
2576 _ => index,
2577 };
2578 let index_id = self.get_index_constant(index);
2579 self.temp_list.push(index_id);
2580 }
2581 _ if is_uniform_matcx2_struct_member_access(
2588 self.ir_function,
2589 self.fun_info,
2590 self.ir_module,
2591 base,
2592 ) =>
2593 {
2594 assert!(prev_decomposed_matrix_index.is_none());
2595 prev_decomposed_matrix_index = Some(index);
2596 }
2597 _ => {
2598 let index_id = self.write_access_chain_index(
2605 base,
2606 GuardedIndex::Known(index),
2607 &mut accumulated_checks,
2608 block,
2609 )?;
2610 self.temp_list.push(index_id);
2611 }
2612 }
2613 base
2614 }
2615 crate::Expression::GlobalVariable(handle) => {
2616 let gv = &self.writer.global_variables[handle];
2617 break gv.access_id;
2618 }
2619 crate::Expression::LocalVariable(variable) => {
2620 let local_var = &self.function.variables[&variable];
2621 break local_var.id;
2622 }
2623 crate::Expression::FunctionArgument(index) => {
2624 break self.function.parameter_id(index);
2625 }
2626 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2627 }
2628 };
2629
2630 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2631 (
2632 root_id,
2633 ExpressionPointer::Ready {
2634 pointer_id: root_id,
2635 },
2636 )
2637 } else {
2638 self.temp_list.reverse();
2639 let pointer_id = self.gen_id();
2640 let access =
2641 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2642
2643 let expr_pointer = match accumulated_checks {
2648 Some(condition) => ExpressionPointer::Conditional { condition, access },
2649 None => {
2650 block.body.push(access);
2651 ExpressionPointer::Ready { pointer_id }
2652 }
2653 };
2654 (pointer_id, expr_pointer)
2655 };
2656 if is_non_uniform_binding_array {
2660 self.writer
2661 .decorate_non_uniform_binding_array_access(pointer_id)?;
2662 }
2663
2664 Ok(expr_pointer)
2665 }
2666
2667 fn is_nonuniform_binding_array_access(
2668 &mut self,
2669 base: Handle<crate::Expression>,
2670 index: Handle<crate::Expression>,
2671 ) -> bool {
2672 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2673 else {
2674 return false;
2675 };
2676
2677 let gvar = &self.ir_module.global_variables[var_handle];
2680 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2681 return false;
2682 };
2683
2684 self.fun_info[index].uniformity.non_uniform_result.is_some()
2685 }
2686
2687 fn write_access_chain_index(
2697 &mut self,
2698 base: Handle<crate::Expression>,
2699 index: GuardedIndex,
2700 accumulated_checks: &mut Option<Word>,
2701 block: &mut Block,
2702 ) -> Result<Word, Error> {
2703 match self.write_bounds_check(base, index, block)? {
2704 BoundsCheckResult::KnownInBounds(known_index) => {
2705 let scalar = crate::Literal::U32(known_index);
2708 Ok(self.writer.get_constant_scalar(scalar))
2709 }
2710 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2711 BoundsCheckResult::Conditional {
2712 condition_id: condition,
2713 index_id: index,
2714 } => {
2715 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2716
2717 Ok(index)
2719 }
2720 }
2721 }
2722
2723 fn extend_bounds_check_condition_chain(
2742 &mut self,
2743 chain: &mut Option<Word>,
2744 comparison_id: Word,
2745 block: &mut Block,
2746 ) {
2747 match *chain {
2748 Some(ref mut prior_checks) => {
2749 let combined = self.gen_id();
2750 block.body.push(Instruction::binary(
2751 spirv::Op::LogicalAnd,
2752 self.writer.get_bool_type_id(),
2753 combined,
2754 *prior_checks,
2755 comparison_id,
2756 ));
2757 *prior_checks = combined;
2758 }
2759 None => {
2760 *chain = Some(comparison_id);
2762 }
2763 }
2764 }
2765
2766 fn write_checked_load(
2767 &mut self,
2768 pointer: Handle<crate::Expression>,
2769 block: &mut Block,
2770 access_type_adjustment: AccessTypeAdjustment,
2771 result_type_id: Word,
2772 ) -> Result<Word, Error> {
2773 if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2774 Ok(result_id)
2775 } else if let Some(result_id) =
2776 self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2777 {
2778 Ok(result_id)
2779 } else {
2780 struct WrappedLoad {
2786 access_type_adjustment: AccessTypeAdjustment,
2787 r#type: Handle<crate::Type>,
2788 }
2789 let mut wrapped_load = None;
2790 if let crate::TypeInner::Pointer {
2791 base: pointer_base_type,
2792 space: crate::AddressSpace::Uniform,
2793 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2794 {
2795 if self
2796 .writer
2797 .std140_compat_uniform_types
2798 .contains_key(&pointer_base_type)
2799 {
2800 wrapped_load = Some(WrappedLoad {
2801 access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2802 r#type: pointer_base_type,
2803 });
2804 };
2805 };
2806
2807 let (load_type_id, access_type_adjustment) = match wrapped_load {
2808 Some(ref wrapped_load) => (
2809 self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2810 wrapped_load.access_type_adjustment,
2811 ),
2812 None => (result_type_id, access_type_adjustment),
2813 };
2814
2815 let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2816 ExpressionPointer::Ready { pointer_id } => {
2817 let id = self.gen_id();
2818 let atomic_space =
2819 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2820 crate::TypeInner::Pointer { base, space } => {
2821 match self.ir_module.types[base].inner {
2822 crate::TypeInner::Atomic { .. } => Some(space),
2823 _ => None,
2824 }
2825 }
2826 _ => None,
2827 };
2828 let instruction = if let Some(space) = atomic_space {
2829 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2830 let scope_constant_id = self.get_scope_constant(scope as u32);
2831 let semantics_id = self.get_index_constant(semantics.bits());
2832 Instruction::atomic_load(
2833 result_type_id,
2834 id,
2835 pointer_id,
2836 scope_constant_id,
2837 semantics_id,
2838 )
2839 } else {
2840 Instruction::load(load_type_id, id, pointer_id, None)
2841 };
2842 block.body.push(instruction);
2843 id
2844 }
2845 ExpressionPointer::Conditional { condition, access } => {
2846 self.write_conditional_indexed_load(
2848 load_type_id,
2849 condition,
2850 block,
2851 move |id_gen, block| {
2852 let pointer_id = access.result_id.unwrap();
2854 let value_id = id_gen.next();
2855 block.body.push(access);
2856 block.body.push(Instruction::load(
2857 load_type_id,
2858 value_id,
2859 pointer_id,
2860 None,
2861 ));
2862 value_id
2863 },
2864 )
2865 }
2866 };
2867
2868 match wrapped_load {
2869 Some(ref wrapped_load) => {
2870 let result_id = self.gen_id();
2873 let function_id = self.writer.wrapped_functions
2874 [&WrappedFunction::ConvertFromStd140CompatType {
2875 r#type: wrapped_load.r#type,
2876 }];
2877 block.body.push(Instruction::function_call(
2878 result_type_id,
2879 result_id,
2880 function_id,
2881 &[load_id],
2882 ));
2883 Ok(result_id)
2884 }
2885 None => Ok(load_id),
2886 }
2887 }
2888 }
2889
2890 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2891 use indexmap::map::Entry;
2892
2893 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2895 Entry::Occupied(preexisting) => preexisting.get().id,
2896 Entry::Vacant(vacant) => {
2897 let pointer_type_id = self.writer.get_resolution_pointer_id(
2900 &self.fun_info[base].ty,
2901 spirv::StorageClass::Function,
2902 );
2903 let id = self.writer.id_gen.next();
2904 vacant.insert(super::LocalVariable {
2905 id,
2906 instruction: Instruction::variable(
2907 pointer_type_id,
2908 id,
2909 spirv::StorageClass::Function,
2910 None,
2911 ),
2912 });
2913 id
2914 }
2915 };
2916
2917 let base_id = self.cached[base];
2942 block
2943 .body
2944 .push(Instruction::store(spill_variable_id, base_id, None));
2945 }
2946
2947 fn maybe_access_spilled_composite(
2964 &mut self,
2965 access: Handle<crate::Expression>,
2966 block: &mut Block,
2967 result_type_id: Word,
2968 ) -> Result<Word, Error> {
2969 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2970 if access_uses == self.fun_info[access].ref_count {
2971 Ok(0)
2975 } else {
2976 self.write_checked_load(
2981 access,
2982 block,
2983 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2984 result_type_id,
2985 )
2986 }
2987 }
2988
2989 #[allow(clippy::too_many_arguments)]
2991 fn write_matrix_matrix_column_op(
2992 &mut self,
2993 block: &mut Block,
2994 result_id: Word,
2995 result_type_id: Word,
2996 left_id: Word,
2997 right_id: Word,
2998 columns: crate::VectorSize,
2999 rows: crate::VectorSize,
3000 width: u8,
3001 op: spirv::Op,
3002 ) {
3003 self.temp_list.clear();
3004
3005 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3006 size: rows,
3007 scalar: crate::Scalar::float(width),
3008 });
3009
3010 for index in 0..columns as u32 {
3011 let column_id_left = self.gen_id();
3012 let column_id_right = self.gen_id();
3013 let column_id_res = self.gen_id();
3014
3015 block.body.push(Instruction::composite_extract(
3016 vector_type_id,
3017 column_id_left,
3018 left_id,
3019 &[index],
3020 ));
3021 block.body.push(Instruction::composite_extract(
3022 vector_type_id,
3023 column_id_right,
3024 right_id,
3025 &[index],
3026 ));
3027 block.body.push(Instruction::binary(
3028 op,
3029 vector_type_id,
3030 column_id_res,
3031 column_id_left,
3032 column_id_right,
3033 ));
3034
3035 self.temp_list.push(column_id_res);
3036 }
3037
3038 block.body.push(Instruction::composite_construct(
3039 result_type_id,
3040 result_id,
3041 &self.temp_list,
3042 ));
3043 }
3044
3045 fn write_vector_scalar_mult(
3047 &mut self,
3048 block: &mut Block,
3049 result_id: Word,
3050 result_type_id: Word,
3051 vector_id: Word,
3052 scalar_id: Word,
3053 vector: &crate::TypeInner,
3054 ) {
3055 let (size, kind) = match *vector {
3056 crate::TypeInner::Vector {
3057 size,
3058 scalar: crate::Scalar { kind, .. },
3059 } => (size, kind),
3060 _ => unreachable!(),
3061 };
3062
3063 let (op, operand_id) = match kind {
3064 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3065 _ => {
3066 let operand_id = self.gen_id();
3067 self.temp_list.clear();
3068 self.temp_list.resize(size as usize, scalar_id);
3069 block.body.push(Instruction::composite_construct(
3070 result_type_id,
3071 operand_id,
3072 &self.temp_list,
3073 ));
3074 (spirv::Op::IMul, operand_id)
3075 }
3076 };
3077
3078 block.body.push(Instruction::binary(
3079 op,
3080 result_type_id,
3081 result_id,
3082 vector_id,
3083 operand_id,
3084 ));
3085 }
3086
3087 #[expect(clippy::too_many_arguments)]
3094 fn write_dot_product(
3095 &mut self,
3096 result_id: Word,
3097 result_type_id: Word,
3098 arg0_id: Word,
3099 arg1_id: Word,
3100 size: u32,
3101 block: &mut Block,
3102 extractor: impl Fn(Word, Word, Word) -> Instruction,
3103 ) {
3104 let mut partial_sum = self.writer.get_constant_null(result_type_id);
3105 let last_component = size - 1;
3106 for index in 0..=last_component {
3107 let a_id = self.gen_id();
3109 block.body.push(extractor(a_id, arg0_id, index));
3110 let b_id = self.gen_id();
3111 block.body.push(extractor(b_id, arg1_id, index));
3112 let prod_id = self.gen_id();
3113 block.body.push(Instruction::binary(
3114 spirv::Op::IMul,
3115 result_type_id,
3116 prod_id,
3117 a_id,
3118 b_id,
3119 ));
3120
3121 let id = if index == last_component {
3123 result_id
3124 } else {
3125 self.gen_id()
3126 };
3127
3128 block.body.push(Instruction::binary(
3130 spirv::Op::IAdd,
3131 result_type_id,
3132 id,
3133 partial_sum,
3134 prod_id,
3135 ));
3136 partial_sum = id;
3138 }
3139 }
3140
3141 fn write_pack4x8_optimized(
3143 &mut self,
3144 block: &mut Block,
3145 result_type_id: u32,
3146 arg0_id: u32,
3147 id: u32,
3148 is_signed: bool,
3149 should_clamp: bool,
3150 ) -> Instruction {
3151 let int_type = if is_signed {
3152 crate::ScalarKind::Sint
3153 } else {
3154 crate::ScalarKind::Uint
3155 };
3156 let wide_vector_type = NumericType::Vector {
3157 size: crate::VectorSize::Quad,
3158 scalar: crate::Scalar {
3159 kind: int_type,
3160 width: 4,
3161 },
3162 };
3163 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3164 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3165 size: crate::VectorSize::Quad,
3166 scalar: crate::Scalar {
3167 kind: crate::ScalarKind::Uint,
3168 width: 1,
3169 },
3170 });
3171
3172 let mut wide_vector = arg0_id;
3173 if should_clamp {
3174 let (min, max, clamp_op) = if is_signed {
3175 (
3176 crate::Literal::I32(-128),
3177 crate::Literal::I32(127),
3178 spirv::GLOp::SClamp,
3179 )
3180 } else {
3181 (
3182 crate::Literal::U32(0),
3183 crate::Literal::U32(255),
3184 spirv::GLOp::UClamp,
3185 )
3186 };
3187 let [min, max] = [min, max].map(|lit| {
3188 let scalar = self.writer.get_constant_scalar(lit);
3189 self.writer.get_constant_composite(
3190 LookupType::Local(LocalType::Numeric(wide_vector_type)),
3191 &[scalar; 4],
3192 )
3193 });
3194
3195 let clamp_id = self.gen_id();
3196 block.body.push(Instruction::ext_inst_gl_op(
3197 self.writer.gl450_ext_inst_id,
3198 clamp_op,
3199 wide_vector_type_id,
3200 clamp_id,
3201 &[wide_vector, min, max],
3202 ));
3203
3204 wide_vector = clamp_id;
3205 }
3206
3207 let packed_vector = self.gen_id();
3208 block.body.push(Instruction::unary(
3209 spirv::Op::UConvert, packed_vector_type_id,
3211 packed_vector,
3212 wide_vector,
3213 ));
3214
3215 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3220 }
3221
3222 fn write_pack4x8_polyfill(
3224 &mut self,
3225 block: &mut Block,
3226 result_type_id: u32,
3227 arg0_id: u32,
3228 id: u32,
3229 is_signed: bool,
3230 should_clamp: bool,
3231 ) -> Instruction {
3232 let int_type = if is_signed {
3233 crate::ScalarKind::Sint
3234 } else {
3235 crate::ScalarKind::Uint
3236 };
3237 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3238 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3239 kind: int_type,
3240 width: 4,
3241 }));
3242
3243 let mut last_instruction = Instruction::new(spirv::Op::Nop);
3244
3245 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3246 let mut preresult = zero;
3247 block
3248 .body
3249 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3250
3251 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3252 const VEC_LENGTH: u8 = 4;
3253 for i in 0..u32::from(VEC_LENGTH) {
3254 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3255 let mut extracted = self.gen_id();
3256 block.body.push(Instruction::binary(
3257 spirv::Op::CompositeExtract,
3258 int_type_id,
3259 extracted,
3260 arg0_id,
3261 i,
3262 ));
3263 if is_signed {
3264 let casted = self.gen_id();
3265 block.body.push(Instruction::unary(
3266 spirv::Op::Bitcast,
3267 uint_type_id,
3268 casted,
3269 extracted,
3270 ));
3271 extracted = casted;
3272 }
3273 if should_clamp {
3274 let (min, max, clamp_op) = if is_signed {
3275 (
3276 crate::Literal::I32(-128),
3277 crate::Literal::I32(127),
3278 spirv::GLOp::SClamp,
3279 )
3280 } else {
3281 (
3282 crate::Literal::U32(0),
3283 crate::Literal::U32(255),
3284 spirv::GLOp::UClamp,
3285 )
3286 };
3287 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3288
3289 let clamp_id = self.gen_id();
3290 block.body.push(Instruction::ext_inst_gl_op(
3291 self.writer.gl450_ext_inst_id,
3292 clamp_op,
3293 result_type_id,
3294 clamp_id,
3295 &[extracted, min, max],
3296 ));
3297
3298 extracted = clamp_id;
3299 }
3300 let is_last = i == u32::from(VEC_LENGTH - 1);
3301 if is_last {
3302 last_instruction = Instruction::quaternary(
3303 spirv::Op::BitFieldInsert,
3304 result_type_id,
3305 id,
3306 preresult,
3307 extracted,
3308 offset,
3309 eight,
3310 )
3311 } else {
3312 let new_preresult = self.gen_id();
3313 block.body.push(Instruction::quaternary(
3314 spirv::Op::BitFieldInsert,
3315 result_type_id,
3316 new_preresult,
3317 preresult,
3318 extracted,
3319 offset,
3320 eight,
3321 ));
3322 preresult = new_preresult;
3323 }
3324 }
3325 last_instruction
3326 }
3327
3328 fn write_unpack4x8_optimized(
3330 &mut self,
3331 block: &mut Block,
3332 result_type_id: u32,
3333 arg0_id: u32,
3334 id: u32,
3335 is_signed: bool,
3336 ) -> Instruction {
3337 let (int_type, convert_op) = if is_signed {
3338 (crate::ScalarKind::Sint, spirv::Op::SConvert)
3339 } else {
3340 (crate::ScalarKind::Uint, spirv::Op::UConvert)
3341 };
3342
3343 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3344 size: crate::VectorSize::Quad,
3345 scalar: crate::Scalar {
3346 kind: int_type,
3347 width: 1,
3348 },
3349 });
3350
3351 let packed_vector = self.gen_id();
3356 block.body.push(Instruction::unary(
3357 spirv::Op::Bitcast,
3358 packed_vector_type_id,
3359 packed_vector,
3360 arg0_id,
3361 ));
3362
3363 Instruction::unary(convert_op, result_type_id, id, packed_vector)
3364 }
3365
3366 fn write_unpack4x8_polyfill(
3368 &mut self,
3369 block: &mut Block,
3370 result_type_id: u32,
3371 arg0_id: u32,
3372 id: u32,
3373 is_signed: bool,
3374 ) -> Instruction {
3375 let (int_type, extract_op) = if is_signed {
3376 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3377 } else {
3378 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3379 };
3380
3381 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3382
3383 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3384 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3385 kind: int_type,
3386 width: 4,
3387 }));
3388 block
3389 .body
3390 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3391 let arg_id = if is_signed {
3392 let new_arg_id = self.gen_id();
3393 block.body.push(Instruction::unary(
3394 spirv::Op::Bitcast,
3395 sint_type_id,
3396 new_arg_id,
3397 arg0_id,
3398 ));
3399 new_arg_id
3400 } else {
3401 arg0_id
3402 };
3403
3404 const VEC_LENGTH: u8 = 4;
3405 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3406 for (i, part_id) in parts.into_iter().enumerate() {
3407 let index = self
3408 .writer
3409 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3410 block.body.push(Instruction::ternary(
3411 extract_op,
3412 int_type_id,
3413 part_id,
3414 arg_id,
3415 index,
3416 eight,
3417 ));
3418 }
3419
3420 Instruction::composite_construct(result_type_id, id, &parts)
3421 }
3422
3423 fn write_block(
3440 &mut self,
3441 label_id: Word,
3442 naga_block: &crate::Block,
3443 exit: BlockExit,
3444 loop_context: LoopContext,
3445 debug_info: Option<&DebugInfoInner>,
3446 ) -> Result<BlockExitDisposition, Error> {
3447 let mut block = Block::new(label_id);
3448 for (statement, span) in naga_block.span_iter() {
3449 if let (Some(debug_info), false) = (
3450 debug_info,
3451 matches!(
3452 statement,
3453 &(Statement::Block(..)
3454 | Statement::Break
3455 | Statement::Continue
3456 | Statement::Kill
3457 | Statement::Return { .. }
3458 | Statement::Loop { .. })
3459 ),
3460 ) {
3461 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3462 block.body.push(Instruction::line(
3463 debug_info.source_file_id,
3464 loc.line_number,
3465 loc.line_position,
3466 ));
3467 };
3468 match *statement {
3469 Statement::Emit(ref range) => {
3470 for handle in range.clone() {
3471 if !self.expression_constness.is_const(handle) {
3473 self.cache_expression_value(handle, &mut block)?;
3474 }
3475 }
3476 }
3477 Statement::Block(ref block_statements) => {
3478 let scope_id = self.gen_id();
3479 self.function.consume(block, Instruction::branch(scope_id));
3480
3481 let merge_id = self.gen_id();
3482 let merge_used = self.write_block(
3483 scope_id,
3484 block_statements,
3485 BlockExit::Branch { target: merge_id },
3486 loop_context,
3487 debug_info,
3488 )?;
3489
3490 match merge_used {
3491 BlockExitDisposition::Used => {
3492 block = Block::new(merge_id);
3493 }
3494 BlockExitDisposition::Discarded => {
3495 return Ok(BlockExitDisposition::Discarded);
3496 }
3497 }
3498 }
3499 Statement::If {
3500 condition,
3501 ref accept,
3502 ref reject,
3503 } => {
3504 if !(accept.is_empty() && reject.is_empty()) {
3510 let condition_id = self.cached[condition];
3511
3512 let merge_id = self.gen_id();
3513 block.body.push(Instruction::selection_merge(
3514 merge_id,
3515 spirv::SelectionControl::NONE,
3516 ));
3517
3518 let accept_id = if accept.is_empty() {
3519 None
3520 } else {
3521 Some(self.gen_id())
3522 };
3523 let reject_id = if reject.is_empty() {
3524 None
3525 } else {
3526 Some(self.gen_id())
3527 };
3528
3529 self.function.consume(
3530 block,
3531 Instruction::branch_conditional(
3532 condition_id,
3533 accept_id.unwrap_or(merge_id),
3534 reject_id.unwrap_or(merge_id),
3535 ),
3536 );
3537
3538 if let Some(block_id) = accept_id {
3539 let _ = self.write_block(
3544 block_id,
3545 accept,
3546 BlockExit::Branch { target: merge_id },
3547 loop_context,
3548 debug_info,
3549 )?;
3550 }
3551 if let Some(block_id) = reject_id {
3552 let _ = self.write_block(
3557 block_id,
3558 reject,
3559 BlockExit::Branch { target: merge_id },
3560 loop_context,
3561 debug_info,
3562 )?;
3563 }
3564
3565 block = Block::new(merge_id);
3566 }
3567 }
3568 Statement::Switch {
3569 selector,
3570 ref cases,
3571 } => {
3572 let selector_id = self.cached[selector];
3573
3574 let merge_id = self.gen_id();
3575 block.body.push(Instruction::selection_merge(
3576 merge_id,
3577 spirv::SelectionControl::NONE,
3578 ));
3579
3580 let mut default_id = None;
3581 let mut last_id = None;
3583
3584 let mut raw_cases = Vec::with_capacity(cases.len());
3585 let mut case_ids = Vec::with_capacity(cases.len());
3586 for case in cases.iter() {
3587 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3589
3590 if case.fall_through && case.body.is_empty() {
3591 last_id = Some(label_id);
3592 }
3593
3594 case_ids.push(label_id);
3595
3596 match case.value {
3597 crate::SwitchValue::I32(value) => {
3598 raw_cases.push(super::instructions::Case {
3599 value: value as Word,
3600 label_id,
3601 });
3602 }
3603 crate::SwitchValue::U32(value) => {
3604 raw_cases.push(super::instructions::Case { value, label_id });
3605 }
3606 crate::SwitchValue::Default => {
3607 default_id = Some(label_id);
3608 }
3609 }
3610 }
3611
3612 let default_id = default_id.unwrap();
3613
3614 self.function.consume(
3615 block,
3616 Instruction::switch(selector_id, default_id, &raw_cases),
3617 );
3618
3619 let inner_context = LoopContext {
3620 break_id: Some(merge_id),
3621 ..loop_context
3622 };
3623
3624 for (i, (case, label_id)) in cases
3625 .iter()
3626 .zip(case_ids.iter())
3627 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3628 .enumerate()
3629 {
3630 let case_finish_id = if case.fall_through {
3631 case_ids[i + 1]
3632 } else {
3633 merge_id
3634 };
3635 let _ = self.write_block(
3644 *label_id,
3645 &case.body,
3646 BlockExit::Branch {
3647 target: case_finish_id,
3648 },
3649 inner_context,
3650 debug_info,
3651 )?;
3652 }
3653
3654 block = Block::new(merge_id);
3655 }
3656 Statement::Loop {
3657 ref body,
3658 ref continuing,
3659 break_if,
3660 } => {
3661 let preamble_id = self.gen_id();
3662 self.function
3663 .consume(block, Instruction::branch(preamble_id));
3664
3665 let merge_id = self.gen_id();
3666 let body_id = self.gen_id();
3667 let continuing_id = self.gen_id();
3668
3669 block = Block::new(preamble_id);
3672 if let Some(debug_info) = debug_info {
3675 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3676 block.body.push(Instruction::line(
3677 debug_info.source_file_id,
3678 loc.line_number,
3679 loc.line_position,
3680 ))
3681 }
3682 block.body.push(Instruction::loop_merge(
3683 merge_id,
3684 continuing_id,
3685 spirv::SelectionControl::NONE,
3686 ));
3687
3688 if self.force_loop_bounding {
3689 block = self.write_force_bounded_loop_instructions(block, merge_id);
3690 }
3691 self.function.consume(block, Instruction::branch(body_id));
3692
3693 let _ = self.write_block(
3697 body_id,
3698 body,
3699 BlockExit::Branch {
3700 target: continuing_id,
3701 },
3702 LoopContext {
3703 continuing_id: Some(continuing_id),
3704 break_id: Some(merge_id),
3705 },
3706 debug_info,
3707 )?;
3708
3709 let exit = match break_if {
3710 Some(condition) => BlockExit::BreakIf {
3711 condition,
3712 preamble_id,
3713 },
3714 None => BlockExit::Branch {
3715 target: preamble_id,
3716 },
3717 };
3718
3719 let _ = self.write_block(
3723 continuing_id,
3724 continuing,
3725 exit,
3726 LoopContext {
3727 continuing_id: None,
3728 break_id: Some(merge_id),
3729 },
3730 debug_info,
3731 )?;
3732
3733 block = Block::new(merge_id);
3734 }
3735 Statement::Break => {
3736 self.function
3737 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3738 return Ok(BlockExitDisposition::Discarded);
3739 }
3740 Statement::Continue => {
3741 self.function.consume(
3742 block,
3743 Instruction::branch(loop_context.continuing_id.unwrap()),
3744 );
3745 return Ok(BlockExitDisposition::Discarded);
3746 }
3747 Statement::Return { value: Some(value) } => {
3748 let value_id = self.cached[value];
3749 let instruction = match self.function.entry_point_context {
3750 Some(ref context) => self.writer.write_entry_point_return(
3753 value_id,
3754 self.ir_function.result.as_ref().unwrap(),
3755 &context.results,
3756 &mut block.body,
3757 context.task_payload_variable_id,
3758 )?,
3759 None => Instruction::return_value(value_id),
3760 };
3761 self.function.consume(block, instruction);
3762 return Ok(BlockExitDisposition::Discarded);
3763 }
3764 Statement::Return { value: None } => {
3765 if let Some(super::EntryPointContext {
3766 mesh_state: Some(ref mesh_state),
3767 ..
3768 }) = self.function.entry_point_context
3769 {
3770 self.function.consume(
3771 block,
3772 Instruction::branch(mesh_state.entry_point_epilogue_id),
3773 );
3774 } else {
3775 self.function.consume(block, Instruction::return_void());
3776 }
3777 return Ok(BlockExitDisposition::Discarded);
3778 }
3779 Statement::Kill => {
3780 self.function.consume(block, Instruction::kill());
3781 return Ok(BlockExitDisposition::Discarded);
3782 }
3783 Statement::ControlBarrier(flags) => {
3784 self.writer.write_control_barrier(flags, &mut block.body);
3785 }
3786 Statement::MemoryBarrier(flags) => {
3787 self.writer.write_memory_barrier(flags, &mut block);
3788 }
3789 Statement::Store { pointer, value } => {
3790 let value_id = self.cached[value];
3791 match self.write_access_chain(
3792 pointer,
3793 &mut block,
3794 AccessTypeAdjustment::None,
3795 )? {
3796 ExpressionPointer::Ready { pointer_id } => {
3797 let atomic_space = match *self.fun_info[pointer]
3798 .ty
3799 .inner_with(&self.ir_module.types)
3800 {
3801 crate::TypeInner::Pointer { base, space } => {
3802 match self.ir_module.types[base].inner {
3803 crate::TypeInner::Atomic { .. } => Some(space),
3804 _ => None,
3805 }
3806 }
3807 _ => None,
3808 };
3809 let instruction = if let Some(space) = atomic_space {
3810 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3811 let scope_constant_id = self.get_scope_constant(scope as u32);
3812 let semantics_id = self.get_index_constant(semantics.bits());
3813 Instruction::atomic_store(
3814 pointer_id,
3815 scope_constant_id,
3816 semantics_id,
3817 value_id,
3818 )
3819 } else {
3820 Instruction::store(pointer_id, value_id, None)
3821 };
3822 block.body.push(instruction);
3823 }
3824 ExpressionPointer::Conditional { condition, access } => {
3825 let mut selection = Selection::start(&mut block, ());
3826 selection.if_true(self, condition, ());
3827
3828 let pointer_id = access.result_id.unwrap();
3830 selection.block().body.push(access);
3831 selection
3832 .block()
3833 .body
3834 .push(Instruction::store(pointer_id, value_id, None));
3835
3836 selection.finish(self, ());
3839 }
3840 };
3841 }
3842 Statement::ImageStore {
3843 image,
3844 coordinate,
3845 array_index,
3846 value,
3847 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3848 Statement::Call {
3849 function: local_function,
3850 ref arguments,
3851 result,
3852 } => {
3853 let id = self.gen_id();
3854 self.temp_list.clear();
3855 for &argument in arguments {
3856 self.temp_list.push(self.cached[argument]);
3857 }
3858
3859 let type_id = match result {
3860 Some(expr) => {
3861 self.cached[expr] = id;
3862 self.get_expression_type_id(&self.fun_info[expr].ty)
3863 }
3864 None => self.writer.void_type,
3865 };
3866
3867 block.body.push(Instruction::function_call(
3868 type_id,
3869 id,
3870 self.writer.lookup_function[&local_function],
3871 &self.temp_list,
3872 ));
3873 }
3874 Statement::Atomic {
3875 pointer,
3876 ref fun,
3877 value,
3878 result,
3879 } => {
3880 let id = self.gen_id();
3881 let result_type_id =
3885 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3886
3887 if let Some(result) = result {
3888 self.cached[result] = id;
3889 }
3890
3891 let pointer_id = match self.write_access_chain(
3892 pointer,
3893 &mut block,
3894 AccessTypeAdjustment::None,
3895 )? {
3896 ExpressionPointer::Ready { pointer_id } => pointer_id,
3897 ExpressionPointer::Conditional { .. } => {
3898 return Err(Error::FeatureNotImplemented(
3899 "Atomics out-of-bounds handling",
3900 ));
3901 }
3902 };
3903
3904 let space = self.fun_info[pointer]
3905 .ty
3906 .inner_with(&self.ir_module.types)
3907 .pointer_space()
3908 .unwrap();
3909 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3910 let scope_constant_id = self.get_scope_constant(scope as u32);
3911 let semantics_id = self.get_index_constant(semantics.bits());
3912 let value_id = self.cached[value];
3913 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3914
3915 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3916 return Err(Error::FeatureNotImplemented(
3917 "Atomics with non-scalar values",
3918 ));
3919 };
3920
3921 let instruction = match *fun {
3922 crate::AtomicFunction::Add => {
3923 let spirv_op = match scalar.kind {
3924 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3925 spirv::Op::AtomicIAdd
3926 }
3927 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3928 _ => unimplemented!(),
3929 };
3930 Instruction::atomic_binary(
3931 spirv_op,
3932 result_type_id,
3933 id,
3934 pointer_id,
3935 scope_constant_id,
3936 semantics_id,
3937 value_id,
3938 )
3939 }
3940 crate::AtomicFunction::Subtract => {
3941 let (spirv_op, value_id) = match scalar.kind {
3942 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3943 (spirv::Op::AtomicISub, value_id)
3944 }
3945 crate::ScalarKind::Float => {
3946 let neg_result_id = self.gen_id();
3949 block.body.push(Instruction::unary(
3950 spirv::Op::FNegate,
3951 result_type_id,
3952 neg_result_id,
3953 value_id,
3954 ));
3955 (spirv::Op::AtomicFAddEXT, neg_result_id)
3956 }
3957 _ => unimplemented!(),
3958 };
3959 Instruction::atomic_binary(
3960 spirv_op,
3961 result_type_id,
3962 id,
3963 pointer_id,
3964 scope_constant_id,
3965 semantics_id,
3966 value_id,
3967 )
3968 }
3969 crate::AtomicFunction::And => {
3970 let spirv_op = match scalar.kind {
3971 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3972 spirv::Op::AtomicAnd
3973 }
3974 _ => unimplemented!(),
3975 };
3976 Instruction::atomic_binary(
3977 spirv_op,
3978 result_type_id,
3979 id,
3980 pointer_id,
3981 scope_constant_id,
3982 semantics_id,
3983 value_id,
3984 )
3985 }
3986 crate::AtomicFunction::InclusiveOr => {
3987 let spirv_op = match scalar.kind {
3988 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3989 spirv::Op::AtomicOr
3990 }
3991 _ => unimplemented!(),
3992 };
3993 Instruction::atomic_binary(
3994 spirv_op,
3995 result_type_id,
3996 id,
3997 pointer_id,
3998 scope_constant_id,
3999 semantics_id,
4000 value_id,
4001 )
4002 }
4003 crate::AtomicFunction::ExclusiveOr => {
4004 let spirv_op = match scalar.kind {
4005 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4006 spirv::Op::AtomicXor
4007 }
4008 _ => unimplemented!(),
4009 };
4010 Instruction::atomic_binary(
4011 spirv_op,
4012 result_type_id,
4013 id,
4014 pointer_id,
4015 scope_constant_id,
4016 semantics_id,
4017 value_id,
4018 )
4019 }
4020 crate::AtomicFunction::Min => {
4021 let spirv_op = match scalar.kind {
4022 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4023 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4024 _ => unimplemented!(),
4025 };
4026 Instruction::atomic_binary(
4027 spirv_op,
4028 result_type_id,
4029 id,
4030 pointer_id,
4031 scope_constant_id,
4032 semantics_id,
4033 value_id,
4034 )
4035 }
4036 crate::AtomicFunction::Max => {
4037 let spirv_op = match scalar.kind {
4038 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4039 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4040 _ => unimplemented!(),
4041 };
4042 Instruction::atomic_binary(
4043 spirv_op,
4044 result_type_id,
4045 id,
4046 pointer_id,
4047 scope_constant_id,
4048 semantics_id,
4049 value_id,
4050 )
4051 }
4052 crate::AtomicFunction::Exchange { compare: None } => {
4053 Instruction::atomic_binary(
4054 spirv::Op::AtomicExchange,
4055 result_type_id,
4056 id,
4057 pointer_id,
4058 scope_constant_id,
4059 semantics_id,
4060 value_id,
4061 )
4062 }
4063 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4064 let scalar_type_id =
4065 self.get_numeric_type_id(NumericType::Scalar(scalar));
4066 let bool_type_id =
4067 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4068
4069 let cas_result_id = self.gen_id();
4070 let equality_result_id = self.gen_id();
4071 let equality_operator = match scalar.kind {
4072 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4073 spirv::Op::IEqual
4074 }
4075 _ => unimplemented!(),
4076 };
4077
4078 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4079 cas_instr.set_type(scalar_type_id);
4080 cas_instr.set_result(cas_result_id);
4081 cas_instr.add_operand(pointer_id);
4082 cas_instr.add_operand(scope_constant_id);
4083 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
4086 cas_instr.add_operand(self.cached[cmp]);
4087 block.body.push(cas_instr);
4088 block.body.push(Instruction::binary(
4089 equality_operator,
4090 bool_type_id,
4091 equality_result_id,
4092 cas_result_id,
4093 self.cached[cmp],
4094 ));
4095 Instruction::composite_construct(
4096 result_type_id,
4097 id,
4098 &[cas_result_id, equality_result_id],
4099 )
4100 }
4101 };
4102
4103 block.body.push(instruction);
4104 }
4105 Statement::ImageAtomic {
4106 image,
4107 coordinate,
4108 array_index,
4109 fun,
4110 value,
4111 } => {
4112 self.write_image_atomic(
4113 image,
4114 coordinate,
4115 array_index,
4116 fun,
4117 value,
4118 &mut block,
4119 )?;
4120 }
4121 Statement::WorkGroupUniformLoad { pointer, result } => {
4122 self.writer
4123 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4124 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4125 let id = self.write_checked_load(
4128 pointer,
4129 &mut block,
4130 AccessTypeAdjustment::None,
4131 result_type_id,
4132 )?;
4133 self.cached[result] = id;
4134 self.writer
4135 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4136 }
4137 Statement::RayQuery { query, ref fun } => {
4138 self.write_ray_query_function(query, fun, &mut block);
4139 }
4140 Statement::SubgroupBallot {
4141 result,
4142 ref predicate,
4143 } => {
4144 self.write_subgroup_ballot(predicate, result, &mut block)?;
4145 }
4146 Statement::SubgroupCollectiveOperation {
4147 ref op,
4148 ref collective_op,
4149 argument,
4150 result,
4151 } => {
4152 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4153 }
4154 Statement::SubgroupGather {
4155 ref mode,
4156 argument,
4157 result,
4158 } => {
4159 self.write_subgroup_gather(mode, argument, result, &mut block)?;
4160 }
4161 Statement::CooperativeStore { target, ref data } => {
4162 let target_id = self.cached[target];
4163 let layout = if data.row_major {
4164 spirv::CooperativeMatrixLayout::RowMajorKHR
4165 } else {
4166 spirv::CooperativeMatrixLayout::ColumnMajorKHR
4167 };
4168 let layout_id = self.get_index_constant(layout as u32);
4169 let stride_id = self.cached[data.stride];
4170 match self.write_access_chain(
4171 data.pointer,
4172 &mut block,
4173 AccessTypeAdjustment::None,
4174 )? {
4175 ExpressionPointer::Ready { pointer_id } => {
4176 block.body.push(Instruction::coop_store(
4177 target_id, pointer_id, layout_id, stride_id,
4178 ));
4179 }
4180 ExpressionPointer::Conditional { condition, access } => {
4181 let mut selection = Selection::start(&mut block, ());
4182 selection.if_true(self, condition, ());
4183
4184 let pointer_id = access.result_id.unwrap();
4186 selection.block().body.push(access);
4187 selection.block().body.push(Instruction::coop_store(
4188 target_id, pointer_id, layout_id, stride_id,
4189 ));
4190
4191 selection.finish(self, ());
4194 }
4195 };
4196 }
4197 }
4198 }
4199
4200 let termination = match exit {
4201 BlockExit::Return => match self.ir_function.result {
4204 Some(ref result) if self.function.entry_point_context.is_none() => {
4205 let type_id = self.get_handle_type_id(result.ty);
4206 let null_id = self.writer.get_constant_null(type_id);
4207 Instruction::return_value(null_id)
4208 }
4209 _ => Instruction::return_void(),
4210 },
4211 BlockExit::Branch { target } => Instruction::branch(target),
4212 BlockExit::BreakIf {
4213 condition,
4214 preamble_id,
4215 } => {
4216 let condition_id = self.cached[condition];
4217
4218 Instruction::branch_conditional(
4219 condition_id,
4220 loop_context.break_id.unwrap(),
4221 preamble_id,
4222 )
4223 }
4224 };
4225
4226 self.function.consume(block, termination);
4227 Ok(BlockExitDisposition::Used)
4228 }
4229
4230 pub(super) fn write_function_body(
4231 &mut self,
4232 entry_id: Word,
4233 debug_info: Option<&DebugInfoInner>,
4234 ) -> Result<(), Error> {
4235 let _ = self.write_block(
4238 entry_id,
4239 &self.ir_function.body,
4240 BlockExit::Return,
4241 LoopContext::default(),
4242 debug_info,
4243 )?;
4244 if let Some(super::EntryPointContext {
4245 mesh_state: Some(ref mesh_state),
4246 ..
4247 }) = self.function.entry_point_context
4248 {
4249 let mut block = Block::new(mesh_state.entry_point_epilogue_id);
4250 self.writer
4251 .write_mesh_shader_return(mesh_state, &mut block)?;
4252 self.function.consume(block, Instruction::return_void());
4253 }
4254
4255 Ok(())
4256 }
4257}