1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::mem;
7
8use hashbrown::HashSet;
9use thiserror::Error;
10
11use super::PipelineConstants;
12use crate::{
13 arena::HandleVec,
14 compact::{compact, KeepUnused},
15 ir,
16 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
17 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
18 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
19 Span, Statement, TypeInner, WithSpan,
20};
21
22#[allow(unused_imports)]
24use num_traits::float::FloatCore as _;
25
26#[derive(Error, Debug, Clone)]
27#[cfg_attr(test, derive(PartialEq))]
28pub enum PipelineConstantError {
29 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
30 MissingValue(String),
31 #[error("pipeline-overridable constant '{0}' not found in the shader")]
32 NotFound(String),
33 #[error(
34 "Source f64 value needs to be finite ({}) for number destinations",
35 "NaNs and Inifinites are not allowed"
36 )]
37 SrcNeedsToBeFinite,
38 #[error("Source f64 value doesn't fit in destination")]
39 DstRangeTooSmall,
40 #[error(transparent)]
41 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
42 #[error(transparent)]
43 ValidationError(#[from] WithSpan<ValidationError>),
44 #[error("workgroup_size override isn't strictly positive")]
45 NegativeWorkgroupSize,
46 #[error("max vertices or max primitives is negative")]
47 NegativeMeshOutputMax,
48}
49
50pub fn process_overrides<'a>(
67 module: &'a Module,
68 module_info: &'a ModuleInfo,
69 entry_point: Option<(ir::ShaderStage, &str)>,
70 pipeline_constants: &PipelineConstants,
71) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
72 let mut handles = module
73 .overrides
74 .iter()
75 .map(|(handle, _)| handle)
76 .collect::<Vec<_>>();
77 for c in pipeline_constants.keys() {
78 let c_id = c.parse().ok();
79 if let Some((i, _)) = handles.iter().enumerate().find(|&(_, handle)| {
80 let o = &module.overrides[*handle];
81 if o.id.is_some() {
82 o.id == c_id
83 } else {
84 o.name.as_deref() == Some(c.as_str())
85 }
86 }) {
87 handles.swap_remove(i);
88 } else {
89 return Err(PipelineConstantError::NotFound(c.clone()));
90 }
91 }
92
93 if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
94 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
104 }
105
106 let mut module = module.clone();
107 if let Some((ep_stage, ep_name)) = entry_point {
108 module
109 .entry_points
110 .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
111 }
112
113 compact(&mut module, KeepUnused::No);
117
118 if module.overrides.is_empty() {
120 return revalidate(module);
121 }
122
123 let mut override_map = HandleVec::with_capacity(module.overrides.len());
126
127 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
130
131 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
142
143 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
144 let mut layouter = crate::proc::Layouter::default();
145
146 let mut overrides = mem::take(&mut module.overrides);
149 let mut override_iter = overrides.iter_mut_span();
150
151 for (old_h, expr, span) in module.global_expressions.drain() {
181 let mut expr = match expr {
182 Expression::Override(h) => {
183 let c_h = if let Some(new_h) = override_map.get(h) {
184 *new_h
185 } else {
186 let mut new_h = None;
187 for entry in override_iter.by_ref() {
188 let stop = entry.0 == h;
189 new_h = Some(process_override(
190 entry,
191 pipeline_constants,
192 &mut module,
193 &mut override_map,
194 &adjusted_global_expressions,
195 &mut adjusted_constant_initializers,
196 &mut global_expression_kind_tracker,
197 )?);
198 if stop {
199 break;
200 }
201 }
202 new_h.unwrap()
203 };
204 Expression::Constant(c_h)
205 }
206 Expression::Constant(c_h) => {
207 if adjusted_constant_initializers.insert(c_h) {
208 let init = &mut module.constants[c_h].init;
209 *init = adjusted_global_expressions[*init];
210 }
211 expr
212 }
213 expr => expr,
214 };
215 let mut evaluator = ConstantEvaluator::for_wgsl_module(
216 &mut module,
217 &mut global_expression_kind_tracker,
218 &mut layouter,
219 false,
220 );
221 adjust_expr(&adjusted_global_expressions, &mut expr);
222 let h = evaluator.try_eval_and_append(expr, span)?;
223 adjusted_global_expressions.insert(old_h, h);
224 }
225
226 for entry in override_iter {
228 match *entry.1 {
229 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
230 process_override(
231 entry,
232 pipeline_constants,
233 &mut module,
234 &mut override_map,
235 &adjusted_global_expressions,
236 &mut adjusted_constant_initializers,
237 &mut global_expression_kind_tracker,
238 )?;
239 }
240 Override {
241 init: Some(ref mut init),
242 ..
243 } => {
244 *init = adjusted_global_expressions[*init];
245 }
246 _ => {}
247 }
248 }
249
250 for (_, c) in module
254 .constants
255 .iter_mut()
256 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
257 {
258 c.init = adjusted_global_expressions[c.init];
259 }
260
261 for (_, v) in module.global_variables.iter_mut() {
262 if let Some(ref mut init) = v.init {
263 *init = adjusted_global_expressions[*init];
264 }
265 }
266
267 let mut functions = mem::take(&mut module.functions);
268 for (_, function) in functions.iter_mut() {
269 process_function(&mut module, &override_map, &mut layouter, function)?;
270 }
271 module.functions = functions;
272
273 let mut entry_points = mem::take(&mut module.entry_points);
274 for ep in entry_points.iter_mut() {
275 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
276 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
277 process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
278 }
279 module.entry_points = entry_points;
280 module.overrides = overrides;
281
282 revalidate(module)
286}
287
288fn revalidate(
289 module: Module,
290) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
291 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
292 let module_info = validator.validate_resolved_overrides(&module)?;
293 Ok((Cow::Owned(module), Cow::Owned(module_info)))
294}
295
296fn process_workgroup_size_override(
297 module: &mut Module,
298 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
299 ep: &mut crate::EntryPoint,
300) -> Result<(), PipelineConstantError> {
301 match ep.workgroup_size_overrides {
302 None => {}
303 Some(overrides) => {
304 overrides.iter().enumerate().try_for_each(
305 |(i, overridden)| -> Result<(), PipelineConstantError> {
306 match *overridden {
307 None => Ok(()),
308 Some(h) => {
309 ep.workgroup_size[i] = module
310 .to_ctx()
311 .get_const_val(adjusted_global_expressions[h])
312 .map(|n| {
313 if n == 0 {
314 Err(PipelineConstantError::NegativeWorkgroupSize)
315 } else {
316 Ok(n)
317 }
318 })
319 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
320 Ok(())
321 }
322 }
323 },
324 )?;
325 ep.workgroup_size_overrides = None;
326 }
327 }
328 Ok(())
329}
330
331fn process_mesh_shader_overrides(
332 module: &mut Module,
333 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
334 ep: &mut crate::EntryPoint,
335) -> Result<(), PipelineConstantError> {
336 if let Some(ref mut mesh_info) = ep.mesh_info {
337 if let Some(r#override) = mesh_info.max_vertices_override {
338 mesh_info.max_vertices = module
339 .to_ctx()
340 .get_const_val(adjusted_global_expressions[r#override])
341 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
342 }
343 if let Some(r#override) = mesh_info.max_primitives_override {
344 mesh_info.max_primitives = module
345 .to_ctx()
346 .get_const_val(adjusted_global_expressions[r#override])
347 .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
348 }
349 }
350 Ok(())
351}
352
353fn process_override(
357 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
358 pipeline_constants: &PipelineConstants,
359 module: &mut Module,
360 override_map: &mut HandleVec<Override, Handle<Constant>>,
361 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
362 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
363 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
364) -> Result<Handle<Constant>, PipelineConstantError> {
365 let key = if let Some(id) = r#override.id {
367 Cow::Owned(id.to_string())
368 } else if let Some(ref name) = r#override.name {
369 Cow::Borrowed(name)
370 } else {
371 unreachable!();
372 };
373
374 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
378 let literal = match module.types[r#override.ty].inner {
379 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
380 _ => unreachable!(),
381 };
382 let expr = module
383 .global_expressions
384 .append(Expression::Literal(literal), Span::UNDEFINED);
385 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
386 expr
387 } else if let Some(init) = r#override.init {
388 adjusted_global_expressions[init]
389 } else {
390 return Err(PipelineConstantError::MissingValue(key.to_string()));
391 };
392
393 let constant = Constant {
395 name: r#override.name.clone(),
396 ty: r#override.ty,
397 init,
398 };
399 let h = module.constants.append(constant, *span);
400 override_map.insert(old_h, h);
401 adjusted_constant_initializers.insert(h);
402 r#override.init = Some(init);
403 Ok(h)
404}
405
406fn process_function(
416 module: &mut Module,
417 override_map: &HandleVec<Override, Handle<Constant>>,
418 layouter: &mut crate::proc::Layouter,
419 function: &mut Function,
420) -> Result<(), ConstantEvaluatorError> {
421 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
424
425 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
426
427 let mut expressions = mem::take(&mut function.expressions);
428
429 let mut emitter = Emitter::default();
437 let mut block = Block::new();
438
439 let mut evaluator = ConstantEvaluator::for_wgsl_function(
440 module,
441 &mut function.expressions,
442 &mut local_expression_kind_tracker,
443 layouter,
444 &mut emitter,
445 &mut block,
446 false,
447 );
448
449 for (old_h, mut expr, span) in expressions.drain() {
450 if let Expression::Override(h) = expr {
451 expr = Expression::Constant(override_map[h]);
452 }
453 adjust_expr(&adjusted_local_expressions, &mut expr);
454 let h = evaluator.try_eval_and_append(expr, span)?;
455 adjusted_local_expressions.insert(old_h, h);
456 }
457
458 adjust_block(&adjusted_local_expressions, &mut function.body);
459
460 filter_emits_in_block(&mut function.body, &function.expressions);
461
462 for (_, local) in function.local_variables.iter_mut() {
464 if let &mut Some(ref mut init) = &mut local.init {
465 *init = adjusted_local_expressions[*init];
466 }
467 }
468
469 let named_expressions = mem::take(&mut function.named_expressions);
472 for (expr_h, name) in named_expressions {
473 function
474 .named_expressions
475 .insert(adjusted_local_expressions[expr_h], name);
476 }
477
478 Ok(())
479}
480
481fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
484 let adjust = |expr: &mut Handle<Expression>| {
485 *expr = new_pos[*expr];
486 };
487 match *expr {
488 Expression::Compose {
489 ref mut components,
490 ty: _,
491 } => {
492 for c in components.iter_mut() {
493 adjust(c);
494 }
495 }
496 Expression::Access {
497 ref mut base,
498 ref mut index,
499 } => {
500 adjust(base);
501 adjust(index);
502 }
503 Expression::AccessIndex {
504 ref mut base,
505 index: _,
506 } => {
507 adjust(base);
508 }
509 Expression::Splat {
510 ref mut value,
511 size: _,
512 } => {
513 adjust(value);
514 }
515 Expression::Swizzle {
516 ref mut vector,
517 size: _,
518 pattern: _,
519 } => {
520 adjust(vector);
521 }
522 Expression::Load { ref mut pointer } => {
523 adjust(pointer);
524 }
525 Expression::ImageSample {
526 ref mut image,
527 ref mut sampler,
528 ref mut coordinate,
529 ref mut array_index,
530 ref mut offset,
531 ref mut level,
532 ref mut depth_ref,
533 gather: _,
534 clamp_to_edge: _,
535 } => {
536 adjust(image);
537 adjust(sampler);
538 adjust(coordinate);
539 if let Some(e) = array_index.as_mut() {
540 adjust(e);
541 }
542 if let Some(e) = offset.as_mut() {
543 adjust(e);
544 }
545 match *level {
546 crate::SampleLevel::Exact(ref mut expr)
547 | crate::SampleLevel::Bias(ref mut expr) => {
548 adjust(expr);
549 }
550 crate::SampleLevel::Gradient {
551 ref mut x,
552 ref mut y,
553 } => {
554 adjust(x);
555 adjust(y);
556 }
557 _ => {}
558 }
559 if let Some(e) = depth_ref.as_mut() {
560 adjust(e);
561 }
562 }
563 Expression::ImageLoad {
564 ref mut image,
565 ref mut coordinate,
566 ref mut array_index,
567 ref mut sample,
568 ref mut level,
569 } => {
570 adjust(image);
571 adjust(coordinate);
572 if let Some(e) = array_index.as_mut() {
573 adjust(e);
574 }
575 if let Some(e) = sample.as_mut() {
576 adjust(e);
577 }
578 if let Some(e) = level.as_mut() {
579 adjust(e);
580 }
581 }
582 Expression::ImageQuery {
583 ref mut image,
584 ref mut query,
585 } => {
586 adjust(image);
587 match *query {
588 crate::ImageQuery::Size { ref mut level } => {
589 if let Some(e) = level.as_mut() {
590 adjust(e);
591 }
592 }
593 crate::ImageQuery::NumLevels
594 | crate::ImageQuery::NumLayers
595 | crate::ImageQuery::NumSamples => {}
596 }
597 }
598 Expression::Unary {
599 ref mut expr,
600 op: _,
601 } => {
602 adjust(expr);
603 }
604 Expression::Binary {
605 ref mut left,
606 ref mut right,
607 op: _,
608 } => {
609 adjust(left);
610 adjust(right);
611 }
612 Expression::Select {
613 ref mut condition,
614 ref mut accept,
615 ref mut reject,
616 } => {
617 adjust(condition);
618 adjust(accept);
619 adjust(reject);
620 }
621 Expression::Derivative {
622 ref mut expr,
623 axis: _,
624 ctrl: _,
625 } => {
626 adjust(expr);
627 }
628 Expression::Relational {
629 ref mut argument,
630 fun: _,
631 } => {
632 adjust(argument);
633 }
634 Expression::Math {
635 ref mut arg,
636 ref mut arg1,
637 ref mut arg2,
638 ref mut arg3,
639 fun: _,
640 } => {
641 adjust(arg);
642 if let Some(e) = arg1.as_mut() {
643 adjust(e);
644 }
645 if let Some(e) = arg2.as_mut() {
646 adjust(e);
647 }
648 if let Some(e) = arg3.as_mut() {
649 adjust(e);
650 }
651 }
652 Expression::As {
653 ref mut expr,
654 kind: _,
655 convert: _,
656 } => {
657 adjust(expr);
658 }
659 Expression::ArrayLength(ref mut expr) => {
660 adjust(expr);
661 }
662 Expression::RayQueryGetIntersection {
663 ref mut query,
664 committed: _,
665 } => {
666 adjust(query);
667 }
668 Expression::Literal(_)
669 | Expression::FunctionArgument(_)
670 | Expression::GlobalVariable(_)
671 | Expression::LocalVariable(_)
672 | Expression::CallResult(_)
673 | Expression::RayQueryProceedResult
674 | Expression::Constant(_)
675 | Expression::Override(_)
676 | Expression::ZeroValue(_)
677 | Expression::AtomicResult {
678 ty: _,
679 comparison: _,
680 }
681 | Expression::WorkGroupUniformLoadResult { ty: _ }
682 | Expression::SubgroupBallotResult
683 | Expression::SubgroupOperationResult { .. } => {}
684 Expression::RayQueryVertexPositions {
685 ref mut query,
686 committed: _,
687 } => {
688 adjust(query);
689 }
690 Expression::CooperativeLoad { ref mut data, .. } => {
691 adjust(&mut data.pointer);
692 adjust(&mut data.stride);
693 }
694 Expression::CooperativeMultiplyAdd {
695 ref mut a,
696 ref mut b,
697 ref mut c,
698 } => {
699 adjust(a);
700 adjust(b);
701 adjust(c);
702 }
703 }
704}
705
706fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
709 for stmt in block.iter_mut() {
710 adjust_stmt(new_pos, stmt);
711 }
712}
713
714fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
717 let adjust = |expr: &mut Handle<Expression>| {
718 *expr = new_pos[*expr];
719 };
720 match *stmt {
721 Statement::Emit(ref mut range) => {
722 if let Some((mut first, mut last)) = range.first_and_last() {
723 adjust(&mut first);
724 adjust(&mut last);
725 *range = Range::new_from_bounds(first, last);
726 }
727 }
728 Statement::Block(ref mut block) => {
729 adjust_block(new_pos, block);
730 }
731 Statement::If {
732 ref mut condition,
733 ref mut accept,
734 ref mut reject,
735 } => {
736 adjust(condition);
737 adjust_block(new_pos, accept);
738 adjust_block(new_pos, reject);
739 }
740 Statement::Switch {
741 ref mut selector,
742 ref mut cases,
743 } => {
744 adjust(selector);
745 for case in cases.iter_mut() {
746 adjust_block(new_pos, &mut case.body);
747 }
748 }
749 Statement::Loop {
750 ref mut body,
751 ref mut continuing,
752 ref mut break_if,
753 } => {
754 adjust_block(new_pos, body);
755 adjust_block(new_pos, continuing);
756 if let Some(e) = break_if.as_mut() {
757 adjust(e);
758 }
759 }
760 Statement::Return { ref mut value } => {
761 if let Some(e) = value.as_mut() {
762 adjust(e);
763 }
764 }
765 Statement::Store {
766 ref mut pointer,
767 ref mut value,
768 } => {
769 adjust(pointer);
770 adjust(value);
771 }
772 Statement::ImageStore {
773 ref mut image,
774 ref mut coordinate,
775 ref mut array_index,
776 ref mut value,
777 } => {
778 adjust(image);
779 adjust(coordinate);
780 if let Some(e) = array_index.as_mut() {
781 adjust(e);
782 }
783 adjust(value);
784 }
785 Statement::Atomic {
786 ref mut pointer,
787 ref mut value,
788 ref mut result,
789 ref mut fun,
790 } => {
791 adjust(pointer);
792 adjust(value);
793 if let Some(ref mut result) = *result {
794 adjust(result);
795 }
796 match *fun {
797 crate::AtomicFunction::Exchange {
798 compare: Some(ref mut compare),
799 } => {
800 adjust(compare);
801 }
802 crate::AtomicFunction::Add
803 | crate::AtomicFunction::Subtract
804 | crate::AtomicFunction::And
805 | crate::AtomicFunction::ExclusiveOr
806 | crate::AtomicFunction::InclusiveOr
807 | crate::AtomicFunction::Min
808 | crate::AtomicFunction::Max
809 | crate::AtomicFunction::Exchange { compare: None } => {}
810 }
811 }
812 Statement::ImageAtomic {
813 ref mut image,
814 ref mut coordinate,
815 ref mut array_index,
816 fun: _,
817 ref mut value,
818 } => {
819 adjust(image);
820 adjust(coordinate);
821 if let Some(ref mut array_index) = *array_index {
822 adjust(array_index);
823 }
824 adjust(value);
825 }
826 Statement::WorkGroupUniformLoad {
827 ref mut pointer,
828 ref mut result,
829 } => {
830 adjust(pointer);
831 adjust(result);
832 }
833 Statement::SubgroupBallot {
834 ref mut result,
835 ref mut predicate,
836 } => {
837 if let Some(ref mut predicate) = *predicate {
838 adjust(predicate);
839 }
840 adjust(result);
841 }
842 Statement::SubgroupCollectiveOperation {
843 ref mut argument,
844 ref mut result,
845 ..
846 } => {
847 adjust(argument);
848 adjust(result);
849 }
850 Statement::SubgroupGather {
851 ref mut mode,
852 ref mut argument,
853 ref mut result,
854 } => {
855 match *mode {
856 crate::GatherMode::BroadcastFirst => {}
857 crate::GatherMode::Broadcast(ref mut index)
858 | crate::GatherMode::Shuffle(ref mut index)
859 | crate::GatherMode::ShuffleDown(ref mut index)
860 | crate::GatherMode::ShuffleUp(ref mut index)
861 | crate::GatherMode::ShuffleXor(ref mut index)
862 | crate::GatherMode::QuadBroadcast(ref mut index) => {
863 adjust(index);
864 }
865 crate::GatherMode::QuadSwap(_) => {}
866 }
867 adjust(argument);
868 adjust(result)
869 }
870 Statement::Call {
871 ref mut arguments,
872 ref mut result,
873 function: _,
874 } => {
875 for argument in arguments.iter_mut() {
876 adjust(argument);
877 }
878 if let Some(e) = result.as_mut() {
879 adjust(e);
880 }
881 }
882 Statement::RayQuery {
883 ref mut query,
884 ref mut fun,
885 } => {
886 adjust(query);
887 match *fun {
888 crate::RayQueryFunction::Initialize {
889 ref mut acceleration_structure,
890 ref mut descriptor,
891 } => {
892 adjust(acceleration_structure);
893 adjust(descriptor);
894 }
895 crate::RayQueryFunction::Proceed { ref mut result } => {
896 adjust(result);
897 }
898 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
899 adjust(hit_t);
900 }
901 crate::RayQueryFunction::ConfirmIntersection => {}
902 crate::RayQueryFunction::Terminate => {}
903 }
904 }
905 Statement::CooperativeStore {
906 ref mut target,
907 ref mut data,
908 } => {
909 adjust(target);
910 adjust(&mut data.pointer);
911 adjust(&mut data.stride);
912 }
913 Statement::RayPipelineFunction(ref mut func) => match *func {
914 crate::RayPipelineFunction::TraceRay {
915 ref mut acceleration_structure,
916 ref mut descriptor,
917 ref mut payload,
918 } => {
919 adjust(acceleration_structure);
920 adjust(descriptor);
921 adjust(payload);
922 }
923 },
924 Statement::Break
925 | Statement::Continue
926 | Statement::Kill
927 | Statement::ControlBarrier(_)
928 | Statement::MemoryBarrier(_) => {}
929 }
930}
931
932fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
949 let original = mem::replace(block, Block::with_capacity(block.len()));
950 for (stmt, span) in original.span_into_iter() {
951 match stmt {
952 Statement::Emit(range) => {
953 let mut current = None;
954 for expr_h in range {
955 if expressions[expr_h].needs_pre_emit() {
956 if let Some((first, last)) = current {
957 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
958 }
959
960 current = None;
961 } else if let Some((_, ref mut last)) = current {
962 *last = expr_h;
963 } else {
964 current = Some((expr_h, expr_h));
965 }
966 }
967 if let Some((first, last)) = current {
968 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
969 }
970 }
971 Statement::Block(mut child) => {
972 filter_emits_in_block(&mut child, expressions);
973 block.push(Statement::Block(child), span);
974 }
975 Statement::If {
976 condition,
977 mut accept,
978 mut reject,
979 } => {
980 filter_emits_in_block(&mut accept, expressions);
981 filter_emits_in_block(&mut reject, expressions);
982 block.push(
983 Statement::If {
984 condition,
985 accept,
986 reject,
987 },
988 span,
989 );
990 }
991 Statement::Switch {
992 selector,
993 mut cases,
994 } => {
995 for case in &mut cases {
996 filter_emits_in_block(&mut case.body, expressions);
997 }
998 block.push(Statement::Switch { selector, cases }, span);
999 }
1000 Statement::Loop {
1001 mut body,
1002 mut continuing,
1003 break_if,
1004 } => {
1005 filter_emits_in_block(&mut body, expressions);
1006 filter_emits_in_block(&mut continuing, expressions);
1007 block.push(
1008 Statement::Loop {
1009 body,
1010 continuing,
1011 break_if,
1012 },
1013 span,
1014 );
1015 }
1016 stmt => block.push(stmt.clone(), span),
1017 }
1018 }
1019}
1020
1021fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
1022 match scalar {
1024 Scalar::BOOL => {
1025 let value = value != 0.0 && !value.is_nan();
1027 Ok(Literal::Bool(value))
1028 }
1029 Scalar::I32 => {
1030 if !value.is_finite() {
1032 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1033 }
1034
1035 let value = value.trunc();
1036 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
1037 return Err(PipelineConstantError::DstRangeTooSmall);
1038 }
1039
1040 let value = value as i32;
1041 Ok(Literal::I32(value))
1042 }
1043 Scalar::U32 => {
1044 if !value.is_finite() {
1046 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1047 }
1048
1049 let value = value.trunc();
1050 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1051 return Err(PipelineConstantError::DstRangeTooSmall);
1052 }
1053
1054 let value = value as u32;
1055 Ok(Literal::U32(value))
1056 }
1057 Scalar::F16 => {
1058 if !value.is_finite() {
1060 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1061 }
1062
1063 let value = half::f16::from_f64(value);
1064 if !value.is_finite() {
1065 return Err(PipelineConstantError::DstRangeTooSmall);
1066 }
1067
1068 Ok(Literal::F16(value))
1069 }
1070 Scalar::F32 => {
1071 if !value.is_finite() {
1073 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1074 }
1075
1076 let value = value as f32;
1077 if !value.is_finite() {
1078 return Err(PipelineConstantError::DstRangeTooSmall);
1079 }
1080
1081 Ok(Literal::F32(value))
1082 }
1083 Scalar::F64 => {
1084 if !value.is_finite() {
1086 return Err(PipelineConstantError::SrcNeedsToBeFinite);
1087 }
1088
1089 Ok(Literal::F64(value))
1090 }
1091 Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1092 unreachable!("abstract values should not be validated out of override processing")
1093 }
1094 _ => unreachable!("unrecognized scalar type for override"),
1095 }
1096}
1097
1098#[test]
1099fn test_map_value_to_literal() {
1100 let bool_test_cases = [
1101 (0.0, false),
1102 (-0.0, false),
1103 (f64::NAN, false),
1104 (1.0, true),
1105 (f64::INFINITY, true),
1106 (f64::NEG_INFINITY, true),
1107 ];
1108 for (value, out) in bool_test_cases {
1109 let res = Ok(Literal::Bool(out));
1110 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1111 }
1112
1113 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1114 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1115 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1116 assert_eq!(map_value_to_literal(value, scalar), res);
1117 }
1118 }
1119
1120 assert_eq!(
1122 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1123 Ok(Literal::I32(i32::MIN))
1124 );
1125 assert_eq!(
1126 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1127 Ok(Literal::I32(i32::MAX))
1128 );
1129 assert_eq!(
1130 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1131 Err(PipelineConstantError::DstRangeTooSmall)
1132 );
1133 assert_eq!(
1134 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1135 Err(PipelineConstantError::DstRangeTooSmall)
1136 );
1137
1138 assert_eq!(
1140 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1141 Ok(Literal::U32(u32::MIN))
1142 );
1143 assert_eq!(
1144 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1145 Ok(Literal::U32(u32::MAX))
1146 );
1147 assert_eq!(
1148 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1149 Err(PipelineConstantError::DstRangeTooSmall)
1150 );
1151 assert_eq!(
1152 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1153 Err(PipelineConstantError::DstRangeTooSmall)
1154 );
1155
1156 assert_eq!(
1158 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1159 Ok(Literal::F32(f32::MIN))
1160 );
1161 assert_eq!(
1162 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1163 Ok(Literal::F32(f32::MAX))
1164 );
1165 assert_eq!(
1166 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1167 Ok(Literal::F32(f32::MIN))
1168 );
1169 assert_eq!(
1170 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1171 Ok(Literal::F32(f32::MAX))
1172 );
1173 assert_eq!(
1174 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1175 Err(PipelineConstantError::DstRangeTooSmall)
1176 );
1177 assert_eq!(
1178 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1179 Err(PipelineConstantError::DstRangeTooSmall)
1180 );
1181
1182 assert_eq!(
1184 map_value_to_literal(f64::MIN, Scalar::F64),
1185 Ok(Literal::F64(f64::MIN))
1186 );
1187 assert_eq!(
1188 map_value_to_literal(f64::MAX, Scalar::F64),
1189 Ok(Literal::F64(f64::MAX))
1190 );
1191}