naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    compact::{compact, KeepUnused},
14    ir,
15    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18    Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28    MissingValue(String),
29    #[error(
30        "Source f64 value needs to be finite ({}) for number destinations",
31        "NaNs and Inifinites are not allowed"
32    )]
33    SrcNeedsToBeFinite,
34    #[error("Source f64 value doesn't fit in destination")]
35    DstRangeTooSmall,
36    #[error(transparent)]
37    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38    #[error(transparent)]
39    ValidationError(#[from] WithSpan<ValidationError>),
40    #[error("workgroup_size override isn't strictly positive")]
41    NegativeWorkgroupSize,
42    #[error("max vertices or max primitives is negative")]
43    NegativeMeshOutputMax,
44}
45
46/// Compact `module` and replace all overrides with constants.
47///
48/// If no changes are needed, this just returns `Cow::Borrowed` references to
49/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
50/// selected entry point, compacts the module, edits its [`global_expressions`]
51/// arena to contain only fully-evaluated expressions, and returns the
52/// simplified module and its validation results.
53///
54/// The module returned has an empty `overrides` arena, and the
55/// `global_expressions` arena contains only fully-evaluated expressions.
56///
57/// [`global_expressions`]: Module::global_expressions
58pub fn process_overrides<'a>(
59    module: &'a Module,
60    module_info: &'a ModuleInfo,
61    entry_point: Option<(ir::ShaderStage, &str)>,
62    pipeline_constants: &PipelineConstants,
63) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
64    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
65        // We skip compacting the module here mostly to reduce the risk of
66        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
67        // Compaction doesn't cost very much [1], so it would also be reasonable
68        // to do it unconditionally. Even when there is a single entry point or
69        // when no entry point is specified, it is still possible that there
70        // are unreferenced items in the module that would be removed by this
71        // compaction.
72        //
73        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
74        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
75    }
76
77    let mut module = module.clone();
78    if let Some((ep_stage, ep_name)) = entry_point {
79        module
80            .entry_points
81            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
82    }
83
84    // Compact the module to remove anything not reachable from an entry point.
85    // This is necessary because we may not have values for overrides that are
86    // not reachable from the/an entry point.
87    compact(&mut module, KeepUnused::No);
88
89    // If there are no overrides in the module, then we can skip the rest.
90    if module.overrides.is_empty() {
91        return revalidate(module);
92    }
93
94    // A map from override handles to the handles of the constants
95    // we've replaced them with.
96    let mut override_map = HandleVec::with_capacity(module.overrides.len());
97
98    // A map from `module`'s original global expression handles to
99    // handles in the new, simplified global expression arena.
100    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
101
102    // The set of constants whose initializer handles we've already
103    // updated to refer to the newly built global expression arena.
104    //
105    // All constants in `module` must have their `init` handles
106    // updated to point into the new, simplified global expression
107    // arena. Some of these we can most easily handle as a side effect
108    // during the simplification process, but we must handle the rest
109    // in a final fixup pass, guided by `adjusted_global_expressions`. We
110    // add their handles to this set, so that the final fixup step can
111    // leave them alone.
112    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
113
114    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
115    let mut layouter = crate::proc::Layouter::default();
116
117    // An iterator through the original overrides table, consumed in
118    // approximate tandem with the global expressions.
119    let mut overrides = mem::take(&mut module.overrides);
120    let mut override_iter = overrides.iter_mut_span();
121
122    // Do two things in tandem:
123    //
124    // - Rebuild the global expression arena from scratch, fully
125    //   evaluating all expressions, and replacing each `Override`
126    //   expression in `module.global_expressions` with a `Constant`
127    //   expression.
128    //
129    // - Build a new `Constant` in `module.constants` to take the
130    //   place of each `Override`.
131    //
132    // Build a map from old global expression handles to their
133    // fully-evaluated counterparts in `adjusted_global_expressions` as we
134    // go.
135    //
136    // Why in tandem? Overrides refer to expressions, and expressions
137    // refer to overrides, so we can't disentangle the two into
138    // separate phases. However, we can take advantage of the fact
139    // that the overrides and expressions must form a DAG, and work
140    // our way from the leaves to the roots, replacing and evaluating
141    // as we go.
142    //
143    // Although the two loops are nested, this is really two
144    // alternating phases: we adjust and evaluate constant expressions
145    // until we hit an `Override` expression, at which point we switch
146    // to building `Constant`s for `Overrides` until we've handled the
147    // one used by the expression. Then we switch back to processing
148    // expressions. Because we know they form a DAG, we know the
149    // `Override` expressions we encounter can only have initializers
150    // referring to global expressions we've already simplified.
151    for (old_h, expr, span) in module.global_expressions.drain() {
152        let mut expr = match expr {
153            Expression::Override(h) => {
154                let c_h = if let Some(new_h) = override_map.get(h) {
155                    *new_h
156                } else {
157                    let mut new_h = None;
158                    for entry in override_iter.by_ref() {
159                        let stop = entry.0 == h;
160                        new_h = Some(process_override(
161                            entry,
162                            pipeline_constants,
163                            &mut module,
164                            &mut override_map,
165                            &adjusted_global_expressions,
166                            &mut adjusted_constant_initializers,
167                            &mut global_expression_kind_tracker,
168                        )?);
169                        if stop {
170                            break;
171                        }
172                    }
173                    new_h.unwrap()
174                };
175                Expression::Constant(c_h)
176            }
177            Expression::Constant(c_h) => {
178                if adjusted_constant_initializers.insert(c_h) {
179                    let init = &mut module.constants[c_h].init;
180                    *init = adjusted_global_expressions[*init];
181                }
182                expr
183            }
184            expr => expr,
185        };
186        let mut evaluator = ConstantEvaluator::for_wgsl_module(
187            &mut module,
188            &mut global_expression_kind_tracker,
189            &mut layouter,
190            false,
191        );
192        adjust_expr(&adjusted_global_expressions, &mut expr);
193        let h = evaluator.try_eval_and_append(expr, span)?;
194        adjusted_global_expressions.insert(old_h, h);
195    }
196
197    // Finish processing any overrides we didn't visit in the loop above.
198    for entry in override_iter {
199        match *entry.1 {
200            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
201                process_override(
202                    entry,
203                    pipeline_constants,
204                    &mut module,
205                    &mut override_map,
206                    &adjusted_global_expressions,
207                    &mut adjusted_constant_initializers,
208                    &mut global_expression_kind_tracker,
209                )?;
210            }
211            Override {
212                init: Some(ref mut init),
213                ..
214            } => {
215                *init = adjusted_global_expressions[*init];
216            }
217            _ => {}
218        }
219    }
220
221    // Update the initialization expression handles of all `Constant`s
222    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
223    // passant.
224    for (_, c) in module
225        .constants
226        .iter_mut()
227        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
228    {
229        c.init = adjusted_global_expressions[c.init];
230    }
231
232    for (_, v) in module.global_variables.iter_mut() {
233        if let Some(ref mut init) = v.init {
234            *init = adjusted_global_expressions[*init];
235        }
236    }
237
238    let mut functions = mem::take(&mut module.functions);
239    for (_, function) in functions.iter_mut() {
240        process_function(&mut module, &override_map, &mut layouter, function)?;
241    }
242    module.functions = functions;
243
244    let mut entry_points = mem::take(&mut module.entry_points);
245    for ep in entry_points.iter_mut() {
246        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
247        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
248        process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
249    }
250    module.entry_points = entry_points;
251    module.overrides = overrides;
252
253    // Now that we've rewritten all the expressions, we need to
254    // recompute their types and other metadata. For the time being,
255    // do a full re-validation.
256    revalidate(module)
257}
258
259fn revalidate(
260    module: Module,
261) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
262    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
263    let module_info = validator.validate_resolved_overrides(&module)?;
264    Ok((Cow::Owned(module), Cow::Owned(module_info)))
265}
266
267fn process_workgroup_size_override(
268    module: &mut Module,
269    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
270    ep: &mut crate::EntryPoint,
271) -> Result<(), PipelineConstantError> {
272    match ep.workgroup_size_overrides {
273        None => {}
274        Some(overrides) => {
275            overrides.iter().enumerate().try_for_each(
276                |(i, overridden)| -> Result<(), PipelineConstantError> {
277                    match *overridden {
278                        None => Ok(()),
279                        Some(h) => {
280                            ep.workgroup_size[i] = module
281                                .to_ctx()
282                                .eval_expr_to_u32(adjusted_global_expressions[h])
283                                .map(|n| {
284                                    if n == 0 {
285                                        Err(PipelineConstantError::NegativeWorkgroupSize)
286                                    } else {
287                                        Ok(n)
288                                    }
289                                })
290                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
291                            Ok(())
292                        }
293                    }
294                },
295            )?;
296            ep.workgroup_size_overrides = None;
297        }
298    }
299    Ok(())
300}
301
302fn process_mesh_shader_overrides(
303    module: &mut Module,
304    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
305    ep: &mut crate::EntryPoint,
306) -> Result<(), PipelineConstantError> {
307    if let Some(ref mut mesh_info) = ep.mesh_info {
308        if let Some(r#override) = mesh_info.max_vertices_override {
309            mesh_info.max_vertices = module
310                .to_ctx()
311                .eval_expr_to_u32(adjusted_global_expressions[r#override])
312                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
313        }
314        if let Some(r#override) = mesh_info.max_primitives_override {
315            mesh_info.max_primitives = module
316                .to_ctx()
317                .eval_expr_to_u32(adjusted_global_expressions[r#override])
318                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
319        }
320    }
321    Ok(())
322}
323
324/// Add a [`Constant`] to `module` for the override `old_h`.
325///
326/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
327fn process_override(
328    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
329    pipeline_constants: &PipelineConstants,
330    module: &mut Module,
331    override_map: &mut HandleVec<Override, Handle<Constant>>,
332    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
333    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
334    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
335) -> Result<Handle<Constant>, PipelineConstantError> {
336    // Determine which key to use for `r#override` in `pipeline_constants`.
337    let key = if let Some(id) = r#override.id {
338        Cow::Owned(id.to_string())
339    } else if let Some(ref name) = r#override.name {
340        Cow::Borrowed(name)
341    } else {
342        unreachable!();
343    };
344
345    // Generate a global expression for `r#override`'s value, either
346    // from the provided `pipeline_constants` table or its initializer
347    // in the module.
348    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
349        let literal = match module.types[r#override.ty].inner {
350            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
351            _ => unreachable!(),
352        };
353        let expr = module
354            .global_expressions
355            .append(Expression::Literal(literal), Span::UNDEFINED);
356        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
357        expr
358    } else if let Some(init) = r#override.init {
359        adjusted_global_expressions[init]
360    } else {
361        return Err(PipelineConstantError::MissingValue(key.to_string()));
362    };
363
364    // Generate a new `Constant` to represent the override's value.
365    let constant = Constant {
366        name: r#override.name.clone(),
367        ty: r#override.ty,
368        init,
369    };
370    let h = module.constants.append(constant, *span);
371    override_map.insert(old_h, h);
372    adjusted_constant_initializers.insert(h);
373    r#override.init = Some(init);
374    Ok(h)
375}
376
377/// Replace all override expressions in `function` with fully-evaluated constants.
378///
379/// Replace all `Expression::Override`s in `function`'s expression arena with
380/// the corresponding `Expression::Constant`s, as given in `override_map`.
381/// Replace any expressions whose values are now known with their fully
382/// evaluated form.
383///
384/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
385/// `Handle<Constant>` for the override's final value.
386fn process_function(
387    module: &mut Module,
388    override_map: &HandleVec<Override, Handle<Constant>>,
389    layouter: &mut crate::proc::Layouter,
390    function: &mut Function,
391) -> Result<(), ConstantEvaluatorError> {
392    // A map from original local expression handles to
393    // handles in the new, local expression arena.
394    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
395
396    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
397
398    let mut expressions = mem::take(&mut function.expressions);
399
400    // Dummy `emitter` and `block` for the constant evaluator.
401    // We can ignore the concept of emitting expressions here since
402    // expressions have already been covered by a `Statement::Emit`
403    // in the frontend.
404    // The only thing we might have to do is remove some expressions
405    // that have been covered by a `Statement::Emit`. See the docs of
406    // `filter_emits_in_block` for the reasoning.
407    let mut emitter = Emitter::default();
408    let mut block = Block::new();
409
410    let mut evaluator = ConstantEvaluator::for_wgsl_function(
411        module,
412        &mut function.expressions,
413        &mut local_expression_kind_tracker,
414        layouter,
415        &mut emitter,
416        &mut block,
417        false,
418    );
419
420    for (old_h, mut expr, span) in expressions.drain() {
421        if let Expression::Override(h) = expr {
422            expr = Expression::Constant(override_map[h]);
423        }
424        adjust_expr(&adjusted_local_expressions, &mut expr);
425        let h = evaluator.try_eval_and_append(expr, span)?;
426        adjusted_local_expressions.insert(old_h, h);
427    }
428
429    adjust_block(&adjusted_local_expressions, &mut function.body);
430
431    filter_emits_in_block(&mut function.body, &function.expressions);
432
433    // Update local expression initializers.
434    for (_, local) in function.local_variables.iter_mut() {
435        if let &mut Some(ref mut init) = &mut local.init {
436            *init = adjusted_local_expressions[*init];
437        }
438    }
439
440    // We've changed the keys of `function.named_expression`, so we have to
441    // rebuild it from scratch.
442    let named_expressions = mem::take(&mut function.named_expressions);
443    for (expr_h, name) in named_expressions {
444        function
445            .named_expressions
446            .insert(adjusted_local_expressions[expr_h], name);
447    }
448
449    Ok(())
450}
451
452/// Replace every expression handle in `expr` with its counterpart
453/// given by `new_pos`.
454fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
455    let adjust = |expr: &mut Handle<Expression>| {
456        *expr = new_pos[*expr];
457    };
458    match *expr {
459        Expression::Compose {
460            ref mut components,
461            ty: _,
462        } => {
463            for c in components.iter_mut() {
464                adjust(c);
465            }
466        }
467        Expression::Access {
468            ref mut base,
469            ref mut index,
470        } => {
471            adjust(base);
472            adjust(index);
473        }
474        Expression::AccessIndex {
475            ref mut base,
476            index: _,
477        } => {
478            adjust(base);
479        }
480        Expression::Splat {
481            ref mut value,
482            size: _,
483        } => {
484            adjust(value);
485        }
486        Expression::Swizzle {
487            ref mut vector,
488            size: _,
489            pattern: _,
490        } => {
491            adjust(vector);
492        }
493        Expression::Load { ref mut pointer } => {
494            adjust(pointer);
495        }
496        Expression::ImageSample {
497            ref mut image,
498            ref mut sampler,
499            ref mut coordinate,
500            ref mut array_index,
501            ref mut offset,
502            ref mut level,
503            ref mut depth_ref,
504            gather: _,
505            clamp_to_edge: _,
506        } => {
507            adjust(image);
508            adjust(sampler);
509            adjust(coordinate);
510            if let Some(e) = array_index.as_mut() {
511                adjust(e);
512            }
513            if let Some(e) = offset.as_mut() {
514                adjust(e);
515            }
516            match *level {
517                crate::SampleLevel::Exact(ref mut expr)
518                | crate::SampleLevel::Bias(ref mut expr) => {
519                    adjust(expr);
520                }
521                crate::SampleLevel::Gradient {
522                    ref mut x,
523                    ref mut y,
524                } => {
525                    adjust(x);
526                    adjust(y);
527                }
528                _ => {}
529            }
530            if let Some(e) = depth_ref.as_mut() {
531                adjust(e);
532            }
533        }
534        Expression::ImageLoad {
535            ref mut image,
536            ref mut coordinate,
537            ref mut array_index,
538            ref mut sample,
539            ref mut level,
540        } => {
541            adjust(image);
542            adjust(coordinate);
543            if let Some(e) = array_index.as_mut() {
544                adjust(e);
545            }
546            if let Some(e) = sample.as_mut() {
547                adjust(e);
548            }
549            if let Some(e) = level.as_mut() {
550                adjust(e);
551            }
552        }
553        Expression::ImageQuery {
554            ref mut image,
555            ref mut query,
556        } => {
557            adjust(image);
558            match *query {
559                crate::ImageQuery::Size { ref mut level } => {
560                    if let Some(e) = level.as_mut() {
561                        adjust(e);
562                    }
563                }
564                crate::ImageQuery::NumLevels
565                | crate::ImageQuery::NumLayers
566                | crate::ImageQuery::NumSamples => {}
567            }
568        }
569        Expression::Unary {
570            ref mut expr,
571            op: _,
572        } => {
573            adjust(expr);
574        }
575        Expression::Binary {
576            ref mut left,
577            ref mut right,
578            op: _,
579        } => {
580            adjust(left);
581            adjust(right);
582        }
583        Expression::Select {
584            ref mut condition,
585            ref mut accept,
586            ref mut reject,
587        } => {
588            adjust(condition);
589            adjust(accept);
590            adjust(reject);
591        }
592        Expression::Derivative {
593            ref mut expr,
594            axis: _,
595            ctrl: _,
596        } => {
597            adjust(expr);
598        }
599        Expression::Relational {
600            ref mut argument,
601            fun: _,
602        } => {
603            adjust(argument);
604        }
605        Expression::Math {
606            ref mut arg,
607            ref mut arg1,
608            ref mut arg2,
609            ref mut arg3,
610            fun: _,
611        } => {
612            adjust(arg);
613            if let Some(e) = arg1.as_mut() {
614                adjust(e);
615            }
616            if let Some(e) = arg2.as_mut() {
617                adjust(e);
618            }
619            if let Some(e) = arg3.as_mut() {
620                adjust(e);
621            }
622        }
623        Expression::As {
624            ref mut expr,
625            kind: _,
626            convert: _,
627        } => {
628            adjust(expr);
629        }
630        Expression::ArrayLength(ref mut expr) => {
631            adjust(expr);
632        }
633        Expression::RayQueryGetIntersection {
634            ref mut query,
635            committed: _,
636        } => {
637            adjust(query);
638        }
639        Expression::Literal(_)
640        | Expression::FunctionArgument(_)
641        | Expression::GlobalVariable(_)
642        | Expression::LocalVariable(_)
643        | Expression::CallResult(_)
644        | Expression::RayQueryProceedResult
645        | Expression::Constant(_)
646        | Expression::Override(_)
647        | Expression::ZeroValue(_)
648        | Expression::AtomicResult {
649            ty: _,
650            comparison: _,
651        }
652        | Expression::WorkGroupUniformLoadResult { ty: _ }
653        | Expression::SubgroupBallotResult
654        | Expression::SubgroupOperationResult { .. } => {}
655        Expression::RayQueryVertexPositions {
656            ref mut query,
657            committed: _,
658        } => {
659            adjust(query);
660        }
661    }
662}
663
664/// Replace every expression handle in `block` with its counterpart
665/// given by `new_pos`.
666fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
667    for stmt in block.iter_mut() {
668        adjust_stmt(new_pos, stmt);
669    }
670}
671
672/// Replace every expression handle in `stmt` with its counterpart
673/// given by `new_pos`.
674fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
675    let adjust = |expr: &mut Handle<Expression>| {
676        *expr = new_pos[*expr];
677    };
678    match *stmt {
679        Statement::Emit(ref mut range) => {
680            if let Some((mut first, mut last)) = range.first_and_last() {
681                adjust(&mut first);
682                adjust(&mut last);
683                *range = Range::new_from_bounds(first, last);
684            }
685        }
686        Statement::Block(ref mut block) => {
687            adjust_block(new_pos, block);
688        }
689        Statement::If {
690            ref mut condition,
691            ref mut accept,
692            ref mut reject,
693        } => {
694            adjust(condition);
695            adjust_block(new_pos, accept);
696            adjust_block(new_pos, reject);
697        }
698        Statement::Switch {
699            ref mut selector,
700            ref mut cases,
701        } => {
702            adjust(selector);
703            for case in cases.iter_mut() {
704                adjust_block(new_pos, &mut case.body);
705            }
706        }
707        Statement::Loop {
708            ref mut body,
709            ref mut continuing,
710            ref mut break_if,
711        } => {
712            adjust_block(new_pos, body);
713            adjust_block(new_pos, continuing);
714            if let Some(e) = break_if.as_mut() {
715                adjust(e);
716            }
717        }
718        Statement::Return { ref mut value } => {
719            if let Some(e) = value.as_mut() {
720                adjust(e);
721            }
722        }
723        Statement::Store {
724            ref mut pointer,
725            ref mut value,
726        } => {
727            adjust(pointer);
728            adjust(value);
729        }
730        Statement::ImageStore {
731            ref mut image,
732            ref mut coordinate,
733            ref mut array_index,
734            ref mut value,
735        } => {
736            adjust(image);
737            adjust(coordinate);
738            if let Some(e) = array_index.as_mut() {
739                adjust(e);
740            }
741            adjust(value);
742        }
743        Statement::Atomic {
744            ref mut pointer,
745            ref mut value,
746            ref mut result,
747            ref mut fun,
748        } => {
749            adjust(pointer);
750            adjust(value);
751            if let Some(ref mut result) = *result {
752                adjust(result);
753            }
754            match *fun {
755                crate::AtomicFunction::Exchange {
756                    compare: Some(ref mut compare),
757                } => {
758                    adjust(compare);
759                }
760                crate::AtomicFunction::Add
761                | crate::AtomicFunction::Subtract
762                | crate::AtomicFunction::And
763                | crate::AtomicFunction::ExclusiveOr
764                | crate::AtomicFunction::InclusiveOr
765                | crate::AtomicFunction::Min
766                | crate::AtomicFunction::Max
767                | crate::AtomicFunction::Exchange { compare: None } => {}
768            }
769        }
770        Statement::ImageAtomic {
771            ref mut image,
772            ref mut coordinate,
773            ref mut array_index,
774            fun: _,
775            ref mut value,
776        } => {
777            adjust(image);
778            adjust(coordinate);
779            if let Some(ref mut array_index) = *array_index {
780                adjust(array_index);
781            }
782            adjust(value);
783        }
784        Statement::WorkGroupUniformLoad {
785            ref mut pointer,
786            ref mut result,
787        } => {
788            adjust(pointer);
789            adjust(result);
790        }
791        Statement::SubgroupBallot {
792            ref mut result,
793            ref mut predicate,
794        } => {
795            if let Some(ref mut predicate) = *predicate {
796                adjust(predicate);
797            }
798            adjust(result);
799        }
800        Statement::SubgroupCollectiveOperation {
801            ref mut argument,
802            ref mut result,
803            ..
804        } => {
805            adjust(argument);
806            adjust(result);
807        }
808        Statement::SubgroupGather {
809            ref mut mode,
810            ref mut argument,
811            ref mut result,
812        } => {
813            match *mode {
814                crate::GatherMode::BroadcastFirst => {}
815                crate::GatherMode::Broadcast(ref mut index)
816                | crate::GatherMode::Shuffle(ref mut index)
817                | crate::GatherMode::ShuffleDown(ref mut index)
818                | crate::GatherMode::ShuffleUp(ref mut index)
819                | crate::GatherMode::ShuffleXor(ref mut index)
820                | crate::GatherMode::QuadBroadcast(ref mut index) => {
821                    adjust(index);
822                }
823                crate::GatherMode::QuadSwap(_) => {}
824            }
825            adjust(argument);
826            adjust(result)
827        }
828        Statement::Call {
829            ref mut arguments,
830            ref mut result,
831            function: _,
832        } => {
833            for argument in arguments.iter_mut() {
834                adjust(argument);
835            }
836            if let Some(e) = result.as_mut() {
837                adjust(e);
838            }
839        }
840        Statement::RayQuery {
841            ref mut query,
842            ref mut fun,
843        } => {
844            adjust(query);
845            match *fun {
846                crate::RayQueryFunction::Initialize {
847                    ref mut acceleration_structure,
848                    ref mut descriptor,
849                } => {
850                    adjust(acceleration_structure);
851                    adjust(descriptor);
852                }
853                crate::RayQueryFunction::Proceed { ref mut result } => {
854                    adjust(result);
855                }
856                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
857                    adjust(hit_t);
858                }
859                crate::RayQueryFunction::ConfirmIntersection => {}
860                crate::RayQueryFunction::Terminate => {}
861            }
862        }
863        Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs {
864            ref mut vertex_count,
865            ref mut primitive_count,
866        }) => {
867            adjust(vertex_count);
868            adjust(primitive_count);
869        }
870        Statement::MeshFunction(
871            crate::MeshFunction::SetVertex {
872                ref mut index,
873                ref mut value,
874            }
875            | crate::MeshFunction::SetPrimitive {
876                ref mut index,
877                ref mut value,
878            },
879        ) => {
880            adjust(index);
881            adjust(value);
882        }
883        Statement::Break
884        | Statement::Continue
885        | Statement::Kill
886        | Statement::ControlBarrier(_)
887        | Statement::MemoryBarrier(_) => {}
888    }
889}
890
891/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
892///
893/// According to validation, [`Emit`] statements must not cover any expressions
894/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
895/// by successful constant evaluation fall into that category, meaning that
896/// `process_function` will usually rewrite [`Override`] expressions and those
897/// that use their values into pre-emitted expressions, leaving any [`Emit`]
898/// statements that cover them invalid.
899///
900/// This function rewrites all [`Emit`] statements into zero or more new
901/// [`Emit`] statements covering only those expressions in the original range
902/// that are not pre-emitted.
903///
904/// [`Emit`]: Statement::Emit
905/// [`needs_pre_emit`]: Expression::needs_pre_emit
906/// [`Override`]: Expression::Override
907fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
908    let original = mem::replace(block, Block::with_capacity(block.len()));
909    for (stmt, span) in original.span_into_iter() {
910        match stmt {
911            Statement::Emit(range) => {
912                let mut current = None;
913                for expr_h in range {
914                    if expressions[expr_h].needs_pre_emit() {
915                        if let Some((first, last)) = current {
916                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
917                        }
918
919                        current = None;
920                    } else if let Some((_, ref mut last)) = current {
921                        *last = expr_h;
922                    } else {
923                        current = Some((expr_h, expr_h));
924                    }
925                }
926                if let Some((first, last)) = current {
927                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
928                }
929            }
930            Statement::Block(mut child) => {
931                filter_emits_in_block(&mut child, expressions);
932                block.push(Statement::Block(child), span);
933            }
934            Statement::If {
935                condition,
936                mut accept,
937                mut reject,
938            } => {
939                filter_emits_in_block(&mut accept, expressions);
940                filter_emits_in_block(&mut reject, expressions);
941                block.push(
942                    Statement::If {
943                        condition,
944                        accept,
945                        reject,
946                    },
947                    span,
948                );
949            }
950            Statement::Switch {
951                selector,
952                mut cases,
953            } => {
954                for case in &mut cases {
955                    filter_emits_in_block(&mut case.body, expressions);
956                }
957                block.push(Statement::Switch { selector, cases }, span);
958            }
959            Statement::Loop {
960                mut body,
961                mut continuing,
962                break_if,
963            } => {
964                filter_emits_in_block(&mut body, expressions);
965                filter_emits_in_block(&mut continuing, expressions);
966                block.push(
967                    Statement::Loop {
968                        body,
969                        continuing,
970                        break_if,
971                    },
972                    span,
973                );
974            }
975            stmt => block.push(stmt.clone(), span),
976        }
977    }
978}
979
980fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
981    // note that in rust 0.0 == -0.0
982    match scalar {
983        Scalar::BOOL => {
984            // https://webidl.spec.whatwg.org/#js-boolean
985            let value = value != 0.0 && !value.is_nan();
986            Ok(Literal::Bool(value))
987        }
988        Scalar::I32 => {
989            // https://webidl.spec.whatwg.org/#js-long
990            if !value.is_finite() {
991                return Err(PipelineConstantError::SrcNeedsToBeFinite);
992            }
993
994            let value = value.trunc();
995            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
996                return Err(PipelineConstantError::DstRangeTooSmall);
997            }
998
999            let value = value as i32;
1000            Ok(Literal::I32(value))
1001        }
1002        Scalar::U32 => {
1003            // https://webidl.spec.whatwg.org/#js-unsigned-long
1004            if !value.is_finite() {
1005                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1006            }
1007
1008            let value = value.trunc();
1009            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
1010                return Err(PipelineConstantError::DstRangeTooSmall);
1011            }
1012
1013            let value = value as u32;
1014            Ok(Literal::U32(value))
1015        }
1016        Scalar::F16 => {
1017            // https://webidl.spec.whatwg.org/#js-float
1018            if !value.is_finite() {
1019                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1020            }
1021
1022            let value = half::f16::from_f64(value);
1023            if !value.is_finite() {
1024                return Err(PipelineConstantError::DstRangeTooSmall);
1025            }
1026
1027            Ok(Literal::F16(value))
1028        }
1029        Scalar::F32 => {
1030            // https://webidl.spec.whatwg.org/#js-float
1031            if !value.is_finite() {
1032                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1033            }
1034
1035            let value = value as f32;
1036            if !value.is_finite() {
1037                return Err(PipelineConstantError::DstRangeTooSmall);
1038            }
1039
1040            Ok(Literal::F32(value))
1041        }
1042        Scalar::F64 => {
1043            // https://webidl.spec.whatwg.org/#js-double
1044            if !value.is_finite() {
1045                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1046            }
1047
1048            Ok(Literal::F64(value))
1049        }
1050        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1051            unreachable!("abstract values should not be validated out of override processing")
1052        }
1053        _ => unreachable!("unrecognized scalar type for override"),
1054    }
1055}
1056
1057#[test]
1058fn test_map_value_to_literal() {
1059    let bool_test_cases = [
1060        (0.0, false),
1061        (-0.0, false),
1062        (f64::NAN, false),
1063        (1.0, true),
1064        (f64::INFINITY, true),
1065        (f64::NEG_INFINITY, true),
1066    ];
1067    for (value, out) in bool_test_cases {
1068        let res = Ok(Literal::Bool(out));
1069        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1070    }
1071
1072    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1073        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1074            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1075            assert_eq!(map_value_to_literal(value, scalar), res);
1076        }
1077    }
1078
1079    // i32
1080    assert_eq!(
1081        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1082        Ok(Literal::I32(i32::MIN))
1083    );
1084    assert_eq!(
1085        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1086        Ok(Literal::I32(i32::MAX))
1087    );
1088    assert_eq!(
1089        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1090        Err(PipelineConstantError::DstRangeTooSmall)
1091    );
1092    assert_eq!(
1093        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1094        Err(PipelineConstantError::DstRangeTooSmall)
1095    );
1096
1097    // u32
1098    assert_eq!(
1099        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1100        Ok(Literal::U32(u32::MIN))
1101    );
1102    assert_eq!(
1103        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1104        Ok(Literal::U32(u32::MAX))
1105    );
1106    assert_eq!(
1107        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1108        Err(PipelineConstantError::DstRangeTooSmall)
1109    );
1110    assert_eq!(
1111        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1112        Err(PipelineConstantError::DstRangeTooSmall)
1113    );
1114
1115    // f32
1116    assert_eq!(
1117        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1118        Ok(Literal::F32(f32::MIN))
1119    );
1120    assert_eq!(
1121        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1122        Ok(Literal::F32(f32::MAX))
1123    );
1124    assert_eq!(
1125        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1126        Ok(Literal::F32(f32::MIN))
1127    );
1128    assert_eq!(
1129        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1130        Ok(Literal::F32(f32::MAX))
1131    );
1132    assert_eq!(
1133        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1134        Err(PipelineConstantError::DstRangeTooSmall)
1135    );
1136    assert_eq!(
1137        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1138        Err(PipelineConstantError::DstRangeTooSmall)
1139    );
1140
1141    // f64
1142    assert_eq!(
1143        map_value_to_literal(f64::MIN, Scalar::F64),
1144        Ok(Literal::F64(f64::MIN))
1145    );
1146    assert_eq!(
1147        map_value_to_literal(f64::MAX, Scalar::F64),
1148        Ok(Literal::F64(f64::MAX))
1149    );
1150}