1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 compact::{compact, KeepUnused},
14 ir,
15 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18 Span, Statement, TypeInner, WithSpan,
19};
20
21#[allow(unused_imports)]
23use num_traits::float::FloatCore as _;
24
25#[derive(Error, Debug, Clone)]
26#[cfg_attr(test, derive(PartialEq))]
27pub enum PipelineConstantError {
28 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
29 MissingValue(String),
30 #[error(
31 "Source f64 value needs to be finite ({}) for number destinations",
32 "NaNs and Inifinites are not allowed"
33 )]
34 SrcNeedsToBeFinite,
35 #[error("Source f64 value doesn't fit in destination")]
36 DstRangeTooSmall,
37 #[error(transparent)]
38 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
39 #[error(transparent)]
40 ValidationError(#[from] WithSpan<ValidationError>),
41 #[error("workgroup_size override isn't strictly positive")]
42 NegativeWorkgroupSize,
43 #[error("max vertices or max primitives is negative")]
44 NegativeMeshOutputMax,
45}
46
47pub fn process_overrides<'a>(
64 module: &'a Module,
65 module_info: &'a ModuleInfo,
66 entry_point: Option<(ir::ShaderStage, &str)>,
67 pipeline_constants: &PipelineConstants,
68) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
69 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
70 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
80 }
81
82 let mut module = module.clone();
83 if let Some((ep_stage, ep_name)) = entry_point {
84 module
85 .entry_points
86 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
87 }
88
89 compact(&mut module, KeepUnused::No);
93
94 if module.overrides.is_empty() {
96 return revalidate(module);
97 }
98
99 let mut override_map = HandleVec::with_capacity(module.overrides.len());
102
103 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
106
107 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
118
119 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
120 let mut layouter = crate::proc::Layouter::default();
121
122 let mut overrides = mem::take(&mut module.overrides);
125 let mut override_iter = overrides.iter_mut_span();
126
127 for (old_h, expr, span) in module.global_expressions.drain() {
157 let mut expr = match expr {
158 Expression::Override(h) => {
159 let c_h = if let Some(new_h) = override_map.get(h) {
160 *new_h
161 } else {
162 let mut new_h = None;
163 for entry in override_iter.by_ref() {
164 let stop = entry.0 == h;
165 new_h = Some(process_override(
166 entry,
167 pipeline_constants,
168 &mut module,
169 &mut override_map,
170 &adjusted_global_expressions,
171 &mut adjusted_constant_initializers,
172 &mut global_expression_kind_tracker,
173 )?);
174 if stop {
175 break;
176 }
177 }
178 new_h.unwrap()
179 };
180 Expression::Constant(c_h)
181 }
182 Expression::Constant(c_h) => {
183 if adjusted_constant_initializers.insert(c_h) {
184 let init = &mut module.constants[c_h].init;
185 *init = adjusted_global_expressions[*init];
186 }
187 expr
188 }
189 expr => expr,
190 };
191 let mut evaluator = ConstantEvaluator::for_wgsl_module(
192 &mut module,
193 &mut global_expression_kind_tracker,
194 &mut layouter,
195 false,
196 );
197 adjust_expr(&adjusted_global_expressions, &mut expr);
198 let h = evaluator.try_eval_and_append(expr, span)?;
199 adjusted_global_expressions.insert(old_h, h);
200 }
201
202 for entry in override_iter {
204 match *entry.1 {
205 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
206 process_override(
207 entry,
208 pipeline_constants,
209 &mut module,
210 &mut override_map,
211 &adjusted_global_expressions,
212 &mut adjusted_constant_initializers,
213 &mut global_expression_kind_tracker,
214 )?;
215 }
216 Override {
217 init: Some(ref mut init),
218 ..
219 } => {
220 *init = adjusted_global_expressions[*init];
221 }
222 _ => {}
223 }
224 }
225
226 for (_, c) in module
230 .constants
231 .iter_mut()
232 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
233 {
234 c.init = adjusted_global_expressions[c.init];
235 }
236
237 for (_, v) in module.global_variables.iter_mut() {
238 if let Some(ref mut init) = v.init {
239 *init = adjusted_global_expressions[*init];
240 }
241 }
242
243 let mut functions = mem::take(&mut module.functions);
244 for (_, function) in functions.iter_mut() {
245 process_function(&mut module, &override_map, &mut layouter, function)?;
246 }
247 module.functions = functions;
248
249 let mut entry_points = mem::take(&mut module.entry_points);
250 for ep in entry_points.iter_mut() {
251 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
252 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
253 process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
254 }
255 module.entry_points = entry_points;
256 module.overrides = overrides;
257
258 revalidate(module)
262}
263
264fn revalidate(
265 module: Module,
266) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
267 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
268 let module_info = validator.validate_resolved_overrides(&module)?;
269 Ok((Cow::Owned(module), Cow::Owned(module_info)))
270}
271
272fn process_workgroup_size_override(
273 module: &mut Module,
274 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
275 ep: &mut crate::EntryPoint,
276) -> Result<(), PipelineConstantError> {
277 match ep.workgroup_size_overrides {
278 None => {}
279 Some(overrides) => {
280 overrides.iter().enumerate().try_for_each(
281 |(i, overridden)| -> Result<(), PipelineConstantError> {
282 match *overridden {
283 None => Ok(()),
284 Some(h) => {
285 ep.workgroup_size[i] = module
286 .to_ctx()
287 .get_const_val(adjusted_global_expressions[h])
288 .map(|n| {
289 if n == 0 {
290 Err(PipelineConstantError::NegativeWorkgroupSize)
291 } else {
292 Ok(n)
293 }
294 })
295 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
296 Ok(())
297 }
298 }
299 },
300 )?;
301 ep.workgroup_size_overrides = None;
302 }
303 }
304 Ok(())
305}
306
307fn process_mesh_shader_overrides(
308 module: &mut Module,
309 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
310 ep: &mut crate::EntryPoint,
311) -> Result<(), PipelineConstantError> {
312 if let Some(ref mut mesh_info) = ep.mesh_info {
313 if let Some(r#override) = mesh_info.max_vertices_override {
314 mesh_info.max_vertices = module
315 .to_ctx()
316 .get_const_val(adjusted_global_expressions[r#override])
317 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
318 }
319 if let Some(r#override) = mesh_info.max_primitives_override {
320 mesh_info.max_primitives = module
321 .to_ctx()
322 .get_const_val(adjusted_global_expressions[r#override])
323 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
324 }
325 }
326 Ok(())
327}
328
329fn process_override(
333 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
334 pipeline_constants: &PipelineConstants,
335 module: &mut Module,
336 override_map: &mut HandleVec<Override, Handle<Constant>>,
337 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
338 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
339 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
340) -> Result<Handle<Constant>, PipelineConstantError> {
341 let key = if let Some(id) = r#override.id {
343 Cow::Owned(id.to_string())
344 } else if let Some(ref name) = r#override.name {
345 Cow::Borrowed(name)
346 } else {
347 unreachable!();
348 };
349
350 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
354 let literal = match module.types[r#override.ty].inner {
355 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
356 _ => unreachable!(),
357 };
358 let expr = module
359 .global_expressions
360 .append(Expression::Literal(literal), Span::UNDEFINED);
361 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
362 expr
363 } else if let Some(init) = r#override.init {
364 adjusted_global_expressions[init]
365 } else {
366 return Err(PipelineConstantError::MissingValue(key.to_string()));
367 };
368
369 let constant = Constant {
371 name: r#override.name.clone(),
372 ty: r#override.ty,
373 init,
374 };
375 let h = module.constants.append(constant, *span);
376 override_map.insert(old_h, h);
377 adjusted_constant_initializers.insert(h);
378 r#override.init = Some(init);
379 Ok(h)
380}
381
382fn process_function(
392 module: &mut Module,
393 override_map: &HandleVec<Override, Handle<Constant>>,
394 layouter: &mut crate::proc::Layouter,
395 function: &mut Function,
396) -> Result<(), ConstantEvaluatorError> {
397 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
400
401 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
402
403 let mut expressions = mem::take(&mut function.expressions);
404
405 let mut emitter = Emitter::default();
413 let mut block = Block::new();
414
415 let mut evaluator = ConstantEvaluator::for_wgsl_function(
416 module,
417 &mut function.expressions,
418 &mut local_expression_kind_tracker,
419 layouter,
420 &mut emitter,
421 &mut block,
422 false,
423 );
424
425 for (old_h, mut expr, span) in expressions.drain() {
426 if let Expression::Override(h) = expr {
427 expr = Expression::Constant(override_map[h]);
428 }
429 adjust_expr(&adjusted_local_expressions, &mut expr);
430 let h = evaluator.try_eval_and_append(expr, span)?;
431 adjusted_local_expressions.insert(old_h, h);
432 }
433
434 adjust_block(&adjusted_local_expressions, &mut function.body);
435
436 filter_emits_in_block(&mut function.body, &function.expressions);
437
438 for (_, local) in function.local_variables.iter_mut() {
440 if let &mut Some(ref mut init) = &mut local.init {
441 *init = adjusted_local_expressions[*init];
442 }
443 }
444
445 let named_expressions = mem::take(&mut function.named_expressions);
448 for (expr_h, name) in named_expressions {
449 function
450 .named_expressions
451 .insert(adjusted_local_expressions[expr_h], name);
452 }
453
454 Ok(())
455}
456
457fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
460 let adjust = |expr: &mut Handle<Expression>| {
461 *expr = new_pos[*expr];
462 };
463 match *expr {
464 Expression::Compose {
465 ref mut components,
466 ty: _,
467 } => {
468 for c in components.iter_mut() {
469 adjust(c);
470 }
471 }
472 Expression::Access {
473 ref mut base,
474 ref mut index,
475 } => {
476 adjust(base);
477 adjust(index);
478 }
479 Expression::AccessIndex {
480 ref mut base,
481 index: _,
482 } => {
483 adjust(base);
484 }
485 Expression::Splat {
486 ref mut value,
487 size: _,
488 } => {
489 adjust(value);
490 }
491 Expression::Swizzle {
492 ref mut vector,
493 size: _,
494 pattern: _,
495 } => {
496 adjust(vector);
497 }
498 Expression::Load { ref mut pointer } => {
499 adjust(pointer);
500 }
501 Expression::ImageSample {
502 ref mut image,
503 ref mut sampler,
504 ref mut coordinate,
505 ref mut array_index,
506 ref mut offset,
507 ref mut level,
508 ref mut depth_ref,
509 gather: _,
510 clamp_to_edge: _,
511 } => {
512 adjust(image);
513 adjust(sampler);
514 adjust(coordinate);
515 if let Some(e) = array_index.as_mut() {
516 adjust(e);
517 }
518 if let Some(e) = offset.as_mut() {
519 adjust(e);
520 }
521 match *level {
522 crate::SampleLevel::Exact(ref mut expr)
523 | crate::SampleLevel::Bias(ref mut expr) => {
524 adjust(expr);
525 }
526 crate::SampleLevel::Gradient {
527 ref mut x,
528 ref mut y,
529 } => {
530 adjust(x);
531 adjust(y);
532 }
533 _ => {}
534 }
535 if let Some(e) = depth_ref.as_mut() {
536 adjust(e);
537 }
538 }
539 Expression::ImageLoad {
540 ref mut image,
541 ref mut coordinate,
542 ref mut array_index,
543 ref mut sample,
544 ref mut level,
545 } => {
546 adjust(image);
547 adjust(coordinate);
548 if let Some(e) = array_index.as_mut() {
549 adjust(e);
550 }
551 if let Some(e) = sample.as_mut() {
552 adjust(e);
553 }
554 if let Some(e) = level.as_mut() {
555 adjust(e);
556 }
557 }
558 Expression::ImageQuery {
559 ref mut image,
560 ref mut query,
561 } => {
562 adjust(image);
563 match *query {
564 crate::ImageQuery::Size { ref mut level } => {
565 if let Some(e) = level.as_mut() {
566 adjust(e);
567 }
568 }
569 crate::ImageQuery::NumLevels
570 | crate::ImageQuery::NumLayers
571 | crate::ImageQuery::NumSamples => {}
572 }
573 }
574 Expression::Unary {
575 ref mut expr,
576 op: _,
577 } => {
578 adjust(expr);
579 }
580 Expression::Binary {
581 ref mut left,
582 ref mut right,
583 op: _,
584 } => {
585 adjust(left);
586 adjust(right);
587 }
588 Expression::Select {
589 ref mut condition,
590 ref mut accept,
591 ref mut reject,
592 } => {
593 adjust(condition);
594 adjust(accept);
595 adjust(reject);
596 }
597 Expression::Derivative {
598 ref mut expr,
599 axis: _,
600 ctrl: _,
601 } => {
602 adjust(expr);
603 }
604 Expression::Relational {
605 ref mut argument,
606 fun: _,
607 } => {
608 adjust(argument);
609 }
610 Expression::Math {
611 ref mut arg,
612 ref mut arg1,
613 ref mut arg2,
614 ref mut arg3,
615 fun: _,
616 } => {
617 adjust(arg);
618 if let Some(e) = arg1.as_mut() {
619 adjust(e);
620 }
621 if let Some(e) = arg2.as_mut() {
622 adjust(e);
623 }
624 if let Some(e) = arg3.as_mut() {
625 adjust(e);
626 }
627 }
628 Expression::As {
629 ref mut expr,
630 kind: _,
631 convert: _,
632 } => {
633 adjust(expr);
634 }
635 Expression::ArrayLength(ref mut expr) => {
636 adjust(expr);
637 }
638 Expression::RayQueryGetIntersection {
639 ref mut query,
640 committed: _,
641 } => {
642 adjust(query);
643 }
644 Expression::Literal(_)
645 | Expression::FunctionArgument(_)
646 | Expression::GlobalVariable(_)
647 | Expression::LocalVariable(_)
648 | Expression::CallResult(_)
649 | Expression::RayQueryProceedResult
650 | Expression::Constant(_)
651 | Expression::Override(_)
652 | Expression::ZeroValue(_)
653 | Expression::AtomicResult {
654 ty: _,
655 comparison: _,
656 }
657 | Expression::WorkGroupUniformLoadResult { ty: _ }
658 | Expression::SubgroupBallotResult
659 | Expression::SubgroupOperationResult { .. } => {}
660 Expression::RayQueryVertexPositions {
661 ref mut query,
662 committed: _,
663 } => {
664 adjust(query);
665 }
666 Expression::CooperativeLoad { ref mut data, .. } => {
667 adjust(&mut data.pointer);
668 adjust(&mut data.stride);
669 }
670 Expression::CooperativeMultiplyAdd {
671 ref mut a,
672 ref mut b,
673 ref mut c,
674 } => {
675 adjust(a);
676 adjust(b);
677 adjust(c);
678 }
679 }
680}
681
682fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
685 for stmt in block.iter_mut() {
686 adjust_stmt(new_pos, stmt);
687 }
688}
689
690fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
693 let adjust = |expr: &mut Handle<Expression>| {
694 *expr = new_pos[*expr];
695 };
696 match *stmt {
697 Statement::Emit(ref mut range) => {
698 if let Some((mut first, mut last)) = range.first_and_last() {
699 adjust(&mut first);
700 adjust(&mut last);
701 *range = Range::new_from_bounds(first, last);
702 }
703 }
704 Statement::Block(ref mut block) => {
705 adjust_block(new_pos, block);
706 }
707 Statement::If {
708 ref mut condition,
709 ref mut accept,
710 ref mut reject,
711 } => {
712 adjust(condition);
713 adjust_block(new_pos, accept);
714 adjust_block(new_pos, reject);
715 }
716 Statement::Switch {
717 ref mut selector,
718 ref mut cases,
719 } => {
720 adjust(selector);
721 for case in cases.iter_mut() {
722 adjust_block(new_pos, &mut case.body);
723 }
724 }
725 Statement::Loop {
726 ref mut body,
727 ref mut continuing,
728 ref mut break_if,
729 } => {
730 adjust_block(new_pos, body);
731 adjust_block(new_pos, continuing);
732 if let Some(e) = break_if.as_mut() {
733 adjust(e);
734 }
735 }
736 Statement::Return { ref mut value } => {
737 if let Some(e) = value.as_mut() {
738 adjust(e);
739 }
740 }
741 Statement::Store {
742 ref mut pointer,
743 ref mut value,
744 } => {
745 adjust(pointer);
746 adjust(value);
747 }
748 Statement::ImageStore {
749 ref mut image,
750 ref mut coordinate,
751 ref mut array_index,
752 ref mut value,
753 } => {
754 adjust(image);
755 adjust(coordinate);
756 if let Some(e) = array_index.as_mut() {
757 adjust(e);
758 }
759 adjust(value);
760 }
761 Statement::Atomic {
762 ref mut pointer,
763 ref mut value,
764 ref mut result,
765 ref mut fun,
766 } => {
767 adjust(pointer);
768 adjust(value);
769 if let Some(ref mut result) = *result {
770 adjust(result);
771 }
772 match *fun {
773 crate::AtomicFunction::Exchange {
774 compare: Some(ref mut compare),
775 } => {
776 adjust(compare);
777 }
778 crate::AtomicFunction::Add
779 | crate::AtomicFunction::Subtract
780 | crate::AtomicFunction::And
781 | crate::AtomicFunction::ExclusiveOr
782 | crate::AtomicFunction::InclusiveOr
783 | crate::AtomicFunction::Min
784 | crate::AtomicFunction::Max
785 | crate::AtomicFunction::Exchange { compare: None } => {}
786 }
787 }
788 Statement::ImageAtomic {
789 ref mut image,
790 ref mut coordinate,
791 ref mut array_index,
792 fun: _,
793 ref mut value,
794 } => {
795 adjust(image);
796 adjust(coordinate);
797 if let Some(ref mut array_index) = *array_index {
798 adjust(array_index);
799 }
800 adjust(value);
801 }
802 Statement::WorkGroupUniformLoad {
803 ref mut pointer,
804 ref mut result,
805 } => {
806 adjust(pointer);
807 adjust(result);
808 }
809 Statement::SubgroupBallot {
810 ref mut result,
811 ref mut predicate,
812 } => {
813 if let Some(ref mut predicate) = *predicate {
814 adjust(predicate);
815 }
816 adjust(result);
817 }
818 Statement::SubgroupCollectiveOperation {
819 ref mut argument,
820 ref mut result,
821 ..
822 } => {
823 adjust(argument);
824 adjust(result);
825 }
826 Statement::SubgroupGather {
827 ref mut mode,
828 ref mut argument,
829 ref mut result,
830 } => {
831 match *mode {
832 crate::GatherMode::BroadcastFirst => {}
833 crate::GatherMode::Broadcast(ref mut index)
834 | crate::GatherMode::Shuffle(ref mut index)
835 | crate::GatherMode::ShuffleDown(ref mut index)
836 | crate::GatherMode::ShuffleUp(ref mut index)
837 | crate::GatherMode::ShuffleXor(ref mut index)
838 | crate::GatherMode::QuadBroadcast(ref mut index) => {
839 adjust(index);
840 }
841 crate::GatherMode::QuadSwap(_) => {}
842 }
843 adjust(argument);
844 adjust(result)
845 }
846 Statement::Call {
847 ref mut arguments,
848 ref mut result,
849 function: _,
850 } => {
851 for argument in arguments.iter_mut() {
852 adjust(argument);
853 }
854 if let Some(e) = result.as_mut() {
855 adjust(e);
856 }
857 }
858 Statement::RayQuery {
859 ref mut query,
860 ref mut fun,
861 } => {
862 adjust(query);
863 match *fun {
864 crate::RayQueryFunction::Initialize {
865 ref mut acceleration_structure,
866 ref mut descriptor,
867 } => {
868 adjust(acceleration_structure);
869 adjust(descriptor);
870 }
871 crate::RayQueryFunction::Proceed { ref mut result } => {
872 adjust(result);
873 }
874 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
875 adjust(hit_t);
876 }
877 crate::RayQueryFunction::ConfirmIntersection => {}
878 crate::RayQueryFunction::Terminate => {}
879 }
880 }
881 Statement::CooperativeStore {
882 ref mut target,
883 ref mut data,
884 } => {
885 adjust(target);
886 adjust(&mut data.pointer);
887 adjust(&mut data.stride);
888 }
889 Statement::RayPipelineFunction(ref mut func) => match *func {
890 crate::RayPipelineFunction::TraceRay {
891 ref mut acceleration_structure,
892 ref mut descriptor,
893 ref mut payload,
894 } => {
895 adjust(acceleration_structure);
896 adjust(descriptor);
897 adjust(payload);
898 }
899 },
900 Statement::Break
901 | Statement::Continue
902 | Statement::Kill
903 | Statement::ControlBarrier(_)
904 | Statement::MemoryBarrier(_) => {}
905 }
906}
907
908fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
925 let original = mem::replace(block, Block::with_capacity(block.len()));
926 for (stmt, span) in original.span_into_iter() {
927 match stmt {
928 Statement::Emit(range) => {
929 let mut current = None;
930 for expr_h in range {
931 if expressions[expr_h].needs_pre_emit() {
932 if let Some((first, last)) = current {
933 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
934 }
935
936 current = None;
937 } else if let Some((_, ref mut last)) = current {
938 *last = expr_h;
939 } else {
940 current = Some((expr_h, expr_h));
941 }
942 }
943 if let Some((first, last)) = current {
944 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
945 }
946 }
947 Statement::Block(mut child) => {
948 filter_emits_in_block(&mut child, expressions);
949 block.push(Statement::Block(child), span);
950 }
951 Statement::If {
952 condition,
953 mut accept,
954 mut reject,
955 } => {
956 filter_emits_in_block(&mut accept, expressions);
957 filter_emits_in_block(&mut reject, expressions);
958 block.push(
959 Statement::If {
960 condition,
961 accept,
962 reject,
963 },
964 span,
965 );
966 }
967 Statement::Switch {
968 selector,
969 mut cases,
970 } => {
971 for case in &mut cases {
972 filter_emits_in_block(&mut case.body, expressions);
973 }
974 block.push(Statement::Switch { selector, cases }, span);
975 }
976 Statement::Loop {
977 mut body,
978 mut continuing,
979 break_if,
980 } => {
981 filter_emits_in_block(&mut body, expressions);
982 filter_emits_in_block(&mut continuing, expressions);
983 block.push(
984 Statement::Loop {
985 body,
986 continuing,
987 break_if,
988 },
989 span,
990 );
991 }
992 stmt => block.push(stmt.clone(), span),
993 }
994 }
995}
996
997fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
998 match scalar {
1000 Scalar::BOOL => {
1001 let value = value != 0.0 && !value.is_nan();
1003 Ok(Literal::Bool(value))
1004 }
1005 Scalar::I32 => {
1006 if !value.is_finite() {
1008 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1009 }
1010
1011 let value = value.trunc();
1012 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
1013 return Err(PipelineConstantError::DstRangeTooSmall);
1014 }
1015
1016 let value = value as i32;
1017 Ok(Literal::I32(value))
1018 }
1019 Scalar::U32 => {
1020 if !value.is_finite() {
1022 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1023 }
1024
1025 let value = value.trunc();
1026 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1027 return Err(PipelineConstantError::DstRangeTooSmall);
1028 }
1029
1030 let value = value as u32;
1031 Ok(Literal::U32(value))
1032 }
1033 Scalar::F16 => {
1034 if !value.is_finite() {
1036 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1037 }
1038
1039 let value = half::f16::from_f64(value);
1040 if !value.is_finite() {
1041 return Err(PipelineConstantError::DstRangeTooSmall);
1042 }
1043
1044 Ok(Literal::F16(value))
1045 }
1046 Scalar::F32 => {
1047 if !value.is_finite() {
1049 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1050 }
1051
1052 let value = value as f32;
1053 if !value.is_finite() {
1054 return Err(PipelineConstantError::DstRangeTooSmall);
1055 }
1056
1057 Ok(Literal::F32(value))
1058 }
1059 Scalar::F64 => {
1060 if !value.is_finite() {
1062 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1063 }
1064
1065 Ok(Literal::F64(value))
1066 }
1067 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1068 unreachable!("abstract values should not be validated out of override processing")
1069 }
1070 _ => unreachable!("unrecognized scalar type for override"),
1071 }
1072}
1073
1074#[test]
1075fn test_map_value_to_literal() {
1076 let bool_test_cases = [
1077 (0.0, false),
1078 (-0.0, false),
1079 (f64::NAN, false),
1080 (1.0, true),
1081 (f64::INFINITY, true),
1082 (f64::NEG_INFINITY, true),
1083 ];
1084 for (value, out) in bool_test_cases {
1085 let res = Ok(Literal::Bool(out));
1086 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1087 }
1088
1089 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1090 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1091 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1092 assert_eq!(map_value_to_literal(value, scalar), res);
1093 }
1094 }
1095
1096 assert_eq!(
1098 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1099 Ok(Literal::I32(i32::MIN))
1100 );
1101 assert_eq!(
1102 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1103 Ok(Literal::I32(i32::MAX))
1104 );
1105 assert_eq!(
1106 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1107 Err(PipelineConstantError::DstRangeTooSmall)
1108 );
1109 assert_eq!(
1110 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1111 Err(PipelineConstantError::DstRangeTooSmall)
1112 );
1113
1114 assert_eq!(
1116 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1117 Ok(Literal::U32(u32::MIN))
1118 );
1119 assert_eq!(
1120 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1121 Ok(Literal::U32(u32::MAX))
1122 );
1123 assert_eq!(
1124 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1125 Err(PipelineConstantError::DstRangeTooSmall)
1126 );
1127 assert_eq!(
1128 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1129 Err(PipelineConstantError::DstRangeTooSmall)
1130 );
1131
1132 assert_eq!(
1134 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1135 Ok(Literal::F32(f32::MIN))
1136 );
1137 assert_eq!(
1138 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1139 Ok(Literal::F32(f32::MAX))
1140 );
1141 assert_eq!(
1142 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1143 Ok(Literal::F32(f32::MIN))
1144 );
1145 assert_eq!(
1146 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1147 Ok(Literal::F32(f32::MAX))
1148 );
1149 assert_eq!(
1150 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1151 Err(PipelineConstantError::DstRangeTooSmall)
1152 );
1153 assert_eq!(
1154 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1155 Err(PipelineConstantError::DstRangeTooSmall)
1156 );
1157
1158 assert_eq!(
1160 map_value_to_literal(f64::MIN, Scalar::F64),
1161 Ok(Literal::F64(f64::MIN))
1162 );
1163 assert_eq!(
1164 map_value_to_literal(f64::MAX, Scalar::F64),
1165 Ok(Literal::F64(f64::MAX))
1166 );
1167}