1use alloc::format;
2
3use crate::front::wgsl::error::NumberError;
4use crate::front::wgsl::parse::directive::enable_extension::ImplementedEnableExtension;
5use crate::front::wgsl::parse::lexer::Token;
6use half::f16;
7
8#[derive(Copy, Clone, Debug, PartialEq)]
10pub enum Number {
11 AbstractInt(i64),
13 AbstractFloat(f64),
15 I32(i32),
17 U32(u32),
19 I64(i64),
21 U64(u64),
23 F16(f16),
25 F32(f32),
27 F64(f64),
29}
30
31impl Number {
32 pub(super) const fn requires_enable_extension(&self) -> Option<ImplementedEnableExtension> {
33 match *self {
34 Number::F16(_) => Some(ImplementedEnableExtension::F16),
35 _ => None,
36 }
37 }
38}
39
40pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
41 let (result, rest) = parse(input);
42 (Token::Number(result), rest)
43}
44
45enum Kind {
46 Int(IntKind),
47 Float(FloatKind),
48}
49
50enum IntKind {
51 I32,
52 U32,
53 I64,
54 U64,
55}
56
57#[derive(Debug)]
58enum FloatKind {
59 F16,
60 F32,
61 F64,
62}
63
64fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
89 macro_rules! consume {
92 ($bytes:ident, $($pattern:pat),*) => {
93 match $bytes {
94 &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
95 _ => false,
96 }
97 };
98 }
99
100 macro_rules! consume_map {
104 ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
105 match $bytes {
106 $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
107 _ => None,
108 }
109 };
110 }
111
112 macro_rules! consume_dec_digits {
115 ($bytes:ident) => {{
116 let start_len = $bytes.len();
117 while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
118 $bytes = rest;
119 }
120 start_len - $bytes.len()
121 }};
122 }
123
124 macro_rules! consume_hex_digits {
127 ($bytes:ident) => {{
128 let start_len = $bytes.len();
129 while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
130 $bytes = rest;
131 }
132 start_len - $bytes.len()
133 }};
134 }
135
136 macro_rules! consume_float_suffix {
137 ($bytes:ident) => {
138 consume_map!($bytes, [
139 b'h' => FloatKind::F16,
140 b'f' => FloatKind::F32,
141 b'l', b'f' => FloatKind::F64,
142 ])
143 };
144 }
145
146 macro_rules! rest_to_str {
148 ($bytes:ident) => {
149 &input[input.len() - $bytes.len()..]
150 };
151 }
152
153 struct ExtractSubStr<'a>(&'a str);
154
155 impl<'a> ExtractSubStr<'a> {
156 fn start(input: &'a str, start: &'a [u8]) -> Self {
159 let start = input.len() - start.len();
160 Self(&input[start..])
161 }
162 fn end(&self, end: &'a [u8]) -> &'a str {
165 let end = self.0.len() - end.len();
166 &self.0[..end]
167 }
168 }
169
170 let mut bytes = input.as_bytes();
171
172 let general_extract = ExtractSubStr::start(input, bytes);
173
174 if consume!(bytes, b'0', b'x' | b'X') {
175 let digits_extract = ExtractSubStr::start(input, bytes);
176
177 let consumed = consume_hex_digits!(bytes);
178
179 if consume!(bytes, b'.') {
180 let consumed_after_period = consume_hex_digits!(bytes);
181
182 if consumed + consumed_after_period == 0 {
183 return (Err(NumberError::Invalid), rest_to_str!(bytes));
184 }
185
186 let significand = general_extract.end(bytes);
187
188 if consume!(bytes, b'p' | b'P') {
189 consume!(bytes, b'+' | b'-');
190 let consumed = consume_dec_digits!(bytes);
191
192 if consumed == 0 {
193 return (Err(NumberError::Invalid), rest_to_str!(bytes));
194 }
195
196 let number = general_extract.end(bytes);
197
198 let kind = consume_float_suffix!(bytes);
199
200 (parse_hex_float(number, kind), rest_to_str!(bytes))
201 } else {
202 (
203 parse_hex_float_missing_exponent(significand, None),
204 rest_to_str!(bytes),
205 )
206 }
207 } else {
208 if consumed == 0 {
209 return (Err(NumberError::Invalid), rest_to_str!(bytes));
210 }
211
212 let significand = general_extract.end(bytes);
213 let digits = digits_extract.end(bytes);
214
215 let exp_extract = ExtractSubStr::start(input, bytes);
216
217 if consume!(bytes, b'p' | b'P') {
218 consume!(bytes, b'+' | b'-');
219 let consumed = consume_dec_digits!(bytes);
220
221 if consumed == 0 {
222 return (Err(NumberError::Invalid), rest_to_str!(bytes));
223 }
224
225 let exponent = exp_extract.end(bytes);
226
227 let kind = consume_float_suffix!(bytes);
228
229 (
230 parse_hex_float_missing_period(significand, exponent, kind),
231 rest_to_str!(bytes),
232 )
233 } else {
234 let kind = consume_map!(bytes, [
235 b'i' => IntKind::I32,
236 b'u' => IntKind::U32,
237 b'l', b'i' => IntKind::I64,
238 b'l', b'u' => IntKind::U64,
239 ]);
240
241 (parse_hex_int(digits, kind), rest_to_str!(bytes))
242 }
243 }
244 } else {
245 let is_first_zero = bytes.first() == Some(&b'0');
246
247 let consumed = consume_dec_digits!(bytes);
248
249 if consume!(bytes, b'.') {
250 let consumed_after_period = consume_dec_digits!(bytes);
251
252 if consumed + consumed_after_period == 0 {
253 return (Err(NumberError::Invalid), rest_to_str!(bytes));
254 }
255
256 if consume!(bytes, b'e' | b'E') {
257 consume!(bytes, b'+' | b'-');
258 let consumed = consume_dec_digits!(bytes);
259
260 if consumed == 0 {
261 return (Err(NumberError::Invalid), rest_to_str!(bytes));
262 }
263 }
264
265 let number = general_extract.end(bytes);
266
267 let kind = consume_float_suffix!(bytes);
268
269 (parse_dec_float(number, kind), rest_to_str!(bytes))
270 } else {
271 if consumed == 0 {
272 return (Err(NumberError::Invalid), rest_to_str!(bytes));
273 }
274
275 if consume!(bytes, b'e' | b'E') {
276 consume!(bytes, b'+' | b'-');
277 let consumed = consume_dec_digits!(bytes);
278
279 if consumed == 0 {
280 return (Err(NumberError::Invalid), rest_to_str!(bytes));
281 }
282
283 let number = general_extract.end(bytes);
284
285 let kind = consume_float_suffix!(bytes);
286
287 (parse_dec_float(number, kind), rest_to_str!(bytes))
288 } else {
289 if consumed > 1 && is_first_zero {
291 return (Err(NumberError::Invalid), rest_to_str!(bytes));
292 }
293
294 let digits = general_extract.end(bytes);
295
296 let kind = consume_map!(bytes, [
297 b'i' => Kind::Int(IntKind::I32),
298 b'u' => Kind::Int(IntKind::U32),
299 b'l', b'i' => Kind::Int(IntKind::I64),
300 b'l', b'u' => Kind::Int(IntKind::U64),
301 b'h' => Kind::Float(FloatKind::F16),
302 b'f' => Kind::Float(FloatKind::F32),
303 b'l', b'f' => Kind::Float(FloatKind::F64),
304 ]);
305
306 (parse_dec(digits, kind), rest_to_str!(bytes))
307 }
308 }
309 }
310}
311
312fn parse_hex_float_missing_exponent(
313 significand: &str,
315 kind: Option<FloatKind>,
316) -> Result<Number, NumberError> {
317 let hexf_input = format!("{}{}", significand, "p0");
318 parse_hex_float(&hexf_input, kind)
319}
320
321fn parse_hex_float_missing_period(
322 significand: &str,
324 exponent: &str,
326 kind: Option<FloatKind>,
327) -> Result<Number, NumberError> {
328 let hexf_input = format!("{significand}.{exponent}");
329 parse_hex_float(&hexf_input, kind)
330}
331
332fn parse_hex_int(
333 digits: &str,
335 kind: Option<IntKind>,
336) -> Result<Number, NumberError> {
337 parse_int(digits, kind, 16)
338}
339
340fn parse_dec(
341 digits: &str,
343 kind: Option<Kind>,
344) -> Result<Number, NumberError> {
345 match kind {
346 None => parse_int(digits, None, 10),
347 Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
348 Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
349 }
350}
351
352fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
379 match kind {
380 None => match hexf_parse::parse_hexf64(input, false) {
381 Ok(num) => Ok(Number::AbstractFloat(num)),
382 _ => Err(NumberError::NotRepresentable),
384 },
385 Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
387 Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
388 Ok(num) => Ok(Number::F32(num)),
389 _ => Err(NumberError::NotRepresentable),
391 },
392 Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
393 Ok(num) => Ok(Number::F64(num)),
394 _ => Err(NumberError::NotRepresentable),
396 },
397 }
398}
399
400fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
403 match kind {
404 None => {
405 let num = input.parse::<f64>().unwrap(); num.is_finite()
407 .then_some(Number::AbstractFloat(num))
408 .ok_or(NumberError::NotRepresentable)
409 }
410 Some(FloatKind::F32) => {
411 let num = input.parse::<f32>().unwrap(); num.is_finite()
413 .then_some(Number::F32(num))
414 .ok_or(NumberError::NotRepresentable)
415 }
416 Some(FloatKind::F64) => {
417 let num = input.parse::<f64>().unwrap(); num.is_finite()
419 .then_some(Number::F64(num))
420 .ok_or(NumberError::NotRepresentable)
421 }
422 Some(FloatKind::F16) => {
423 let num = input.parse::<f16>().unwrap(); num.is_finite()
425 .then_some(Number::F16(num))
426 .ok_or(NumberError::NotRepresentable)
427 }
428 }
429}
430
431fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
432 fn map_err(e: core::num::ParseIntError) -> NumberError {
433 match *e.kind() {
434 core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
435 NumberError::NotRepresentable
436 }
437 _ => unreachable!(),
438 }
439 }
440 match kind {
441 None => match i64::from_str_radix(input, radix) {
442 Ok(num) => Ok(Number::AbstractInt(num)),
443 Err(e) => Err(map_err(e)),
444 },
445 Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
446 Ok(num) => Ok(Number::I32(num)),
447 Err(e) => Err(map_err(e)),
448 },
449 Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
450 Ok(num) => Ok(Number::U32(num)),
451 Err(e) => Err(map_err(e)),
452 },
453 Some(IntKind::I64) => match i64::from_str_radix(input, radix) {
454 Ok(num) => Ok(Number::I64(num)),
455 Err(e) => Err(map_err(e)),
456 },
457 Some(IntKind::U64) => match u64::from_str_radix(input, radix) {
458 Ok(num) => Ok(Number::U64(num)),
459 Err(e) => Err(map_err(e)),
460 },
461 }
462}