1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
12 Instruction, LocalType, LookupType, NumericType, ResultMember, WrappedFunction, Writer,
13 WriterFlags,
14};
15use crate::{arena::Handle, proc::index::GuardedIndex, Statement};
16
17fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
18 match *type_inner {
19 crate::TypeInner::Scalar(_) => Dimension::Scalar,
20 crate::TypeInner::Vector { .. } => Dimension::Vector,
21 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
22 _ => unreachable!(),
23 }
24}
25
26enum AccessTypeAdjustment {
38 None,
47
48 IntroducePointer(spirv::StorageClass),
72}
73
74enum ExpressionPointer {
78 Ready { pointer_id: Word },
81
82 Conditional {
88 condition: Word,
89 access: Instruction,
90 },
91}
92
93enum BlockExit {
95 Return,
97 Branch {
99 target: Word,
101 },
102 BreakIf {
108 condition: Handle<crate::Expression>,
110 preamble_id: Word,
112 },
113}
114
115#[must_use]
126enum BlockExitDisposition {
127 Used,
131
132 Discarded,
137}
138
139#[derive(Clone, Copy, Default)]
140struct LoopContext {
141 continuing_id: Option<Word>,
142 break_id: Option<Word>,
143}
144
145#[derive(Debug)]
146pub(crate) struct DebugInfoInner<'a> {
147 pub source_code: &'a str,
148 pub source_file_id: Word,
149}
150
151impl Writer {
152 fn write_epilogue_position_y_flip(
157 &mut self,
158 position_id: Word,
159 body: &mut Vec<Instruction>,
160 ) -> Result<(), Error> {
161 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
162 let index_y_id = self.get_index_constant(1);
163 let access_id = self.id_gen.next();
164 body.push(Instruction::access_chain(
165 float_ptr_type_id,
166 access_id,
167 position_id,
168 &[index_y_id],
169 ));
170
171 let float_type_id = self.get_f32_type_id();
172 let load_id = self.id_gen.next();
173 body.push(Instruction::load(float_type_id, load_id, access_id, None));
174
175 let neg_id = self.id_gen.next();
176 body.push(Instruction::unary(
177 spirv::Op::FNegate,
178 float_type_id,
179 neg_id,
180 load_id,
181 ));
182
183 body.push(Instruction::store(access_id, neg_id, None));
184 Ok(())
185 }
186
187 fn write_epilogue_frag_depth_clamp(
189 &mut self,
190 frag_depth_id: Word,
191 body: &mut Vec<Instruction>,
192 ) -> Result<(), Error> {
193 let float_type_id = self.get_f32_type_id();
194 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
195 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
196
197 let original_id = self.id_gen.next();
198 body.push(Instruction::load(
199 float_type_id,
200 original_id,
201 frag_depth_id,
202 None,
203 ));
204
205 let clamp_id = self.id_gen.next();
206 body.push(Instruction::ext_inst_gl_op(
207 self.gl450_ext_inst_id,
208 spirv::GLOp::FClamp,
209 float_type_id,
210 clamp_id,
211 &[original_id, zero_scalar_id, one_scalar_id],
212 ));
213
214 body.push(Instruction::store(frag_depth_id, clamp_id, None));
215 Ok(())
216 }
217
218 fn write_entry_point_return(
219 &mut self,
220 value_id: Word,
221 ir_result: &crate::FunctionResult,
222 result_members: &[ResultMember],
223 body: &mut Vec<Instruction>,
224 ) -> Result<(), Error> {
225 for (index, res_member) in result_members.iter().enumerate() {
226 let member_value_id = match ir_result.binding {
227 Some(_) => value_id,
228 None => {
229 let member_value_id = self.id_gen.next();
230 body.push(Instruction::composite_extract(
231 res_member.type_id,
232 member_value_id,
233 value_id,
234 &[index as u32],
235 ));
236 member_value_id
237 }
238 };
239
240 self.store_io_with_f16_polyfill(body, res_member.id, member_value_id);
241
242 match res_member.built_in {
243 Some(crate::BuiltIn::Position { .. })
244 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
245 {
246 self.write_epilogue_position_y_flip(res_member.id, body)?;
247 }
248 Some(crate::BuiltIn::FragDepth)
249 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
250 {
251 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
252 }
253 _ => {}
254 }
255 }
256 Ok(())
257 }
258}
259
260impl BlockContext<'_> {
261 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
274 let uint_type_id = self.writer.get_u32_type_id();
275 let uint2_type_id = self.writer.get_vec2u_type_id();
276 let uint2_ptr_type_id = self
277 .writer
278 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
279 let bool_type_id = self.writer.get_bool_type_id();
280 let bool2_type_id = self.writer.get_vec2_bool_type_id();
281 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
282 let zero_uint2_const_id = self.writer.get_constant_composite(
283 LookupType::Local(LocalType::Numeric(NumericType::Vector {
284 size: crate::VectorSize::Bi,
285 scalar: crate::Scalar::U32,
286 })),
287 &[zero_uint_const_id, zero_uint_const_id],
288 );
289 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
290 let max_uint_const_id = self
291 .writer
292 .get_constant_scalar(crate::Literal::U32(u32::MAX));
293 let max_uint2_const_id = self.writer.get_constant_composite(
294 LookupType::Local(LocalType::Numeric(NumericType::Vector {
295 size: crate::VectorSize::Bi,
296 scalar: crate::Scalar::U32,
297 })),
298 &[max_uint_const_id, max_uint_const_id],
299 );
300
301 let loop_counter_var_id = self.gen_id();
302 if self.writer.flags.contains(WriterFlags::DEBUG) {
303 self.writer
304 .debugs
305 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
306 }
307 let var = super::LocalVariable {
308 id: loop_counter_var_id,
309 instruction: Instruction::variable(
310 uint2_ptr_type_id,
311 loop_counter_var_id,
312 spirv::StorageClass::Function,
313 Some(max_uint2_const_id),
314 ),
315 };
316 self.function.force_loop_bounding_vars.push(var);
317
318 let break_if_block = self.gen_id();
319
320 self.function
321 .consume(block, Instruction::branch(break_if_block));
322 block = Block::new(break_if_block);
323
324 let load_id = self.gen_id();
327 block.body.push(Instruction::load(
328 uint2_type_id,
329 load_id,
330 loop_counter_var_id,
331 None,
332 ));
333
334 let eq_id = self.gen_id();
337 block.body.push(Instruction::binary(
338 spirv::Op::IEqual,
339 bool2_type_id,
340 eq_id,
341 zero_uint2_const_id,
342 load_id,
343 ));
344 let all_eq_id = self.gen_id();
345 block.body.push(Instruction::relational(
346 spirv::Op::All,
347 bool_type_id,
348 all_eq_id,
349 eq_id,
350 ));
351
352 let inc_counter_block_id = self.gen_id();
353 block.body.push(Instruction::selection_merge(
354 inc_counter_block_id,
355 spirv::SelectionControl::empty(),
356 ));
357 self.function.consume(
358 block,
359 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
360 );
361 block = Block::new(inc_counter_block_id);
362
363 let low_id = self.gen_id();
369 block.body.push(Instruction::composite_extract(
370 uint_type_id,
371 low_id,
372 load_id,
373 &[1],
374 ));
375 let low_overflow_id = self.gen_id();
376 block.body.push(Instruction::binary(
377 spirv::Op::IEqual,
378 bool_type_id,
379 low_overflow_id,
380 low_id,
381 zero_uint_const_id,
382 ));
383 let carry_bit_id = self.gen_id();
384 block.body.push(Instruction::select(
385 uint_type_id,
386 carry_bit_id,
387 low_overflow_id,
388 one_uint_const_id,
389 zero_uint_const_id,
390 ));
391 let decrement_id = self.gen_id();
392 block.body.push(Instruction::composite_construct(
393 uint2_type_id,
394 decrement_id,
395 &[carry_bit_id, one_uint_const_id],
396 ));
397 let result_id = self.gen_id();
398 block.body.push(Instruction::binary(
399 spirv::Op::ISub,
400 uint2_type_id,
401 result_id,
402 load_id,
403 decrement_id,
404 ));
405 block
406 .body
407 .push(Instruction::store(loop_counter_var_id, result_id, None));
408
409 block
410 }
411
412 pub(super) fn cache_expression_value(
414 &mut self,
415 expr_handle: Handle<crate::Expression>,
416 block: &mut Block,
417 ) -> Result<(), Error> {
418 let is_named_expression = self
419 .ir_function
420 .named_expressions
421 .contains_key(&expr_handle);
422
423 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
424 return Ok(());
425 }
426
427 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
428 let id = match self.ir_function.expressions[expr_handle] {
429 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
430 crate::Expression::Constant(handle) => {
431 let init = self.ir_module.constants[handle].init;
432 self.writer.constant_ids[init]
433 }
434 crate::Expression::Override(_) => return Err(Error::Override),
435 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
436 crate::Expression::Compose { ty, ref components } => {
437 self.temp_list.clear();
438 if self.expression_constness.is_const(expr_handle) {
439 self.temp_list.extend(
440 crate::proc::flatten_compose(
441 ty,
442 components,
443 &self.ir_function.expressions,
444 &self.ir_module.types,
445 )
446 .map(|component| self.cached[component]),
447 );
448 self.writer
449 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
450 } else {
451 self.temp_list
452 .extend(components.iter().map(|&component| self.cached[component]));
453
454 let id = self.gen_id();
455 block.body.push(Instruction::composite_construct(
456 result_type_id,
457 id,
458 &self.temp_list,
459 ));
460 id
461 }
462 }
463 crate::Expression::Splat { size, value } => {
464 let value_id = self.cached[value];
465 let components = &[value_id; 4][..size as usize];
466
467 if self.expression_constness.is_const(expr_handle) {
468 let ty = self
469 .writer
470 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
471 self.writer.get_constant_composite(ty, components)
472 } else {
473 let id = self.gen_id();
474 block.body.push(Instruction::composite_construct(
475 result_type_id,
476 id,
477 components,
478 ));
479 id
480 }
481 }
482 crate::Expression::Access { base, index } => {
483 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
484 match *base_ty_inner {
485 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
486 0
492 }
493 _ if self.function.spilled_accesses.contains(base) => {
494 self.function.spilled_accesses.insert(expr_handle);
502 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
503 }
504 crate::TypeInner::Vector { .. } => {
505 self.write_vector_access(expr_handle, base, index, block)?
506 }
507 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
508 match GuardedIndex::from_expression(
510 index,
511 &self.ir_function.expressions,
512 self.ir_module,
513 ) {
514 GuardedIndex::Known(value) => {
515 let id = self.gen_id();
525 let base_id = self.cached[base];
526 block.body.push(Instruction::composite_extract(
527 result_type_id,
528 id,
529 base_id,
530 &[value],
531 ));
532 id
533 }
534 GuardedIndex::Expression(_) => {
535 self.spill_to_internal_variable(base, block);
542
543 self.function.spilled_accesses.insert(expr_handle);
546 self.maybe_access_spilled_composite(
547 expr_handle,
548 block,
549 result_type_id,
550 )?
551 }
552 }
553 }
554 crate::TypeInner::BindingArray {
555 base: binding_type, ..
556 } => {
557 let result_id = match self.write_access_chain(
560 expr_handle,
561 block,
562 AccessTypeAdjustment::IntroducePointer(
563 spirv::StorageClass::UniformConstant,
564 ),
565 )? {
566 ExpressionPointer::Ready { pointer_id } => pointer_id,
567 ExpressionPointer::Conditional { .. } => {
568 return Err(Error::FeatureNotImplemented(
569 "Texture array out-of-bounds handling",
570 ));
571 }
572 };
573
574 let binding_type_id = self.get_handle_type_id(binding_type);
575
576 let load_id = self.gen_id();
577 block.body.push(Instruction::load(
578 binding_type_id,
579 load_id,
580 result_id,
581 None,
582 ));
583
584 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
588 self.writer
589 .decorate_non_uniform_binding_array_access(load_id)?;
590 }
591
592 load_id
593 }
594 ref other => {
595 log::error!(
596 "Unable to access base {:?} of type {:?}",
597 self.ir_function.expressions[base],
598 other
599 );
600 return Err(Error::Validation(
601 "only vectors and arrays may be dynamically indexed by value",
602 ));
603 }
604 }
605 }
606 crate::Expression::AccessIndex { base, index } => {
607 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
608 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
609 0
615 }
616 _ if self.function.spilled_accesses.contains(base) => {
617 self.function.spilled_accesses.insert(expr_handle);
625 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
626 }
627 crate::TypeInner::Vector { .. }
628 | crate::TypeInner::Matrix { .. }
629 | crate::TypeInner::Array { .. }
630 | crate::TypeInner::Struct { .. } => {
631 let id = self.gen_id();
636 let base_id = self.cached[base];
637 block.body.push(Instruction::composite_extract(
638 result_type_id,
639 id,
640 base_id,
641 &[index],
642 ));
643 id
644 }
645 crate::TypeInner::BindingArray {
646 base: binding_type, ..
647 } => {
648 let result_id = match self.write_access_chain(
651 expr_handle,
652 block,
653 AccessTypeAdjustment::IntroducePointer(
654 spirv::StorageClass::UniformConstant,
655 ),
656 )? {
657 ExpressionPointer::Ready { pointer_id } => pointer_id,
658 ExpressionPointer::Conditional { .. } => {
659 return Err(Error::FeatureNotImplemented(
660 "Texture array out-of-bounds handling",
661 ));
662 }
663 };
664
665 let binding_type_id = self.get_handle_type_id(binding_type);
666
667 let load_id = self.gen_id();
668 block.body.push(Instruction::load(
669 binding_type_id,
670 load_id,
671 result_id,
672 None,
673 ));
674
675 load_id
676 }
677 ref other => {
678 log::error!("Unable to access index of {other:?}");
679 return Err(Error::FeatureNotImplemented("access index for type"));
680 }
681 }
682 }
683 crate::Expression::GlobalVariable(handle) => {
684 self.writer.global_variables[handle].access_id
685 }
686 crate::Expression::Swizzle {
687 size,
688 vector,
689 pattern,
690 } => {
691 let vector_id = self.cached[vector];
692 self.temp_list.clear();
693 for &sc in pattern[..size as usize].iter() {
694 self.temp_list.push(sc as Word);
695 }
696 let id = self.gen_id();
697 block.body.push(Instruction::vector_shuffle(
698 result_type_id,
699 id,
700 vector_id,
701 vector_id,
702 &self.temp_list,
703 ));
704 id
705 }
706 crate::Expression::Unary { op, expr } => {
707 let id = self.gen_id();
708 let expr_id = self.cached[expr];
709 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
710
711 let spirv_op = match op {
712 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
713 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
714 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
715 _ => return Err(Error::Validation("Unexpected kind for negation")),
716 },
717 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
718 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
719 };
720
721 block
722 .body
723 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
724 id
725 }
726 crate::Expression::Binary { op, left, right } => {
727 let id = self.gen_id();
728 let left_id = self.cached[left];
729 let right_id = self.cached[right];
730 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
731 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
732
733 if let Some(function_id) =
734 self.writer
735 .wrapped_functions
736 .get(&WrappedFunction::BinaryOp {
737 op,
738 left_type_id,
739 right_type_id,
740 })
741 {
742 block.body.push(Instruction::function_call(
743 result_type_id,
744 id,
745 *function_id,
746 &[left_id, right_id],
747 ));
748 } else {
749 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
750 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
751
752 let left_dimension = get_dimension(left_ty_inner);
753 let right_dimension = get_dimension(right_ty_inner);
754
755 let mut reverse_operands = false;
756
757 let spirv_op = match op {
758 crate::BinaryOperator::Add => match *left_ty_inner {
759 crate::TypeInner::Scalar(scalar)
760 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
761 crate::ScalarKind::Float => spirv::Op::FAdd,
762 _ => spirv::Op::IAdd,
763 },
764 crate::TypeInner::Matrix {
765 columns,
766 rows,
767 scalar,
768 } => {
769 self.write_matrix_matrix_column_op(
770 block,
771 id,
772 result_type_id,
773 left_id,
774 right_id,
775 columns,
776 rows,
777 scalar.width,
778 spirv::Op::FAdd,
779 );
780
781 self.cached[expr_handle] = id;
782 return Ok(());
783 }
784 _ => unimplemented!(),
785 },
786 crate::BinaryOperator::Subtract => match *left_ty_inner {
787 crate::TypeInner::Scalar(scalar)
788 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
789 crate::ScalarKind::Float => spirv::Op::FSub,
790 _ => spirv::Op::ISub,
791 },
792 crate::TypeInner::Matrix {
793 columns,
794 rows,
795 scalar,
796 } => {
797 self.write_matrix_matrix_column_op(
798 block,
799 id,
800 result_type_id,
801 left_id,
802 right_id,
803 columns,
804 rows,
805 scalar.width,
806 spirv::Op::FSub,
807 );
808
809 self.cached[expr_handle] = id;
810 return Ok(());
811 }
812 _ => unimplemented!(),
813 },
814 crate::BinaryOperator::Multiply => {
815 match (left_dimension, right_dimension) {
816 (Dimension::Scalar, Dimension::Vector) => {
817 self.write_vector_scalar_mult(
818 block,
819 id,
820 result_type_id,
821 right_id,
822 left_id,
823 right_ty_inner,
824 );
825
826 self.cached[expr_handle] = id;
827 return Ok(());
828 }
829 (Dimension::Vector, Dimension::Scalar) => {
830 self.write_vector_scalar_mult(
831 block,
832 id,
833 result_type_id,
834 left_id,
835 right_id,
836 left_ty_inner,
837 );
838
839 self.cached[expr_handle] = id;
840 return Ok(());
841 }
842 (Dimension::Vector, Dimension::Matrix) => {
843 spirv::Op::VectorTimesMatrix
844 }
845 (Dimension::Matrix, Dimension::Scalar) => {
846 spirv::Op::MatrixTimesScalar
847 }
848 (Dimension::Scalar, Dimension::Matrix) => {
849 reverse_operands = true;
850 spirv::Op::MatrixTimesScalar
851 }
852 (Dimension::Matrix, Dimension::Vector) => {
853 spirv::Op::MatrixTimesVector
854 }
855 (Dimension::Matrix, Dimension::Matrix) => {
856 spirv::Op::MatrixTimesMatrix
857 }
858 (Dimension::Vector, Dimension::Vector)
859 | (Dimension::Scalar, Dimension::Scalar)
860 if left_ty_inner.scalar_kind()
861 == Some(crate::ScalarKind::Float) =>
862 {
863 spirv::Op::FMul
864 }
865 (Dimension::Vector, Dimension::Vector)
866 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
867 }
868 }
869 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
870 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
871 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
872 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
873 _ => unimplemented!(),
874 },
875 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
876 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
879 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
880 unreachable!("Should have been handled by wrapped function")
881 }
882 _ => unimplemented!(),
883 },
884 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
885 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
886 spirv::Op::IEqual
887 }
888 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
889 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
890 _ => unimplemented!(),
891 },
892 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
893 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
894 spirv::Op::INotEqual
895 }
896 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
897 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
898 _ => unimplemented!(),
899 },
900 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
901 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
902 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
903 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
904 _ => unimplemented!(),
905 },
906 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
907 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
908 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
909 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
910 _ => unimplemented!(),
911 },
912 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
913 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
914 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
915 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
916 _ => unimplemented!(),
917 },
918 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
919 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
920 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
921 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
922 _ => unimplemented!(),
923 },
924 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
925 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
926 _ => spirv::Op::BitwiseAnd,
927 },
928 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
929 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
930 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
931 _ => spirv::Op::BitwiseOr,
932 },
933 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
934 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
935 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
936 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
937 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
938 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
939 _ => unimplemented!(),
940 },
941 };
942
943 block.body.push(Instruction::binary(
944 spirv_op,
945 result_type_id,
946 id,
947 if reverse_operands { right_id } else { left_id },
948 if reverse_operands { left_id } else { right_id },
949 ));
950 }
951 id
952 }
953 crate::Expression::Math {
954 fun,
955 arg,
956 arg1,
957 arg2,
958 arg3,
959 } => {
960 use crate::MathFunction as Mf;
961 enum MathOp {
962 Ext(spirv::GLOp),
963 Custom(Instruction),
964 }
965
966 let arg0_id = self.cached[arg];
967 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
968 let arg_scalar_kind = arg_ty.scalar_kind();
969 let arg1_id = match arg1 {
970 Some(handle) => self.cached[handle],
971 None => 0,
972 };
973 let arg2_id = match arg2 {
974 Some(handle) => self.cached[handle],
975 None => 0,
976 };
977 let arg3_id = match arg3 {
978 Some(handle) => self.cached[handle],
979 None => 0,
980 };
981
982 let id = self.gen_id();
983 let math_op = match fun {
984 Mf::Abs => {
986 match arg_scalar_kind {
987 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
988 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
989 Some(crate::ScalarKind::Uint) => {
990 MathOp::Custom(Instruction::unary(
991 spirv::Op::CopyObject, result_type_id,
993 id,
994 arg0_id,
995 ))
996 }
997 other => unimplemented!("Unexpected abs({:?})", other),
998 }
999 }
1000 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1001 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1002 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1003 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1004 other => unimplemented!("Unexpected min({:?})", other),
1005 }),
1006 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1007 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1008 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1009 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1010 other => unimplemented!("Unexpected max({:?})", other),
1011 }),
1012 Mf::Clamp => match arg_scalar_kind {
1013 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1017 Some(_) => {
1018 let (min_op, max_op) = match arg_scalar_kind {
1019 Some(crate::ScalarKind::Sint) => {
1020 (spirv::GLOp::SMin, spirv::GLOp::SMax)
1021 }
1022 Some(crate::ScalarKind::Uint) => {
1023 (spirv::GLOp::UMin, spirv::GLOp::UMax)
1024 }
1025 _ => unreachable!(),
1026 };
1027
1028 let max_id = self.gen_id();
1029 block.body.push(Instruction::ext_inst_gl_op(
1030 self.writer.gl450_ext_inst_id,
1031 max_op,
1032 result_type_id,
1033 max_id,
1034 &[arg0_id, arg1_id],
1035 ));
1036
1037 MathOp::Custom(Instruction::ext_inst_gl_op(
1038 self.writer.gl450_ext_inst_id,
1039 min_op,
1040 result_type_id,
1041 id,
1042 &[max_id, arg2_id],
1043 ))
1044 }
1045 other => unimplemented!("Unexpected max({:?})", other),
1046 },
1047 Mf::Saturate => {
1048 let (maybe_size, scalar) = match *arg_ty {
1049 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1050 crate::TypeInner::Scalar(scalar) => (None, scalar),
1051 ref other => unimplemented!("Unexpected saturate({:?})", other),
1052 };
1053 let scalar = crate::Scalar::float(scalar.width);
1054 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1055 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1056
1057 if let Some(size) = maybe_size {
1058 let ty =
1059 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1060
1061 self.temp_list.clear();
1062 self.temp_list.resize(size as _, arg1_id);
1063
1064 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1065
1066 self.temp_list.fill(arg2_id);
1067
1068 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1069 }
1070
1071 MathOp::Custom(Instruction::ext_inst_gl_op(
1072 self.writer.gl450_ext_inst_id,
1073 spirv::GLOp::FClamp,
1074 result_type_id,
1075 id,
1076 &[arg0_id, arg1_id, arg2_id],
1077 ))
1078 }
1079 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1081 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1082 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1083 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1084 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1085 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1086 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1087 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1088 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1089 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1090 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1091 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1092 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1093 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1094 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1095 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1097 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1098 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1099 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1100 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1101 Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1102 Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1103 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1104 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1106 crate::TypeInner::Vector {
1107 scalar:
1108 crate::Scalar {
1109 kind: crate::ScalarKind::Float,
1110 ..
1111 },
1112 ..
1113 } => MathOp::Custom(Instruction::binary(
1114 spirv::Op::Dot,
1115 result_type_id,
1116 id,
1117 arg0_id,
1118 arg1_id,
1119 )),
1120 crate::TypeInner::Vector { size, .. } => {
1122 self.write_dot_product(
1123 id,
1124 result_type_id,
1125 arg0_id,
1126 arg1_id,
1127 size as u32,
1128 block,
1129 |result_id, composite_id, index| {
1130 Instruction::composite_extract(
1131 result_type_id,
1132 result_id,
1133 composite_id,
1134 &[index],
1135 )
1136 },
1137 );
1138 self.cached[expr_handle] = id;
1139 return Ok(());
1140 }
1141 _ => unreachable!(
1142 "Correct TypeInner for dot product should be already validated"
1143 ),
1144 },
1145 fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146 if self
1147 .writer
1148 .require_all(&[
1149 spirv::Capability::DotProduct,
1150 spirv::Capability::DotProductInput4x8BitPacked,
1151 ])
1152 .is_ok()
1153 {
1154 if self.writer.lang_version() < (1, 6) {
1156 self.writer.use_extension("SPV_KHR_integer_dot_product");
1160 }
1161
1162 let op = match fun {
1163 Mf::Dot4I8Packed => spirv::Op::SDot,
1164 Mf::Dot4U8Packed => spirv::Op::UDot,
1165 _ => unreachable!(),
1166 };
1167
1168 block.body.push(Instruction::ternary(
1169 op,
1170 result_type_id,
1171 id,
1172 arg0_id,
1173 arg1_id,
1174 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
1175 ));
1176 } else {
1177 let (extract_op, arg0_id, arg1_id) = match fun {
1179 Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
1180 Mf::Dot4I8Packed => {
1181 let new_arg0_id = self.gen_id();
1184 block.body.push(Instruction::unary(
1185 spirv::Op::Bitcast,
1186 result_type_id,
1187 new_arg0_id,
1188 arg0_id,
1189 ));
1190
1191 let new_arg1_id = self.gen_id();
1192 block.body.push(Instruction::unary(
1193 spirv::Op::Bitcast,
1194 result_type_id,
1195 new_arg1_id,
1196 arg1_id,
1197 ));
1198
1199 (spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
1200 }
1201 _ => unreachable!(),
1202 };
1203
1204 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1205
1206 const VEC_LENGTH: u8 = 4;
1207 let bit_shifts: [_; VEC_LENGTH as usize] =
1208 core::array::from_fn(|index| {
1209 self.writer
1210 .get_constant_scalar(crate::Literal::U32(index as u32 * 8))
1211 });
1212
1213 self.write_dot_product(
1214 id,
1215 result_type_id,
1216 arg0_id,
1217 arg1_id,
1218 VEC_LENGTH as Word,
1219 block,
1220 |result_id, composite_id, index| {
1221 Instruction::ternary(
1222 extract_op,
1223 result_type_id,
1224 result_id,
1225 composite_id,
1226 bit_shifts[index as usize],
1227 eight,
1228 )
1229 },
1230 );
1231 }
1232
1233 self.cached[expr_handle] = id;
1234 return Ok(());
1235 }
1236 Mf::Outer => MathOp::Custom(Instruction::binary(
1237 spirv::Op::OuterProduct,
1238 result_type_id,
1239 id,
1240 arg0_id,
1241 arg1_id,
1242 )),
1243 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1244 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1245 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1246 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1247 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1248 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1249 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1250 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1252 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1253 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1254 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1255 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1256 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1258 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1259 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1260 other => unimplemented!("Unexpected sign({:?})", other),
1261 }),
1262 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1263 Mf::Mix => {
1264 let selector = arg2.unwrap();
1265 let selector_ty =
1266 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1267 match (arg_ty, selector_ty) {
1268 (
1270 &crate::TypeInner::Vector { size, .. },
1271 &crate::TypeInner::Scalar(scalar),
1272 ) => {
1273 let selector_type_id =
1274 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1275 self.temp_list.clear();
1276 self.temp_list.resize(size as usize, arg2_id);
1277
1278 let selector_id = self.gen_id();
1279 block.body.push(Instruction::composite_construct(
1280 selector_type_id,
1281 selector_id,
1282 &self.temp_list,
1283 ));
1284
1285 MathOp::Custom(Instruction::ext_inst_gl_op(
1286 self.writer.gl450_ext_inst_id,
1287 spirv::GLOp::FMix,
1288 result_type_id,
1289 id,
1290 &[arg0_id, arg1_id, selector_id],
1291 ))
1292 }
1293 _ => MathOp::Ext(spirv::GLOp::FMix),
1294 }
1295 }
1296 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1297 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1298 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1299 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1300 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1301 Mf::Transpose => MathOp::Custom(Instruction::unary(
1302 spirv::Op::Transpose,
1303 result_type_id,
1304 id,
1305 arg0_id,
1306 )),
1307 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1308 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1309 spirv::Op::QuantizeToF16,
1310 result_type_id,
1311 id,
1312 arg0_id,
1313 )),
1314 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1315 spirv::Op::BitReverse,
1316 result_type_id,
1317 id,
1318 arg0_id,
1319 )),
1320 Mf::CountTrailingZeros => {
1321 let uint_id = match *arg_ty {
1322 crate::TypeInner::Vector { size, scalar } => {
1323 let ty =
1324 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1325
1326 self.temp_list.clear();
1327 self.temp_list.resize(
1328 size as _,
1329 self.writer
1330 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1331 );
1332
1333 self.writer.get_constant_composite(ty, &self.temp_list)
1334 }
1335 crate::TypeInner::Scalar(scalar) => self
1336 .writer
1337 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1338 _ => unreachable!(),
1339 };
1340
1341 let lsb_id = self.gen_id();
1342 block.body.push(Instruction::ext_inst_gl_op(
1343 self.writer.gl450_ext_inst_id,
1344 spirv::GLOp::FindILsb,
1345 result_type_id,
1346 lsb_id,
1347 &[arg0_id],
1348 ));
1349
1350 MathOp::Custom(Instruction::ext_inst_gl_op(
1351 self.writer.gl450_ext_inst_id,
1352 spirv::GLOp::UMin,
1353 result_type_id,
1354 id,
1355 &[uint_id, lsb_id],
1356 ))
1357 }
1358 Mf::CountLeadingZeros => {
1359 let (int_type_id, int_id, width) = match *arg_ty {
1360 crate::TypeInner::Vector { size, scalar } => {
1361 let ty =
1362 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1363
1364 self.temp_list.clear();
1365 self.temp_list.resize(
1366 size as _,
1367 self.writer
1368 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1369 );
1370
1371 (
1372 self.get_type_id(ty),
1373 self.writer.get_constant_composite(ty, &self.temp_list),
1374 scalar.width,
1375 )
1376 }
1377 crate::TypeInner::Scalar(scalar) => (
1378 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1379 self.writer
1380 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1381 scalar.width,
1382 ),
1383 _ => unreachable!(),
1384 };
1385
1386 if width != 4 {
1387 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1388 };
1389
1390 let msb_id = self.gen_id();
1391 block.body.push(Instruction::ext_inst_gl_op(
1392 self.writer.gl450_ext_inst_id,
1393 if width != 4 {
1394 spirv::GLOp::FindILsb
1395 } else {
1396 spirv::GLOp::FindUMsb
1397 },
1398 int_type_id,
1399 msb_id,
1400 &[arg0_id],
1401 ));
1402
1403 MathOp::Custom(Instruction::binary(
1404 spirv::Op::ISub,
1405 result_type_id,
1406 id,
1407 int_id,
1408 msb_id,
1409 ))
1410 }
1411 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1412 spirv::Op::BitCount,
1413 result_type_id,
1414 id,
1415 arg0_id,
1416 )),
1417 Mf::ExtractBits => {
1418 let op = match arg_scalar_kind {
1419 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1420 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1421 other => unimplemented!("Unexpected sign({:?})", other),
1422 };
1423
1424 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1439 let width_constant = self
1440 .writer
1441 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1442
1443 let u32_type =
1444 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1445
1446 let offset_id = self.gen_id();
1448 block.body.push(Instruction::ext_inst_gl_op(
1449 self.writer.gl450_ext_inst_id,
1450 spirv::GLOp::UMin,
1451 u32_type,
1452 offset_id,
1453 &[arg1_id, width_constant],
1454 ));
1455
1456 let max_count_id = self.gen_id();
1458 block.body.push(Instruction::binary(
1459 spirv::Op::ISub,
1460 u32_type,
1461 max_count_id,
1462 width_constant,
1463 offset_id,
1464 ));
1465
1466 let count_id = self.gen_id();
1468 block.body.push(Instruction::ext_inst_gl_op(
1469 self.writer.gl450_ext_inst_id,
1470 spirv::GLOp::UMin,
1471 u32_type,
1472 count_id,
1473 &[arg2_id, max_count_id],
1474 ));
1475
1476 MathOp::Custom(Instruction::ternary(
1477 op,
1478 result_type_id,
1479 id,
1480 arg0_id,
1481 offset_id,
1482 count_id,
1483 ))
1484 }
1485 Mf::InsertBits => {
1486 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1489 let width_constant = self
1490 .writer
1491 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1492
1493 let u32_type =
1494 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1495
1496 let offset_id = self.gen_id();
1498 block.body.push(Instruction::ext_inst_gl_op(
1499 self.writer.gl450_ext_inst_id,
1500 spirv::GLOp::UMin,
1501 u32_type,
1502 offset_id,
1503 &[arg2_id, width_constant],
1504 ));
1505
1506 let max_count_id = self.gen_id();
1508 block.body.push(Instruction::binary(
1509 spirv::Op::ISub,
1510 u32_type,
1511 max_count_id,
1512 width_constant,
1513 offset_id,
1514 ));
1515
1516 let count_id = self.gen_id();
1518 block.body.push(Instruction::ext_inst_gl_op(
1519 self.writer.gl450_ext_inst_id,
1520 spirv::GLOp::UMin,
1521 u32_type,
1522 count_id,
1523 &[arg3_id, max_count_id],
1524 ));
1525
1526 MathOp::Custom(Instruction::quaternary(
1527 spirv::Op::BitFieldInsert,
1528 result_type_id,
1529 id,
1530 arg0_id,
1531 arg1_id,
1532 offset_id,
1533 count_id,
1534 ))
1535 }
1536 Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1537 Mf::FirstLeadingBit => {
1538 if arg_ty.scalar_width() == Some(4) {
1539 let thing = match arg_scalar_kind {
1540 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1541 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1542 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1543 };
1544 MathOp::Ext(thing)
1545 } else {
1546 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1547 }
1548 }
1549 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1550 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1551 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1552 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1553 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1554 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
1555 let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
1556 let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1557
1558 let last_instruction =
1559 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1560 self.write_pack4x8_optimized(
1561 block,
1562 result_type_id,
1563 arg0_id,
1564 id,
1565 is_signed,
1566 should_clamp,
1567 )
1568 } else {
1569 self.write_pack4x8_polyfill(
1570 block,
1571 result_type_id,
1572 arg0_id,
1573 id,
1574 is_signed,
1575 should_clamp,
1576 )
1577 };
1578
1579 MathOp::Custom(last_instruction)
1580 }
1581 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1582 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1583 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1584 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1585 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1586 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1587 let is_signed = matches!(fun, Mf::Unpack4xI8);
1588
1589 let last_instruction =
1590 if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() {
1591 self.write_unpack4x8_optimized(
1592 block,
1593 result_type_id,
1594 arg0_id,
1595 id,
1596 is_signed,
1597 )
1598 } else {
1599 self.write_unpack4x8_polyfill(
1600 block,
1601 result_type_id,
1602 arg0_id,
1603 id,
1604 is_signed,
1605 )
1606 };
1607
1608 MathOp::Custom(last_instruction)
1609 }
1610 };
1611
1612 block.body.push(match math_op {
1613 MathOp::Ext(op) => Instruction::ext_inst_gl_op(
1614 self.writer.gl450_ext_inst_id,
1615 op,
1616 result_type_id,
1617 id,
1618 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1619 ),
1620 MathOp::Custom(inst) => inst,
1621 });
1622 id
1623 }
1624 crate::Expression::LocalVariable(variable) => {
1625 if let Some(rq_tracker) = self
1626 .function
1627 .ray_query_initialization_tracker_variables
1628 .get(&variable)
1629 {
1630 self.ray_query_tracker_expr.insert(
1631 expr_handle,
1632 super::RayQueryTrackers {
1633 initialized_tracker: rq_tracker.id,
1634 t_max_tracker: self
1635 .function
1636 .ray_query_t_max_tracker_variables
1637 .get(&variable)
1638 .expect("Both trackers are set at the same time.")
1639 .id,
1640 },
1641 );
1642 }
1643 self.function.variables[&variable].id
1644 }
1645 crate::Expression::Load { pointer } => {
1646 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1647 }
1648 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1649 crate::Expression::CallResult(_)
1650 | crate::Expression::AtomicResult { .. }
1651 | crate::Expression::WorkGroupUniformLoadResult { .. }
1652 | crate::Expression::RayQueryProceedResult
1653 | crate::Expression::SubgroupBallotResult
1654 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1655 crate::Expression::As {
1656 expr,
1657 kind,
1658 convert,
1659 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1660 crate::Expression::ImageLoad {
1661 image,
1662 coordinate,
1663 array_index,
1664 sample,
1665 level,
1666 } => self.write_image_load(
1667 result_type_id,
1668 image,
1669 coordinate,
1670 array_index,
1671 level,
1672 sample,
1673 block,
1674 )?,
1675 crate::Expression::ImageSample {
1676 image,
1677 sampler,
1678 gather,
1679 coordinate,
1680 array_index,
1681 offset,
1682 level,
1683 depth_ref,
1684 clamp_to_edge,
1685 } => self.write_image_sample(
1686 result_type_id,
1687 image,
1688 sampler,
1689 gather,
1690 coordinate,
1691 array_index,
1692 offset,
1693 level,
1694 depth_ref,
1695 clamp_to_edge,
1696 block,
1697 )?,
1698 crate::Expression::Select {
1699 condition,
1700 accept,
1701 reject,
1702 } => {
1703 let id = self.gen_id();
1704 let mut condition_id = self.cached[condition];
1705 let accept_id = self.cached[accept];
1706 let reject_id = self.cached[reject];
1707
1708 let condition_ty = self.fun_info[condition]
1709 .ty
1710 .inner_with(&self.ir_module.types);
1711 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
1712
1713 if let (
1714 &crate::TypeInner::Scalar(
1715 condition_scalar @ crate::Scalar {
1716 kind: crate::ScalarKind::Bool,
1717 ..
1718 },
1719 ),
1720 &crate::TypeInner::Vector { size, .. },
1721 ) = (condition_ty, object_ty)
1722 {
1723 self.temp_list.clear();
1724 self.temp_list.resize(size as usize, condition_id);
1725
1726 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1727 size,
1728 scalar: condition_scalar,
1729 });
1730
1731 let id = self.gen_id();
1732 block.body.push(Instruction::composite_construct(
1733 bool_vector_type_id,
1734 id,
1735 &self.temp_list,
1736 ));
1737 condition_id = id
1738 }
1739
1740 let instruction =
1741 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
1742 block.body.push(instruction);
1743 id
1744 }
1745 crate::Expression::Derivative { axis, ctrl, expr } => {
1746 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1747 match ctrl {
1748 Ctrl::Coarse | Ctrl::Fine => {
1749 self.writer.require_any(
1750 "DerivativeControl",
1751 &[spirv::Capability::DerivativeControl],
1752 )?;
1753 }
1754 Ctrl::None => {}
1755 }
1756 let id = self.gen_id();
1757 let expr_id = self.cached[expr];
1758 let op = match (axis, ctrl) {
1759 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
1760 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
1761 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
1762 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
1763 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
1764 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
1765 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
1766 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
1767 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
1768 };
1769 block
1770 .body
1771 .push(Instruction::derivative(op, result_type_id, id, expr_id));
1772 id
1773 }
1774 crate::Expression::ImageQuery { image, query } => {
1775 self.write_image_query(result_type_id, image, query, block)?
1776 }
1777 crate::Expression::Relational { fun, argument } => {
1778 use crate::RelationalFunction as Rf;
1779 let arg_id = self.cached[argument];
1780 let op = match fun {
1781 Rf::All => spirv::Op::All,
1782 Rf::Any => spirv::Op::Any,
1783 Rf::IsNan => spirv::Op::IsNan,
1784 Rf::IsInf => spirv::Op::IsInf,
1785 };
1786 let id = self.gen_id();
1787 block
1788 .body
1789 .push(Instruction::relational(op, result_type_id, id, arg_id));
1790 id
1791 }
1792 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
1793 crate::Expression::RayQueryGetIntersection { query, committed } => {
1794 let query_id = self.cached[query];
1795 let init_tracker_id = *self
1796 .ray_query_tracker_expr
1797 .get(&query)
1798 .expect("not a cached ray query");
1799 let func_id = self
1800 .writer
1801 .write_ray_query_get_intersection_function(committed, self.ir_module);
1802 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
1803 let intersection_type_id = self.get_handle_type_id(ray_intersection);
1804 let id = self.gen_id();
1805 block.body.push(Instruction::function_call(
1806 intersection_type_id,
1807 id,
1808 func_id,
1809 &[query_id, init_tracker_id.initialized_tracker],
1810 ));
1811 id
1812 }
1813 crate::Expression::RayQueryVertexPositions { query, committed } => {
1814 self.writer.require_any(
1815 "RayQueryVertexPositions",
1816 &[spirv::Capability::RayQueryPositionFetchKHR],
1817 )?;
1818 self.write_ray_query_return_vertex_position(query, block, committed)
1819 }
1820 };
1821
1822 self.cached[expr_handle] = id;
1823 Ok(())
1824 }
1825
1826 fn write_as_expression(
1829 &mut self,
1830 expr: Handle<crate::Expression>,
1831 convert: Option<u8>,
1832 kind: crate::ScalarKind,
1833
1834 block: &mut Block,
1835 result_type_id: u32,
1836 ) -> Result<u32, Error> {
1837 use crate::ScalarKind as Sk;
1838 let expr_id = self.cached[expr];
1839 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1840
1841 if let crate::TypeInner::Matrix {
1846 columns,
1847 rows,
1848 scalar,
1849 } = *ty
1850 {
1851 let Some(convert) = convert else {
1852 return Ok(expr_id);
1854 };
1855
1856 if convert == scalar.width {
1857 return Ok(expr_id);
1859 }
1860
1861 if kind != Sk::Float {
1862 return Err(Error::Validation("Matrices must be floats"));
1864 }
1865
1866 let column_src_ty =
1868 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1869 size: rows,
1870 scalar,
1871 })));
1872
1873 let column_dst_ty =
1875 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1876 size: rows,
1877 scalar: crate::Scalar {
1878 kind,
1879 width: convert,
1880 },
1881 })));
1882
1883 let mut components = ArrayVec::<Word, 4>::new();
1884
1885 for column in 0..columns as usize {
1886 let column_id = self.gen_id();
1887 block.body.push(Instruction::composite_extract(
1888 column_src_ty,
1889 column_id,
1890 expr_id,
1891 &[column as u32],
1892 ));
1893
1894 let column_conv_id = self.gen_id();
1895 block.body.push(Instruction::unary(
1896 spirv::Op::FConvert,
1897 column_dst_ty,
1898 column_conv_id,
1899 column_id,
1900 ));
1901
1902 components.push(column_conv_id);
1903 }
1904
1905 let construct_id = self.gen_id();
1906
1907 block.body.push(Instruction::composite_construct(
1908 result_type_id,
1909 construct_id,
1910 &components,
1911 ));
1912
1913 return Ok(construct_id);
1914 }
1915
1916 let (src_scalar, src_size) = match *ty {
1917 crate::TypeInner::Scalar(scalar) => (scalar, None),
1918 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
1919 ref other => {
1920 log::error!("As source {other:?}");
1921 return Err(Error::Validation("Unexpected Expression::As source"));
1922 }
1923 };
1924
1925 enum Cast {
1926 Identity(Word),
1927 Unary(spirv::Op, Word),
1928 Binary(spirv::Op, Word, Word),
1929 Ternary(spirv::Op, Word, Word, Word),
1930 }
1931 let cast = match (src_scalar.kind, kind, convert) {
1932 (src_kind, kind, convert)
1935 if src_kind == kind
1936 && convert.filter(|&width| width != src_scalar.width).is_none() =>
1937 {
1938 Cast::Identity(expr_id)
1939 }
1940 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
1941 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
1942 (_, Sk::Bool, Some(_)) => {
1944 let op = match src_scalar.kind {
1945 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
1946 Sk::Float => spirv::Op::FUnordNotEqual,
1947 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
1948 };
1949 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
1950 let zero_id = match src_size {
1951 Some(size) => {
1952 let ty = LocalType::Numeric(NumericType::Vector {
1953 size,
1954 scalar: src_scalar,
1955 })
1956 .into();
1957
1958 self.temp_list.clear();
1959 self.temp_list.resize(size as _, zero_scalar_id);
1960
1961 self.writer.get_constant_composite(ty, &self.temp_list)
1962 }
1963 None => zero_scalar_id,
1964 };
1965
1966 Cast::Binary(op, expr_id, zero_id)
1967 }
1968 (Sk::Bool, _, Some(dst_width)) => {
1970 let dst_scalar = crate::Scalar {
1971 kind,
1972 width: dst_width,
1973 };
1974 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
1975 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
1976 let (accept_id, reject_id) = match src_size {
1977 Some(size) => {
1978 let ty = LocalType::Numeric(NumericType::Vector {
1979 size,
1980 scalar: dst_scalar,
1981 })
1982 .into();
1983
1984 self.temp_list.clear();
1985 self.temp_list.resize(size as _, zero_scalar_id);
1986
1987 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
1988
1989 self.temp_list.fill(one_scalar_id);
1990
1991 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1992
1993 (vec1_id, vec0_id)
1994 }
1995 None => (one_scalar_id, zero_scalar_id),
1996 };
1997
1998 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
1999 }
2000 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
2011 let dst_scalar = crate::Scalar { kind, width };
2012 let (min, max) =
2013 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
2014 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
2015
2016 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
2017 None => const_id,
2018 Some(size) => {
2019 let constituent_ids = [const_id; crate::VectorSize::MAX];
2020 writer.get_constant_composite(
2021 LookupType::Local(LocalType::Numeric(NumericType::Vector {
2022 size,
2023 scalar: src_scalar,
2024 })),
2025 &constituent_ids[..size as usize],
2026 )
2027 }
2028 };
2029 let min_const_id = self.writer.get_constant_scalar(min);
2030 let min_const_id = maybe_splat_const(self.writer, min_const_id);
2031 let max_const_id = self.writer.get_constant_scalar(max);
2032 let max_const_id = maybe_splat_const(self.writer, max_const_id);
2033
2034 let clamp_id = self.gen_id();
2035 block.body.push(Instruction::ext_inst_gl_op(
2036 self.writer.gl450_ext_inst_id,
2037 spirv::GLOp::FClamp,
2038 expr_type_id,
2039 clamp_id,
2040 &[expr_id, min_const_id, max_const_id],
2041 ));
2042
2043 let op = match dst_scalar.kind {
2044 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
2045 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
2046 _ => unreachable!(),
2047 };
2048 Cast::Unary(op, clamp_id)
2049 }
2050 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2051 Cast::Unary(spirv::Op::FConvert, expr_id)
2052 }
2053 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2054 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2055 Cast::Unary(spirv::Op::SConvert, expr_id)
2056 }
2057 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2058 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2059 Cast::Unary(spirv::Op::UConvert, expr_id)
2060 }
2061 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2062 Cast::Unary(spirv::Op::SConvert, expr_id)
2063 }
2064 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2065 Cast::Unary(spirv::Op::UConvert, expr_id)
2066 }
2067 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2069 };
2070 Ok(match cast {
2071 Cast::Identity(expr) => expr,
2072 Cast::Unary(op, op1) => {
2073 let id = self.gen_id();
2074 block
2075 .body
2076 .push(Instruction::unary(op, result_type_id, id, op1));
2077 id
2078 }
2079 Cast::Binary(op, op1, op2) => {
2080 let id = self.gen_id();
2081 block
2082 .body
2083 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2084 id
2085 }
2086 Cast::Ternary(op, op1, op2, op3) => {
2087 let id = self.gen_id();
2088 block
2089 .body
2090 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2091 id
2092 }
2093 })
2094 }
2095
2096 fn write_access_chain(
2107 &mut self,
2108 mut expr_handle: Handle<crate::Expression>,
2109 block: &mut Block,
2110 type_adjustment: AccessTypeAdjustment,
2111 ) -> Result<ExpressionPointer, Error> {
2112 let result_type_id = {
2113 let resolution = &self.fun_info[expr_handle].ty;
2114 match type_adjustment {
2115 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2116 AccessTypeAdjustment::IntroducePointer(class) => {
2117 self.writer.get_resolution_pointer_id(resolution, class)
2118 }
2119 }
2120 };
2121
2122 let mut accumulated_checks = None;
2126
2127 let mut is_non_uniform_binding_array = false;
2129
2130 self.temp_list.clear();
2131 let root_id = loop {
2132 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2135 break spilled.id;
2138 }
2139
2140 expr_handle = match self.ir_function.expressions[expr_handle] {
2141 crate::Expression::Access { base, index } => {
2142 is_non_uniform_binding_array |=
2143 self.is_nonuniform_binding_array_access(base, index);
2144
2145 let index = GuardedIndex::Expression(index);
2146 let index_id =
2147 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2148 self.temp_list.push(index_id);
2149
2150 base
2151 }
2152 crate::Expression::AccessIndex { base, index } => {
2153 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2156 if let crate::TypeInner::Pointer { base, .. } = *base_ty {
2157 base_ty = &self.ir_module.types[base].inner;
2158 }
2159 let index_id = if let crate::TypeInner::Struct { .. } = *base_ty {
2160 self.get_index_constant(index)
2161 } else {
2162 self.write_access_chain_index(
2169 base,
2170 GuardedIndex::Known(index),
2171 &mut accumulated_checks,
2172 block,
2173 )?
2174 };
2175
2176 self.temp_list.push(index_id);
2177 base
2178 }
2179 crate::Expression::GlobalVariable(handle) => {
2180 let gv = &self.writer.global_variables[handle];
2181 break gv.access_id;
2182 }
2183 crate::Expression::LocalVariable(variable) => {
2184 let local_var = &self.function.variables[&variable];
2185 break local_var.id;
2186 }
2187 crate::Expression::FunctionArgument(index) => {
2188 break self.function.parameter_id(index);
2189 }
2190 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2191 }
2192 };
2193
2194 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2195 (
2196 root_id,
2197 ExpressionPointer::Ready {
2198 pointer_id: root_id,
2199 },
2200 )
2201 } else {
2202 self.temp_list.reverse();
2203 let pointer_id = self.gen_id();
2204 let access =
2205 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2206
2207 let expr_pointer = match accumulated_checks {
2212 Some(condition) => ExpressionPointer::Conditional { condition, access },
2213 None => {
2214 block.body.push(access);
2215 ExpressionPointer::Ready { pointer_id }
2216 }
2217 };
2218 (pointer_id, expr_pointer)
2219 };
2220 if is_non_uniform_binding_array {
2224 self.writer
2225 .decorate_non_uniform_binding_array_access(pointer_id)?;
2226 }
2227
2228 Ok(expr_pointer)
2229 }
2230
2231 fn is_nonuniform_binding_array_access(
2232 &mut self,
2233 base: Handle<crate::Expression>,
2234 index: Handle<crate::Expression>,
2235 ) -> bool {
2236 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2237 else {
2238 return false;
2239 };
2240
2241 let gvar = &self.ir_module.global_variables[var_handle];
2244 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2245 return false;
2246 };
2247
2248 self.fun_info[index].uniformity.non_uniform_result.is_some()
2249 }
2250
2251 fn write_access_chain_index(
2261 &mut self,
2262 base: Handle<crate::Expression>,
2263 index: GuardedIndex,
2264 accumulated_checks: &mut Option<Word>,
2265 block: &mut Block,
2266 ) -> Result<Word, Error> {
2267 match self.write_bounds_check(base, index, block)? {
2268 BoundsCheckResult::KnownInBounds(known_index) => {
2269 let scalar = crate::Literal::U32(known_index);
2272 Ok(self.writer.get_constant_scalar(scalar))
2273 }
2274 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2275 BoundsCheckResult::Conditional {
2276 condition_id: condition,
2277 index_id: index,
2278 } => {
2279 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2280
2281 Ok(index)
2283 }
2284 }
2285 }
2286
2287 fn extend_bounds_check_condition_chain(
2306 &mut self,
2307 chain: &mut Option<Word>,
2308 comparison_id: Word,
2309 block: &mut Block,
2310 ) {
2311 match *chain {
2312 Some(ref mut prior_checks) => {
2313 let combined = self.gen_id();
2314 block.body.push(Instruction::binary(
2315 spirv::Op::LogicalAnd,
2316 self.writer.get_bool_type_id(),
2317 combined,
2318 *prior_checks,
2319 comparison_id,
2320 ));
2321 *prior_checks = combined;
2322 }
2323 None => {
2324 *chain = Some(comparison_id);
2326 }
2327 }
2328 }
2329
2330 fn write_checked_load(
2331 &mut self,
2332 pointer: Handle<crate::Expression>,
2333 block: &mut Block,
2334 access_type_adjustment: AccessTypeAdjustment,
2335 result_type_id: Word,
2336 ) -> Result<Word, Error> {
2337 match self.write_access_chain(pointer, block, access_type_adjustment)? {
2338 ExpressionPointer::Ready { pointer_id } => {
2339 let id = self.gen_id();
2340 let atomic_space =
2341 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2342 crate::TypeInner::Pointer { base, space } => {
2343 match self.ir_module.types[base].inner {
2344 crate::TypeInner::Atomic { .. } => Some(space),
2345 _ => None,
2346 }
2347 }
2348 _ => None,
2349 };
2350 let instruction = if let Some(space) = atomic_space {
2351 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2352 let scope_constant_id = self.get_scope_constant(scope as u32);
2353 let semantics_id = self.get_index_constant(semantics.bits());
2354 Instruction::atomic_load(
2355 result_type_id,
2356 id,
2357 pointer_id,
2358 scope_constant_id,
2359 semantics_id,
2360 )
2361 } else {
2362 Instruction::load(result_type_id, id, pointer_id, None)
2363 };
2364 block.body.push(instruction);
2365 Ok(id)
2366 }
2367 ExpressionPointer::Conditional { condition, access } => {
2368 let value = self.write_conditional_indexed_load(
2370 result_type_id,
2371 condition,
2372 block,
2373 move |id_gen, block| {
2374 let pointer_id = access.result_id.unwrap();
2376 let value_id = id_gen.next();
2377 block.body.push(access);
2378 block.body.push(Instruction::load(
2379 result_type_id,
2380 value_id,
2381 pointer_id,
2382 None,
2383 ));
2384 value_id
2385 },
2386 );
2387 Ok(value)
2388 }
2389 }
2390 }
2391
2392 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2393 use indexmap::map::Entry;
2394
2395 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2397 Entry::Occupied(preexisting) => preexisting.get().id,
2398 Entry::Vacant(vacant) => {
2399 let pointer_type_id = self.writer.get_resolution_pointer_id(
2402 &self.fun_info[base].ty,
2403 spirv::StorageClass::Function,
2404 );
2405 let id = self.writer.id_gen.next();
2406 vacant.insert(super::LocalVariable {
2407 id,
2408 instruction: Instruction::variable(
2409 pointer_type_id,
2410 id,
2411 spirv::StorageClass::Function,
2412 None,
2413 ),
2414 });
2415 id
2416 }
2417 };
2418
2419 let base_id = self.cached[base];
2444 block
2445 .body
2446 .push(Instruction::store(spill_variable_id, base_id, None));
2447 }
2448
2449 fn maybe_access_spilled_composite(
2466 &mut self,
2467 access: Handle<crate::Expression>,
2468 block: &mut Block,
2469 result_type_id: Word,
2470 ) -> Result<Word, Error> {
2471 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2472 if access_uses == self.fun_info[access].ref_count {
2473 Ok(0)
2477 } else {
2478 self.write_checked_load(
2483 access,
2484 block,
2485 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2486 result_type_id,
2487 )
2488 }
2489 }
2490
2491 #[allow(clippy::too_many_arguments)]
2493 fn write_matrix_matrix_column_op(
2494 &mut self,
2495 block: &mut Block,
2496 result_id: Word,
2497 result_type_id: Word,
2498 left_id: Word,
2499 right_id: Word,
2500 columns: crate::VectorSize,
2501 rows: crate::VectorSize,
2502 width: u8,
2503 op: spirv::Op,
2504 ) {
2505 self.temp_list.clear();
2506
2507 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2508 size: rows,
2509 scalar: crate::Scalar::float(width),
2510 });
2511
2512 for index in 0..columns as u32 {
2513 let column_id_left = self.gen_id();
2514 let column_id_right = self.gen_id();
2515 let column_id_res = self.gen_id();
2516
2517 block.body.push(Instruction::composite_extract(
2518 vector_type_id,
2519 column_id_left,
2520 left_id,
2521 &[index],
2522 ));
2523 block.body.push(Instruction::composite_extract(
2524 vector_type_id,
2525 column_id_right,
2526 right_id,
2527 &[index],
2528 ));
2529 block.body.push(Instruction::binary(
2530 op,
2531 vector_type_id,
2532 column_id_res,
2533 column_id_left,
2534 column_id_right,
2535 ));
2536
2537 self.temp_list.push(column_id_res);
2538 }
2539
2540 block.body.push(Instruction::composite_construct(
2541 result_type_id,
2542 result_id,
2543 &self.temp_list,
2544 ));
2545 }
2546
2547 fn write_vector_scalar_mult(
2549 &mut self,
2550 block: &mut Block,
2551 result_id: Word,
2552 result_type_id: Word,
2553 vector_id: Word,
2554 scalar_id: Word,
2555 vector: &crate::TypeInner,
2556 ) {
2557 let (size, kind) = match *vector {
2558 crate::TypeInner::Vector {
2559 size,
2560 scalar: crate::Scalar { kind, .. },
2561 } => (size, kind),
2562 _ => unreachable!(),
2563 };
2564
2565 let (op, operand_id) = match kind {
2566 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
2567 _ => {
2568 let operand_id = self.gen_id();
2569 self.temp_list.clear();
2570 self.temp_list.resize(size as usize, scalar_id);
2571 block.body.push(Instruction::composite_construct(
2572 result_type_id,
2573 operand_id,
2574 &self.temp_list,
2575 ));
2576 (spirv::Op::IMul, operand_id)
2577 }
2578 };
2579
2580 block.body.push(Instruction::binary(
2581 op,
2582 result_type_id,
2583 result_id,
2584 vector_id,
2585 operand_id,
2586 ));
2587 }
2588
2589 #[expect(clippy::too_many_arguments)]
2596 fn write_dot_product(
2597 &mut self,
2598 result_id: Word,
2599 result_type_id: Word,
2600 arg0_id: Word,
2601 arg1_id: Word,
2602 size: u32,
2603 block: &mut Block,
2604 extractor: impl Fn(Word, Word, Word) -> Instruction,
2605 ) {
2606 let mut partial_sum = self.writer.get_constant_null(result_type_id);
2607 let last_component = size - 1;
2608 for index in 0..=last_component {
2609 let a_id = self.gen_id();
2611 block.body.push(extractor(a_id, arg0_id, index));
2612 let b_id = self.gen_id();
2613 block.body.push(extractor(b_id, arg1_id, index));
2614 let prod_id = self.gen_id();
2615 block.body.push(Instruction::binary(
2616 spirv::Op::IMul,
2617 result_type_id,
2618 prod_id,
2619 a_id,
2620 b_id,
2621 ));
2622
2623 let id = if index == last_component {
2625 result_id
2626 } else {
2627 self.gen_id()
2628 };
2629
2630 block.body.push(Instruction::binary(
2632 spirv::Op::IAdd,
2633 result_type_id,
2634 id,
2635 partial_sum,
2636 prod_id,
2637 ));
2638 partial_sum = id;
2640 }
2641 }
2642
2643 fn write_pack4x8_optimized(
2645 &mut self,
2646 block: &mut Block,
2647 result_type_id: u32,
2648 arg0_id: u32,
2649 id: u32,
2650 is_signed: bool,
2651 should_clamp: bool,
2652 ) -> Instruction {
2653 let int_type = if is_signed {
2654 crate::ScalarKind::Sint
2655 } else {
2656 crate::ScalarKind::Uint
2657 };
2658 let wide_vector_type = NumericType::Vector {
2659 size: crate::VectorSize::Quad,
2660 scalar: crate::Scalar {
2661 kind: int_type,
2662 width: 4,
2663 },
2664 };
2665 let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type);
2666 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2667 size: crate::VectorSize::Quad,
2668 scalar: crate::Scalar {
2669 kind: crate::ScalarKind::Uint,
2670 width: 1,
2671 },
2672 });
2673
2674 let mut wide_vector = arg0_id;
2675 if should_clamp {
2676 let (min, max, clamp_op) = if is_signed {
2677 (
2678 crate::Literal::I32(-128),
2679 crate::Literal::I32(127),
2680 spirv::GLOp::SClamp,
2681 )
2682 } else {
2683 (
2684 crate::Literal::U32(0),
2685 crate::Literal::U32(255),
2686 spirv::GLOp::UClamp,
2687 )
2688 };
2689 let [min, max] = [min, max].map(|lit| {
2690 let scalar = self.writer.get_constant_scalar(lit);
2691 self.writer.get_constant_composite(
2692 LookupType::Local(LocalType::Numeric(wide_vector_type)),
2693 &[scalar; 4],
2694 )
2695 });
2696
2697 let clamp_id = self.gen_id();
2698 block.body.push(Instruction::ext_inst_gl_op(
2699 self.writer.gl450_ext_inst_id,
2700 clamp_op,
2701 wide_vector_type_id,
2702 clamp_id,
2703 &[wide_vector, min, max],
2704 ));
2705
2706 wide_vector = clamp_id;
2707 }
2708
2709 let packed_vector = self.gen_id();
2710 block.body.push(Instruction::unary(
2711 spirv::Op::UConvert, packed_vector_type_id,
2713 packed_vector,
2714 wide_vector,
2715 ));
2716
2717 Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector)
2722 }
2723
2724 fn write_pack4x8_polyfill(
2726 &mut self,
2727 block: &mut Block,
2728 result_type_id: u32,
2729 arg0_id: u32,
2730 id: u32,
2731 is_signed: bool,
2732 should_clamp: bool,
2733 ) -> Instruction {
2734 let int_type = if is_signed {
2735 crate::ScalarKind::Sint
2736 } else {
2737 crate::ScalarKind::Uint
2738 };
2739 let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
2740 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2741 kind: int_type,
2742 width: 4,
2743 }));
2744
2745 let mut last_instruction = Instruction::new(spirv::Op::Nop);
2746
2747 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
2748 let mut preresult = zero;
2749 block
2750 .body
2751 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
2752
2753 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2754 const VEC_LENGTH: u8 = 4;
2755 for i in 0..u32::from(VEC_LENGTH) {
2756 let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
2757 let mut extracted = self.gen_id();
2758 block.body.push(Instruction::binary(
2759 spirv::Op::CompositeExtract,
2760 int_type_id,
2761 extracted,
2762 arg0_id,
2763 i,
2764 ));
2765 if is_signed {
2766 let casted = self.gen_id();
2767 block.body.push(Instruction::unary(
2768 spirv::Op::Bitcast,
2769 uint_type_id,
2770 casted,
2771 extracted,
2772 ));
2773 extracted = casted;
2774 }
2775 if should_clamp {
2776 let (min, max, clamp_op) = if is_signed {
2777 (
2778 crate::Literal::I32(-128),
2779 crate::Literal::I32(127),
2780 spirv::GLOp::SClamp,
2781 )
2782 } else {
2783 (
2784 crate::Literal::U32(0),
2785 crate::Literal::U32(255),
2786 spirv::GLOp::UClamp,
2787 )
2788 };
2789 let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));
2790
2791 let clamp_id = self.gen_id();
2792 block.body.push(Instruction::ext_inst_gl_op(
2793 self.writer.gl450_ext_inst_id,
2794 clamp_op,
2795 result_type_id,
2796 clamp_id,
2797 &[extracted, min, max],
2798 ));
2799
2800 extracted = clamp_id;
2801 }
2802 let is_last = i == u32::from(VEC_LENGTH - 1);
2803 if is_last {
2804 last_instruction = Instruction::quaternary(
2805 spirv::Op::BitFieldInsert,
2806 result_type_id,
2807 id,
2808 preresult,
2809 extracted,
2810 offset,
2811 eight,
2812 )
2813 } else {
2814 let new_preresult = self.gen_id();
2815 block.body.push(Instruction::quaternary(
2816 spirv::Op::BitFieldInsert,
2817 result_type_id,
2818 new_preresult,
2819 preresult,
2820 extracted,
2821 offset,
2822 eight,
2823 ));
2824 preresult = new_preresult;
2825 }
2826 }
2827 last_instruction
2828 }
2829
2830 fn write_unpack4x8_optimized(
2832 &mut self,
2833 block: &mut Block,
2834 result_type_id: u32,
2835 arg0_id: u32,
2836 id: u32,
2837 is_signed: bool,
2838 ) -> Instruction {
2839 let (int_type, convert_op) = if is_signed {
2840 (crate::ScalarKind::Sint, spirv::Op::SConvert)
2841 } else {
2842 (crate::ScalarKind::Uint, spirv::Op::UConvert)
2843 };
2844
2845 let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2846 size: crate::VectorSize::Quad,
2847 scalar: crate::Scalar {
2848 kind: int_type,
2849 width: 1,
2850 },
2851 });
2852
2853 let packed_vector = self.gen_id();
2858 block.body.push(Instruction::unary(
2859 spirv::Op::Bitcast,
2860 packed_vector_type_id,
2861 packed_vector,
2862 arg0_id,
2863 ));
2864
2865 Instruction::unary(convert_op, result_type_id, id, packed_vector)
2866 }
2867
2868 fn write_unpack4x8_polyfill(
2870 &mut self,
2871 block: &mut Block,
2872 result_type_id: u32,
2873 arg0_id: u32,
2874 id: u32,
2875 is_signed: bool,
2876 ) -> Instruction {
2877 let (int_type, extract_op) = if is_signed {
2878 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract)
2879 } else {
2880 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract)
2881 };
2882
2883 let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
2884
2885 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
2886 let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
2887 kind: int_type,
2888 width: 4,
2889 }));
2890 block
2891 .body
2892 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
2893 let arg_id = if is_signed {
2894 let new_arg_id = self.gen_id();
2895 block.body.push(Instruction::unary(
2896 spirv::Op::Bitcast,
2897 sint_type_id,
2898 new_arg_id,
2899 arg0_id,
2900 ));
2901 new_arg_id
2902 } else {
2903 arg0_id
2904 };
2905
2906 const VEC_LENGTH: u8 = 4;
2907 let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id());
2908 for (i, part_id) in parts.into_iter().enumerate() {
2909 let index = self
2910 .writer
2911 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
2912 block.body.push(Instruction::ternary(
2913 extract_op,
2914 int_type_id,
2915 part_id,
2916 arg_id,
2917 index,
2918 eight,
2919 ));
2920 }
2921
2922 Instruction::composite_construct(result_type_id, id, &parts)
2923 }
2924
2925 fn write_block(
2942 &mut self,
2943 label_id: Word,
2944 naga_block: &crate::Block,
2945 exit: BlockExit,
2946 loop_context: LoopContext,
2947 debug_info: Option<&DebugInfoInner>,
2948 ) -> Result<BlockExitDisposition, Error> {
2949 let mut block = Block::new(label_id);
2950 for (statement, span) in naga_block.span_iter() {
2951 if let (Some(debug_info), false) = (
2952 debug_info,
2953 matches!(
2954 statement,
2955 &(Statement::Block(..)
2956 | Statement::Break
2957 | Statement::Continue
2958 | Statement::Kill
2959 | Statement::Return { .. }
2960 | Statement::Loop { .. })
2961 ),
2962 ) {
2963 let loc: crate::SourceLocation = span.location(debug_info.source_code);
2964 block.body.push(Instruction::line(
2965 debug_info.source_file_id,
2966 loc.line_number,
2967 loc.line_position,
2968 ));
2969 };
2970 match *statement {
2971 Statement::Emit(ref range) => {
2972 for handle in range.clone() {
2973 if !self.expression_constness.is_const(handle) {
2975 self.cache_expression_value(handle, &mut block)?;
2976 }
2977 }
2978 }
2979 Statement::Block(ref block_statements) => {
2980 let scope_id = self.gen_id();
2981 self.function.consume(block, Instruction::branch(scope_id));
2982
2983 let merge_id = self.gen_id();
2984 let merge_used = self.write_block(
2985 scope_id,
2986 block_statements,
2987 BlockExit::Branch { target: merge_id },
2988 loop_context,
2989 debug_info,
2990 )?;
2991
2992 match merge_used {
2993 BlockExitDisposition::Used => {
2994 block = Block::new(merge_id);
2995 }
2996 BlockExitDisposition::Discarded => {
2997 return Ok(BlockExitDisposition::Discarded);
2998 }
2999 }
3000 }
3001 Statement::If {
3002 condition,
3003 ref accept,
3004 ref reject,
3005 } => {
3006 if !(accept.is_empty() && reject.is_empty()) {
3012 let condition_id = self.cached[condition];
3013
3014 let merge_id = self.gen_id();
3015 block.body.push(Instruction::selection_merge(
3016 merge_id,
3017 spirv::SelectionControl::NONE,
3018 ));
3019
3020 let accept_id = if accept.is_empty() {
3021 None
3022 } else {
3023 Some(self.gen_id())
3024 };
3025 let reject_id = if reject.is_empty() {
3026 None
3027 } else {
3028 Some(self.gen_id())
3029 };
3030
3031 self.function.consume(
3032 block,
3033 Instruction::branch_conditional(
3034 condition_id,
3035 accept_id.unwrap_or(merge_id),
3036 reject_id.unwrap_or(merge_id),
3037 ),
3038 );
3039
3040 if let Some(block_id) = accept_id {
3041 let _ = self.write_block(
3046 block_id,
3047 accept,
3048 BlockExit::Branch { target: merge_id },
3049 loop_context,
3050 debug_info,
3051 )?;
3052 }
3053 if let Some(block_id) = reject_id {
3054 let _ = self.write_block(
3059 block_id,
3060 reject,
3061 BlockExit::Branch { target: merge_id },
3062 loop_context,
3063 debug_info,
3064 )?;
3065 }
3066
3067 block = Block::new(merge_id);
3068 }
3069 }
3070 Statement::Switch {
3071 selector,
3072 ref cases,
3073 } => {
3074 let selector_id = self.cached[selector];
3075
3076 let merge_id = self.gen_id();
3077 block.body.push(Instruction::selection_merge(
3078 merge_id,
3079 spirv::SelectionControl::NONE,
3080 ));
3081
3082 let mut default_id = None;
3083 let mut last_id = None;
3085
3086 let mut raw_cases = Vec::with_capacity(cases.len());
3087 let mut case_ids = Vec::with_capacity(cases.len());
3088 for case in cases.iter() {
3089 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
3091
3092 if case.fall_through && case.body.is_empty() {
3093 last_id = Some(label_id);
3094 }
3095
3096 case_ids.push(label_id);
3097
3098 match case.value {
3099 crate::SwitchValue::I32(value) => {
3100 raw_cases.push(super::instructions::Case {
3101 value: value as Word,
3102 label_id,
3103 });
3104 }
3105 crate::SwitchValue::U32(value) => {
3106 raw_cases.push(super::instructions::Case { value, label_id });
3107 }
3108 crate::SwitchValue::Default => {
3109 default_id = Some(label_id);
3110 }
3111 }
3112 }
3113
3114 let default_id = default_id.unwrap();
3115
3116 self.function.consume(
3117 block,
3118 Instruction::switch(selector_id, default_id, &raw_cases),
3119 );
3120
3121 let inner_context = LoopContext {
3122 break_id: Some(merge_id),
3123 ..loop_context
3124 };
3125
3126 for (i, (case, label_id)) in cases
3127 .iter()
3128 .zip(case_ids.iter())
3129 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
3130 .enumerate()
3131 {
3132 let case_finish_id = if case.fall_through {
3133 case_ids[i + 1]
3134 } else {
3135 merge_id
3136 };
3137 let _ = self.write_block(
3146 *label_id,
3147 &case.body,
3148 BlockExit::Branch {
3149 target: case_finish_id,
3150 },
3151 inner_context,
3152 debug_info,
3153 )?;
3154 }
3155
3156 block = Block::new(merge_id);
3157 }
3158 Statement::Loop {
3159 ref body,
3160 ref continuing,
3161 break_if,
3162 } => {
3163 let preamble_id = self.gen_id();
3164 self.function
3165 .consume(block, Instruction::branch(preamble_id));
3166
3167 let merge_id = self.gen_id();
3168 let body_id = self.gen_id();
3169 let continuing_id = self.gen_id();
3170
3171 block = Block::new(preamble_id);
3174 if let Some(debug_info) = debug_info {
3177 let loc: crate::SourceLocation = span.location(debug_info.source_code);
3178 block.body.push(Instruction::line(
3179 debug_info.source_file_id,
3180 loc.line_number,
3181 loc.line_position,
3182 ))
3183 }
3184 block.body.push(Instruction::loop_merge(
3185 merge_id,
3186 continuing_id,
3187 spirv::SelectionControl::NONE,
3188 ));
3189
3190 if self.force_loop_bounding {
3191 block = self.write_force_bounded_loop_instructions(block, merge_id);
3192 }
3193 self.function.consume(block, Instruction::branch(body_id));
3194
3195 let _ = self.write_block(
3199 body_id,
3200 body,
3201 BlockExit::Branch {
3202 target: continuing_id,
3203 },
3204 LoopContext {
3205 continuing_id: Some(continuing_id),
3206 break_id: Some(merge_id),
3207 },
3208 debug_info,
3209 )?;
3210
3211 let exit = match break_if {
3212 Some(condition) => BlockExit::BreakIf {
3213 condition,
3214 preamble_id,
3215 },
3216 None => BlockExit::Branch {
3217 target: preamble_id,
3218 },
3219 };
3220
3221 let _ = self.write_block(
3225 continuing_id,
3226 continuing,
3227 exit,
3228 LoopContext {
3229 continuing_id: None,
3230 break_id: Some(merge_id),
3231 },
3232 debug_info,
3233 )?;
3234
3235 block = Block::new(merge_id);
3236 }
3237 Statement::Break => {
3238 self.function
3239 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
3240 return Ok(BlockExitDisposition::Discarded);
3241 }
3242 Statement::Continue => {
3243 self.function.consume(
3244 block,
3245 Instruction::branch(loop_context.continuing_id.unwrap()),
3246 );
3247 return Ok(BlockExitDisposition::Discarded);
3248 }
3249 Statement::Return { value: Some(value) } => {
3250 let value_id = self.cached[value];
3251 let instruction = match self.function.entry_point_context {
3252 Some(ref context) => {
3255 self.writer.write_entry_point_return(
3256 value_id,
3257 self.ir_function.result.as_ref().unwrap(),
3258 &context.results,
3259 &mut block.body,
3260 )?;
3261 Instruction::return_void()
3262 }
3263 None => Instruction::return_value(value_id),
3264 };
3265 self.function.consume(block, instruction);
3266 return Ok(BlockExitDisposition::Discarded);
3267 }
3268 Statement::Return { value: None } => {
3269 self.function.consume(block, Instruction::return_void());
3270 return Ok(BlockExitDisposition::Discarded);
3271 }
3272 Statement::Kill => {
3273 self.function.consume(block, Instruction::kill());
3274 return Ok(BlockExitDisposition::Discarded);
3275 }
3276 Statement::ControlBarrier(flags) => {
3277 self.writer.write_control_barrier(flags, &mut block);
3278 }
3279 Statement::MemoryBarrier(flags) => {
3280 self.writer.write_memory_barrier(flags, &mut block);
3281 }
3282 Statement::Store { pointer, value } => {
3283 let value_id = self.cached[value];
3284 match self.write_access_chain(
3285 pointer,
3286 &mut block,
3287 AccessTypeAdjustment::None,
3288 )? {
3289 ExpressionPointer::Ready { pointer_id } => {
3290 let atomic_space = match *self.fun_info[pointer]
3291 .ty
3292 .inner_with(&self.ir_module.types)
3293 {
3294 crate::TypeInner::Pointer { base, space } => {
3295 match self.ir_module.types[base].inner {
3296 crate::TypeInner::Atomic { .. } => Some(space),
3297 _ => None,
3298 }
3299 }
3300 _ => None,
3301 };
3302 let instruction = if let Some(space) = atomic_space {
3303 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3304 let scope_constant_id = self.get_scope_constant(scope as u32);
3305 let semantics_id = self.get_index_constant(semantics.bits());
3306 Instruction::atomic_store(
3307 pointer_id,
3308 scope_constant_id,
3309 semantics_id,
3310 value_id,
3311 )
3312 } else {
3313 Instruction::store(pointer_id, value_id, None)
3314 };
3315 block.body.push(instruction);
3316 }
3317 ExpressionPointer::Conditional { condition, access } => {
3318 let mut selection = Selection::start(&mut block, ());
3319 selection.if_true(self, condition, ());
3320
3321 let pointer_id = access.result_id.unwrap();
3323 selection.block().body.push(access);
3324 selection
3325 .block()
3326 .body
3327 .push(Instruction::store(pointer_id, value_id, None));
3328
3329 selection.finish(self, ());
3332 }
3333 };
3334 }
3335 Statement::ImageStore {
3336 image,
3337 coordinate,
3338 array_index,
3339 value,
3340 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3341 Statement::Call {
3342 function: local_function,
3343 ref arguments,
3344 result,
3345 } => {
3346 let id = self.gen_id();
3347 self.temp_list.clear();
3348 for &argument in arguments {
3349 self.temp_list.push(self.cached[argument]);
3350 }
3351
3352 let type_id = match result {
3353 Some(expr) => {
3354 self.cached[expr] = id;
3355 self.get_expression_type_id(&self.fun_info[expr].ty)
3356 }
3357 None => self.writer.void_type,
3358 };
3359
3360 block.body.push(Instruction::function_call(
3361 type_id,
3362 id,
3363 self.writer.lookup_function[&local_function],
3364 &self.temp_list,
3365 ));
3366 }
3367 Statement::Atomic {
3368 pointer,
3369 ref fun,
3370 value,
3371 result,
3372 } => {
3373 let id = self.gen_id();
3374 let result_type_id =
3378 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3379
3380 if let Some(result) = result {
3381 self.cached[result] = id;
3382 }
3383
3384 let pointer_id = match self.write_access_chain(
3385 pointer,
3386 &mut block,
3387 AccessTypeAdjustment::None,
3388 )? {
3389 ExpressionPointer::Ready { pointer_id } => pointer_id,
3390 ExpressionPointer::Conditional { .. } => {
3391 return Err(Error::FeatureNotImplemented(
3392 "Atomics out-of-bounds handling",
3393 ));
3394 }
3395 };
3396
3397 let space = self.fun_info[pointer]
3398 .ty
3399 .inner_with(&self.ir_module.types)
3400 .pointer_space()
3401 .unwrap();
3402 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3403 let scope_constant_id = self.get_scope_constant(scope as u32);
3404 let semantics_id = self.get_index_constant(semantics.bits());
3405 let value_id = self.cached[value];
3406 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3407
3408 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3409 return Err(Error::FeatureNotImplemented(
3410 "Atomics with non-scalar values",
3411 ));
3412 };
3413
3414 let instruction = match *fun {
3415 crate::AtomicFunction::Add => {
3416 let spirv_op = match scalar.kind {
3417 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3418 spirv::Op::AtomicIAdd
3419 }
3420 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3421 _ => unimplemented!(),
3422 };
3423 Instruction::atomic_binary(
3424 spirv_op,
3425 result_type_id,
3426 id,
3427 pointer_id,
3428 scope_constant_id,
3429 semantics_id,
3430 value_id,
3431 )
3432 }
3433 crate::AtomicFunction::Subtract => {
3434 let (spirv_op, value_id) = match scalar.kind {
3435 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3436 (spirv::Op::AtomicISub, value_id)
3437 }
3438 crate::ScalarKind::Float => {
3439 let neg_result_id = self.gen_id();
3442 block.body.push(Instruction::unary(
3443 spirv::Op::FNegate,
3444 result_type_id,
3445 neg_result_id,
3446 value_id,
3447 ));
3448 (spirv::Op::AtomicFAddEXT, neg_result_id)
3449 }
3450 _ => unimplemented!(),
3451 };
3452 Instruction::atomic_binary(
3453 spirv_op,
3454 result_type_id,
3455 id,
3456 pointer_id,
3457 scope_constant_id,
3458 semantics_id,
3459 value_id,
3460 )
3461 }
3462 crate::AtomicFunction::And => {
3463 let spirv_op = match scalar.kind {
3464 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3465 spirv::Op::AtomicAnd
3466 }
3467 _ => unimplemented!(),
3468 };
3469 Instruction::atomic_binary(
3470 spirv_op,
3471 result_type_id,
3472 id,
3473 pointer_id,
3474 scope_constant_id,
3475 semantics_id,
3476 value_id,
3477 )
3478 }
3479 crate::AtomicFunction::InclusiveOr => {
3480 let spirv_op = match scalar.kind {
3481 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3482 spirv::Op::AtomicOr
3483 }
3484 _ => unimplemented!(),
3485 };
3486 Instruction::atomic_binary(
3487 spirv_op,
3488 result_type_id,
3489 id,
3490 pointer_id,
3491 scope_constant_id,
3492 semantics_id,
3493 value_id,
3494 )
3495 }
3496 crate::AtomicFunction::ExclusiveOr => {
3497 let spirv_op = match scalar.kind {
3498 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3499 spirv::Op::AtomicXor
3500 }
3501 _ => unimplemented!(),
3502 };
3503 Instruction::atomic_binary(
3504 spirv_op,
3505 result_type_id,
3506 id,
3507 pointer_id,
3508 scope_constant_id,
3509 semantics_id,
3510 value_id,
3511 )
3512 }
3513 crate::AtomicFunction::Min => {
3514 let spirv_op = match scalar.kind {
3515 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
3516 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
3517 _ => unimplemented!(),
3518 };
3519 Instruction::atomic_binary(
3520 spirv_op,
3521 result_type_id,
3522 id,
3523 pointer_id,
3524 scope_constant_id,
3525 semantics_id,
3526 value_id,
3527 )
3528 }
3529 crate::AtomicFunction::Max => {
3530 let spirv_op = match scalar.kind {
3531 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
3532 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
3533 _ => unimplemented!(),
3534 };
3535 Instruction::atomic_binary(
3536 spirv_op,
3537 result_type_id,
3538 id,
3539 pointer_id,
3540 scope_constant_id,
3541 semantics_id,
3542 value_id,
3543 )
3544 }
3545 crate::AtomicFunction::Exchange { compare: None } => {
3546 Instruction::atomic_binary(
3547 spirv::Op::AtomicExchange,
3548 result_type_id,
3549 id,
3550 pointer_id,
3551 scope_constant_id,
3552 semantics_id,
3553 value_id,
3554 )
3555 }
3556 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3557 let scalar_type_id =
3558 self.get_numeric_type_id(NumericType::Scalar(scalar));
3559 let bool_type_id =
3560 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
3561
3562 let cas_result_id = self.gen_id();
3563 let equality_result_id = self.gen_id();
3564 let equality_operator = match scalar.kind {
3565 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3566 spirv::Op::IEqual
3567 }
3568 _ => unimplemented!(),
3569 };
3570
3571 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
3572 cas_instr.set_type(scalar_type_id);
3573 cas_instr.set_result(cas_result_id);
3574 cas_instr.add_operand(pointer_id);
3575 cas_instr.add_operand(scope_constant_id);
3576 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
3579 cas_instr.add_operand(self.cached[cmp]);
3580 block.body.push(cas_instr);
3581 block.body.push(Instruction::binary(
3582 equality_operator,
3583 bool_type_id,
3584 equality_result_id,
3585 cas_result_id,
3586 self.cached[cmp],
3587 ));
3588 Instruction::composite_construct(
3589 result_type_id,
3590 id,
3591 &[cas_result_id, equality_result_id],
3592 )
3593 }
3594 };
3595
3596 block.body.push(instruction);
3597 }
3598 Statement::ImageAtomic {
3599 image,
3600 coordinate,
3601 array_index,
3602 fun,
3603 value,
3604 } => {
3605 self.write_image_atomic(
3606 image,
3607 coordinate,
3608 array_index,
3609 fun,
3610 value,
3611 &mut block,
3612 )?;
3613 }
3614 Statement::WorkGroupUniformLoad { pointer, result } => {
3615 self.writer
3616 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3617 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
3618 match self.write_access_chain(
3620 pointer,
3621 &mut block,
3622 AccessTypeAdjustment::None,
3623 )? {
3624 ExpressionPointer::Ready { pointer_id } => {
3625 let id = self.gen_id();
3626 block.body.push(Instruction::load(
3627 result_type_id,
3628 id,
3629 pointer_id,
3630 None,
3631 ));
3632 self.cached[result] = id;
3633 }
3634 ExpressionPointer::Conditional { condition, access } => {
3635 self.cached[result] = self.write_conditional_indexed_load(
3636 result_type_id,
3637 condition,
3638 &mut block,
3639 move |id_gen, block| {
3640 let pointer_id = access.result_id.unwrap();
3642 let value_id = id_gen.next();
3643 block.body.push(access);
3644 block.body.push(Instruction::load(
3645 result_type_id,
3646 value_id,
3647 pointer_id,
3648 None,
3649 ));
3650 value_id
3651 },
3652 )
3653 }
3654 }
3655 self.writer
3656 .write_control_barrier(crate::Barrier::WORK_GROUP, &mut block);
3657 }
3658 Statement::RayQuery { query, ref fun } => {
3659 self.write_ray_query_function(query, fun, &mut block);
3660 }
3661 Statement::SubgroupBallot {
3662 result,
3663 ref predicate,
3664 } => {
3665 self.write_subgroup_ballot(predicate, result, &mut block)?;
3666 }
3667 Statement::SubgroupCollectiveOperation {
3668 ref op,
3669 ref collective_op,
3670 argument,
3671 result,
3672 } => {
3673 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
3674 }
3675 Statement::SubgroupGather {
3676 ref mode,
3677 argument,
3678 result,
3679 } => {
3680 self.write_subgroup_gather(mode, argument, result, &mut block)?;
3681 }
3682 }
3683 }
3684
3685 let termination = match exit {
3686 BlockExit::Return => match self.ir_function.result {
3689 Some(ref result) if self.function.entry_point_context.is_none() => {
3690 let type_id = self.get_handle_type_id(result.ty);
3691 let null_id = self.writer.get_constant_null(type_id);
3692 Instruction::return_value(null_id)
3693 }
3694 _ => Instruction::return_void(),
3695 },
3696 BlockExit::Branch { target } => Instruction::branch(target),
3697 BlockExit::BreakIf {
3698 condition,
3699 preamble_id,
3700 } => {
3701 let condition_id = self.cached[condition];
3702
3703 Instruction::branch_conditional(
3704 condition_id,
3705 loop_context.break_id.unwrap(),
3706 preamble_id,
3707 )
3708 }
3709 };
3710
3711 self.function.consume(block, termination);
3712 Ok(BlockExitDisposition::Used)
3713 }
3714
3715 pub(super) fn write_function_body(
3716 &mut self,
3717 entry_id: Word,
3718 debug_info: Option<&DebugInfoInner>,
3719 ) -> Result<(), Error> {
3720 let _ = self.write_block(
3723 entry_id,
3724 &self.ir_function.body,
3725 BlockExit::Return,
3726 LoopContext::default(),
3727 debug_info,
3728 )?;
3729
3730 Ok(())
3731 }
3732}