naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::mem;
7
8use hashbrown::HashSet;
9use thiserror::Error;
10
11use super::PipelineConstants;
12use crate::{
13    arena::HandleVec,
14    compact::{compact, KeepUnused},
15    ir,
16    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
17    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
18    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
19    Span, Statement, TypeInner, WithSpan,
20};
21
22// Possibly unused if not compiled with no_std
23#[allow(unused_imports)]
24use num_traits::float::FloatCore as _;
25
26#[derive(Error, Debug, Clone)]
27#[cfg_attr(test, derive(PartialEq))]
28pub enum PipelineConstantError {
29    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
30    MissingValue(String),
31    #[error("pipeline-overridable constant '{0}' not found in the shader")]
32    NotFound(String),
33    #[error(
34        "Source f64 value needs to be finite ({}) for number destinations",
35        "NaNs and Inifinites are not allowed"
36    )]
37    SrcNeedsToBeFinite,
38    #[error("Source f64 value doesn't fit in destination")]
39    DstRangeTooSmall,
40    #[error(transparent)]
41    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
42    #[error(transparent)]
43    ValidationError(#[from] WithSpan<ValidationError>),
44    #[error("workgroup_size override isn't strictly positive")]
45    NegativeWorkgroupSize,
46    #[error("max vertices or max primitives is negative")]
47    NegativeMeshOutputMax,
48}
49
50/// Compact `module` and replace all overrides with constants.
51///
52/// `module` must be valid. Both compaction and constant evaluation may produce
53/// invalid results (e.g. replace an invalid expression with a constant) for
54/// invalid modules.
55///
56/// If no changes are needed, this just returns `Cow::Borrowed` references to
57/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
58/// selected entry point, compacts the module, edits its [`global_expressions`]
59/// arena to contain only fully-evaluated expressions, and returns the
60/// simplified module and its validation results.
61///
62/// The module returned has an empty `overrides` arena, and the
63/// `global_expressions` arena contains only fully-evaluated expressions.
64///
65/// [`global_expressions`]: Module::global_expressions
66pub fn process_overrides<'a>(
67    module: &'a Module,
68    module_info: &'a ModuleInfo,
69    entry_point: Option<(ir::ShaderStage, &str)>,
70    pipeline_constants: &PipelineConstants,
71) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
72    let mut handles = module
73        .overrides
74        .iter()
75        .map(|(handle, _)| handle)
76        .collect::<Vec<_>>();
77    for c in pipeline_constants.keys() {
78        let c_id = c.parse().ok();
79        if let Some((i, _)) = handles.iter().enumerate().find(|&(_, handle)| {
80            let o = &module.overrides[*handle];
81            if o.id.is_some() {
82                o.id == c_id
83            } else {
84                o.name.as_deref() == Some(c.as_str())
85            }
86        }) {
87            handles.swap_remove(i);
88        } else {
89            return Err(PipelineConstantError::NotFound(c.clone()));
90        }
91    }
92
93    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
94        // We skip compacting the module here mostly to reduce the risk of
95        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
96        // Compaction doesn't cost very much [1], so it would also be reasonable
97        // to do it unconditionally. Even when there is a single entry point or
98        // when no entry point is specified, it is still possible that there
99        // are unreferenced items in the module that would be removed by this
100        // compaction.
101        //
102        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
103        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
104    }
105
106    let mut module = module.clone();
107    if let Some((ep_stage, ep_name)) = entry_point {
108        module
109            .entry_points
110            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
111    }
112
113    // Compact the module to remove anything not reachable from an entry point.
114    // This is necessary because we may not have values for overrides that are
115    // not reachable from the/an entry point.
116    compact(&mut module, KeepUnused::No);
117
118    // If there are no overrides in the module, then we can skip the rest.
119    if module.overrides.is_empty() {
120        return revalidate(module);
121    }
122
123    // A map from override handles to the handles of the constants
124    // we've replaced them with.
125    let mut override_map = HandleVec::with_capacity(module.overrides.len());
126
127    // A map from `module`'s original global expression handles to
128    // handles in the new, simplified global expression arena.
129    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
130
131    // The set of constants whose initializer handles we've already
132    // updated to refer to the newly built global expression arena.
133    //
134    // All constants in `module` must have their `init` handles
135    // updated to point into the new, simplified global expression
136    // arena. Some of these we can most easily handle as a side effect
137    // during the simplification process, but we must handle the rest
138    // in a final fixup pass, guided by `adjusted_global_expressions`. We
139    // add their handles to this set, so that the final fixup step can
140    // leave them alone.
141    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
142
143    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
144    let mut layouter = crate::proc::Layouter::default();
145
146    // An iterator through the original overrides table, consumed in
147    // approximate tandem with the global expressions.
148    let mut overrides = mem::take(&mut module.overrides);
149    let mut override_iter = overrides.iter_mut_span();
150
151    // Do two things in tandem:
152    //
153    // - Rebuild the global expression arena from scratch, fully
154    //   evaluating all expressions, and replacing each `Override`
155    //   expression in `module.global_expressions` with a `Constant`
156    //   expression.
157    //
158    // - Build a new `Constant` in `module.constants` to take the
159    //   place of each `Override`.
160    //
161    // Build a map from old global expression handles to their
162    // fully-evaluated counterparts in `adjusted_global_expressions` as we
163    // go.
164    //
165    // Why in tandem? Overrides refer to expressions, and expressions
166    // refer to overrides, so we can't disentangle the two into
167    // separate phases. However, we can take advantage of the fact
168    // that the overrides and expressions must form a DAG, and work
169    // our way from the leaves to the roots, replacing and evaluating
170    // as we go.
171    //
172    // Although the two loops are nested, this is really two
173    // alternating phases: we adjust and evaluate constant expressions
174    // until we hit an `Override` expression, at which point we switch
175    // to building `Constant`s for `Overrides` until we've handled the
176    // one used by the expression. Then we switch back to processing
177    // expressions. Because we know they form a DAG, we know the
178    // `Override` expressions we encounter can only have initializers
179    // referring to global expressions we've already simplified.
180    for (old_h, expr, span) in module.global_expressions.drain() {
181        let mut expr = match expr {
182            Expression::Override(h) => {
183                let c_h = if let Some(new_h) = override_map.get(h) {
184                    *new_h
185                } else {
186                    let mut new_h = None;
187                    for entry in override_iter.by_ref() {
188                        let stop = entry.0 == h;
189                        new_h = Some(process_override(
190                            entry,
191                            pipeline_constants,
192                            &mut module,
193                            &mut override_map,
194                            &adjusted_global_expressions,
195                            &mut adjusted_constant_initializers,
196                            &mut global_expression_kind_tracker,
197                        )?);
198                        if stop {
199                            break;
200                        }
201                    }
202                    new_h.unwrap()
203                };
204                Expression::Constant(c_h)
205            }
206            Expression::Constant(c_h) => {
207                if adjusted_constant_initializers.insert(c_h) {
208                    let init = &mut module.constants[c_h].init;
209                    *init = adjusted_global_expressions[*init];
210                }
211                expr
212            }
213            expr => expr,
214        };
215        let mut evaluator = ConstantEvaluator::for_wgsl_module(
216            &mut module,
217            &mut global_expression_kind_tracker,
218            &mut layouter,
219            false,
220        );
221        adjust_expr(&adjusted_global_expressions, &mut expr);
222        let h = evaluator.try_eval_and_append(expr, span)?;
223        adjusted_global_expressions.insert(old_h, h);
224    }
225
226    // Finish processing any overrides we didn't visit in the loop above.
227    for entry in override_iter {
228        match *entry.1 {
229            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
230                process_override(
231                    entry,
232                    pipeline_constants,
233                    &mut module,
234                    &mut override_map,
235                    &adjusted_global_expressions,
236                    &mut adjusted_constant_initializers,
237                    &mut global_expression_kind_tracker,
238                )?;
239            }
240            Override {
241                init: Some(ref mut init),
242                ..
243            } => {
244                *init = adjusted_global_expressions[*init];
245            }
246            _ => {}
247        }
248    }
249
250    // Update the initialization expression handles of all `Constant`s
251    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
252    // passant.
253    for (_, c) in module
254        .constants
255        .iter_mut()
256        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
257    {
258        c.init = adjusted_global_expressions[c.init];
259    }
260
261    for (_, v) in module.global_variables.iter_mut() {
262        if let Some(ref mut init) = v.init {
263            *init = adjusted_global_expressions[*init];
264        }
265    }
266
267    let mut functions = mem::take(&mut module.functions);
268    for (_, function) in functions.iter_mut() {
269        process_function(&mut module, &override_map, &mut layouter, function)?;
270    }
271    module.functions = functions;
272
273    let mut entry_points = mem::take(&mut module.entry_points);
274    for ep in entry_points.iter_mut() {
275        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
276        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
277        process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
278    }
279    module.entry_points = entry_points;
280    module.overrides = overrides;
281
282    // Now that we've rewritten all the expressions, we need to
283    // recompute their types and other metadata. For the time being,
284    // do a full re-validation.
285    revalidate(module)
286}
287
288fn revalidate(
289    module: Module,
290) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
291    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
292    let module_info = validator.validate_resolved_overrides(&module)?;
293    Ok((Cow::Owned(module), Cow::Owned(module_info)))
294}
295
296fn process_workgroup_size_override(
297    module: &mut Module,
298    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
299    ep: &mut crate::EntryPoint,
300) -> Result<(), PipelineConstantError> {
301    match ep.workgroup_size_overrides {
302        None => {}
303        Some(overrides) => {
304            overrides.iter().enumerate().try_for_each(
305                |(i, overridden)| -> Result<(), PipelineConstantError> {
306                    match *overridden {
307                        None => Ok(()),
308                        Some(h) => {
309                            ep.workgroup_size[i] = module
310                                .to_ctx()
311                                .get_const_val(adjusted_global_expressions[h])
312                                .map(|n| {
313                                    if n == 0 {
314                                        Err(PipelineConstantError::NegativeWorkgroupSize)
315                                    } else {
316                                        Ok(n)
317                                    }
318                                })
319                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
320                            Ok(())
321                        }
322                    }
323                },
324            )?;
325            ep.workgroup_size_overrides = None;
326        }
327    }
328    Ok(())
329}
330
331fn process_mesh_shader_overrides(
332    module: &mut Module,
333    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
334    ep: &mut crate::EntryPoint,
335) -> Result<(), PipelineConstantError> {
336    if let Some(ref mut mesh_info) = ep.mesh_info {
337        if let Some(r#override) = mesh_info.max_vertices_override {
338            mesh_info.max_vertices = module
339                .to_ctx()
340                .get_const_val(adjusted_global_expressions[r#override])
341                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
342        }
343        if let Some(r#override) = mesh_info.max_primitives_override {
344            mesh_info.max_primitives = module
345                .to_ctx()
346                .get_const_val(adjusted_global_expressions[r#override])
347                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
348        }
349    }
350    Ok(())
351}
352
353/// Add a [`Constant`] to `module` for the override `old_h`.
354///
355/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
356fn process_override(
357    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
358    pipeline_constants: &PipelineConstants,
359    module: &mut Module,
360    override_map: &mut HandleVec<Override, Handle<Constant>>,
361    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
362    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
363    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
364) -> Result<Handle<Constant>, PipelineConstantError> {
365    // Determine which key to use for `r#override` in `pipeline_constants`.
366    let key = if let Some(id) = r#override.id {
367        Cow::Owned(id.to_string())
368    } else if let Some(ref name) = r#override.name {
369        Cow::Borrowed(name)
370    } else {
371        unreachable!();
372    };
373
374    // Generate a global expression for `r#override`'s value, either
375    // from the provided `pipeline_constants` table or its initializer
376    // in the module.
377    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
378        let literal = match module.types[r#override.ty].inner {
379            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
380            _ => unreachable!(),
381        };
382        let expr = module
383            .global_expressions
384            .append(Expression::Literal(literal), Span::UNDEFINED);
385        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
386        expr
387    } else if let Some(init) = r#override.init {
388        adjusted_global_expressions[init]
389    } else {
390        return Err(PipelineConstantError::MissingValue(key.to_string()));
391    };
392
393    // Generate a new `Constant` to represent the override's value.
394    let constant = Constant {
395        name: r#override.name.clone(),
396        ty: r#override.ty,
397        init,
398    };
399    let h = module.constants.append(constant, *span);
400    override_map.insert(old_h, h);
401    adjusted_constant_initializers.insert(h);
402    r#override.init = Some(init);
403    Ok(h)
404}
405
406/// Replace all override expressions in `function` with fully-evaluated constants.
407///
408/// Replace all `Expression::Override`s in `function`'s expression arena with
409/// the corresponding `Expression::Constant`s, as given in `override_map`.
410/// Replace any expressions whose values are now known with their fully
411/// evaluated form.
412///
413/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
414/// `Handle<Constant>` for the override's final value.
415fn process_function(
416    module: &mut Module,
417    override_map: &HandleVec<Override, Handle<Constant>>,
418    layouter: &mut crate::proc::Layouter,
419    function: &mut Function,
420) -> Result<(), ConstantEvaluatorError> {
421    // A map from original local expression handles to
422    // handles in the new, local expression arena.
423    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
424
425    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
426
427    let mut expressions = mem::take(&mut function.expressions);
428
429    // Dummy `emitter` and `block` for the constant evaluator.
430    // We can ignore the concept of emitting expressions here since
431    // expressions have already been covered by a `Statement::Emit`
432    // in the frontend.
433    // The only thing we might have to do is remove some expressions
434    // that have been covered by a `Statement::Emit`. See the docs of
435    // `filter_emits_in_block` for the reasoning.
436    let mut emitter = Emitter::default();
437    let mut block = Block::new();
438
439    let mut evaluator = ConstantEvaluator::for_wgsl_function(
440        module,
441        &mut function.expressions,
442        &mut local_expression_kind_tracker,
443        layouter,
444        &mut emitter,
445        &mut block,
446        false,
447    );
448
449    for (old_h, mut expr, span) in expressions.drain() {
450        if let Expression::Override(h) = expr {
451            expr = Expression::Constant(override_map[h]);
452        }
453        adjust_expr(&adjusted_local_expressions, &mut expr);
454        let h = evaluator.try_eval_and_append(expr, span)?;
455        adjusted_local_expressions.insert(old_h, h);
456    }
457
458    adjust_block(&adjusted_local_expressions, &mut function.body);
459
460    filter_emits_in_block(&mut function.body, &function.expressions);
461
462    // Update local expression initializers.
463    for (_, local) in function.local_variables.iter_mut() {
464        if let &mut Some(ref mut init) = &mut local.init {
465            *init = adjusted_local_expressions[*init];
466        }
467    }
468
469    // We've changed the keys of `function.named_expression`, so we have to
470    // rebuild it from scratch.
471    let named_expressions = mem::take(&mut function.named_expressions);
472    for (expr_h, name) in named_expressions {
473        function
474            .named_expressions
475            .insert(adjusted_local_expressions[expr_h], name);
476    }
477
478    Ok(())
479}
480
481/// Replace every expression handle in `expr` with its counterpart
482/// given by `new_pos`.
483fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
484    let adjust = |expr: &mut Handle<Expression>| {
485        *expr = new_pos[*expr];
486    };
487    match *expr {
488        Expression::Compose {
489            ref mut components,
490            ty: _,
491        } => {
492            for c in components.iter_mut() {
493                adjust(c);
494            }
495        }
496        Expression::Access {
497            ref mut base,
498            ref mut index,
499        } => {
500            adjust(base);
501            adjust(index);
502        }
503        Expression::AccessIndex {
504            ref mut base,
505            index: _,
506        } => {
507            adjust(base);
508        }
509        Expression::Splat {
510            ref mut value,
511            size: _,
512        } => {
513            adjust(value);
514        }
515        Expression::Swizzle {
516            ref mut vector,
517            size: _,
518            pattern: _,
519        } => {
520            adjust(vector);
521        }
522        Expression::Load { ref mut pointer } => {
523            adjust(pointer);
524        }
525        Expression::ImageSample {
526            ref mut image,
527            ref mut sampler,
528            ref mut coordinate,
529            ref mut array_index,
530            ref mut offset,
531            ref mut level,
532            ref mut depth_ref,
533            gather: _,
534            clamp_to_edge: _,
535        } => {
536            adjust(image);
537            adjust(sampler);
538            adjust(coordinate);
539            if let Some(e) = array_index.as_mut() {
540                adjust(e);
541            }
542            if let Some(e) = offset.as_mut() {
543                adjust(e);
544            }
545            match *level {
546                crate::SampleLevel::Exact(ref mut expr)
547                | crate::SampleLevel::Bias(ref mut expr) => {
548                    adjust(expr);
549                }
550                crate::SampleLevel::Gradient {
551                    ref mut x,
552                    ref mut y,
553                } => {
554                    adjust(x);
555                    adjust(y);
556                }
557                _ => {}
558            }
559            if let Some(e) = depth_ref.as_mut() {
560                adjust(e);
561            }
562        }
563        Expression::ImageLoad {
564            ref mut image,
565            ref mut coordinate,
566            ref mut array_index,
567            ref mut sample,
568            ref mut level,
569        } => {
570            adjust(image);
571            adjust(coordinate);
572            if let Some(e) = array_index.as_mut() {
573                adjust(e);
574            }
575            if let Some(e) = sample.as_mut() {
576                adjust(e);
577            }
578            if let Some(e) = level.as_mut() {
579                adjust(e);
580            }
581        }
582        Expression::ImageQuery {
583            ref mut image,
584            ref mut query,
585        } => {
586            adjust(image);
587            match *query {
588                crate::ImageQuery::Size { ref mut level } => {
589                    if let Some(e) = level.as_mut() {
590                        adjust(e);
591                    }
592                }
593                crate::ImageQuery::NumLevels
594                | crate::ImageQuery::NumLayers
595                | crate::ImageQuery::NumSamples => {}
596            }
597        }
598        Expression::Unary {
599            ref mut expr,
600            op: _,
601        } => {
602            adjust(expr);
603        }
604        Expression::Binary {
605            ref mut left,
606            ref mut right,
607            op: _,
608        } => {
609            adjust(left);
610            adjust(right);
611        }
612        Expression::Select {
613            ref mut condition,
614            ref mut accept,
615            ref mut reject,
616        } => {
617            adjust(condition);
618            adjust(accept);
619            adjust(reject);
620        }
621        Expression::Derivative {
622            ref mut expr,
623            axis: _,
624            ctrl: _,
625        } => {
626            adjust(expr);
627        }
628        Expression::Relational {
629            ref mut argument,
630            fun: _,
631        } => {
632            adjust(argument);
633        }
634        Expression::Math {
635            ref mut arg,
636            ref mut arg1,
637            ref mut arg2,
638            ref mut arg3,
639            fun: _,
640        } => {
641            adjust(arg);
642            if let Some(e) = arg1.as_mut() {
643                adjust(e);
644            }
645            if let Some(e) = arg2.as_mut() {
646                adjust(e);
647            }
648            if let Some(e) = arg3.as_mut() {
649                adjust(e);
650            }
651        }
652        Expression::As {
653            ref mut expr,
654            kind: _,
655            convert: _,
656        } => {
657            adjust(expr);
658        }
659        Expression::ArrayLength(ref mut expr) => {
660            adjust(expr);
661        }
662        Expression::RayQueryGetIntersection {
663            ref mut query,
664            committed: _,
665        } => {
666            adjust(query);
667        }
668        Expression::Literal(_)
669        | Expression::FunctionArgument(_)
670        | Expression::GlobalVariable(_)
671        | Expression::LocalVariable(_)
672        | Expression::CallResult(_)
673        | Expression::RayQueryProceedResult
674        | Expression::Constant(_)
675        | Expression::Override(_)
676        | Expression::ZeroValue(_)
677        | Expression::AtomicResult {
678            ty: _,
679            comparison: _,
680        }
681        | Expression::WorkGroupUniformLoadResult { ty: _ }
682        | Expression::SubgroupBallotResult
683        | Expression::SubgroupOperationResult { .. } => {}
684        Expression::RayQueryVertexPositions {
685            ref mut query,
686            committed: _,
687        } => {
688            adjust(query);
689        }
690        Expression::CooperativeLoad { ref mut data, .. } => {
691            adjust(&mut data.pointer);
692            adjust(&mut data.stride);
693        }
694        Expression::CooperativeMultiplyAdd {
695            ref mut a,
696            ref mut b,
697            ref mut c,
698        } => {
699            adjust(a);
700            adjust(b);
701            adjust(c);
702        }
703    }
704}
705
706/// Replace every expression handle in `block` with its counterpart
707/// given by `new_pos`.
708fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
709    for stmt in block.iter_mut() {
710        adjust_stmt(new_pos, stmt);
711    }
712}
713
714/// Replace every expression handle in `stmt` with its counterpart
715/// given by `new_pos`.
716fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
717    let adjust = |expr: &mut Handle<Expression>| {
718        *expr = new_pos[*expr];
719    };
720    match *stmt {
721        Statement::Emit(ref mut range) => {
722            if let Some((mut first, mut last)) = range.first_and_last() {
723                adjust(&mut first);
724                adjust(&mut last);
725                *range = Range::new_from_bounds(first, last);
726            }
727        }
728        Statement::Block(ref mut block) => {
729            adjust_block(new_pos, block);
730        }
731        Statement::If {
732            ref mut condition,
733            ref mut accept,
734            ref mut reject,
735        } => {
736            adjust(condition);
737            adjust_block(new_pos, accept);
738            adjust_block(new_pos, reject);
739        }
740        Statement::Switch {
741            ref mut selector,
742            ref mut cases,
743        } => {
744            adjust(selector);
745            for case in cases.iter_mut() {
746                adjust_block(new_pos, &mut case.body);
747            }
748        }
749        Statement::Loop {
750            ref mut body,
751            ref mut continuing,
752            ref mut break_if,
753        } => {
754            adjust_block(new_pos, body);
755            adjust_block(new_pos, continuing);
756            if let Some(e) = break_if.as_mut() {
757                adjust(e);
758            }
759        }
760        Statement::Return { ref mut value } => {
761            if let Some(e) = value.as_mut() {
762                adjust(e);
763            }
764        }
765        Statement::Store {
766            ref mut pointer,
767            ref mut value,
768        } => {
769            adjust(pointer);
770            adjust(value);
771        }
772        Statement::ImageStore {
773            ref mut image,
774            ref mut coordinate,
775            ref mut array_index,
776            ref mut value,
777        } => {
778            adjust(image);
779            adjust(coordinate);
780            if let Some(e) = array_index.as_mut() {
781                adjust(e);
782            }
783            adjust(value);
784        }
785        Statement::Atomic {
786            ref mut pointer,
787            ref mut value,
788            ref mut result,
789            ref mut fun,
790        } => {
791            adjust(pointer);
792            adjust(value);
793            if let Some(ref mut result) = *result {
794                adjust(result);
795            }
796            match *fun {
797                crate::AtomicFunction::Exchange {
798                    compare: Some(ref mut compare),
799                } => {
800                    adjust(compare);
801                }
802                crate::AtomicFunction::Add
803                | crate::AtomicFunction::Subtract
804                | crate::AtomicFunction::And
805                | crate::AtomicFunction::ExclusiveOr
806                | crate::AtomicFunction::InclusiveOr
807                | crate::AtomicFunction::Min
808                | crate::AtomicFunction::Max
809                | crate::AtomicFunction::Exchange { compare: None } => {}
810            }
811        }
812        Statement::ImageAtomic {
813            ref mut image,
814            ref mut coordinate,
815            ref mut array_index,
816            fun: _,
817            ref mut value,
818        } => {
819            adjust(image);
820            adjust(coordinate);
821            if let Some(ref mut array_index) = *array_index {
822                adjust(array_index);
823            }
824            adjust(value);
825        }
826        Statement::WorkGroupUniformLoad {
827            ref mut pointer,
828            ref mut result,
829        } => {
830            adjust(pointer);
831            adjust(result);
832        }
833        Statement::SubgroupBallot {
834            ref mut result,
835            ref mut predicate,
836        } => {
837            if let Some(ref mut predicate) = *predicate {
838                adjust(predicate);
839            }
840            adjust(result);
841        }
842        Statement::SubgroupCollectiveOperation {
843            ref mut argument,
844            ref mut result,
845            ..
846        } => {
847            adjust(argument);
848            adjust(result);
849        }
850        Statement::SubgroupGather {
851            ref mut mode,
852            ref mut argument,
853            ref mut result,
854        } => {
855            match *mode {
856                crate::GatherMode::BroadcastFirst => {}
857                crate::GatherMode::Broadcast(ref mut index)
858                | crate::GatherMode::Shuffle(ref mut index)
859                | crate::GatherMode::ShuffleDown(ref mut index)
860                | crate::GatherMode::ShuffleUp(ref mut index)
861                | crate::GatherMode::ShuffleXor(ref mut index)
862                | crate::GatherMode::QuadBroadcast(ref mut index) => {
863                    adjust(index);
864                }
865                crate::GatherMode::QuadSwap(_) => {}
866            }
867            adjust(argument);
868            adjust(result)
869        }
870        Statement::Call {
871            ref mut arguments,
872            ref mut result,
873            function: _,
874        } => {
875            for argument in arguments.iter_mut() {
876                adjust(argument);
877            }
878            if let Some(e) = result.as_mut() {
879                adjust(e);
880            }
881        }
882        Statement::RayQuery {
883            ref mut query,
884            ref mut fun,
885        } => {
886            adjust(query);
887            match *fun {
888                crate::RayQueryFunction::Initialize {
889                    ref mut acceleration_structure,
890                    ref mut descriptor,
891                } => {
892                    adjust(acceleration_structure);
893                    adjust(descriptor);
894                }
895                crate::RayQueryFunction::Proceed { ref mut result } => {
896                    adjust(result);
897                }
898                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
899                    adjust(hit_t);
900                }
901                crate::RayQueryFunction::ConfirmIntersection => {}
902                crate::RayQueryFunction::Terminate => {}
903            }
904        }
905        Statement::CooperativeStore {
906            ref mut target,
907            ref mut data,
908        } => {
909            adjust(target);
910            adjust(&mut data.pointer);
911            adjust(&mut data.stride);
912        }
913        Statement::RayPipelineFunction(ref mut func) => match *func {
914            crate::RayPipelineFunction::TraceRay {
915                ref mut acceleration_structure,
916                ref mut descriptor,
917                ref mut payload,
918            } => {
919                adjust(acceleration_structure);
920                adjust(descriptor);
921                adjust(payload);
922            }
923        },
924        Statement::Break
925        | Statement::Continue
926        | Statement::Kill
927        | Statement::ControlBarrier(_)
928        | Statement::MemoryBarrier(_) => {}
929    }
930}
931
932/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
933///
934/// According to validation, [`Emit`] statements must not cover any expressions
935/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
936/// by successful constant evaluation fall into that category, meaning that
937/// `process_function` will usually rewrite [`Override`] expressions and those
938/// that use their values into pre-emitted expressions, leaving any [`Emit`]
939/// statements that cover them invalid.
940///
941/// This function rewrites all [`Emit`] statements into zero or more new
942/// [`Emit`] statements covering only those expressions in the original range
943/// that are not pre-emitted.
944///
945/// [`Emit`]: Statement::Emit
946/// [`needs_pre_emit`]: Expression::needs_pre_emit
947/// [`Override`]: Expression::Override
948fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
949    let original = mem::replace(block, Block::with_capacity(block.len()));
950    for (stmt, span) in original.span_into_iter() {
951        match stmt {
952            Statement::Emit(range) => {
953                let mut current = None;
954                for expr_h in range {
955                    if expressions[expr_h].needs_pre_emit() {
956                        if let Some((first, last)) = current {
957                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
958                        }
959
960                        current = None;
961                    } else if let Some((_, ref mut last)) = current {
962                        *last = expr_h;
963                    } else {
964                        current = Some((expr_h, expr_h));
965                    }
966                }
967                if let Some((first, last)) = current {
968                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
969                }
970            }
971            Statement::Block(mut child) => {
972                filter_emits_in_block(&mut child, expressions);
973                block.push(Statement::Block(child), span);
974            }
975            Statement::If {
976                condition,
977                mut accept,
978                mut reject,
979            } => {
980                filter_emits_in_block(&mut accept, expressions);
981                filter_emits_in_block(&mut reject, expressions);
982                block.push(
983                    Statement::If {
984                        condition,
985                        accept,
986                        reject,
987                    },
988                    span,
989                );
990            }
991            Statement::Switch {
992                selector,
993                mut cases,
994            } => {
995                for case in &mut cases {
996                    filter_emits_in_block(&mut case.body, expressions);
997                }
998                block.push(Statement::Switch { selector, cases }, span);
999            }
1000            Statement::Loop {
1001                mut body,
1002                mut continuing,
1003                break_if,
1004            } => {
1005                filter_emits_in_block(&mut body, expressions);
1006                filter_emits_in_block(&mut continuing, expressions);
1007                block.push(
1008                    Statement::Loop {
1009                        body,
1010                        continuing,
1011                        break_if,
1012                    },
1013                    span,
1014                );
1015            }
1016            stmt => block.push(stmt.clone(), span),
1017        }
1018    }
1019}
1020
1021fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
1022    // note that in rust 0.0 == -0.0
1023    match scalar {
1024        Scalar::BOOL => {
1025            // https://webidl.spec.whatwg.org/#js-boolean
1026            let value = value != 0.0 && !value.is_nan();
1027            Ok(Literal::Bool(value))
1028        }
1029        Scalar::I32 => {
1030            // https://webidl.spec.whatwg.org/#js-long
1031            if !value.is_finite() {
1032                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1033            }
1034
1035            let value = value.trunc();
1036            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
1037                return Err(PipelineConstantError::DstRangeTooSmall);
1038            }
1039
1040            let value = value as i32;
1041            Ok(Literal::I32(value))
1042        }
1043        Scalar::U32 => {
1044            // https://webidl.spec.whatwg.org/#js-unsigned-long
1045            if !value.is_finite() {
1046                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1047            }
1048
1049            let value = value.trunc();
1050            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1051                return Err(PipelineConstantError::DstRangeTooSmall);
1052            }
1053
1054            let value = value as u32;
1055            Ok(Literal::U32(value))
1056        }
1057        Scalar::F16 => {
1058            // https://webidl.spec.whatwg.org/#js-float
1059            if !value.is_finite() {
1060                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1061            }
1062
1063            let value = half::f16::from_f64(value);
1064            if !value.is_finite() {
1065                return Err(PipelineConstantError::DstRangeTooSmall);
1066            }
1067
1068            Ok(Literal::F16(value))
1069        }
1070        Scalar::F32 => {
1071            // https://webidl.spec.whatwg.org/#js-float
1072            if !value.is_finite() {
1073                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1074            }
1075
1076            let value = value as f32;
1077            if !value.is_finite() {
1078                return Err(PipelineConstantError::DstRangeTooSmall);
1079            }
1080
1081            Ok(Literal::F32(value))
1082        }
1083        Scalar::F64 => {
1084            // https://webidl.spec.whatwg.org/#js-double
1085            if !value.is_finite() {
1086                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1087            }
1088
1089            Ok(Literal::F64(value))
1090        }
1091        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1092            unreachable!("abstract values should not be validated out of override processing")
1093        }
1094        _ => unreachable!("unrecognized scalar type for override"),
1095    }
1096}
1097
1098#[test]
1099fn test_map_value_to_literal() {
1100    let bool_test_cases = [
1101        (0.0, false),
1102        (-0.0, false),
1103        (f64::NAN, false),
1104        (1.0, true),
1105        (f64::INFINITY, true),
1106        (f64::NEG_INFINITY, true),
1107    ];
1108    for (value, out) in bool_test_cases {
1109        let res = Ok(Literal::Bool(out));
1110        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1111    }
1112
1113    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1114        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1115            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1116            assert_eq!(map_value_to_literal(value, scalar), res);
1117        }
1118    }
1119
1120    // i32
1121    assert_eq!(
1122        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1123        Ok(Literal::I32(i32::MIN))
1124    );
1125    assert_eq!(
1126        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1127        Ok(Literal::I32(i32::MAX))
1128    );
1129    assert_eq!(
1130        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1131        Err(PipelineConstantError::DstRangeTooSmall)
1132    );
1133    assert_eq!(
1134        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1135        Err(PipelineConstantError::DstRangeTooSmall)
1136    );
1137
1138    // u32
1139    assert_eq!(
1140        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1141        Ok(Literal::U32(u32::MIN))
1142    );
1143    assert_eq!(
1144        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1145        Ok(Literal::U32(u32::MAX))
1146    );
1147    assert_eq!(
1148        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1149        Err(PipelineConstantError::DstRangeTooSmall)
1150    );
1151    assert_eq!(
1152        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1153        Err(PipelineConstantError::DstRangeTooSmall)
1154    );
1155
1156    // f32
1157    assert_eq!(
1158        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1159        Ok(Literal::F32(f32::MIN))
1160    );
1161    assert_eq!(
1162        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1163        Ok(Literal::F32(f32::MAX))
1164    );
1165    assert_eq!(
1166        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1167        Ok(Literal::F32(f32::MIN))
1168    );
1169    assert_eq!(
1170        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1171        Ok(Literal::F32(f32::MAX))
1172    );
1173    assert_eq!(
1174        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1175        Err(PipelineConstantError::DstRangeTooSmall)
1176    );
1177    assert_eq!(
1178        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1179        Err(PipelineConstantError::DstRangeTooSmall)
1180    );
1181
1182    // f64
1183    assert_eq!(
1184        map_value_to_literal(f64::MIN, Scalar::F64),
1185        Ok(Literal::F64(f64::MIN))
1186    );
1187    assert_eq!(
1188        map_value_to_literal(f64::MAX, Scalar::F64),
1189        Ok(Literal::F64(f64::MAX))
1190    );
1191}