1use super::{Block, BlockContext, Error, Instruction, NumericType};
2use crate::{arena::Handle, TypeInner};
3
4impl BlockContext<'_> {
5 pub(super) fn write_subgroup_ballot(
6 &mut self,
7 predicate: &Option<Handle<crate::Expression>>,
8 result: Handle<crate::Expression>,
9 block: &mut Block,
10 ) -> Result<(), Error> {
11 self.writer.require_any(
12 "GroupNonUniformBallot",
13 &[spirv::Capability::GroupNonUniformBallot],
14 )?;
15 let vec4_u32_type_id = self.get_numeric_type_id(NumericType::Vector {
16 size: crate::VectorSize::Quad,
17 scalar: crate::Scalar::U32,
18 });
19 let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
20 let predicate = if let Some(predicate) = *predicate {
21 self.cached[predicate]
22 } else {
23 self.writer.get_constant_scalar(crate::Literal::Bool(true))
24 };
25 let id = self.gen_id();
26 block.body.push(Instruction::group_non_uniform_ballot(
27 vec4_u32_type_id,
28 id,
29 exec_scope_id,
30 predicate,
31 ));
32 self.cached[result] = id;
33 Ok(())
34 }
35 pub(super) fn write_subgroup_operation(
36 &mut self,
37 op: &crate::SubgroupOperation,
38 collective_op: &crate::CollectiveOperation,
39 argument: Handle<crate::Expression>,
40 result: Handle<crate::Expression>,
41 block: &mut Block,
42 ) -> Result<(), Error> {
43 use crate::SubgroupOperation as sg;
44 match *op {
45 sg::All | sg::Any => {
46 self.writer.require_any(
47 "GroupNonUniformVote",
48 &[spirv::Capability::GroupNonUniformVote],
49 )?;
50 }
51 _ => {
52 self.writer.require_any(
53 "GroupNonUniformArithmetic",
54 &[spirv::Capability::GroupNonUniformArithmetic],
55 )?;
56 }
57 }
58
59 let id = self.gen_id();
60 let result_ty = &self.fun_info[result].ty;
61 let result_type_id = self.get_expression_type_id(result_ty);
62 let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
63
64 let (is_scalar, scalar) = match *result_ty_inner {
65 TypeInner::Scalar(kind) => (true, kind),
66 TypeInner::Vector { scalar: kind, .. } => (false, kind),
67 _ => unimplemented!(),
68 };
69
70 use crate::ScalarKind as sk;
71 let spirv_op = match (scalar.kind, *op) {
72 (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
73 (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
74 (_, sg::All | sg::Any) => unimplemented!(),
75
76 (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
77 (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
78 (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
79 (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
80 (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
81 (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
82 (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
83 (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
84 (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
85 (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
86 (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
87
88 (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
89 (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
90 (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
91 (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
92 (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
93 (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
94 (_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
95 };
96
97 let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
98
99 use crate::CollectiveOperation as c;
100 let group_op = match *op {
101 sg::All | sg::Any => None,
102 _ => Some(match *collective_op {
103 c::Reduce => spirv::GroupOperation::Reduce,
104 c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
105 c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
106 }),
107 };
108
109 let arg_id = self.cached[argument];
110 block.body.push(Instruction::group_non_uniform_arithmetic(
111 spirv_op,
112 result_type_id,
113 id,
114 exec_scope_id,
115 group_op,
116 arg_id,
117 ));
118 self.cached[result] = id;
119 Ok(())
120 }
121 pub(super) fn write_subgroup_gather(
122 &mut self,
123 mode: &crate::GatherMode,
124 argument: Handle<crate::Expression>,
125 result: Handle<crate::Expression>,
126 block: &mut Block,
127 ) -> Result<(), Error> {
128 match *mode {
129 crate::GatherMode::BroadcastFirst => {
130 self.writer.require_any(
131 "GroupNonUniformBallot",
132 &[spirv::Capability::GroupNonUniformBallot],
133 )?;
134 }
135 crate::GatherMode::Shuffle(_)
136 | crate::GatherMode::ShuffleXor(_)
137 | crate::GatherMode::Broadcast(_) => {
138 self.writer.require_any(
139 "GroupNonUniformShuffle",
140 &[spirv::Capability::GroupNonUniformShuffle],
141 )?;
142 }
143 crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
144 self.writer.require_any(
145 "GroupNonUniformShuffleRelative",
146 &[spirv::Capability::GroupNonUniformShuffleRelative],
147 )?;
148 }
149 crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => {
150 self.writer.require_any(
151 "GroupNonUniformQuad",
152 &[spirv::Capability::GroupNonUniformQuad],
153 )?;
154 }
155 }
156
157 let id = self.gen_id();
158 let result_ty = &self.fun_info[result].ty;
159 let result_type_id = self.get_expression_type_id(result_ty);
160
161 let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
162
163 let arg_id = self.cached[argument];
164 match *mode {
165 crate::GatherMode::BroadcastFirst => {
166 block
167 .body
168 .push(Instruction::group_non_uniform_broadcast_first(
169 result_type_id,
170 id,
171 exec_scope_id,
172 arg_id,
173 ));
174 }
175 crate::GatherMode::Broadcast(index)
176 | crate::GatherMode::Shuffle(index)
177 | crate::GatherMode::ShuffleDown(index)
178 | crate::GatherMode::ShuffleUp(index)
179 | crate::GatherMode::ShuffleXor(index)
180 | crate::GatherMode::QuadBroadcast(index) => {
181 let index_id = self.cached[index];
182 let op = match *mode {
183 crate::GatherMode::BroadcastFirst => unreachable!(),
184 crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
189 crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
190 crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
191 crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
192 crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
193 crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast,
194 crate::GatherMode::QuadSwap(_) => unreachable!(),
195 };
196 block.body.push(Instruction::group_non_uniform_gather(
197 op,
198 result_type_id,
199 id,
200 exec_scope_id,
201 arg_id,
202 index_id,
203 ));
204 }
205 crate::GatherMode::QuadSwap(direction) => {
206 let direction = self.get_index_constant(match direction {
207 crate::Direction::X => 0,
208 crate::Direction::Y => 1,
209 crate::Direction::Diagonal => 2,
210 });
211 block.body.push(Instruction::group_non_uniform_quad_swap(
212 result_type_id,
213 id,
214 exec_scope_id,
215 arg_id,
216 direction,
217 ));
218 }
219 }
220 self.cached[result] = id;
221 Ok(())
222 }
223}