1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
12 Instruction, LocalType, LookupType, NumericType, ResultMember, WrappedFunction, Writer,
13 WriterFlags,
14};
15use crate::{arena::Handle, proc::index::GuardedIndex, Statement};
16
17fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
18 match *type_inner {
19 crate::TypeInner::Scalar(_) => Dimension::Scalar,
20 crate::TypeInner::Vector { .. } => Dimension::Vector,
21 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
22 _ => unreachable!(),
23 }
24}
25
26enum AccessTypeAdjustment {
38 None,
47
48 IntroducePointer(spirv::StorageClass),
72}
73
74enum ExpressionPointer {
78 Ready { pointer_id: Word },
81
82 Conditional {
88 condition: Word,
89 access: Instruction,
90 },
91}
92
93enum BlockExit {
95 Return,
97 Branch {
99 target: Word,
101 },
102 BreakIf {
108 condition: Handle<crate::Expression>,
110 preamble_id: Word,
112 },
113}
114
115#[must_use]
126enum BlockExitDisposition {
127 Used,
131
132 Discarded,
137}
138
139#[derive(Clone, Copy, Default)]
140struct LoopContext {
141 continuing_id: Option<Word>,
142 break_id: Option<Word>,
143}
144
145#[derive(Debug)]
146pub(crate) struct DebugInfoInner<'a> {
147 pub source_code: &'a str,
148 pub source_file_id: Word,
149}
150
151impl Writer {
152 fn write_epilogue_position_y_flip(
157 &mut self,
158 position_id: Word,
159 body: &mut Vec<Instruction>,
160 ) -> Result<(), Error> {
161 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
162 let index_y_id = self.get_index_constant(1);
163 let access_id = self.id_gen.next();
164 body.push(Instruction::access_chain(
165 float_ptr_type_id,
166 access_id,
167 position_id,
168 &[index_y_id],
169 ));
170
171 let float_type_id = self.get_f32_type_id();
172 let load_id = self.id_gen.next();
173 body.push(Instruction::load(float_type_id, load_id, access_id, None));
174
175 let neg_id = self.id_gen.next();
176 body.push(Instruction::unary(
177 spirv::Op::FNegate,
178 float_type_id,
179 neg_id,
180 load_id,
181 ));
182
183 body.push(Instruction::store(access_id, neg_id, None));
184 Ok(())
185 }
186
187 fn write_epilogue_frag_depth_clamp(
189 &mut self,
190 frag_depth_id: Word,
191 body: &mut Vec<Instruction>,
192 ) -> Result<(), Error> {
193 let float_type_id = self.get_f32_type_id();
194 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
195 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
196
197 let original_id = self.id_gen.next();
198 body.push(Instruction::load(
199 float_type_id,
200 original_id,
201 frag_depth_id,
202 None,
203 ));
204
205 let clamp_id = self.id_gen.next();
206 body.push(Instruction::ext_inst(
207 self.gl450_ext_inst_id,
208 spirv::GLOp::FClamp,
209 float_type_id,
210 clamp_id,
211 &[original_id, zero_scalar_id, one_scalar_id],
212 ));
213
214 body.push(Instruction::store(frag_depth_id, clamp_id, None));
215 Ok(())
216 }
217
218 fn write_entry_point_return(
219 &mut self,
220 value_id: Word,
221 ir_result: &crate::FunctionResult,
222 result_members: &[ResultMember],
223 body: &mut Vec<Instruction>,
224 ) -> Result<(), Error> {
225 for (index, res_member) in result_members.iter().enumerate() {
226 let member_value_id = match ir_result.binding {
227 Some(_) => value_id,
228 None => {
229 let member_value_id = self.id_gen.next();
230 body.push(Instruction::composite_extract(
231 res_member.type_id,
232 member_value_id,
233 value_id,
234 &[index as u32],
235 ));
236 member_value_id
237 }
238 };
239
240 body.push(Instruction::store(res_member.id, member_value_id, None));
241
242 match res_member.built_in {
243 Some(crate::BuiltIn::Position { .. })
244 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
245 {
246 self.write_epilogue_position_y_flip(res_member.id, body)?;
247 }
248 Some(crate::BuiltIn::FragDepth)
249 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
250 {
251 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
252 }
253 _ => {}
254 }
255 }
256 Ok(())
257 }
258}
259
260impl BlockContext<'_> {
261 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
274 let uint_type_id = self.writer.get_u32_type_id();
275 let uint2_type_id = self.writer.get_vec2u_type_id();
276 let uint2_ptr_type_id = self
277 .writer
278 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
279 let bool_type_id = self.writer.get_bool_type_id();
280 let bool2_type_id = self.writer.get_vec2_bool_type_id();
281 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
282 let zero_uint2_const_id = self.writer.get_constant_composite(
283 LookupType::Local(LocalType::Numeric(NumericType::Vector {
284 size: crate::VectorSize::Bi,
285 scalar: crate::Scalar::U32,
286 })),
287 &[zero_uint_const_id, zero_uint_const_id],
288 );
289 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
290 let max_uint_const_id = self
291 .writer
292 .get_constant_scalar(crate::Literal::U32(u32::MAX));
293 let max_uint2_const_id = self.writer.get_constant_composite(
294 LookupType::Local(LocalType::Numeric(NumericType::Vector {
295 size: crate::VectorSize::Bi,
296 scalar: crate::Scalar::U32,
297 })),
298 &[max_uint_const_id, max_uint_const_id],
299 );
300
301 let loop_counter_var_id = self.gen_id();
302 if self.writer.flags.contains(WriterFlags::DEBUG) {
303 self.writer
304 .debugs
305 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
306 }
307 let var = super::LocalVariable {
308 id: loop_counter_var_id,
309 instruction: Instruction::variable(
310 uint2_ptr_type_id,
311 loop_counter_var_id,
312 spirv::StorageClass::Function,
313 Some(max_uint2_const_id),
314 ),
315 };
316 self.function.force_loop_bounding_vars.push(var);
317
318 let break_if_block = self.gen_id();
319
320 self.function
321 .consume(block, Instruction::branch(break_if_block));
322 block = Block::new(break_if_block);
323
324 let load_id = self.gen_id();
327 block.body.push(Instruction::load(
328 uint2_type_id,
329 load_id,
330 loop_counter_var_id,
331 None,
332 ));
333
334 let eq_id = self.gen_id();
337 block.body.push(Instruction::binary(
338 spirv::Op::IEqual,
339 bool2_type_id,
340 eq_id,
341 zero_uint2_const_id,
342 load_id,
343 ));
344 let all_eq_id = self.gen_id();
345 block.body.push(Instruction::relational(
346 spirv::Op::All,
347 bool_type_id,
348 all_eq_id,
349 eq_id,
350 ));
351
352 let inc_counter_block_id = self.gen_id();
353 block.body.push(Instruction::selection_merge(
354 inc_counter_block_id,
355 spirv::SelectionControl::empty(),
356 ));
357 self.function.consume(
358 block,
359 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
360 );
361 block = Block::new(inc_counter_block_id);
362
363 let low_id = self.gen_id();
369 block.body.push(Instruction::composite_extract(
370 uint_type_id,
371 low_id,
372 load_id,
373 &[1],
374 ));
375 let low_overflow_id = self.gen_id();
376 block.body.push(Instruction::binary(
377 spirv::Op::IEqual,
378 bool_type_id,
379 low_overflow_id,
380 low_id,
381 zero_uint_const_id,
382 ));
383 let carry_bit_id = self.gen_id();
384 block.body.push(Instruction::select(
385 uint_type_id,
386 carry_bit_id,
387 low_overflow_id,
388 one_uint_const_id,
389 zero_uint_const_id,
390 ));
391 let decrement_id = self.gen_id();
392 block.body.push(Instruction::composite_construct(
393 uint2_type_id,
394 decrement_id,
395 &[carry_bit_id, one_uint_const_id],
396 ));
397 let result_id = self.gen_id();
398 block.body.push(Instruction::binary(
399 spirv::Op::ISub,
400 uint2_type_id,
401 result_id,
402 load_id,
403 decrement_id,
404 ));
405 block
406 .body
407 .push(Instruction::store(loop_counter_var_id, result_id, None));
408
409 block
410 }
411
412 pub(super) fn cache_expression_value(
414 &mut self,
415 expr_handle: Handle<crate::Expression>,
416 block: &mut Block,
417 ) -> Result<(), Error> {
418 let is_named_expression = self
419 .ir_function
420 .named_expressions
421 .contains_key(&expr_handle);
422
423 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
424 return Ok(());
425 }
426
427 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
428 let id = match self.ir_function.expressions[expr_handle] {
429 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
430 crate::Expression::Constant(handle) => {
431 let init = self.ir_module.constants[handle].init;
432 self.writer.constant_ids[init]
433 }
434 crate::Expression::Override(_) => return Err(Error::Override),
435 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
436 crate::Expression::Compose { ty, ref components } => {
437 self.temp_list.clear();
438 if self.expression_constness.is_const(expr_handle) {
439 self.temp_list.extend(
440 crate::proc::flatten_compose(
441 ty,
442 components,
443 &self.ir_function.expressions,
444 &self.ir_module.types,
445 )
446 .map(|component| self.cached[component]),
447 );
448 self.writer
449 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
450 } else {
451 self.temp_list
452 .extend(components.iter().map(|&component| self.cached[component]));
453
454 let id = self.gen_id();
455 block.body.push(Instruction::composite_construct(
456 result_type_id,
457 id,
458 &self.temp_list,
459 ));
460 id
461 }
462 }
463 crate::Expression::Splat { size, value } => {
464 let value_id = self.cached[value];
465 let components = &[value_id; 4][..size as usize];
466
467 if self.expression_constness.is_const(expr_handle) {
468 let ty = self
469 .writer
470 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
471 self.writer.get_constant_composite(ty, components)
472 } else {
473 let id = self.gen_id();
474 block.body.push(Instruction::composite_construct(
475 result_type_id,
476 id,
477 components,
478 ));
479 id
480 }
481 }
482 crate::Expression::Access { base, index } => {
483 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
484 match *base_ty_inner {
485 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
486 0
492 }
493 _ if self.function.spilled_accesses.contains(base) => {
494 self.function.spilled_accesses.insert(expr_handle);
502 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
503 }
504 crate::TypeInner::Vector { .. } => {
505 self.write_vector_access(expr_handle, base, index, block)?
506 }
507 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
508 match GuardedIndex::from_expression(
510 index,
511 &self.ir_function.expressions,
512 self.ir_module,
513 ) {
514 GuardedIndex::Known(value) => {
515 let id = self.gen_id();
525 let base_id = self.cached[base];
526 block.body.push(Instruction::composite_extract(
527 result_type_id,
528 id,
529 base_id,
530 &[value],
531 ));
532 id
533 }
534 GuardedIndex::Expression(_) => {
535 self.spill_to_internal_variable(base, block);
542
543 self.function.spilled_accesses.insert(expr_handle);
546 self.maybe_access_spilled_composite(
547 expr_handle,
548 block,
549 result_type_id,
550 )?
551 }
552 }
553 }
554 crate::TypeInner::BindingArray {
555 base: binding_type, ..
556 } => {
557 let result_id = match self.write_access_chain(
560 expr_handle,
561 block,
562 AccessTypeAdjustment::IntroducePointer(
563 spirv::StorageClass::UniformConstant,
564 ),
565 )? {
566 ExpressionPointer::Ready { pointer_id } => pointer_id,
567 ExpressionPointer::Conditional { .. } => {
568 return Err(Error::FeatureNotImplemented(
569 "Texture array out-of-bounds handling",
570 ));
571 }
572 };
573
574 let binding_type_id = self.get_handle_type_id(binding_type);
575
576 let load_id = self.gen_id();
577 block.body.push(Instruction::load(
578 binding_type_id,
579 load_id,
580 result_id,
581 None,
582 ));
583
584 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
588 self.writer
589 .decorate_non_uniform_binding_array_access(load_id)?;
590 }
591
592 load_id
593 }
594 ref other => {
595 log::error!(
596 "Unable to access base {:?} of type {:?}",
597 self.ir_function.expressions[base],
598 other
599 );
600 return Err(Error::Validation(
601 "only vectors and arrays may be dynamically indexed by value",
602 ));
603 }
604 }
605 }
606 crate::Expression::AccessIndex { base, index } => {
607 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
608 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
609 0
615 }
616 _ if self.function.spilled_accesses.contains(base) => {
617 self.function.spilled_accesses.insert(expr_handle);
625 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
626 }
627 crate::TypeInner::Vector { .. }
628 | crate::TypeInner::Matrix { .. }
629 | crate::TypeInner::Array { .. }
630 | crate::TypeInner::Struct { .. } => {
631 let id = self.gen_id();
636 let base_id = self.cached[base];
637 block.body.push(Instruction::composite_extract(
638 result_type_id,
639 id,
640 base_id,
641 &[index],
642 ));
643 id
644 }
645 crate::TypeInner::BindingArray {
646 base: binding_type, ..
647 } => {
648 let result_id = match self.write_access_chain(
651 expr_handle,
652 block,
653 AccessTypeAdjustment::IntroducePointer(
654 spirv::StorageClass::UniformConstant,
655 ),
656 )? {
657 ExpressionPointer::Ready { pointer_id } => pointer_id,
658 ExpressionPointer::Conditional { .. } => {
659 return Err(Error::FeatureNotImplemented(
660 "Texture array out-of-bounds handling",
661 ));
662 }
663 };
664
665 let binding_type_id = self.get_handle_type_id(binding_type);
666
667 let load_id = self.gen_id();
668 block.body.push(Instruction::load(
669 binding_type_id,
670 load_id,
671 result_id,
672 None,
673 ));
674
675 load_id
676 }
677 ref other => {
678 log::error!("Unable to access index of {other:?}");
679 return Err(Error::FeatureNotImplemented("access index for type"));
680 }
681 }
682 }
683 crate::Expression::GlobalVariable(handle) => {
684 self.writer.global_variables[handle].access_id
685 }
686 crate::Expression::Swizzle {
687 size,
688 vector,
689 pattern,
690 } => {
691 let vector_id = self.cached[vector];
692 self.temp_list.clear();
693 for &sc in pattern[..size as usize].iter() {
694 self.temp_list.push(sc as Word);
695 }
696 let id = self.gen_id();
697 block.body.push(Instruction::vector_shuffle(
698 result_type_id,
699 id,
700 vector_id,
701 vector_id,
702 &self.temp_list,
703 ));
704 id
705 }
706 crate::Expression::Unary { op, expr } => {
707 let id = self.gen_id();
708 let expr_id = self.cached[expr];
709 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
710
711 let spirv_op = match op {
712 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
713 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
714 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
715 _ => return Err(Error::Validation("Unexpected kind for negation")),
716 },
717 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
718 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
719 };
720
721 block
722 .body
723 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
724 id
725 }
726 crate::Expression::Binary { op, left, right } => {
727 let id = self.gen_id();
728 let left_id = self.cached[left];
729 let right_id = self.cached[right];
730 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
731 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
732
733 if let Some(function_id) =
734 self.writer
735 .wrapped_functions
736 .get(&WrappedFunction::BinaryOp {
737 op,
738 left_type_id,
739 right_type_id,
740 })
741 {
742 block.body.push(Instruction::function_call(
743 result_type_id,
744 id,
745 *function_id,
746 &[left_id, right_id],
747 ));
748 } else {
749 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
750 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
751
752 let left_dimension = get_dimension(left_ty_inner);
753 let right_dimension = get_dimension(right_ty_inner);
754
755 let mut reverse_operands = false;
756
757 let spirv_op = match op {
758 crate::BinaryOperator::Add => match *left_ty_inner {
759 crate::TypeInner::Scalar(scalar)
760 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
761 crate::ScalarKind::Float => spirv::Op::FAdd,
762 _ => spirv::Op::IAdd,
763 },
764 crate::TypeInner::Matrix {
765 columns,
766 rows,
767 scalar,
768 } => {
769 self.write_matrix_matrix_column_op(
770 block,
771 id,
772 result_type_id,
773 left_id,
774 right_id,
775 columns,
776 rows,
777 scalar.width,
778 spirv::Op::FAdd,
779 );
780
781 self.cached[expr_handle] = id;
782 return Ok(());
783 }
784 _ => unimplemented!(),
785 },
786 crate::BinaryOperator::Subtract => match *left_ty_inner {
787 crate::TypeInner::Scalar(scalar)
788 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
789 crate::ScalarKind::Float => spirv::Op::FSub,
790 _ => spirv::Op::ISub,
791 },
792 crate::TypeInner::Matrix {
793 columns,
794 rows,
795 scalar,
796 } => {
797 self.write_matrix_matrix_column_op(
798 block,
799 id,
800 result_type_id,
801 left_id,
802 right_id,
803 columns,
804 rows,
805 scalar.width,
806 spirv::Op::FSub,
807 );
808
809 self.cached[expr_handle] = id;
810 return Ok(());
811 }
812 _ => unimplemented!(),
813 },
814 crate::BinaryOperator::Multiply => {
815 match (left_dimension, right_dimension) {
816 (Dimension::Scalar, Dimension::Vector) => {
817 self.write_vector_scalar_mult(
818 block,
819 id,
820 result_type_id,
821 right_id,
822 left_id,
823 right_ty_inner,
824 );
825
826 self.cached[expr_handle] = id;
827 return Ok(());
828 }
829 (Dimension::Vector, Dimension::Scalar) => {
830 self.write_vector_scalar_mult(
831 block,
832 id,
833 result_type_id,
834 left_id,
835 right_id,
836 left_ty_inner,
837 );
838
839 self.cached[expr_handle] = id;
840 return Ok(());
841 }
842 (Dimension::Vector, Dimension::Matrix) => {
843 spirv::Op::VectorTimesMatrix
844 }
845 (Dimension::Matrix, Dimension::Scalar) => {
846 spirv::Op::MatrixTimesScalar
847 }
848 (Dimension::Scalar, Dimension::Matrix) => {
849 reverse_operands = true;
850 spirv::Op::MatrixTimesScalar
851 }
852 (Dimension::Matrix, Dimension::Vector) => {
853 spirv::Op::MatrixTimesVector
854 }
855 (Dimension::Matrix, Dimension::Matrix) => {
856 spirv::Op::MatrixTimesMatrix
857 }
858 (Dimension::Vector, Dimension::Vector)
859 | (Dimension::Scalar, Dimension::Scalar)
860 if left_ty_inner.scalar_kind()
861 == Some(crate::ScalarKind::Float) =>
862 {
863 spirv::Op::FMul
864 }
865 (Dimension::Vector, Dimension::Vector)
866 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
867 }
868 }
869 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
870 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
871 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
872 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
873 _ => unimplemented!(),
874 },
875 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
876 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
879 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
880 unreachable!("Should have been handled by wrapped function")
881 }
882 _ => unimplemented!(),
883 },
884 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
885 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
886 spirv::Op::IEqual
887 }
888 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
889 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
890 _ => unimplemented!(),
891 },
892 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
893 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
894 spirv::Op::INotEqual
895 }
896 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
897 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
898 _ => unimplemented!(),
899 },
900 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
901 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
902 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
903 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
904 _ => unimplemented!(),
905 },
906 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
907 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
908 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
909 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
910 _ => unimplemented!(),
911 },
912 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
913 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
914 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
915 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
916 _ => unimplemented!(),
917 },
918 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
919 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
920 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
921 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
922 _ => unimplemented!(),
923 },
924 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
925 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
926 _ => spirv::Op::BitwiseAnd,
927 },
928 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
929 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
930 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
931 _ => spirv::Op::BitwiseOr,
932 },
933 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
934 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
935 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
936 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
937 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
938 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
939 _ => unimplemented!(),
940 },
941 };
942
943 block.body.push(Instruction::binary(
944 spirv_op,
945 result_type_id,
946 id,
947 if reverse_operands { right_id } else { left_id },
948 if reverse_operands { left_id } else { right_id },
949 ));
950 }
951 id
952 }
953 crate::Expression::Math {
954 fun,
955 arg,
956 arg1,
957 arg2,
958 arg3,
959 } => {
960 use crate::MathFunction as Mf;
961 enum MathOp {
962 Ext(spirv::GLOp),
963 Custom(Instruction),
964 }
965
966 let arg0_id = self.cached[arg];
967 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
968 let arg_scalar_kind = arg_ty.scalar_kind();
969 let arg1_id = match arg1 {
970 Some(handle) => self.cached[handle],
971 None => 0,
972 };
973 let arg2_id = match arg2 {
974 Some(handle) => self.cached[handle],
975 None => 0,
976 };
977 let arg3_id = match arg3 {
978 Some(handle) => self.cached[handle],
979 None => 0,
980 };
981
982 let id = self.gen_id();
983 let math_op = match fun {
984 Mf::Abs => {
986 match arg_scalar_kind {
987 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
988 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
989 Some(crate::ScalarKind::Uint) => {
990 MathOp::Custom(Instruction::unary(
991 spirv::Op::CopyObject, result_type_id,
993 id,
994 arg0_id,
995 ))
996 }
997 other => unimplemented!("Unexpected abs({:?})", other),
998 }
999 }
1000 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1001 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1002 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1003 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1004 other => unimplemented!("Unexpected min({:?})", other),
1005 }),
1006 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1007 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1008 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1009 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1010 other => unimplemented!("Unexpected max({:?})", other),
1011 }),
1012 Mf::Clamp => match arg_scalar_kind {
1013 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1017 Some(_) => {
1018 let (min_op, max_op) = match arg_scalar_kind {
1019 Some(crate::ScalarKind::Sint) => {
1020 (spirv::GLOp::SMin, spirv::GLOp::SMax)
1021 }
1022 Some(crate::ScalarKind::Uint) => {
1023 (spirv::GLOp::UMin, spirv::GLOp::UMax)
1024 }
1025 _ => unreachable!(),
1026 };
1027
1028 let max_id = self.gen_id();
1029 block.body.push(Instruction::ext_inst(
1030 self.writer.gl450_ext_inst_id,
1031 max_op,
1032 result_type_id,
1033 max_id,
1034 &[arg0_id, arg1_id],
1035 ));
1036
1037 MathOp::Custom(Instruction::ext_inst(
1038 self.writer.gl450_ext_inst_id,
1039 min_op,
1040 result_type_id,
1041 id,
1042 &[max_id, arg2_id],
1043 ))
1044 }
1045 other => unimplemented!("Unexpected max({:?})", other),
1046 },
1047 Mf::Saturate => {
1048 let (maybe_size, scalar) = match *arg_ty {
1049 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1050 crate::TypeInner::Scalar(scalar) => (None, scalar),
1051 ref other => unimplemented!("Unexpected saturate({:?})", other),
1052 };
1053 let scalar = crate::Scalar::float(scalar.width);
1054 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1055 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1056
1057 if let Some(size) = maybe_size {
1058 let ty =
1059 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1060
1061 self.temp_list.clear();
1062 self.temp_list.resize(size as _, arg1_id);
1063
1064 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1065
1066 self.temp_list.fill(arg2_id);
1067
1068 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1069 }
1070
1071 MathOp::Custom(Instruction::ext_inst(
1072 self.writer.gl450_ext_inst_id,
1073 spirv::GLOp::FClamp,
1074 result_type_id,
1075 id,
1076 &[arg0_id, arg1_id, arg2_id],
1077 ))
1078 }
1079 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1081 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1082 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1083 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1084 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1085 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1086 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1087 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1088 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1089 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1090 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1091 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1092 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1093 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1094 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1095 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1097 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1098 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1099 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1100 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1101 Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1102 Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1103 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1104 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1106 crate::TypeInner::Vector {
1107 scalar:
1108 crate::Scalar {
1109 kind: crate::ScalarKind::Float,
1110 ..
1111 },
1112 ..
1113 } => MathOp::Custom(Instruction::binary(
1114 spirv::Op::Dot,
1115 result_type_id,
1116 id,
1117 arg0_id,
1118 arg1_id,
1119 )),
1120 crate::TypeInner::Vector { size, .. } => {
1122 self.write_dot_product(
1123 id,
1124 result_type_id,
1125 arg0_id,
1126 arg1_id,
1127 size as u32,
1128 block,
1129 |result_id, composite_id, index| {
1130 Instruction::composite_extract(
1131 result_type_id,
1132 result_id,
1133 composite_id,
1134 &[index],
1135 )
1136 },
1137 );
1138 self.cached[expr_handle] = id;
1139 return Ok(());
1140 }
1141 _ => unreachable!(
1142 "Correct TypeInner for dot product should be already validated"
1143 ),
1144 },
1145 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146 if self
1147 .writer
1148 .require_all(&[
1149 spirv::Capability::DotProduct,
1150 spirv::Capability::DotProductInput4x8BitPacked,
1151 ])
1152 .is_ok()
1153 {
1154 if self.writer.lang_version() < (1, 6) {
1156 self.writer.use_extension("SPV_KHR_integer_dot_product");
1160 }
1161
1162 let op = match fun {
1163 Mf::Dot4I8Packed => spirv::Op::SDot,
1164 Mf::Dot4U8Packed => spirv::Op::UDot,
1165 _ => unreachable!(),
1166 };
1167
1168 block.body.push(Instruction::ternary(
1169 op,
1170 result_type_id,
1171 id,
1172 arg0_id,
1173 arg1_id,
1174 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1175 ));
1176 } else {
1177 let (extract_op, arg0_id, arg1_id) = match fun {
1179 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1180 Mf::Dot4I8Packed => {
1181 let new_arg0_id = self.gen_id();
1184 block.body.push(Instruction::unary(
1185 spirv::Op::Bitcast,
1186 result_type_id,
1187 new_arg0_id,
1188 arg0_id,
1189 ));
1190
1191 let new_arg1_id = self.gen_id();
1192 block.body.push(Instruction::unary(
1193 spirv::Op::Bitcast,
1194 result_type_id,
1195 new_arg1_id,
1196 arg1_id,
1197 ));
1198
1199 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1200 }
1201 _ => unreachable!(),
1202 };
1203
1204 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1205
1206 const VEC_LENGTH: u8 = 4;
1207 let bit_shifts: [_; VEC_LENGTH as usize] =
1208 core::array::from_fn(|index| {
1209 self.writer
1210 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1211 });
1212
1213 self.write_dot_product(
1214 id,
1215 result_type_id,
1216 arg0_id,
1217 arg1_id,
1218 VEC_LENGTH as Word,
1219 block,
1220 |result_id, composite_id, index| {
1221 Instruction::ternary(
1222 extract_op,
1223 result_type_id,
1224 result_id,
1225 composite_id,
1226 bit_shifts[index as usize],
1227 eight,
1228 )
1229 },
1230 );
1231 }
1232
1233 self.cached[expr_handle] = id;
1234 return Ok(());
1235 }
1236 Mf::Outer => MathOp::Custom(Instruction::binary(
1237 spirv::Op::OuterProduct,
1238 result_type_id,
1239 id,
1240 arg0_id,
1241 arg1_id,
1242 )),
1243 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1244 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1245 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1246 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1247 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1248 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1249 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1250 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1252 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1253 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1254 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1255 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1256 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1258 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1259 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1260 other => unimplemented!("Unexpected sign({:?})", other),
1261 }),
1262 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1263 Mf::Mix => {
1264 let selector = arg2.unwrap();
1265 let selector_ty =
1266 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1267 match (arg_ty, selector_ty) {
1268 (
1270 &crate::TypeInner::Vector { size, .. },
1271 &crate::TypeInner::Scalar(scalar),
1272 ) => {
1273 let selector_type_id =
1274 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1275 self.temp_list.clear();
1276 self.temp_list.resize(size as usize, arg2_id);
1277
1278 let selector_id = self.gen_id();
1279 block.body.push(Instruction::composite_construct(
1280 selector_type_id,
1281 selector_id,
1282 &self.temp_list,
1283 ));
1284
1285 MathOp::Custom(Instruction::ext_inst(
1286 self.writer.gl450_ext_inst_id,
1287 spirv::GLOp::FMix,
1288 result_type_id,
1289 id,
1290 &[arg0_id, arg1_id, selector_id],
1291 ))
1292 }
1293 _ => MathOp::Ext(spirv::GLOp::FMix),
1294 }
1295 }
1296 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1297 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1298 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1299 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1300 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1301 Mf::Transpose => MathOp::Custom(Instruction::unary(
1302 spirv::Op::Transpose,
1303 result_type_id,
1304 id,
1305 arg0_id,
1306 )),
1307 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1308 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1309 spirv::Op::QuantizeToF16,
1310 result_type_id,
1311 id,
1312 arg0_id,
1313 )),
1314 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1315 spirv::Op::BitReverse,
1316 result_type_id,
1317 id,
1318 arg0_id,
1319 )),
1320 Mf::CountTrailingZeros => {
1321 let uint_id = match *arg_ty {
1322 crate::TypeInner::Vector { size, scalar } => {
1323 let ty =
1324 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1325
1326 self.temp_list.clear();
1327 self.temp_list.resize(
1328 size as _,
1329 self.writer
1330 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1331 );
1332
1333 self.writer.get_constant_composite(ty, &self.temp_list)
1334 }
1335 crate::TypeInner::Scalar(scalar) => self
1336 .writer
1337 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1338 _ => unreachable!(),
1339 };
1340
1341 let lsb_id = self.gen_id();
1342 block.body.push(Instruction::ext_inst(
1343 self.writer.gl450_ext_inst_id,
1344 spirv::GLOp::FindILsb,
1345 result_type_id,
1346 lsb_id,
1347 &[arg0_id],
1348 ));
1349
1350 MathOp::Custom(Instruction::ext_inst(
1351 self.writer.gl450_ext_inst_id,
1352 spirv::GLOp::UMin,
1353 result_type_id,
1354 id,
1355 &[uint_id, lsb_id],
1356 ))
1357 }
1358 Mf::CountLeadingZeros => {
1359 let (int_type_id, int_id, width) = match *arg_ty {
1360 crate::TypeInner::Vector { size, scalar } => {
1361 let ty =
1362 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1363
1364 self.temp_list.clear();
1365 self.temp_list.resize(
1366 size as _,
1367 self.writer
1368 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1369 );
1370
1371 (
1372 self.get_type_id(ty),
1373 self.writer.get_constant_composite(ty, &self.temp_list),
1374 scalar.width,
1375 )
1376 }
1377 crate::TypeInner::Scalar(scalar) => (
1378 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1379 self.writer
1380 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1381 scalar.width,
1382 ),
1383 _ => unreachable!(),
1384 };
1385
1386 if width != 4 {
1387 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1388 };
1389
1390 let msb_id = self.gen_id();
1391 block.body.push(Instruction::ext_inst(
1392 self.writer.gl450_ext_inst_id,
1393 if width != 4 {
1394 spirv::GLOp::FindILsb
1395 } else {
1396 spirv::GLOp::FindUMsb
1397 },
1398 int_type_id,
1399 msb_id,
1400 &[arg0_id],
1401 ));
1402
1403 MathOp::Custom(Instruction::binary(
1404 spirv::Op::ISub,
1405 result_type_id,
1406 id,
1407 int_id,
1408 msb_id,
1409 ))
1410 }
1411 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1412 spirv::Op::BitCount,
1413 result_type_id,
1414 id,
1415 arg0_id,
1416 )),
1417 Mf::ExtractBits => {
1418 let op = match arg_scalar_kind {
1419 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1420 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1421 other => unimplemented!("Unexpected sign({:?})", other),
1422 };
1423
1424 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1439 let width_constant = self
1440 .writer
1441 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1442
1443 let u32_type =
1444 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1445
1446 let offset_id = self.gen_id();
1448 block.body.push(Instruction::ext_inst(
1449 self.writer.gl450_ext_inst_id,
1450 spirv::GLOp::UMin,
1451 u32_type,
1452 offset_id,
1453 &[arg1_id, width_constant],
1454 ));
1455
1456 let max_count_id = self.gen_id();
1458 block.body.push(Instruction::binary(
1459 spirv::Op::ISub,
1460 u32_type,
1461 max_count_id,
1462 width_constant,
1463 offset_id,
1464 ));
1465
1466 let count_id = self.gen_id();
1468 block.body.push(Instruction::ext_inst(
1469 self.writer.gl450_ext_inst_id,
1470 spirv::GLOp::UMin,
1471 u32_type,
1472 count_id,
1473 &[arg2_id, max_count_id],
1474 ));
1475
1476 MathOp::Custom(Instruction::ternary(
1477 op,
1478 result_type_id,
1479 id,
1480 arg0_id,
1481 offset_id,
1482 count_id,
1483 ))
1484 }
1485 Mf::InsertBits => {
1486 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1489 let width_constant = self
1490 .writer
1491 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1492
1493 let u32_type =
1494 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1495
1496 let offset_id = self.gen_id();
1498 block.body.push(Instruction::ext_inst(
1499 self.writer.gl450_ext_inst_id,
1500 spirv::GLOp::UMin,
1501 u32_type,
1502 offset_id,
1503 &[arg2_id, width_constant],
1504 ));
1505
1506 let max_count_id = self.gen_id();
1508 block.body.push(Instruction::binary(
1509 spirv::Op::ISub,
1510 u32_type,
1511 max_count_id,
1512 width_constant,
1513 offset_id,
1514 ));
1515
1516 let count_id = self.gen_id();
1518 block.body.push(Instruction::ext_inst(
1519 self.writer.gl450_ext_inst_id,
1520 spirv::GLOp::UMin,
1521 u32_type,
1522 count_id,
1523 &[arg3_id, max_count_id],
1524 ));
1525
1526 MathOp::Custom(Instruction::quaternary(
1527 spirv::Op::BitFieldInsert,
1528 result_type_id,
1529 id,
1530 arg0_id,
1531 arg1_id,
1532 offset_id,
1533 count_id,
1534 ))
1535 }
1536 Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1537 Mf::FirstLeadingBit => {
1538 if arg_ty.scalar_width() == Some(4) {
1539 let thing = match arg_scalar_kind {
1540 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1541 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1542 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1543 };
1544 MathOp::Ext(thing)
1545 } else {
1546 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1547 }
1548 }
1549 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1550 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1551 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1552 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1553 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1554 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1555 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1556 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1557
1558 let last_instruction =
1559 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1560 self.write_pack4x8_optimized(
1561 block,
1562 result_type_id,
1563 arg0_id,
1564 id,
1565 is_signed,
1566 should_clamp,
1567 )
1568 } else {
1569 self.write_pack4x8_polyfill(
1570 block,
1571 result_type_id,
1572 arg0_id,
1573 id,
1574 is_signed,
1575 should_clamp,
1576 )
1577 };
1578
1579 MathOp::Custom(last_instruction)
1580 }
1581 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1582 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1583 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1584 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1585 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1586 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1587 let is_signed = matches!(fun, Mf::Unpack4xI8);
1588
1589 let last_instruction =
1590 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1591 self.write_unpack4x8_optimized(
1592 block,
1593 result_type_id,
1594 arg0_id,
1595 id,
1596 is_signed,
1597 )
1598 } else {
1599 self.write_unpack4x8_polyfill(
1600 block,
1601 result_type_id,
1602 arg0_id,
1603 id,
1604 is_signed,
1605 )
1606 };
1607
1608 MathOp::Custom(last_instruction)
1609 }
1610 };
1611
1612 block.body.push(match math_op {
1613 MathOp::Ext(op) => Instruction::ext_inst(
1614 self.writer.gl450_ext_inst_id,
1615 op,
1616 result_type_id,
1617 id,
1618 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1619 ),
1620 MathOp::Custom(inst) => inst,
1621 });
1622 id
1623 }
1624 crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
1625 crate::Expression::Load { pointer } => {
1626 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1627 }
1628 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1629 crate::Expression::CallResult(_)
1630 | crate::Expression::AtomicResult { .. }
1631 | crate::Expression::WorkGroupUniformLoadResult { .. }
1632 | crate::Expression::RayQueryProceedResult
1633 | crate::Expression::SubgroupBallotResult
1634 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1635 crate::Expression::As {
1636 expr,
1637 kind,
1638 convert,
1639 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1640 crate::Expression::ImageLoad {
1641 image,
1642 coordinate,
1643 array_index,
1644 sample,
1645 level,
1646 } => self.write_image_load(
1647 result_type_id,
1648 image,
1649 coordinate,
1650 array_index,
1651 level,
1652 sample,
1653 block,
1654 )?,
1655 crate::Expression::ImageSample {
1656 image,
1657 sampler,
1658 gather,
1659 coordinate,
1660 array_index,
1661 offset,
1662 level,
1663 depth_ref,
1664 clamp_to_edge,
1665 } => self.write_image_sample(
1666 result_type_id,
1667 image,
1668 sampler,
1669 gather,
1670 coordinate,
1671 array_index,
1672 offset,
1673 level,
1674 depth_ref,
1675 clamp_to_edge,
1676 block,
1677 )?,
1678 crate::Expression::Select {
1679 condition,
1680 accept,
1681 reject,
1682 } => {
1683 let id = self.gen_id();
1684 let mut condition_id = self.cached[condition];
1685 let accept_id = self.cached[accept];
1686 let reject_id = self.cached[reject];
1687
1688 let condition_ty = self.fun_info[condition]
1689 .ty
1690 .inner_with(&self.ir_module.types);
1691 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
1692
1693 if let (
1694 &crate::TypeInner::Scalar(
1695 condition_scalar @ crate::Scalar {
1696 kind: crate::ScalarKind::Bool,
1697 ..
1698 },
1699 ),
1700 &crate::TypeInner::Vector { size, .. },
1701 ) = (condition_ty, object_ty)
1702 {
1703 self.temp_list.clear();
1704 self.temp_list.resize(size as usize, condition_id);
1705
1706 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1707 size,
1708 scalar: condition_scalar,
1709 });
1710
1711 let id = self.gen_id();
1712 block.body.push(Instruction::composite_construct(
1713 bool_vector_type_id,
1714 id,
1715 &self.temp_list,
1716 ));
1717 condition_id = id
1718 }
1719
1720 let instruction =
1721 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
1722 block.body.push(instruction);
1723 id
1724 }
1725 crate::Expression::Derivative { axis, ctrl, expr } => {
1726 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1727 match ctrl {
1728 Ctrl::Coarse | Ctrl::Fine => {
1729 self.writer.require_any(
1730 "DerivativeControl",
1731 &[spirv::Capability::DerivativeControl],
1732 )?;
1733 }
1734 Ctrl::None => {}
1735 }
1736 let id = self.gen_id();
1737 let expr_id = self.cached[expr];
1738 let op = match (axis, ctrl) {
1739 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
1740 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
1741 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
1742 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
1743 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
1744 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
1745 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
1746 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
1747 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
1748 };
1749 block
1750 .body
1751 .push(Instruction::derivative(op, result_type_id, id, expr_id));
1752 id
1753 }
1754 crate::Expression::ImageQuery { image, query } => {
1755 self.write_image_query(result_type_id, image, query, block)?
1756 }
1757 crate::Expression::Relational { fun, argument } => {
1758 use crate::RelationalFunction as Rf;
1759 let arg_id = self.cached[argument];
1760 let op = match fun {
1761 Rf::All => spirv::Op::All,
1762 Rf::Any => spirv::Op::Any,
1763 Rf::IsNan => spirv::Op::IsNan,
1764 Rf::IsInf => spirv::Op::IsInf,
1765 };
1766 let id = self.gen_id();
1767 block
1768 .body
1769 .push(Instruction::relational(op, result_type_id, id, arg_id));
1770 id
1771 }
1772 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
1773 crate::Expression::RayQueryGetIntersection { query, committed } => {
1774 let query_id = self.cached[query];
1775 let func_id = self
1776 .writer
1777 .write_ray_query_get_intersection_function(committed, self.ir_module);
1778 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
1779 let intersection_type_id = self.get_handle_type_id(ray_intersection);
1780 let id = self.gen_id();
1781 block.body.push(Instruction::function_call(
1782 intersection_type_id,
1783 id,
1784 func_id,
1785 &[query_id],
1786 ));
1787 id
1788 }
1789 crate::Expression::RayQueryVertexPositions { query, committed } => {
1790 self.writer.require_any(
1791 "RayQueryVertexPositions",
1792 &[spirv::Capability::RayQueryPositionFetchKHR],
1793 )?;
1794 self.write_ray_query_return_vertex_position(query, block, committed)
1795 }
1796 };
1797
1798 self.cached[expr_handle] = id;
1799 Ok(())
1800 }
1801
1802 fn write_as_expression(
1805 &mut self,
1806 expr: Handle<crate::Expression>,
1807 convert: Option<u8>,
1808 kind: crate::ScalarKind,
1809
1810 block: &mut Block,
1811 result_type_id: u32,
1812 ) -> Result<u32, Error> {
1813 use crate::ScalarKind as Sk;
1814 let expr_id = self.cached[expr];
1815 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1816
1817 if let crate::TypeInner::Matrix {
1822 columns,
1823 rows,
1824 scalar,
1825 } = *ty
1826 {
1827 let Some(convert) = convert else {
1828 return Ok(expr_id);
1830 };
1831
1832 if convert == scalar.width {
1833 return Ok(expr_id);
1835 }
1836
1837 if kind != Sk::Float {
1838 return Err(Error::Validation("Matrices must be floats"));
1840 }
1841
1842 let column_src_ty =
1844 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1845 size: rows,
1846 scalar,
1847 })));
1848
1849 let column_dst_ty =
1851 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1852 size: rows,
1853 scalar: crate::Scalar {
1854 kind,
1855 width: convert,
1856 },
1857 })));
1858
1859 let mut components = ArrayVec::<Word, 4>::new();
1860
1861 for column in 0..columns as usize {
1862 let column_id = self.gen_id();
1863 block.body.push(Instruction::composite_extract(
1864 column_src_ty,
1865 column_id,
1866 expr_id,
1867 &[column as u32],
1868 ));
1869
1870 let column_conv_id = self.gen_id();
1871 block.body.push(Instruction::unary(
1872 spirv::Op::FConvert,
1873 column_dst_ty,
1874 column_conv_id,
1875 column_id,
1876 ));
1877
1878 components.push(column_conv_id);
1879 }
1880
1881 let construct_id = self.gen_id();
1882
1883 block.body.push(Instruction::composite_construct(
1884 result_type_id,
1885 construct_id,
1886 &components,
1887 ));
1888
1889 return Ok(construct_id);
1890 }
1891
1892 let (src_scalar, src_size) = match *ty {
1893 crate::TypeInner::Scalar(scalar) => (scalar, None),
1894 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
1895 ref other => {
1896 log::error!("As source {other:?}");
1897 return Err(Error::Validation("Unexpected Expression::As source"));
1898 }
1899 };
1900
1901 enum Cast {
1902 Identity(Word),
1903 Unary(spirv::Op, Word),
1904 Binary(spirv::Op, Word, Word),
1905 Ternary(spirv::Op, Word, Word, Word),
1906 }
1907 let cast = match (src_scalar.kind, kind, convert) {
1908 (src_kind, kind, convert)
1911 if src_kind == kind
1912 && convert.filter(|&width| width != src_scalar.width).is_none() =>
1913 {
1914 Cast::Identity(expr_id)
1915 }
1916 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
1917 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
1918 (_, Sk::Bool, Some(_)) => {
1920 let op = match src_scalar.kind {
1921 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
1922 Sk::Float => spirv::Op::FUnordNotEqual,
1923 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
1924 };
1925 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
1926 let zero_id = match src_size {
1927 Some(size) => {
1928 let ty = LocalType::Numeric(NumericType::Vector {
1929 size,
1930 scalar: src_scalar,
1931 })
1932 .into();
1933
1934 self.temp_list.clear();
1935 self.temp_list.resize(size as _, zero_scalar_id);
1936
1937 self.writer.get_constant_composite(ty, &self.temp_list)
1938 }
1939 None => zero_scalar_id,
1940 };
1941
1942 Cast::Binary(op, expr_id, zero_id)
1943 }
1944 (Sk::Bool, _, Some(dst_width)) => {
1946 let dst_scalar = crate::Scalar {
1947 kind,
1948 width: dst_width,
1949 };
1950 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
1951 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
1952 let (accept_id, reject_id) = match src_size {
1953 Some(size) => {
1954 let ty = LocalType::Numeric(NumericType::Vector {
1955 size,
1956 scalar: dst_scalar,
1957 })
1958 .into();
1959
1960 self.temp_list.clear();
1961 self.temp_list.resize(size as _, zero_scalar_id);
1962
1963 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
1964
1965 self.temp_list.fill(one_scalar_id);
1966
1967 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1968
1969 (vec1_id, vec0_id)
1970 }
1971 None => (one_scalar_id, zero_scalar_id),
1972 };
1973
1974 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
1975 }
1976 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
1987 let dst_scalar = crate::Scalar { kind, width };
1988 let (min, max) =
1989 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
1990 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
1991
1992 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
1993 None => const_id,
1994 Some(size) => {
1995 let constituent_ids = [const_id; crate::VectorSize::MAX];
1996 writer.get_constant_composite(
1997 LookupType::Local(LocalType::Numeric(NumericType::Vector {
1998 size,
1999 scalar: src_scalar,
2000 })),
2001 &constituent_ids[..size as usize],
2002 )
2003 }
2004 };
2005 let min_const_id = self.writer.get_constant_scalar(min);
2006 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2007 let max_const_id = self.writer.get_constant_scalar(max);
2008 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2009
2010 let clamp_id = self.gen_id();
2011 block.body.push(Instruction::ext_inst(
2012 self.writer.gl450_ext_inst_id,
2013 spirv::GLOp::FClamp,
2014 expr_type_id,
2015 clamp_id,
2016 &[expr_id, min_const_id, max_const_id],
2017 ));
2018
2019 let op = match dst_scalar.kind {
2020 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2021 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2022 _ => unreachable!(),
2023 };
2024 Cast::Unary(op, clamp_id)
2025 }
2026 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2027 Cast::Unary(spirv::Op::FConvert, expr_id)
2028 }
2029 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2030 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2031 Cast::Unary(spirv::Op::SConvert, expr_id)
2032 }
2033 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2034 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2035 Cast::Unary(spirv::Op::UConvert, expr_id)
2036 }
2037 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2038 Cast::Unary(spirv::Op::SConvert, expr_id)
2039 }
2040 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2041 Cast::Unary(spirv::Op::UConvert, expr_id)
2042 }
2043 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2045 };
2046 Ok(match cast {
2047 Cast::Identity(expr) => expr,
2048 Cast::Unary(op, op1) => {
2049 let id = self.gen_id();
2050 block
2051 .body
2052 .push(Instruction::unary(op, result_type_id, id, op1));
2053 id
2054 }
2055 Cast::Binary(op, op1, op2) => {
2056 let id = self.gen_id();
2057 block
2058 .body
2059 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2060 id
2061 }
2062 Cast::Ternary(op, op1, op2, op3) => {
2063 let id = self.gen_id();
2064 block
2065 .body
2066 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2067 id
2068 }
2069 })
2070 }
2071
2072 fn write_access_chain(
2083 &mut self,
2084 mut expr_handle: Handle<crate::Expression>,
2085 block: &mut Block,
2086 type_adjustment: AccessTypeAdjustment,
2087 ) -> Result<ExpressionPointer, Error> {
2088 let result_type_id = {
2089 let resolution = &self.fun_info[expr_handle].ty;
2090 match type_adjustment {
2091 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2092 AccessTypeAdjustment::IntroducePointer(class) => {
2093 self.writer.get_resolution_pointer_id(resolution, class)
2094 }
2095 }
2096 };
2097
2098 let mut accumulated_checks = None;
2102
2103 let mut is_non_uniform_binding_array = false;
2105
2106 self.temp_list.clear();
2107 let root_id = loop {
2108 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2111 break spilled.id;
2114 }
2115
2116 expr_handle = match self.ir_function.expressions[expr_handle] {
2117 crate::Expression::Access { base, index } => {
2118 is_non_uniform_binding_array |=
2119 self.is_nonuniform_binding_array_access(base, index);
2120
2121 let index = GuardedIndex::Expression(index);
2122 let index_id =
2123 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2124 self.temp_list.push(index_id);
2125
2126 base
2127 }
2128 crate::Expression::AccessIndex { base, index } => {
2129 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2132 if let crate::TypeInner::Pointer { base, .. } = *base_ty {
2133 base_ty = &self.ir_module.types[base].inner;
2134 }
2135 let index_id = if let crate::TypeInner::Struct { .. } = *base_ty {
2136 self.get_index_constant(index)
2137 } else {
2138 self.write_access_chain_index(
2145 base,
2146 GuardedIndex::Known(index),
2147 &mut accumulated_checks,
2148 block,
2149 )?
2150 };
2151
2152 self.temp_list.push(index_id);
2153 base
2154 }
2155 crate::Expression::GlobalVariable(handle) => {
2156 let gv = &self.writer.global_variables[handle];
2157 break gv.access_id;
2158 }
2159 crate::Expression::LocalVariable(variable) => {
2160 let local_var = &self.function.variables[&variable];
2161 break local_var.id;
2162 }
2163 crate::Expression::FunctionArgument(index) => {
2164 break self.function.parameter_id(index);
2165 }
2166 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2167 }
2168 };
2169
2170 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2171 (
2172 root_id,
2173 ExpressionPointer::Ready {
2174 pointer_id: root_id,
2175 },
2176 )
2177 } else {
2178 self.temp_list.reverse();
2179 let pointer_id = self.gen_id();
2180 let access =
2181 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2182
2183 let expr_pointer = match accumulated_checks {
2188 Some(condition) => ExpressionPointer::Conditional { condition, access },
2189 None => {
2190 block.body.push(access);
2191 ExpressionPointer::Ready { pointer_id }
2192 }
2193 };
2194 (pointer_id, expr_pointer)
2195 };
2196 if is_non_uniform_binding_array {
2200 self.writer
2201 .decorate_non_uniform_binding_array_access(pointer_id)?;
2202 }
2203
2204 Ok(expr_pointer)
2205 }
2206
2207 fn is_nonuniform_binding_array_access(
2208 &mut self,
2209 base: Handle<crate::Expression>,
2210 index: Handle<crate::Expression>,
2211 ) -> bool {
2212 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2213 else {
2214 return false;
2215 };
2216
2217 let gvar = &self.ir_module.global_variables[var_handle];
2220 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2221 return false;
2222 };
2223
2224 self.fun_info[index].uniformity.non_uniform_result.is_some()
2225 }
2226
2227 fn write_access_chain_index(
2237 &mut self,
2238 base: Handle<crate::Expression>,
2239 index: GuardedIndex,
2240 accumulated_checks: &mut Option<Word>,
2241 block: &mut Block,
2242 ) -> Result<Word, Error> {
2243 match self.write_bounds_check(base, index, block)? {
2244 BoundsCheckResult::KnownInBounds(known_index) => {
2245 let scalar = crate::Literal::U32(known_index);
2248 Ok(self.writer.get_constant_scalar(scalar))
2249 }
2250 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2251 BoundsCheckResult::Conditional {
2252 condition_id: condition,
2253 index_id: index,
2254 } => {
2255 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2256
2257 Ok(index)
2259 }
2260 }
2261 }
2262
2263 fn extend_bounds_check_condition_chain(
2282 &mut self,
2283 chain: &mut Option<Word>,
2284 comparison_id: Word,
2285 block: &mut Block,
2286 ) {
2287 match *chain {
2288 Some(ref mut prior_checks) => {
2289 let combined = self.gen_id();
2290 block.body.push(Instruction::binary(
2291 spirv::Op::LogicalAnd,
2292 self.writer.get_bool_type_id(),
2293 combined,
2294 *prior_checks,
2295 comparison_id,
2296 ));
2297 *prior_checks = combined;
2298 }
2299 None => {
2300 *chain = Some(comparison_id);
2302 }
2303 }
2304 }
2305
2306 fn write_checked_load(
2307 &mut self,
2308 pointer: Handle<crate::Expression>,
2309 block: &mut Block,
2310 access_type_adjustment: AccessTypeAdjustment,
2311 result_type_id: Word,
2312 ) -> Result<Word, Error> {
2313 match self.write_access_chain(pointer, block, access_type_adjustment)? {
2314 ExpressionPointer::Ready { pointer_id } => {
2315 let id = self.gen_id();
2316 let atomic_space =
2317 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2318 crate::TypeInner::Pointer { base, space } => {
2319 match self.ir_module.types[base].inner {
2320 crate::TypeInner::Atomic { .. } => Some(space),
2321 _ => None,
2322 }
2323 }
2324 _ => None,
2325 };
2326 let instruction = if let Some(space) = atomic_space {
2327 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2328 let scope_constant_id = self.get_scope_constant(scope as u32);
2329 let semantics_id = self.get_index_constant(semantics.bits());
2330 Instruction::atomic_load(
2331 result_type_id,
2332 id,
2333 pointer_id,
2334 scope_constant_id,
2335 semantics_id,
2336 )
2337 } else {
2338 Instruction::load(result_type_id, id, pointer_id, None)
2339 };
2340 block.body.push(instruction);
2341 Ok(id)
2342 }
2343 ExpressionPointer::Conditional { condition, access } => {
2344 let value = self.write_conditional_indexed_load(
2346 result_type_id,
2347 condition,
2348 block,
2349 move |id_gen, block| {
2350 let pointer_id = access.result_id.unwrap();
2352 let value_id = id_gen.next();
2353 block.body.push(access);
2354 block.body.push(Instruction::load(
2355 result_type_id,
2356 value_id,
2357 pointer_id,
2358 None,
2359 ));
2360 value_id
2361 },
2362 );
2363 Ok(value)
2364 }
2365 }
2366 }
2367
2368 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2369 use indexmap::map::Entry;
2370
2371 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2373 Entry::Occupied(preexisting) => preexisting.get().id,
2374 Entry::Vacant(vacant) => {
2375 let pointer_type_id = self.writer.get_resolution_pointer_id(
2378 &self.fun_info[base].ty,
2379 spirv::StorageClass::Function,
2380 );
2381 let id = self.writer.id_gen.next();
2382 vacant.insert(super::LocalVariable {
2383 id,
2384 instruction: Instruction::variable(
2385 pointer_type_id,
2386 id,
2387 spirv::StorageClass::Function,
2388 None,
2389 ),
2390 });
2391 id
2392 }
2393 };
2394
2395 let base_id = self.cached[base];
2420 block
2421 .body
2422 .push(Instruction::store(spill_variable_id, base_id, None));
2423 }
2424
2425 fn maybe_access_spilled_composite(
2442 &mut self,
2443 access: Handle<crate::Expression>,
2444 block: &mut Block,
2445 result_type_id: Word,
2446 ) -> Result<Word, Error> {
2447 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2448 if access_uses == self.fun_info[access].ref_count {
2449 Ok(0)
2453 } else {
2454 self.write_checked_load(
2459 access,
2460 block,
2461 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2462 result_type_id,
2463 )
2464 }
2465 }
2466
2467 #[allow(clippy::too_many_arguments)]
2469 fn write_matrix_matrix_column_op(
2470 &mut self,
2471 block: &mut Block,
2472 result_id: Word,
2473 result_type_id: Word,
2474 left_id: Word,
2475 right_id: Word,
2476 columns: crate::VectorSize,
2477 rows: crate::VectorSize,
2478 width: u8,
2479 op: spirv::Op,
2480 ) {
2481 self.temp_list.clear();
2482
2483 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2484 size: rows,
2485 scalar: crate::Scalar::float(width),
2486 });
2487
2488 for index in 0..columns as u32 {
2489 let column_id_left = self.gen_id();
2490 let column_id_right = self.gen_id();
2491 let column_id_res = self.gen_id();
2492
2493 block.body.push(Instruction::composite_extract(
2494 vector_type_id,
2495 column_id_left,
2496 left_id,
2497 &[index],
2498 ));
2499 block.body.push(Instruction::composite_extract(
2500 vector_type_id,
2501 column_id_right,
2502 right_id,
2503 &[index],
2504 ));
2505 block.body.push(Instruction::binary(
2506 op,
2507 vector_type_id,
2508 column_id_res,
2509 column_id_left,
2510 column_id_right,
2511 ));
2512
2513 self.temp_list.push(column_id_res);
2514 }
2515
2516 block.body.push(Instruction::composite_construct(
2517 result_type_id,
2518 result_id,
2519 &self.temp_list,
2520 ));
2521 }
2522
2523 fn write_vector_scalar_mult(
2525 &mut self,
2526 block: &mut Block,
2527 result_id: Word,
2528 result_type_id: Word,
2529 vector_id: Word,
2530 scalar_id: Word,
2531 vector: &crate::TypeInner,
2532 ) {
2533 let (size, kind) = match *vector {
2534 crate::TypeInner::Vector {
2535 size,
2536 scalar: crate::Scalar { kind, .. },
2537 } => (size, kind),
2538 _ => unreachable!(),
2539 };
2540
2541 let (op, operand_id) = match kind {
2542 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
2543 _ => {
2544 let operand_id = self.gen_id();
2545 self.temp_list.clear();
2546 self.temp_list.resize(size as usize, scalar_id);
2547 block.body.push(Instruction::composite_construct(
2548 result_type_id,
2549 operand_id,
2550 &self.temp_list,
2551 ));
2552 (spirv::Op::IMul, operand_id)
2553 }
2554 };
2555
2556 block.body.push(Instruction::binary(
2557 op,
2558 result_type_id,
2559 result_id,
2560 vector_id,
2561 operand_id,
2562 ));
2563 }
2564
2565 #[expect(clippy::too_many_arguments)]
2572 fn write_dot_product(
2573 &mut self,
2574 result_id: Word,
2575 result_type_id: Word,
2576 arg0_id: Word,
2577 arg1_id: Word,
2578 size: u32,
2579 block: &mut Block,
2580 extractor: impl Fn(Word, Word, Word) -> Instruction,
2581 ) {
2582 let mut partial_sum = self.writer.get_constant_null(result_type_id);
2583 let last_component = size - 1;
2584 for index in 0..=last_component {
2585 let a_id = self.gen_id();
2587 block.body.push(extractor(a_id, arg0_id, index));
2588 let b_id = self.gen_id();
2589 block.body.push(extractor(b_id, arg1_id, index));
2590 let prod_id = self.gen_id();
2591 block.body.push(Instruction::binary(
2592 spirv::Op::IMul,
2593 result_type_id,
2594 prod_id,
2595 a_id,
2596 b_id,
2597 ));
2598
2599 let id = if index == last_component {
2601 result_id
2602 } else {
2603 self.gen_id()
2604 };
2605
2606 block.body.push(Instruction::binary(
2608 spirv::Op::IAdd,
2609 result_type_id,
2610 id,
2611 partial_sum,
2612 prod_id,
2613 ));
2614 partial_sum = id;
2616 }
2617 }
2618
2619 fn write_pack4x8_optimized(
2621 &mut self,
2622 block: &mut Block,
2623 result_type_id: u32,
2624 arg0_id: u32,
2625 id: u32,
2626 is_signed: bool,
2627 should_clamp: bool,
2628 ) -> Instruction {
2629 let int_type = if is_signed {
2630 crate::ScalarKind::Sint
2631 } else {
2632 crate::ScalarKind::Uint
2633 };
2634 let wide_vector_type = NumericType::Vector {
2635 size: crate::VectorSize::Quad,
2636 scalar: crate::Scalar {
2637 kind: int_type,
2638 width: 4,
2639 },
2640 };
2641 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
2642 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2643 size: crate::VectorSize::Quad,
2644 scalar: crate::Scalar {
2645 kind: crate::ScalarKind::Uint,
2646 width: 1,
2647 },
2648 });
2649
2650 let mut wide_vector = arg0_id;
2651 if should_clamp {
2652 let (min, max, clamp_op) = if is_signed {
2653 (
2654 crate::Literal::I32(-128),
2655 crate::Literal::I32(127),
2656 spirv::GLOp::SClamp,
2657 )
2658 } else {
2659 (
2660 crate::Literal::U32(0),
2661 crate::Literal::U32(255),
2662 spirv::GLOp::UClamp,
2663 )
2664 };
2665 let [min, max] = [min, max].map(|lit| {
2666 let scalar = self.writer.get_constant_scalar(lit);
2667 self.writer.get_constant_composite(
2668 LookupType::Local(LocalType::Numeric(wide_vector_type)),
2669 &[scalar; 4],
2670 )
2671 });
2672
2673 let clamp_id = self.gen_id();
2674 block.body.push(Instruction::ext_inst(
2675 self.writer.gl450_ext_inst_id,
2676 clamp_op,
2677 wide_vector_type_id,
2678 clamp_id,
2679 &[wide_vector, min, max],
2680 ));
2681
2682 wide_vector = clamp_id;
2683 }
2684
2685 let packed_vector = self.gen_id();
2686 block.body.push(Instruction::unary(
2687 spirv::Op::UConvert, packed_vector_type_id,
2689 packed_vector,
2690 wide_vector,
2691 ));
2692
2693 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
2698 }
2699
2700 fn write_pack4x8_polyfill(
2702 &mut self,
2703 block: &mut Block,
2704 result_type_id: u32,
2705 arg0_id: u32,
2706 id: u32,
2707 is_signed: bool,
2708 should_clamp: bool,
2709 ) -> Instruction {
2710 let int_type = if is_signed {
2711 crate::ScalarKind::Sint
2712 } else {
2713 crate::ScalarKind::Uint
2714 };
2715 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
2716 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2717 kind: int_type,
2718 width: 4,
2719 }));
2720
2721 let mut last_instruction = Instruction::new(spirv::Op::Nop);
2722
2723 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
2724 let mut preresult = zero;
2725 block
2726 .body
2727 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
2728
2729 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2730 const VEC_LENGTH: u8 = 4;
2731 for i in 0..u32::from(VEC_LENGTH) {
2732 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
2733 let mut extracted = self.gen_id();
2734 block.body.push(Instruction::binary(
2735 spirv::Op::CompositeExtract,
2736 int_type_id,
2737 extracted,
2738 arg0_id,
2739 i,
2740 ));
2741 if is_signed {
2742 let casted = self.gen_id();
2743 block.body.push(Instruction::unary(
2744 spirv::Op::Bitcast,
2745 uint_type_id,
2746 casted,
2747 extracted,
2748 ));
2749 extracted = casted;
2750 }
2751 if should_clamp {
2752 let (min, max, clamp_op) = if is_signed {
2753 (
2754 crate::Literal::I32(-128),
2755 crate::Literal::I32(127),
2756 spirv::GLOp::SClamp,
2757 )
2758 } else {
2759 (
2760 crate::Literal::U32(0),
2761 crate::Literal::U32(255),
2762 spirv::GLOp::UClamp,
2763 )
2764 };
2765 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
2766
2767 let clamp_id = self.gen_id();
2768 block.body.push(Instruction::ext_inst(
2769 self.writer.gl450_ext_inst_id,
2770 clamp_op,
2771 result_type_id,
2772 clamp_id,
2773 &[extracted, min, max],
2774 ));
2775
2776 extracted = clamp_id;
2777 }
2778 let is_last = i == u32::from(VEC_LENGTH - 1);
2779 if is_last {
2780 last_instruction = Instruction::quaternary(
2781 spirv::Op::BitFieldInsert,
2782 result_type_id,
2783 id,
2784 preresult,
2785 extracted,
2786 offset,
2787 eight,
2788 )
2789 } else {
2790 let new_preresult = self.gen_id();
2791 block.body.push(Instruction::quaternary(
2792 spirv::Op::BitFieldInsert,
2793 result_type_id,
2794 new_preresult,
2795 preresult,
2796 extracted,
2797 offset,
2798 eight,
2799 ));
2800 preresult = new_preresult;
2801 }
2802 }
2803 last_instruction
2804 }
2805
2806 fn write_unpack4x8_optimized(
2808 &mut self,
2809 block: &mut Block,
2810 result_type_id: u32,
2811 arg0_id: u32,
2812 id: u32,
2813 is_signed: bool,
2814 ) -> Instruction {
2815 let (int_type, convert_op) = if is_signed {
2816 (crate::ScalarKind::Sint, spirv::Op::SConvert)
2817 } else {
2818 (crate::ScalarKind::Uint, spirv::Op::UConvert)
2819 };
2820
2821 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2822 size: crate::VectorSize::Quad,
2823 scalar: crate::Scalar {
2824 kind: int_type,
2825 width: 1,
2826 },
2827 });
2828
2829 let packed_vector = self.gen_id();
2834 block.body.push(Instruction::unary(
2835 spirv::Op::Bitcast,
2836 packed_vector_type_id,
2837 packed_vector,
2838 arg0_id,
2839 ));
2840
2841 Instruction::unary(convert_op, result_type_id, id, packed_vector)
2842 }
2843
2844 fn write_unpack4x8_polyfill(
2846 &mut self,
2847 block: &mut Block,
2848 result_type_id: u32,
2849 arg0_id: u32,
2850 id: u32,
2851 is_signed: bool,
2852 ) -> Instruction {
2853 let (int_type, extract_op) = if is_signed {
2854 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
2855 } else {
2856 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
2857 };
2858
2859 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
2860
2861 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2862 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2863 kind: int_type,
2864 width: 4,
2865 }));
2866 block
2867 .body
2868 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
2869 let arg_id = if is_signed {
2870 let new_arg_id = self.gen_id();
2871 block.body.push(Instruction::unary(
2872 spirv::Op::Bitcast,
2873 sint_type_id,
2874 new_arg_id,
2875 arg0_id,
2876 ));
2877 new_arg_id
2878 } else {
2879 arg0_id
2880 };
2881
2882 const VEC_LENGTH: u8 = 4;
2883 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
2884 for (i, part_id) in parts.into_iter().enumerate() {
2885 let index = self
2886 .writer
2887 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
2888 block.body.push(Instruction::ternary(
2889 extract_op,
2890 int_type_id,
2891 part_id,
2892 arg_id,
2893 index,
2894 eight,
2895 ));
2896 }
2897
2898 Instruction::composite_construct(result_type_id, id, &parts)
2899 }
2900
2901 fn write_block(
2918 &mut self,
2919 label_id: Word,
2920 naga_block: &crate::Block,
2921 exit: BlockExit,
2922 loop_context: LoopContext,
2923 debug_info: Option<&DebugInfoInner>,
2924 ) -> Result<BlockExitDisposition, Error> {
2925 let mut block = Block::new(label_id);
2926 for (statement, span) in naga_block.span_iter() {
2927 if let (Some(debug_info), false) = (
2928 debug_info,
2929 matches!(
2930 statement,
2931 &(Statement::Block(..)
2932 | Statement::Break
2933 | Statement::Continue
2934 | Statement::Kill
2935 | Statement::Return { .. }
2936 | Statement::Loop { .. })
2937 ),
2938 ) {
2939 let loc: crate::SourceLocation = span.location(debug_info.source_code);
2940 block.body.push(Instruction::line(
2941 debug_info.source_file_id,
2942 loc.line_number,
2943 loc.line_position,
2944 ));
2945 };
2946 match *statement {
2947 Statement::Emit(ref range) => {
2948 for handle in range.clone() {
2949 if !self.expression_constness.is_const(handle) {
2951 self.cache_expression_value(handle, &mut block)?;
2952 }
2953 }
2954 }
2955 Statement::Block(ref block_statements) => {
2956 let scope_id = self.gen_id();
2957 self.function.consume(block, Instruction::branch(scope_id));
2958
2959 let merge_id = self.gen_id();
2960 let merge_used = self.write_block(
2961 scope_id,
2962 block_statements,
2963 BlockExit::Branch { target: merge_id },
2964 loop_context,
2965 debug_info,
2966 )?;
2967
2968 match merge_used {
2969 BlockExitDisposition::Used => {
2970 block = Block::new(merge_id);
2971 }
2972 BlockExitDisposition::Discarded => {
2973 return Ok(BlockExitDisposition::Discarded);
2974 }
2975 }
2976 }
2977 Statement::If {
2978 condition,
2979 ref accept,
2980 ref reject,
2981 } => {
2982 let condition_id = self.cached[condition];
2983
2984 let merge_id = self.gen_id();
2985 block.body.push(Instruction::selection_merge(
2986 merge_id,
2987 spirv::SelectionControl::NONE,
2988 ));
2989
2990 let accept_id = if accept.is_empty() {
2991 None
2992 } else {
2993 Some(self.gen_id())
2994 };
2995 let reject_id = if reject.is_empty() {
2996 None
2997 } else {
2998 Some(self.gen_id())
2999 };
3000
3001 self.function.consume(
3002 block,
3003 Instruction::branch_conditional(
3004 condition_id,
3005 accept_id.unwrap_or(merge_id),
3006 reject_id.unwrap_or(merge_id),
3007 ),
3008 );
3009
3010 if let Some(block_id) = accept_id {
3011 let _ = self.write_block(
3016 block_id,
3017 accept,
3018 BlockExit::Branch { target: merge_id },
3019 loop_context,
3020 debug_info,
3021 )?;
3022 }
3023 if let Some(block_id) = reject_id {
3024 let _ = self.write_block(
3029 block_id,
3030 reject,
3031 BlockExit::Branch { target: merge_id },
3032 loop_context,
3033 debug_info,
3034 )?;
3035 }
3036
3037 block = Block::new(merge_id);
3038 }
3039 Statement::Switch {
3040 selector,
3041 ref cases,
3042 } => {
3043 let selector_id = self.cached[selector];
3044
3045 let merge_id = self.gen_id();
3046 block.body.push(Instruction::selection_merge(
3047 merge_id,
3048 spirv::SelectionControl::NONE,
3049 ));
3050
3051 let mut default_id = None;
3052 let mut last_id = None;
3054
3055 let mut raw_cases = Vec::with_capacity(cases.len());
3056 let mut case_ids = Vec::with_capacity(cases.len());
3057 for case in cases.iter() {
3058 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3060
3061 if case.fall_through && case.body.is_empty() {
3062 last_id = Some(label_id);
3063 }
3064
3065 case_ids.push(label_id);
3066
3067 match case.value {
3068 crate::SwitchValue::I32(value) => {
3069 raw_cases.push(super::instructions::Case {
3070 value: value as Word,
3071 label_id,
3072 });
3073 }
3074 crate::SwitchValue::U32(value) => {
3075 raw_cases.push(super::instructions::Case { value, label_id });
3076 }
3077 crate::SwitchValue::Default => {
3078 default_id = Some(label_id);
3079 }
3080 }
3081 }
3082
3083 let default_id = default_id.unwrap();
3084
3085 self.function.consume(
3086 block,
3087 Instruction::switch(selector_id, default_id, &raw_cases),
3088 );
3089
3090 let inner_context = LoopContext {
3091 break_id: Some(merge_id),
3092 ..loop_context
3093 };
3094
3095 for (i, (case, label_id)) in cases
3096 .iter()
3097 .zip(case_ids.iter())
3098 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3099 .enumerate()
3100 {
3101 let case_finish_id = if case.fall_through {
3102 case_ids[i + 1]
3103 } else {
3104 merge_id
3105 };
3106 let _ = self.write_block(
3115 *label_id,
3116 &case.body,
3117 BlockExit::Branch {
3118 target: case_finish_id,
3119 },
3120 inner_context,
3121 debug_info,
3122 )?;
3123 }
3124
3125 block = Block::new(merge_id);
3126 }
3127 Statement::Loop {
3128 ref body,
3129 ref continuing,
3130 break_if,
3131 } => {
3132 let preamble_id = self.gen_id();
3133 self.function
3134 .consume(block, Instruction::branch(preamble_id));
3135
3136 let merge_id = self.gen_id();
3137 let body_id = self.gen_id();
3138 let continuing_id = self.gen_id();
3139
3140 block = Block::new(preamble_id);
3143 if let Some(debug_info) = debug_info {
3146 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3147 block.body.push(Instruction::line(
3148 debug_info.source_file_id,
3149 loc.line_number,
3150 loc.line_position,
3151 ))
3152 }
3153 block.body.push(Instruction::loop_merge(
3154 merge_id,
3155 continuing_id,
3156 spirv::SelectionControl::NONE,
3157 ));
3158
3159 if self.force_loop_bounding {
3160 block = self.write_force_bounded_loop_instructions(block, merge_id);
3161 }
3162 self.function.consume(block, Instruction::branch(body_id));
3163
3164 let _ = self.write_block(
3168 body_id,
3169 body,
3170 BlockExit::Branch {
3171 target: continuing_id,
3172 },
3173 LoopContext {
3174 continuing_id: Some(continuing_id),
3175 break_id: Some(merge_id),
3176 },
3177 debug_info,
3178 )?;
3179
3180 let exit = match break_if {
3181 Some(condition) => BlockExit::BreakIf {
3182 condition,
3183 preamble_id,
3184 },
3185 None => BlockExit::Branch {
3186 target: preamble_id,
3187 },
3188 };
3189
3190 let _ = self.write_block(
3194 continuing_id,
3195 continuing,
3196 exit,
3197 LoopContext {
3198 continuing_id: None,
3199 break_id: Some(merge_id),
3200 },
3201 debug_info,
3202 )?;
3203
3204 block = Block::new(merge_id);
3205 }
3206 Statement::Break => {
3207 self.function
3208 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3209 return Ok(BlockExitDisposition::Discarded);
3210 }
3211 Statement::Continue => {
3212 self.function.consume(
3213 block,
3214 Instruction::branch(loop_context.continuing_id.unwrap()),
3215 );
3216 return Ok(BlockExitDisposition::Discarded);
3217 }
3218 Statement::Return { value: Some(value) } => {
3219 let value_id = self.cached[value];
3220 let instruction = match self.function.entry_point_context {
3221 Some(ref context) => {
3224 self.writer.write_entry_point_return(
3225 value_id,
3226 self.ir_function.result.as_ref().unwrap(),
3227 &context.results,
3228 &mut block.body,
3229 )?;
3230 Instruction::return_void()
3231 }
3232 None => Instruction::return_value(value_id),
3233 };
3234 self.function.consume(block, instruction);
3235 return Ok(BlockExitDisposition::Discarded);
3236 }
3237 Statement::Return { value: None } => {
3238 self.function.consume(block, Instruction::return_void());
3239 return Ok(BlockExitDisposition::Discarded);
3240 }
3241 Statement::Kill => {
3242 self.function.consume(block, Instruction::kill());
3243 return Ok(BlockExitDisposition::Discarded);
3244 }
3245 Statement::ControlBarrier(flags) => {
3246 self.writer.write_control_barrier(flags, &mut block);
3247 }
3248 Statement::MemoryBarrier(flags) => {
3249 self.writer.write_memory_barrier(flags, &mut block);
3250 }
3251 Statement::Store { pointer, value } => {
3252 let value_id = self.cached[value];
3253 match self.write_access_chain(
3254 pointer,
3255 &mut block,
3256 AccessTypeAdjustment::None,
3257 )? {
3258 ExpressionPointer::Ready { pointer_id } => {
3259 let atomic_space = match *self.fun_info[pointer]
3260 .ty
3261 .inner_with(&self.ir_module.types)
3262 {
3263 crate::TypeInner::Pointer { base, space } => {
3264 match self.ir_module.types[base].inner {
3265 crate::TypeInner::Atomic { .. } => Some(space),
3266 _ => None,
3267 }
3268 }
3269 _ => None,
3270 };
3271 let instruction = if let Some(space) = atomic_space {
3272 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3273 let scope_constant_id = self.get_scope_constant(scope as u32);
3274 let semantics_id = self.get_index_constant(semantics.bits());
3275 Instruction::atomic_store(
3276 pointer_id,
3277 scope_constant_id,
3278 semantics_id,
3279 value_id,
3280 )
3281 } else {
3282 Instruction::store(pointer_id, value_id, None)
3283 };
3284 block.body.push(instruction);
3285 }
3286 ExpressionPointer::Conditional { condition, access } => {
3287 let mut selection = Selection::start(&mut block, ());
3288 selection.if_true(self, condition, ());
3289
3290 let pointer_id = access.result_id.unwrap();
3292 selection.block().body.push(access);
3293 selection
3294 .block()
3295 .body
3296 .push(Instruction::store(pointer_id, value_id, None));
3297
3298 selection.finish(self, ());
3301 }
3302 };
3303 }
3304 Statement::ImageStore {
3305 image,
3306 coordinate,
3307 array_index,
3308 value,
3309 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3310 Statement::Call {
3311 function: local_function,
3312 ref arguments,
3313 result,
3314 } => {
3315 let id = self.gen_id();
3316 self.temp_list.clear();
3317 for &argument in arguments {
3318 self.temp_list.push(self.cached[argument]);
3319 }
3320
3321 let type_id = match result {
3322 Some(expr) => {
3323 self.cached[expr] = id;
3324 self.get_expression_type_id(&self.fun_info[expr].ty)
3325 }
3326 None => self.writer.void_type,
3327 };
3328
3329 block.body.push(Instruction::function_call(
3330 type_id,
3331 id,
3332 self.writer.lookup_function[&local_function],
3333 &self.temp_list,
3334 ));
3335 }
3336 Statement::Atomic {
3337 pointer,
3338 ref fun,
3339 value,
3340 result,
3341 } => {
3342 let id = self.gen_id();
3343 let result_type_id =
3347 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3348
3349 if let Some(result) = result {
3350 self.cached[result] = id;
3351 }
3352
3353 let pointer_id = match self.write_access_chain(
3354 pointer,
3355 &mut block,
3356 AccessTypeAdjustment::None,
3357 )? {
3358 ExpressionPointer::Ready { pointer_id } => pointer_id,
3359 ExpressionPointer::Conditional { .. } => {
3360 return Err(Error::FeatureNotImplemented(
3361 "Atomics out-of-bounds handling",
3362 ));
3363 }
3364 };
3365
3366 let space = self.fun_info[pointer]
3367 .ty
3368 .inner_with(&self.ir_module.types)
3369 .pointer_space()
3370 .unwrap();
3371 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3372 let scope_constant_id = self.get_scope_constant(scope as u32);
3373 let semantics_id = self.get_index_constant(semantics.bits());
3374 let value_id = self.cached[value];
3375 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3376
3377 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3378 return Err(Error::FeatureNotImplemented(
3379 "Atomics with non-scalar values",
3380 ));
3381 };
3382
3383 let instruction = match *fun {
3384 crate::AtomicFunction::Add => {
3385 let spirv_op = match scalar.kind {
3386 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3387 spirv::Op::AtomicIAdd
3388 }
3389 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3390 _ => unimplemented!(),
3391 };
3392 Instruction::atomic_binary(
3393 spirv_op,
3394 result_type_id,
3395 id,
3396 pointer_id,
3397 scope_constant_id,
3398 semantics_id,
3399 value_id,
3400 )
3401 }
3402 crate::AtomicFunction::Subtract => {
3403 let (spirv_op, value_id) = match scalar.kind {
3404 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3405 (spirv::Op::AtomicISub, value_id)
3406 }
3407 crate::ScalarKind::Float => {
3408 let neg_result_id = self.gen_id();
3411 block.body.push(Instruction::unary(
3412 spirv::Op::FNegate,
3413 result_type_id,
3414 neg_result_id,
3415 value_id,
3416 ));
3417 (spirv::Op::AtomicFAddEXT, neg_result_id)
3418 }
3419 _ => unimplemented!(),
3420 };
3421 Instruction::atomic_binary(
3422 spirv_op,
3423 result_type_id,
3424 id,
3425 pointer_id,
3426 scope_constant_id,
3427 semantics_id,
3428 value_id,
3429 )
3430 }
3431 crate::AtomicFunction::And => {
3432 let spirv_op = match scalar.kind {
3433 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3434 spirv::Op::AtomicAnd
3435 }
3436 _ => unimplemented!(),
3437 };
3438 Instruction::atomic_binary(
3439 spirv_op,
3440 result_type_id,
3441 id,
3442 pointer_id,
3443 scope_constant_id,
3444 semantics_id,
3445 value_id,
3446 )
3447 }
3448 crate::AtomicFunction::InclusiveOr => {
3449 let spirv_op = match scalar.kind {
3450 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3451 spirv::Op::AtomicOr
3452 }
3453 _ => unimplemented!(),
3454 };
3455 Instruction::atomic_binary(
3456 spirv_op,
3457 result_type_id,
3458 id,
3459 pointer_id,
3460 scope_constant_id,
3461 semantics_id,
3462 value_id,
3463 )
3464 }
3465 crate::AtomicFunction::ExclusiveOr => {
3466 let spirv_op = match scalar.kind {
3467 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3468 spirv::Op::AtomicXor
3469 }
3470 _ => unimplemented!(),
3471 };
3472 Instruction::atomic_binary(
3473 spirv_op,
3474 result_type_id,
3475 id,
3476 pointer_id,
3477 scope_constant_id,
3478 semantics_id,
3479 value_id,
3480 )
3481 }
3482 crate::AtomicFunction::Min => {
3483 let spirv_op = match scalar.kind {
3484 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
3485 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
3486 _ => unimplemented!(),
3487 };
3488 Instruction::atomic_binary(
3489 spirv_op,
3490 result_type_id,
3491 id,
3492 pointer_id,
3493 scope_constant_id,
3494 semantics_id,
3495 value_id,
3496 )
3497 }
3498 crate::AtomicFunction::Max => {
3499 let spirv_op = match scalar.kind {
3500 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
3501 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
3502 _ => unimplemented!(),
3503 };
3504 Instruction::atomic_binary(
3505 spirv_op,
3506 result_type_id,
3507 id,
3508 pointer_id,
3509 scope_constant_id,
3510 semantics_id,
3511 value_id,
3512 )
3513 }
3514 crate::AtomicFunction::Exchange { compare: None } => {
3515 Instruction::atomic_binary(
3516 spirv::Op::AtomicExchange,
3517 result_type_id,
3518 id,
3519 pointer_id,
3520 scope_constant_id,
3521 semantics_id,
3522 value_id,
3523 )
3524 }
3525 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3526 let scalar_type_id =
3527 self.get_numeric_type_id(NumericType::Scalar(scalar));
3528 let bool_type_id =
3529 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
3530
3531 let cas_result_id = self.gen_id();
3532 let equality_result_id = self.gen_id();
3533 let equality_operator = match scalar.kind {
3534 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3535 spirv::Op::IEqual
3536 }
3537 _ => unimplemented!(),
3538 };
3539 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
3540 cas_instr.set_type(scalar_type_id);
3541 cas_instr.set_result(cas_result_id);
3542 cas_instr.add_operand(pointer_id);
3543 cas_instr.add_operand(scope_constant_id);
3544 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
3547 cas_instr.add_operand(self.cached[cmp]);
3548 block.body.push(cas_instr);
3549 block.body.push(Instruction::binary(
3550 equality_operator,
3551 bool_type_id,
3552 equality_result_id,
3553 cas_result_id,
3554 self.cached[cmp],
3555 ));
3556 Instruction::composite_construct(
3557 result_type_id,
3558 id,
3559 &[cas_result_id, equality_result_id],
3560 )
3561 }
3562 };
3563
3564 block.body.push(instruction);
3565 }
3566 Statement::ImageAtomic {
3567 image,
3568 coordinate,
3569 array_index,
3570 fun,
3571 value,
3572 } => {
3573 self.write_image_atomic(
3574 image,
3575 coordinate,
3576 array_index,
3577 fun,
3578 value,
3579 &mut block,
3580 )?;
3581 }
3582 Statement::WorkGroupUniformLoad { pointer, result } => {
3583 self.writer
3584 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3585 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
3586 match self.write_access_chain(
3588 pointer,
3589 &mut block,
3590 AccessTypeAdjustment::None,
3591 )? {
3592 ExpressionPointer::Ready { pointer_id } => {
3593 let id = self.gen_id();
3594 block.body.push(Instruction::load(
3595 result_type_id,
3596 id,
3597 pointer_id,
3598 None,
3599 ));
3600 self.cached[result] = id;
3601 }
3602 ExpressionPointer::Conditional { condition, access } => {
3603 self.cached[result] = self.write_conditional_indexed_load(
3604 result_type_id,
3605 condition,
3606 &mut block,
3607 move |id_gen, block| {
3608 let pointer_id = access.result_id.unwrap();
3610 let value_id = id_gen.next();
3611 block.body.push(access);
3612 block.body.push(Instruction::load(
3613 result_type_id,
3614 value_id,
3615 pointer_id,
3616 None,
3617 ));
3618 value_id
3619 },
3620 )
3621 }
3622 }
3623 self.writer
3624 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3625 }
3626 Statement::RayQuery { query, ref fun } => {
3627 self.write_ray_query_function(query, fun, &mut block);
3628 }
3629 Statement::SubgroupBallot {
3630 result,
3631 ref predicate,
3632 } => {
3633 self.write_subgroup_ballot(predicate, result, &mut block)?;
3634 }
3635 Statement::SubgroupCollectiveOperation {
3636 ref op,
3637 ref collective_op,
3638 argument,
3639 result,
3640 } => {
3641 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
3642 }
3643 Statement::SubgroupGather {
3644 ref mode,
3645 argument,
3646 result,
3647 } => {
3648 self.write_subgroup_gather(mode, argument, result, &mut block)?;
3649 }
3650 }
3651 }
3652
3653 let termination = match exit {
3654 BlockExit::Return => match self.ir_function.result {
3657 Some(ref result) if self.function.entry_point_context.is_none() => {
3658 let type_id = self.get_handle_type_id(result.ty);
3659 let null_id = self.writer.get_constant_null(type_id);
3660 Instruction::return_value(null_id)
3661 }
3662 _ => Instruction::return_void(),
3663 },
3664 BlockExit::Branch { target } => Instruction::branch(target),
3665 BlockExit::BreakIf {
3666 condition,
3667 preamble_id,
3668 } => {
3669 let condition_id = self.cached[condition];
3670
3671 Instruction::branch_conditional(
3672 condition_id,
3673 loop_context.break_id.unwrap(),
3674 preamble_id,
3675 )
3676 }
3677 };
3678
3679 self.function.consume(block, termination);
3680 Ok(BlockExitDisposition::Used)
3681 }
3682
3683 pub(super) fn write_function_body(
3684 &mut self,
3685 entry_id: Word,
3686 debug_info: Option<&DebugInfoInner>,
3687 ) -> Result<(), Error> {
3688 let _ = self.write_block(
3691 entry_id,
3692 &self.ir_function.body,
3693 BlockExit::Return,
3694 LoopContext::default(),
3695 debug_info,
3696 )?;
3697
3698 Ok(())
3699 }
3700}