naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    compact::{compact, KeepUnused},
14    ir,
15    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18    Span, Statement, TypeInner, WithSpan,
19};
20
21// Possibly unused if not compiled with no_std
22#[allow(unused_imports)]
23use num_traits::float::FloatCore as _;
24
25#[derive(Error, Debug, Clone)]
26#[cfg_attr(test, derive(PartialEq))]
27pub enum PipelineConstantError {
28    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
29    MissingValue(String),
30    #[error(
31        "Source f64 value needs to be finite ({}) for number destinations",
32        "NaNs and Inifinites are not allowed"
33    )]
34    SrcNeedsToBeFinite,
35    #[error("Source f64 value doesn't fit in destination")]
36    DstRangeTooSmall,
37    #[error(transparent)]
38    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
39    #[error(transparent)]
40    ValidationError(#[from] WithSpan<ValidationError>),
41    #[error("workgroup_size override isn't strictly positive")]
42    NegativeWorkgroupSize,
43    #[error("max vertices or max primitives is negative")]
44    NegativeMeshOutputMax,
45}
46
47/// Compact `module` and replace all overrides with constants.
48///
49/// If no changes are needed, this just returns `Cow::Borrowed` references to
50/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
51/// selected entry point, compacts the module, edits its [`global_expressions`]
52/// arena to contain only fully-evaluated expressions, and returns the
53/// simplified module and its validation results.
54///
55/// The module returned has an empty `overrides` arena, and the
56/// `global_expressions` arena contains only fully-evaluated expressions.
57///
58/// [`global_expressions`]: Module::global_expressions
59pub fn process_overrides<'a>(
60    module: &'a Module,
61    module_info: &'a ModuleInfo,
62    entry_point: Option<(ir::ShaderStage, &str)>,
63    pipeline_constants: &PipelineConstants,
64) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
65    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
66        // We skip compacting the module here mostly to reduce the risk of
67        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
68        // Compaction doesn't cost very much [1], so it would also be reasonable
69        // to do it unconditionally. Even when there is a single entry point or
70        // when no entry point is specified, it is still possible that there
71        // are unreferenced items in the module that would be removed by this
72        // compaction.
73        //
74        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
75        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
76    }
77
78    let mut module = module.clone();
79    if let Some((ep_stage, ep_name)) = entry_point {
80        module
81            .entry_points
82            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
83    }
84
85    // Compact the module to remove anything not reachable from an entry point.
86    // This is necessary because we may not have values for overrides that are
87    // not reachable from the/an entry point.
88    compact(&mut module, KeepUnused::No);
89
90    // If there are no overrides in the module, then we can skip the rest.
91    if module.overrides.is_empty() {
92        return revalidate(module);
93    }
94
95    // A map from override handles to the handles of the constants
96    // we've replaced them with.
97    let mut override_map = HandleVec::with_capacity(module.overrides.len());
98
99    // A map from `module`'s original global expression handles to
100    // handles in the new, simplified global expression arena.
101    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
102
103    // The set of constants whose initializer handles we've already
104    // updated to refer to the newly built global expression arena.
105    //
106    // All constants in `module` must have their `init` handles
107    // updated to point into the new, simplified global expression
108    // arena. Some of these we can most easily handle as a side effect
109    // during the simplification process, but we must handle the rest
110    // in a final fixup pass, guided by `adjusted_global_expressions`. We
111    // add their handles to this set, so that the final fixup step can
112    // leave them alone.
113    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
114
115    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
116    let mut layouter = crate::proc::Layouter::default();
117
118    // An iterator through the original overrides table, consumed in
119    // approximate tandem with the global expressions.
120    let mut overrides = mem::take(&mut module.overrides);
121    let mut override_iter = overrides.iter_mut_span();
122
123    // Do two things in tandem:
124    //
125    // - Rebuild the global expression arena from scratch, fully
126    //   evaluating all expressions, and replacing each `Override`
127    //   expression in `module.global_expressions` with a `Constant`
128    //   expression.
129    //
130    // - Build a new `Constant` in `module.constants` to take the
131    //   place of each `Override`.
132    //
133    // Build a map from old global expression handles to their
134    // fully-evaluated counterparts in `adjusted_global_expressions` as we
135    // go.
136    //
137    // Why in tandem? Overrides refer to expressions, and expressions
138    // refer to overrides, so we can't disentangle the two into
139    // separate phases. However, we can take advantage of the fact
140    // that the overrides and expressions must form a DAG, and work
141    // our way from the leaves to the roots, replacing and evaluating
142    // as we go.
143    //
144    // Although the two loops are nested, this is really two
145    // alternating phases: we adjust and evaluate constant expressions
146    // until we hit an `Override` expression, at which point we switch
147    // to building `Constant`s for `Overrides` until we've handled the
148    // one used by the expression. Then we switch back to processing
149    // expressions. Because we know they form a DAG, we know the
150    // `Override` expressions we encounter can only have initializers
151    // referring to global expressions we've already simplified.
152    for (old_h, expr, span) in module.global_expressions.drain() {
153        let mut expr = match expr {
154            Expression::Override(h) => {
155                let c_h = if let Some(new_h) = override_map.get(h) {
156                    *new_h
157                } else {
158                    let mut new_h = None;
159                    for entry in override_iter.by_ref() {
160                        let stop = entry.0 == h;
161                        new_h = Some(process_override(
162                            entry,
163                            pipeline_constants,
164                            &mut module,
165                            &mut override_map,
166                            &adjusted_global_expressions,
167                            &mut adjusted_constant_initializers,
168                            &mut global_expression_kind_tracker,
169                        )?);
170                        if stop {
171                            break;
172                        }
173                    }
174                    new_h.unwrap()
175                };
176                Expression::Constant(c_h)
177            }
178            Expression::Constant(c_h) => {
179                if adjusted_constant_initializers.insert(c_h) {
180                    let init = &mut module.constants[c_h].init;
181                    *init = adjusted_global_expressions[*init];
182                }
183                expr
184            }
185            expr => expr,
186        };
187        let mut evaluator = ConstantEvaluator::for_wgsl_module(
188            &mut module,
189            &mut global_expression_kind_tracker,
190            &mut layouter,
191            false,
192        );
193        adjust_expr(&adjusted_global_expressions, &mut expr);
194        let h = evaluator.try_eval_and_append(expr, span)?;
195        adjusted_global_expressions.insert(old_h, h);
196    }
197
198    // Finish processing any overrides we didn't visit in the loop above.
199    for entry in override_iter {
200        match *entry.1 {
201            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
202                process_override(
203                    entry,
204                    pipeline_constants,
205                    &mut module,
206                    &mut override_map,
207                    &adjusted_global_expressions,
208                    &mut adjusted_constant_initializers,
209                    &mut global_expression_kind_tracker,
210                )?;
211            }
212            Override {
213                init: Some(ref mut init),
214                ..
215            } => {
216                *init = adjusted_global_expressions[*init];
217            }
218            _ => {}
219        }
220    }
221
222    // Update the initialization expression handles of all `Constant`s
223    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
224    // passant.
225    for (_, c) in module
226        .constants
227        .iter_mut()
228        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
229    {
230        c.init = adjusted_global_expressions[c.init];
231    }
232
233    for (_, v) in module.global_variables.iter_mut() {
234        if let Some(ref mut init) = v.init {
235            *init = adjusted_global_expressions[*init];
236        }
237    }
238
239    let mut functions = mem::take(&mut module.functions);
240    for (_, function) in functions.iter_mut() {
241        process_function(&mut module, &override_map, &mut layouter, function)?;
242    }
243    module.functions = functions;
244
245    let mut entry_points = mem::take(&mut module.entry_points);
246    for ep in entry_points.iter_mut() {
247        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
248        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
249        process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
250    }
251    module.entry_points = entry_points;
252    module.overrides = overrides;
253
254    // Now that we've rewritten all the expressions, we need to
255    // recompute their types and other metadata. For the time being,
256    // do a full re-validation.
257    revalidate(module)
258}
259
260fn revalidate(
261    module: Module,
262) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
263    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
264    let module_info = validator.validate_resolved_overrides(&module)?;
265    Ok((Cow::Owned(module), Cow::Owned(module_info)))
266}
267
268fn process_workgroup_size_override(
269    module: &mut Module,
270    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
271    ep: &mut crate::EntryPoint,
272) -> Result<(), PipelineConstantError> {
273    match ep.workgroup_size_overrides {
274        None => {}
275        Some(overrides) => {
276            overrides.iter().enumerate().try_for_each(
277                |(i, overridden)| -> Result<(), PipelineConstantError> {
278                    match *overridden {
279                        None => Ok(()),
280                        Some(h) => {
281                            ep.workgroup_size[i] = module
282                                .to_ctx()
283                                .get_const_val(adjusted_global_expressions[h])
284                                .map(|n| {
285                                    if n == 0 {
286                                        Err(PipelineConstantError::NegativeWorkgroupSize)
287                                    } else {
288                                        Ok(n)
289                                    }
290                                })
291                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
292                            Ok(())
293                        }
294                    }
295                },
296            )?;
297            ep.workgroup_size_overrides = None;
298        }
299    }
300    Ok(())
301}
302
303fn process_mesh_shader_overrides(
304    module: &mut Module,
305    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
306    ep: &mut crate::EntryPoint,
307) -> Result<(), PipelineConstantError> {
308    if let Some(ref mut mesh_info) = ep.mesh_info {
309        if let Some(r#override) = mesh_info.max_vertices_override {
310            mesh_info.max_vertices = module
311                .to_ctx()
312                .get_const_val(adjusted_global_expressions[r#override])
313                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
314        }
315        if let Some(r#override) = mesh_info.max_primitives_override {
316            mesh_info.max_primitives = module
317                .to_ctx()
318                .get_const_val(adjusted_global_expressions[r#override])
319                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
320        }
321    }
322    Ok(())
323}
324
325/// Add a [`Constant`] to `module` for the override `old_h`.
326///
327/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
328fn process_override(
329    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
330    pipeline_constants: &PipelineConstants,
331    module: &mut Module,
332    override_map: &mut HandleVec<Override, Handle<Constant>>,
333    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
334    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
335    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
336) -> Result<Handle<Constant>, PipelineConstantError> {
337    // Determine which key to use for `r#override` in `pipeline_constants`.
338    let key = if let Some(id) = r#override.id {
339        Cow::Owned(id.to_string())
340    } else if let Some(ref name) = r#override.name {
341        Cow::Borrowed(name)
342    } else {
343        unreachable!();
344    };
345
346    // Generate a global expression for `r#override`'s value, either
347    // from the provided `pipeline_constants` table or its initializer
348    // in the module.
349    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
350        let literal = match module.types[r#override.ty].inner {
351            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
352            _ => unreachable!(),
353        };
354        let expr = module
355            .global_expressions
356            .append(Expression::Literal(literal), Span::UNDEFINED);
357        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
358        expr
359    } else if let Some(init) = r#override.init {
360        adjusted_global_expressions[init]
361    } else {
362        return Err(PipelineConstantError::MissingValue(key.to_string()));
363    };
364
365    // Generate a new `Constant` to represent the override's value.
366    let constant = Constant {
367        name: r#override.name.clone(),
368        ty: r#override.ty,
369        init,
370    };
371    let h = module.constants.append(constant, *span);
372    override_map.insert(old_h, h);
373    adjusted_constant_initializers.insert(h);
374    r#override.init = Some(init);
375    Ok(h)
376}
377
378/// Replace all override expressions in `function` with fully-evaluated constants.
379///
380/// Replace all `Expression::Override`s in `function`'s expression arena with
381/// the corresponding `Expression::Constant`s, as given in `override_map`.
382/// Replace any expressions whose values are now known with their fully
383/// evaluated form.
384///
385/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
386/// `Handle<Constant>` for the override's final value.
387fn process_function(
388    module: &mut Module,
389    override_map: &HandleVec<Override, Handle<Constant>>,
390    layouter: &mut crate::proc::Layouter,
391    function: &mut Function,
392) -> Result<(), ConstantEvaluatorError> {
393    // A map from original local expression handles to
394    // handles in the new, local expression arena.
395    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
396
397    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
398
399    let mut expressions = mem::take(&mut function.expressions);
400
401    // Dummy `emitter` and `block` for the constant evaluator.
402    // We can ignore the concept of emitting expressions here since
403    // expressions have already been covered by a `Statement::Emit`
404    // in the frontend.
405    // The only thing we might have to do is remove some expressions
406    // that have been covered by a `Statement::Emit`. See the docs of
407    // `filter_emits_in_block` for the reasoning.
408    let mut emitter = Emitter::default();
409    let mut block = Block::new();
410
411    let mut evaluator = ConstantEvaluator::for_wgsl_function(
412        module,
413        &mut function.expressions,
414        &mut local_expression_kind_tracker,
415        layouter,
416        &mut emitter,
417        &mut block,
418        false,
419    );
420
421    for (old_h, mut expr, span) in expressions.drain() {
422        if let Expression::Override(h) = expr {
423            expr = Expression::Constant(override_map[h]);
424        }
425        adjust_expr(&adjusted_local_expressions, &mut expr);
426        let h = evaluator.try_eval_and_append(expr, span)?;
427        adjusted_local_expressions.insert(old_h, h);
428    }
429
430    adjust_block(&adjusted_local_expressions, &mut function.body);
431
432    filter_emits_in_block(&mut function.body, &function.expressions);
433
434    // Update local expression initializers.
435    for (_, local) in function.local_variables.iter_mut() {
436        if let &mut Some(ref mut init) = &mut local.init {
437            *init = adjusted_local_expressions[*init];
438        }
439    }
440
441    // We've changed the keys of `function.named_expression`, so we have to
442    // rebuild it from scratch.
443    let named_expressions = mem::take(&mut function.named_expressions);
444    for (expr_h, name) in named_expressions {
445        function
446            .named_expressions
447            .insert(adjusted_local_expressions[expr_h], name);
448    }
449
450    Ok(())
451}
452
453/// Replace every expression handle in `expr` with its counterpart
454/// given by `new_pos`.
455fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
456    let adjust = |expr: &mut Handle<Expression>| {
457        *expr = new_pos[*expr];
458    };
459    match *expr {
460        Expression::Compose {
461            ref mut components,
462            ty: _,
463        } => {
464            for c in components.iter_mut() {
465                adjust(c);
466            }
467        }
468        Expression::Access {
469            ref mut base,
470            ref mut index,
471        } => {
472            adjust(base);
473            adjust(index);
474        }
475        Expression::AccessIndex {
476            ref mut base,
477            index: _,
478        } => {
479            adjust(base);
480        }
481        Expression::Splat {
482            ref mut value,
483            size: _,
484        } => {
485            adjust(value);
486        }
487        Expression::Swizzle {
488            ref mut vector,
489            size: _,
490            pattern: _,
491        } => {
492            adjust(vector);
493        }
494        Expression::Load { ref mut pointer } => {
495            adjust(pointer);
496        }
497        Expression::ImageSample {
498            ref mut image,
499            ref mut sampler,
500            ref mut coordinate,
501            ref mut array_index,
502            ref mut offset,
503            ref mut level,
504            ref mut depth_ref,
505            gather: _,
506            clamp_to_edge: _,
507        } => {
508            adjust(image);
509            adjust(sampler);
510            adjust(coordinate);
511            if let Some(e) = array_index.as_mut() {
512                adjust(e);
513            }
514            if let Some(e) = offset.as_mut() {
515                adjust(e);
516            }
517            match *level {
518                crate::SampleLevel::Exact(ref mut expr)
519                | crate::SampleLevel::Bias(ref mut expr) => {
520                    adjust(expr);
521                }
522                crate::SampleLevel::Gradient {
523                    ref mut x,
524                    ref mut y,
525                } => {
526                    adjust(x);
527                    adjust(y);
528                }
529                _ => {}
530            }
531            if let Some(e) = depth_ref.as_mut() {
532                adjust(e);
533            }
534        }
535        Expression::ImageLoad {
536            ref mut image,
537            ref mut coordinate,
538            ref mut array_index,
539            ref mut sample,
540            ref mut level,
541        } => {
542            adjust(image);
543            adjust(coordinate);
544            if let Some(e) = array_index.as_mut() {
545                adjust(e);
546            }
547            if let Some(e) = sample.as_mut() {
548                adjust(e);
549            }
550            if let Some(e) = level.as_mut() {
551                adjust(e);
552            }
553        }
554        Expression::ImageQuery {
555            ref mut image,
556            ref mut query,
557        } => {
558            adjust(image);
559            match *query {
560                crate::ImageQuery::Size { ref mut level } => {
561                    if let Some(e) = level.as_mut() {
562                        adjust(e);
563                    }
564                }
565                crate::ImageQuery::NumLevels
566                | crate::ImageQuery::NumLayers
567                | crate::ImageQuery::NumSamples => {}
568            }
569        }
570        Expression::Unary {
571            ref mut expr,
572            op: _,
573        } => {
574            adjust(expr);
575        }
576        Expression::Binary {
577            ref mut left,
578            ref mut right,
579            op: _,
580        } => {
581            adjust(left);
582            adjust(right);
583        }
584        Expression::Select {
585            ref mut condition,
586            ref mut accept,
587            ref mut reject,
588        } => {
589            adjust(condition);
590            adjust(accept);
591            adjust(reject);
592        }
593        Expression::Derivative {
594            ref mut expr,
595            axis: _,
596            ctrl: _,
597        } => {
598            adjust(expr);
599        }
600        Expression::Relational {
601            ref mut argument,
602            fun: _,
603        } => {
604            adjust(argument);
605        }
606        Expression::Math {
607            ref mut arg,
608            ref mut arg1,
609            ref mut arg2,
610            ref mut arg3,
611            fun: _,
612        } => {
613            adjust(arg);
614            if let Some(e) = arg1.as_mut() {
615                adjust(e);
616            }
617            if let Some(e) = arg2.as_mut() {
618                adjust(e);
619            }
620            if let Some(e) = arg3.as_mut() {
621                adjust(e);
622            }
623        }
624        Expression::As {
625            ref mut expr,
626            kind: _,
627            convert: _,
628        } => {
629            adjust(expr);
630        }
631        Expression::ArrayLength(ref mut expr) => {
632            adjust(expr);
633        }
634        Expression::RayQueryGetIntersection {
635            ref mut query,
636            committed: _,
637        } => {
638            adjust(query);
639        }
640        Expression::Literal(_)
641        | Expression::FunctionArgument(_)
642        | Expression::GlobalVariable(_)
643        | Expression::LocalVariable(_)
644        | Expression::CallResult(_)
645        | Expression::RayQueryProceedResult
646        | Expression::Constant(_)
647        | Expression::Override(_)
648        | Expression::ZeroValue(_)
649        | Expression::AtomicResult {
650            ty: _,
651            comparison: _,
652        }
653        | Expression::WorkGroupUniformLoadResult { ty: _ }
654        | Expression::SubgroupBallotResult
655        | Expression::SubgroupOperationResult { .. } => {}
656        Expression::RayQueryVertexPositions {
657            ref mut query,
658            committed: _,
659        } => {
660            adjust(query);
661        }
662        Expression::CooperativeLoad { ref mut data, .. } => {
663            adjust(&mut data.pointer);
664            adjust(&mut data.stride);
665        }
666        Expression::CooperativeMultiplyAdd {
667            ref mut a,
668            ref mut b,
669            ref mut c,
670        } => {
671            adjust(a);
672            adjust(b);
673            adjust(c);
674        }
675    }
676}
677
678/// Replace every expression handle in `block` with its counterpart
679/// given by `new_pos`.
680fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
681    for stmt in block.iter_mut() {
682        adjust_stmt(new_pos, stmt);
683    }
684}
685
686/// Replace every expression handle in `stmt` with its counterpart
687/// given by `new_pos`.
688fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
689    let adjust = |expr: &mut Handle<Expression>| {
690        *expr = new_pos[*expr];
691    };
692    match *stmt {
693        Statement::Emit(ref mut range) => {
694            if let Some((mut first, mut last)) = range.first_and_last() {
695                adjust(&mut first);
696                adjust(&mut last);
697                *range = Range::new_from_bounds(first, last);
698            }
699        }
700        Statement::Block(ref mut block) => {
701            adjust_block(new_pos, block);
702        }
703        Statement::If {
704            ref mut condition,
705            ref mut accept,
706            ref mut reject,
707        } => {
708            adjust(condition);
709            adjust_block(new_pos, accept);
710            adjust_block(new_pos, reject);
711        }
712        Statement::Switch {
713            ref mut selector,
714            ref mut cases,
715        } => {
716            adjust(selector);
717            for case in cases.iter_mut() {
718                adjust_block(new_pos, &mut case.body);
719            }
720        }
721        Statement::Loop {
722            ref mut body,
723            ref mut continuing,
724            ref mut break_if,
725        } => {
726            adjust_block(new_pos, body);
727            adjust_block(new_pos, continuing);
728            if let Some(e) = break_if.as_mut() {
729                adjust(e);
730            }
731        }
732        Statement::Return { ref mut value } => {
733            if let Some(e) = value.as_mut() {
734                adjust(e);
735            }
736        }
737        Statement::Store {
738            ref mut pointer,
739            ref mut value,
740        } => {
741            adjust(pointer);
742            adjust(value);
743        }
744        Statement::ImageStore {
745            ref mut image,
746            ref mut coordinate,
747            ref mut array_index,
748            ref mut value,
749        } => {
750            adjust(image);
751            adjust(coordinate);
752            if let Some(e) = array_index.as_mut() {
753                adjust(e);
754            }
755            adjust(value);
756        }
757        Statement::Atomic {
758            ref mut pointer,
759            ref mut value,
760            ref mut result,
761            ref mut fun,
762        } => {
763            adjust(pointer);
764            adjust(value);
765            if let Some(ref mut result) = *result {
766                adjust(result);
767            }
768            match *fun {
769                crate::AtomicFunction::Exchange {
770                    compare: Some(ref mut compare),
771                } => {
772                    adjust(compare);
773                }
774                crate::AtomicFunction::Add
775                | crate::AtomicFunction::Subtract
776                | crate::AtomicFunction::And
777                | crate::AtomicFunction::ExclusiveOr
778                | crate::AtomicFunction::InclusiveOr
779                | crate::AtomicFunction::Min
780                | crate::AtomicFunction::Max
781                | crate::AtomicFunction::Exchange { compare: None } => {}
782            }
783        }
784        Statement::ImageAtomic {
785            ref mut image,
786            ref mut coordinate,
787            ref mut array_index,
788            fun: _,
789            ref mut value,
790        } => {
791            adjust(image);
792            adjust(coordinate);
793            if let Some(ref mut array_index) = *array_index {
794                adjust(array_index);
795            }
796            adjust(value);
797        }
798        Statement::WorkGroupUniformLoad {
799            ref mut pointer,
800            ref mut result,
801        } => {
802            adjust(pointer);
803            adjust(result);
804        }
805        Statement::SubgroupBallot {
806            ref mut result,
807            ref mut predicate,
808        } => {
809            if let Some(ref mut predicate) = *predicate {
810                adjust(predicate);
811            }
812            adjust(result);
813        }
814        Statement::SubgroupCollectiveOperation {
815            ref mut argument,
816            ref mut result,
817            ..
818        } => {
819            adjust(argument);
820            adjust(result);
821        }
822        Statement::SubgroupGather {
823            ref mut mode,
824            ref mut argument,
825            ref mut result,
826        } => {
827            match *mode {
828                crate::GatherMode::BroadcastFirst => {}
829                crate::GatherMode::Broadcast(ref mut index)
830                | crate::GatherMode::Shuffle(ref mut index)
831                | crate::GatherMode::ShuffleDown(ref mut index)
832                | crate::GatherMode::ShuffleUp(ref mut index)
833                | crate::GatherMode::ShuffleXor(ref mut index)
834                | crate::GatherMode::QuadBroadcast(ref mut index) => {
835                    adjust(index);
836                }
837                crate::GatherMode::QuadSwap(_) => {}
838            }
839            adjust(argument);
840            adjust(result)
841        }
842        Statement::Call {
843            ref mut arguments,
844            ref mut result,
845            function: _,
846        } => {
847            for argument in arguments.iter_mut() {
848                adjust(argument);
849            }
850            if let Some(e) = result.as_mut() {
851                adjust(e);
852            }
853        }
854        Statement::RayQuery {
855            ref mut query,
856            ref mut fun,
857        } => {
858            adjust(query);
859            match *fun {
860                crate::RayQueryFunction::Initialize {
861                    ref mut acceleration_structure,
862                    ref mut descriptor,
863                } => {
864                    adjust(acceleration_structure);
865                    adjust(descriptor);
866                }
867                crate::RayQueryFunction::Proceed { ref mut result } => {
868                    adjust(result);
869                }
870                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
871                    adjust(hit_t);
872                }
873                crate::RayQueryFunction::ConfirmIntersection => {}
874                crate::RayQueryFunction::Terminate => {}
875            }
876        }
877        Statement::CooperativeStore {
878            ref mut target,
879            ref mut data,
880        } => {
881            adjust(target);
882            adjust(&mut data.pointer);
883            adjust(&mut data.stride);
884        }
885        Statement::Break
886        | Statement::Continue
887        | Statement::Kill
888        | Statement::ControlBarrier(_)
889        | Statement::MemoryBarrier(_) => {}
890    }
891}
892
893/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
894///
895/// According to validation, [`Emit`] statements must not cover any expressions
896/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
897/// by successful constant evaluation fall into that category, meaning that
898/// `process_function` will usually rewrite [`Override`] expressions and those
899/// that use their values into pre-emitted expressions, leaving any [`Emit`]
900/// statements that cover them invalid.
901///
902/// This function rewrites all [`Emit`] statements into zero or more new
903/// [`Emit`] statements covering only those expressions in the original range
904/// that are not pre-emitted.
905///
906/// [`Emit`]: Statement::Emit
907/// [`needs_pre_emit`]: Expression::needs_pre_emit
908/// [`Override`]: Expression::Override
909fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
910    let original = mem::replace(block, Block::with_capacity(block.len()));
911    for (stmt, span) in original.span_into_iter() {
912        match stmt {
913            Statement::Emit(range) => {
914                let mut current = None;
915                for expr_h in range {
916                    if expressions[expr_h].needs_pre_emit() {
917                        if let Some((first, last)) = current {
918                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
919                        }
920
921                        current = None;
922                    } else if let Some((_, ref mut last)) = current {
923                        *last = expr_h;
924                    } else {
925                        current = Some((expr_h, expr_h));
926                    }
927                }
928                if let Some((first, last)) = current {
929                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
930                }
931            }
932            Statement::Block(mut child) => {
933                filter_emits_in_block(&mut child, expressions);
934                block.push(Statement::Block(child), span);
935            }
936            Statement::If {
937                condition,
938                mut accept,
939                mut reject,
940            } => {
941                filter_emits_in_block(&mut accept, expressions);
942                filter_emits_in_block(&mut reject, expressions);
943                block.push(
944                    Statement::If {
945                        condition,
946                        accept,
947                        reject,
948                    },
949                    span,
950                );
951            }
952            Statement::Switch {
953                selector,
954                mut cases,
955            } => {
956                for case in &mut cases {
957                    filter_emits_in_block(&mut case.body, expressions);
958                }
959                block.push(Statement::Switch { selector, cases }, span);
960            }
961            Statement::Loop {
962                mut body,
963                mut continuing,
964                break_if,
965            } => {
966                filter_emits_in_block(&mut body, expressions);
967                filter_emits_in_block(&mut continuing, expressions);
968                block.push(
969                    Statement::Loop {
970                        body,
971                        continuing,
972                        break_if,
973                    },
974                    span,
975                );
976            }
977            stmt => block.push(stmt.clone(), span),
978        }
979    }
980}
981
982fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
983    // note that in rust 0.0 == -0.0
984    match scalar {
985        Scalar::BOOL => {
986            // https://webidl.spec.whatwg.org/#js-boolean
987            let value = value != 0.0 && !value.is_nan();
988            Ok(Literal::Bool(value))
989        }
990        Scalar::I32 => {
991            // https://webidl.spec.whatwg.org/#js-long
992            if !value.is_finite() {
993                return Err(PipelineConstantError::SrcNeedsToBeFinite);
994            }
995
996            let value = value.trunc();
997            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
998                return Err(PipelineConstantError::DstRangeTooSmall);
999            }
1000
1001            let value = value as i32;
1002            Ok(Literal::I32(value))
1003        }
1004        Scalar::U32 => {
1005            // https://webidl.spec.whatwg.org/#js-unsigned-long
1006            if !value.is_finite() {
1007                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1008            }
1009
1010            let value = value.trunc();
1011            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1012                return Err(PipelineConstantError::DstRangeTooSmall);
1013            }
1014
1015            let value = value as u32;
1016            Ok(Literal::U32(value))
1017        }
1018        Scalar::F16 => {
1019            // https://webidl.spec.whatwg.org/#js-float
1020            if !value.is_finite() {
1021                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1022            }
1023
1024            let value = half::f16::from_f64(value);
1025            if !value.is_finite() {
1026                return Err(PipelineConstantError::DstRangeTooSmall);
1027            }
1028
1029            Ok(Literal::F16(value))
1030        }
1031        Scalar::F32 => {
1032            // https://webidl.spec.whatwg.org/#js-float
1033            if !value.is_finite() {
1034                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1035            }
1036
1037            let value = value as f32;
1038            if !value.is_finite() {
1039                return Err(PipelineConstantError::DstRangeTooSmall);
1040            }
1041
1042            Ok(Literal::F32(value))
1043        }
1044        Scalar::F64 => {
1045            // https://webidl.spec.whatwg.org/#js-double
1046            if !value.is_finite() {
1047                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1048            }
1049
1050            Ok(Literal::F64(value))
1051        }
1052        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1053            unreachable!("abstract values should not be validated out of override processing")
1054        }
1055        _ => unreachable!("unrecognized scalar type for override"),
1056    }
1057}
1058
1059#[test]
1060fn test_map_value_to_literal() {
1061    let bool_test_cases = [
1062        (0.0, false),
1063        (-0.0, false),
1064        (f64::NAN, false),
1065        (1.0, true),
1066        (f64::INFINITY, true),
1067        (f64::NEG_INFINITY, true),
1068    ];
1069    for (value, out) in bool_test_cases {
1070        let res = Ok(Literal::Bool(out));
1071        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1072    }
1073
1074    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1075        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1076            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1077            assert_eq!(map_value_to_literal(value, scalar), res);
1078        }
1079    }
1080
1081    // i32
1082    assert_eq!(
1083        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1084        Ok(Literal::I32(i32::MIN))
1085    );
1086    assert_eq!(
1087        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1088        Ok(Literal::I32(i32::MAX))
1089    );
1090    assert_eq!(
1091        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1092        Err(PipelineConstantError::DstRangeTooSmall)
1093    );
1094    assert_eq!(
1095        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1096        Err(PipelineConstantError::DstRangeTooSmall)
1097    );
1098
1099    // u32
1100    assert_eq!(
1101        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1102        Ok(Literal::U32(u32::MIN))
1103    );
1104    assert_eq!(
1105        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1106        Ok(Literal::U32(u32::MAX))
1107    );
1108    assert_eq!(
1109        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1110        Err(PipelineConstantError::DstRangeTooSmall)
1111    );
1112    assert_eq!(
1113        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1114        Err(PipelineConstantError::DstRangeTooSmall)
1115    );
1116
1117    // f32
1118    assert_eq!(
1119        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1120        Ok(Literal::F32(f32::MIN))
1121    );
1122    assert_eq!(
1123        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1124        Ok(Literal::F32(f32::MAX))
1125    );
1126    assert_eq!(
1127        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1128        Ok(Literal::F32(f32::MIN))
1129    );
1130    assert_eq!(
1131        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1132        Ok(Literal::F32(f32::MAX))
1133    );
1134    assert_eq!(
1135        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1136        Err(PipelineConstantError::DstRangeTooSmall)
1137    );
1138    assert_eq!(
1139        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1140        Err(PipelineConstantError::DstRangeTooSmall)
1141    );
1142
1143    // f64
1144    assert_eq!(
1145        map_value_to_literal(f64::MIN, Scalar::F64),
1146        Ok(Literal::F64(f64::MIN))
1147    );
1148    assert_eq!(
1149        map_value_to_literal(f64::MAX, Scalar::F64),
1150        Ok(Literal::F64(f64::MAX))
1151    );
1152}