1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 compact::{compact, KeepUnused},
14 ir,
15 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18 Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28 MissingValue(String),
29 #[error(
30 "Source f64 value needs to be finite ({}) for number destinations",
31 "NaNs and Inifinites are not allowed"
32 )]
33 SrcNeedsToBeFinite,
34 #[error("Source f64 value doesn't fit in destination")]
35 DstRangeTooSmall,
36 #[error(transparent)]
37 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38 #[error(transparent)]
39 ValidationError(#[from] WithSpan<ValidationError>),
40 #[error("workgroup_size override isn't strictly positive")]
41 NegativeWorkgroupSize,
42}
43
44pub fn process_overrides<'a>(
57 module: &'a Module,
58 module_info: &'a ModuleInfo,
59 entry_point: Option<(ir::ShaderStage, &str)>,
60 pipeline_constants: &PipelineConstants,
61) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
62 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
63 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
73 }
74
75 let mut module = module.clone();
76 if let Some((ep_stage, ep_name)) = entry_point {
77 module
78 .entry_points
79 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
80 }
81
82 compact(&mut module, KeepUnused::No);
86
87 if module.overrides.is_empty() {
89 return revalidate(module);
90 }
91
92 let mut override_map = HandleVec::with_capacity(module.overrides.len());
95
96 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
99
100 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
111
112 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
113 let mut layouter = crate::proc::Layouter::default();
114
115 let mut overrides = mem::take(&mut module.overrides);
118 let mut override_iter = overrides.iter_mut_span();
119
120 for (old_h, expr, span) in module.global_expressions.drain() {
150 let mut expr = match expr {
151 Expression::Override(h) => {
152 let c_h = if let Some(new_h) = override_map.get(h) {
153 *new_h
154 } else {
155 let mut new_h = None;
156 for entry in override_iter.by_ref() {
157 let stop = entry.0 == h;
158 new_h = Some(process_override(
159 entry,
160 pipeline_constants,
161 &mut module,
162 &mut override_map,
163 &adjusted_global_expressions,
164 &mut adjusted_constant_initializers,
165 &mut global_expression_kind_tracker,
166 )?);
167 if stop {
168 break;
169 }
170 }
171 new_h.unwrap()
172 };
173 Expression::Constant(c_h)
174 }
175 Expression::Constant(c_h) => {
176 if adjusted_constant_initializers.insert(c_h) {
177 let init = &mut module.constants[c_h].init;
178 *init = adjusted_global_expressions[*init];
179 }
180 expr
181 }
182 expr => expr,
183 };
184 let mut evaluator = ConstantEvaluator::for_wgsl_module(
185 &mut module,
186 &mut global_expression_kind_tracker,
187 &mut layouter,
188 false,
189 );
190 adjust_expr(&adjusted_global_expressions, &mut expr);
191 let h = evaluator.try_eval_and_append(expr, span)?;
192 adjusted_global_expressions.insert(old_h, h);
193 }
194
195 for entry in override_iter {
197 match *entry.1 {
198 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
199 process_override(
200 entry,
201 pipeline_constants,
202 &mut module,
203 &mut override_map,
204 &adjusted_global_expressions,
205 &mut adjusted_constant_initializers,
206 &mut global_expression_kind_tracker,
207 )?;
208 }
209 Override {
210 init: Some(ref mut init),
211 ..
212 } => {
213 *init = adjusted_global_expressions[*init];
214 }
215 _ => {}
216 }
217 }
218
219 for (_, c) in module
223 .constants
224 .iter_mut()
225 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
226 {
227 c.init = adjusted_global_expressions[c.init];
228 }
229
230 for (_, v) in module.global_variables.iter_mut() {
231 if let Some(ref mut init) = v.init {
232 *init = adjusted_global_expressions[*init];
233 }
234 }
235
236 let mut functions = mem::take(&mut module.functions);
237 for (_, function) in functions.iter_mut() {
238 process_function(&mut module, &override_map, &mut layouter, function)?;
239 }
240 module.functions = functions;
241
242 let mut entry_points = mem::take(&mut module.entry_points);
243 for ep in entry_points.iter_mut() {
244 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
245 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
246 }
247 module.entry_points = entry_points;
248 module.overrides = overrides;
249
250 revalidate(module)
254}
255
256fn revalidate(
257 module: Module,
258) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
259 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
260 let module_info = validator.validate_resolved_overrides(&module)?;
261 Ok((Cow::Owned(module), Cow::Owned(module_info)))
262}
263
264fn process_workgroup_size_override(
265 module: &mut Module,
266 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
267 ep: &mut crate::EntryPoint,
268) -> Result<(), PipelineConstantError> {
269 match ep.workgroup_size_overrides {
270 None => {}
271 Some(overrides) => {
272 overrides.iter().enumerate().try_for_each(
273 |(i, overridden)| -> Result<(), PipelineConstantError> {
274 match *overridden {
275 None => Ok(()),
276 Some(h) => {
277 ep.workgroup_size[i] = module
278 .to_ctx()
279 .eval_expr_to_u32(adjusted_global_expressions[h])
280 .map(|n| {
281 if n == 0 {
282 Err(PipelineConstantError::NegativeWorkgroupSize)
283 } else {
284 Ok(n)
285 }
286 })
287 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
288 Ok(())
289 }
290 }
291 },
292 )?;
293 ep.workgroup_size_overrides = None;
294 }
295 }
296 Ok(())
297}
298
299fn process_override(
303 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
304 pipeline_constants: &PipelineConstants,
305 module: &mut Module,
306 override_map: &mut HandleVec<Override, Handle<Constant>>,
307 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
308 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
309 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
310) -> Result<Handle<Constant>, PipelineConstantError> {
311 let key = if let Some(id) = r#override.id {
313 Cow::Owned(id.to_string())
314 } else if let Some(ref name) = r#override.name {
315 Cow::Borrowed(name)
316 } else {
317 unreachable!();
318 };
319
320 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
324 let literal = match module.types[r#override.ty].inner {
325 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
326 _ => unreachable!(),
327 };
328 let expr = module
329 .global_expressions
330 .append(Expression::Literal(literal), Span::UNDEFINED);
331 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
332 expr
333 } else if let Some(init) = r#override.init {
334 adjusted_global_expressions[init]
335 } else {
336 return Err(PipelineConstantError::MissingValue(key.to_string()));
337 };
338
339 let constant = Constant {
341 name: r#override.name.clone(),
342 ty: r#override.ty,
343 init,
344 };
345 let h = module.constants.append(constant, *span);
346 override_map.insert(old_h, h);
347 adjusted_constant_initializers.insert(h);
348 r#override.init = Some(init);
349 Ok(h)
350}
351
352fn process_function(
362 module: &mut Module,
363 override_map: &HandleVec<Override, Handle<Constant>>,
364 layouter: &mut crate::proc::Layouter,
365 function: &mut Function,
366) -> Result<(), ConstantEvaluatorError> {
367 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
370
371 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
372
373 let mut expressions = mem::take(&mut function.expressions);
374
375 let mut emitter = Emitter::default();
383 let mut block = Block::new();
384
385 let mut evaluator = ConstantEvaluator::for_wgsl_function(
386 module,
387 &mut function.expressions,
388 &mut local_expression_kind_tracker,
389 layouter,
390 &mut emitter,
391 &mut block,
392 false,
393 );
394
395 for (old_h, mut expr, span) in expressions.drain() {
396 if let Expression::Override(h) = expr {
397 expr = Expression::Constant(override_map[h]);
398 }
399 adjust_expr(&adjusted_local_expressions, &mut expr);
400 let h = evaluator.try_eval_and_append(expr, span)?;
401 adjusted_local_expressions.insert(old_h, h);
402 }
403
404 adjust_block(&adjusted_local_expressions, &mut function.body);
405
406 filter_emits_in_block(&mut function.body, &function.expressions);
407
408 for (_, local) in function.local_variables.iter_mut() {
410 if let &mut Some(ref mut init) = &mut local.init {
411 *init = adjusted_local_expressions[*init];
412 }
413 }
414
415 let named_expressions = mem::take(&mut function.named_expressions);
418 for (expr_h, name) in named_expressions {
419 function
420 .named_expressions
421 .insert(adjusted_local_expressions[expr_h], name);
422 }
423
424 Ok(())
425}
426
427fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
430 let adjust = |expr: &mut Handle<Expression>| {
431 *expr = new_pos[*expr];
432 };
433 match *expr {
434 Expression::Compose {
435 ref mut components,
436 ty: _,
437 } => {
438 for c in components.iter_mut() {
439 adjust(c);
440 }
441 }
442 Expression::Access {
443 ref mut base,
444 ref mut index,
445 } => {
446 adjust(base);
447 adjust(index);
448 }
449 Expression::AccessIndex {
450 ref mut base,
451 index: _,
452 } => {
453 adjust(base);
454 }
455 Expression::Splat {
456 ref mut value,
457 size: _,
458 } => {
459 adjust(value);
460 }
461 Expression::Swizzle {
462 ref mut vector,
463 size: _,
464 pattern: _,
465 } => {
466 adjust(vector);
467 }
468 Expression::Load { ref mut pointer } => {
469 adjust(pointer);
470 }
471 Expression::ImageSample {
472 ref mut image,
473 ref mut sampler,
474 ref mut coordinate,
475 ref mut array_index,
476 ref mut offset,
477 ref mut level,
478 ref mut depth_ref,
479 gather: _,
480 clamp_to_edge: _,
481 } => {
482 adjust(image);
483 adjust(sampler);
484 adjust(coordinate);
485 if let Some(e) = array_index.as_mut() {
486 adjust(e);
487 }
488 if let Some(e) = offset.as_mut() {
489 adjust(e);
490 }
491 match *level {
492 crate::SampleLevel::Exact(ref mut expr)
493 | crate::SampleLevel::Bias(ref mut expr) => {
494 adjust(expr);
495 }
496 crate::SampleLevel::Gradient {
497 ref mut x,
498 ref mut y,
499 } => {
500 adjust(x);
501 adjust(y);
502 }
503 _ => {}
504 }
505 if let Some(e) = depth_ref.as_mut() {
506 adjust(e);
507 }
508 }
509 Expression::ImageLoad {
510 ref mut image,
511 ref mut coordinate,
512 ref mut array_index,
513 ref mut sample,
514 ref mut level,
515 } => {
516 adjust(image);
517 adjust(coordinate);
518 if let Some(e) = array_index.as_mut() {
519 adjust(e);
520 }
521 if let Some(e) = sample.as_mut() {
522 adjust(e);
523 }
524 if let Some(e) = level.as_mut() {
525 adjust(e);
526 }
527 }
528 Expression::ImageQuery {
529 ref mut image,
530 ref mut query,
531 } => {
532 adjust(image);
533 match *query {
534 crate::ImageQuery::Size { ref mut level } => {
535 if let Some(e) = level.as_mut() {
536 adjust(e);
537 }
538 }
539 crate::ImageQuery::NumLevels
540 | crate::ImageQuery::NumLayers
541 | crate::ImageQuery::NumSamples => {}
542 }
543 }
544 Expression::Unary {
545 ref mut expr,
546 op: _,
547 } => {
548 adjust(expr);
549 }
550 Expression::Binary {
551 ref mut left,
552 ref mut right,
553 op: _,
554 } => {
555 adjust(left);
556 adjust(right);
557 }
558 Expression::Select {
559 ref mut condition,
560 ref mut accept,
561 ref mut reject,
562 } => {
563 adjust(condition);
564 adjust(accept);
565 adjust(reject);
566 }
567 Expression::Derivative {
568 ref mut expr,
569 axis: _,
570 ctrl: _,
571 } => {
572 adjust(expr);
573 }
574 Expression::Relational {
575 ref mut argument,
576 fun: _,
577 } => {
578 adjust(argument);
579 }
580 Expression::Math {
581 ref mut arg,
582 ref mut arg1,
583 ref mut arg2,
584 ref mut arg3,
585 fun: _,
586 } => {
587 adjust(arg);
588 if let Some(e) = arg1.as_mut() {
589 adjust(e);
590 }
591 if let Some(e) = arg2.as_mut() {
592 adjust(e);
593 }
594 if let Some(e) = arg3.as_mut() {
595 adjust(e);
596 }
597 }
598 Expression::As {
599 ref mut expr,
600 kind: _,
601 convert: _,
602 } => {
603 adjust(expr);
604 }
605 Expression::ArrayLength(ref mut expr) => {
606 adjust(expr);
607 }
608 Expression::RayQueryGetIntersection {
609 ref mut query,
610 committed: _,
611 } => {
612 adjust(query);
613 }
614 Expression::Literal(_)
615 | Expression::FunctionArgument(_)
616 | Expression::GlobalVariable(_)
617 | Expression::LocalVariable(_)
618 | Expression::CallResult(_)
619 | Expression::RayQueryProceedResult
620 | Expression::Constant(_)
621 | Expression::Override(_)
622 | Expression::ZeroValue(_)
623 | Expression::AtomicResult {
624 ty: _,
625 comparison: _,
626 }
627 | Expression::WorkGroupUniformLoadResult { ty: _ }
628 | Expression::SubgroupBallotResult
629 | Expression::SubgroupOperationResult { .. } => {}
630 Expression::RayQueryVertexPositions {
631 ref mut query,
632 committed: _,
633 } => {
634 adjust(query);
635 }
636 }
637}
638
639fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
642 for stmt in block.iter_mut() {
643 adjust_stmt(new_pos, stmt);
644 }
645}
646
647fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
650 let adjust = |expr: &mut Handle<Expression>| {
651 *expr = new_pos[*expr];
652 };
653 match *stmt {
654 Statement::Emit(ref mut range) => {
655 if let Some((mut first, mut last)) = range.first_and_last() {
656 adjust(&mut first);
657 adjust(&mut last);
658 *range = Range::new_from_bounds(first, last);
659 }
660 }
661 Statement::Block(ref mut block) => {
662 adjust_block(new_pos, block);
663 }
664 Statement::If {
665 ref mut condition,
666 ref mut accept,
667 ref mut reject,
668 } => {
669 adjust(condition);
670 adjust_block(new_pos, accept);
671 adjust_block(new_pos, reject);
672 }
673 Statement::Switch {
674 ref mut selector,
675 ref mut cases,
676 } => {
677 adjust(selector);
678 for case in cases.iter_mut() {
679 adjust_block(new_pos, &mut case.body);
680 }
681 }
682 Statement::Loop {
683 ref mut body,
684 ref mut continuing,
685 ref mut break_if,
686 } => {
687 adjust_block(new_pos, body);
688 adjust_block(new_pos, continuing);
689 if let Some(e) = break_if.as_mut() {
690 adjust(e);
691 }
692 }
693 Statement::Return { ref mut value } => {
694 if let Some(e) = value.as_mut() {
695 adjust(e);
696 }
697 }
698 Statement::Store {
699 ref mut pointer,
700 ref mut value,
701 } => {
702 adjust(pointer);
703 adjust(value);
704 }
705 Statement::ImageStore {
706 ref mut image,
707 ref mut coordinate,
708 ref mut array_index,
709 ref mut value,
710 } => {
711 adjust(image);
712 adjust(coordinate);
713 if let Some(e) = array_index.as_mut() {
714 adjust(e);
715 }
716 adjust(value);
717 }
718 Statement::Atomic {
719 ref mut pointer,
720 ref mut value,
721 ref mut result,
722 ref mut fun,
723 } => {
724 adjust(pointer);
725 adjust(value);
726 if let Some(ref mut result) = *result {
727 adjust(result);
728 }
729 match *fun {
730 crate::AtomicFunction::Exchange {
731 compare: Some(ref mut compare),
732 } => {
733 adjust(compare);
734 }
735 crate::AtomicFunction::Add
736 | crate::AtomicFunction::Subtract
737 | crate::AtomicFunction::And
738 | crate::AtomicFunction::ExclusiveOr
739 | crate::AtomicFunction::InclusiveOr
740 | crate::AtomicFunction::Min
741 | crate::AtomicFunction::Max
742 | crate::AtomicFunction::Exchange { compare: None } => {}
743 }
744 }
745 Statement::ImageAtomic {
746 ref mut image,
747 ref mut coordinate,
748 ref mut array_index,
749 fun: _,
750 ref mut value,
751 } => {
752 adjust(image);
753 adjust(coordinate);
754 if let Some(ref mut array_index) = *array_index {
755 adjust(array_index);
756 }
757 adjust(value);
758 }
759 Statement::WorkGroupUniformLoad {
760 ref mut pointer,
761 ref mut result,
762 } => {
763 adjust(pointer);
764 adjust(result);
765 }
766 Statement::SubgroupBallot {
767 ref mut result,
768 ref mut predicate,
769 } => {
770 if let Some(ref mut predicate) = *predicate {
771 adjust(predicate);
772 }
773 adjust(result);
774 }
775 Statement::SubgroupCollectiveOperation {
776 ref mut argument,
777 ref mut result,
778 ..
779 } => {
780 adjust(argument);
781 adjust(result);
782 }
783 Statement::SubgroupGather {
784 ref mut mode,
785 ref mut argument,
786 ref mut result,
787 } => {
788 match *mode {
789 crate::GatherMode::BroadcastFirst => {}
790 crate::GatherMode::Broadcast(ref mut index)
791 | crate::GatherMode::Shuffle(ref mut index)
792 | crate::GatherMode::ShuffleDown(ref mut index)
793 | crate::GatherMode::ShuffleUp(ref mut index)
794 | crate::GatherMode::ShuffleXor(ref mut index)
795 | crate::GatherMode::QuadBroadcast(ref mut index) => {
796 adjust(index);
797 }
798 crate::GatherMode::QuadSwap(_) => {}
799 }
800 adjust(argument);
801 adjust(result)
802 }
803 Statement::Call {
804 ref mut arguments,
805 ref mut result,
806 function: _,
807 } => {
808 for argument in arguments.iter_mut() {
809 adjust(argument);
810 }
811 if let Some(e) = result.as_mut() {
812 adjust(e);
813 }
814 }
815 Statement::RayQuery {
816 ref mut query,
817 ref mut fun,
818 } => {
819 adjust(query);
820 match *fun {
821 crate::RayQueryFunction::Initialize {
822 ref mut acceleration_structure,
823 ref mut descriptor,
824 } => {
825 adjust(acceleration_structure);
826 adjust(descriptor);
827 }
828 crate::RayQueryFunction::Proceed { ref mut result } => {
829 adjust(result);
830 }
831 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
832 adjust(hit_t);
833 }
834 crate::RayQueryFunction::ConfirmIntersection => {}
835 crate::RayQueryFunction::Terminate => {}
836 }
837 }
838 Statement::Break
839 | Statement::Continue
840 | Statement::Kill
841 | Statement::ControlBarrier(_)
842 | Statement::MemoryBarrier(_) => {}
843 }
844}
845
846fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
863 let original = mem::replace(block, Block::with_capacity(block.len()));
864 for (stmt, span) in original.span_into_iter() {
865 match stmt {
866 Statement::Emit(range) => {
867 let mut current = None;
868 for expr_h in range {
869 if expressions[expr_h].needs_pre_emit() {
870 if let Some((first, last)) = current {
871 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
872 }
873
874 current = None;
875 } else if let Some((_, ref mut last)) = current {
876 *last = expr_h;
877 } else {
878 current = Some((expr_h, expr_h));
879 }
880 }
881 if let Some((first, last)) = current {
882 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
883 }
884 }
885 Statement::Block(mut child) => {
886 filter_emits_in_block(&mut child, expressions);
887 block.push(Statement::Block(child), span);
888 }
889 Statement::If {
890 condition,
891 mut accept,
892 mut reject,
893 } => {
894 filter_emits_in_block(&mut accept, expressions);
895 filter_emits_in_block(&mut reject, expressions);
896 block.push(
897 Statement::If {
898 condition,
899 accept,
900 reject,
901 },
902 span,
903 );
904 }
905 Statement::Switch {
906 selector,
907 mut cases,
908 } => {
909 for case in &mut cases {
910 filter_emits_in_block(&mut case.body, expressions);
911 }
912 block.push(Statement::Switch { selector, cases }, span);
913 }
914 Statement::Loop {
915 mut body,
916 mut continuing,
917 break_if,
918 } => {
919 filter_emits_in_block(&mut body, expressions);
920 filter_emits_in_block(&mut continuing, expressions);
921 block.push(
922 Statement::Loop {
923 body,
924 continuing,
925 break_if,
926 },
927 span,
928 );
929 }
930 stmt => block.push(stmt.clone(), span),
931 }
932 }
933}
934
935fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
936 match scalar {
938 Scalar::BOOL => {
939 let value = value != 0.0 && !value.is_nan();
941 Ok(Literal::Bool(value))
942 }
943 Scalar::I32 => {
944 if !value.is_finite() {
946 return Err(PipelineConstantError::SrcNeedsToBeFinite);
947 }
948
949 let value = value.trunc();
950 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
951 return Err(PipelineConstantError::DstRangeTooSmall);
952 }
953
954 let value = value as i32;
955 Ok(Literal::I32(value))
956 }
957 Scalar::U32 => {
958 if !value.is_finite() {
960 return Err(PipelineConstantError::SrcNeedsToBeFinite);
961 }
962
963 let value = value.trunc();
964 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
965 return Err(PipelineConstantError::DstRangeTooSmall);
966 }
967
968 let value = value as u32;
969 Ok(Literal::U32(value))
970 }
971 Scalar::F16 => {
972 if !value.is_finite() {
974 return Err(PipelineConstantError::SrcNeedsToBeFinite);
975 }
976
977 let value = half::f16::from_f64(value);
978 if !value.is_finite() {
979 return Err(PipelineConstantError::DstRangeTooSmall);
980 }
981
982 Ok(Literal::F16(value))
983 }
984 Scalar::F32 => {
985 if !value.is_finite() {
987 return Err(PipelineConstantError::SrcNeedsToBeFinite);
988 }
989
990 let value = value as f32;
991 if !value.is_finite() {
992 return Err(PipelineConstantError::DstRangeTooSmall);
993 }
994
995 Ok(Literal::F32(value))
996 }
997 Scalar::F64 => {
998 if !value.is_finite() {
1000 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1001 }
1002
1003 Ok(Literal::F64(value))
1004 }
1005 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1006 unreachable!("abstract values should not be validated out of override processing")
1007 }
1008 _ => unreachable!("unrecognized scalar type for override"),
1009 }
1010}
1011
1012#[test]
1013fn test_map_value_to_literal() {
1014 let bool_test_cases = [
1015 (0.0, false),
1016 (-0.0, false),
1017 (f64::NAN, false),
1018 (1.0, true),
1019 (f64::INFINITY, true),
1020 (f64::NEG_INFINITY, true),
1021 ];
1022 for (value, out) in bool_test_cases {
1023 let res = Ok(Literal::Bool(out));
1024 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1025 }
1026
1027 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1028 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1029 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1030 assert_eq!(map_value_to_literal(value, scalar), res);
1031 }
1032 }
1033
1034 assert_eq!(
1036 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1037 Ok(Literal::I32(i32::MIN))
1038 );
1039 assert_eq!(
1040 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1041 Ok(Literal::I32(i32::MAX))
1042 );
1043 assert_eq!(
1044 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1045 Err(PipelineConstantError::DstRangeTooSmall)
1046 );
1047 assert_eq!(
1048 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1049 Err(PipelineConstantError::DstRangeTooSmall)
1050 );
1051
1052 assert_eq!(
1054 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1055 Ok(Literal::U32(u32::MIN))
1056 );
1057 assert_eq!(
1058 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1059 Ok(Literal::U32(u32::MAX))
1060 );
1061 assert_eq!(
1062 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1063 Err(PipelineConstantError::DstRangeTooSmall)
1064 );
1065 assert_eq!(
1066 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1067 Err(PipelineConstantError::DstRangeTooSmall)
1068 );
1069
1070 assert_eq!(
1072 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1073 Ok(Literal::F32(f32::MIN))
1074 );
1075 assert_eq!(
1076 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1077 Ok(Literal::F32(f32::MAX))
1078 );
1079 assert_eq!(
1080 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1081 Ok(Literal::F32(f32::MIN))
1082 );
1083 assert_eq!(
1084 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1085 Ok(Literal::F32(f32::MAX))
1086 );
1087 assert_eq!(
1088 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1089 Err(PipelineConstantError::DstRangeTooSmall)
1090 );
1091 assert_eq!(
1092 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1093 Err(PipelineConstantError::DstRangeTooSmall)
1094 );
1095
1096 assert_eq!(
1098 map_value_to_literal(f64::MIN, Scalar::F64),
1099 Ok(Literal::F64(f64::MIN))
1100 );
1101 assert_eq!(
1102 map_value_to_literal(f64::MAX, Scalar::F64),
1103 Ok(Literal::F64(f64::MAX))
1104 );
1105}