1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 compact::{compact, KeepUnused},
14 ir,
15 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18 Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28 MissingValue(String),
29 #[error(
30 "Source f64 value needs to be finite ({}) for number destinations",
31 "NaNs and Inifinites are not allowed"
32 )]
33 SrcNeedsToBeFinite,
34 #[error("Source f64 value doesn't fit in destination")]
35 DstRangeTooSmall,
36 #[error(transparent)]
37 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38 #[error(transparent)]
39 ValidationError(#[from] WithSpan<ValidationError>),
40 #[error("workgroup_size override isn't strictly positive")]
41 NegativeWorkgroupSize,
42 #[error("max vertices or max primitives is negative")]
43 NegativeMeshOutputMax,
44}
45
46pub fn process_overrides<'a>(
59 module: &'a Module,
60 module_info: &'a ModuleInfo,
61 entry_point: Option<(ir::ShaderStage, &str)>,
62 pipeline_constants: &PipelineConstants,
63) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
64 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
65 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
75 }
76
77 let mut module = module.clone();
78 if let Some((ep_stage, ep_name)) = entry_point {
79 module
80 .entry_points
81 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
82 }
83
84 compact(&mut module, KeepUnused::No);
88
89 if module.overrides.is_empty() {
91 return revalidate(module);
92 }
93
94 let mut override_map = HandleVec::with_capacity(module.overrides.len());
97
98 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
101
102 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
113
114 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
115 let mut layouter = crate::proc::Layouter::default();
116
117 let mut overrides = mem::take(&mut module.overrides);
120 let mut override_iter = overrides.iter_mut_span();
121
122 for (old_h, expr, span) in module.global_expressions.drain() {
152 let mut expr = match expr {
153 Expression::Override(h) => {
154 let c_h = if let Some(new_h) = override_map.get(h) {
155 *new_h
156 } else {
157 let mut new_h = None;
158 for entry in override_iter.by_ref() {
159 let stop = entry.0 == h;
160 new_h = Some(process_override(
161 entry,
162 pipeline_constants,
163 &mut module,
164 &mut override_map,
165 &adjusted_global_expressions,
166 &mut adjusted_constant_initializers,
167 &mut global_expression_kind_tracker,
168 )?);
169 if stop {
170 break;
171 }
172 }
173 new_h.unwrap()
174 };
175 Expression::Constant(c_h)
176 }
177 Expression::Constant(c_h) => {
178 if adjusted_constant_initializers.insert(c_h) {
179 let init = &mut module.constants[c_h].init;
180 *init = adjusted_global_expressions[*init];
181 }
182 expr
183 }
184 expr => expr,
185 };
186 let mut evaluator = ConstantEvaluator::for_wgsl_module(
187 &mut module,
188 &mut global_expression_kind_tracker,
189 &mut layouter,
190 false,
191 );
192 adjust_expr(&adjusted_global_expressions, &mut expr);
193 let h = evaluator.try_eval_and_append(expr, span)?;
194 adjusted_global_expressions.insert(old_h, h);
195 }
196
197 for entry in override_iter {
199 match *entry.1 {
200 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
201 process_override(
202 entry,
203 pipeline_constants,
204 &mut module,
205 &mut override_map,
206 &adjusted_global_expressions,
207 &mut adjusted_constant_initializers,
208 &mut global_expression_kind_tracker,
209 )?;
210 }
211 Override {
212 init: Some(ref mut init),
213 ..
214 } => {
215 *init = adjusted_global_expressions[*init];
216 }
217 _ => {}
218 }
219 }
220
221 for (_, c) in module
225 .constants
226 .iter_mut()
227 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
228 {
229 c.init = adjusted_global_expressions[c.init];
230 }
231
232 for (_, v) in module.global_variables.iter_mut() {
233 if let Some(ref mut init) = v.init {
234 *init = adjusted_global_expressions[*init];
235 }
236 }
237
238 let mut functions = mem::take(&mut module.functions);
239 for (_, function) in functions.iter_mut() {
240 process_function(&mut module, &override_map, &mut layouter, function)?;
241 }
242 module.functions = functions;
243
244 let mut entry_points = mem::take(&mut module.entry_points);
245 for ep in entry_points.iter_mut() {
246 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
247 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
248 process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
249 }
250 module.entry_points = entry_points;
251 module.overrides = overrides;
252
253 revalidate(module)
257}
258
259fn revalidate(
260 module: Module,
261) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
262 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
263 let module_info = validator.validate_resolved_overrides(&module)?;
264 Ok((Cow::Owned(module), Cow::Owned(module_info)))
265}
266
267fn process_workgroup_size_override(
268 module: &mut Module,
269 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
270 ep: &mut crate::EntryPoint,
271) -> Result<(), PipelineConstantError> {
272 match ep.workgroup_size_overrides {
273 None => {}
274 Some(overrides) => {
275 overrides.iter().enumerate().try_for_each(
276 |(i, overridden)| -> Result<(), PipelineConstantError> {
277 match *overridden {
278 None => Ok(()),
279 Some(h) => {
280 ep.workgroup_size[i] = module
281 .to_ctx()
282 .eval_expr_to_u32(adjusted_global_expressions[h])
283 .map(|n| {
284 if n == 0 {
285 Err(PipelineConstantError::NegativeWorkgroupSize)
286 } else {
287 Ok(n)
288 }
289 })
290 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
291 Ok(())
292 }
293 }
294 },
295 )?;
296 ep.workgroup_size_overrides = None;
297 }
298 }
299 Ok(())
300}
301
302fn process_mesh_shader_overrides(
303 module: &mut Module,
304 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
305 ep: &mut crate::EntryPoint,
306) -> Result<(), PipelineConstantError> {
307 if let Some(ref mut mesh_info) = ep.mesh_info {
308 if let Some(r#override) = mesh_info.max_vertices_override {
309 mesh_info.max_vertices = module
310 .to_ctx()
311 .eval_expr_to_u32(adjusted_global_expressions[r#override])
312 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
313 }
314 if let Some(r#override) = mesh_info.max_primitives_override {
315 mesh_info.max_primitives = module
316 .to_ctx()
317 .eval_expr_to_u32(adjusted_global_expressions[r#override])
318 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
319 }
320 }
321 Ok(())
322}
323
324fn process_override(
328 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
329 pipeline_constants: &PipelineConstants,
330 module: &mut Module,
331 override_map: &mut HandleVec<Override, Handle<Constant>>,
332 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
333 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
334 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
335) -> Result<Handle<Constant>, PipelineConstantError> {
336 let key = if let Some(id) = r#override.id {
338 Cow::Owned(id.to_string())
339 } else if let Some(ref name) = r#override.name {
340 Cow::Borrowed(name)
341 } else {
342 unreachable!();
343 };
344
345 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
349 let literal = match module.types[r#override.ty].inner {
350 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
351 _ => unreachable!(),
352 };
353 let expr = module
354 .global_expressions
355 .append(Expression::Literal(literal), Span::UNDEFINED);
356 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
357 expr
358 } else if let Some(init) = r#override.init {
359 adjusted_global_expressions[init]
360 } else {
361 return Err(PipelineConstantError::MissingValue(key.to_string()));
362 };
363
364 let constant = Constant {
366 name: r#override.name.clone(),
367 ty: r#override.ty,
368 init,
369 };
370 let h = module.constants.append(constant, *span);
371 override_map.insert(old_h, h);
372 adjusted_constant_initializers.insert(h);
373 r#override.init = Some(init);
374 Ok(h)
375}
376
377fn process_function(
387 module: &mut Module,
388 override_map: &HandleVec<Override, Handle<Constant>>,
389 layouter: &mut crate::proc::Layouter,
390 function: &mut Function,
391) -> Result<(), ConstantEvaluatorError> {
392 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
395
396 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
397
398 let mut expressions = mem::take(&mut function.expressions);
399
400 let mut emitter = Emitter::default();
408 let mut block = Block::new();
409
410 let mut evaluator = ConstantEvaluator::for_wgsl_function(
411 module,
412 &mut function.expressions,
413 &mut local_expression_kind_tracker,
414 layouter,
415 &mut emitter,
416 &mut block,
417 false,
418 );
419
420 for (old_h, mut expr, span) in expressions.drain() {
421 if let Expression::Override(h) = expr {
422 expr = Expression::Constant(override_map[h]);
423 }
424 adjust_expr(&adjusted_local_expressions, &mut expr);
425 let h = evaluator.try_eval_and_append(expr, span)?;
426 adjusted_local_expressions.insert(old_h, h);
427 }
428
429 adjust_block(&adjusted_local_expressions, &mut function.body);
430
431 filter_emits_in_block(&mut function.body, &function.expressions);
432
433 for (_, local) in function.local_variables.iter_mut() {
435 if let &mut Some(ref mut init) = &mut local.init {
436 *init = adjusted_local_expressions[*init];
437 }
438 }
439
440 let named_expressions = mem::take(&mut function.named_expressions);
443 for (expr_h, name) in named_expressions {
444 function
445 .named_expressions
446 .insert(adjusted_local_expressions[expr_h], name);
447 }
448
449 Ok(())
450}
451
452fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
455 let adjust = |expr: &mut Handle<Expression>| {
456 *expr = new_pos[*expr];
457 };
458 match *expr {
459 Expression::Compose {
460 ref mut components,
461 ty: _,
462 } => {
463 for c in components.iter_mut() {
464 adjust(c);
465 }
466 }
467 Expression::Access {
468 ref mut base,
469 ref mut index,
470 } => {
471 adjust(base);
472 adjust(index);
473 }
474 Expression::AccessIndex {
475 ref mut base,
476 index: _,
477 } => {
478 adjust(base);
479 }
480 Expression::Splat {
481 ref mut value,
482 size: _,
483 } => {
484 adjust(value);
485 }
486 Expression::Swizzle {
487 ref mut vector,
488 size: _,
489 pattern: _,
490 } => {
491 adjust(vector);
492 }
493 Expression::Load { ref mut pointer } => {
494 adjust(pointer);
495 }
496 Expression::ImageSample {
497 ref mut image,
498 ref mut sampler,
499 ref mut coordinate,
500 ref mut array_index,
501 ref mut offset,
502 ref mut level,
503 ref mut depth_ref,
504 gather: _,
505 clamp_to_edge: _,
506 } => {
507 adjust(image);
508 adjust(sampler);
509 adjust(coordinate);
510 if let Some(e) = array_index.as_mut() {
511 adjust(e);
512 }
513 if let Some(e) = offset.as_mut() {
514 adjust(e);
515 }
516 match *level {
517 crate::SampleLevel::Exact(ref mut expr)
518 | crate::SampleLevel::Bias(ref mut expr) => {
519 adjust(expr);
520 }
521 crate::SampleLevel::Gradient {
522 ref mut x,
523 ref mut y,
524 } => {
525 adjust(x);
526 adjust(y);
527 }
528 _ => {}
529 }
530 if let Some(e) = depth_ref.as_mut() {
531 adjust(e);
532 }
533 }
534 Expression::ImageLoad {
535 ref mut image,
536 ref mut coordinate,
537 ref mut array_index,
538 ref mut sample,
539 ref mut level,
540 } => {
541 adjust(image);
542 adjust(coordinate);
543 if let Some(e) = array_index.as_mut() {
544 adjust(e);
545 }
546 if let Some(e) = sample.as_mut() {
547 adjust(e);
548 }
549 if let Some(e) = level.as_mut() {
550 adjust(e);
551 }
552 }
553 Expression::ImageQuery {
554 ref mut image,
555 ref mut query,
556 } => {
557 adjust(image);
558 match *query {
559 crate::ImageQuery::Size { ref mut level } => {
560 if let Some(e) = level.as_mut() {
561 adjust(e);
562 }
563 }
564 crate::ImageQuery::NumLevels
565 | crate::ImageQuery::NumLayers
566 | crate::ImageQuery::NumSamples => {}
567 }
568 }
569 Expression::Unary {
570 ref mut expr,
571 op: _,
572 } => {
573 adjust(expr);
574 }
575 Expression::Binary {
576 ref mut left,
577 ref mut right,
578 op: _,
579 } => {
580 adjust(left);
581 adjust(right);
582 }
583 Expression::Select {
584 ref mut condition,
585 ref mut accept,
586 ref mut reject,
587 } => {
588 adjust(condition);
589 adjust(accept);
590 adjust(reject);
591 }
592 Expression::Derivative {
593 ref mut expr,
594 axis: _,
595 ctrl: _,
596 } => {
597 adjust(expr);
598 }
599 Expression::Relational {
600 ref mut argument,
601 fun: _,
602 } => {
603 adjust(argument);
604 }
605 Expression::Math {
606 ref mut arg,
607 ref mut arg1,
608 ref mut arg2,
609 ref mut arg3,
610 fun: _,
611 } => {
612 adjust(arg);
613 if let Some(e) = arg1.as_mut() {
614 adjust(e);
615 }
616 if let Some(e) = arg2.as_mut() {
617 adjust(e);
618 }
619 if let Some(e) = arg3.as_mut() {
620 adjust(e);
621 }
622 }
623 Expression::As {
624 ref mut expr,
625 kind: _,
626 convert: _,
627 } => {
628 adjust(expr);
629 }
630 Expression::ArrayLength(ref mut expr) => {
631 adjust(expr);
632 }
633 Expression::RayQueryGetIntersection {
634 ref mut query,
635 committed: _,
636 } => {
637 adjust(query);
638 }
639 Expression::Literal(_)
640 | Expression::FunctionArgument(_)
641 | Expression::GlobalVariable(_)
642 | Expression::LocalVariable(_)
643 | Expression::CallResult(_)
644 | Expression::RayQueryProceedResult
645 | Expression::Constant(_)
646 | Expression::Override(_)
647 | Expression::ZeroValue(_)
648 | Expression::AtomicResult {
649 ty: _,
650 comparison: _,
651 }
652 | Expression::WorkGroupUniformLoadResult { ty: _ }
653 | Expression::SubgroupBallotResult
654 | Expression::SubgroupOperationResult { .. } => {}
655 Expression::RayQueryVertexPositions {
656 ref mut query,
657 committed: _,
658 } => {
659 adjust(query);
660 }
661 }
662}
663
664fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
667 for stmt in block.iter_mut() {
668 adjust_stmt(new_pos, stmt);
669 }
670}
671
672fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
675 let adjust = |expr: &mut Handle<Expression>| {
676 *expr = new_pos[*expr];
677 };
678 match *stmt {
679 Statement::Emit(ref mut range) => {
680 if let Some((mut first, mut last)) = range.first_and_last() {
681 adjust(&mut first);
682 adjust(&mut last);
683 *range = Range::new_from_bounds(first, last);
684 }
685 }
686 Statement::Block(ref mut block) => {
687 adjust_block(new_pos, block);
688 }
689 Statement::If {
690 ref mut condition,
691 ref mut accept,
692 ref mut reject,
693 } => {
694 adjust(condition);
695 adjust_block(new_pos, accept);
696 adjust_block(new_pos, reject);
697 }
698 Statement::Switch {
699 ref mut selector,
700 ref mut cases,
701 } => {
702 adjust(selector);
703 for case in cases.iter_mut() {
704 adjust_block(new_pos, &mut case.body);
705 }
706 }
707 Statement::Loop {
708 ref mut body,
709 ref mut continuing,
710 ref mut break_if,
711 } => {
712 adjust_block(new_pos, body);
713 adjust_block(new_pos, continuing);
714 if let Some(e) = break_if.as_mut() {
715 adjust(e);
716 }
717 }
718 Statement::Return { ref mut value } => {
719 if let Some(e) = value.as_mut() {
720 adjust(e);
721 }
722 }
723 Statement::Store {
724 ref mut pointer,
725 ref mut value,
726 } => {
727 adjust(pointer);
728 adjust(value);
729 }
730 Statement::ImageStore {
731 ref mut image,
732 ref mut coordinate,
733 ref mut array_index,
734 ref mut value,
735 } => {
736 adjust(image);
737 adjust(coordinate);
738 if let Some(e) = array_index.as_mut() {
739 adjust(e);
740 }
741 adjust(value);
742 }
743 Statement::Atomic {
744 ref mut pointer,
745 ref mut value,
746 ref mut result,
747 ref mut fun,
748 } => {
749 adjust(pointer);
750 adjust(value);
751 if let Some(ref mut result) = *result {
752 adjust(result);
753 }
754 match *fun {
755 crate::AtomicFunction::Exchange {
756 compare: Some(ref mut compare),
757 } => {
758 adjust(compare);
759 }
760 crate::AtomicFunction::Add
761 | crate::AtomicFunction::Subtract
762 | crate::AtomicFunction::And
763 | crate::AtomicFunction::ExclusiveOr
764 | crate::AtomicFunction::InclusiveOr
765 | crate::AtomicFunction::Min
766 | crate::AtomicFunction::Max
767 | crate::AtomicFunction::Exchange { compare: None } => {}
768 }
769 }
770 Statement::ImageAtomic {
771 ref mut image,
772 ref mut coordinate,
773 ref mut array_index,
774 fun: _,
775 ref mut value,
776 } => {
777 adjust(image);
778 adjust(coordinate);
779 if let Some(ref mut array_index) = *array_index {
780 adjust(array_index);
781 }
782 adjust(value);
783 }
784 Statement::WorkGroupUniformLoad {
785 ref mut pointer,
786 ref mut result,
787 } => {
788 adjust(pointer);
789 adjust(result);
790 }
791 Statement::SubgroupBallot {
792 ref mut result,
793 ref mut predicate,
794 } => {
795 if let Some(ref mut predicate) = *predicate {
796 adjust(predicate);
797 }
798 adjust(result);
799 }
800 Statement::SubgroupCollectiveOperation {
801 ref mut argument,
802 ref mut result,
803 ..
804 } => {
805 adjust(argument);
806 adjust(result);
807 }
808 Statement::SubgroupGather {
809 ref mut mode,
810 ref mut argument,
811 ref mut result,
812 } => {
813 match *mode {
814 crate::GatherMode::BroadcastFirst => {}
815 crate::GatherMode::Broadcast(ref mut index)
816 | crate::GatherMode::Shuffle(ref mut index)
817 | crate::GatherMode::ShuffleDown(ref mut index)
818 | crate::GatherMode::ShuffleUp(ref mut index)
819 | crate::GatherMode::ShuffleXor(ref mut index)
820 | crate::GatherMode::QuadBroadcast(ref mut index) => {
821 adjust(index);
822 }
823 crate::GatherMode::QuadSwap(_) => {}
824 }
825 adjust(argument);
826 adjust(result)
827 }
828 Statement::Call {
829 ref mut arguments,
830 ref mut result,
831 function: _,
832 } => {
833 for argument in arguments.iter_mut() {
834 adjust(argument);
835 }
836 if let Some(e) = result.as_mut() {
837 adjust(e);
838 }
839 }
840 Statement::RayQuery {
841 ref mut query,
842 ref mut fun,
843 } => {
844 adjust(query);
845 match *fun {
846 crate::RayQueryFunction::Initialize {
847 ref mut acceleration_structure,
848 ref mut descriptor,
849 } => {
850 adjust(acceleration_structure);
851 adjust(descriptor);
852 }
853 crate::RayQueryFunction::Proceed { ref mut result } => {
854 adjust(result);
855 }
856 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
857 adjust(hit_t);
858 }
859 crate::RayQueryFunction::ConfirmIntersection => {}
860 crate::RayQueryFunction::Terminate => {}
861 }
862 }
863 Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs {
864 ref mut vertex_count,
865 ref mut primitive_count,
866 }) => {
867 adjust(vertex_count);
868 adjust(primitive_count);
869 }
870 Statement::MeshFunction(
871 crate::MeshFunction::SetVertex {
872 ref mut index,
873 ref mut value,
874 }
875 | crate::MeshFunction::SetPrimitive {
876 ref mut index,
877 ref mut value,
878 },
879 ) => {
880 adjust(index);
881 adjust(value);
882 }
883 Statement::Break
884 | Statement::Continue
885 | Statement::Kill
886 | Statement::ControlBarrier(_)
887 | Statement::MemoryBarrier(_) => {}
888 }
889}
890
891fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
908 let original = mem::replace(block, Block::with_capacity(block.len()));
909 for (stmt, span) in original.span_into_iter() {
910 match stmt {
911 Statement::Emit(range) => {
912 let mut current = None;
913 for expr_h in range {
914 if expressions[expr_h].needs_pre_emit() {
915 if let Some((first, last)) = current {
916 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
917 }
918
919 current = None;
920 } else if let Some((_, ref mut last)) = current {
921 *last = expr_h;
922 } else {
923 current = Some((expr_h, expr_h));
924 }
925 }
926 if let Some((first, last)) = current {
927 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
928 }
929 }
930 Statement::Block(mut child) => {
931 filter_emits_in_block(&mut child, expressions);
932 block.push(Statement::Block(child), span);
933 }
934 Statement::If {
935 condition,
936 mut accept,
937 mut reject,
938 } => {
939 filter_emits_in_block(&mut accept, expressions);
940 filter_emits_in_block(&mut reject, expressions);
941 block.push(
942 Statement::If {
943 condition,
944 accept,
945 reject,
946 },
947 span,
948 );
949 }
950 Statement::Switch {
951 selector,
952 mut cases,
953 } => {
954 for case in &mut cases {
955 filter_emits_in_block(&mut case.body, expressions);
956 }
957 block.push(Statement::Switch { selector, cases }, span);
958 }
959 Statement::Loop {
960 mut body,
961 mut continuing,
962 break_if,
963 } => {
964 filter_emits_in_block(&mut body, expressions);
965 filter_emits_in_block(&mut continuing, expressions);
966 block.push(
967 Statement::Loop {
968 body,
969 continuing,
970 break_if,
971 },
972 span,
973 );
974 }
975 stmt => block.push(stmt.clone(), span),
976 }
977 }
978}
979
980fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
981 match scalar {
983 Scalar::BOOL => {
984 let value = value != 0.0 && !value.is_nan();
986 Ok(Literal::Bool(value))
987 }
988 Scalar::I32 => {
989 if !value.is_finite() {
991 return Err(PipelineConstantError::SrcNeedsToBeFinite);
992 }
993
994 let value = value.trunc();
995 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
996 return Err(PipelineConstantError::DstRangeTooSmall);
997 }
998
999 let value = value as i32;
1000 Ok(Literal::I32(value))
1001 }
1002 Scalar::U32 => {
1003 if !value.is_finite() {
1005 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1006 }
1007
1008 let value = value.trunc();
1009 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1010 return Err(PipelineConstantError::DstRangeTooSmall);
1011 }
1012
1013 let value = value as u32;
1014 Ok(Literal::U32(value))
1015 }
1016 Scalar::F16 => {
1017 if !value.is_finite() {
1019 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1020 }
1021
1022 let value = half::f16::from_f64(value);
1023 if !value.is_finite() {
1024 return Err(PipelineConstantError::DstRangeTooSmall);
1025 }
1026
1027 Ok(Literal::F16(value))
1028 }
1029 Scalar::F32 => {
1030 if !value.is_finite() {
1032 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1033 }
1034
1035 let value = value as f32;
1036 if !value.is_finite() {
1037 return Err(PipelineConstantError::DstRangeTooSmall);
1038 }
1039
1040 Ok(Literal::F32(value))
1041 }
1042 Scalar::F64 => {
1043 if !value.is_finite() {
1045 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1046 }
1047
1048 Ok(Literal::F64(value))
1049 }
1050 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1051 unreachable!("abstract values should not be validated out of override processing")
1052 }
1053 _ => unreachable!("unrecognized scalar type for override"),
1054 }
1055}
1056
1057#[test]
1058fn test_map_value_to_literal() {
1059 let bool_test_cases = [
1060 (0.0, false),
1061 (-0.0, false),
1062 (f64::NAN, false),
1063 (1.0, true),
1064 (f64::INFINITY, true),
1065 (f64::NEG_INFINITY, true),
1066 ];
1067 for (value, out) in bool_test_cases {
1068 let res = Ok(Literal::Bool(out));
1069 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1070 }
1071
1072 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1073 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1074 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1075 assert_eq!(map_value_to_literal(value, scalar), res);
1076 }
1077 }
1078
1079 assert_eq!(
1081 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1082 Ok(Literal::I32(i32::MIN))
1083 );
1084 assert_eq!(
1085 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1086 Ok(Literal::I32(i32::MAX))
1087 );
1088 assert_eq!(
1089 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1090 Err(PipelineConstantError::DstRangeTooSmall)
1091 );
1092 assert_eq!(
1093 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1094 Err(PipelineConstantError::DstRangeTooSmall)
1095 );
1096
1097 assert_eq!(
1099 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1100 Ok(Literal::U32(u32::MIN))
1101 );
1102 assert_eq!(
1103 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1104 Ok(Literal::U32(u32::MAX))
1105 );
1106 assert_eq!(
1107 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1108 Err(PipelineConstantError::DstRangeTooSmall)
1109 );
1110 assert_eq!(
1111 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1112 Err(PipelineConstantError::DstRangeTooSmall)
1113 );
1114
1115 assert_eq!(
1117 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1118 Ok(Literal::F32(f32::MIN))
1119 );
1120 assert_eq!(
1121 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1122 Ok(Literal::F32(f32::MAX))
1123 );
1124 assert_eq!(
1125 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1126 Ok(Literal::F32(f32::MIN))
1127 );
1128 assert_eq!(
1129 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1130 Ok(Literal::F32(f32::MAX))
1131 );
1132 assert_eq!(
1133 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1134 Err(PipelineConstantError::DstRangeTooSmall)
1135 );
1136 assert_eq!(
1137 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1138 Err(PipelineConstantError::DstRangeTooSmall)
1139 );
1140
1141 assert_eq!(
1143 map_value_to_literal(f64::MIN, Scalar::F64),
1144 Ok(Literal::F64(f64::MIN))
1145 );
1146 assert_eq!(
1147 map_value_to_literal(f64::MAX, Scalar::F64),
1148 Ok(Literal::F64(f64::MAX))
1149 );
1150}