1use crate::front::wgsl::parse::directive::enable_extension::{
2 EnableExtensions, ImplementedEnableExtension,
3};
4use crate::front::wgsl::{Error, Result, Scalar};
5use crate::Span;
6
7use alloc::boxed::Box;
8
9pub fn map_address_space<'a>(
10 word: &str,
11 span: Span,
12 enable_extensions: &EnableExtensions,
13) -> Result<'a, crate::AddressSpace> {
14 match word {
15 "private" => Ok(crate::AddressSpace::Private),
16 "workgroup" => Ok(crate::AddressSpace::WorkGroup),
17 "uniform" => Ok(crate::AddressSpace::Uniform),
18 "storage" => Ok(crate::AddressSpace::Storage {
19 access: crate::StorageAccess::default(),
20 }),
21 "immediate" => Ok(crate::AddressSpace::Immediate),
22 "function" => Ok(crate::AddressSpace::Function),
23 "task_payload" => {
24 if enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
25 Ok(crate::AddressSpace::TaskPayload)
26 } else {
27 Err(Box::new(Error::EnableExtensionNotEnabled {
28 span,
29 kind: ImplementedEnableExtension::WgpuMeshShader.into(),
30 }))
31 }
32 }
33 _ => Err(Box::new(Error::UnknownAddressSpace(span))),
34 }
35}
36
37pub fn map_built_in(
38 enable_extensions: &EnableExtensions,
39 word: &str,
40 span: Span,
41) -> Result<'static, crate::BuiltIn> {
42 let built_in = match word {
43 "position" => crate::BuiltIn::Position { invariant: false },
44 "vertex_index" => crate::BuiltIn::VertexIndex,
46 "instance_index" => crate::BuiltIn::InstanceIndex,
47 "view_index" => crate::BuiltIn::ViewIndex,
48 "clip_distances" => crate::BuiltIn::ClipDistance,
49 "front_facing" => crate::BuiltIn::FrontFacing,
51 "frag_depth" => crate::BuiltIn::FragDepth,
52 "primitive_index" => crate::BuiltIn::PrimitiveIndex,
53 "barycentric" => crate::BuiltIn::Barycentric,
54 "sample_index" => crate::BuiltIn::SampleIndex,
55 "sample_mask" => crate::BuiltIn::SampleMask,
56 "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
58 "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
59 "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
60 "workgroup_id" => crate::BuiltIn::WorkGroupId,
61 "num_workgroups" => crate::BuiltIn::NumWorkGroups,
62 "num_subgroups" => crate::BuiltIn::NumSubgroups,
64 "subgroup_id" => crate::BuiltIn::SubgroupId,
65 "subgroup_size" => crate::BuiltIn::SubgroupSize,
66 "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
67 "cull_primitive" => crate::BuiltIn::CullPrimitive,
69 "point_index" => crate::BuiltIn::PointIndex,
70 "line_indices" => crate::BuiltIn::LineIndices,
71 "triangle_indices" => crate::BuiltIn::TriangleIndices,
72 "mesh_task_size" => crate::BuiltIn::MeshTaskSize,
73 "vertex_count" => crate::BuiltIn::VertexCount,
75 "vertices" => crate::BuiltIn::Vertices,
76 "primitive_count" => crate::BuiltIn::PrimitiveCount,
77 "primitives" => crate::BuiltIn::Primitives,
78 _ => return Err(Box::new(Error::UnknownBuiltin(span))),
79 };
80 match built_in {
81 crate::BuiltIn::ClipDistance => {
82 if !enable_extensions.contains(ImplementedEnableExtension::ClipDistances) {
83 return Err(Box::new(Error::EnableExtensionNotEnabled {
84 span,
85 kind: ImplementedEnableExtension::ClipDistances.into(),
86 }));
87 }
88 }
89 crate::BuiltIn::CullPrimitive
90 | crate::BuiltIn::PointIndex
91 | crate::BuiltIn::LineIndices
92 | crate::BuiltIn::TriangleIndices
93 | crate::BuiltIn::VertexCount
94 | crate::BuiltIn::Vertices
95 | crate::BuiltIn::PrimitiveCount
96 | crate::BuiltIn::Primitives => {
97 if !enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
98 return Err(Box::new(Error::EnableExtensionNotEnabled {
99 span,
100 kind: ImplementedEnableExtension::WgpuMeshShader.into(),
101 }));
102 }
103 }
104 _ => {}
105 }
106 Ok(built_in)
107}
108
109pub fn map_interpolation(word: &str, span: Span) -> Result<'_, crate::Interpolation> {
110 match word {
111 "linear" => Ok(crate::Interpolation::Linear),
112 "flat" => Ok(crate::Interpolation::Flat),
113 "perspective" => Ok(crate::Interpolation::Perspective),
114 _ => Err(Box::new(Error::UnknownAttribute(span))),
115 }
116}
117
118pub fn map_sampling(word: &str, span: Span) -> Result<'_, crate::Sampling> {
119 match word {
120 "center" => Ok(crate::Sampling::Center),
121 "centroid" => Ok(crate::Sampling::Centroid),
122 "sample" => Ok(crate::Sampling::Sample),
123 "first" => Ok(crate::Sampling::First),
124 "either" => Ok(crate::Sampling::Either),
125 _ => Err(Box::new(Error::UnknownAttribute(span))),
126 }
127}
128
129pub fn map_storage_format(word: &str, span: Span) -> Result<'_, crate::StorageFormat> {
130 use crate::StorageFormat as Sf;
131 Ok(match word {
132 "r8unorm" => Sf::R8Unorm,
133 "r8snorm" => Sf::R8Snorm,
134 "r8uint" => Sf::R8Uint,
135 "r8sint" => Sf::R8Sint,
136 "r16unorm" => Sf::R16Unorm,
137 "r16snorm" => Sf::R16Snorm,
138 "r16uint" => Sf::R16Uint,
139 "r16sint" => Sf::R16Sint,
140 "r16float" => Sf::R16Float,
141 "rg8unorm" => Sf::Rg8Unorm,
142 "rg8snorm" => Sf::Rg8Snorm,
143 "rg8uint" => Sf::Rg8Uint,
144 "rg8sint" => Sf::Rg8Sint,
145 "r32uint" => Sf::R32Uint,
146 "r32sint" => Sf::R32Sint,
147 "r32float" => Sf::R32Float,
148 "rg16unorm" => Sf::Rg16Unorm,
149 "rg16snorm" => Sf::Rg16Snorm,
150 "rg16uint" => Sf::Rg16Uint,
151 "rg16sint" => Sf::Rg16Sint,
152 "rg16float" => Sf::Rg16Float,
153 "rgba8unorm" => Sf::Rgba8Unorm,
154 "rgba8snorm" => Sf::Rgba8Snorm,
155 "rgba8uint" => Sf::Rgba8Uint,
156 "rgba8sint" => Sf::Rgba8Sint,
157 "rgb10a2uint" => Sf::Rgb10a2Uint,
158 "rgb10a2unorm" => Sf::Rgb10a2Unorm,
159 "rg11b10ufloat" => Sf::Rg11b10Ufloat,
160 "r64uint" => Sf::R64Uint,
161 "rg32uint" => Sf::Rg32Uint,
162 "rg32sint" => Sf::Rg32Sint,
163 "rg32float" => Sf::Rg32Float,
164 "rgba16unorm" => Sf::Rgba16Unorm,
165 "rgba16snorm" => Sf::Rgba16Snorm,
166 "rgba16uint" => Sf::Rgba16Uint,
167 "rgba16sint" => Sf::Rgba16Sint,
168 "rgba16float" => Sf::Rgba16Float,
169 "rgba32uint" => Sf::Rgba32Uint,
170 "rgba32sint" => Sf::Rgba32Sint,
171 "rgba32float" => Sf::Rgba32Float,
172 "bgra8unorm" => Sf::Bgra8Unorm,
173 _ => return Err(Box::new(Error::UnknownStorageFormat(span))),
174 })
175}
176
177pub fn get_scalar_type(
178 enable_extensions: &EnableExtensions,
179 span: Span,
180 word: &str,
181) -> Result<'static, Option<Scalar>> {
182 use crate::ScalarKind as Sk;
183 let scalar = match word {
184 "f16" => Some(Scalar {
185 kind: Sk::Float,
186 width: 2,
187 }),
188 "f32" => Some(Scalar {
189 kind: Sk::Float,
190 width: 4,
191 }),
192 "f64" => Some(Scalar {
193 kind: Sk::Float,
194 width: 8,
195 }),
196 "i32" => Some(Scalar {
197 kind: Sk::Sint,
198 width: 4,
199 }),
200 "u32" => Some(Scalar {
201 kind: Sk::Uint,
202 width: 4,
203 }),
204 "i64" => Some(Scalar {
205 kind: Sk::Sint,
206 width: 8,
207 }),
208 "u64" => Some(Scalar {
209 kind: Sk::Uint,
210 width: 8,
211 }),
212 "bool" => Some(Scalar {
213 kind: Sk::Bool,
214 width: crate::BOOL_WIDTH,
215 }),
216 _ => None,
217 };
218
219 if matches!(scalar, Some(Scalar::F16))
220 && !enable_extensions.contains(ImplementedEnableExtension::F16)
221 {
222 return Err(Box::new(Error::EnableExtensionNotEnabled {
223 span,
224 kind: ImplementedEnableExtension::F16.into(),
225 }));
226 }
227
228 Ok(scalar)
229}
230
231pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
232 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
233 match word {
234 "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
235 "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
236 "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
237 "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
238 "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
239 "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
240 "dpdx" => Some((Axis::X, Ctrl::None)),
241 "dpdy" => Some((Axis::Y, Ctrl::None)),
242 "fwidth" => Some((Axis::Width, Ctrl::None)),
243 _ => None,
244 }
245}
246
247pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
248 match word {
249 "any" => Some(crate::RelationalFunction::Any),
250 "all" => Some(crate::RelationalFunction::All),
251 _ => None,
252 }
253}
254
255pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
256 use crate::MathFunction as Mf;
257 Some(match word {
258 "abs" => Mf::Abs,
260 "min" => Mf::Min,
261 "max" => Mf::Max,
262 "clamp" => Mf::Clamp,
263 "saturate" => Mf::Saturate,
264 "cos" => Mf::Cos,
266 "cosh" => Mf::Cosh,
267 "sin" => Mf::Sin,
268 "sinh" => Mf::Sinh,
269 "tan" => Mf::Tan,
270 "tanh" => Mf::Tanh,
271 "acos" => Mf::Acos,
272 "acosh" => Mf::Acosh,
273 "asin" => Mf::Asin,
274 "asinh" => Mf::Asinh,
275 "atan" => Mf::Atan,
276 "atanh" => Mf::Atanh,
277 "atan2" => Mf::Atan2,
278 "radians" => Mf::Radians,
279 "degrees" => Mf::Degrees,
280 "ceil" => Mf::Ceil,
282 "floor" => Mf::Floor,
283 "round" => Mf::Round,
284 "fract" => Mf::Fract,
285 "trunc" => Mf::Trunc,
286 "modf" => Mf::Modf,
287 "frexp" => Mf::Frexp,
288 "ldexp" => Mf::Ldexp,
289 "exp" => Mf::Exp,
291 "exp2" => Mf::Exp2,
292 "log" => Mf::Log,
293 "log2" => Mf::Log2,
294 "pow" => Mf::Pow,
295 "dot" => Mf::Dot,
297 "dot4I8Packed" => Mf::Dot4I8Packed,
298 "dot4U8Packed" => Mf::Dot4U8Packed,
299 "cross" => Mf::Cross,
300 "distance" => Mf::Distance,
301 "length" => Mf::Length,
302 "normalize" => Mf::Normalize,
303 "faceForward" => Mf::FaceForward,
304 "reflect" => Mf::Reflect,
305 "refract" => Mf::Refract,
306 "sign" => Mf::Sign,
308 "fma" => Mf::Fma,
309 "mix" => Mf::Mix,
310 "step" => Mf::Step,
311 "smoothstep" => Mf::SmoothStep,
312 "sqrt" => Mf::Sqrt,
313 "inverseSqrt" => Mf::InverseSqrt,
314 "transpose" => Mf::Transpose,
315 "determinant" => Mf::Determinant,
316 "quantizeToF16" => Mf::QuantizeToF16,
317 "countTrailingZeros" => Mf::CountTrailingZeros,
319 "countLeadingZeros" => Mf::CountLeadingZeros,
320 "countOneBits" => Mf::CountOneBits,
321 "reverseBits" => Mf::ReverseBits,
322 "extractBits" => Mf::ExtractBits,
323 "insertBits" => Mf::InsertBits,
324 "firstTrailingBit" => Mf::FirstTrailingBit,
325 "firstLeadingBit" => Mf::FirstLeadingBit,
326 "pack4x8snorm" => Mf::Pack4x8snorm,
328 "pack4x8unorm" => Mf::Pack4x8unorm,
329 "pack2x16snorm" => Mf::Pack2x16snorm,
330 "pack2x16unorm" => Mf::Pack2x16unorm,
331 "pack2x16float" => Mf::Pack2x16float,
332 "pack4xI8" => Mf::Pack4xI8,
333 "pack4xU8" => Mf::Pack4xU8,
334 "pack4xI8Clamp" => Mf::Pack4xI8Clamp,
335 "pack4xU8Clamp" => Mf::Pack4xU8Clamp,
336 "unpack4x8snorm" => Mf::Unpack4x8snorm,
338 "unpack4x8unorm" => Mf::Unpack4x8unorm,
339 "unpack2x16snorm" => Mf::Unpack2x16snorm,
340 "unpack2x16unorm" => Mf::Unpack2x16unorm,
341 "unpack2x16float" => Mf::Unpack2x16float,
342 "unpack4xI8" => Mf::Unpack4xI8,
343 "unpack4xU8" => Mf::Unpack4xU8,
344 _ => return None,
345 })
346}
347
348pub fn map_conservative_depth(word: &str, span: Span) -> Result<'_, crate::ConservativeDepth> {
349 use crate::ConservativeDepth as Cd;
350 match word {
351 "greater_equal" => Ok(Cd::GreaterEqual),
352 "less_equal" => Ok(Cd::LessEqual),
353 "unchanged" => Ok(Cd::Unchanged),
354 _ => Err(Box::new(Error::UnknownConservativeDepth(span))),
355 }
356}
357
358pub fn map_subgroup_operation(
359 word: &str,
360) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
361 use crate::CollectiveOperation as co;
362 use crate::SubgroupOperation as sg;
363 Some(match word {
364 "subgroupAll" => (sg::All, co::Reduce),
365 "subgroupAny" => (sg::Any, co::Reduce),
366 "subgroupAdd" => (sg::Add, co::Reduce),
367 "subgroupMul" => (sg::Mul, co::Reduce),
368 "subgroupMin" => (sg::Min, co::Reduce),
369 "subgroupMax" => (sg::Max, co::Reduce),
370 "subgroupAnd" => (sg::And, co::Reduce),
371 "subgroupOr" => (sg::Or, co::Reduce),
372 "subgroupXor" => (sg::Xor, co::Reduce),
373 "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
374 "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
375 "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
376 "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
377 _ => return None,
378 })
379}