naga/front/wgsl/parse/
conv.rs

1use crate::front::wgsl::parse::directive::enable_extension::{
2    EnableExtensions, ImplementedEnableExtension,
3};
4use crate::front::wgsl::{Error, Result, Scalar};
5use crate::Span;
6
7use alloc::boxed::Box;
8
9pub fn map_address_space<'a>(
10    word: &str,
11    span: Span,
12    enable_extensions: &EnableExtensions,
13) -> Result<'a, crate::AddressSpace> {
14    match word {
15        "private" => Ok(crate::AddressSpace::Private),
16        "workgroup" => Ok(crate::AddressSpace::WorkGroup),
17        "uniform" => Ok(crate::AddressSpace::Uniform),
18        "storage" => Ok(crate::AddressSpace::Storage {
19            access: crate::StorageAccess::default(),
20        }),
21        "immediate" => Ok(crate::AddressSpace::Immediate),
22        "function" => Ok(crate::AddressSpace::Function),
23        "task_payload" => {
24            if enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
25                Ok(crate::AddressSpace::TaskPayload)
26            } else {
27                Err(Box::new(Error::EnableExtensionNotEnabled {
28                    span,
29                    kind: ImplementedEnableExtension::WgpuMeshShader.into(),
30                }))
31            }
32        }
33        _ => Err(Box::new(Error::UnknownAddressSpace(span))),
34    }
35}
36
37pub fn map_built_in(
38    enable_extensions: &EnableExtensions,
39    word: &str,
40    span: Span,
41) -> Result<'static, crate::BuiltIn> {
42    let built_in = match word {
43        "position" => crate::BuiltIn::Position { invariant: false },
44        // vertex
45        "vertex_index" => crate::BuiltIn::VertexIndex,
46        "instance_index" => crate::BuiltIn::InstanceIndex,
47        "view_index" => crate::BuiltIn::ViewIndex,
48        "clip_distances" => crate::BuiltIn::ClipDistance,
49        // fragment
50        "front_facing" => crate::BuiltIn::FrontFacing,
51        "frag_depth" => crate::BuiltIn::FragDepth,
52        "primitive_index" => crate::BuiltIn::PrimitiveIndex,
53        "barycentric" => crate::BuiltIn::Barycentric { perspective: true },
54        "barycentric_no_perspective" => crate::BuiltIn::Barycentric { perspective: false },
55        "sample_index" => crate::BuiltIn::SampleIndex,
56        "sample_mask" => crate::BuiltIn::SampleMask,
57        // compute
58        "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
59        "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
60        "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
61        "workgroup_id" => crate::BuiltIn::WorkGroupId,
62        "num_workgroups" => crate::BuiltIn::NumWorkGroups,
63        // subgroup
64        "num_subgroups" => crate::BuiltIn::NumSubgroups,
65        "subgroup_id" => crate::BuiltIn::SubgroupId,
66        "subgroup_size" => crate::BuiltIn::SubgroupSize,
67        "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
68        // mesh
69        "cull_primitive" => crate::BuiltIn::CullPrimitive,
70        "point_index" => crate::BuiltIn::PointIndex,
71        "line_indices" => crate::BuiltIn::LineIndices,
72        "triangle_indices" => crate::BuiltIn::TriangleIndices,
73        "mesh_task_size" => crate::BuiltIn::MeshTaskSize,
74        // mesh global variable
75        "vertex_count" => crate::BuiltIn::VertexCount,
76        "vertices" => crate::BuiltIn::Vertices,
77        "primitive_count" => crate::BuiltIn::PrimitiveCount,
78        "primitives" => crate::BuiltIn::Primitives,
79        _ => return Err(Box::new(Error::UnknownBuiltin(span))),
80    };
81    match built_in {
82        crate::BuiltIn::ClipDistance => {
83            if !enable_extensions.contains(ImplementedEnableExtension::ClipDistances) {
84                return Err(Box::new(Error::EnableExtensionNotEnabled {
85                    span,
86                    kind: ImplementedEnableExtension::ClipDistances.into(),
87                }));
88            }
89        }
90        crate::BuiltIn::CullPrimitive
91        | crate::BuiltIn::PointIndex
92        | crate::BuiltIn::LineIndices
93        | crate::BuiltIn::TriangleIndices
94        | crate::BuiltIn::VertexCount
95        | crate::BuiltIn::Vertices
96        | crate::BuiltIn::PrimitiveCount
97        | crate::BuiltIn::Primitives => {
98            if !enable_extensions.contains(ImplementedEnableExtension::WgpuMeshShader) {
99                return Err(Box::new(Error::EnableExtensionNotEnabled {
100                    span,
101                    kind: ImplementedEnableExtension::WgpuMeshShader.into(),
102                }));
103            }
104        }
105        _ => {}
106    }
107    Ok(built_in)
108}
109
110pub fn map_interpolation(word: &str, span: Span) -> Result<'_, crate::Interpolation> {
111    match word {
112        "linear" => Ok(crate::Interpolation::Linear),
113        "flat" => Ok(crate::Interpolation::Flat),
114        "perspective" => Ok(crate::Interpolation::Perspective),
115        "per_vertex" => Ok(crate::Interpolation::PerVertex),
116        _ => Err(Box::new(Error::UnknownAttribute(span))),
117    }
118}
119
120pub fn map_sampling(word: &str, span: Span) -> Result<'_, crate::Sampling> {
121    match word {
122        "center" => Ok(crate::Sampling::Center),
123        "centroid" => Ok(crate::Sampling::Centroid),
124        "sample" => Ok(crate::Sampling::Sample),
125        "first" => Ok(crate::Sampling::First),
126        "either" => Ok(crate::Sampling::Either),
127        _ => Err(Box::new(Error::UnknownAttribute(span))),
128    }
129}
130
131pub fn map_storage_format(word: &str, span: Span) -> Result<'_, crate::StorageFormat> {
132    use crate::StorageFormat as Sf;
133    Ok(match word {
134        "r8unorm" => Sf::R8Unorm,
135        "r8snorm" => Sf::R8Snorm,
136        "r8uint" => Sf::R8Uint,
137        "r8sint" => Sf::R8Sint,
138        "r16unorm" => Sf::R16Unorm,
139        "r16snorm" => Sf::R16Snorm,
140        "r16uint" => Sf::R16Uint,
141        "r16sint" => Sf::R16Sint,
142        "r16float" => Sf::R16Float,
143        "rg8unorm" => Sf::Rg8Unorm,
144        "rg8snorm" => Sf::Rg8Snorm,
145        "rg8uint" => Sf::Rg8Uint,
146        "rg8sint" => Sf::Rg8Sint,
147        "r32uint" => Sf::R32Uint,
148        "r32sint" => Sf::R32Sint,
149        "r32float" => Sf::R32Float,
150        "rg16unorm" => Sf::Rg16Unorm,
151        "rg16snorm" => Sf::Rg16Snorm,
152        "rg16uint" => Sf::Rg16Uint,
153        "rg16sint" => Sf::Rg16Sint,
154        "rg16float" => Sf::Rg16Float,
155        "rgba8unorm" => Sf::Rgba8Unorm,
156        "rgba8snorm" => Sf::Rgba8Snorm,
157        "rgba8uint" => Sf::Rgba8Uint,
158        "rgba8sint" => Sf::Rgba8Sint,
159        "rgb10a2uint" => Sf::Rgb10a2Uint,
160        "rgb10a2unorm" => Sf::Rgb10a2Unorm,
161        "rg11b10ufloat" => Sf::Rg11b10Ufloat,
162        "r64uint" => Sf::R64Uint,
163        "rg32uint" => Sf::Rg32Uint,
164        "rg32sint" => Sf::Rg32Sint,
165        "rg32float" => Sf::Rg32Float,
166        "rgba16unorm" => Sf::Rgba16Unorm,
167        "rgba16snorm" => Sf::Rgba16Snorm,
168        "rgba16uint" => Sf::Rgba16Uint,
169        "rgba16sint" => Sf::Rgba16Sint,
170        "rgba16float" => Sf::Rgba16Float,
171        "rgba32uint" => Sf::Rgba32Uint,
172        "rgba32sint" => Sf::Rgba32Sint,
173        "rgba32float" => Sf::Rgba32Float,
174        "bgra8unorm" => Sf::Bgra8Unorm,
175        _ => return Err(Box::new(Error::UnknownStorageFormat(span))),
176    })
177}
178
179pub fn get_scalar_type(
180    enable_extensions: &EnableExtensions,
181    span: Span,
182    word: &str,
183) -> Result<'static, Option<Scalar>> {
184    use crate::ScalarKind as Sk;
185    let scalar = match word {
186        "f16" => Some(Scalar {
187            kind: Sk::Float,
188            width: 2,
189        }),
190        "f32" => Some(Scalar {
191            kind: Sk::Float,
192            width: 4,
193        }),
194        "f64" => Some(Scalar {
195            kind: Sk::Float,
196            width: 8,
197        }),
198        "i32" => Some(Scalar {
199            kind: Sk::Sint,
200            width: 4,
201        }),
202        "u32" => Some(Scalar {
203            kind: Sk::Uint,
204            width: 4,
205        }),
206        "i64" => Some(Scalar {
207            kind: Sk::Sint,
208            width: 8,
209        }),
210        "u64" => Some(Scalar {
211            kind: Sk::Uint,
212            width: 8,
213        }),
214        "bool" => Some(Scalar {
215            kind: Sk::Bool,
216            width: crate::BOOL_WIDTH,
217        }),
218        _ => None,
219    };
220
221    if matches!(scalar, Some(Scalar::F16))
222        && !enable_extensions.contains(ImplementedEnableExtension::F16)
223    {
224        return Err(Box::new(Error::EnableExtensionNotEnabled {
225            span,
226            kind: ImplementedEnableExtension::F16.into(),
227        }));
228    }
229
230    Ok(scalar)
231}
232
233pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
234    use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
235    match word {
236        "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
237        "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
238        "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
239        "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
240        "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
241        "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
242        "dpdx" => Some((Axis::X, Ctrl::None)),
243        "dpdy" => Some((Axis::Y, Ctrl::None)),
244        "fwidth" => Some((Axis::Width, Ctrl::None)),
245        _ => None,
246    }
247}
248
249pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
250    match word {
251        "any" => Some(crate::RelationalFunction::Any),
252        "all" => Some(crate::RelationalFunction::All),
253        _ => None,
254    }
255}
256
257pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
258    use crate::MathFunction as Mf;
259    Some(match word {
260        // comparison
261        "abs" => Mf::Abs,
262        "min" => Mf::Min,
263        "max" => Mf::Max,
264        "clamp" => Mf::Clamp,
265        "saturate" => Mf::Saturate,
266        // trigonometry
267        "cos" => Mf::Cos,
268        "cosh" => Mf::Cosh,
269        "sin" => Mf::Sin,
270        "sinh" => Mf::Sinh,
271        "tan" => Mf::Tan,
272        "tanh" => Mf::Tanh,
273        "acos" => Mf::Acos,
274        "acosh" => Mf::Acosh,
275        "asin" => Mf::Asin,
276        "asinh" => Mf::Asinh,
277        "atan" => Mf::Atan,
278        "atanh" => Mf::Atanh,
279        "atan2" => Mf::Atan2,
280        "radians" => Mf::Radians,
281        "degrees" => Mf::Degrees,
282        // decomposition
283        "ceil" => Mf::Ceil,
284        "floor" => Mf::Floor,
285        "round" => Mf::Round,
286        "fract" => Mf::Fract,
287        "trunc" => Mf::Trunc,
288        "modf" => Mf::Modf,
289        "frexp" => Mf::Frexp,
290        "ldexp" => Mf::Ldexp,
291        // exponent
292        "exp" => Mf::Exp,
293        "exp2" => Mf::Exp2,
294        "log" => Mf::Log,
295        "log2" => Mf::Log2,
296        "pow" => Mf::Pow,
297        // geometry
298        "dot" => Mf::Dot,
299        "dot4I8Packed" => Mf::Dot4I8Packed,
300        "dot4U8Packed" => Mf::Dot4U8Packed,
301        "cross" => Mf::Cross,
302        "distance" => Mf::Distance,
303        "length" => Mf::Length,
304        "normalize" => Mf::Normalize,
305        "faceForward" => Mf::FaceForward,
306        "reflect" => Mf::Reflect,
307        "refract" => Mf::Refract,
308        // computational
309        "sign" => Mf::Sign,
310        "fma" => Mf::Fma,
311        "mix" => Mf::Mix,
312        "step" => Mf::Step,
313        "smoothstep" => Mf::SmoothStep,
314        "sqrt" => Mf::Sqrt,
315        "inverseSqrt" => Mf::InverseSqrt,
316        "transpose" => Mf::Transpose,
317        "determinant" => Mf::Determinant,
318        "quantizeToF16" => Mf::QuantizeToF16,
319        // bits
320        "countTrailingZeros" => Mf::CountTrailingZeros,
321        "countLeadingZeros" => Mf::CountLeadingZeros,
322        "countOneBits" => Mf::CountOneBits,
323        "reverseBits" => Mf::ReverseBits,
324        "extractBits" => Mf::ExtractBits,
325        "insertBits" => Mf::InsertBits,
326        "firstTrailingBit" => Mf::FirstTrailingBit,
327        "firstLeadingBit" => Mf::FirstLeadingBit,
328        // data packing
329        "pack4x8snorm" => Mf::Pack4x8snorm,
330        "pack4x8unorm" => Mf::Pack4x8unorm,
331        "pack2x16snorm" => Mf::Pack2x16snorm,
332        "pack2x16unorm" => Mf::Pack2x16unorm,
333        "pack2x16float" => Mf::Pack2x16float,
334        "pack4xI8" => Mf::Pack4xI8,
335        "pack4xU8" => Mf::Pack4xU8,
336        "pack4xI8Clamp" => Mf::Pack4xI8Clamp,
337        "pack4xU8Clamp" => Mf::Pack4xU8Clamp,
338        // data unpacking
339        "unpack4x8snorm" => Mf::Unpack4x8snorm,
340        "unpack4x8unorm" => Mf::Unpack4x8unorm,
341        "unpack2x16snorm" => Mf::Unpack2x16snorm,
342        "unpack2x16unorm" => Mf::Unpack2x16unorm,
343        "unpack2x16float" => Mf::Unpack2x16float,
344        "unpack4xI8" => Mf::Unpack4xI8,
345        "unpack4xU8" => Mf::Unpack4xU8,
346        _ => return None,
347    })
348}
349
350pub fn map_conservative_depth(word: &str, span: Span) -> Result<'_, crate::ConservativeDepth> {
351    use crate::ConservativeDepth as Cd;
352    match word {
353        "greater_equal" => Ok(Cd::GreaterEqual),
354        "less_equal" => Ok(Cd::LessEqual),
355        "unchanged" => Ok(Cd::Unchanged),
356        _ => Err(Box::new(Error::UnknownConservativeDepth(span))),
357    }
358}
359
360pub fn map_subgroup_operation(
361    word: &str,
362) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
363    use crate::CollectiveOperation as co;
364    use crate::SubgroupOperation as sg;
365    Some(match word {
366        "subgroupAll" => (sg::All, co::Reduce),
367        "subgroupAny" => (sg::Any, co::Reduce),
368        "subgroupAdd" => (sg::Add, co::Reduce),
369        "subgroupMul" => (sg::Mul, co::Reduce),
370        "subgroupMin" => (sg::Min, co::Reduce),
371        "subgroupMax" => (sg::Max, co::Reduce),
372        "subgroupAnd" => (sg::And, co::Reduce),
373        "subgroupOr" => (sg::Or, co::Reduce),
374        "subgroupXor" => (sg::Xor, co::Reduce),
375        "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
376        "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
377        "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
378        "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
379        _ => return None,
380    })
381}