naga/front/wgsl/parse/
conv.rs

1use crate::front::wgsl::parse::directive::enable_extension::{
2    EnableExtensions, ImplementedEnableExtension,
3};
4use crate::front::wgsl::{Error, Result, Scalar};
5use crate::Span;
6
7use alloc::boxed::Box;
8
9pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpace> {
10    match word {
11        "private" => Ok(crate::AddressSpace::Private),
12        "workgroup" => Ok(crate::AddressSpace::WorkGroup),
13        "uniform" => Ok(crate::AddressSpace::Uniform),
14        "storage" => Ok(crate::AddressSpace::Storage {
15            access: crate::StorageAccess::default(),
16        }),
17        "push_constant" => Ok(crate::AddressSpace::PushConstant),
18        "function" => Ok(crate::AddressSpace::Function),
19        _ => Err(Box::new(Error::UnknownAddressSpace(span))),
20    }
21}
22
23pub fn map_built_in(
24    enable_extensions: &EnableExtensions,
25    word: &str,
26    span: Span,
27) -> Result<'static, crate::BuiltIn> {
28    let built_in = match word {
29        "position" => crate::BuiltIn::Position { invariant: false },
30        // vertex
31        "vertex_index" => crate::BuiltIn::VertexIndex,
32        "instance_index" => crate::BuiltIn::InstanceIndex,
33        "view_index" => crate::BuiltIn::ViewIndex,
34        "clip_distances" => crate::BuiltIn::ClipDistance,
35        // fragment
36        "front_facing" => crate::BuiltIn::FrontFacing,
37        "frag_depth" => crate::BuiltIn::FragDepth,
38        "primitive_index" => crate::BuiltIn::PrimitiveIndex,
39        "sample_index" => crate::BuiltIn::SampleIndex,
40        "sample_mask" => crate::BuiltIn::SampleMask,
41        // compute
42        "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
43        "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
44        "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
45        "workgroup_id" => crate::BuiltIn::WorkGroupId,
46        "num_workgroups" => crate::BuiltIn::NumWorkGroups,
47        // subgroup
48        "num_subgroups" => crate::BuiltIn::NumSubgroups,
49        "subgroup_id" => crate::BuiltIn::SubgroupId,
50        "subgroup_size" => crate::BuiltIn::SubgroupSize,
51        "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
52        _ => return Err(Box::new(Error::UnknownBuiltin(span))),
53    };
54    match built_in {
55        crate::BuiltIn::ClipDistance => {
56            if !enable_extensions.contains(ImplementedEnableExtension::ClipDistances) {
57                return Err(Box::new(Error::EnableExtensionNotEnabled {
58                    span,
59                    kind: ImplementedEnableExtension::ClipDistances.into(),
60                }));
61            }
62        }
63        _ => {}
64    }
65    Ok(built_in)
66}
67
68pub fn map_interpolation(word: &str, span: Span) -> Result<'_, crate::Interpolation> {
69    match word {
70        "linear" => Ok(crate::Interpolation::Linear),
71        "flat" => Ok(crate::Interpolation::Flat),
72        "perspective" => Ok(crate::Interpolation::Perspective),
73        _ => Err(Box::new(Error::UnknownAttribute(span))),
74    }
75}
76
77pub fn map_sampling(word: &str, span: Span) -> Result<'_, crate::Sampling> {
78    match word {
79        "center" => Ok(crate::Sampling::Center),
80        "centroid" => Ok(crate::Sampling::Centroid),
81        "sample" => Ok(crate::Sampling::Sample),
82        "first" => Ok(crate::Sampling::First),
83        "either" => Ok(crate::Sampling::Either),
84        _ => Err(Box::new(Error::UnknownAttribute(span))),
85    }
86}
87
88pub fn map_storage_format(word: &str, span: Span) -> Result<'_, crate::StorageFormat> {
89    use crate::StorageFormat as Sf;
90    Ok(match word {
91        "r8unorm" => Sf::R8Unorm,
92        "r8snorm" => Sf::R8Snorm,
93        "r8uint" => Sf::R8Uint,
94        "r8sint" => Sf::R8Sint,
95        "r16unorm" => Sf::R16Unorm,
96        "r16snorm" => Sf::R16Snorm,
97        "r16uint" => Sf::R16Uint,
98        "r16sint" => Sf::R16Sint,
99        "r16float" => Sf::R16Float,
100        "rg8unorm" => Sf::Rg8Unorm,
101        "rg8snorm" => Sf::Rg8Snorm,
102        "rg8uint" => Sf::Rg8Uint,
103        "rg8sint" => Sf::Rg8Sint,
104        "r32uint" => Sf::R32Uint,
105        "r32sint" => Sf::R32Sint,
106        "r32float" => Sf::R32Float,
107        "rg16unorm" => Sf::Rg16Unorm,
108        "rg16snorm" => Sf::Rg16Snorm,
109        "rg16uint" => Sf::Rg16Uint,
110        "rg16sint" => Sf::Rg16Sint,
111        "rg16float" => Sf::Rg16Float,
112        "rgba8unorm" => Sf::Rgba8Unorm,
113        "rgba8snorm" => Sf::Rgba8Snorm,
114        "rgba8uint" => Sf::Rgba8Uint,
115        "rgba8sint" => Sf::Rgba8Sint,
116        "rgb10a2uint" => Sf::Rgb10a2Uint,
117        "rgb10a2unorm" => Sf::Rgb10a2Unorm,
118        "rg11b10float" => Sf::Rg11b10Ufloat,
119        "r64uint" => Sf::R64Uint,
120        "rg32uint" => Sf::Rg32Uint,
121        "rg32sint" => Sf::Rg32Sint,
122        "rg32float" => Sf::Rg32Float,
123        "rgba16unorm" => Sf::Rgba16Unorm,
124        "rgba16snorm" => Sf::Rgba16Snorm,
125        "rgba16uint" => Sf::Rgba16Uint,
126        "rgba16sint" => Sf::Rgba16Sint,
127        "rgba16float" => Sf::Rgba16Float,
128        "rgba32uint" => Sf::Rgba32Uint,
129        "rgba32sint" => Sf::Rgba32Sint,
130        "rgba32float" => Sf::Rgba32Float,
131        "bgra8unorm" => Sf::Bgra8Unorm,
132        _ => return Err(Box::new(Error::UnknownStorageFormat(span))),
133    })
134}
135
136pub fn get_scalar_type(
137    enable_extensions: &EnableExtensions,
138    span: Span,
139    word: &str,
140) -> Result<'static, Option<Scalar>> {
141    use crate::ScalarKind as Sk;
142    let scalar = match word {
143        "f16" => Some(Scalar {
144            kind: Sk::Float,
145            width: 2,
146        }),
147        "f32" => Some(Scalar {
148            kind: Sk::Float,
149            width: 4,
150        }),
151        "f64" => Some(Scalar {
152            kind: Sk::Float,
153            width: 8,
154        }),
155        "i32" => Some(Scalar {
156            kind: Sk::Sint,
157            width: 4,
158        }),
159        "u32" => Some(Scalar {
160            kind: Sk::Uint,
161            width: 4,
162        }),
163        "i64" => Some(Scalar {
164            kind: Sk::Sint,
165            width: 8,
166        }),
167        "u64" => Some(Scalar {
168            kind: Sk::Uint,
169            width: 8,
170        }),
171        "bool" => Some(Scalar {
172            kind: Sk::Bool,
173            width: crate::BOOL_WIDTH,
174        }),
175        _ => None,
176    };
177
178    if matches!(scalar, Some(Scalar::F16))
179        && !enable_extensions.contains(ImplementedEnableExtension::F16)
180    {
181        return Err(Box::new(Error::EnableExtensionNotEnabled {
182            span,
183            kind: ImplementedEnableExtension::F16.into(),
184        }));
185    }
186
187    Ok(scalar)
188}
189
190pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
191    use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
192    match word {
193        "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
194        "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
195        "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
196        "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
197        "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
198        "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
199        "dpdx" => Some((Axis::X, Ctrl::None)),
200        "dpdy" => Some((Axis::Y, Ctrl::None)),
201        "fwidth" => Some((Axis::Width, Ctrl::None)),
202        _ => None,
203    }
204}
205
206pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
207    match word {
208        "any" => Some(crate::RelationalFunction::Any),
209        "all" => Some(crate::RelationalFunction::All),
210        _ => None,
211    }
212}
213
214pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
215    use crate::MathFunction as Mf;
216    Some(match word {
217        // comparison
218        "abs" => Mf::Abs,
219        "min" => Mf::Min,
220        "max" => Mf::Max,
221        "clamp" => Mf::Clamp,
222        "saturate" => Mf::Saturate,
223        // trigonometry
224        "cos" => Mf::Cos,
225        "cosh" => Mf::Cosh,
226        "sin" => Mf::Sin,
227        "sinh" => Mf::Sinh,
228        "tan" => Mf::Tan,
229        "tanh" => Mf::Tanh,
230        "acos" => Mf::Acos,
231        "acosh" => Mf::Acosh,
232        "asin" => Mf::Asin,
233        "asinh" => Mf::Asinh,
234        "atan" => Mf::Atan,
235        "atanh" => Mf::Atanh,
236        "atan2" => Mf::Atan2,
237        "radians" => Mf::Radians,
238        "degrees" => Mf::Degrees,
239        // decomposition
240        "ceil" => Mf::Ceil,
241        "floor" => Mf::Floor,
242        "round" => Mf::Round,
243        "fract" => Mf::Fract,
244        "trunc" => Mf::Trunc,
245        "modf" => Mf::Modf,
246        "frexp" => Mf::Frexp,
247        "ldexp" => Mf::Ldexp,
248        // exponent
249        "exp" => Mf::Exp,
250        "exp2" => Mf::Exp2,
251        "log" => Mf::Log,
252        "log2" => Mf::Log2,
253        "pow" => Mf::Pow,
254        // geometry
255        "dot" => Mf::Dot,
256        "dot4I8Packed" => Mf::Dot4I8Packed,
257        "dot4U8Packed" => Mf::Dot4U8Packed,
258        "cross" => Mf::Cross,
259        "distance" => Mf::Distance,
260        "length" => Mf::Length,
261        "normalize" => Mf::Normalize,
262        "faceForward" => Mf::FaceForward,
263        "reflect" => Mf::Reflect,
264        "refract" => Mf::Refract,
265        // computational
266        "sign" => Mf::Sign,
267        "fma" => Mf::Fma,
268        "mix" => Mf::Mix,
269        "step" => Mf::Step,
270        "smoothstep" => Mf::SmoothStep,
271        "sqrt" => Mf::Sqrt,
272        "inverseSqrt" => Mf::InverseSqrt,
273        "transpose" => Mf::Transpose,
274        "determinant" => Mf::Determinant,
275        "quantizeToF16" => Mf::QuantizeToF16,
276        // bits
277        "countTrailingZeros" => Mf::CountTrailingZeros,
278        "countLeadingZeros" => Mf::CountLeadingZeros,
279        "countOneBits" => Mf::CountOneBits,
280        "reverseBits" => Mf::ReverseBits,
281        "extractBits" => Mf::ExtractBits,
282        "insertBits" => Mf::InsertBits,
283        "firstTrailingBit" => Mf::FirstTrailingBit,
284        "firstLeadingBit" => Mf::FirstLeadingBit,
285        // data packing
286        "pack4x8snorm" => Mf::Pack4x8snorm,
287        "pack4x8unorm" => Mf::Pack4x8unorm,
288        "pack2x16snorm" => Mf::Pack2x16snorm,
289        "pack2x16unorm" => Mf::Pack2x16unorm,
290        "pack2x16float" => Mf::Pack2x16float,
291        "pack4xI8" => Mf::Pack4xI8,
292        "pack4xU8" => Mf::Pack4xU8,
293        "pack4xI8Clamp" => Mf::Pack4xI8Clamp,
294        "pack4xU8Clamp" => Mf::Pack4xU8Clamp,
295        // data unpacking
296        "unpack4x8snorm" => Mf::Unpack4x8snorm,
297        "unpack4x8unorm" => Mf::Unpack4x8unorm,
298        "unpack2x16snorm" => Mf::Unpack2x16snorm,
299        "unpack2x16unorm" => Mf::Unpack2x16unorm,
300        "unpack2x16float" => Mf::Unpack2x16float,
301        "unpack4xI8" => Mf::Unpack4xI8,
302        "unpack4xU8" => Mf::Unpack4xU8,
303        _ => return None,
304    })
305}
306
307pub fn map_conservative_depth(word: &str, span: Span) -> Result<'_, crate::ConservativeDepth> {
308    use crate::ConservativeDepth as Cd;
309    match word {
310        "greater_equal" => Ok(Cd::GreaterEqual),
311        "less_equal" => Ok(Cd::LessEqual),
312        "unchanged" => Ok(Cd::Unchanged),
313        _ => Err(Box::new(Error::UnknownConservativeDepth(span))),
314    }
315}
316
317pub fn map_subgroup_operation(
318    word: &str,
319) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
320    use crate::CollectiveOperation as co;
321    use crate::SubgroupOperation as sg;
322    Some(match word {
323        "subgroupAll" => (sg::All, co::Reduce),
324        "subgroupAny" => (sg::Any, co::Reduce),
325        "subgroupAdd" => (sg::Add, co::Reduce),
326        "subgroupMul" => (sg::Mul, co::Reduce),
327        "subgroupMin" => (sg::Min, co::Reduce),
328        "subgroupMax" => (sg::Max, co::Reduce),
329        "subgroupAnd" => (sg::And, co::Reduce),
330        "subgroupOr" => (sg::Or, co::Reduce),
331        "subgroupXor" => (sg::Xor, co::Reduce),
332        "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
333        "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
334        "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
335        "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
336        _ => return None,
337    })
338}