naga/front/wgsl/parse/
conv.rs

1use crate::front::wgsl::parse::directive::enable_extension::{
2    EnableExtensions, ImplementedEnableExtension,
3};
4use crate::front::wgsl::{Error, Result, Scalar};
5use crate::Span;
6
7use alloc::boxed::Box;
8
9pub fn map_address_space<'a>(
10    word: &str,
11    span: Span,
12    enable_extensions: &EnableExtensions,
13) -> Result<'a, crate::AddressSpace> {
14    match word {
15        "private" => Ok(crate::AddressSpace::Private),
16        "workgroup" => Ok(crate::AddressSpace::WorkGroup),
17        "uniform" => Ok(crate::AddressSpace::Uniform),
18        "storage" => Ok(crate::AddressSpace::Storage {
19            access: crate::StorageAccess::default(),
20        }),
21        "immediate" => Ok(crate::AddressSpace::Immediate),
22        "function" => Ok(crate::AddressSpace::Function),
23        "task_payload" => {
24            if enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
25                Ok(crate::AddressSpace::TaskPayload)
26            } else {
27                Err(Box::new(Error::EnableExtensionNotEnabled {
28                    span,
29                    kind: ImplementedEnableExtension::WgpuMeshShader.into(),
30                }))
31            }
32        }
33        _ => Err(Box::new(Error::UnknownAddressSpace(span))),
34    }
35}
36
37pub fn map_built_in(
38    enable_extensions: &EnableExtensions,
39    word: &str,
40    span: Span,
41) -> Result<'static, crate::BuiltIn> {
42    let built_in = match word {
43        "position" => crate::BuiltIn::Position { invariant: false },
44        // vertex
45        "vertex_index" => crate::BuiltIn::VertexIndex,
46        "instance_index" => crate::BuiltIn::InstanceIndex,
47        "view_index" => crate::BuiltIn::ViewIndex,
48        "clip_distances" => crate::BuiltIn::ClipDistance,
49        // fragment
50        "front_facing" => crate::BuiltIn::FrontFacing,
51        "frag_depth" => crate::BuiltIn::FragDepth,
52        "primitive_index" => crate::BuiltIn::PrimitiveIndex,
53        "barycentric" => crate::BuiltIn::Barycentric,
54        "sample_index" => crate::BuiltIn::SampleIndex,
55        "sample_mask" => crate::BuiltIn::SampleMask,
56        // compute
57        "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
58        "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
59        "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
60        "workgroup_id" => crate::BuiltIn::WorkGroupId,
61        "num_workgroups" => crate::BuiltIn::NumWorkGroups,
62        // subgroup
63        "num_subgroups" => crate::BuiltIn::NumSubgroups,
64        "subgroup_id" => crate::BuiltIn::SubgroupId,
65        "subgroup_size" => crate::BuiltIn::SubgroupSize,
66        "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
67        // mesh
68        "cull_primitive" => crate::BuiltIn::CullPrimitive,
69        "point_index" => crate::BuiltIn::PointIndex,
70        "line_indices" => crate::BuiltIn::LineIndices,
71        "triangle_indices" => crate::BuiltIn::TriangleIndices,
72        "mesh_task_size" => crate::BuiltIn::MeshTaskSize,
73        // mesh global variable
74        "vertex_count" => crate::BuiltIn::VertexCount,
75        "vertices" => crate::BuiltIn::Vertices,
76        "primitive_count" => crate::BuiltIn::PrimitiveCount,
77        "primitives" => crate::BuiltIn::Primitives,
78        _ => return Err(Box::new(Error::UnknownBuiltin(span))),
79    };
80    match built_in {
81        crate::BuiltIn::ClipDistance => {
82            if !enable_extensions.contains(ImplementedEnableExtension::ClipDistances) {
83                return Err(Box::new(Error::EnableExtensionNotEnabled {
84                    span,
85                    kind: ImplementedEnableExtension::ClipDistances.into(),
86                }));
87            }
88        }
89        crate::BuiltIn::CullPrimitive
90        | crate::BuiltIn::PointIndex
91        | crate::BuiltIn::LineIndices
92        | crate::BuiltIn::TriangleIndices
93        | crate::BuiltIn::VertexCount
94        | crate::BuiltIn::Vertices
95        | crate::BuiltIn::PrimitiveCount
96        | crate::BuiltIn::Primitives => {
97            if !enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
98                return Err(Box::new(Error::EnableExtensionNotEnabled {
99                    span,
100                    kind: ImplementedEnableExtension::WgpuMeshShader.into(),
101                }));
102            }
103        }
104        _ => {}
105    }
106    Ok(built_in)
107}
108
109pub fn map_interpolation(word: &str, span: Span) -> Result<'_, crate::Interpolation> {
110    match word {
111        "linear" => Ok(crate::Interpolation::Linear),
112        "flat" => Ok(crate::Interpolation::Flat),
113        "perspective" => Ok(crate::Interpolation::Perspective),
114        _ => Err(Box::new(Error::UnknownAttribute(span))),
115    }
116}
117
118pub fn map_sampling(word: &str, span: Span) -> Result<'_, crate::Sampling> {
119    match word {
120        "center" => Ok(crate::Sampling::Center),
121        "centroid" => Ok(crate::Sampling::Centroid),
122        "sample" => Ok(crate::Sampling::Sample),
123        "first" => Ok(crate::Sampling::First),
124        "either" => Ok(crate::Sampling::Either),
125        _ => Err(Box::new(Error::UnknownAttribute(span))),
126    }
127}
128
129pub fn map_storage_format(word: &str, span: Span) -> Result<'_, crate::StorageFormat> {
130    use crate::StorageFormat as Sf;
131    Ok(match word {
132        "r8unorm" => Sf::R8Unorm,
133        "r8snorm" => Sf::R8Snorm,
134        "r8uint" => Sf::R8Uint,
135        "r8sint" => Sf::R8Sint,
136        "r16unorm" => Sf::R16Unorm,
137        "r16snorm" => Sf::R16Snorm,
138        "r16uint" => Sf::R16Uint,
139        "r16sint" => Sf::R16Sint,
140        "r16float" => Sf::R16Float,
141        "rg8unorm" => Sf::Rg8Unorm,
142        "rg8snorm" => Sf::Rg8Snorm,
143        "rg8uint" => Sf::Rg8Uint,
144        "rg8sint" => Sf::Rg8Sint,
145        "r32uint" => Sf::R32Uint,
146        "r32sint" => Sf::R32Sint,
147        "r32float" => Sf::R32Float,
148        "rg16unorm" => Sf::Rg16Unorm,
149        "rg16snorm" => Sf::Rg16Snorm,
150        "rg16uint" => Sf::Rg16Uint,
151        "rg16sint" => Sf::Rg16Sint,
152        "rg16float" => Sf::Rg16Float,
153        "rgba8unorm" => Sf::Rgba8Unorm,
154        "rgba8snorm" => Sf::Rgba8Snorm,
155        "rgba8uint" => Sf::Rgba8Uint,
156        "rgba8sint" => Sf::Rgba8Sint,
157        "rgb10a2uint" => Sf::Rgb10a2Uint,
158        "rgb10a2unorm" => Sf::Rgb10a2Unorm,
159        "rg11b10ufloat" => Sf::Rg11b10Ufloat,
160        "r64uint" => Sf::R64Uint,
161        "rg32uint" => Sf::Rg32Uint,
162        "rg32sint" => Sf::Rg32Sint,
163        "rg32float" => Sf::Rg32Float,
164        "rgba16unorm" => Sf::Rgba16Unorm,
165        "rgba16snorm" => Sf::Rgba16Snorm,
166        "rgba16uint" => Sf::Rgba16Uint,
167        "rgba16sint" => Sf::Rgba16Sint,
168        "rgba16float" => Sf::Rgba16Float,
169        "rgba32uint" => Sf::Rgba32Uint,
170        "rgba32sint" => Sf::Rgba32Sint,
171        "rgba32float" => Sf::Rgba32Float,
172        "bgra8unorm" => Sf::Bgra8Unorm,
173        _ => return Err(Box::new(Error::UnknownStorageFormat(span))),
174    })
175}
176
177pub fn get_scalar_type(
178    enable_extensions: &EnableExtensions,
179    span: Span,
180    word: &str,
181) -> Result<'static, Option<Scalar>> {
182    use crate::ScalarKind as Sk;
183    let scalar = match word {
184        "f16" => Some(Scalar {
185            kind: Sk::Float,
186            width: 2,
187        }),
188        "f32" => Some(Scalar {
189            kind: Sk::Float,
190            width: 4,
191        }),
192        "f64" => Some(Scalar {
193            kind: Sk::Float,
194            width: 8,
195        }),
196        "i32" => Some(Scalar {
197            kind: Sk::Sint,
198            width: 4,
199        }),
200        "u32" => Some(Scalar {
201            kind: Sk::Uint,
202            width: 4,
203        }),
204        "i64" => Some(Scalar {
205            kind: Sk::Sint,
206            width: 8,
207        }),
208        "u64" => Some(Scalar {
209            kind: Sk::Uint,
210            width: 8,
211        }),
212        "bool" => Some(Scalar {
213            kind: Sk::Bool,
214            width: crate::BOOL_WIDTH,
215        }),
216        _ => None,
217    };
218
219    if matches!(scalar, Some(Scalar::F16))
220        && !enable_extensions.contains(ImplementedEnableExtension::F16)
221    {
222        return Err(Box::new(Error::EnableExtensionNotEnabled {
223            span,
224            kind: ImplementedEnableExtension::F16.into(),
225        }));
226    }
227
228    Ok(scalar)
229}
230
231pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
232    use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
233    match word {
234        "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
235        "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
236        "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
237        "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
238        "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
239        "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
240        "dpdx" => Some((Axis::X, Ctrl::None)),
241        "dpdy" => Some((Axis::Y, Ctrl::None)),
242        "fwidth" => Some((Axis::Width, Ctrl::None)),
243        _ => None,
244    }
245}
246
247pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
248    match word {
249        "any" => Some(crate::RelationalFunction::Any),
250        "all" => Some(crate::RelationalFunction::All),
251        _ => None,
252    }
253}
254
255pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
256    use crate::MathFunction as Mf;
257    Some(match word {
258        // comparison
259        "abs" => Mf::Abs,
260        "min" => Mf::Min,
261        "max" => Mf::Max,
262        "clamp" => Mf::Clamp,
263        "saturate" => Mf::Saturate,
264        // trigonometry
265        "cos" => Mf::Cos,
266        "cosh" => Mf::Cosh,
267        "sin" => Mf::Sin,
268        "sinh" => Mf::Sinh,
269        "tan" => Mf::Tan,
270        "tanh" => Mf::Tanh,
271        "acos" => Mf::Acos,
272        "acosh" => Mf::Acosh,
273        "asin" => Mf::Asin,
274        "asinh" => Mf::Asinh,
275        "atan" => Mf::Atan,
276        "atanh" => Mf::Atanh,
277        "atan2" => Mf::Atan2,
278        "radians" => Mf::Radians,
279        "degrees" => Mf::Degrees,
280        // decomposition
281        "ceil" => Mf::Ceil,
282        "floor" => Mf::Floor,
283        "round" => Mf::Round,
284        "fract" => Mf::Fract,
285        "trunc" => Mf::Trunc,
286        "modf" => Mf::Modf,
287        "frexp" => Mf::Frexp,
288        "ldexp" => Mf::Ldexp,
289        // exponent
290        "exp" => Mf::Exp,
291        "exp2" => Mf::Exp2,
292        "log" => Mf::Log,
293        "log2" => Mf::Log2,
294        "pow" => Mf::Pow,
295        // geometry
296        "dot" => Mf::Dot,
297        "dot4I8Packed" => Mf::Dot4I8Packed,
298        "dot4U8Packed" => Mf::Dot4U8Packed,
299        "cross" => Mf::Cross,
300        "distance" => Mf::Distance,
301        "length" => Mf::Length,
302        "normalize" => Mf::Normalize,
303        "faceForward" => Mf::FaceForward,
304        "reflect" => Mf::Reflect,
305        "refract" => Mf::Refract,
306        // computational
307        "sign" => Mf::Sign,
308        "fma" => Mf::Fma,
309        "mix" => Mf::Mix,
310        "step" => Mf::Step,
311        "smoothstep" => Mf::SmoothStep,
312        "sqrt" => Mf::Sqrt,
313        "inverseSqrt" => Mf::InverseSqrt,
314        "transpose" => Mf::Transpose,
315        "determinant" => Mf::Determinant,
316        "quantizeToF16" => Mf::QuantizeToF16,
317        // bits
318        "countTrailingZeros" => Mf::CountTrailingZeros,
319        "countLeadingZeros" => Mf::CountLeadingZeros,
320        "countOneBits" => Mf::CountOneBits,
321        "reverseBits" => Mf::ReverseBits,
322        "extractBits" => Mf::ExtractBits,
323        "insertBits" => Mf::InsertBits,
324        "firstTrailingBit" => Mf::FirstTrailingBit,
325        "firstLeadingBit" => Mf::FirstLeadingBit,
326        // data packing
327        "pack4x8snorm" => Mf::Pack4x8snorm,
328        "pack4x8unorm" => Mf::Pack4x8unorm,
329        "pack2x16snorm" => Mf::Pack2x16snorm,
330        "pack2x16unorm" => Mf::Pack2x16unorm,
331        "pack2x16float" => Mf::Pack2x16float,
332        "pack4xI8" => Mf::Pack4xI8,
333        "pack4xU8" => Mf::Pack4xU8,
334        "pack4xI8Clamp" => Mf::Pack4xI8Clamp,
335        "pack4xU8Clamp" => Mf::Pack4xU8Clamp,
336        // data unpacking
337        "unpack4x8snorm" => Mf::Unpack4x8snorm,
338        "unpack4x8unorm" => Mf::Unpack4x8unorm,
339        "unpack2x16snorm" => Mf::Unpack2x16snorm,
340        "unpack2x16unorm" => Mf::Unpack2x16unorm,
341        "unpack2x16float" => Mf::Unpack2x16float,
342        "unpack4xI8" => Mf::Unpack4xI8,
343        "unpack4xU8" => Mf::Unpack4xU8,
344        _ => return None,
345    })
346}
347
348pub fn map_conservative_depth(word: &str, span: Span) -> Result<'_, crate::ConservativeDepth> {
349    use crate::ConservativeDepth as Cd;
350    match word {
351        "greater_equal" => Ok(Cd::GreaterEqual),
352        "less_equal" => Ok(Cd::LessEqual),
353        "unchanged" => Ok(Cd::Unchanged),
354        _ => Err(Box::new(Error::UnknownConservativeDepth(span))),
355    }
356}
357
358pub fn map_subgroup_operation(
359    word: &str,
360) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
361    use crate::CollectiveOperation as co;
362    use crate::SubgroupOperation as sg;
363    Some(match word {
364        "subgroupAll" => (sg::All, co::Reduce),
365        "subgroupAny" => (sg::Any, co::Reduce),
366        "subgroupAdd" => (sg::Add, co::Reduce),
367        "subgroupMul" => (sg::Mul, co::Reduce),
368        "subgroupMin" => (sg::Min, co::Reduce),
369        "subgroupMax" => (sg::Max, co::Reduce),
370        "subgroupAnd" => (sg::And, co::Reduce),
371        "subgroupOr" => (sg::Or, co::Reduce),
372        "subgroupXor" => (sg::Xor, co::Reduce),
373        "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
374        "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
375        "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
376        "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
377        _ => return None,
378    })
379}