1use alloc::format;
2
3use crate::front::wgsl::error::NumberError;
4use crate::front::wgsl::parse::directive::enable_extension::ImplementedEnableExtension;
5use crate::front::wgsl::parse::lexer::Token;
6use half::f16;
7
8#[derive(Copy, Clone, Debug, PartialEq)]
10pub enum Number {
11 AbstractInt(i64),
13 AbstractFloat(f64),
15 I32(i32),
17 U32(u32),
19 I64(i64),
21 U64(u64),
23 F16(f16),
25 F32(f32),
27 F64(f64),
29}
30
31impl Number {
32 pub(super) const fn requires_enable_extension(&self) -> Option<ImplementedEnableExtension> {
33 match *self {
34 Number::F16(_) => Some(ImplementedEnableExtension::F16),
35 _ => None,
36 }
37 }
38}
39
40pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
41 let (result, rest) = parse(input);
42 (Token::Number(result), rest)
43}
44
45enum Kind {
46 Int(IntKind),
47 Float(FloatKind),
48}
49
50enum IntKind {
51 I32,
52 U32,
53 I64,
54 U64,
55}
56
57#[derive(Debug)]
58enum FloatKind {
59 F16,
60 F32,
61 F64,
62}
63
64fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
89 macro_rules! consume {
92 ($bytes:ident, $($pattern:pat),*) => {
93 match $bytes {
94 &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
95 _ => false,
96 }
97 };
98 }
99
100 macro_rules! consume_map {
104 ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
105 match $bytes {
106 $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
107 _ => None,
108 }
109 };
110 }
111
112 macro_rules! consume_dec_digits {
115 ($bytes:ident) => {{
116 let start_len = $bytes.len();
117 while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
118 $bytes = rest;
119 }
120 start_len - $bytes.len()
121 }};
122 }
123
124 macro_rules! consume_hex_digits {
127 ($bytes:ident) => {{
128 let start_len = $bytes.len();
129 while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
130 $bytes = rest;
131 }
132 start_len - $bytes.len()
133 }};
134 }
135
136 macro_rules! consume_float_suffix {
137 ($bytes:ident) => {
138 consume_map!($bytes, [
139 b'h' => FloatKind::F16,
140 b'f' => FloatKind::F32,
141 b'l', b'f' => FloatKind::F64,
142 ])
143 };
144 }
145
146 macro_rules! rest_to_str {
148 ($bytes:ident) => {
149 &input[input.len() - $bytes.len()..]
150 };
151 }
152
153 struct ExtractSubStr<'a>(&'a str);
154
155 impl<'a> ExtractSubStr<'a> {
156 fn start(input: &'a str, start: &'a [u8]) -> Self {
159 let start = input.len() - start.len();
160 Self(&input[start..])
161 }
162 fn end(&self, end: &'a [u8]) -> &'a str {
165 let end = self.0.len() - end.len();
166 &self.0[..end]
167 }
168 }
169
170 let mut bytes = input.as_bytes();
171
172 let general_extract = ExtractSubStr::start(input, bytes);
173
174 if consume!(bytes, b'0', b'x' | b'X') {
175 let digits_extract = ExtractSubStr::start(input, bytes);
176
177 let consumed = consume_hex_digits!(bytes);
178
179 if consume!(bytes, b'.') {
180 let consumed_after_period = consume_hex_digits!(bytes);
181
182 if consumed + consumed_after_period == 0 {
183 return (Err(NumberError::Invalid), rest_to_str!(bytes));
184 }
185
186 let significand = general_extract.end(bytes);
187
188 if consume!(bytes, b'p' | b'P') {
189 consume!(bytes, b'+' | b'-');
190 let consumed = consume_dec_digits!(bytes);
191
192 if consumed == 0 {
193 return (Err(NumberError::Invalid), rest_to_str!(bytes));
194 }
195
196 let number = general_extract.end(bytes);
197
198 let kind = consume_float_suffix!(bytes);
199
200 (parse_hex_float(number, kind), rest_to_str!(bytes))
201 } else {
202 (
203 parse_hex_float_missing_exponent(significand, None),
204 rest_to_str!(bytes),
205 )
206 }
207 } else {
208 if consumed == 0 {
209 return (Err(NumberError::Invalid), rest_to_str!(bytes));
210 }
211
212 let significand = general_extract.end(bytes);
213 let digits = digits_extract.end(bytes);
214
215 let exp_extract = ExtractSubStr::start(input, bytes);
216
217 if consume!(bytes, b'p' | b'P') {
218 consume!(bytes, b'+' | b'-');
219 let consumed = consume_dec_digits!(bytes);
220
221 if consumed == 0 {
222 return (Err(NumberError::Invalid), rest_to_str!(bytes));
223 }
224
225 let exponent = exp_extract.end(bytes);
226
227 let kind = consume_float_suffix!(bytes);
228
229 (
230 parse_hex_float_missing_period(significand, exponent, kind),
231 rest_to_str!(bytes),
232 )
233 } else {
234 let kind = consume_map!(bytes, [
235 b'i' => IntKind::I32,
236 b'u' => IntKind::U32,
237 b'l', b'i' => IntKind::I64,
238 b'l', b'u' => IntKind::U64,
239 ]);
240
241 (parse_hex_int(digits, kind), rest_to_str!(bytes))
242 }
243 }
244 } else {
245 let is_first_zero = bytes.first() == Some(&b'0');
246
247 let consumed = consume_dec_digits!(bytes);
248
249 if consume!(bytes, b'.') {
250 let consumed_after_period = consume_dec_digits!(bytes);
251
252 if consumed + consumed_after_period == 0 {
253 return (Err(NumberError::Invalid), rest_to_str!(bytes));
254 }
255
256 if consume!(bytes, b'e' | b'E') {
257 consume!(bytes, b'+' | b'-');
258 let consumed = consume_dec_digits!(bytes);
259
260 if consumed == 0 {
261 return (Err(NumberError::Invalid), rest_to_str!(bytes));
262 }
263 }
264
265 let number = general_extract.end(bytes);
266
267 let kind = consume_float_suffix!(bytes);
268
269 (parse_dec_float(number, kind), rest_to_str!(bytes))
270 } else {
271 if consumed == 0 {
272 return (Err(NumberError::Invalid), rest_to_str!(bytes));
273 }
274
275 if consume!(bytes, b'e' | b'E') {
276 consume!(bytes, b'+' | b'-');
277 let consumed = consume_dec_digits!(bytes);
278
279 if consumed == 0 {
280 return (Err(NumberError::Invalid), rest_to_str!(bytes));
281 }
282
283 let number = general_extract.end(bytes);
284
285 let kind = consume_float_suffix!(bytes);
286
287 (parse_dec_float(number, kind), rest_to_str!(bytes))
288 } else {
289 if consumed > 1 && is_first_zero {
291 return (Err(NumberError::Invalid), rest_to_str!(bytes));
292 }
293
294 let digits = general_extract.end(bytes);
295
296 let kind = consume_map!(bytes, [
297 b'i' => Kind::Int(IntKind::I32),
298 b'u' => Kind::Int(IntKind::U32),
299 b'l', b'i' => Kind::Int(IntKind::I64),
300 b'l', b'u' => Kind::Int(IntKind::U64),
301 b'h' => Kind::Float(FloatKind::F16),
302 b'f' => Kind::Float(FloatKind::F32),
303 b'l', b'f' => Kind::Float(FloatKind::F64),
304 ]);
305
306 (parse_dec(digits, kind), rest_to_str!(bytes))
307 }
308 }
309 }
310}
311
312fn parse_hex_float_missing_exponent(
313 significand: &str,
315 kind: Option<FloatKind>,
316) -> Result<Number, NumberError> {
317 let hexf_input = format!("{}{}", significand, "p0");
318 parse_hex_float(&hexf_input, kind)
319}
320
321fn parse_hex_float_missing_period(
322 significand: &str,
324 exponent: &str,
326 kind: Option<FloatKind>,
327) -> Result<Number, NumberError> {
328 let hexf_input = format!("{significand}.{exponent}");
329 parse_hex_float(&hexf_input, kind)
330}
331
332fn parse_hex_int(
333 digits: &str,
335 kind: Option<IntKind>,
336) -> Result<Number, NumberError> {
337 parse_int(digits, kind, 16)
338}
339
340fn parse_dec(
341 digits: &str,
343 kind: Option<Kind>,
344) -> Result<Number, NumberError> {
345 match kind {
346 None => parse_int(digits, None, 10),
347 Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
348 Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
349 }
350}
351
352fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
378 match kind {
379 None => {
380 let (neg, mant, exp) = parse_hex_float_parts(input.as_bytes())?;
381 let bits = convert_hex_float(neg, mant, exp, F64)?;
382 let num = f64::from_bits(bits);
383
384 Ok(Number::AbstractFloat(num))
385 }
386 Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
388 Some(FloatKind::F32) => {
389 let (neg, mant, exp) = parse_hex_float_parts(input.as_bytes())?;
390 let bits = convert_hex_float(neg, mant, exp, F32)?;
391 let num = f32::from_bits(bits as u32);
392
393 Ok(Number::F32(num))
394 }
395 Some(FloatKind::F64) => {
396 let (neg, mant, exp) = parse_hex_float_parts(input.as_bytes())?;
397 let bits = convert_hex_float(neg, mant, exp, F64)?;
398 let num = f64::from_bits(bits);
399
400 Ok(Number::F64(num))
401 }
402 }
403}
404
405struct HexFloatFormat {
407 mant_bits: usize, precision: usize, bias: i32, max_exp: i32, exp_bits: usize, min_norm_exp: i32, }
414
415const F32: HexFloatFormat = HexFloatFormat {
416 mant_bits: 23,
417 precision: 24,
418 bias: 127,
419 max_exp: 127,
420 exp_bits: 8,
421 min_norm_exp: -126,
422};
423
424const F64: HexFloatFormat = HexFloatFormat {
425 mant_bits: 52,
426 precision: 53,
427 bias: 1023,
428 max_exp: 1023,
429 exp_bits: 11,
430 min_norm_exp: -1022,
431};
432
433fn parse_hex_float_parts(s: &[u8]) -> Result<(bool, u64, i32), NumberError> {
437 let (s, negative) = match s.split_first() {
438 Some((&b'+', s)) => (s, false),
439 Some((&b'-', s)) => (s, true),
440 Some(_) => (s, false),
441 None => return Err(NumberError::Invalid),
443 };
444
445 if !(s.starts_with(b"0x") || s.starts_with(b"0X")) {
446 return Err(NumberError::Invalid);
447 }
448
449 let mut s = &s[2..];
450 let mut acc: u128 = 0;
451 let mut digit_seen = false;
452
453 loop {
455 let (rest, digit) = match s.split_first() {
456 Some((&c @ b'0'..=b'9', s)) => (s, c - b'0'),
457 Some((&c @ b'a'..=b'f', s)) => (s, c - b'a' + 10),
458 Some((&c @ b'A'..=b'F', s)) => (s, c - b'A' + 10),
459 _ => break,
460 };
461 s = rest;
462 digit_seen = true;
463 acc = acc.checked_shl(4).ok_or(NumberError::NotRepresentable)? | digit as u128;
464 }
465
466 let mut nfracs: i32 = 0;
468 let mut frac_digit_seen = false;
469 if s.starts_with(b".") {
470 s = &s[1..];
471 loop {
472 let (rest, digit) = match s.split_first() {
473 Some((&c @ b'0'..=b'9', s)) => (s, c - b'0'),
474 Some((&c @ b'a'..=b'f', s)) => (s, c - b'a' + 10),
475 Some((&c @ b'A'..=b'F', s)) => (s, c - b'A' + 10),
476 _ => break,
477 };
478 s = rest;
479 frac_digit_seen = true;
480 acc = acc.checked_shl(4).ok_or(NumberError::NotRepresentable)? | digit as u128;
481 nfracs = nfracs.checked_add(1).ok_or(NumberError::NotRepresentable)?;
482 }
483 }
484
485 if !(digit_seen || frac_digit_seen) {
486 return Err(NumberError::Invalid);
487 }
488
489 let s = match s.split_first() {
491 Some((&b'P', s)) | Some((&b'p', s)) => s,
492 _ => return Err(NumberError::Invalid),
493 };
494
495 let (mut s, negative_exponent) = match s.split_first() {
497 Some((&b'+', s)) => (s, false),
498 Some((&b'-', s)) => (s, true),
499 Some(_) => (s, false),
500 None => return Err(NumberError::Invalid),
501 };
502
503 let mut digit_seen = false;
505 let mut exponent: i32 = 0;
506 loop {
507 let (rest, digit) = match s.split_first() {
508 Some((&c @ b'0'..=b'9', s)) => (s, c - b'0'),
509 None if digit_seen => break,
510 _ => return Err(NumberError::Invalid),
511 };
512 s = rest;
513 digit_seen = true;
514
515 if acc != 0 {
517 exponent = exponent
518 .checked_mul(10)
519 .and_then(|v| v.checked_add(digit as i32))
520 .ok_or(NumberError::NotRepresentable)?;
521 }
522 }
523
524 if negative_exponent {
525 exponent = -exponent;
526 }
527
528 if acc == 0 {
529 return Ok((negative, 0, 0));
530 }
531
532 let exp_adj = nfracs.checked_mul(4).ok_or(NumberError::NotRepresentable)?;
534 let exponent = exponent
535 .checked_sub(exp_adj)
536 .ok_or(NumberError::NotRepresentable)?;
537
538 let mut mant = acc;
540 let mut extra_shift = 0i32;
541 while mant > 0 && (mant & 0xF) == 0 {
542 mant >>= 4;
543 extra_shift = extra_shift
544 .checked_add(4)
545 .ok_or(NumberError::NotRepresentable)?;
546 }
547
548 if mant > u64::MAX as u128 {
550 return Err(NumberError::NotRepresentable);
551 }
552
553 let exponent = exponent
554 .checked_add(extra_shift)
555 .ok_or(NumberError::NotRepresentable)?;
556
557 Ok((negative, mant as u64, exponent))
558}
559
560fn convert_hex_float(
561 negative: bool,
562 mant: u64,
563 exp: i32,
564 fmt: HexFloatFormat,
565) -> Result<u64, NumberError> {
566 let sign_shift = fmt.mant_bits + fmt.exp_bits;
567 let sign = (negative as u64) << sign_shift;
568
569 if mant == 0 {
570 return Ok(sign);
571 }
572
573 let k = 63usize - mant.leading_zeros() as usize;
574 let normalexp = exp
575 .checked_add(k as i32)
576 .ok_or(NumberError::NotRepresentable)?;
577
578 if normalexp > fmt.max_exp {
579 return Err(NumberError::NotRepresentable);
580 }
581
582 let shift = k as i32 - ((fmt.precision as i32) - 1);
584 let mut mant_field: u64;
585
586 if normalexp >= fmt.min_norm_exp {
587 if shift > 0 {
589 if shift >= 64 || (mant & ((1u64 << shift) - 1)) != 0 {
590 return Err(NumberError::NotRepresentable);
591 }
592 mant_field = mant >> shift;
593 } else {
594 mant_field = mant << -shift;
595 }
596
597 mant_field &= (1u64 << fmt.mant_bits) - 1;
598 let expo_field = (normalexp + fmt.bias) as u64;
599
600 Ok(sign | (expo_field << fmt.mant_bits) | mant_field)
601 } else {
602 let shift_sub = exp - (fmt.min_norm_exp - ((fmt.precision as i32) - 1));
604 if shift_sub < 0 {
605 let rs = (-shift_sub) as usize;
606 if rs >= 64 || (mant & ((1u64 << rs) - 1)) != 0 {
607 return Err(NumberError::NotRepresentable);
608 }
609 mant_field = mant >> rs;
610 } else {
611 mant_field = mant << shift_sub as u32;
612 if (mant_field >> fmt.mant_bits) != 0 {
613 return Err(NumberError::NotRepresentable);
614 }
615 }
616
617 if mant_field == 0 {
618 return Err(NumberError::NotRepresentable);
619 }
620
621 Ok(sign | (mant_field & ((1u64 << fmt.mant_bits) - 1)))
622 }
623}
624
625fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
628 match kind {
629 None => {
630 let num = input.parse::<f64>().unwrap(); num.is_finite()
632 .then_some(Number::AbstractFloat(num))
633 .ok_or(NumberError::NotRepresentable)
634 }
635 Some(FloatKind::F32) => {
636 let num = input.parse::<f32>().unwrap(); num.is_finite()
638 .then_some(Number::F32(num))
639 .ok_or(NumberError::NotRepresentable)
640 }
641 Some(FloatKind::F64) => {
642 let num = input.parse::<f64>().unwrap(); num.is_finite()
644 .then_some(Number::F64(num))
645 .ok_or(NumberError::NotRepresentable)
646 }
647 Some(FloatKind::F16) => {
648 let num = input.parse::<f16>().unwrap(); num.is_finite()
650 .then_some(Number::F16(num))
651 .ok_or(NumberError::NotRepresentable)
652 }
653 }
654}
655
656fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
657 fn map_err(e: core::num::ParseIntError) -> NumberError {
658 match *e.kind() {
659 core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
660 NumberError::NotRepresentable
661 }
662 _ => unreachable!(),
663 }
664 }
665 match kind {
666 None => match i64::from_str_radix(input, radix) {
667 Ok(num) => Ok(Number::AbstractInt(num)),
668 Err(e) => Err(map_err(e)),
669 },
670 Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
671 Ok(num) => Ok(Number::I32(num)),
672 Err(e) => Err(map_err(e)),
673 },
674 Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
675 Ok(num) => Ok(Number::U32(num)),
676 Err(e) => Err(map_err(e)),
677 },
678 Some(IntKind::I64) => match i64::from_str_radix(input, radix) {
679 Ok(num) => Ok(Number::I64(num)),
680 Err(e) => Err(map_err(e)),
681 },
682 Some(IntKind::U64) => match u64::from_str_radix(input, radix) {
683 Ok(num) => Ok(Number::U64(num)),
684 Err(e) => Err(map_err(e)),
685 },
686 }
687}