naga/back/spv/
subgroup.rs

1use super::{Block, BlockContext, Error, Instruction, NumericType};
2use crate::{arena::Handle, TypeInner};
3
4impl BlockContext<'_> {
5    pub(super) fn write_subgroup_ballot(
6        &mut self,
7        predicate: &Option<Handle<crate::Expression>>,
8        result: Handle<crate::Expression>,
9        block: &mut Block,
10    ) -> Result<(), Error> {
11        self.writer.require_any(
12            "GroupNonUniformBallot",
13            &[spirv::Capability::GroupNonUniformBallot],
14        )?;
15        let vec4_u32_type_id = self.get_numeric_type_id(NumericType::Vector {
16            size: crate::VectorSize::Quad,
17            scalar: crate::Scalar::U32,
18        });
19        let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
20        let predicate = if let Some(predicate) = *predicate {
21            self.cached[predicate]
22        } else {
23            self.writer.get_constant_scalar(crate::Literal::Bool(true))
24        };
25        let id = self.gen_id();
26        block.body.push(Instruction::group_non_uniform_ballot(
27            vec4_u32_type_id,
28            id,
29            exec_scope_id,
30            predicate,
31        ));
32        self.cached[result] = id;
33        Ok(())
34    }
35    pub(super) fn write_subgroup_operation(
36        &mut self,
37        op: &crate::SubgroupOperation,
38        collective_op: &crate::CollectiveOperation,
39        argument: Handle<crate::Expression>,
40        result: Handle<crate::Expression>,
41        block: &mut Block,
42    ) -> Result<(), Error> {
43        use crate::SubgroupOperation as sg;
44        match *op {
45            sg::All | sg::Any => {
46                self.writer.require_any(
47                    "GroupNonUniformVote",
48                    &[spirv::Capability::GroupNonUniformVote],
49                )?;
50            }
51            _ => {
52                self.writer.require_any(
53                    "GroupNonUniformArithmetic",
54                    &[spirv::Capability::GroupNonUniformArithmetic],
55                )?;
56            }
57        }
58
59        let id = self.gen_id();
60        let result_ty = &self.fun_info[result].ty;
61        let result_type_id = self.get_expression_type_id(result_ty);
62        let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
63
64        let (is_scalar, scalar) = match *result_ty_inner {
65            TypeInner::Scalar(kind) => (true, kind),
66            TypeInner::Vector { scalar: kind, .. } => (false, kind),
67            _ => unimplemented!(),
68        };
69
70        use crate::ScalarKind as sk;
71        let spirv_op = match (scalar.kind, *op) {
72            (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
73            (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
74            (_, sg::All | sg::Any) => unimplemented!(),
75
76            (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
77            (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
78            (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
79            (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
80            (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
81            (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
82            (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
83            (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
84            (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
85            (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
86            (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
87
88            (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
89            (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
90            (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
91            (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
92            (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
93            (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
94            (_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
95        };
96
97        let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
98
99        use crate::CollectiveOperation as c;
100        let group_op = match *op {
101            sg::All | sg::Any => None,
102            _ => Some(match *collective_op {
103                c::Reduce => spirv::GroupOperation::Reduce,
104                c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
105                c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
106            }),
107        };
108
109        let arg_id = self.cached[argument];
110        block.body.push(Instruction::group_non_uniform_arithmetic(
111            spirv_op,
112            result_type_id,
113            id,
114            exec_scope_id,
115            group_op,
116            arg_id,
117        ));
118        self.cached[result] = id;
119        Ok(())
120    }
121    pub(super) fn write_subgroup_gather(
122        &mut self,
123        mode: &crate::GatherMode,
124        argument: Handle<crate::Expression>,
125        result: Handle<crate::Expression>,
126        block: &mut Block,
127    ) -> Result<(), Error> {
128        match *mode {
129            crate::GatherMode::BroadcastFirst => {
130                self.writer.require_any(
131                    "GroupNonUniformBallot",
132                    &[spirv::Capability::GroupNonUniformBallot],
133                )?;
134            }
135            crate::GatherMode::Shuffle(_)
136            | crate::GatherMode::ShuffleXor(_)
137            | crate::GatherMode::Broadcast(_) => {
138                self.writer.require_any(
139                    "GroupNonUniformShuffle",
140                    &[spirv::Capability::GroupNonUniformShuffle],
141                )?;
142            }
143            crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
144                self.writer.require_any(
145                    "GroupNonUniformShuffleRelative",
146                    &[spirv::Capability::GroupNonUniformShuffleRelative],
147                )?;
148            }
149            crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => {
150                self.writer.require_any(
151                    "GroupNonUniformQuad",
152                    &[spirv::Capability::GroupNonUniformQuad],
153                )?;
154            }
155        }
156
157        let id = self.gen_id();
158        let result_ty = &self.fun_info[result].ty;
159        let result_type_id = self.get_expression_type_id(result_ty);
160
161        let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
162
163        let arg_id = self.cached[argument];
164        match *mode {
165            crate::GatherMode::BroadcastFirst => {
166                block
167                    .body
168                    .push(Instruction::group_non_uniform_broadcast_first(
169                        result_type_id,
170                        id,
171                        exec_scope_id,
172                        arg_id,
173                    ));
174            }
175            crate::GatherMode::Broadcast(index)
176            | crate::GatherMode::Shuffle(index)
177            | crate::GatherMode::ShuffleDown(index)
178            | crate::GatherMode::ShuffleUp(index)
179            | crate::GatherMode::ShuffleXor(index)
180            | crate::GatherMode::QuadBroadcast(index) => {
181                let index_id = self.cached[index];
182                let op = match *mode {
183                    crate::GatherMode::BroadcastFirst => unreachable!(),
184                    // Use shuffle to emit broadcast to allow the index to
185                    // be dynamically uniform on Vulkan 1.1. The argument to
186                    // OpGroupNonUniformBroadcast must be a constant pre SPIR-V
187                    // 1.5 (vulkan 1.2)
188                    crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
189                    crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
190                    crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
191                    crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
192                    crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
193                    crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast,
194                    crate::GatherMode::QuadSwap(_) => unreachable!(),
195                };
196                block.body.push(Instruction::group_non_uniform_gather(
197                    op,
198                    result_type_id,
199                    id,
200                    exec_scope_id,
201                    arg_id,
202                    index_id,
203                ));
204            }
205            crate::GatherMode::QuadSwap(direction) => {
206                let direction = self.get_index_constant(match direction {
207                    crate::Direction::X => 0,
208                    crate::Direction::Y => 1,
209                    crate::Direction::Diagonal => 2,
210                });
211                block.body.push(Instruction::group_non_uniform_quad_swap(
212                    result_type_id,
213                    id,
214                    exec_scope_id,
215                    arg_id,
216                    direction,
217                ));
218            }
219        }
220        self.cached[result] = id;
221        Ok(())
222    }
223}