naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    compact::{compact, KeepUnused},
14    ir,
15    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18    Span, Statement, TypeInner, WithSpan,
19};
20
21// Possibly unused if not compiled with no_std
22#[allow(unused_imports)]
23use num_traits::float::FloatCore as _;
24
25#[derive(Error, Debug, Clone)]
26#[cfg_attr(test, derive(PartialEq))]
27pub enum PipelineConstantError {
28    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
29    MissingValue(String),
30    #[error(
31        "Source f64 value needs to be finite ({}) for number destinations",
32        "NaNs and Inifinites are not allowed"
33    )]
34    SrcNeedsToBeFinite,
35    #[error("Source f64 value doesn't fit in destination")]
36    DstRangeTooSmall,
37    #[error(transparent)]
38    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
39    #[error(transparent)]
40    ValidationError(#[from] WithSpan<ValidationError>),
41    #[error("workgroup_size override isn't strictly positive")]
42    NegativeWorkgroupSize,
43    #[error("max vertices or max primitives is negative")]
44    NegativeMeshOutputMax,
45}
46
47/// Compact `module` and replace all overrides with constants.
48///
49/// `module` must be valid. Both compaction and constant evaluation may produce
50/// invalid results (e.g. replace an invalid expression with a constant) for
51/// invalid modules.
52///
53/// If no changes are needed, this just returns `Cow::Borrowed` references to
54/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
55/// selected entry point, compacts the module, edits its [`global_expressions`]
56/// arena to contain only fully-evaluated expressions, and returns the
57/// simplified module and its validation results.
58///
59/// The module returned has an empty `overrides` arena, and the
60/// `global_expressions` arena contains only fully-evaluated expressions.
61///
62/// [`global_expressions`]: Module::global_expressions
63pub fn process_overrides<'a>(
64    module: &'a Module,
65    module_info: &'a ModuleInfo,
66    entry_point: Option<(ir::ShaderStage, &str)>,
67    pipeline_constants: &PipelineConstants,
68) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
69    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
70        // We skip compacting the module here mostly to reduce the risk of
71        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
72        // Compaction doesn't cost very much [1], so it would also be reasonable
73        // to do it unconditionally. Even when there is a single entry point or
74        // when no entry point is specified, it is still possible that there
75        // are unreferenced items in the module that would be removed by this
76        // compaction.
77        //
78        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
79        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
80    }
81
82    let mut module = module.clone();
83    if let Some((ep_stage, ep_name)) = entry_point {
84        module
85            .entry_points
86            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
87    }
88
89    // Compact the module to remove anything not reachable from an entry point.
90    // This is necessary because we may not have values for overrides that are
91    // not reachable from the/an entry point.
92    compact(&mut module, KeepUnused::No);
93
94    // If there are no overrides in the module, then we can skip the rest.
95    if module.overrides.is_empty() {
96        return revalidate(module);
97    }
98
99    // A map from override handles to the handles of the constants
100    // we've replaced them with.
101    let mut override_map = HandleVec::with_capacity(module.overrides.len());
102
103    // A map from `module`'s original global expression handles to
104    // handles in the new, simplified global expression arena.
105    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
106
107    // The set of constants whose initializer handles we've already
108    // updated to refer to the newly built global expression arena.
109    //
110    // All constants in `module` must have their `init` handles
111    // updated to point into the new, simplified global expression
112    // arena. Some of these we can most easily handle as a side effect
113    // during the simplification process, but we must handle the rest
114    // in a final fixup pass, guided by `adjusted_global_expressions`. We
115    // add their handles to this set, so that the final fixup step can
116    // leave them alone.
117    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
118
119    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
120    let mut layouter = crate::proc::Layouter::default();
121
122    // An iterator through the original overrides table, consumed in
123    // approximate tandem with the global expressions.
124    let mut overrides = mem::take(&mut module.overrides);
125    let mut override_iter = overrides.iter_mut_span();
126
127    // Do two things in tandem:
128    //
129    // - Rebuild the global expression arena from scratch, fully
130    //   evaluating all expressions, and replacing each `Override`
131    //   expression in `module.global_expressions` with a `Constant`
132    //   expression.
133    //
134    // - Build a new `Constant` in `module.constants` to take the
135    //   place of each `Override`.
136    //
137    // Build a map from old global expression handles to their
138    // fully-evaluated counterparts in `adjusted_global_expressions` as we
139    // go.
140    //
141    // Why in tandem? Overrides refer to expressions, and expressions
142    // refer to overrides, so we can't disentangle the two into
143    // separate phases. However, we can take advantage of the fact
144    // that the overrides and expressions must form a DAG, and work
145    // our way from the leaves to the roots, replacing and evaluating
146    // as we go.
147    //
148    // Although the two loops are nested, this is really two
149    // alternating phases: we adjust and evaluate constant expressions
150    // until we hit an `Override` expression, at which point we switch
151    // to building `Constant`s for `Overrides` until we've handled the
152    // one used by the expression. Then we switch back to processing
153    // expressions. Because we know they form a DAG, we know the
154    // `Override` expressions we encounter can only have initializers
155    // referring to global expressions we've already simplified.
156    for (old_h, expr, span) in module.global_expressions.drain() {
157        let mut expr = match expr {
158            Expression::Override(h) => {
159                let c_h = if let Some(new_h) = override_map.get(h) {
160                    *new_h
161                } else {
162                    let mut new_h = None;
163                    for entry in override_iter.by_ref() {
164                        let stop = entry.0 == h;
165                        new_h = Some(process_override(
166                            entry,
167                            pipeline_constants,
168                            &mut module,
169                            &mut override_map,
170                            &adjusted_global_expressions,
171                            &mut adjusted_constant_initializers,
172                            &mut global_expression_kind_tracker,
173                        )?);
174                        if stop {
175                            break;
176                        }
177                    }
178                    new_h.unwrap()
179                };
180                Expression::Constant(c_h)
181            }
182            Expression::Constant(c_h) => {
183                if adjusted_constant_initializers.insert(c_h) {
184                    let init = &mut module.constants[c_h].init;
185                    *init = adjusted_global_expressions[*init];
186                }
187                expr
188            }
189            expr => expr,
190        };
191        let mut evaluator = ConstantEvaluator::for_wgsl_module(
192            &mut module,
193            &mut global_expression_kind_tracker,
194            &mut layouter,
195            false,
196        );
197        adjust_expr(&adjusted_global_expressions, &mut expr);
198        let h = evaluator.try_eval_and_append(expr, span)?;
199        adjusted_global_expressions.insert(old_h, h);
200    }
201
202    // Finish processing any overrides we didn't visit in the loop above.
203    for entry in override_iter {
204        match *entry.1 {
205            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
206                process_override(
207                    entry,
208                    pipeline_constants,
209                    &mut module,
210                    &mut override_map,
211                    &adjusted_global_expressions,
212                    &mut adjusted_constant_initializers,
213                    &mut global_expression_kind_tracker,
214                )?;
215            }
216            Override {
217                init: Some(ref mut init),
218                ..
219            } => {
220                *init = adjusted_global_expressions[*init];
221            }
222            _ => {}
223        }
224    }
225
226    // Update the initialization expression handles of all `Constant`s
227    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
228    // passant.
229    for (_, c) in module
230        .constants
231        .iter_mut()
232        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
233    {
234        c.init = adjusted_global_expressions[c.init];
235    }
236
237    for (_, v) in module.global_variables.iter_mut() {
238        if let Some(ref mut init) = v.init {
239            *init = adjusted_global_expressions[*init];
240        }
241    }
242
243    let mut functions = mem::take(&mut module.functions);
244    for (_, function) in functions.iter_mut() {
245        process_function(&mut module, &override_map, &mut layouter, function)?;
246    }
247    module.functions = functions;
248
249    let mut entry_points = mem::take(&mut module.entry_points);
250    for ep in entry_points.iter_mut() {
251        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
252        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
253        process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
254    }
255    module.entry_points = entry_points;
256    module.overrides = overrides;
257
258    // Now that we've rewritten all the expressions, we need to
259    // recompute their types and other metadata. For the time being,
260    // do a full re-validation.
261    revalidate(module)
262}
263
264fn revalidate(
265    module: Module,
266) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
267    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
268    let module_info = validator.validate_resolved_overrides(&module)?;
269    Ok((Cow::Owned(module), Cow::Owned(module_info)))
270}
271
272fn process_workgroup_size_override(
273    module: &mut Module,
274    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
275    ep: &mut crate::EntryPoint,
276) -> Result<(), PipelineConstantError> {
277    match ep.workgroup_size_overrides {
278        None => {}
279        Some(overrides) => {
280            overrides.iter().enumerate().try_for_each(
281                |(i, overridden)| -> Result<(), PipelineConstantError> {
282                    match *overridden {
283                        None => Ok(()),
284                        Some(h) => {
285                            ep.workgroup_size[i] = module
286                                .to_ctx()
287                                .get_const_val(adjusted_global_expressions[h])
288                                .map(|n| {
289                                    if n == 0 {
290                                        Err(PipelineConstantError::NegativeWorkgroupSize)
291                                    } else {
292                                        Ok(n)
293                                    }
294                                })
295                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
296                            Ok(())
297                        }
298                    }
299                },
300            )?;
301            ep.workgroup_size_overrides = None;
302        }
303    }
304    Ok(())
305}
306
307fn process_mesh_shader_overrides(
308    module: &mut Module,
309    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
310    ep: &mut crate::EntryPoint,
311) -> Result<(), PipelineConstantError> {
312    if let Some(ref mut mesh_info) = ep.mesh_info {
313        if let Some(r#override) = mesh_info.max_vertices_override {
314            mesh_info.max_vertices = module
315                .to_ctx()
316                .get_const_val(adjusted_global_expressions[r#override])
317                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
318        }
319        if let Some(r#override) = mesh_info.max_primitives_override {
320            mesh_info.max_primitives = module
321                .to_ctx()
322                .get_const_val(adjusted_global_expressions[r#override])
323                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
324        }
325    }
326    Ok(())
327}
328
329/// Add a [`Constant`] to `module` for the override `old_h`.
330///
331/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
332fn process_override(
333    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
334    pipeline_constants: &PipelineConstants,
335    module: &mut Module,
336    override_map: &mut HandleVec<Override, Handle<Constant>>,
337    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
338    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
339    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
340) -> Result<Handle<Constant>, PipelineConstantError> {
341    // Determine which key to use for `r#override` in `pipeline_constants`.
342    let key = if let Some(id) = r#override.id {
343        Cow::Owned(id.to_string())
344    } else if let Some(ref name) = r#override.name {
345        Cow::Borrowed(name)
346    } else {
347        unreachable!();
348    };
349
350    // Generate a global expression for `r#override`'s value, either
351    // from the provided `pipeline_constants` table or its initializer
352    // in the module.
353    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
354        let literal = match module.types[r#override.ty].inner {
355            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
356            _ => unreachable!(),
357        };
358        let expr = module
359            .global_expressions
360            .append(Expression::Literal(literal), Span::UNDEFINED);
361        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
362        expr
363    } else if let Some(init) = r#override.init {
364        adjusted_global_expressions[init]
365    } else {
366        return Err(PipelineConstantError::MissingValue(key.to_string()));
367    };
368
369    // Generate a new `Constant` to represent the override's value.
370    let constant = Constant {
371        name: r#override.name.clone(),
372        ty: r#override.ty,
373        init,
374    };
375    let h = module.constants.append(constant, *span);
376    override_map.insert(old_h, h);
377    adjusted_constant_initializers.insert(h);
378    r#override.init = Some(init);
379    Ok(h)
380}
381
382/// Replace all override expressions in `function` with fully-evaluated constants.
383///
384/// Replace all `Expression::Override`s in `function`'s expression arena with
385/// the corresponding `Expression::Constant`s, as given in `override_map`.
386/// Replace any expressions whose values are now known with their fully
387/// evaluated form.
388///
389/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
390/// `Handle<Constant>` for the override's final value.
391fn process_function(
392    module: &mut Module,
393    override_map: &HandleVec<Override, Handle<Constant>>,
394    layouter: &mut crate::proc::Layouter,
395    function: &mut Function,
396) -> Result<(), ConstantEvaluatorError> {
397    // A map from original local expression handles to
398    // handles in the new, local expression arena.
399    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
400
401    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
402
403    let mut expressions = mem::take(&mut function.expressions);
404
405    // Dummy `emitter` and `block` for the constant evaluator.
406    // We can ignore the concept of emitting expressions here since
407    // expressions have already been covered by a `Statement::Emit`
408    // in the frontend.
409    // The only thing we might have to do is remove some expressions
410    // that have been covered by a `Statement::Emit`. See the docs of
411    // `filter_emits_in_block` for the reasoning.
412    let mut emitter = Emitter::default();
413    let mut block = Block::new();
414
415    let mut evaluator = ConstantEvaluator::for_wgsl_function(
416        module,
417        &mut function.expressions,
418        &mut local_expression_kind_tracker,
419        layouter,
420        &mut emitter,
421        &mut block,
422        false,
423    );
424
425    for (old_h, mut expr, span) in expressions.drain() {
426        if let Expression::Override(h) = expr {
427            expr = Expression::Constant(override_map[h]);
428        }
429        adjust_expr(&adjusted_local_expressions, &mut expr);
430        let h = evaluator.try_eval_and_append(expr, span)?;
431        adjusted_local_expressions.insert(old_h, h);
432    }
433
434    adjust_block(&adjusted_local_expressions, &mut function.body);
435
436    filter_emits_in_block(&mut function.body, &function.expressions);
437
438    // Update local expression initializers.
439    for (_, local) in function.local_variables.iter_mut() {
440        if let &mut Some(ref mut init) = &mut local.init {
441            *init = adjusted_local_expressions[*init];
442        }
443    }
444
445    // We've changed the keys of `function.named_expression`, so we have to
446    // rebuild it from scratch.
447    let named_expressions = mem::take(&mut function.named_expressions);
448    for (expr_h, name) in named_expressions {
449        function
450            .named_expressions
451            .insert(adjusted_local_expressions[expr_h], name);
452    }
453
454    Ok(())
455}
456
457/// Replace every expression handle in `expr` with its counterpart
458/// given by `new_pos`.
459fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
460    let adjust = |expr: &mut Handle<Expression>| {
461        *expr = new_pos[*expr];
462    };
463    match *expr {
464        Expression::Compose {
465            ref mut components,
466            ty: _,
467        } => {
468            for c in components.iter_mut() {
469                adjust(c);
470            }
471        }
472        Expression::Access {
473            ref mut base,
474            ref mut index,
475        } => {
476            adjust(base);
477            adjust(index);
478        }
479        Expression::AccessIndex {
480            ref mut base,
481            index: _,
482        } => {
483            adjust(base);
484        }
485        Expression::Splat {
486            ref mut value,
487            size: _,
488        } => {
489            adjust(value);
490        }
491        Expression::Swizzle {
492            ref mut vector,
493            size: _,
494            pattern: _,
495        } => {
496            adjust(vector);
497        }
498        Expression::Load { ref mut pointer } => {
499            adjust(pointer);
500        }
501        Expression::ImageSample {
502            ref mut image,
503            ref mut sampler,
504            ref mut coordinate,
505            ref mut array_index,
506            ref mut offset,
507            ref mut level,
508            ref mut depth_ref,
509            gather: _,
510            clamp_to_edge: _,
511        } => {
512            adjust(image);
513            adjust(sampler);
514            adjust(coordinate);
515            if let Some(e) = array_index.as_mut() {
516                adjust(e);
517            }
518            if let Some(e) = offset.as_mut() {
519                adjust(e);
520            }
521            match *level {
522                crate::SampleLevel::Exact(ref mut expr)
523                | crate::SampleLevel::Bias(ref mut expr) => {
524                    adjust(expr);
525                }
526                crate::SampleLevel::Gradient {
527                    ref mut x,
528                    ref mut y,
529                } => {
530                    adjust(x);
531                    adjust(y);
532                }
533                _ => {}
534            }
535            if let Some(e) = depth_ref.as_mut() {
536                adjust(e);
537            }
538        }
539        Expression::ImageLoad {
540            ref mut image,
541            ref mut coordinate,
542            ref mut array_index,
543            ref mut sample,
544            ref mut level,
545        } => {
546            adjust(image);
547            adjust(coordinate);
548            if let Some(e) = array_index.as_mut() {
549                adjust(e);
550            }
551            if let Some(e) = sample.as_mut() {
552                adjust(e);
553            }
554            if let Some(e) = level.as_mut() {
555                adjust(e);
556            }
557        }
558        Expression::ImageQuery {
559            ref mut image,
560            ref mut query,
561        } => {
562            adjust(image);
563            match *query {
564                crate::ImageQuery::Size { ref mut level } => {
565                    if let Some(e) = level.as_mut() {
566                        adjust(e);
567                    }
568                }
569                crate::ImageQuery::NumLevels
570                | crate::ImageQuery::NumLayers
571                | crate::ImageQuery::NumSamples => {}
572            }
573        }
574        Expression::Unary {
575            ref mut expr,
576            op: _,
577        } => {
578            adjust(expr);
579        }
580        Expression::Binary {
581            ref mut left,
582            ref mut right,
583            op: _,
584        } => {
585            adjust(left);
586            adjust(right);
587        }
588        Expression::Select {
589            ref mut condition,
590            ref mut accept,
591            ref mut reject,
592        } => {
593            adjust(condition);
594            adjust(accept);
595            adjust(reject);
596        }
597        Expression::Derivative {
598            ref mut expr,
599            axis: _,
600            ctrl: _,
601        } => {
602            adjust(expr);
603        }
604        Expression::Relational {
605            ref mut argument,
606            fun: _,
607        } => {
608            adjust(argument);
609        }
610        Expression::Math {
611            ref mut arg,
612            ref mut arg1,
613            ref mut arg2,
614            ref mut arg3,
615            fun: _,
616        } => {
617            adjust(arg);
618            if let Some(e) = arg1.as_mut() {
619                adjust(e);
620            }
621            if let Some(e) = arg2.as_mut() {
622                adjust(e);
623            }
624            if let Some(e) = arg3.as_mut() {
625                adjust(e);
626            }
627        }
628        Expression::As {
629            ref mut expr,
630            kind: _,
631            convert: _,
632        } => {
633            adjust(expr);
634        }
635        Expression::ArrayLength(ref mut expr) => {
636            adjust(expr);
637        }
638        Expression::RayQueryGetIntersection {
639            ref mut query,
640            committed: _,
641        } => {
642            adjust(query);
643        }
644        Expression::Literal(_)
645        | Expression::FunctionArgument(_)
646        | Expression::GlobalVariable(_)
647        | Expression::LocalVariable(_)
648        | Expression::CallResult(_)
649        | Expression::RayQueryProceedResult
650        | Expression::Constant(_)
651        | Expression::Override(_)
652        | Expression::ZeroValue(_)
653        | Expression::AtomicResult {
654            ty: _,
655            comparison: _,
656        }
657        | Expression::WorkGroupUniformLoadResult { ty: _ }
658        | Expression::SubgroupBallotResult
659        | Expression::SubgroupOperationResult { .. } => {}
660        Expression::RayQueryVertexPositions {
661            ref mut query,
662            committed: _,
663        } => {
664            adjust(query);
665        }
666        Expression::CooperativeLoad { ref mut data, .. } => {
667            adjust(&mut data.pointer);
668            adjust(&mut data.stride);
669        }
670        Expression::CooperativeMultiplyAdd {
671            ref mut a,
672            ref mut b,
673            ref mut c,
674        } => {
675            adjust(a);
676            adjust(b);
677            adjust(c);
678        }
679    }
680}
681
682/// Replace every expression handle in `block` with its counterpart
683/// given by `new_pos`.
684fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
685    for stmt in block.iter_mut() {
686        adjust_stmt(new_pos, stmt);
687    }
688}
689
690/// Replace every expression handle in `stmt` with its counterpart
691/// given by `new_pos`.
692fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
693    let adjust = |expr: &mut Handle<Expression>| {
694        *expr = new_pos[*expr];
695    };
696    match *stmt {
697        Statement::Emit(ref mut range) => {
698            if let Some((mut first, mut last)) = range.first_and_last() {
699                adjust(&mut first);
700                adjust(&mut last);
701                *range = Range::new_from_bounds(first, last);
702            }
703        }
704        Statement::Block(ref mut block) => {
705            adjust_block(new_pos, block);
706        }
707        Statement::If {
708            ref mut condition,
709            ref mut accept,
710            ref mut reject,
711        } => {
712            adjust(condition);
713            adjust_block(new_pos, accept);
714            adjust_block(new_pos, reject);
715        }
716        Statement::Switch {
717            ref mut selector,
718            ref mut cases,
719        } => {
720            adjust(selector);
721            for case in cases.iter_mut() {
722                adjust_block(new_pos, &mut case.body);
723            }
724        }
725        Statement::Loop {
726            ref mut body,
727            ref mut continuing,
728            ref mut break_if,
729        } => {
730            adjust_block(new_pos, body);
731            adjust_block(new_pos, continuing);
732            if let Some(e) = break_if.as_mut() {
733                adjust(e);
734            }
735        }
736        Statement::Return { ref mut value } => {
737            if let Some(e) = value.as_mut() {
738                adjust(e);
739            }
740        }
741        Statement::Store {
742            ref mut pointer,
743            ref mut value,
744        } => {
745            adjust(pointer);
746            adjust(value);
747        }
748        Statement::ImageStore {
749            ref mut image,
750            ref mut coordinate,
751            ref mut array_index,
752            ref mut value,
753        } => {
754            adjust(image);
755            adjust(coordinate);
756            if let Some(e) = array_index.as_mut() {
757                adjust(e);
758            }
759            adjust(value);
760        }
761        Statement::Atomic {
762            ref mut pointer,
763            ref mut value,
764            ref mut result,
765            ref mut fun,
766        } => {
767            adjust(pointer);
768            adjust(value);
769            if let Some(ref mut result) = *result {
770                adjust(result);
771            }
772            match *fun {
773                crate::AtomicFunction::Exchange {
774                    compare: Some(ref mut compare),
775                } => {
776                    adjust(compare);
777                }
778                crate::AtomicFunction::Add
779                | crate::AtomicFunction::Subtract
780                | crate::AtomicFunction::And
781                | crate::AtomicFunction::ExclusiveOr
782                | crate::AtomicFunction::InclusiveOr
783                | crate::AtomicFunction::Min
784                | crate::AtomicFunction::Max
785                | crate::AtomicFunction::Exchange { compare: None } => {}
786            }
787        }
788        Statement::ImageAtomic {
789            ref mut image,
790            ref mut coordinate,
791            ref mut array_index,
792            fun: _,
793            ref mut value,
794        } => {
795            adjust(image);
796            adjust(coordinate);
797            if let Some(ref mut array_index) = *array_index {
798                adjust(array_index);
799            }
800            adjust(value);
801        }
802        Statement::WorkGroupUniformLoad {
803            ref mut pointer,
804            ref mut result,
805        } => {
806            adjust(pointer);
807            adjust(result);
808        }
809        Statement::SubgroupBallot {
810            ref mut result,
811            ref mut predicate,
812        } => {
813            if let Some(ref mut predicate) = *predicate {
814                adjust(predicate);
815            }
816            adjust(result);
817        }
818        Statement::SubgroupCollectiveOperation {
819            ref mut argument,
820            ref mut result,
821            ..
822        } => {
823            adjust(argument);
824            adjust(result);
825        }
826        Statement::SubgroupGather {
827            ref mut mode,
828            ref mut argument,
829            ref mut result,
830        } => {
831            match *mode {
832                crate::GatherMode::BroadcastFirst => {}
833                crate::GatherMode::Broadcast(ref mut index)
834                | crate::GatherMode::Shuffle(ref mut index)
835                | crate::GatherMode::ShuffleDown(ref mut index)
836                | crate::GatherMode::ShuffleUp(ref mut index)
837                | crate::GatherMode::ShuffleXor(ref mut index)
838                | crate::GatherMode::QuadBroadcast(ref mut index) => {
839                    adjust(index);
840                }
841                crate::GatherMode::QuadSwap(_) => {}
842            }
843            adjust(argument);
844            adjust(result)
845        }
846        Statement::Call {
847            ref mut arguments,
848            ref mut result,
849            function: _,
850        } => {
851            for argument in arguments.iter_mut() {
852                adjust(argument);
853            }
854            if let Some(e) = result.as_mut() {
855                adjust(e);
856            }
857        }
858        Statement::RayQuery {
859            ref mut query,
860            ref mut fun,
861        } => {
862            adjust(query);
863            match *fun {
864                crate::RayQueryFunction::Initialize {
865                    ref mut acceleration_structure,
866                    ref mut descriptor,
867                } => {
868                    adjust(acceleration_structure);
869                    adjust(descriptor);
870                }
871                crate::RayQueryFunction::Proceed { ref mut result } => {
872                    adjust(result);
873                }
874                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
875                    adjust(hit_t);
876                }
877                crate::RayQueryFunction::ConfirmIntersection => {}
878                crate::RayQueryFunction::Terminate => {}
879            }
880        }
881        Statement::CooperativeStore {
882            ref mut target,
883            ref mut data,
884        } => {
885            adjust(target);
886            adjust(&mut data.pointer);
887            adjust(&mut data.stride);
888        }
889        Statement::RayPipelineFunction(ref mut func) => match *func {
890            crate::RayPipelineFunction::TraceRay {
891                ref mut acceleration_structure,
892                ref mut descriptor,
893                ref mut payload,
894            } => {
895                adjust(acceleration_structure);
896                adjust(descriptor);
897                adjust(payload);
898            }
899        },
900        Statement::Break
901        | Statement::Continue
902        | Statement::Kill
903        | Statement::ControlBarrier(_)
904        | Statement::MemoryBarrier(_) => {}
905    }
906}
907
908/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
909///
910/// According to validation, [`Emit`] statements must not cover any expressions
911/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
912/// by successful constant evaluation fall into that category, meaning that
913/// `process_function` will usually rewrite [`Override`] expressions and those
914/// that use their values into pre-emitted expressions, leaving any [`Emit`]
915/// statements that cover them invalid.
916///
917/// This function rewrites all [`Emit`] statements into zero or more new
918/// [`Emit`] statements covering only those expressions in the original range
919/// that are not pre-emitted.
920///
921/// [`Emit`]: Statement::Emit
922/// [`needs_pre_emit`]: Expression::needs_pre_emit
923/// [`Override`]: Expression::Override
924fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
925    let original = mem::replace(block, Block::with_capacity(block.len()));
926    for (stmt, span) in original.span_into_iter() {
927        match stmt {
928            Statement::Emit(range) => {
929                let mut current = None;
930                for expr_h in range {
931                    if expressions[expr_h].needs_pre_emit() {
932                        if let Some((first, last)) = current {
933                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
934                        }
935
936                        current = None;
937                    } else if let Some((_, ref mut last)) = current {
938                        *last = expr_h;
939                    } else {
940                        current = Some((expr_h, expr_h));
941                    }
942                }
943                if let Some((first, last)) = current {
944                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
945                }
946            }
947            Statement::Block(mut child) => {
948                filter_emits_in_block(&mut child, expressions);
949                block.push(Statement::Block(child), span);
950            }
951            Statement::If {
952                condition,
953                mut accept,
954                mut reject,
955            } => {
956                filter_emits_in_block(&mut accept, expressions);
957                filter_emits_in_block(&mut reject, expressions);
958                block.push(
959                    Statement::If {
960                        condition,
961                        accept,
962                        reject,
963                    },
964                    span,
965                );
966            }
967            Statement::Switch {
968                selector,
969                mut cases,
970            } => {
971                for case in &mut cases {
972                    filter_emits_in_block(&mut case.body, expressions);
973                }
974                block.push(Statement::Switch { selector, cases }, span);
975            }
976            Statement::Loop {
977                mut body,
978                mut continuing,
979                break_if,
980            } => {
981                filter_emits_in_block(&mut body, expressions);
982                filter_emits_in_block(&mut continuing, expressions);
983                block.push(
984                    Statement::Loop {
985                        body,
986                        continuing,
987                        break_if,
988                    },
989                    span,
990                );
991            }
992            stmt => block.push(stmt.clone(), span),
993        }
994    }
995}
996
997fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
998    // note that in rust 0.0 == -0.0
999    match scalar {
1000        Scalar::BOOL => {
1001            // https://webidl.spec.whatwg.org/#js-boolean
1002            let value = value != 0.0 && !value.is_nan();
1003            Ok(Literal::Bool(value))
1004        }
1005        Scalar::I32 => {
1006            // https://webidl.spec.whatwg.org/#js-long
1007            if !value.is_finite() {
1008                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1009            }
1010
1011            let value = value.trunc();
1012            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
1013                return Err(PipelineConstantError::DstRangeTooSmall);
1014            }
1015
1016            let value = value as i32;
1017            Ok(Literal::I32(value))
1018        }
1019        Scalar::U32 => {
1020            // https://webidl.spec.whatwg.org/#js-unsigned-long
1021            if !value.is_finite() {
1022                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1023            }
1024
1025            let value = value.trunc();
1026            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1027                return Err(PipelineConstantError::DstRangeTooSmall);
1028            }
1029
1030            let value = value as u32;
1031            Ok(Literal::U32(value))
1032        }
1033        Scalar::F16 => {
1034            // https://webidl.spec.whatwg.org/#js-float
1035            if !value.is_finite() {
1036                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1037            }
1038
1039            let value = half::f16::from_f64(value);
1040            if !value.is_finite() {
1041                return Err(PipelineConstantError::DstRangeTooSmall);
1042            }
1043
1044            Ok(Literal::F16(value))
1045        }
1046        Scalar::F32 => {
1047            // https://webidl.spec.whatwg.org/#js-float
1048            if !value.is_finite() {
1049                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1050            }
1051
1052            let value = value as f32;
1053            if !value.is_finite() {
1054                return Err(PipelineConstantError::DstRangeTooSmall);
1055            }
1056
1057            Ok(Literal::F32(value))
1058        }
1059        Scalar::F64 => {
1060            // https://webidl.spec.whatwg.org/#js-double
1061            if !value.is_finite() {
1062                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1063            }
1064
1065            Ok(Literal::F64(value))
1066        }
1067        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1068            unreachable!("abstract values should not be validated out of override processing")
1069        }
1070        _ => unreachable!("unrecognized scalar type for override"),
1071    }
1072}
1073
1074#[test]
1075fn test_map_value_to_literal() {
1076    let bool_test_cases = [
1077        (0.0, false),
1078        (-0.0, false),
1079        (f64::NAN, false),
1080        (1.0, true),
1081        (f64::INFINITY, true),
1082        (f64::NEG_INFINITY, true),
1083    ];
1084    for (value, out) in bool_test_cases {
1085        let res = Ok(Literal::Bool(out));
1086        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1087    }
1088
1089    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1090        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1091            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1092            assert_eq!(map_value_to_literal(value, scalar), res);
1093        }
1094    }
1095
1096    // i32
1097    assert_eq!(
1098        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1099        Ok(Literal::I32(i32::MIN))
1100    );
1101    assert_eq!(
1102        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1103        Ok(Literal::I32(i32::MAX))
1104    );
1105    assert_eq!(
1106        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1107        Err(PipelineConstantError::DstRangeTooSmall)
1108    );
1109    assert_eq!(
1110        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1111        Err(PipelineConstantError::DstRangeTooSmall)
1112    );
1113
1114    // u32
1115    assert_eq!(
1116        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1117        Ok(Literal::U32(u32::MIN))
1118    );
1119    assert_eq!(
1120        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1121        Ok(Literal::U32(u32::MAX))
1122    );
1123    assert_eq!(
1124        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1125        Err(PipelineConstantError::DstRangeTooSmall)
1126    );
1127    assert_eq!(
1128        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1129        Err(PipelineConstantError::DstRangeTooSmall)
1130    );
1131
1132    // f32
1133    assert_eq!(
1134        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1135        Ok(Literal::F32(f32::MIN))
1136    );
1137    assert_eq!(
1138        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1139        Ok(Literal::F32(f32::MAX))
1140    );
1141    assert_eq!(
1142        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1143        Ok(Literal::F32(f32::MIN))
1144    );
1145    assert_eq!(
1146        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1147        Ok(Literal::F32(f32::MAX))
1148    );
1149    assert_eq!(
1150        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1151        Err(PipelineConstantError::DstRangeTooSmall)
1152    );
1153    assert_eq!(
1154        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1155        Err(PipelineConstantError::DstRangeTooSmall)
1156    );
1157
1158    // f64
1159    assert_eq!(
1160        map_value_to_literal(f64::MIN, Scalar::F64),
1161        Ok(Literal::F64(f64::MIN))
1162    );
1163    assert_eq!(
1164        map_value_to_literal(f64::MAX, Scalar::F64),
1165        Ok(Literal::F64(f64::MAX))
1166    );
1167}