naga/front/wgsl/parse/
number.rs

1use alloc::format;
2
3use crate::front::wgsl::error::NumberError;
4use crate::front::wgsl::parse::directive::enable_extension::ImplementedEnableExtension;
5use crate::front::wgsl::parse::lexer::Token;
6use half::f16;
7
8/// When using this type assume no Abstract Int/Float for now
9#[derive(Copy, Clone, Debug, PartialEq)]
10pub enum Number {
11    /// Abstract Int (-2^63 ≤ i < 2^63)
12    AbstractInt(i64),
13    /// Abstract Float (IEEE-754 binary64)
14    AbstractFloat(f64),
15    /// Concrete i32
16    I32(i32),
17    /// Concrete u32
18    U32(u32),
19    /// Concrete i64
20    I64(i64),
21    /// Concrete u64
22    U64(u64),
23    /// Concrete f16
24    F16(f16),
25    /// Concrete f32
26    F32(f32),
27    /// Concrete f64
28    F64(f64),
29}
30
31impl Number {
32    pub(super) const fn requires_enable_extension(&self) -> Option<ImplementedEnableExtension> {
33        match *self {
34            Number::F16(_) => Some(ImplementedEnableExtension::F16),
35            _ => None,
36        }
37    }
38}
39
40pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
41    let (result, rest) = parse(input);
42    (Token::Number(result), rest)
43}
44
45enum Kind {
46    Int(IntKind),
47    Float(FloatKind),
48}
49
50enum IntKind {
51    I32,
52    U32,
53    I64,
54    U64,
55}
56
57#[derive(Debug)]
58enum FloatKind {
59    F16,
60    F32,
61    F64,
62}
63
64// The following regexes (from the WGSL spec) will be matched:
65
66// int_literal:
67// | / 0                                                                [iu]?   /
68// | / [1-9][0-9]*                                                      [iu]?   /
69// | / 0[xX][0-9a-fA-F]+                                                [iu]?   /
70
71// decimal_float_literal:
72// | / 0                                                                [fh]    /
73// | / [1-9][0-9]*                                                      [fh]    /
74// | / [0-9]*               \.[0-9]+            ([eE][+-]?[0-9]+)?      [fh]?   /
75// | / [0-9]+               \.[0-9]*            ([eE][+-]?[0-9]+)?      [fh]?   /
76// | / [0-9]+                                    [eE][+-]?[0-9]+        [fh]?   /
77
78// hex_float_literal:
79// | / 0[xX][0-9a-fA-F]*    \.[0-9a-fA-F]+      ([pP][+-]?[0-9]+        [fh]?)? /
80// | / 0[xX][0-9a-fA-F]+    \.[0-9a-fA-F]*      ([pP][+-]?[0-9]+        [fh]?)? /
81// | / 0[xX][0-9a-fA-F]+                         [pP][+-]?[0-9]+        [fh]?   /
82
83// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
84// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
85
86// Leading signs are handled as unary operators.
87
88fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
89    /// returns `true` and consumes `X` bytes from the given byte buffer
90    /// if the given `X` nr of patterns are found at the start of the buffer
91    macro_rules! consume {
92        ($bytes:ident, $($pattern:pat),*) => {
93            match $bytes {
94                &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
95                _ => false,
96            }
97        };
98    }
99
100    /// consumes one byte from the given byte buffer
101    /// if one of the given patterns are found at the start of the buffer
102    /// returning the corresponding expr for the matched pattern
103    macro_rules! consume_map {
104        ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
105            match $bytes {
106                $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
107                _ => None,
108            }
109        };
110    }
111
112    /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
113    /// returning the number of consumed bytes
114    macro_rules! consume_dec_digits {
115        ($bytes:ident) => {{
116            let start_len = $bytes.len();
117            while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
118                $bytes = rest;
119            }
120            start_len - $bytes.len()
121        }};
122    }
123
124    /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
125    /// returning the number of consumed bytes
126    macro_rules! consume_hex_digits {
127        ($bytes:ident) => {{
128            let start_len = $bytes.len();
129            while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
130                $bytes = rest;
131            }
132            start_len - $bytes.len()
133        }};
134    }
135
136    macro_rules! consume_float_suffix {
137        ($bytes:ident) => {
138            consume_map!($bytes, [
139                b'h' => FloatKind::F16,
140                b'f' => FloatKind::F32,
141                b'l', b'f' => FloatKind::F64,
142            ])
143        };
144    }
145
146    /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
147    macro_rules! rest_to_str {
148        ($bytes:ident) => {
149            &input[input.len() - $bytes.len()..]
150        };
151    }
152
153    struct ExtractSubStr<'a>(&'a str);
154
155    impl<'a> ExtractSubStr<'a> {
156        /// given an `input` and a `start` (tail of the `input`)
157        /// creates a new [`ExtractSubStr`](`Self`)
158        fn start(input: &'a str, start: &'a [u8]) -> Self {
159            let start = input.len() - start.len();
160            Self(&input[start..])
161        }
162        /// given an `end` (tail of the initial `input`)
163        /// returns a substring of `input`
164        fn end(&self, end: &'a [u8]) -> &'a str {
165            let end = self.0.len() - end.len();
166            &self.0[..end]
167        }
168    }
169
170    let mut bytes = input.as_bytes();
171
172    let general_extract = ExtractSubStr::start(input, bytes);
173
174    if consume!(bytes, b'0', b'x' | b'X') {
175        let digits_extract = ExtractSubStr::start(input, bytes);
176
177        let consumed = consume_hex_digits!(bytes);
178
179        if consume!(bytes, b'.') {
180            let consumed_after_period = consume_hex_digits!(bytes);
181
182            if consumed + consumed_after_period == 0 {
183                return (Err(NumberError::Invalid), rest_to_str!(bytes));
184            }
185
186            let significand = general_extract.end(bytes);
187
188            if consume!(bytes, b'p' | b'P') {
189                consume!(bytes, b'+' | b'-');
190                let consumed = consume_dec_digits!(bytes);
191
192                if consumed == 0 {
193                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
194                }
195
196                let number = general_extract.end(bytes);
197
198                let kind = consume_float_suffix!(bytes);
199
200                (parse_hex_float(number, kind), rest_to_str!(bytes))
201            } else {
202                (
203                    parse_hex_float_missing_exponent(significand, None),
204                    rest_to_str!(bytes),
205                )
206            }
207        } else {
208            if consumed == 0 {
209                return (Err(NumberError::Invalid), rest_to_str!(bytes));
210            }
211
212            let significand = general_extract.end(bytes);
213            let digits = digits_extract.end(bytes);
214
215            let exp_extract = ExtractSubStr::start(input, bytes);
216
217            if consume!(bytes, b'p' | b'P') {
218                consume!(bytes, b'+' | b'-');
219                let consumed = consume_dec_digits!(bytes);
220
221                if consumed == 0 {
222                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
223                }
224
225                let exponent = exp_extract.end(bytes);
226
227                let kind = consume_float_suffix!(bytes);
228
229                (
230                    parse_hex_float_missing_period(significand, exponent, kind),
231                    rest_to_str!(bytes),
232                )
233            } else {
234                let kind = consume_map!(bytes, [
235                    b'i' => IntKind::I32,
236                    b'u' => IntKind::U32,
237                    b'l', b'i' => IntKind::I64,
238                    b'l', b'u' => IntKind::U64,
239                ]);
240
241                (parse_hex_int(digits, kind), rest_to_str!(bytes))
242            }
243        }
244    } else {
245        let is_first_zero = bytes.first() == Some(&b'0');
246
247        let consumed = consume_dec_digits!(bytes);
248
249        if consume!(bytes, b'.') {
250            let consumed_after_period = consume_dec_digits!(bytes);
251
252            if consumed + consumed_after_period == 0 {
253                return (Err(NumberError::Invalid), rest_to_str!(bytes));
254            }
255
256            if consume!(bytes, b'e' | b'E') {
257                consume!(bytes, b'+' | b'-');
258                let consumed = consume_dec_digits!(bytes);
259
260                if consumed == 0 {
261                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
262                }
263            }
264
265            let number = general_extract.end(bytes);
266
267            let kind = consume_float_suffix!(bytes);
268
269            (parse_dec_float(number, kind), rest_to_str!(bytes))
270        } else {
271            if consumed == 0 {
272                return (Err(NumberError::Invalid), rest_to_str!(bytes));
273            }
274
275            if consume!(bytes, b'e' | b'E') {
276                consume!(bytes, b'+' | b'-');
277                let consumed = consume_dec_digits!(bytes);
278
279                if consumed == 0 {
280                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
281                }
282
283                let number = general_extract.end(bytes);
284
285                let kind = consume_float_suffix!(bytes);
286
287                (parse_dec_float(number, kind), rest_to_str!(bytes))
288            } else {
289                // make sure the multi-digit numbers don't start with zero
290                if consumed > 1 && is_first_zero {
291                    return (Err(NumberError::Invalid), rest_to_str!(bytes));
292                }
293
294                let digits = general_extract.end(bytes);
295
296                let kind = consume_map!(bytes, [
297                    b'i' => Kind::Int(IntKind::I32),
298                    b'u' => Kind::Int(IntKind::U32),
299                    b'l', b'i' => Kind::Int(IntKind::I64),
300                    b'l', b'u' => Kind::Int(IntKind::U64),
301                    b'h' => Kind::Float(FloatKind::F16),
302                    b'f' => Kind::Float(FloatKind::F32),
303                    b'l', b'f' => Kind::Float(FloatKind::F64),
304                ]);
305
306                (parse_dec(digits, kind), rest_to_str!(bytes))
307            }
308        }
309    }
310}
311
312fn parse_hex_float_missing_exponent(
313    // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
314    significand: &str,
315    kind: Option<FloatKind>,
316) -> Result<Number, NumberError> {
317    let hexf_input = format!("{}{}", significand, "p0");
318    parse_hex_float(&hexf_input, kind)
319}
320
321fn parse_hex_float_missing_period(
322    // format: 0[xX] [0-9a-fA-F]+
323    significand: &str,
324    // format: [pP][+-]?[0-9]+
325    exponent: &str,
326    kind: Option<FloatKind>,
327) -> Result<Number, NumberError> {
328    let hexf_input = format!("{significand}.{exponent}");
329    parse_hex_float(&hexf_input, kind)
330}
331
332fn parse_hex_int(
333    // format: [0-9a-fA-F]+
334    digits: &str,
335    kind: Option<IntKind>,
336) -> Result<Number, NumberError> {
337    parse_int(digits, kind, 16)
338}
339
340fn parse_dec(
341    // format: ( [0-9] | [1-9][0-9]+ )
342    digits: &str,
343    kind: Option<Kind>,
344) -> Result<Number, NumberError> {
345    match kind {
346        None => parse_int(digits, None, 10),
347        Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
348        Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
349    }
350}
351
352// Float parsing notes
353
354// The following chapters of IEEE 754-2019 are relevant:
355//
356// 7.4 Overflow (largest finite number is exceeded by what would have been
357//     the rounded floating-point result were the exponent range unbounded)
358//
359// 7.5 Underflow (tiny non-zero result is detected;
360//     for decimal formats tininess is detected before rounding when a non-zero result
361//     computed as though both the exponent range and the precision were unbounded
362//     would lie strictly between 2^−126)
363//
364// 7.6 Inexact (rounded result differs from what would have been computed
365//     were both exponent range and precision unbounded)
366
367// The WGSL spec requires us to error:
368//   on overflow for decimal floating point literals
369//   on overflow and inexact for hexadecimal floating point literals
370// (underflow is not mentioned)
371
372// hexf_parse errors on overflow, underflow, inexact
373// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
374
375// Therefore we only check for overflow manually for decimal floating point literals
376
377// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
378fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
379    match kind {
380        None => match hexf_parse::parse_hexf64(input, false) {
381            Ok(num) => Ok(Number::AbstractFloat(num)),
382            // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
383            _ => Err(NumberError::NotRepresentable),
384        },
385        // TODO: f16 is not supported by hexf_parse
386        Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
387        Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
388            Ok(num) => Ok(Number::F32(num)),
389            // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
390            _ => Err(NumberError::NotRepresentable),
391        },
392        Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
393            Ok(num) => Ok(Number::F64(num)),
394            // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
395            _ => Err(NumberError::NotRepresentable),
396        },
397    }
398}
399
400// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
401//             | [0-9]+ [eE][+-]?[0-9]+
402fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
403    match kind {
404        None => {
405            let num = input.parse::<f64>().unwrap(); // will never fail
406            num.is_finite()
407                .then_some(Number::AbstractFloat(num))
408                .ok_or(NumberError::NotRepresentable)
409        }
410        Some(FloatKind::F32) => {
411            let num = input.parse::<f32>().unwrap(); // will never fail
412            num.is_finite()
413                .then_some(Number::F32(num))
414                .ok_or(NumberError::NotRepresentable)
415        }
416        Some(FloatKind::F64) => {
417            let num = input.parse::<f64>().unwrap(); // will never fail
418            num.is_finite()
419                .then_some(Number::F64(num))
420                .ok_or(NumberError::NotRepresentable)
421        }
422        Some(FloatKind::F16) => {
423            let num = input.parse::<f16>().unwrap(); // will never fail
424            num.is_finite()
425                .then_some(Number::F16(num))
426                .ok_or(NumberError::NotRepresentable)
427        }
428    }
429}
430
431fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
432    fn map_err(e: core::num::ParseIntError) -> NumberError {
433        match *e.kind() {
434            core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
435                NumberError::NotRepresentable
436            }
437            _ => unreachable!(),
438        }
439    }
440    match kind {
441        None => match i64::from_str_radix(input, radix) {
442            Ok(num) => Ok(Number::AbstractInt(num)),
443            Err(e) => Err(map_err(e)),
444        },
445        Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
446            Ok(num) => Ok(Number::I32(num)),
447            Err(e) => Err(map_err(e)),
448        },
449        Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
450            Ok(num) => Ok(Number::U32(num)),
451            Err(e) => Err(map_err(e)),
452        },
453        Some(IntKind::I64) => match i64::from_str_radix(input, radix) {
454            Ok(num) => Ok(Number::I64(num)),
455            Err(e) => Err(map_err(e)),
456        },
457        Some(IntKind::U64) => match u64::from_str_radix(input, radix) {
458            Ok(num) => Ok(Number::U64(num)),
459            Err(e) => Err(map_err(e)),
460        },
461    }
462}