1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 helpers::map_storage_class, index::BoundsCheckResult, selection::Selection, Block,
12 BlockContext, Dimension, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
13 ResultMember, WrappedFunction, Writer, WriterFlags,
14};
15use crate::{
16 arena::Handle, back::spv::helpers::is_uniform_matcx2_struct_member_access,
17 proc::index::GuardedIndex, Statement,
18};
19
20fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
21 match *type_inner {
22 crate::TypeInner::Scalar(_) => Dimension::Scalar,
23 crate::TypeInner::Vector { .. } => Dimension::Vector,
24 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
25 crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix,
26 _ => unreachable!(),
27 }
28}
29
30#[derive(Copy, Clone)]
42enum AccessTypeAdjustment {
43 None,
52
53 IntroducePointer(spirv::StorageClass),
77
78 UseStd140CompatType,
88}
89
90enum ExpressionPointer {
94 Ready { pointer_id: Word },
97
98 Conditional {
104 condition: Word,
105 access: Instruction,
106 },
107}
108
109enum BlockExit {
111 Return,
113 Branch {
115 target: Word,
117 },
118 BreakIf {
124 condition: Handle<crate::Expression>,
126 preamble_id: Word,
128 },
129}
130
131#[must_use]
142enum BlockExitDisposition {
143 Used,
147
148 Discarded,
153}
154
155#[derive(Clone, Copy, Default)]
156struct LoopContext {
157 continuing_id: Option<Word>,
158 break_id: Option<Word>,
159}
160
161#[derive(Debug)]
162pub(crate) struct DebugInfoInner<'a> {
163 pub source_code: &'a str,
164 pub source_file_id: Word,
165}
166
167impl Writer {
168 fn write_epilogue_position_y_flip(
173 &mut self,
174 position_id: Word,
175 body: &mut Vec<Instruction>,
176 ) -> Result<(), Error> {
177 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
178 let index_y_id = self.get_index_constant(1);
179 let access_id = self.id_gen.next();
180 body.push(Instruction::access_chain(
181 float_ptr_type_id,
182 access_id,
183 position_id,
184 &[index_y_id],
185 ));
186
187 let float_type_id = self.get_f32_type_id();
188 let load_id = self.id_gen.next();
189 body.push(Instruction::load(float_type_id, load_id, access_id, None));
190
191 let neg_id = self.id_gen.next();
192 body.push(Instruction::unary(
193 spirv::Op::FNegate,
194 float_type_id,
195 neg_id,
196 load_id,
197 ));
198
199 body.push(Instruction::store(access_id, neg_id, None));
200 Ok(())
201 }
202
203 fn write_epilogue_frag_depth_clamp(
205 &mut self,
206 frag_depth_id: Word,
207 body: &mut Vec<Instruction>,
208 ) -> Result<(), Error> {
209 let float_type_id = self.get_f32_type_id();
210 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
211 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
212
213 let original_id = self.id_gen.next();
214 body.push(Instruction::load(
215 float_type_id,
216 original_id,
217 frag_depth_id,
218 None,
219 ));
220
221 let clamp_id = self.id_gen.next();
222 body.push(Instruction::ext_inst_gl_op(
223 self.gl450_ext_inst_id,
224 spirv::GLOp::FClamp,
225 float_type_id,
226 clamp_id,
227 &[original_id, zero_scalar_id, one_scalar_id],
228 ));
229
230 body.push(Instruction::store(frag_depth_id, clamp_id, None));
231 Ok(())
232 }
233
234 fn write_entry_point_return(
235 &mut self,
236 value_id: Word,
237 ir_result: &crate::FunctionResult,
238 result_members: &[ResultMember],
239 body: &mut Vec<Instruction>,
240 ) -> Result<Instruction, Error> {
241 for (index, res_member) in result_members.iter().enumerate() {
242 if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) {
244 return Ok(Instruction::return_value(value_id));
245 }
246 let member_value_id = match ir_result.binding {
247 Some(_) => value_id,
248 None => {
249 let member_value_id = self.id_gen.next();
250 body.push(Instruction::composite_extract(
251 res_member.type_id,
252 member_value_id,
253 value_id,
254 &[index as u32],
255 ));
256 member_value_id
257 }
258 };
259
260 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
261
262 match res_member.built_in {
263 Some(crate::BuiltIn::Position { .. })
264 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
265 {
266 self.write_epilogue_position_y_flip(res_member.id, body)?;
267 }
268 Some(crate::BuiltIn::FragDepth)
269 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
270 {
271 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
272 }
273 _ => {}
274 }
275 }
276 Ok(Instruction::return_void())
277 }
278}
279
280impl BlockContext<'_> {
281 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
294 let uint_type_id = self.writer.get_u32_type_id();
295 let uint2_type_id = self.writer.get_vec2u_type_id();
296 let uint2_ptr_type_id = self
297 .writer
298 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
299 let bool_type_id = self.writer.get_bool_type_id();
300 let bool2_type_id = self.writer.get_vec2_bool_type_id();
301 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
302 let zero_uint2_const_id = self.writer.get_constant_composite(
303 LookupType::Local(LocalType::Numeric(NumericType::Vector {
304 size: crate::VectorSize::Bi,
305 scalar: crate::Scalar::U32,
306 })),
307 &[zero_uint_const_id, zero_uint_const_id],
308 );
309 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
310 let max_uint_const_id = self
311 .writer
312 .get_constant_scalar(crate::Literal::U32(u32::MAX));
313 let max_uint2_const_id = self.writer.get_constant_composite(
314 LookupType::Local(LocalType::Numeric(NumericType::Vector {
315 size: crate::VectorSize::Bi,
316 scalar: crate::Scalar::U32,
317 })),
318 &[max_uint_const_id, max_uint_const_id],
319 );
320
321 let loop_counter_var_id = self.gen_id();
322 if self.writer.flags.contains(WriterFlags::DEBUG) {
323 self.writer
324 .debugs
325 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
326 }
327 let var = super::LocalVariable {
328 id: loop_counter_var_id,
329 instruction: Instruction::variable(
330 uint2_ptr_type_id,
331 loop_counter_var_id,
332 spirv::StorageClass::Function,
333 Some(max_uint2_const_id),
334 ),
335 };
336 self.function.force_loop_bounding_vars.push(var);
337
338 let break_if_block = self.gen_id();
339
340 self.function
341 .consume(block, Instruction::branch(break_if_block));
342 block = Block::new(break_if_block);
343
344 let load_id = self.gen_id();
347 block.body.push(Instruction::load(
348 uint2_type_id,
349 load_id,
350 loop_counter_var_id,
351 None,
352 ));
353
354 let eq_id = self.gen_id();
357 block.body.push(Instruction::binary(
358 spirv::Op::IEqual,
359 bool2_type_id,
360 eq_id,
361 zero_uint2_const_id,
362 load_id,
363 ));
364 let all_eq_id = self.gen_id();
365 block.body.push(Instruction::relational(
366 spirv::Op::All,
367 bool_type_id,
368 all_eq_id,
369 eq_id,
370 ));
371
372 let inc_counter_block_id = self.gen_id();
373 block.body.push(Instruction::selection_merge(
374 inc_counter_block_id,
375 spirv::SelectionControl::empty(),
376 ));
377 self.function.consume(
378 block,
379 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
380 );
381 block = Block::new(inc_counter_block_id);
382
383 let low_id = self.gen_id();
389 block.body.push(Instruction::composite_extract(
390 uint_type_id,
391 low_id,
392 load_id,
393 &[1],
394 ));
395 let low_overflow_id = self.gen_id();
396 block.body.push(Instruction::binary(
397 spirv::Op::IEqual,
398 bool_type_id,
399 low_overflow_id,
400 low_id,
401 zero_uint_const_id,
402 ));
403 let carry_bit_id = self.gen_id();
404 block.body.push(Instruction::select(
405 uint_type_id,
406 carry_bit_id,
407 low_overflow_id,
408 one_uint_const_id,
409 zero_uint_const_id,
410 ));
411 let decrement_id = self.gen_id();
412 block.body.push(Instruction::composite_construct(
413 uint2_type_id,
414 decrement_id,
415 &[carry_bit_id, one_uint_const_id],
416 ));
417 let result_id = self.gen_id();
418 block.body.push(Instruction::binary(
419 spirv::Op::ISub,
420 uint2_type_id,
421 result_id,
422 load_id,
423 decrement_id,
424 ));
425 block
426 .body
427 .push(Instruction::store(loop_counter_var_id, result_id, None));
428
429 block
430 }
431
432 fn maybe_write_uniform_matcx2_dynamic_access(
449 &mut self,
450 pointer: Handle<crate::Expression>,
451 block: &mut Block,
452 ) -> Result<Option<Word>, Error> {
453 let (column_pointer, component_index) = match self.fun_info[pointer]
459 .ty
460 .inner_with(&self.ir_module.types)
461 .pointer_base_type()
462 {
463 Some(resolution) => match *resolution.inner_with(&self.ir_module.types) {
464 crate::TypeInner::Scalar(_) => match self.ir_function.expressions[pointer] {
465 crate::Expression::Access { base, index } => {
466 (base, Some(GuardedIndex::Expression(index)))
467 }
468 crate::Expression::AccessIndex { base, index } => {
469 (base, Some(GuardedIndex::Known(index)))
470 }
471 _ => return Ok(None),
472 },
473 crate::TypeInner::Vector { .. } => (pointer, None),
474 _ => return Ok(None),
475 },
476 None => return Ok(None),
477 };
478
479 let crate::Expression::Access {
482 base: matrix_pointer,
483 index: column_index,
484 } = self.ir_function.expressions[column_pointer]
485 else {
486 return Ok(None);
487 };
488
489 let crate::TypeInner::Pointer {
491 base: matrix_pointer_base_type,
492 space: crate::AddressSpace::Uniform,
493 } = *self.fun_info[matrix_pointer]
494 .ty
495 .inner_with(&self.ir_module.types)
496 else {
497 return Ok(None);
498 };
499
500 let crate::TypeInner::Matrix {
502 columns,
503 rows: rows @ crate::VectorSize::Bi,
504 scalar,
505 } = self.ir_module.types[matrix_pointer_base_type].inner
506 else {
507 return Ok(None);
508 };
509
510 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
511 columns,
512 rows,
513 scalar,
514 });
515 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
516 let component_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
517 let get_column_function_id = self.writer.wrapped_functions
518 [&WrappedFunction::MatCx2GetColumn {
519 r#type: matrix_pointer_base_type,
520 }];
521
522 let matrix_load_id = self.write_checked_load(
523 matrix_pointer,
524 block,
525 AccessTypeAdjustment::None,
526 matrix_type_id,
527 )?;
528
529 let column_index_id = match *self.fun_info[column_index]
532 .ty
533 .inner_with(&self.ir_module.types)
534 {
535 crate::TypeInner::Scalar(crate::Scalar {
536 kind: crate::ScalarKind::Uint,
537 ..
538 }) => self.cached[column_index],
539 crate::TypeInner::Scalar(crate::Scalar {
540 kind: crate::ScalarKind::Sint,
541 ..
542 }) => {
543 let cast_id = self.gen_id();
544 let u32_type_id = self.writer.get_u32_type_id();
545 block.body.push(Instruction::unary(
546 spirv::Op::Bitcast,
547 u32_type_id,
548 cast_id,
549 self.cached[column_index],
550 ));
551 cast_id
552 }
553 _ => return Err(Error::Validation("Matrix access index must be u32 or i32")),
554 };
555 let column_id = self.gen_id();
556 block.body.push(Instruction::function_call(
557 column_type_id,
558 column_id,
559 get_column_function_id,
560 &[matrix_load_id, column_index_id],
561 ));
562 let result_id = match component_index {
563 Some(index) => self.write_vector_access(
564 component_type_id,
565 column_pointer,
566 Some(column_id),
567 index,
568 block,
569 )?,
570 None => column_id,
571 };
572
573 Ok(Some(result_id))
574 }
575
576 fn maybe_write_load_uniform_matcx2_struct_member(
588 &mut self,
589 pointer: Handle<crate::Expression>,
590 block: &mut Block,
591 ) -> Result<Option<Word>, Error> {
592 let crate::TypeInner::Pointer {
594 base: matrix_type,
595 space: space @ crate::AddressSpace::Uniform,
596 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
597 else {
598 return Ok(None);
599 };
600
601 let crate::TypeInner::Matrix {
602 columns,
603 rows: rows @ crate::VectorSize::Bi,
604 scalar,
605 } = self.ir_module.types[matrix_type].inner
606 else {
607 return Ok(None);
608 };
609
610 let crate::Expression::AccessIndex {
613 base: struct_pointer,
614 index: member_index,
615 } = self.ir_function.expressions[pointer]
616 else {
617 return Ok(None);
618 };
619
620 let crate::TypeInner::Pointer {
621 base: struct_type, ..
622 } = *self.fun_info[struct_pointer]
623 .ty
624 .inner_with(&self.ir_module.types)
625 else {
626 return Ok(None);
627 };
628
629 let crate::TypeInner::Struct { .. } = self.ir_module.types[struct_type].inner else {
630 return Ok(None);
631 };
632
633 let matrix_type_id = self.get_numeric_type_id(NumericType::Matrix {
634 columns,
635 rows,
636 scalar,
637 });
638 let column_type_id = self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
639 let column_pointer_type_id =
640 self.get_pointer_type_id(column_type_id, map_storage_class(space));
641 let column0_index = self.writer.std140_compat_uniform_types[&struct_type].member_indices
642 [member_index as usize];
643 let column_indices = (0..columns as u32)
644 .map(|c| self.get_index_constant(column0_index + c))
645 .collect::<ArrayVec<_, 4>>();
646
647 let load_mat_from_struct =
650 |struct_pointer_id: Word, id_gen: &mut IdGenerator, block: &mut Block| -> Word {
651 let mut column_ids: ArrayVec<Word, 4> = ArrayVec::new();
652 for index in &column_indices {
653 let column_pointer_id = id_gen.next();
654 block.body.push(Instruction::access_chain(
655 column_pointer_type_id,
656 column_pointer_id,
657 struct_pointer_id,
658 &[*index],
659 ));
660 let column_id = id_gen.next();
661 block.body.push(Instruction::load(
662 column_type_id,
663 column_id,
664 column_pointer_id,
665 None,
666 ));
667 column_ids.push(column_id);
668 }
669 let result_id = id_gen.next();
670 block.body.push(Instruction::composite_construct(
671 matrix_type_id,
672 result_id,
673 &column_ids,
674 ));
675 result_id
676 };
677
678 let result_id = match self.write_access_chain(
679 struct_pointer,
680 block,
681 AccessTypeAdjustment::UseStd140CompatType,
682 )? {
683 ExpressionPointer::Ready { pointer_id } => {
684 load_mat_from_struct(pointer_id, &mut self.writer.id_gen, block)
685 }
686 ExpressionPointer::Conditional { condition, access } => self
687 .write_conditional_indexed_load(
688 matrix_type_id,
689 condition,
690 block,
691 |id_gen, block| {
692 let pointer_id = access.result_id.unwrap();
693 block.body.push(access);
694 load_mat_from_struct(pointer_id, id_gen, block)
695 },
696 ),
697 };
698
699 Ok(Some(result_id))
700 }
701
702 pub(super) fn cache_expression_value(
704 &mut self,
705 expr_handle: Handle<crate::Expression>,
706 block: &mut Block,
707 ) -> Result<(), Error> {
708 let is_named_expression = self
709 .ir_function
710 .named_expressions
711 .contains_key(&expr_handle);
712
713 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
714 return Ok(());
715 }
716
717 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
718 let id = match self.ir_function.expressions[expr_handle] {
719 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
720 crate::Expression::Constant(handle) => {
721 let init = self.ir_module.constants[handle].init;
722 self.writer.constant_ids[init]
723 }
724 crate::Expression::Override(_) => return Err(Error::Override),
725 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
726 crate::Expression::Compose { ty, ref components } => {
727 self.temp_list.clear();
728 if self.expression_constness.is_const(expr_handle) {
729 self.temp_list.extend(
730 crate::proc::flatten_compose(
731 ty,
732 components,
733 &self.ir_function.expressions,
734 &self.ir_module.types,
735 )
736 .map(|component| self.cached[component]),
737 );
738 self.writer
739 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
740 } else {
741 self.temp_list
742 .extend(components.iter().map(|&component| self.cached[component]));
743
744 let id = self.gen_id();
745 block.body.push(Instruction::composite_construct(
746 result_type_id,
747 id,
748 &self.temp_list,
749 ));
750 id
751 }
752 }
753 crate::Expression::Splat { size, value } => {
754 let value_id = self.cached[value];
755 let components = &[value_id; 4][..size as usize];
756
757 if self.expression_constness.is_const(expr_handle) {
758 let ty = self
759 .writer
760 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
761 self.writer.get_constant_composite(ty, components)
762 } else {
763 let id = self.gen_id();
764 block.body.push(Instruction::composite_construct(
765 result_type_id,
766 id,
767 components,
768 ));
769 id
770 }
771 }
772 crate::Expression::Access { base, index } => {
773 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
774 match *base_ty_inner {
775 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
776 0
782 }
783 _ if self.function.spilled_accesses.contains(base) => {
784 self.function.spilled_accesses.insert(expr_handle);
792 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
793 }
794 crate::TypeInner::Vector { .. } => self.write_vector_access(
795 result_type_id,
796 base,
797 None,
798 GuardedIndex::Expression(index),
799 block,
800 )?,
801 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
802 match GuardedIndex::from_expression(
804 index,
805 &self.ir_function.expressions,
806 self.ir_module,
807 ) {
808 GuardedIndex::Known(value) => {
809 let id = self.gen_id();
819 let base_id = self.cached[base];
820 block.body.push(Instruction::composite_extract(
821 result_type_id,
822 id,
823 base_id,
824 &[value],
825 ));
826 id
827 }
828 GuardedIndex::Expression(_) => {
829 self.spill_to_internal_variable(base, block);
836
837 self.function.spilled_accesses.insert(expr_handle);
840 self.maybe_access_spilled_composite(
841 expr_handle,
842 block,
843 result_type_id,
844 )?
845 }
846 }
847 }
848 crate::TypeInner::BindingArray {
849 base: binding_type, ..
850 } => {
851 let result_id = match self.write_access_chain(
854 expr_handle,
855 block,
856 AccessTypeAdjustment::IntroducePointer(
857 spirv::StorageClass::UniformConstant,
858 ),
859 )? {
860 ExpressionPointer::Ready { pointer_id } => pointer_id,
861 ExpressionPointer::Conditional { .. } => {
862 return Err(Error::FeatureNotImplemented(
863 "Texture array out-of-bounds handling",
864 ));
865 }
866 };
867
868 let binding_type_id = self.get_handle_type_id(binding_type);
869
870 let load_id = self.gen_id();
871 block.body.push(Instruction::load(
872 binding_type_id,
873 load_id,
874 result_id,
875 None,
876 ));
877
878 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
882 self.writer
883 .decorate_non_uniform_binding_array_access(load_id)?;
884 }
885
886 load_id
887 }
888 ref other => {
889 log::error!(
890 "Unable to access base {:?} of type {:?}",
891 self.ir_function.expressions[base],
892 other
893 );
894 return Err(Error::Validation(
895 "only vectors and arrays may be dynamically indexed by value",
896 ));
897 }
898 }
899 }
900 crate::Expression::AccessIndex { base, index } => {
901 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
902 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
903 0
909 }
910 _ if self.function.spilled_accesses.contains(base) => {
911 self.function.spilled_accesses.insert(expr_handle);
919 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
920 }
921 crate::TypeInner::Vector { .. }
922 | crate::TypeInner::Matrix { .. }
923 | crate::TypeInner::Array { .. }
924 | crate::TypeInner::Struct { .. } => {
925 let id = self.gen_id();
930 let base_id = self.cached[base];
931 block.body.push(Instruction::composite_extract(
932 result_type_id,
933 id,
934 base_id,
935 &[index],
936 ));
937 id
938 }
939 crate::TypeInner::BindingArray {
940 base: binding_type, ..
941 } => {
942 let result_id = match self.write_access_chain(
945 expr_handle,
946 block,
947 AccessTypeAdjustment::IntroducePointer(
948 spirv::StorageClass::UniformConstant,
949 ),
950 )? {
951 ExpressionPointer::Ready { pointer_id } => pointer_id,
952 ExpressionPointer::Conditional { .. } => {
953 return Err(Error::FeatureNotImplemented(
954 "Texture array out-of-bounds handling",
955 ));
956 }
957 };
958
959 let binding_type_id = self.get_handle_type_id(binding_type);
960
961 let load_id = self.gen_id();
962 block.body.push(Instruction::load(
963 binding_type_id,
964 load_id,
965 result_id,
966 None,
967 ));
968
969 load_id
970 }
971 ref other => {
972 log::error!("Unable to access index of {other:?}");
973 return Err(Error::FeatureNotImplemented("access index for type"));
974 }
975 }
976 }
977 crate::Expression::GlobalVariable(handle) => {
978 self.writer.global_variables[handle].access_id
979 }
980 crate::Expression::Swizzle {
981 size,
982 vector,
983 pattern,
984 } => {
985 let vector_id = self.cached[vector];
986 self.temp_list.clear();
987 for &sc in pattern[..size as usize].iter() {
988 self.temp_list.push(sc as Word);
989 }
990 let id = self.gen_id();
991 block.body.push(Instruction::vector_shuffle(
992 result_type_id,
993 id,
994 vector_id,
995 vector_id,
996 &self.temp_list,
997 ));
998 id
999 }
1000 crate::Expression::Unary { op, expr } => {
1001 let id = self.gen_id();
1002 let expr_id = self.cached[expr];
1003 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1004
1005 let spirv_op = match op {
1006 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
1007 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
1008 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
1009 _ => return Err(Error::Validation("Unexpected kind for negation")),
1010 },
1011 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
1012 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
1013 };
1014
1015 block
1016 .body
1017 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
1018 id
1019 }
1020 crate::Expression::Binary { op, left, right } => {
1021 let id = self.gen_id();
1022 let left_id = self.cached[left];
1023 let right_id = self.cached[right];
1024 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
1025 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
1026
1027 if let Some(function_id) =
1028 self.writer
1029 .wrapped_functions
1030 .get(&WrappedFunction::BinaryOp {
1031 op,
1032 left_type_id,
1033 right_type_id,
1034 })
1035 {
1036 block.body.push(Instruction::function_call(
1037 result_type_id,
1038 id,
1039 *function_id,
1040 &[left_id, right_id],
1041 ));
1042 } else {
1043 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
1044 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
1045
1046 let left_dimension = get_dimension(left_ty_inner);
1047 let right_dimension = get_dimension(right_ty_inner);
1048
1049 let mut reverse_operands = false;
1050
1051 let spirv_op = match op {
1052 crate::BinaryOperator::Add => match *left_ty_inner {
1053 crate::TypeInner::Scalar(scalar)
1054 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1055 crate::ScalarKind::Float => spirv::Op::FAdd,
1056 _ => spirv::Op::IAdd,
1057 },
1058 crate::TypeInner::Matrix {
1059 columns,
1060 rows,
1061 scalar,
1062 } => {
1063 self.write_matrix_matrix_column_op(
1065 block,
1066 id,
1067 result_type_id,
1068 left_id,
1069 right_id,
1070 columns,
1071 rows,
1072 scalar.width,
1073 spirv::Op::FAdd,
1074 );
1075
1076 self.cached[expr_handle] = id;
1077 return Ok(());
1078 }
1079 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
1080 _ => unimplemented!(),
1081 },
1082 crate::BinaryOperator::Subtract => match *left_ty_inner {
1083 crate::TypeInner::Scalar(scalar)
1084 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1085 crate::ScalarKind::Float => spirv::Op::FSub,
1086 _ => spirv::Op::ISub,
1087 },
1088 crate::TypeInner::Matrix {
1089 columns,
1090 rows,
1091 scalar,
1092 } => {
1093 self.write_matrix_matrix_column_op(
1094 block,
1095 id,
1096 result_type_id,
1097 left_id,
1098 right_id,
1099 columns,
1100 rows,
1101 scalar.width,
1102 spirv::Op::FSub,
1103 );
1104
1105 self.cached[expr_handle] = id;
1106 return Ok(());
1107 }
1108 crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
1109 _ => unimplemented!(),
1110 },
1111 crate::BinaryOperator::Multiply => {
1112 match (left_dimension, right_dimension) {
1113 (Dimension::Scalar, Dimension::Vector) => {
1114 self.write_vector_scalar_mult(
1115 block,
1116 id,
1117 result_type_id,
1118 right_id,
1119 left_id,
1120 right_ty_inner,
1121 );
1122
1123 self.cached[expr_handle] = id;
1124 return Ok(());
1125 }
1126 (Dimension::Vector, Dimension::Scalar) => {
1127 self.write_vector_scalar_mult(
1128 block,
1129 id,
1130 result_type_id,
1131 left_id,
1132 right_id,
1133 left_ty_inner,
1134 );
1135
1136 self.cached[expr_handle] = id;
1137 return Ok(());
1138 }
1139 (Dimension::Vector, Dimension::Matrix) => {
1140 spirv::Op::VectorTimesMatrix
1141 }
1142 (Dimension::Matrix, Dimension::Scalar)
1143 | (Dimension::CooperativeMatrix, Dimension::Scalar) => {
1144 spirv::Op::MatrixTimesScalar
1145 }
1146 (Dimension::Scalar, Dimension::Matrix)
1147 | (Dimension::Scalar, Dimension::CooperativeMatrix) => {
1148 reverse_operands = true;
1149 spirv::Op::MatrixTimesScalar
1150 }
1151 (Dimension::Matrix, Dimension::Vector) => {
1152 spirv::Op::MatrixTimesVector
1153 }
1154 (Dimension::Matrix, Dimension::Matrix) => {
1155 spirv::Op::MatrixTimesMatrix
1156 }
1157 (Dimension::Vector, Dimension::Vector)
1158 | (Dimension::Scalar, Dimension::Scalar)
1159 if left_ty_inner.scalar_kind()
1160 == Some(crate::ScalarKind::Float) =>
1161 {
1162 spirv::Op::FMul
1163 }
1164 (Dimension::Vector, Dimension::Vector)
1165 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
1166 (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix)
1167 | (Dimension::CooperativeMatrix, _)
1169 | (_, Dimension::CooperativeMatrix) => {
1170 unimplemented!()
1171 }
1172 }
1173 }
1174 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
1175 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
1176 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
1177 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
1178 _ => unimplemented!(),
1179 },
1180 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
1181 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
1184 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1185 unreachable!("Should have been handled by wrapped function")
1186 }
1187 _ => unimplemented!(),
1188 },
1189 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
1190 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1191 spirv::Op::IEqual
1192 }
1193 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
1194 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
1195 _ => unimplemented!(),
1196 },
1197 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
1198 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
1199 spirv::Op::INotEqual
1200 }
1201 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
1202 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
1203 _ => unimplemented!(),
1204 },
1205 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
1206 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
1207 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
1208 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
1209 _ => unimplemented!(),
1210 },
1211 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
1212 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
1213 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
1214 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
1215 _ => unimplemented!(),
1216 },
1217 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
1218 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
1219 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
1220 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
1221 _ => unimplemented!(),
1222 },
1223 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
1224 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
1225 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
1226 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
1227 _ => unimplemented!(),
1228 },
1229 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
1230 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
1231 _ => spirv::Op::BitwiseAnd,
1232 },
1233 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
1234 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
1235 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
1236 _ => spirv::Op::BitwiseOr,
1237 },
1238 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
1239 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
1240 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
1241 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
1242 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
1243 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
1244 _ => unimplemented!(),
1245 },
1246 };
1247
1248 block.body.push(Instruction::binary(
1249 spirv_op,
1250 result_type_id,
1251 id,
1252 if reverse_operands { right_id } else { left_id },
1253 if reverse_operands { left_id } else { right_id },
1254 ));
1255 }
1256 id
1257 }
1258 crate::Expression::Math {
1259 fun,
1260 arg,
1261 arg1,
1262 arg2,
1263 arg3,
1264 } => {
1265 use crate::MathFunction as Mf;
1266 enum MathOp {
1267 Ext(spirv::GLOp),
1268 Custom(Instruction),
1269 }
1270
1271 let arg0_id = self.cached[arg];
1272 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
1273 let arg_scalar_kind = arg_ty.scalar_kind();
1274 let arg1_id = match arg1 {
1275 Some(handle) => self.cached[handle],
1276 None => 0,
1277 };
1278 let arg2_id = match arg2 {
1279 Some(handle) => self.cached[handle],
1280 None => 0,
1281 };
1282 let arg3_id = match arg3 {
1283 Some(handle) => self.cached[handle],
1284 None => 0,
1285 };
1286
1287 let id = self.gen_id();
1288 let math_op = match fun {
1289 Mf::Abs => {
1291 match arg_scalar_kind {
1292 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
1293 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
1294 Some(crate::ScalarKind::Uint) => {
1295 MathOp::Custom(Instruction::unary(
1296 spirv::Op::CopyObject, result_type_id,
1298 id,
1299 arg0_id,
1300 ))
1301 }
1302 other => unimplemented!("Unexpected abs({:?})", other),
1303 }
1304 }
1305 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1306 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1307 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1308 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1309 other => unimplemented!("Unexpected min({:?})", other),
1310 }),
1311 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1312 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1313 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1314 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1315 other => unimplemented!("Unexpected max({:?})", other),
1316 }),
1317 Mf::Clamp => match arg_scalar_kind {
1318 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1322 Some(_) => {
1323 let (min_op, max_op) = match arg_scalar_kind {
1324 Some(crate::ScalarKind::Sint) => {
1325 (spirv::GLOp::SMin, spirv::GLOp::SMax)
1326 }
1327 Some(crate::ScalarKind::Uint) => {
1328 (spirv::GLOp::UMin, spirv::GLOp::UMax)
1329 }
1330 _ => unreachable!(),
1331 };
1332
1333 let max_id = self.gen_id();
1334 block.body.push(Instruction::ext_inst_gl_op(
1335 self.writer.gl450_ext_inst_id,
1336 max_op,
1337 result_type_id,
1338 max_id,
1339 &[arg0_id, arg1_id],
1340 ));
1341
1342 MathOp::Custom(Instruction::ext_inst_gl_op(
1343 self.writer.gl450_ext_inst_id,
1344 min_op,
1345 result_type_id,
1346 id,
1347 &[max_id, arg2_id],
1348 ))
1349 }
1350 other => unimplemented!("Unexpected max({:?})", other),
1351 },
1352 Mf::Saturate => {
1353 let (maybe_size, scalar) = match *arg_ty {
1354 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1355 crate::TypeInner::Scalar(scalar) => (None, scalar),
1356 ref other => unimplemented!("Unexpected saturate({:?})", other),
1357 };
1358 let scalar = crate::Scalar::float(scalar.width);
1359 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1360 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1361
1362 if let Some(size) = maybe_size {
1363 let ty =
1364 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1365
1366 self.temp_list.clear();
1367 self.temp_list.resize(size as _, arg1_id);
1368
1369 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1370
1371 self.temp_list.fill(arg2_id);
1372
1373 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1374 }
1375
1376 MathOp::Custom(Instruction::ext_inst_gl_op(
1377 self.writer.gl450_ext_inst_id,
1378 spirv::GLOp::FClamp,
1379 result_type_id,
1380 id,
1381 &[arg0_id, arg1_id, arg2_id],
1382 ))
1383 }
1384 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1386 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1387 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1388 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1389 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1390 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1391 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1392 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1393 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1394 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1395 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1396 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1397 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1398 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1399 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1400 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1402 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1403 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1404 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1405 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1406 Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1407 Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1408 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1409 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1411 crate::TypeInner::Vector {
1412 scalar:
1413 crate::Scalar {
1414 kind: crate::ScalarKind::Float,
1415 ..
1416 },
1417 ..
1418 } => MathOp::Custom(Instruction::binary(
1419 spirv::Op::Dot,
1420 result_type_id,
1421 id,
1422 arg0_id,
1423 arg1_id,
1424 )),
1425 crate::TypeInner::Vector { size, .. } => {
1427 self.write_dot_product(
1428 id,
1429 result_type_id,
1430 arg0_id,
1431 arg1_id,
1432 size as u32,
1433 block,
1434 |result_id, composite_id, index| {
1435 Instruction::composite_extract(
1436 result_type_id,
1437 result_id,
1438 composite_id,
1439 &[index],
1440 )
1441 },
1442 );
1443 self.cached[expr_handle] = id;
1444 return Ok(());
1445 }
1446 _ => unreachable!(
1447 "Correct TypeInner for dot product should be already validated"
1448 ),
1449 },
1450 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1451 if self
1452 .writer
1453 .require_all(&[
1454 spirv::Capability::DotProduct,
1455 spirv::Capability::DotProductInput4x8BitPacked,
1456 ])
1457 .is_ok()
1458 {
1459 if self.writer.lang_version() < (1, 6) {
1461 self.writer.use_extension("SPV_KHR_integer_dot_product");
1465 }
1466
1467 let op = match fun {
1468 Mf::Dot4I8Packed => spirv::Op::SDot,
1469 Mf::Dot4U8Packed => spirv::Op::UDot,
1470 _ => unreachable!(),
1471 };
1472
1473 block.body.push(Instruction::ternary(
1474 op,
1475 result_type_id,
1476 id,
1477 arg0_id,
1478 arg1_id,
1479 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1480 ));
1481 } else {
1482 let (extract_op, arg0_id, arg1_id) = match fun {
1484 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1485 Mf::Dot4I8Packed => {
1486 let new_arg0_id = self.gen_id();
1489 block.body.push(Instruction::unary(
1490 spirv::Op::Bitcast,
1491 result_type_id,
1492 new_arg0_id,
1493 arg0_id,
1494 ));
1495
1496 let new_arg1_id = self.gen_id();
1497 block.body.push(Instruction::unary(
1498 spirv::Op::Bitcast,
1499 result_type_id,
1500 new_arg1_id,
1501 arg1_id,
1502 ));
1503
1504 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1505 }
1506 _ => unreachable!(),
1507 };
1508
1509 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1510
1511 const VEC_LENGTH: u8 = 4;
1512 let bit_shifts: [_; VEC_LENGTH as usize] =
1513 core::array::from_fn(|index| {
1514 self.writer
1515 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1516 });
1517
1518 self.write_dot_product(
1519 id,
1520 result_type_id,
1521 arg0_id,
1522 arg1_id,
1523 VEC_LENGTH as Word,
1524 block,
1525 |result_id, composite_id, index| {
1526 Instruction::ternary(
1527 extract_op,
1528 result_type_id,
1529 result_id,
1530 composite_id,
1531 bit_shifts[index as usize],
1532 eight,
1533 )
1534 },
1535 );
1536 }
1537
1538 self.cached[expr_handle] = id;
1539 return Ok(());
1540 }
1541 Mf::Outer => MathOp::Custom(Instruction::binary(
1542 spirv::Op::OuterProduct,
1543 result_type_id,
1544 id,
1545 arg0_id,
1546 arg1_id,
1547 )),
1548 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1549 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1550 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1551 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1552 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1553 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1554 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1555 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1557 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1558 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1559 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1560 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1561 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1563 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1564 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1565 other => unimplemented!("Unexpected sign({:?})", other),
1566 }),
1567 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1568 Mf::Mix => {
1569 let selector = arg2.unwrap();
1570 let selector_ty =
1571 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1572 match (arg_ty, selector_ty) {
1573 (
1575 &crate::TypeInner::Vector { size, .. },
1576 &crate::TypeInner::Scalar(scalar),
1577 ) => {
1578 let selector_type_id =
1579 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1580 self.temp_list.clear();
1581 self.temp_list.resize(size as usize, arg2_id);
1582
1583 let selector_id = self.gen_id();
1584 block.body.push(Instruction::composite_construct(
1585 selector_type_id,
1586 selector_id,
1587 &self.temp_list,
1588 ));
1589
1590 MathOp::Custom(Instruction::ext_inst_gl_op(
1591 self.writer.gl450_ext_inst_id,
1592 spirv::GLOp::FMix,
1593 result_type_id,
1594 id,
1595 &[arg0_id, arg1_id, selector_id],
1596 ))
1597 }
1598 _ => MathOp::Ext(spirv::GLOp::FMix),
1599 }
1600 }
1601 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1602 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1603 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1604 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1605 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1606 Mf::Transpose => MathOp::Custom(Instruction::unary(
1607 spirv::Op::Transpose,
1608 result_type_id,
1609 id,
1610 arg0_id,
1611 )),
1612 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1613 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1614 spirv::Op::QuantizeToF16,
1615 result_type_id,
1616 id,
1617 arg0_id,
1618 )),
1619 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1620 spirv::Op::BitReverse,
1621 result_type_id,
1622 id,
1623 arg0_id,
1624 )),
1625 Mf::CountTrailingZeros => {
1626 let uint_id = match *arg_ty {
1627 crate::TypeInner::Vector { size, scalar } => {
1628 let ty =
1629 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1630
1631 self.temp_list.clear();
1632 self.temp_list.resize(
1633 size as _,
1634 self.writer
1635 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1636 );
1637
1638 self.writer.get_constant_composite(ty, &self.temp_list)
1639 }
1640 crate::TypeInner::Scalar(scalar) => self
1641 .writer
1642 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1643 _ => unreachable!(),
1644 };
1645
1646 let lsb_id = self.gen_id();
1647 block.body.push(Instruction::ext_inst_gl_op(
1648 self.writer.gl450_ext_inst_id,
1649 spirv::GLOp::FindILsb,
1650 result_type_id,
1651 lsb_id,
1652 &[arg0_id],
1653 ));
1654
1655 MathOp::Custom(Instruction::ext_inst_gl_op(
1656 self.writer.gl450_ext_inst_id,
1657 spirv::GLOp::UMin,
1658 result_type_id,
1659 id,
1660 &[uint_id, lsb_id],
1661 ))
1662 }
1663 Mf::CountLeadingZeros => {
1664 let (int_type_id, int_id, width) = match *arg_ty {
1665 crate::TypeInner::Vector { size, scalar } => {
1666 let ty =
1667 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1668
1669 self.temp_list.clear();
1670 self.temp_list.resize(
1671 size as _,
1672 self.writer
1673 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1674 );
1675
1676 (
1677 self.get_type_id(ty),
1678 self.writer.get_constant_composite(ty, &self.temp_list),
1679 scalar.width,
1680 )
1681 }
1682 crate::TypeInner::Scalar(scalar) => (
1683 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1684 self.writer
1685 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1686 scalar.width,
1687 ),
1688 _ => unreachable!(),
1689 };
1690
1691 if width != 4 {
1692 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1693 };
1694
1695 let msb_id = self.gen_id();
1696 block.body.push(Instruction::ext_inst_gl_op(
1697 self.writer.gl450_ext_inst_id,
1698 if width != 4 {
1699 spirv::GLOp::FindILsb
1700 } else {
1701 spirv::GLOp::FindUMsb
1702 },
1703 int_type_id,
1704 msb_id,
1705 &[arg0_id],
1706 ));
1707
1708 MathOp::Custom(Instruction::binary(
1709 spirv::Op::ISub,
1710 result_type_id,
1711 id,
1712 int_id,
1713 msb_id,
1714 ))
1715 }
1716 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1717 spirv::Op::BitCount,
1718 result_type_id,
1719 id,
1720 arg0_id,
1721 )),
1722 Mf::ExtractBits => {
1723 let op = match arg_scalar_kind {
1724 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1725 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1726 other => unimplemented!("Unexpected sign({:?})", other),
1727 };
1728
1729 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1744 let width_constant = self
1745 .writer
1746 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1747
1748 let u32_type =
1749 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1750
1751 let offset_id = self.gen_id();
1753 block.body.push(Instruction::ext_inst_gl_op(
1754 self.writer.gl450_ext_inst_id,
1755 spirv::GLOp::UMin,
1756 u32_type,
1757 offset_id,
1758 &[arg1_id, width_constant],
1759 ));
1760
1761 let max_count_id = self.gen_id();
1763 block.body.push(Instruction::binary(
1764 spirv::Op::ISub,
1765 u32_type,
1766 max_count_id,
1767 width_constant,
1768 offset_id,
1769 ));
1770
1771 let count_id = self.gen_id();
1773 block.body.push(Instruction::ext_inst_gl_op(
1774 self.writer.gl450_ext_inst_id,
1775 spirv::GLOp::UMin,
1776 u32_type,
1777 count_id,
1778 &[arg2_id, max_count_id],
1779 ));
1780
1781 MathOp::Custom(Instruction::ternary(
1782 op,
1783 result_type_id,
1784 id,
1785 arg0_id,
1786 offset_id,
1787 count_id,
1788 ))
1789 }
1790 Mf::InsertBits => {
1791 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1794 let width_constant = self
1795 .writer
1796 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1797
1798 let u32_type =
1799 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1800
1801 let offset_id = self.gen_id();
1803 block.body.push(Instruction::ext_inst_gl_op(
1804 self.writer.gl450_ext_inst_id,
1805 spirv::GLOp::UMin,
1806 u32_type,
1807 offset_id,
1808 &[arg2_id, width_constant],
1809 ));
1810
1811 let max_count_id = self.gen_id();
1813 block.body.push(Instruction::binary(
1814 spirv::Op::ISub,
1815 u32_type,
1816 max_count_id,
1817 width_constant,
1818 offset_id,
1819 ));
1820
1821 let count_id = self.gen_id();
1823 block.body.push(Instruction::ext_inst_gl_op(
1824 self.writer.gl450_ext_inst_id,
1825 spirv::GLOp::UMin,
1826 u32_type,
1827 count_id,
1828 &[arg3_id, max_count_id],
1829 ));
1830
1831 MathOp::Custom(Instruction::quaternary(
1832 spirv::Op::BitFieldInsert,
1833 result_type_id,
1834 id,
1835 arg0_id,
1836 arg1_id,
1837 offset_id,
1838 count_id,
1839 ))
1840 }
1841 Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1842 Mf::FirstLeadingBit => {
1843 if arg_ty.scalar_width() == Some(4) {
1844 let thing = match arg_scalar_kind {
1845 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1846 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1847 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1848 };
1849 MathOp::Ext(thing)
1850 } else {
1851 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1852 }
1853 }
1854 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1855 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1856 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1857 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1858 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1859 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1860 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1861 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1862
1863 let last_instruction =
1864 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1865 self.write_pack4x8_optimized(
1866 block,
1867 result_type_id,
1868 arg0_id,
1869 id,
1870 is_signed,
1871 should_clamp,
1872 )
1873 } else {
1874 self.write_pack4x8_polyfill(
1875 block,
1876 result_type_id,
1877 arg0_id,
1878 id,
1879 is_signed,
1880 should_clamp,
1881 )
1882 };
1883
1884 MathOp::Custom(last_instruction)
1885 }
1886 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1887 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1888 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1889 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1890 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1891 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1892 let is_signed = matches!(fun, Mf::Unpack4xI8);
1893
1894 let last_instruction =
1895 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1896 self.write_unpack4x8_optimized(
1897 block,
1898 result_type_id,
1899 arg0_id,
1900 id,
1901 is_signed,
1902 )
1903 } else {
1904 self.write_unpack4x8_polyfill(
1905 block,
1906 result_type_id,
1907 arg0_id,
1908 id,
1909 is_signed,
1910 )
1911 };
1912
1913 MathOp::Custom(last_instruction)
1914 }
1915 };
1916
1917 block.body.push(match math_op {
1918 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1919 self.writer.gl450_ext_inst_id,
1920 op,
1921 result_type_id,
1922 id,
1923 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1924 ),
1925 MathOp::Custom(inst) => inst,
1926 });
1927 id
1928 }
1929 crate::Expression::LocalVariable(variable) => {
1930 if let Some(rq_tracker) = self
1931 .function
1932 .ray_query_initialization_tracker_variables
1933 .get(&variable)
1934 {
1935 self.ray_query_tracker_expr.insert(
1936 expr_handle,
1937 super::RayQueryTrackers {
1938 initialized_tracker: rq_tracker.id,
1939 t_max_tracker: self
1940 .function
1941 .ray_query_t_max_tracker_variables
1942 .get(&variable)
1943 .expect("Both trackers are set at the same time.")
1944 .id,
1945 },
1946 );
1947 }
1948 self.function.variables[&variable].id
1949 }
1950 crate::Expression::Load { pointer } => {
1951 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1952 }
1953 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1954 crate::Expression::CallResult(_)
1955 | crate::Expression::AtomicResult { .. }
1956 | crate::Expression::WorkGroupUniformLoadResult { .. }
1957 | crate::Expression::RayQueryProceedResult
1958 | crate::Expression::SubgroupBallotResult
1959 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1960 crate::Expression::As {
1961 expr,
1962 kind,
1963 convert,
1964 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1965 crate::Expression::ImageLoad {
1966 image,
1967 coordinate,
1968 array_index,
1969 sample,
1970 level,
1971 } => self.write_image_load(
1972 result_type_id,
1973 image,
1974 coordinate,
1975 array_index,
1976 level,
1977 sample,
1978 block,
1979 )?,
1980 crate::Expression::ImageSample {
1981 image,
1982 sampler,
1983 gather,
1984 coordinate,
1985 array_index,
1986 offset,
1987 level,
1988 depth_ref,
1989 clamp_to_edge,
1990 } => self.write_image_sample(
1991 result_type_id,
1992 image,
1993 sampler,
1994 gather,
1995 coordinate,
1996 array_index,
1997 offset,
1998 level,
1999 depth_ref,
2000 clamp_to_edge,
2001 block,
2002 )?,
2003 crate::Expression::Select {
2004 condition,
2005 accept,
2006 reject,
2007 } => {
2008 let id = self.gen_id();
2009 let mut condition_id = self.cached[condition];
2010 let accept_id = self.cached[accept];
2011 let reject_id = self.cached[reject];
2012
2013 let condition_ty = self.fun_info[condition]
2014 .ty
2015 .inner_with(&self.ir_module.types);
2016 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
2017
2018 if let (
2019 &crate::TypeInner::Scalar(
2020 condition_scalar @ crate::Scalar {
2021 kind: crate::ScalarKind::Bool,
2022 ..
2023 },
2024 ),
2025 &crate::TypeInner::Vector { size, .. },
2026 ) = (condition_ty, object_ty)
2027 {
2028 self.temp_list.clear();
2029 self.temp_list.resize(size as usize, condition_id);
2030
2031 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2032 size,
2033 scalar: condition_scalar,
2034 });
2035
2036 let id = self.gen_id();
2037 block.body.push(Instruction::composite_construct(
2038 bool_vector_type_id,
2039 id,
2040 &self.temp_list,
2041 ));
2042 condition_id = id
2043 }
2044
2045 let instruction =
2046 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
2047 block.body.push(instruction);
2048 id
2049 }
2050 crate::Expression::Derivative { axis, ctrl, expr } => {
2051 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
2052 match ctrl {
2053 Ctrl::Coarse | Ctrl::Fine => {
2054 self.writer.require_any(
2055 "DerivativeControl",
2056 &[spirv::Capability::DerivativeControl],
2057 )?;
2058 }
2059 Ctrl::None => {}
2060 }
2061 let id = self.gen_id();
2062 let expr_id = self.cached[expr];
2063 let op = match (axis, ctrl) {
2064 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
2065 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
2066 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
2067 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
2068 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
2069 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
2070 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
2071 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
2072 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
2073 };
2074 block
2075 .body
2076 .push(Instruction::derivative(op, result_type_id, id, expr_id));
2077 id
2078 }
2079 crate::Expression::ImageQuery { image, query } => {
2080 self.write_image_query(result_type_id, image, query, block)?
2081 }
2082 crate::Expression::Relational { fun, argument } => {
2083 use crate::RelationalFunction as Rf;
2084 let arg_id = self.cached[argument];
2085 let op = match fun {
2086 Rf::All => spirv::Op::All,
2087 Rf::Any => spirv::Op::Any,
2088 Rf::IsNan => spirv::Op::IsNan,
2089 Rf::IsInf => spirv::Op::IsInf,
2090 };
2091 let id = self.gen_id();
2092 block
2093 .body
2094 .push(Instruction::relational(op, result_type_id, id, arg_id));
2095 id
2096 }
2097 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
2098 crate::Expression::RayQueryGetIntersection { query, committed } => {
2099 let query_id = self.cached[query];
2100 let init_tracker_id = *self
2101 .ray_query_tracker_expr
2102 .get(&query)
2103 .expect("not a cached ray query");
2104 let func_id = self
2105 .writer
2106 .write_ray_query_get_intersection_function(committed, self.ir_module);
2107 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
2108 let intersection_type_id = self.get_handle_type_id(ray_intersection);
2109 let id = self.gen_id();
2110 block.body.push(Instruction::function_call(
2111 intersection_type_id,
2112 id,
2113 func_id,
2114 &[query_id, init_tracker_id.initialized_tracker],
2115 ));
2116 id
2117 }
2118 crate::Expression::RayQueryVertexPositions { query, committed } => {
2119 self.writer.require_any(
2120 "RayQueryVertexPositions",
2121 &[spirv::Capability::RayQueryPositionFetchKHR],
2122 )?;
2123 self.write_ray_query_return_vertex_position(query, block, committed)
2124 }
2125 crate::Expression::CooperativeLoad { ref data, .. } => {
2126 self.writer.require_any(
2127 "CooperativeMatrix",
2128 &[spirv::Capability::CooperativeMatrixKHR],
2129 )?;
2130 let layout = if data.row_major {
2131 spirv::CooperativeMatrixLayout::RowMajorKHR
2132 } else {
2133 spirv::CooperativeMatrixLayout::ColumnMajorKHR
2134 };
2135 let layout_id = self.get_index_constant(layout as u32);
2136 let stride_id = self.cached[data.stride];
2137 match self.write_access_chain(data.pointer, block, AccessTypeAdjustment::None)? {
2138 ExpressionPointer::Ready { pointer_id } => {
2139 let id = self.gen_id();
2140 block.body.push(Instruction::coop_load(
2141 result_type_id,
2142 id,
2143 pointer_id,
2144 layout_id,
2145 stride_id,
2146 ));
2147 id
2148 }
2149 ExpressionPointer::Conditional { condition, access } => self
2150 .write_conditional_indexed_load(
2151 result_type_id,
2152 condition,
2153 block,
2154 |id_gen, block| {
2155 let pointer_id = access.result_id.unwrap();
2156 block.body.push(access);
2157 let id = id_gen.next();
2158 block.body.push(Instruction::coop_load(
2159 result_type_id,
2160 id,
2161 pointer_id,
2162 layout_id,
2163 stride_id,
2164 ));
2165 id
2166 },
2167 ),
2168 }
2169 }
2170 crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
2171 self.writer.require_any(
2172 "CooperativeMatrix",
2173 &[spirv::Capability::CooperativeMatrixKHR],
2174 )?;
2175 let a_id = self.cached[a];
2176 let b_id = self.cached[b];
2177 let c_id = self.cached[c];
2178 let id = self.gen_id();
2179 block.body.push(Instruction::coop_mul_add(
2180 result_type_id,
2181 id,
2182 a_id,
2183 b_id,
2184 c_id,
2185 ));
2186 id
2187 }
2188 };
2189
2190 self.cached[expr_handle] = id;
2191 Ok(())
2192 }
2193
2194 fn write_as_expression(
2197 &mut self,
2198 expr: Handle<crate::Expression>,
2199 convert: Option<u8>,
2200 kind: crate::ScalarKind,
2201
2202 block: &mut Block,
2203 result_type_id: u32,
2204 ) -> Result<u32, Error> {
2205 use crate::ScalarKind as Sk;
2206 let expr_id = self.cached[expr];
2207 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
2208
2209 if let crate::TypeInner::Matrix {
2214 columns,
2215 rows,
2216 scalar,
2217 } = *ty
2218 {
2219 let Some(convert) = convert else {
2220 return Ok(expr_id);
2222 };
2223
2224 if convert == scalar.width {
2225 return Ok(expr_id);
2227 }
2228
2229 if kind != Sk::Float {
2230 return Err(Error::Validation("Matrices must be floats"));
2232 }
2233
2234 let column_src_ty =
2236 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2237 size: rows,
2238 scalar,
2239 })));
2240
2241 let column_dst_ty =
2243 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
2244 size: rows,
2245 scalar: crate::Scalar {
2246 kind,
2247 width: convert,
2248 },
2249 })));
2250
2251 let mut components = ArrayVec::<Word, 4>::new();
2252
2253 for column in 0..columns as usize {
2254 let column_id = self.gen_id();
2255 block.body.push(Instruction::composite_extract(
2256 column_src_ty,
2257 column_id,
2258 expr_id,
2259 &[column as u32],
2260 ));
2261
2262 let column_conv_id = self.gen_id();
2263 block.body.push(Instruction::unary(
2264 spirv::Op::FConvert,
2265 column_dst_ty,
2266 column_conv_id,
2267 column_id,
2268 ));
2269
2270 components.push(column_conv_id);
2271 }
2272
2273 let construct_id = self.gen_id();
2274
2275 block.body.push(Instruction::composite_construct(
2276 result_type_id,
2277 construct_id,
2278 &components,
2279 ));
2280
2281 return Ok(construct_id);
2282 }
2283
2284 let (src_scalar, src_size) = match *ty {
2285 crate::TypeInner::Scalar(scalar) => (scalar, None),
2286 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
2287 ref other => {
2288 log::error!("As source {other:?}");
2289 return Err(Error::Validation("Unexpected Expression::As source"));
2290 }
2291 };
2292
2293 enum Cast {
2294 Identity(Word),
2295 Unary(spirv::Op, Word),
2296 Binary(spirv::Op, Word, Word),
2297 Ternary(spirv::Op, Word, Word, Word),
2298 }
2299 let cast = match (src_scalar.kind, kind, convert) {
2300 (src_kind, kind, convert)
2303 if src_kind == kind
2304 && convert.filter(|&width| width != src_scalar.width).is_none() =>
2305 {
2306 Cast::Identity(expr_id)
2307 }
2308 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
2309 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
2310 (_, Sk::Bool, Some(_)) => {
2312 let op = match src_scalar.kind {
2313 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
2314 Sk::Float => spirv::Op::FUnordNotEqual,
2315 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
2316 };
2317 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
2318 let zero_id = match src_size {
2319 Some(size) => {
2320 let ty = LocalType::Numeric(NumericType::Vector {
2321 size,
2322 scalar: src_scalar,
2323 })
2324 .into();
2325
2326 self.temp_list.clear();
2327 self.temp_list.resize(size as _, zero_scalar_id);
2328
2329 self.writer.get_constant_composite(ty, &self.temp_list)
2330 }
2331 None => zero_scalar_id,
2332 };
2333
2334 Cast::Binary(op, expr_id, zero_id)
2335 }
2336 (Sk::Bool, _, Some(dst_width)) => {
2338 let dst_scalar = crate::Scalar {
2339 kind,
2340 width: dst_width,
2341 };
2342 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
2343 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
2344 let (accept_id, reject_id) = match src_size {
2345 Some(size) => {
2346 let ty = LocalType::Numeric(NumericType::Vector {
2347 size,
2348 scalar: dst_scalar,
2349 })
2350 .into();
2351
2352 self.temp_list.clear();
2353 self.temp_list.resize(size as _, zero_scalar_id);
2354
2355 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
2356
2357 self.temp_list.fill(one_scalar_id);
2358
2359 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
2360
2361 (vec1_id, vec0_id)
2362 }
2363 None => (one_scalar_id, zero_scalar_id),
2364 };
2365
2366 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
2367 }
2368 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2379 let dst_scalar = crate::Scalar { kind, width };
2380 let (min, max) =
2381 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2382 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2383
2384 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2385 None => const_id,
2386 Some(size) => {
2387 let constituent_ids = [const_id; crate::VectorSize::MAX];
2388 writer.get_constant_composite(
2389 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2390 size,
2391 scalar: src_scalar,
2392 })),
2393 &constituent_ids[..size as usize],
2394 )
2395 }
2396 };
2397 let min_const_id = self.writer.get_constant_scalar(min);
2398 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2399 let max_const_id = self.writer.get_constant_scalar(max);
2400 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2401
2402 let clamp_id = self.gen_id();
2403 block.body.push(Instruction::ext_inst_gl_op(
2404 self.writer.gl450_ext_inst_id,
2405 spirv::GLOp::FClamp,
2406 expr_type_id,
2407 clamp_id,
2408 &[expr_id, min_const_id, max_const_id],
2409 ));
2410
2411 let op = match dst_scalar.kind {
2412 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2413 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2414 _ => unreachable!(),
2415 };
2416 Cast::Unary(op, clamp_id)
2417 }
2418 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2419 Cast::Unary(spirv::Op::FConvert, expr_id)
2420 }
2421 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2422 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2423 Cast::Unary(spirv::Op::SConvert, expr_id)
2424 }
2425 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2426 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2427 Cast::Unary(spirv::Op::UConvert, expr_id)
2428 }
2429 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2430 Cast::Unary(spirv::Op::SConvert, expr_id)
2431 }
2432 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2433 Cast::Unary(spirv::Op::UConvert, expr_id)
2434 }
2435 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2437 };
2438 Ok(match cast {
2439 Cast::Identity(expr) => expr,
2440 Cast::Unary(op, op1) => {
2441 let id = self.gen_id();
2442 block
2443 .body
2444 .push(Instruction::unary(op, result_type_id, id, op1));
2445 id
2446 }
2447 Cast::Binary(op, op1, op2) => {
2448 let id = self.gen_id();
2449 block
2450 .body
2451 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2452 id
2453 }
2454 Cast::Ternary(op, op1, op2, op3) => {
2455 let id = self.gen_id();
2456 block
2457 .body
2458 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2459 id
2460 }
2461 })
2462 }
2463
2464 fn write_access_chain(
2475 &mut self,
2476 mut expr_handle: Handle<crate::Expression>,
2477 block: &mut Block,
2478 type_adjustment: AccessTypeAdjustment,
2479 ) -> Result<ExpressionPointer, Error> {
2480 let result_type_id = {
2481 let resolution = &self.fun_info[expr_handle].ty;
2482 match type_adjustment {
2483 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2484 AccessTypeAdjustment::IntroducePointer(class) => {
2485 self.writer.get_resolution_pointer_id(resolution, class)
2486 }
2487 AccessTypeAdjustment::UseStd140CompatType => {
2488 match *resolution.inner_with(&self.ir_module.types) {
2489 crate::TypeInner::Pointer {
2490 base,
2491 space: space @ crate::AddressSpace::Uniform,
2492 } => self.writer.get_pointer_type_id(
2493 self.writer.std140_compat_uniform_types[&base].type_id,
2494 map_storage_class(space),
2495 ),
2496 _ => unreachable!(
2497 "`UseStd140CompatType` must only be used with uniform pointer types"
2498 ),
2499 }
2500 }
2501 }
2502 };
2503
2504 let mut accumulated_checks = None;
2508
2509 let mut is_non_uniform_binding_array = false;
2511
2512 let mut prev_decomposed_matrix_index = None;
2518
2519 self.temp_list.clear();
2520 let root_id = loop {
2521 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2524 break spilled.id;
2527 }
2528
2529 expr_handle = match self.ir_function.expressions[expr_handle] {
2530 crate::Expression::Access { base, index } => {
2531 is_non_uniform_binding_array |=
2532 self.is_nonuniform_binding_array_access(base, index);
2533
2534 let index = GuardedIndex::Expression(index);
2535 let index_id =
2536 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2537 self.temp_list.push(index_id);
2538
2539 base
2540 }
2541 crate::Expression::AccessIndex { base, index } => {
2542 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2545 let mut base_ty_handle = self.fun_info[base].ty.handle();
2546 let mut pointer_space = None;
2547 if let crate::TypeInner::Pointer { base, space } = *base_ty {
2548 base_ty = &self.ir_module.types[base].inner;
2549 base_ty_handle = Some(base);
2550 pointer_space = Some(space);
2551 }
2552 match *base_ty {
2553 crate::TypeInner::Struct { .. } => {
2560 let index = match base_ty_handle.and_then(|handle| {
2561 self.writer.std140_compat_uniform_types.get(&handle)
2562 }) {
2563 Some(std140_type_info)
2564 if pointer_space == Some(crate::AddressSpace::Uniform) =>
2565 {
2566 std140_type_info.member_indices[index as usize]
2567 + prev_decomposed_matrix_index.take().unwrap_or(0)
2568 }
2569 _ => index,
2570 };
2571 let index_id = self.get_index_constant(index);
2572 self.temp_list.push(index_id);
2573 }
2574 _ if is_uniform_matcx2_struct_member_access(
2581 self.ir_function,
2582 self.fun_info,
2583 self.ir_module,
2584 base,
2585 ) =>
2586 {
2587 assert!(prev_decomposed_matrix_index.is_none());
2588 prev_decomposed_matrix_index = Some(index);
2589 }
2590 _ => {
2591 let index_id = self.write_access_chain_index(
2598 base,
2599 GuardedIndex::Known(index),
2600 &mut accumulated_checks,
2601 block,
2602 )?;
2603 self.temp_list.push(index_id);
2604 }
2605 }
2606 base
2607 }
2608 crate::Expression::GlobalVariable(handle) => {
2609 let gv = &self.writer.global_variables[handle];
2610 break gv.access_id;
2611 }
2612 crate::Expression::LocalVariable(variable) => {
2613 let local_var = &self.function.variables[&variable];
2614 break local_var.id;
2615 }
2616 crate::Expression::FunctionArgument(index) => {
2617 break self.function.parameter_id(index);
2618 }
2619 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2620 }
2621 };
2622
2623 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2624 (
2625 root_id,
2626 ExpressionPointer::Ready {
2627 pointer_id: root_id,
2628 },
2629 )
2630 } else {
2631 self.temp_list.reverse();
2632 let pointer_id = self.gen_id();
2633 let access =
2634 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2635
2636 let expr_pointer = match accumulated_checks {
2641 Some(condition) => ExpressionPointer::Conditional { condition, access },
2642 None => {
2643 block.body.push(access);
2644 ExpressionPointer::Ready { pointer_id }
2645 }
2646 };
2647 (pointer_id, expr_pointer)
2648 };
2649 if is_non_uniform_binding_array {
2653 self.writer
2654 .decorate_non_uniform_binding_array_access(pointer_id)?;
2655 }
2656
2657 Ok(expr_pointer)
2658 }
2659
2660 fn is_nonuniform_binding_array_access(
2661 &mut self,
2662 base: Handle<crate::Expression>,
2663 index: Handle<crate::Expression>,
2664 ) -> bool {
2665 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2666 else {
2667 return false;
2668 };
2669
2670 let gvar = &self.ir_module.global_variables[var_handle];
2673 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2674 return false;
2675 };
2676
2677 self.fun_info[index].uniformity.non_uniform_result.is_some()
2678 }
2679
2680 fn write_access_chain_index(
2690 &mut self,
2691 base: Handle<crate::Expression>,
2692 index: GuardedIndex,
2693 accumulated_checks: &mut Option<Word>,
2694 block: &mut Block,
2695 ) -> Result<Word, Error> {
2696 match self.write_bounds_check(base, index, block)? {
2697 BoundsCheckResult::KnownInBounds(known_index) => {
2698 let scalar = crate::Literal::U32(known_index);
2701 Ok(self.writer.get_constant_scalar(scalar))
2702 }
2703 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2704 BoundsCheckResult::Conditional {
2705 condition_id: condition,
2706 index_id: index,
2707 } => {
2708 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2709
2710 Ok(index)
2712 }
2713 }
2714 }
2715
2716 fn extend_bounds_check_condition_chain(
2735 &mut self,
2736 chain: &mut Option<Word>,
2737 comparison_id: Word,
2738 block: &mut Block,
2739 ) {
2740 match *chain {
2741 Some(ref mut prior_checks) => {
2742 let combined = self.gen_id();
2743 block.body.push(Instruction::binary(
2744 spirv::Op::LogicalAnd,
2745 self.writer.get_bool_type_id(),
2746 combined,
2747 *prior_checks,
2748 comparison_id,
2749 ));
2750 *prior_checks = combined;
2751 }
2752 None => {
2753 *chain = Some(comparison_id);
2755 }
2756 }
2757 }
2758
2759 fn write_checked_load(
2760 &mut self,
2761 pointer: Handle<crate::Expression>,
2762 block: &mut Block,
2763 access_type_adjustment: AccessTypeAdjustment,
2764 result_type_id: Word,
2765 ) -> Result<Word, Error> {
2766 if let Some(result_id) = self.maybe_write_uniform_matcx2_dynamic_access(pointer, block)? {
2767 Ok(result_id)
2768 } else if let Some(result_id) =
2769 self.maybe_write_load_uniform_matcx2_struct_member(pointer, block)?
2770 {
2771 Ok(result_id)
2772 } else {
2773 struct WrappedLoad {
2779 access_type_adjustment: AccessTypeAdjustment,
2780 r#type: Handle<crate::Type>,
2781 }
2782 let mut wrapped_load = None;
2783 if let crate::TypeInner::Pointer {
2784 base: pointer_base_type,
2785 space: crate::AddressSpace::Uniform,
2786 } = *self.fun_info[pointer].ty.inner_with(&self.ir_module.types)
2787 {
2788 if self
2789 .writer
2790 .std140_compat_uniform_types
2791 .contains_key(&pointer_base_type)
2792 {
2793 wrapped_load = Some(WrappedLoad {
2794 access_type_adjustment: AccessTypeAdjustment::UseStd140CompatType,
2795 r#type: pointer_base_type,
2796 });
2797 };
2798 };
2799
2800 let (load_type_id, access_type_adjustment) = match wrapped_load {
2801 Some(ref wrapped_load) => (
2802 self.writer.std140_compat_uniform_types[&wrapped_load.r#type].type_id,
2803 wrapped_load.access_type_adjustment,
2804 ),
2805 None => (result_type_id, access_type_adjustment),
2806 };
2807
2808 let load_id = match self.write_access_chain(pointer, block, access_type_adjustment)? {
2809 ExpressionPointer::Ready { pointer_id } => {
2810 let id = self.gen_id();
2811 let atomic_space =
2812 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2813 crate::TypeInner::Pointer { base, space } => {
2814 match self.ir_module.types[base].inner {
2815 crate::TypeInner::Atomic { .. } => Some(space),
2816 _ => None,
2817 }
2818 }
2819 _ => None,
2820 };
2821 let instruction = if let Some(space) = atomic_space {
2822 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2823 let scope_constant_id = self.get_scope_constant(scope as u32);
2824 let semantics_id = self.get_index_constant(semantics.bits());
2825 Instruction::atomic_load(
2826 result_type_id,
2827 id,
2828 pointer_id,
2829 scope_constant_id,
2830 semantics_id,
2831 )
2832 } else {
2833 Instruction::load(load_type_id, id, pointer_id, None)
2834 };
2835 block.body.push(instruction);
2836 id
2837 }
2838 ExpressionPointer::Conditional { condition, access } => {
2839 self.write_conditional_indexed_load(
2841 load_type_id,
2842 condition,
2843 block,
2844 move |id_gen, block| {
2845 let pointer_id = access.result_id.unwrap();
2847 let value_id = id_gen.next();
2848 block.body.push(access);
2849 block.body.push(Instruction::load(
2850 load_type_id,
2851 value_id,
2852 pointer_id,
2853 None,
2854 ));
2855 value_id
2856 },
2857 )
2858 }
2859 };
2860
2861 match wrapped_load {
2862 Some(ref wrapped_load) => {
2863 let result_id = self.gen_id();
2866 let function_id = self.writer.wrapped_functions
2867 [&WrappedFunction::ConvertFromStd140CompatType {
2868 r#type: wrapped_load.r#type,
2869 }];
2870 block.body.push(Instruction::function_call(
2871 result_type_id,
2872 result_id,
2873 function_id,
2874 &[load_id],
2875 ));
2876 Ok(result_id)
2877 }
2878 None => Ok(load_id),
2879 }
2880 }
2881 }
2882
2883 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2884 use indexmap::map::Entry;
2885
2886 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2888 Entry::Occupied(preexisting) => preexisting.get().id,
2889 Entry::Vacant(vacant) => {
2890 let pointer_type_id = self.writer.get_resolution_pointer_id(
2893 &self.fun_info[base].ty,
2894 spirv::StorageClass::Function,
2895 );
2896 let id = self.writer.id_gen.next();
2897 vacant.insert(super::LocalVariable {
2898 id,
2899 instruction: Instruction::variable(
2900 pointer_type_id,
2901 id,
2902 spirv::StorageClass::Function,
2903 None,
2904 ),
2905 });
2906 id
2907 }
2908 };
2909
2910 let base_id = self.cached[base];
2935 block
2936 .body
2937 .push(Instruction::store(spill_variable_id, base_id, None));
2938 }
2939
2940 fn maybe_access_spilled_composite(
2957 &mut self,
2958 access: Handle<crate::Expression>,
2959 block: &mut Block,
2960 result_type_id: Word,
2961 ) -> Result<Word, Error> {
2962 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2963 if access_uses == self.fun_info[access].ref_count {
2964 Ok(0)
2968 } else {
2969 self.write_checked_load(
2974 access,
2975 block,
2976 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2977 result_type_id,
2978 )
2979 }
2980 }
2981
2982 #[allow(clippy::too_many_arguments)]
2984 fn write_matrix_matrix_column_op(
2985 &mut self,
2986 block: &mut Block,
2987 result_id: Word,
2988 result_type_id: Word,
2989 left_id: Word,
2990 right_id: Word,
2991 columns: crate::VectorSize,
2992 rows: crate::VectorSize,
2993 width: u8,
2994 op: spirv::Op,
2995 ) {
2996 self.temp_list.clear();
2997
2998 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2999 size: rows,
3000 scalar: crate::Scalar::float(width),
3001 });
3002
3003 for index in 0..columns as u32 {
3004 let column_id_left = self.gen_id();
3005 let column_id_right = self.gen_id();
3006 let column_id_res = self.gen_id();
3007
3008 block.body.push(Instruction::composite_extract(
3009 vector_type_id,
3010 column_id_left,
3011 left_id,
3012 &[index],
3013 ));
3014 block.body.push(Instruction::composite_extract(
3015 vector_type_id,
3016 column_id_right,
3017 right_id,
3018 &[index],
3019 ));
3020 block.body.push(Instruction::binary(
3021 op,
3022 vector_type_id,
3023 column_id_res,
3024 column_id_left,
3025 column_id_right,
3026 ));
3027
3028 self.temp_list.push(column_id_res);
3029 }
3030
3031 block.body.push(Instruction::composite_construct(
3032 result_type_id,
3033 result_id,
3034 &self.temp_list,
3035 ));
3036 }
3037
3038 fn write_vector_scalar_mult(
3040 &mut self,
3041 block: &mut Block,
3042 result_id: Word,
3043 result_type_id: Word,
3044 vector_id: Word,
3045 scalar_id: Word,
3046 vector: &crate::TypeInner,
3047 ) {
3048 let (size, kind) = match *vector {
3049 crate::TypeInner::Vector {
3050 size,
3051 scalar: crate::Scalar { kind, .. },
3052 } => (size, kind),
3053 _ => unreachable!(),
3054 };
3055
3056 let (op, operand_id) = match kind {
3057 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
3058 _ => {
3059 let operand_id = self.gen_id();
3060 self.temp_list.clear();
3061 self.temp_list.resize(size as usize, scalar_id);
3062 block.body.push(Instruction::composite_construct(
3063 result_type_id,
3064 operand_id,
3065 &self.temp_list,
3066 ));
3067 (spirv::Op::IMul, operand_id)
3068 }
3069 };
3070
3071 block.body.push(Instruction::binary(
3072 op,
3073 result_type_id,
3074 result_id,
3075 vector_id,
3076 operand_id,
3077 ));
3078 }
3079
3080 #[expect(clippy::too_many_arguments)]
3087 fn write_dot_product(
3088 &mut self,
3089 result_id: Word,
3090 result_type_id: Word,
3091 arg0_id: Word,
3092 arg1_id: Word,
3093 size: u32,
3094 block: &mut Block,
3095 extractor: impl Fn(Word, Word, Word) -> Instruction,
3096 ) {
3097 let mut partial_sum = self.writer.get_constant_null(result_type_id);
3098 let last_component = size - 1;
3099 for index in 0..=last_component {
3100 let a_id = self.gen_id();
3102 block.body.push(extractor(a_id, arg0_id, index));
3103 let b_id = self.gen_id();
3104 block.body.push(extractor(b_id, arg1_id, index));
3105 let prod_id = self.gen_id();
3106 block.body.push(Instruction::binary(
3107 spirv::Op::IMul,
3108 result_type_id,
3109 prod_id,
3110 a_id,
3111 b_id,
3112 ));
3113
3114 let id = if index == last_component {
3116 result_id
3117 } else {
3118 self.gen_id()
3119 };
3120
3121 block.body.push(Instruction::binary(
3123 spirv::Op::IAdd,
3124 result_type_id,
3125 id,
3126 partial_sum,
3127 prod_id,
3128 ));
3129 partial_sum = id;
3131 }
3132 }
3133
3134 fn write_pack4x8_optimized(
3136 &mut self,
3137 block: &mut Block,
3138 result_type_id: u32,
3139 arg0_id: u32,
3140 id: u32,
3141 is_signed: bool,
3142 should_clamp: bool,
3143 ) -> Instruction {
3144 let int_type = if is_signed {
3145 crate::ScalarKind::Sint
3146 } else {
3147 crate::ScalarKind::Uint
3148 };
3149 let wide_vector_type = NumericType::Vector {
3150 size: crate::VectorSize::Quad,
3151 scalar: crate::Scalar {
3152 kind: int_type,
3153 width: 4,
3154 },
3155 };
3156 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
3157 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3158 size: crate::VectorSize::Quad,
3159 scalar: crate::Scalar {
3160 kind: crate::ScalarKind::Uint,
3161 width: 1,
3162 },
3163 });
3164
3165 let mut wide_vector = arg0_id;
3166 if should_clamp {
3167 let (min, max, clamp_op) = if is_signed {
3168 (
3169 crate::Literal::I32(-128),
3170 crate::Literal::I32(127),
3171 spirv::GLOp::SClamp,
3172 )
3173 } else {
3174 (
3175 crate::Literal::U32(0),
3176 crate::Literal::U32(255),
3177 spirv::GLOp::UClamp,
3178 )
3179 };
3180 let [min, max] = [min, max].map(|lit| {
3181 let scalar = self.writer.get_constant_scalar(lit);
3182 self.writer.get_constant_composite(
3183 LookupType::Local(LocalType::Numeric(wide_vector_type)),
3184 &[scalar; 4],
3185 )
3186 });
3187
3188 let clamp_id = self.gen_id();
3189 block.body.push(Instruction::ext_inst_gl_op(
3190 self.writer.gl450_ext_inst_id,
3191 clamp_op,
3192 wide_vector_type_id,
3193 clamp_id,
3194 &[wide_vector, min, max],
3195 ));
3196
3197 wide_vector = clamp_id;
3198 }
3199
3200 let packed_vector = self.gen_id();
3201 block.body.push(Instruction::unary(
3202 spirv::Op::UConvert, packed_vector_type_id,
3204 packed_vector,
3205 wide_vector,
3206 ));
3207
3208 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
3213 }
3214
3215 fn write_pack4x8_polyfill(
3217 &mut self,
3218 block: &mut Block,
3219 result_type_id: u32,
3220 arg0_id: u32,
3221 id: u32,
3222 is_signed: bool,
3223 should_clamp: bool,
3224 ) -> Instruction {
3225 let int_type = if is_signed {
3226 crate::ScalarKind::Sint
3227 } else {
3228 crate::ScalarKind::Uint
3229 };
3230 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
3231 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3232 kind: int_type,
3233 width: 4,
3234 }));
3235
3236 let mut last_instruction = Instruction::new(spirv::Op::Nop);
3237
3238 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
3239 let mut preresult = zero;
3240 block
3241 .body
3242 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
3243
3244 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3245 const VEC_LENGTH: u8 = 4;
3246 for i in 0..u32::from(VEC_LENGTH) {
3247 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
3248 let mut extracted = self.gen_id();
3249 block.body.push(Instruction::binary(
3250 spirv::Op::CompositeExtract,
3251 int_type_id,
3252 extracted,
3253 arg0_id,
3254 i,
3255 ));
3256 if is_signed {
3257 let casted = self.gen_id();
3258 block.body.push(Instruction::unary(
3259 spirv::Op::Bitcast,
3260 uint_type_id,
3261 casted,
3262 extracted,
3263 ));
3264 extracted = casted;
3265 }
3266 if should_clamp {
3267 let (min, max, clamp_op) = if is_signed {
3268 (
3269 crate::Literal::I32(-128),
3270 crate::Literal::I32(127),
3271 spirv::GLOp::SClamp,
3272 )
3273 } else {
3274 (
3275 crate::Literal::U32(0),
3276 crate::Literal::U32(255),
3277 spirv::GLOp::UClamp,
3278 )
3279 };
3280 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
3281
3282 let clamp_id = self.gen_id();
3283 block.body.push(Instruction::ext_inst_gl_op(
3284 self.writer.gl450_ext_inst_id,
3285 clamp_op,
3286 result_type_id,
3287 clamp_id,
3288 &[extracted, min, max],
3289 ));
3290
3291 extracted = clamp_id;
3292 }
3293 let is_last = i == u32::from(VEC_LENGTH - 1);
3294 if is_last {
3295 last_instruction = Instruction::quaternary(
3296 spirv::Op::BitFieldInsert,
3297 result_type_id,
3298 id,
3299 preresult,
3300 extracted,
3301 offset,
3302 eight,
3303 )
3304 } else {
3305 let new_preresult = self.gen_id();
3306 block.body.push(Instruction::quaternary(
3307 spirv::Op::BitFieldInsert,
3308 result_type_id,
3309 new_preresult,
3310 preresult,
3311 extracted,
3312 offset,
3313 eight,
3314 ));
3315 preresult = new_preresult;
3316 }
3317 }
3318 last_instruction
3319 }
3320
3321 fn write_unpack4x8_optimized(
3323 &mut self,
3324 block: &mut Block,
3325 result_type_id: u32,
3326 arg0_id: u32,
3327 id: u32,
3328 is_signed: bool,
3329 ) -> Instruction {
3330 let (int_type, convert_op) = if is_signed {
3331 (crate::ScalarKind::Sint, spirv::Op::SConvert)
3332 } else {
3333 (crate::ScalarKind::Uint, spirv::Op::UConvert)
3334 };
3335
3336 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
3337 size: crate::VectorSize::Quad,
3338 scalar: crate::Scalar {
3339 kind: int_type,
3340 width: 1,
3341 },
3342 });
3343
3344 let packed_vector = self.gen_id();
3349 block.body.push(Instruction::unary(
3350 spirv::Op::Bitcast,
3351 packed_vector_type_id,
3352 packed_vector,
3353 arg0_id,
3354 ));
3355
3356 Instruction::unary(convert_op, result_type_id, id, packed_vector)
3357 }
3358
3359 fn write_unpack4x8_polyfill(
3361 &mut self,
3362 block: &mut Block,
3363 result_type_id: u32,
3364 arg0_id: u32,
3365 id: u32,
3366 is_signed: bool,
3367 ) -> Instruction {
3368 let (int_type, extract_op) = if is_signed {
3369 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
3370 } else {
3371 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
3372 };
3373
3374 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
3375
3376 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
3377 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
3378 kind: int_type,
3379 width: 4,
3380 }));
3381 block
3382 .body
3383 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
3384 let arg_id = if is_signed {
3385 let new_arg_id = self.gen_id();
3386 block.body.push(Instruction::unary(
3387 spirv::Op::Bitcast,
3388 sint_type_id,
3389 new_arg_id,
3390 arg0_id,
3391 ));
3392 new_arg_id
3393 } else {
3394 arg0_id
3395 };
3396
3397 const VEC_LENGTH: u8 = 4;
3398 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
3399 for (i, part_id) in parts.into_iter().enumerate() {
3400 let index = self
3401 .writer
3402 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
3403 block.body.push(Instruction::ternary(
3404 extract_op,
3405 int_type_id,
3406 part_id,
3407 arg_id,
3408 index,
3409 eight,
3410 ));
3411 }
3412
3413 Instruction::composite_construct(result_type_id, id, &parts)
3414 }
3415
3416 fn write_block(
3433 &mut self,
3434 label_id: Word,
3435 naga_block: &crate::Block,
3436 exit: BlockExit,
3437 loop_context: LoopContext,
3438 debug_info: Option<&DebugInfoInner>,
3439 ) -> Result<BlockExitDisposition, Error> {
3440 let mut block = Block::new(label_id);
3441 for (statement, span) in naga_block.span_iter() {
3442 if let (Some(debug_info), false) = (
3443 debug_info,
3444 matches!(
3445 statement,
3446 &(Statement::Block(..)
3447 | Statement::Break
3448 | Statement::Continue
3449 | Statement::Kill
3450 | Statement::Return { .. }
3451 | Statement::Loop { .. })
3452 ),
3453 ) {
3454 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3455 block.body.push(Instruction::line(
3456 debug_info.source_file_id,
3457 loc.line_number,
3458 loc.line_position,
3459 ));
3460 };
3461 match *statement {
3462 Statement::Emit(ref range) => {
3463 for handle in range.clone() {
3464 if !self.expression_constness.is_const(handle) {
3466 self.cache_expression_value(handle, &mut block)?;
3467 }
3468 }
3469 }
3470 Statement::Block(ref block_statements) => {
3471 let scope_id = self.gen_id();
3472 self.function.consume(block, Instruction::branch(scope_id));
3473
3474 let merge_id = self.gen_id();
3475 let merge_used = self.write_block(
3476 scope_id,
3477 block_statements,
3478 BlockExit::Branch { target: merge_id },
3479 loop_context,
3480 debug_info,
3481 )?;
3482
3483 match merge_used {
3484 BlockExitDisposition::Used => {
3485 block = Block::new(merge_id);
3486 }
3487 BlockExitDisposition::Discarded => {
3488 return Ok(BlockExitDisposition::Discarded);
3489 }
3490 }
3491 }
3492 Statement::If {
3493 condition,
3494 ref accept,
3495 ref reject,
3496 } => {
3497 if !(accept.is_empty() && reject.is_empty()) {
3503 let condition_id = self.cached[condition];
3504
3505 let merge_id = self.gen_id();
3506 block.body.push(Instruction::selection_merge(
3507 merge_id,
3508 spirv::SelectionControl::NONE,
3509 ));
3510
3511 let accept_id = if accept.is_empty() {
3512 None
3513 } else {
3514 Some(self.gen_id())
3515 };
3516 let reject_id = if reject.is_empty() {
3517 None
3518 } else {
3519 Some(self.gen_id())
3520 };
3521
3522 self.function.consume(
3523 block,
3524 Instruction::branch_conditional(
3525 condition_id,
3526 accept_id.unwrap_or(merge_id),
3527 reject_id.unwrap_or(merge_id),
3528 ),
3529 );
3530
3531 if let Some(block_id) = accept_id {
3532 let _ = self.write_block(
3537 block_id,
3538 accept,
3539 BlockExit::Branch { target: merge_id },
3540 loop_context,
3541 debug_info,
3542 )?;
3543 }
3544 if let Some(block_id) = reject_id {
3545 let _ = self.write_block(
3550 block_id,
3551 reject,
3552 BlockExit::Branch { target: merge_id },
3553 loop_context,
3554 debug_info,
3555 )?;
3556 }
3557
3558 block = Block::new(merge_id);
3559 }
3560 }
3561 Statement::Switch {
3562 selector,
3563 ref cases,
3564 } => {
3565 let selector_id = self.cached[selector];
3566
3567 let merge_id = self.gen_id();
3568 block.body.push(Instruction::selection_merge(
3569 merge_id,
3570 spirv::SelectionControl::NONE,
3571 ));
3572
3573 let mut default_id = None;
3574 let mut last_id = None;
3576
3577 let mut raw_cases = Vec::with_capacity(cases.len());
3578 let mut case_ids = Vec::with_capacity(cases.len());
3579 for case in cases.iter() {
3580 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3582
3583 if case.fall_through && case.body.is_empty() {
3584 last_id = Some(label_id);
3585 }
3586
3587 case_ids.push(label_id);
3588
3589 match case.value {
3590 crate::SwitchValue::I32(value) => {
3591 raw_cases.push(super::instructions::Case {
3592 value: value as Word,
3593 label_id,
3594 });
3595 }
3596 crate::SwitchValue::U32(value) => {
3597 raw_cases.push(super::instructions::Case { value, label_id });
3598 }
3599 crate::SwitchValue::Default => {
3600 default_id = Some(label_id);
3601 }
3602 }
3603 }
3604
3605 let default_id = default_id.unwrap();
3606
3607 self.function.consume(
3608 block,
3609 Instruction::switch(selector_id, default_id, &raw_cases),
3610 );
3611
3612 let inner_context = LoopContext {
3613 break_id: Some(merge_id),
3614 ..loop_context
3615 };
3616
3617 for (i, (case, label_id)) in cases
3618 .iter()
3619 .zip(case_ids.iter())
3620 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3621 .enumerate()
3622 {
3623 let case_finish_id = if case.fall_through {
3624 case_ids[i + 1]
3625 } else {
3626 merge_id
3627 };
3628 let _ = self.write_block(
3637 *label_id,
3638 &case.body,
3639 BlockExit::Branch {
3640 target: case_finish_id,
3641 },
3642 inner_context,
3643 debug_info,
3644 )?;
3645 }
3646
3647 block = Block::new(merge_id);
3648 }
3649 Statement::Loop {
3650 ref body,
3651 ref continuing,
3652 break_if,
3653 } => {
3654 let preamble_id = self.gen_id();
3655 self.function
3656 .consume(block, Instruction::branch(preamble_id));
3657
3658 let merge_id = self.gen_id();
3659 let body_id = self.gen_id();
3660 let continuing_id = self.gen_id();
3661
3662 block = Block::new(preamble_id);
3665 if let Some(debug_info) = debug_info {
3668 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3669 block.body.push(Instruction::line(
3670 debug_info.source_file_id,
3671 loc.line_number,
3672 loc.line_position,
3673 ))
3674 }
3675 block.body.push(Instruction::loop_merge(
3676 merge_id,
3677 continuing_id,
3678 spirv::SelectionControl::NONE,
3679 ));
3680
3681 if self.force_loop_bounding {
3682 block = self.write_force_bounded_loop_instructions(block, merge_id);
3683 }
3684 self.function.consume(block, Instruction::branch(body_id));
3685
3686 let _ = self.write_block(
3690 body_id,
3691 body,
3692 BlockExit::Branch {
3693 target: continuing_id,
3694 },
3695 LoopContext {
3696 continuing_id: Some(continuing_id),
3697 break_id: Some(merge_id),
3698 },
3699 debug_info,
3700 )?;
3701
3702 let exit = match break_if {
3703 Some(condition) => BlockExit::BreakIf {
3704 condition,
3705 preamble_id,
3706 },
3707 None => BlockExit::Branch {
3708 target: preamble_id,
3709 },
3710 };
3711
3712 let _ = self.write_block(
3716 continuing_id,
3717 continuing,
3718 exit,
3719 LoopContext {
3720 continuing_id: None,
3721 break_id: Some(merge_id),
3722 },
3723 debug_info,
3724 )?;
3725
3726 block = Block::new(merge_id);
3727 }
3728 Statement::Break => {
3729 self.function
3730 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3731 return Ok(BlockExitDisposition::Discarded);
3732 }
3733 Statement::Continue => {
3734 self.function.consume(
3735 block,
3736 Instruction::branch(loop_context.continuing_id.unwrap()),
3737 );
3738 return Ok(BlockExitDisposition::Discarded);
3739 }
3740 Statement::Return { value: Some(value) } => {
3741 let value_id = self.cached[value];
3742 let instruction = match self.function.entry_point_context {
3743 Some(ref context) => self.writer.write_entry_point_return(
3746 value_id,
3747 self.ir_function.result.as_ref().unwrap(),
3748 &context.results,
3749 &mut block.body,
3750 )?,
3751 None => Instruction::return_value(value_id),
3752 };
3753 self.function.consume(block, instruction);
3754 return Ok(BlockExitDisposition::Discarded);
3755 }
3756 Statement::Return { value: None } => {
3757 self.function.consume(block, Instruction::return_void());
3758 return Ok(BlockExitDisposition::Discarded);
3759 }
3760 Statement::Kill => {
3761 self.function.consume(block, Instruction::kill());
3762 return Ok(BlockExitDisposition::Discarded);
3763 }
3764 Statement::ControlBarrier(flags) => {
3765 self.writer.write_control_barrier(flags, &mut block.body);
3766 }
3767 Statement::MemoryBarrier(flags) => {
3768 self.writer.write_memory_barrier(flags, &mut block);
3769 }
3770 Statement::Store { pointer, value } => {
3771 let value_id = self.cached[value];
3772 match self.write_access_chain(
3773 pointer,
3774 &mut block,
3775 AccessTypeAdjustment::None,
3776 )? {
3777 ExpressionPointer::Ready { pointer_id } => {
3778 let atomic_space = match *self.fun_info[pointer]
3779 .ty
3780 .inner_with(&self.ir_module.types)
3781 {
3782 crate::TypeInner::Pointer { base, space } => {
3783 match self.ir_module.types[base].inner {
3784 crate::TypeInner::Atomic { .. } => Some(space),
3785 _ => None,
3786 }
3787 }
3788 _ => None,
3789 };
3790 let instruction = if let Some(space) = atomic_space {
3791 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3792 let scope_constant_id = self.get_scope_constant(scope as u32);
3793 let semantics_id = self.get_index_constant(semantics.bits());
3794 Instruction::atomic_store(
3795 pointer_id,
3796 scope_constant_id,
3797 semantics_id,
3798 value_id,
3799 )
3800 } else {
3801 Instruction::store(pointer_id, value_id, None)
3802 };
3803 block.body.push(instruction);
3804 }
3805 ExpressionPointer::Conditional { condition, access } => {
3806 let mut selection = Selection::start(&mut block, ());
3807 selection.if_true(self, condition, ());
3808
3809 let pointer_id = access.result_id.unwrap();
3811 selection.block().body.push(access);
3812 selection
3813 .block()
3814 .body
3815 .push(Instruction::store(pointer_id, value_id, None));
3816
3817 selection.finish(self, ());
3820 }
3821 };
3822 }
3823 Statement::ImageStore {
3824 image,
3825 coordinate,
3826 array_index,
3827 value,
3828 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3829 Statement::Call {
3830 function: local_function,
3831 ref arguments,
3832 result,
3833 } => {
3834 let id = self.gen_id();
3835 self.temp_list.clear();
3836 for &argument in arguments {
3837 self.temp_list.push(self.cached[argument]);
3838 }
3839
3840 let type_id = match result {
3841 Some(expr) => {
3842 self.cached[expr] = id;
3843 self.get_expression_type_id(&self.fun_info[expr].ty)
3844 }
3845 None => self.writer.void_type,
3846 };
3847
3848 block.body.push(Instruction::function_call(
3849 type_id,
3850 id,
3851 self.writer.lookup_function[&local_function],
3852 &self.temp_list,
3853 ));
3854 }
3855 Statement::Atomic {
3856 pointer,
3857 ref fun,
3858 value,
3859 result,
3860 } => {
3861 let id = self.gen_id();
3862 let result_type_id =
3866 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3867
3868 if let Some(result) = result {
3869 self.cached[result] = id;
3870 }
3871
3872 let pointer_id = match self.write_access_chain(
3873 pointer,
3874 &mut block,
3875 AccessTypeAdjustment::None,
3876 )? {
3877 ExpressionPointer::Ready { pointer_id } => pointer_id,
3878 ExpressionPointer::Conditional { .. } => {
3879 return Err(Error::FeatureNotImplemented(
3880 "Atomics out-of-bounds handling",
3881 ));
3882 }
3883 };
3884
3885 let space = self.fun_info[pointer]
3886 .ty
3887 .inner_with(&self.ir_module.types)
3888 .pointer_space()
3889 .unwrap();
3890 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3891 let scope_constant_id = self.get_scope_constant(scope as u32);
3892 let semantics_id = self.get_index_constant(semantics.bits());
3893 let value_id = self.cached[value];
3894 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3895
3896 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3897 return Err(Error::FeatureNotImplemented(
3898 "Atomics with non-scalar values",
3899 ));
3900 };
3901
3902 let instruction = match *fun {
3903 crate::AtomicFunction::Add => {
3904 let spirv_op = match scalar.kind {
3905 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3906 spirv::Op::AtomicIAdd
3907 }
3908 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3909 _ => unimplemented!(),
3910 };
3911 Instruction::atomic_binary(
3912 spirv_op,
3913 result_type_id,
3914 id,
3915 pointer_id,
3916 scope_constant_id,
3917 semantics_id,
3918 value_id,
3919 )
3920 }
3921 crate::AtomicFunction::Subtract => {
3922 let (spirv_op, value_id) = match scalar.kind {
3923 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3924 (spirv::Op::AtomicISub, value_id)
3925 }
3926 crate::ScalarKind::Float => {
3927 let neg_result_id = self.gen_id();
3930 block.body.push(Instruction::unary(
3931 spirv::Op::FNegate,
3932 result_type_id,
3933 neg_result_id,
3934 value_id,
3935 ));
3936 (spirv::Op::AtomicFAddEXT, neg_result_id)
3937 }
3938 _ => unimplemented!(),
3939 };
3940 Instruction::atomic_binary(
3941 spirv_op,
3942 result_type_id,
3943 id,
3944 pointer_id,
3945 scope_constant_id,
3946 semantics_id,
3947 value_id,
3948 )
3949 }
3950 crate::AtomicFunction::And => {
3951 let spirv_op = match scalar.kind {
3952 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3953 spirv::Op::AtomicAnd
3954 }
3955 _ => unimplemented!(),
3956 };
3957 Instruction::atomic_binary(
3958 spirv_op,
3959 result_type_id,
3960 id,
3961 pointer_id,
3962 scope_constant_id,
3963 semantics_id,
3964 value_id,
3965 )
3966 }
3967 crate::AtomicFunction::InclusiveOr => {
3968 let spirv_op = match scalar.kind {
3969 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3970 spirv::Op::AtomicOr
3971 }
3972 _ => unimplemented!(),
3973 };
3974 Instruction::atomic_binary(
3975 spirv_op,
3976 result_type_id,
3977 id,
3978 pointer_id,
3979 scope_constant_id,
3980 semantics_id,
3981 value_id,
3982 )
3983 }
3984 crate::AtomicFunction::ExclusiveOr => {
3985 let spirv_op = match scalar.kind {
3986 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3987 spirv::Op::AtomicXor
3988 }
3989 _ => unimplemented!(),
3990 };
3991 Instruction::atomic_binary(
3992 spirv_op,
3993 result_type_id,
3994 id,
3995 pointer_id,
3996 scope_constant_id,
3997 semantics_id,
3998 value_id,
3999 )
4000 }
4001 crate::AtomicFunction::Min => {
4002 let spirv_op = match scalar.kind {
4003 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
4004 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
4005 _ => unimplemented!(),
4006 };
4007 Instruction::atomic_binary(
4008 spirv_op,
4009 result_type_id,
4010 id,
4011 pointer_id,
4012 scope_constant_id,
4013 semantics_id,
4014 value_id,
4015 )
4016 }
4017 crate::AtomicFunction::Max => {
4018 let spirv_op = match scalar.kind {
4019 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
4020 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
4021 _ => unimplemented!(),
4022 };
4023 Instruction::atomic_binary(
4024 spirv_op,
4025 result_type_id,
4026 id,
4027 pointer_id,
4028 scope_constant_id,
4029 semantics_id,
4030 value_id,
4031 )
4032 }
4033 crate::AtomicFunction::Exchange { compare: None } => {
4034 Instruction::atomic_binary(
4035 spirv::Op::AtomicExchange,
4036 result_type_id,
4037 id,
4038 pointer_id,
4039 scope_constant_id,
4040 semantics_id,
4041 value_id,
4042 )
4043 }
4044 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
4045 let scalar_type_id =
4046 self.get_numeric_type_id(NumericType::Scalar(scalar));
4047 let bool_type_id =
4048 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
4049
4050 let cas_result_id = self.gen_id();
4051 let equality_result_id = self.gen_id();
4052 let equality_operator = match scalar.kind {
4053 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
4054 spirv::Op::IEqual
4055 }
4056 _ => unimplemented!(),
4057 };
4058
4059 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
4060 cas_instr.set_type(scalar_type_id);
4061 cas_instr.set_result(cas_result_id);
4062 cas_instr.add_operand(pointer_id);
4063 cas_instr.add_operand(scope_constant_id);
4064 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
4067 cas_instr.add_operand(self.cached[cmp]);
4068 block.body.push(cas_instr);
4069 block.body.push(Instruction::binary(
4070 equality_operator,
4071 bool_type_id,
4072 equality_result_id,
4073 cas_result_id,
4074 self.cached[cmp],
4075 ));
4076 Instruction::composite_construct(
4077 result_type_id,
4078 id,
4079 &[cas_result_id, equality_result_id],
4080 )
4081 }
4082 };
4083
4084 block.body.push(instruction);
4085 }
4086 Statement::ImageAtomic {
4087 image,
4088 coordinate,
4089 array_index,
4090 fun,
4091 value,
4092 } => {
4093 self.write_image_atomic(
4094 image,
4095 coordinate,
4096 array_index,
4097 fun,
4098 value,
4099 &mut block,
4100 )?;
4101 }
4102 Statement::WorkGroupUniformLoad { pointer, result } => {
4103 self.writer
4104 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4105 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4106 let id = self.write_checked_load(
4109 pointer,
4110 &mut block,
4111 AccessTypeAdjustment::None,
4112 result_type_id,
4113 )?;
4114 self.cached[result] = id;
4115 self.writer
4116 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
4117 }
4118 Statement::RayQuery { query, ref fun } => {
4119 self.write_ray_query_function(query, fun, &mut block);
4120 }
4121 Statement::SubgroupBallot {
4122 result,
4123 ref predicate,
4124 } => {
4125 self.write_subgroup_ballot(predicate, result, &mut block)?;
4126 }
4127 Statement::SubgroupCollectiveOperation {
4128 ref op,
4129 ref collective_op,
4130 argument,
4131 result,
4132 } => {
4133 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
4134 }
4135 Statement::SubgroupGather {
4136 ref mode,
4137 argument,
4138 result,
4139 } => {
4140 self.write_subgroup_gather(mode, argument, result, &mut block)?;
4141 }
4142 Statement::CooperativeStore { target, ref data } => {
4143 let target_id = self.cached[target];
4144 let layout = if data.row_major {
4145 spirv::CooperativeMatrixLayout::RowMajorKHR
4146 } else {
4147 spirv::CooperativeMatrixLayout::ColumnMajorKHR
4148 };
4149 let layout_id = self.get_index_constant(layout as u32);
4150 let stride_id = self.cached[data.stride];
4151 match self.write_access_chain(
4152 data.pointer,
4153 &mut block,
4154 AccessTypeAdjustment::None,
4155 )? {
4156 ExpressionPointer::Ready { pointer_id } => {
4157 block.body.push(Instruction::coop_store(
4158 target_id, pointer_id, layout_id, stride_id,
4159 ));
4160 }
4161 ExpressionPointer::Conditional { condition, access } => {
4162 let mut selection = Selection::start(&mut block, ());
4163 selection.if_true(self, condition, ());
4164
4165 let pointer_id = access.result_id.unwrap();
4167 selection.block().body.push(access);
4168 selection.block().body.push(Instruction::coop_store(
4169 target_id, pointer_id, layout_id, stride_id,
4170 ));
4171
4172 selection.finish(self, ());
4175 }
4176 };
4177 }
4178 Statement::RayPipelineFunction(_) => unreachable!(),
4179 }
4180 }
4181
4182 let termination = match exit {
4183 BlockExit::Return => match self.ir_function.result {
4186 Some(ref result) if self.function.entry_point_context.is_none() => {
4187 let type_id = self.get_handle_type_id(result.ty);
4188 let null_id = self.writer.get_constant_null(type_id);
4189 Instruction::return_value(null_id)
4190 }
4191 _ => Instruction::return_void(),
4192 },
4193 BlockExit::Branch { target } => Instruction::branch(target),
4194 BlockExit::BreakIf {
4195 condition,
4196 preamble_id,
4197 } => {
4198 let condition_id = self.cached[condition];
4199
4200 Instruction::branch_conditional(
4201 condition_id,
4202 loop_context.break_id.unwrap(),
4203 preamble_id,
4204 )
4205 }
4206 };
4207
4208 self.function.consume(block, termination);
4209 Ok(BlockExitDisposition::Used)
4210 }
4211
4212 pub(super) fn write_function_body(
4213 &mut self,
4214 entry_id: Word,
4215 debug_info: Option<&DebugInfoInner>,
4216 ) -> Result<(), Error> {
4217 let _ = self.write_block(
4220 entry_id,
4221 &self.ir_function.body,
4222 BlockExit::Return,
4223 LoopContext::default(),
4224 debug_info,
4225 )?;
4226
4227 Ok(())
4228 }
4229}