1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 compact::{compact, KeepUnused},
14 ir,
15 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18 Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28 MissingValue(String),
29 #[error(
30 "Source f64 value needs to be finite ({}) for number destinations",
31 "NaNs and Inifinites are not allowed"
32 )]
33 SrcNeedsToBeFinite,
34 #[error("Source f64 value doesn't fit in destination")]
35 DstRangeTooSmall,
36 #[error(transparent)]
37 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38 #[error(transparent)]
39 ValidationError(#[from] WithSpan<ValidationError>),
40 #[error("workgroup_size override isn't strictly positive")]
41 NegativeWorkgroupSize,
42 #[error("max vertices or max primitives is negative")]
43 NegativeMeshOutputMax,
44}
45
46pub fn process_overrides<'a>(
59 module: &'a Module,
60 module_info: &'a ModuleInfo,
61 entry_point: Option<(ir::ShaderStage, &str)>,
62 pipeline_constants: &PipelineConstants,
63) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
64 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
65 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
75 }
76
77 let mut module = module.clone();
78 if let Some((ep_stage, ep_name)) = entry_point {
79 module
80 .entry_points
81 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
82 }
83
84 compact(&mut module, KeepUnused::No);
88
89 if module.overrides.is_empty() {
91 return revalidate(module);
92 }
93
94 let mut override_map = HandleVec::with_capacity(module.overrides.len());
97
98 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
101
102 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
113
114 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
115 let mut layouter = crate::proc::Layouter::default();
116
117 let mut overrides = mem::take(&mut module.overrides);
120 let mut override_iter = overrides.iter_mut_span();
121
122 for (old_h, expr, span) in module.global_expressions.drain() {
152 let mut expr = match expr {
153 Expression::Override(h) => {
154 let c_h = if let Some(new_h) = override_map.get(h) {
155 *new_h
156 } else {
157 let mut new_h = None;
158 for entry in override_iter.by_ref() {
159 let stop = entry.0 == h;
160 new_h = Some(process_override(
161 entry,
162 pipeline_constants,
163 &mut module,
164 &mut override_map,
165 &adjusted_global_expressions,
166 &mut adjusted_constant_initializers,
167 &mut global_expression_kind_tracker,
168 )?);
169 if stop {
170 break;
171 }
172 }
173 new_h.unwrap()
174 };
175 Expression::Constant(c_h)
176 }
177 Expression::Constant(c_h) => {
178 if adjusted_constant_initializers.insert(c_h) {
179 let init = &mut module.constants[c_h].init;
180 *init = adjusted_global_expressions[*init];
181 }
182 expr
183 }
184 expr => expr,
185 };
186 let mut evaluator = ConstantEvaluator::for_wgsl_module(
187 &mut module,
188 &mut global_expression_kind_tracker,
189 &mut layouter,
190 false,
191 );
192 adjust_expr(&adjusted_global_expressions, &mut expr);
193 let h = evaluator.try_eval_and_append(expr, span)?;
194 adjusted_global_expressions.insert(old_h, h);
195 }
196
197 for entry in override_iter {
199 match *entry.1 {
200 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
201 process_override(
202 entry,
203 pipeline_constants,
204 &mut module,
205 &mut override_map,
206 &adjusted_global_expressions,
207 &mut adjusted_constant_initializers,
208 &mut global_expression_kind_tracker,
209 )?;
210 }
211 Override {
212 init: Some(ref mut init),
213 ..
214 } => {
215 *init = adjusted_global_expressions[*init];
216 }
217 _ => {}
218 }
219 }
220
221 for (_, c) in module
225 .constants
226 .iter_mut()
227 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
228 {
229 c.init = adjusted_global_expressions[c.init];
230 }
231
232 for (_, v) in module.global_variables.iter_mut() {
233 if let Some(ref mut init) = v.init {
234 *init = adjusted_global_expressions[*init];
235 }
236 }
237
238 let mut functions = mem::take(&mut module.functions);
239 for (_, function) in functions.iter_mut() {
240 process_function(&mut module, &override_map, &mut layouter, function)?;
241 }
242 module.functions = functions;
243
244 let mut entry_points = mem::take(&mut module.entry_points);
245 for ep in entry_points.iter_mut() {
246 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
247 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
248 process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
249 }
250 module.entry_points = entry_points;
251 module.overrides = overrides;
252
253 revalidate(module)
257}
258
259fn revalidate(
260 module: Module,
261) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
262 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
263 let module_info = validator.validate_resolved_overrides(&module)?;
264 Ok((Cow::Owned(module), Cow::Owned(module_info)))
265}
266
267fn process_workgroup_size_override(
268 module: &mut Module,
269 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
270 ep: &mut crate::EntryPoint,
271) -> Result<(), PipelineConstantError> {
272 match ep.workgroup_size_overrides {
273 None => {}
274 Some(overrides) => {
275 overrides.iter().enumerate().try_for_each(
276 |(i, overridden)| -> Result<(), PipelineConstantError> {
277 match *overridden {
278 None => Ok(()),
279 Some(h) => {
280 ep.workgroup_size[i] = module
281 .to_ctx()
282 .eval_expr_to_u32(adjusted_global_expressions[h])
283 .map(|n| {
284 if n == 0 {
285 Err(PipelineConstantError::NegativeWorkgroupSize)
286 } else {
287 Ok(n)
288 }
289 })
290 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
291 Ok(())
292 }
293 }
294 },
295 )?;
296 ep.workgroup_size_overrides = None;
297 }
298 }
299 Ok(())
300}
301
302fn process_mesh_shader_overrides(
303 module: &mut Module,
304 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
305 ep: &mut crate::EntryPoint,
306) -> Result<(), PipelineConstantError> {
307 if let Some(ref mut mesh_info) = ep.mesh_info {
308 if let Some(r#override) = mesh_info.max_vertices_override {
309 mesh_info.max_vertices = module
310 .to_ctx()
311 .eval_expr_to_u32(adjusted_global_expressions[r#override])
312 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
313 }
314 if let Some(r#override) = mesh_info.max_primitives_override {
315 mesh_info.max_primitives = module
316 .to_ctx()
317 .eval_expr_to_u32(adjusted_global_expressions[r#override])
318 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
319 }
320 }
321 Ok(())
322}
323
324fn process_override(
328 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
329 pipeline_constants: &PipelineConstants,
330 module: &mut Module,
331 override_map: &mut HandleVec<Override, Handle<Constant>>,
332 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
333 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
334 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
335) -> Result<Handle<Constant>, PipelineConstantError> {
336 let key = if let Some(id) = r#override.id {
338 Cow::Owned(id.to_string())
339 } else if let Some(ref name) = r#override.name {
340 Cow::Borrowed(name)
341 } else {
342 unreachable!();
343 };
344
345 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
349 let literal = match module.types[r#override.ty].inner {
350 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
351 _ => unreachable!(),
352 };
353 let expr = module
354 .global_expressions
355 .append(Expression::Literal(literal), Span::UNDEFINED);
356 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
357 expr
358 } else if let Some(init) = r#override.init {
359 adjusted_global_expressions[init]
360 } else {
361 return Err(PipelineConstantError::MissingValue(key.to_string()));
362 };
363
364 let constant = Constant {
366 name: r#override.name.clone(),
367 ty: r#override.ty,
368 init,
369 };
370 let h = module.constants.append(constant, *span);
371 override_map.insert(old_h, h);
372 adjusted_constant_initializers.insert(h);
373 r#override.init = Some(init);
374 Ok(h)
375}
376
377fn process_function(
387 module: &mut Module,
388 override_map: &HandleVec<Override, Handle<Constant>>,
389 layouter: &mut crate::proc::Layouter,
390 function: &mut Function,
391) -> Result<(), ConstantEvaluatorError> {
392 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
395
396 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
397
398 let mut expressions = mem::take(&mut function.expressions);
399
400 let mut emitter = Emitter::default();
408 let mut block = Block::new();
409
410 let mut evaluator = ConstantEvaluator::for_wgsl_function(
411 module,
412 &mut function.expressions,
413 &mut local_expression_kind_tracker,
414 layouter,
415 &mut emitter,
416 &mut block,
417 false,
418 );
419
420 for (old_h, mut expr, span) in expressions.drain() {
421 if let Expression::Override(h) = expr {
422 expr = Expression::Constant(override_map[h]);
423 }
424 adjust_expr(&adjusted_local_expressions, &mut expr);
425 let h = evaluator.try_eval_and_append(expr, span)?;
426 adjusted_local_expressions.insert(old_h, h);
427 }
428
429 adjust_block(&adjusted_local_expressions, &mut function.body);
430
431 filter_emits_in_block(&mut function.body, &function.expressions);
432
433 for (_, local) in function.local_variables.iter_mut() {
435 if let &mut Some(ref mut init) = &mut local.init {
436 *init = adjusted_local_expressions[*init];
437 }
438 }
439
440 let named_expressions = mem::take(&mut function.named_expressions);
443 for (expr_h, name) in named_expressions {
444 function
445 .named_expressions
446 .insert(adjusted_local_expressions[expr_h], name);
447 }
448
449 Ok(())
450}
451
452fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
455 let adjust = |expr: &mut Handle<Expression>| {
456 *expr = new_pos[*expr];
457 };
458 match *expr {
459 Expression::Compose {
460 ref mut components,
461 ty: _,
462 } => {
463 for c in components.iter_mut() {
464 adjust(c);
465 }
466 }
467 Expression::Access {
468 ref mut base,
469 ref mut index,
470 } => {
471 adjust(base);
472 adjust(index);
473 }
474 Expression::AccessIndex {
475 ref mut base,
476 index: _,
477 } => {
478 adjust(base);
479 }
480 Expression::Splat {
481 ref mut value,
482 size: _,
483 } => {
484 adjust(value);
485 }
486 Expression::Swizzle {
487 ref mut vector,
488 size: _,
489 pattern: _,
490 } => {
491 adjust(vector);
492 }
493 Expression::Load { ref mut pointer } => {
494 adjust(pointer);
495 }
496 Expression::ImageSample {
497 ref mut image,
498 ref mut sampler,
499 ref mut coordinate,
500 ref mut array_index,
501 ref mut offset,
502 ref mut level,
503 ref mut depth_ref,
504 gather: _,
505 clamp_to_edge: _,
506 } => {
507 adjust(image);
508 adjust(sampler);
509 adjust(coordinate);
510 if let Some(e) = array_index.as_mut() {
511 adjust(e);
512 }
513 if let Some(e) = offset.as_mut() {
514 adjust(e);
515 }
516 match *level {
517 crate::SampleLevel::Exact(ref mut expr)
518 | crate::SampleLevel::Bias(ref mut expr) => {
519 adjust(expr);
520 }
521 crate::SampleLevel::Gradient {
522 ref mut x,
523 ref mut y,
524 } => {
525 adjust(x);
526 adjust(y);
527 }
528 _ => {}
529 }
530 if let Some(e) = depth_ref.as_mut() {
531 adjust(e);
532 }
533 }
534 Expression::ImageLoad {
535 ref mut image,
536 ref mut coordinate,
537 ref mut array_index,
538 ref mut sample,
539 ref mut level,
540 } => {
541 adjust(image);
542 adjust(coordinate);
543 if let Some(e) = array_index.as_mut() {
544 adjust(e);
545 }
546 if let Some(e) = sample.as_mut() {
547 adjust(e);
548 }
549 if let Some(e) = level.as_mut() {
550 adjust(e);
551 }
552 }
553 Expression::ImageQuery {
554 ref mut image,
555 ref mut query,
556 } => {
557 adjust(image);
558 match *query {
559 crate::ImageQuery::Size { ref mut level } => {
560 if let Some(e) = level.as_mut() {
561 adjust(e);
562 }
563 }
564 crate::ImageQuery::NumLevels
565 | crate::ImageQuery::NumLayers
566 | crate::ImageQuery::NumSamples => {}
567 }
568 }
569 Expression::Unary {
570 ref mut expr,
571 op: _,
572 } => {
573 adjust(expr);
574 }
575 Expression::Binary {
576 ref mut left,
577 ref mut right,
578 op: _,
579 } => {
580 adjust(left);
581 adjust(right);
582 }
583 Expression::Select {
584 ref mut condition,
585 ref mut accept,
586 ref mut reject,
587 } => {
588 adjust(condition);
589 adjust(accept);
590 adjust(reject);
591 }
592 Expression::Derivative {
593 ref mut expr,
594 axis: _,
595 ctrl: _,
596 } => {
597 adjust(expr);
598 }
599 Expression::Relational {
600 ref mut argument,
601 fun: _,
602 } => {
603 adjust(argument);
604 }
605 Expression::Math {
606 ref mut arg,
607 ref mut arg1,
608 ref mut arg2,
609 ref mut arg3,
610 fun: _,
611 } => {
612 adjust(arg);
613 if let Some(e) = arg1.as_mut() {
614 adjust(e);
615 }
616 if let Some(e) = arg2.as_mut() {
617 adjust(e);
618 }
619 if let Some(e) = arg3.as_mut() {
620 adjust(e);
621 }
622 }
623 Expression::As {
624 ref mut expr,
625 kind: _,
626 convert: _,
627 } => {
628 adjust(expr);
629 }
630 Expression::ArrayLength(ref mut expr) => {
631 adjust(expr);
632 }
633 Expression::RayQueryGetIntersection {
634 ref mut query,
635 committed: _,
636 } => {
637 adjust(query);
638 }
639 Expression::Literal(_)
640 | Expression::FunctionArgument(_)
641 | Expression::GlobalVariable(_)
642 | Expression::LocalVariable(_)
643 | Expression::CallResult(_)
644 | Expression::RayQueryProceedResult
645 | Expression::Constant(_)
646 | Expression::Override(_)
647 | Expression::ZeroValue(_)
648 | Expression::AtomicResult {
649 ty: _,
650 comparison: _,
651 }
652 | Expression::WorkGroupUniformLoadResult { ty: _ }
653 | Expression::SubgroupBallotResult
654 | Expression::SubgroupOperationResult { .. } => {}
655 Expression::RayQueryVertexPositions {
656 ref mut query,
657 committed: _,
658 } => {
659 adjust(query);
660 }
661 }
662}
663
664fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
667 for stmt in block.iter_mut() {
668 adjust_stmt(new_pos, stmt);
669 }
670}
671
672fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
675 let adjust = |expr: &mut Handle<Expression>| {
676 *expr = new_pos[*expr];
677 };
678 match *stmt {
679 Statement::Emit(ref mut range) => {
680 if let Some((mut first, mut last)) = range.first_and_last() {
681 adjust(&mut first);
682 adjust(&mut last);
683 *range = Range::new_from_bounds(first, last);
684 }
685 }
686 Statement::Block(ref mut block) => {
687 adjust_block(new_pos, block);
688 }
689 Statement::If {
690 ref mut condition,
691 ref mut accept,
692 ref mut reject,
693 } => {
694 adjust(condition);
695 adjust_block(new_pos, accept);
696 adjust_block(new_pos, reject);
697 }
698 Statement::Switch {
699 ref mut selector,
700 ref mut cases,
701 } => {
702 adjust(selector);
703 for case in cases.iter_mut() {
704 adjust_block(new_pos, &mut case.body);
705 }
706 }
707 Statement::Loop {
708 ref mut body,
709 ref mut continuing,
710 ref mut break_if,
711 } => {
712 adjust_block(new_pos, body);
713 adjust_block(new_pos, continuing);
714 if let Some(e) = break_if.as_mut() {
715 adjust(e);
716 }
717 }
718 Statement::Return { ref mut value } => {
719 if let Some(e) = value.as_mut() {
720 adjust(e);
721 }
722 }
723 Statement::Store {
724 ref mut pointer,
725 ref mut value,
726 } => {
727 adjust(pointer);
728 adjust(value);
729 }
730 Statement::ImageStore {
731 ref mut image,
732 ref mut coordinate,
733 ref mut array_index,
734 ref mut value,
735 } => {
736 adjust(image);
737 adjust(coordinate);
738 if let Some(e) = array_index.as_mut() {
739 adjust(e);
740 }
741 adjust(value);
742 }
743 Statement::Atomic {
744 ref mut pointer,
745 ref mut value,
746 ref mut result,
747 ref mut fun,
748 } => {
749 adjust(pointer);
750 adjust(value);
751 if let Some(ref mut result) = *result {
752 adjust(result);
753 }
754 match *fun {
755 crate::AtomicFunction::Exchange {
756 compare: Some(ref mut compare),
757 } => {
758 adjust(compare);
759 }
760 crate::AtomicFunction::Add
761 | crate::AtomicFunction::Subtract
762 | crate::AtomicFunction::And
763 | crate::AtomicFunction::ExclusiveOr
764 | crate::AtomicFunction::InclusiveOr
765 | crate::AtomicFunction::Min
766 | crate::AtomicFunction::Max
767 | crate::AtomicFunction::Exchange { compare: None } => {}
768 }
769 }
770 Statement::ImageAtomic {
771 ref mut image,
772 ref mut coordinate,
773 ref mut array_index,
774 fun: _,
775 ref mut value,
776 } => {
777 adjust(image);
778 adjust(coordinate);
779 if let Some(ref mut array_index) = *array_index {
780 adjust(array_index);
781 }
782 adjust(value);
783 }
784 Statement::WorkGroupUniformLoad {
785 ref mut pointer,
786 ref mut result,
787 } => {
788 adjust(pointer);
789 adjust(result);
790 }
791 Statement::SubgroupBallot {
792 ref mut result,
793 ref mut predicate,
794 } => {
795 if let Some(ref mut predicate) = *predicate {
796 adjust(predicate);
797 }
798 adjust(result);
799 }
800 Statement::SubgroupCollectiveOperation {
801 ref mut argument,
802 ref mut result,
803 ..
804 } => {
805 adjust(argument);
806 adjust(result);
807 }
808 Statement::SubgroupGather {
809 ref mut mode,
810 ref mut argument,
811 ref mut result,
812 } => {
813 match *mode {
814 crate::GatherMode::BroadcastFirst => {}
815 crate::GatherMode::Broadcast(ref mut index)
816 | crate::GatherMode::Shuffle(ref mut index)
817 | crate::GatherMode::ShuffleDown(ref mut index)
818 | crate::GatherMode::ShuffleUp(ref mut index)
819 | crate::GatherMode::ShuffleXor(ref mut index)
820 | crate::GatherMode::QuadBroadcast(ref mut index) => {
821 adjust(index);
822 }
823 crate::GatherMode::QuadSwap(_) => {}
824 }
825 adjust(argument);
826 adjust(result)
827 }
828 Statement::Call {
829 ref mut arguments,
830 ref mut result,
831 function: _,
832 } => {
833 for argument in arguments.iter_mut() {
834 adjust(argument);
835 }
836 if let Some(e) = result.as_mut() {
837 adjust(e);
838 }
839 }
840 Statement::RayQuery {
841 ref mut query,
842 ref mut fun,
843 } => {
844 adjust(query);
845 match *fun {
846 crate::RayQueryFunction::Initialize {
847 ref mut acceleration_structure,
848 ref mut descriptor,
849 } => {
850 adjust(acceleration_structure);
851 adjust(descriptor);
852 }
853 crate::RayQueryFunction::Proceed { ref mut result } => {
854 adjust(result);
855 }
856 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
857 adjust(hit_t);
858 }
859 crate::RayQueryFunction::ConfirmIntersection => {}
860 crate::RayQueryFunction::Terminate => {}
861 }
862 }
863 Statement::Break
864 | Statement::Continue
865 | Statement::Kill
866 | Statement::ControlBarrier(_)
867 | Statement::MemoryBarrier(_) => {}
868 }
869}
870
871fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
888 let original = mem::replace(block, Block::with_capacity(block.len()));
889 for (stmt, span) in original.span_into_iter() {
890 match stmt {
891 Statement::Emit(range) => {
892 let mut current = None;
893 for expr_h in range {
894 if expressions[expr_h].needs_pre_emit() {
895 if let Some((first, last)) = current {
896 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
897 }
898
899 current = None;
900 } else if let Some((_, ref mut last)) = current {
901 *last = expr_h;
902 } else {
903 current = Some((expr_h, expr_h));
904 }
905 }
906 if let Some((first, last)) = current {
907 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
908 }
909 }
910 Statement::Block(mut child) => {
911 filter_emits_in_block(&mut child, expressions);
912 block.push(Statement::Block(child), span);
913 }
914 Statement::If {
915 condition,
916 mut accept,
917 mut reject,
918 } => {
919 filter_emits_in_block(&mut accept, expressions);
920 filter_emits_in_block(&mut reject, expressions);
921 block.push(
922 Statement::If {
923 condition,
924 accept,
925 reject,
926 },
927 span,
928 );
929 }
930 Statement::Switch {
931 selector,
932 mut cases,
933 } => {
934 for case in &mut cases {
935 filter_emits_in_block(&mut case.body, expressions);
936 }
937 block.push(Statement::Switch { selector, cases }, span);
938 }
939 Statement::Loop {
940 mut body,
941 mut continuing,
942 break_if,
943 } => {
944 filter_emits_in_block(&mut body, expressions);
945 filter_emits_in_block(&mut continuing, expressions);
946 block.push(
947 Statement::Loop {
948 body,
949 continuing,
950 break_if,
951 },
952 span,
953 );
954 }
955 stmt => block.push(stmt.clone(), span),
956 }
957 }
958}
959
960fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
961 match scalar {
963 Scalar::BOOL => {
964 let value = value != 0.0 && !value.is_nan();
966 Ok(Literal::Bool(value))
967 }
968 Scalar::I32 => {
969 if !value.is_finite() {
971 return Err(PipelineConstantError::SrcNeedsToBeFinite);
972 }
973
974 let value = value.trunc();
975 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
976 return Err(PipelineConstantError::DstRangeTooSmall);
977 }
978
979 let value = value as i32;
980 Ok(Literal::I32(value))
981 }
982 Scalar::U32 => {
983 if !value.is_finite() {
985 return Err(PipelineConstantError::SrcNeedsToBeFinite);
986 }
987
988 let value = value.trunc();
989 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
990 return Err(PipelineConstantError::DstRangeTooSmall);
991 }
992
993 let value = value as u32;
994 Ok(Literal::U32(value))
995 }
996 Scalar::F16 => {
997 if !value.is_finite() {
999 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1000 }
1001
1002 let value = half::f16::from_f64(value);
1003 if !value.is_finite() {
1004 return Err(PipelineConstantError::DstRangeTooSmall);
1005 }
1006
1007 Ok(Literal::F16(value))
1008 }
1009 Scalar::F32 => {
1010 if !value.is_finite() {
1012 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1013 }
1014
1015 let value = value as f32;
1016 if !value.is_finite() {
1017 return Err(PipelineConstantError::DstRangeTooSmall);
1018 }
1019
1020 Ok(Literal::F32(value))
1021 }
1022 Scalar::F64 => {
1023 if !value.is_finite() {
1025 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1026 }
1027
1028 Ok(Literal::F64(value))
1029 }
1030 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1031 unreachable!("abstract values should not be validated out of override processing")
1032 }
1033 _ => unreachable!("unrecognized scalar type for override"),
1034 }
1035}
1036
1037#[test]
1038fn test_map_value_to_literal() {
1039 let bool_test_cases = [
1040 (0.0, false),
1041 (-0.0, false),
1042 (f64::NAN, false),
1043 (1.0, true),
1044 (f64::INFINITY, true),
1045 (f64::NEG_INFINITY, true),
1046 ];
1047 for (value, out) in bool_test_cases {
1048 let res = Ok(Literal::Bool(out));
1049 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1050 }
1051
1052 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1053 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1054 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1055 assert_eq!(map_value_to_literal(value, scalar), res);
1056 }
1057 }
1058
1059 assert_eq!(
1061 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1062 Ok(Literal::I32(i32::MIN))
1063 );
1064 assert_eq!(
1065 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1066 Ok(Literal::I32(i32::MAX))
1067 );
1068 assert_eq!(
1069 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1070 Err(PipelineConstantError::DstRangeTooSmall)
1071 );
1072 assert_eq!(
1073 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1074 Err(PipelineConstantError::DstRangeTooSmall)
1075 );
1076
1077 assert_eq!(
1079 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1080 Ok(Literal::U32(u32::MIN))
1081 );
1082 assert_eq!(
1083 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1084 Ok(Literal::U32(u32::MAX))
1085 );
1086 assert_eq!(
1087 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1088 Err(PipelineConstantError::DstRangeTooSmall)
1089 );
1090 assert_eq!(
1091 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1092 Err(PipelineConstantError::DstRangeTooSmall)
1093 );
1094
1095 assert_eq!(
1097 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1098 Ok(Literal::F32(f32::MIN))
1099 );
1100 assert_eq!(
1101 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1102 Ok(Literal::F32(f32::MAX))
1103 );
1104 assert_eq!(
1105 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1106 Ok(Literal::F32(f32::MIN))
1107 );
1108 assert_eq!(
1109 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1110 Ok(Literal::F32(f32::MAX))
1111 );
1112 assert_eq!(
1113 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1114 Err(PipelineConstantError::DstRangeTooSmall)
1115 );
1116 assert_eq!(
1117 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1118 Err(PipelineConstantError::DstRangeTooSmall)
1119 );
1120
1121 assert_eq!(
1123 map_value_to_literal(f64::MIN, Scalar::F64),
1124 Ok(Literal::F64(f64::MIN))
1125 );
1126 assert_eq!(
1127 map_value_to_literal(f64::MAX, Scalar::F64),
1128 Ok(Literal::F64(f64::MAX))
1129 );
1130}