1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 compact::{compact, KeepUnused},
14 ir,
15 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18 Span, Statement, TypeInner, WithSpan,
19};
20
21#[allow(unused_imports)]
23use num_traits::float::FloatCore as _;
24
25#[derive(Error, Debug, Clone)]
26#[cfg_attr(test, derive(PartialEq))]
27pub enum PipelineConstantError {
28 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
29 MissingValue(String),
30 #[error(
31 "Source f64 value needs to be finite ({}) for number destinations",
32 "NaNs and Inifinites are not allowed"
33 )]
34 SrcNeedsToBeFinite,
35 #[error("Source f64 value doesn't fit in destination")]
36 DstRangeTooSmall,
37 #[error(transparent)]
38 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
39 #[error(transparent)]
40 ValidationError(#[from] WithSpan<ValidationError>),
41 #[error("workgroup_size override isn't strictly positive")]
42 NegativeWorkgroupSize,
43 #[error("max vertices or max primitives is negative")]
44 NegativeMeshOutputMax,
45}
46
47pub fn process_overrides<'a>(
60 module: &'a Module,
61 module_info: &'a ModuleInfo,
62 entry_point: Option<(ir::ShaderStage, &str)>,
63 pipeline_constants: &PipelineConstants,
64) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
65 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
66 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
76 }
77
78 let mut module = module.clone();
79 if let Some((ep_stage, ep_name)) = entry_point {
80 module
81 .entry_points
82 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
83 }
84
85 compact(&mut module, KeepUnused::No);
89
90 if module.overrides.is_empty() {
92 return revalidate(module);
93 }
94
95 let mut override_map = HandleVec::with_capacity(module.overrides.len());
98
99 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
102
103 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
114
115 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
116 let mut layouter = crate::proc::Layouter::default();
117
118 let mut overrides = mem::take(&mut module.overrides);
121 let mut override_iter = overrides.iter_mut_span();
122
123 for (old_h, expr, span) in module.global_expressions.drain() {
153 let mut expr = match expr {
154 Expression::Override(h) => {
155 let c_h = if let Some(new_h) = override_map.get(h) {
156 *new_h
157 } else {
158 let mut new_h = None;
159 for entry in override_iter.by_ref() {
160 let stop = entry.0 == h;
161 new_h = Some(process_override(
162 entry,
163 pipeline_constants,
164 &mut module,
165 &mut override_map,
166 &adjusted_global_expressions,
167 &mut adjusted_constant_initializers,
168 &mut global_expression_kind_tracker,
169 )?);
170 if stop {
171 break;
172 }
173 }
174 new_h.unwrap()
175 };
176 Expression::Constant(c_h)
177 }
178 Expression::Constant(c_h) => {
179 if adjusted_constant_initializers.insert(c_h) {
180 let init = &mut module.constants[c_h].init;
181 *init = adjusted_global_expressions[*init];
182 }
183 expr
184 }
185 expr => expr,
186 };
187 let mut evaluator = ConstantEvaluator::for_wgsl_module(
188 &mut module,
189 &mut global_expression_kind_tracker,
190 &mut layouter,
191 false,
192 );
193 adjust_expr(&adjusted_global_expressions, &mut expr);
194 let h = evaluator.try_eval_and_append(expr, span)?;
195 adjusted_global_expressions.insert(old_h, h);
196 }
197
198 for entry in override_iter {
200 match *entry.1 {
201 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
202 process_override(
203 entry,
204 pipeline_constants,
205 &mut module,
206 &mut override_map,
207 &adjusted_global_expressions,
208 &mut adjusted_constant_initializers,
209 &mut global_expression_kind_tracker,
210 )?;
211 }
212 Override {
213 init: Some(ref mut init),
214 ..
215 } => {
216 *init = adjusted_global_expressions[*init];
217 }
218 _ => {}
219 }
220 }
221
222 for (_, c) in module
226 .constants
227 .iter_mut()
228 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
229 {
230 c.init = adjusted_global_expressions[c.init];
231 }
232
233 for (_, v) in module.global_variables.iter_mut() {
234 if let Some(ref mut init) = v.init {
235 *init = adjusted_global_expressions[*init];
236 }
237 }
238
239 let mut functions = mem::take(&mut module.functions);
240 for (_, function) in functions.iter_mut() {
241 process_function(&mut module, &override_map, &mut layouter, function)?;
242 }
243 module.functions = functions;
244
245 let mut entry_points = mem::take(&mut module.entry_points);
246 for ep in entry_points.iter_mut() {
247 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
248 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
249 process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
250 }
251 module.entry_points = entry_points;
252 module.overrides = overrides;
253
254 revalidate(module)
258}
259
260fn revalidate(
261 module: Module,
262) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
263 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
264 let module_info = validator.validate_resolved_overrides(&module)?;
265 Ok((Cow::Owned(module), Cow::Owned(module_info)))
266}
267
268fn process_workgroup_size_override(
269 module: &mut Module,
270 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
271 ep: &mut crate::EntryPoint,
272) -> Result<(), PipelineConstantError> {
273 match ep.workgroup_size_overrides {
274 None => {}
275 Some(overrides) => {
276 overrides.iter().enumerate().try_for_each(
277 |(i, overridden)| -> Result<(), PipelineConstantError> {
278 match *overridden {
279 None => Ok(()),
280 Some(h) => {
281 ep.workgroup_size[i] = module
282 .to_ctx()
283 .get_const_val(adjusted_global_expressions[h])
284 .map(|n| {
285 if n == 0 {
286 Err(PipelineConstantError::NegativeWorkgroupSize)
287 } else {
288 Ok(n)
289 }
290 })
291 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
292 Ok(())
293 }
294 }
295 },
296 )?;
297 ep.workgroup_size_overrides = None;
298 }
299 }
300 Ok(())
301}
302
303fn process_mesh_shader_overrides(
304 module: &mut Module,
305 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
306 ep: &mut crate::EntryPoint,
307) -> Result<(), PipelineConstantError> {
308 if let Some(ref mut mesh_info) = ep.mesh_info {
309 if let Some(r#override) = mesh_info.max_vertices_override {
310 mesh_info.max_vertices = module
311 .to_ctx()
312 .get_const_val(adjusted_global_expressions[r#override])
313 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
314 }
315 if let Some(r#override) = mesh_info.max_primitives_override {
316 mesh_info.max_primitives = module
317 .to_ctx()
318 .get_const_val(adjusted_global_expressions[r#override])
319 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
320 }
321 }
322 Ok(())
323}
324
325fn process_override(
329 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
330 pipeline_constants: &PipelineConstants,
331 module: &mut Module,
332 override_map: &mut HandleVec<Override, Handle<Constant>>,
333 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
334 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
335 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
336) -> Result<Handle<Constant>, PipelineConstantError> {
337 let key = if let Some(id) = r#override.id {
339 Cow::Owned(id.to_string())
340 } else if let Some(ref name) = r#override.name {
341 Cow::Borrowed(name)
342 } else {
343 unreachable!();
344 };
345
346 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
350 let literal = match module.types[r#override.ty].inner {
351 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
352 _ => unreachable!(),
353 };
354 let expr = module
355 .global_expressions
356 .append(Expression::Literal(literal), Span::UNDEFINED);
357 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
358 expr
359 } else if let Some(init) = r#override.init {
360 adjusted_global_expressions[init]
361 } else {
362 return Err(PipelineConstantError::MissingValue(key.to_string()));
363 };
364
365 let constant = Constant {
367 name: r#override.name.clone(),
368 ty: r#override.ty,
369 init,
370 };
371 let h = module.constants.append(constant, *span);
372 override_map.insert(old_h, h);
373 adjusted_constant_initializers.insert(h);
374 r#override.init = Some(init);
375 Ok(h)
376}
377
378fn process_function(
388 module: &mut Module,
389 override_map: &HandleVec<Override, Handle<Constant>>,
390 layouter: &mut crate::proc::Layouter,
391 function: &mut Function,
392) -> Result<(), ConstantEvaluatorError> {
393 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
396
397 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
398
399 let mut expressions = mem::take(&mut function.expressions);
400
401 let mut emitter = Emitter::default();
409 let mut block = Block::new();
410
411 let mut evaluator = ConstantEvaluator::for_wgsl_function(
412 module,
413 &mut function.expressions,
414 &mut local_expression_kind_tracker,
415 layouter,
416 &mut emitter,
417 &mut block,
418 false,
419 );
420
421 for (old_h, mut expr, span) in expressions.drain() {
422 if let Expression::Override(h) = expr {
423 expr = Expression::Constant(override_map[h]);
424 }
425 adjust_expr(&adjusted_local_expressions, &mut expr);
426 let h = evaluator.try_eval_and_append(expr, span)?;
427 adjusted_local_expressions.insert(old_h, h);
428 }
429
430 adjust_block(&adjusted_local_expressions, &mut function.body);
431
432 filter_emits_in_block(&mut function.body, &function.expressions);
433
434 for (_, local) in function.local_variables.iter_mut() {
436 if let &mut Some(ref mut init) = &mut local.init {
437 *init = adjusted_local_expressions[*init];
438 }
439 }
440
441 let named_expressions = mem::take(&mut function.named_expressions);
444 for (expr_h, name) in named_expressions {
445 function
446 .named_expressions
447 .insert(adjusted_local_expressions[expr_h], name);
448 }
449
450 Ok(())
451}
452
453fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
456 let adjust = |expr: &mut Handle<Expression>| {
457 *expr = new_pos[*expr];
458 };
459 match *expr {
460 Expression::Compose {
461 ref mut components,
462 ty: _,
463 } => {
464 for c in components.iter_mut() {
465 adjust(c);
466 }
467 }
468 Expression::Access {
469 ref mut base,
470 ref mut index,
471 } => {
472 adjust(base);
473 adjust(index);
474 }
475 Expression::AccessIndex {
476 ref mut base,
477 index: _,
478 } => {
479 adjust(base);
480 }
481 Expression::Splat {
482 ref mut value,
483 size: _,
484 } => {
485 adjust(value);
486 }
487 Expression::Swizzle {
488 ref mut vector,
489 size: _,
490 pattern: _,
491 } => {
492 adjust(vector);
493 }
494 Expression::Load { ref mut pointer } => {
495 adjust(pointer);
496 }
497 Expression::ImageSample {
498 ref mut image,
499 ref mut sampler,
500 ref mut coordinate,
501 ref mut array_index,
502 ref mut offset,
503 ref mut level,
504 ref mut depth_ref,
505 gather: _,
506 clamp_to_edge: _,
507 } => {
508 adjust(image);
509 adjust(sampler);
510 adjust(coordinate);
511 if let Some(e) = array_index.as_mut() {
512 adjust(e);
513 }
514 if let Some(e) = offset.as_mut() {
515 adjust(e);
516 }
517 match *level {
518 crate::SampleLevel::Exact(ref mut expr)
519 | crate::SampleLevel::Bias(ref mut expr) => {
520 adjust(expr);
521 }
522 crate::SampleLevel::Gradient {
523 ref mut x,
524 ref mut y,
525 } => {
526 adjust(x);
527 adjust(y);
528 }
529 _ => {}
530 }
531 if let Some(e) = depth_ref.as_mut() {
532 adjust(e);
533 }
534 }
535 Expression::ImageLoad {
536 ref mut image,
537 ref mut coordinate,
538 ref mut array_index,
539 ref mut sample,
540 ref mut level,
541 } => {
542 adjust(image);
543 adjust(coordinate);
544 if let Some(e) = array_index.as_mut() {
545 adjust(e);
546 }
547 if let Some(e) = sample.as_mut() {
548 adjust(e);
549 }
550 if let Some(e) = level.as_mut() {
551 adjust(e);
552 }
553 }
554 Expression::ImageQuery {
555 ref mut image,
556 ref mut query,
557 } => {
558 adjust(image);
559 match *query {
560 crate::ImageQuery::Size { ref mut level } => {
561 if let Some(e) = level.as_mut() {
562 adjust(e);
563 }
564 }
565 crate::ImageQuery::NumLevels
566 | crate::ImageQuery::NumLayers
567 | crate::ImageQuery::NumSamples => {}
568 }
569 }
570 Expression::Unary {
571 ref mut expr,
572 op: _,
573 } => {
574 adjust(expr);
575 }
576 Expression::Binary {
577 ref mut left,
578 ref mut right,
579 op: _,
580 } => {
581 adjust(left);
582 adjust(right);
583 }
584 Expression::Select {
585 ref mut condition,
586 ref mut accept,
587 ref mut reject,
588 } => {
589 adjust(condition);
590 adjust(accept);
591 adjust(reject);
592 }
593 Expression::Derivative {
594 ref mut expr,
595 axis: _,
596 ctrl: _,
597 } => {
598 adjust(expr);
599 }
600 Expression::Relational {
601 ref mut argument,
602 fun: _,
603 } => {
604 adjust(argument);
605 }
606 Expression::Math {
607 ref mut arg,
608 ref mut arg1,
609 ref mut arg2,
610 ref mut arg3,
611 fun: _,
612 } => {
613 adjust(arg);
614 if let Some(e) = arg1.as_mut() {
615 adjust(e);
616 }
617 if let Some(e) = arg2.as_mut() {
618 adjust(e);
619 }
620 if let Some(e) = arg3.as_mut() {
621 adjust(e);
622 }
623 }
624 Expression::As {
625 ref mut expr,
626 kind: _,
627 convert: _,
628 } => {
629 adjust(expr);
630 }
631 Expression::ArrayLength(ref mut expr) => {
632 adjust(expr);
633 }
634 Expression::RayQueryGetIntersection {
635 ref mut query,
636 committed: _,
637 } => {
638 adjust(query);
639 }
640 Expression::Literal(_)
641 | Expression::FunctionArgument(_)
642 | Expression::GlobalVariable(_)
643 | Expression::LocalVariable(_)
644 | Expression::CallResult(_)
645 | Expression::RayQueryProceedResult
646 | Expression::Constant(_)
647 | Expression::Override(_)
648 | Expression::ZeroValue(_)
649 | Expression::AtomicResult {
650 ty: _,
651 comparison: _,
652 }
653 | Expression::WorkGroupUniformLoadResult { ty: _ }
654 | Expression::SubgroupBallotResult
655 | Expression::SubgroupOperationResult { .. } => {}
656 Expression::RayQueryVertexPositions {
657 ref mut query,
658 committed: _,
659 } => {
660 adjust(query);
661 }
662 Expression::CooperativeLoad { ref mut data, .. } => {
663 adjust(&mut data.pointer);
664 adjust(&mut data.stride);
665 }
666 Expression::CooperativeMultiplyAdd {
667 ref mut a,
668 ref mut b,
669 ref mut c,
670 } => {
671 adjust(a);
672 adjust(b);
673 adjust(c);
674 }
675 }
676}
677
678fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
681 for stmt in block.iter_mut() {
682 adjust_stmt(new_pos, stmt);
683 }
684}
685
686fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
689 let adjust = |expr: &mut Handle<Expression>| {
690 *expr = new_pos[*expr];
691 };
692 match *stmt {
693 Statement::Emit(ref mut range) => {
694 if let Some((mut first, mut last)) = range.first_and_last() {
695 adjust(&mut first);
696 adjust(&mut last);
697 *range = Range::new_from_bounds(first, last);
698 }
699 }
700 Statement::Block(ref mut block) => {
701 adjust_block(new_pos, block);
702 }
703 Statement::If {
704 ref mut condition,
705 ref mut accept,
706 ref mut reject,
707 } => {
708 adjust(condition);
709 adjust_block(new_pos, accept);
710 adjust_block(new_pos, reject);
711 }
712 Statement::Switch {
713 ref mut selector,
714 ref mut cases,
715 } => {
716 adjust(selector);
717 for case in cases.iter_mut() {
718 adjust_block(new_pos, &mut case.body);
719 }
720 }
721 Statement::Loop {
722 ref mut body,
723 ref mut continuing,
724 ref mut break_if,
725 } => {
726 adjust_block(new_pos, body);
727 adjust_block(new_pos, continuing);
728 if let Some(e) = break_if.as_mut() {
729 adjust(e);
730 }
731 }
732 Statement::Return { ref mut value } => {
733 if let Some(e) = value.as_mut() {
734 adjust(e);
735 }
736 }
737 Statement::Store {
738 ref mut pointer,
739 ref mut value,
740 } => {
741 adjust(pointer);
742 adjust(value);
743 }
744 Statement::ImageStore {
745 ref mut image,
746 ref mut coordinate,
747 ref mut array_index,
748 ref mut value,
749 } => {
750 adjust(image);
751 adjust(coordinate);
752 if let Some(e) = array_index.as_mut() {
753 adjust(e);
754 }
755 adjust(value);
756 }
757 Statement::Atomic {
758 ref mut pointer,
759 ref mut value,
760 ref mut result,
761 ref mut fun,
762 } => {
763 adjust(pointer);
764 adjust(value);
765 if let Some(ref mut result) = *result {
766 adjust(result);
767 }
768 match *fun {
769 crate::AtomicFunction::Exchange {
770 compare: Some(ref mut compare),
771 } => {
772 adjust(compare);
773 }
774 crate::AtomicFunction::Add
775 | crate::AtomicFunction::Subtract
776 | crate::AtomicFunction::And
777 | crate::AtomicFunction::ExclusiveOr
778 | crate::AtomicFunction::InclusiveOr
779 | crate::AtomicFunction::Min
780 | crate::AtomicFunction::Max
781 | crate::AtomicFunction::Exchange { compare: None } => {}
782 }
783 }
784 Statement::ImageAtomic {
785 ref mut image,
786 ref mut coordinate,
787 ref mut array_index,
788 fun: _,
789 ref mut value,
790 } => {
791 adjust(image);
792 adjust(coordinate);
793 if let Some(ref mut array_index) = *array_index {
794 adjust(array_index);
795 }
796 adjust(value);
797 }
798 Statement::WorkGroupUniformLoad {
799 ref mut pointer,
800 ref mut result,
801 } => {
802 adjust(pointer);
803 adjust(result);
804 }
805 Statement::SubgroupBallot {
806 ref mut result,
807 ref mut predicate,
808 } => {
809 if let Some(ref mut predicate) = *predicate {
810 adjust(predicate);
811 }
812 adjust(result);
813 }
814 Statement::SubgroupCollectiveOperation {
815 ref mut argument,
816 ref mut result,
817 ..
818 } => {
819 adjust(argument);
820 adjust(result);
821 }
822 Statement::SubgroupGather {
823 ref mut mode,
824 ref mut argument,
825 ref mut result,
826 } => {
827 match *mode {
828 crate::GatherMode::BroadcastFirst => {}
829 crate::GatherMode::Broadcast(ref mut index)
830 | crate::GatherMode::Shuffle(ref mut index)
831 | crate::GatherMode::ShuffleDown(ref mut index)
832 | crate::GatherMode::ShuffleUp(ref mut index)
833 | crate::GatherMode::ShuffleXor(ref mut index)
834 | crate::GatherMode::QuadBroadcast(ref mut index) => {
835 adjust(index);
836 }
837 crate::GatherMode::QuadSwap(_) => {}
838 }
839 adjust(argument);
840 adjust(result)
841 }
842 Statement::Call {
843 ref mut arguments,
844 ref mut result,
845 function: _,
846 } => {
847 for argument in arguments.iter_mut() {
848 adjust(argument);
849 }
850 if let Some(e) = result.as_mut() {
851 adjust(e);
852 }
853 }
854 Statement::RayQuery {
855 ref mut query,
856 ref mut fun,
857 } => {
858 adjust(query);
859 match *fun {
860 crate::RayQueryFunction::Initialize {
861 ref mut acceleration_structure,
862 ref mut descriptor,
863 } => {
864 adjust(acceleration_structure);
865 adjust(descriptor);
866 }
867 crate::RayQueryFunction::Proceed { ref mut result } => {
868 adjust(result);
869 }
870 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
871 adjust(hit_t);
872 }
873 crate::RayQueryFunction::ConfirmIntersection => {}
874 crate::RayQueryFunction::Terminate => {}
875 }
876 }
877 Statement::CooperativeStore {
878 ref mut target,
879 ref mut data,
880 } => {
881 adjust(target);
882 adjust(&mut data.pointer);
883 adjust(&mut data.stride);
884 }
885 Statement::Break
886 | Statement::Continue
887 | Statement::Kill
888 | Statement::ControlBarrier(_)
889 | Statement::MemoryBarrier(_) => {}
890 }
891}
892
893fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
910 let original = mem::replace(block, Block::with_capacity(block.len()));
911 for (stmt, span) in original.span_into_iter() {
912 match stmt {
913 Statement::Emit(range) => {
914 let mut current = None;
915 for expr_h in range {
916 if expressions[expr_h].needs_pre_emit() {
917 if let Some((first, last)) = current {
918 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
919 }
920
921 current = None;
922 } else if let Some((_, ref mut last)) = current {
923 *last = expr_h;
924 } else {
925 current = Some((expr_h, expr_h));
926 }
927 }
928 if let Some((first, last)) = current {
929 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
930 }
931 }
932 Statement::Block(mut child) => {
933 filter_emits_in_block(&mut child, expressions);
934 block.push(Statement::Block(child), span);
935 }
936 Statement::If {
937 condition,
938 mut accept,
939 mut reject,
940 } => {
941 filter_emits_in_block(&mut accept, expressions);
942 filter_emits_in_block(&mut reject, expressions);
943 block.push(
944 Statement::If {
945 condition,
946 accept,
947 reject,
948 },
949 span,
950 );
951 }
952 Statement::Switch {
953 selector,
954 mut cases,
955 } => {
956 for case in &mut cases {
957 filter_emits_in_block(&mut case.body, expressions);
958 }
959 block.push(Statement::Switch { selector, cases }, span);
960 }
961 Statement::Loop {
962 mut body,
963 mut continuing,
964 break_if,
965 } => {
966 filter_emits_in_block(&mut body, expressions);
967 filter_emits_in_block(&mut continuing, expressions);
968 block.push(
969 Statement::Loop {
970 body,
971 continuing,
972 break_if,
973 },
974 span,
975 );
976 }
977 stmt => block.push(stmt.clone(), span),
978 }
979 }
980}
981
982fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
983 match scalar {
985 Scalar::BOOL => {
986 let value = value != 0.0 && !value.is_nan();
988 Ok(Literal::Bool(value))
989 }
990 Scalar::I32 => {
991 if !value.is_finite() {
993 return Err(PipelineConstantError::SrcNeedsToBeFinite);
994 }
995
996 let value = value.trunc();
997 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
998 return Err(PipelineConstantError::DstRangeTooSmall);
999 }
1000
1001 let value = value as i32;
1002 Ok(Literal::I32(value))
1003 }
1004 Scalar::U32 => {
1005 if !value.is_finite() {
1007 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1008 }
1009
1010 let value = value.trunc();
1011 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1012 return Err(PipelineConstantError::DstRangeTooSmall);
1013 }
1014
1015 let value = value as u32;
1016 Ok(Literal::U32(value))
1017 }
1018 Scalar::F16 => {
1019 if !value.is_finite() {
1021 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1022 }
1023
1024 let value = half::f16::from_f64(value);
1025 if !value.is_finite() {
1026 return Err(PipelineConstantError::DstRangeTooSmall);
1027 }
1028
1029 Ok(Literal::F16(value))
1030 }
1031 Scalar::F32 => {
1032 if !value.is_finite() {
1034 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1035 }
1036
1037 let value = value as f32;
1038 if !value.is_finite() {
1039 return Err(PipelineConstantError::DstRangeTooSmall);
1040 }
1041
1042 Ok(Literal::F32(value))
1043 }
1044 Scalar::F64 => {
1045 if !value.is_finite() {
1047 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1048 }
1049
1050 Ok(Literal::F64(value))
1051 }
1052 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1053 unreachable!("abstract values should not be validated out of override processing")
1054 }
1055 _ => unreachable!("unrecognized scalar type for override"),
1056 }
1057}
1058
1059#[test]
1060fn test_map_value_to_literal() {
1061 let bool_test_cases = [
1062 (0.0, false),
1063 (-0.0, false),
1064 (f64::NAN, false),
1065 (1.0, true),
1066 (f64::INFINITY, true),
1067 (f64::NEG_INFINITY, true),
1068 ];
1069 for (value, out) in bool_test_cases {
1070 let res = Ok(Literal::Bool(out));
1071 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1072 }
1073
1074 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1075 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1076 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1077 assert_eq!(map_value_to_literal(value, scalar), res);
1078 }
1079 }
1080
1081 assert_eq!(
1083 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1084 Ok(Literal::I32(i32::MIN))
1085 );
1086 assert_eq!(
1087 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1088 Ok(Literal::I32(i32::MAX))
1089 );
1090 assert_eq!(
1091 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1092 Err(PipelineConstantError::DstRangeTooSmall)
1093 );
1094 assert_eq!(
1095 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1096 Err(PipelineConstantError::DstRangeTooSmall)
1097 );
1098
1099 assert_eq!(
1101 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1102 Ok(Literal::U32(u32::MIN))
1103 );
1104 assert_eq!(
1105 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1106 Ok(Literal::U32(u32::MAX))
1107 );
1108 assert_eq!(
1109 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1110 Err(PipelineConstantError::DstRangeTooSmall)
1111 );
1112 assert_eq!(
1113 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1114 Err(PipelineConstantError::DstRangeTooSmall)
1115 );
1116
1117 assert_eq!(
1119 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1120 Ok(Literal::F32(f32::MIN))
1121 );
1122 assert_eq!(
1123 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1124 Ok(Literal::F32(f32::MAX))
1125 );
1126 assert_eq!(
1127 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1128 Ok(Literal::F32(f32::MIN))
1129 );
1130 assert_eq!(
1131 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1132 Ok(Literal::F32(f32::MAX))
1133 );
1134 assert_eq!(
1135 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1136 Err(PipelineConstantError::DstRangeTooSmall)
1137 );
1138 assert_eq!(
1139 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1140 Err(PipelineConstantError::DstRangeTooSmall)
1141 );
1142
1143 assert_eq!(
1145 map_value_to_literal(f64::MIN, Scalar::F64),
1146 Ok(Literal::F64(f64::MIN))
1147 );
1148 assert_eq!(
1149 map_value_to_literal(f64::MAX, Scalar::F64),
1150 Ok(Literal::F64(f64::MAX))
1151 );
1152}