naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    compact::{compact, KeepUnused},
14    ir,
15    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18    Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28    MissingValue(String),
29    #[error(
30        "Source f64 value needs to be finite ({}) for number destinations",
31        "NaNs and Inifinites are not allowed"
32    )]
33    SrcNeedsToBeFinite,
34    #[error("Source f64 value doesn't fit in destination")]
35    DstRangeTooSmall,
36    #[error(transparent)]
37    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38    #[error(transparent)]
39    ValidationError(#[from] WithSpan<ValidationError>),
40    #[error("workgroup_size override isn't strictly positive")]
41    NegativeWorkgroupSize,
42    #[error("max vertices or max primitives is negative")]
43    NegativeMeshOutputMax,
44}
45
46/// Compact `module` and replace all overrides with constants.
47///
48/// If no changes are needed, this just returns `Cow::Borrowed` references to
49/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
50/// selected entry point, compacts the module, edits its [`global_expressions`]
51/// arena to contain only fully-evaluated expressions, and returns the
52/// simplified module and its validation results.
53///
54/// The module returned has an empty `overrides` arena, and the
55/// `global_expressions` arena contains only fully-evaluated expressions.
56///
57/// [`global_expressions`]: Module::global_expressions
58pub fn process_overrides<'a>(
59    module: &'a Module,
60    module_info: &'a ModuleInfo,
61    entry_point: Option<(ir::ShaderStage, &str)>,
62    pipeline_constants: &PipelineConstants,
63) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
64    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
65        // We skip compacting the module here mostly to reduce the risk of
66        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
67        // Compaction doesn't cost very much [1], so it would also be reasonable
68        // to do it unconditionally. Even when there is a single entry point or
69        // when no entry point is specified, it is still possible that there
70        // are unreferenced items in the module that would be removed by this
71        // compaction.
72        //
73        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
74        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
75    }
76
77    let mut module = module.clone();
78    if let Some((ep_stage, ep_name)) = entry_point {
79        module
80            .entry_points
81            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
82    }
83
84    // Compact the module to remove anything not reachable from an entry point.
85    // This is necessary because we may not have values for overrides that are
86    // not reachable from the/an entry point.
87    compact(&mut module, KeepUnused::No);
88
89    // If there are no overrides in the module, then we can skip the rest.
90    if module.overrides.is_empty() {
91        return revalidate(module);
92    }
93
94    // A map from override handles to the handles of the constants
95    // we've replaced them with.
96    let mut override_map = HandleVec::with_capacity(module.overrides.len());
97
98    // A map from `module`'s original global expression handles to
99    // handles in the new, simplified global expression arena.
100    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
101
102    // The set of constants whose initializer handles we've already
103    // updated to refer to the newly built global expression arena.
104    //
105    // All constants in `module` must have their `init` handles
106    // updated to point into the new, simplified global expression
107    // arena. Some of these we can most easily handle as a side effect
108    // during the simplification process, but we must handle the rest
109    // in a final fixup pass, guided by `adjusted_global_expressions`. We
110    // add their handles to this set, so that the final fixup step can
111    // leave them alone.
112    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
113
114    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
115    let mut layouter = crate::proc::Layouter::default();
116
117    // An iterator through the original overrides table, consumed in
118    // approximate tandem with the global expressions.
119    let mut overrides = mem::take(&mut module.overrides);
120    let mut override_iter = overrides.iter_mut_span();
121
122    // Do two things in tandem:
123    //
124    // - Rebuild the global expression arena from scratch, fully
125    //   evaluating all expressions, and replacing each `Override`
126    //   expression in `module.global_expressions` with a `Constant`
127    //   expression.
128    //
129    // - Build a new `Constant` in `module.constants` to take the
130    //   place of each `Override`.
131    //
132    // Build a map from old global expression handles to their
133    // fully-evaluated counterparts in `adjusted_global_expressions` as we
134    // go.
135    //
136    // Why in tandem? Overrides refer to expressions, and expressions
137    // refer to overrides, so we can't disentangle the two into
138    // separate phases. However, we can take advantage of the fact
139    // that the overrides and expressions must form a DAG, and work
140    // our way from the leaves to the roots, replacing and evaluating
141    // as we go.
142    //
143    // Although the two loops are nested, this is really two
144    // alternating phases: we adjust and evaluate constant expressions
145    // until we hit an `Override` expression, at which point we switch
146    // to building `Constant`s for `Overrides` until we've handled the
147    // one used by the expression. Then we switch back to processing
148    // expressions. Because we know they form a DAG, we know the
149    // `Override` expressions we encounter can only have initializers
150    // referring to global expressions we've already simplified.
151    for (old_h, expr, span) in module.global_expressions.drain() {
152        let mut expr = match expr {
153            Expression::Override(h) => {
154                let c_h = if let Some(new_h) = override_map.get(h) {
155                    *new_h
156                } else {
157                    let mut new_h = None;
158                    for entry in override_iter.by_ref() {
159                        let stop = entry.0 == h;
160                        new_h = Some(process_override(
161                            entry,
162                            pipeline_constants,
163                            &mut module,
164                            &mut override_map,
165                            &adjusted_global_expressions,
166                            &mut adjusted_constant_initializers,
167                            &mut global_expression_kind_tracker,
168                        )?);
169                        if stop {
170                            break;
171                        }
172                    }
173                    new_h.unwrap()
174                };
175                Expression::Constant(c_h)
176            }
177            Expression::Constant(c_h) => {
178                if adjusted_constant_initializers.insert(c_h) {
179                    let init = &mut module.constants[c_h].init;
180                    *init = adjusted_global_expressions[*init];
181                }
182                expr
183            }
184            expr => expr,
185        };
186        let mut evaluator = ConstantEvaluator::for_wgsl_module(
187            &mut module,
188            &mut global_expression_kind_tracker,
189            &mut layouter,
190            false,
191        );
192        adjust_expr(&adjusted_global_expressions, &mut expr);
193        let h = evaluator.try_eval_and_append(expr, span)?;
194        adjusted_global_expressions.insert(old_h, h);
195    }
196
197    // Finish processing any overrides we didn't visit in the loop above.
198    for entry in override_iter {
199        match *entry.1 {
200            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
201                process_override(
202                    entry,
203                    pipeline_constants,
204                    &mut module,
205                    &mut override_map,
206                    &adjusted_global_expressions,
207                    &mut adjusted_constant_initializers,
208                    &mut global_expression_kind_tracker,
209                )?;
210            }
211            Override {
212                init: Some(ref mut init),
213                ..
214            } => {
215                *init = adjusted_global_expressions[*init];
216            }
217            _ => {}
218        }
219    }
220
221    // Update the initialization expression handles of all `Constant`s
222    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
223    // passant.
224    for (_, c) in module
225        .constants
226        .iter_mut()
227        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
228    {
229        c.init = adjusted_global_expressions[c.init];
230    }
231
232    for (_, v) in module.global_variables.iter_mut() {
233        if let Some(ref mut init) = v.init {
234            *init = adjusted_global_expressions[*init];
235        }
236    }
237
238    let mut functions = mem::take(&mut module.functions);
239    for (_, function) in functions.iter_mut() {
240        process_function(&mut module, &override_map, &mut layouter, function)?;
241    }
242    module.functions = functions;
243
244    let mut entry_points = mem::take(&mut module.entry_points);
245    for ep in entry_points.iter_mut() {
246        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
247        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
248        process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?;
249    }
250    module.entry_points = entry_points;
251    module.overrides = overrides;
252
253    // Now that we've rewritten all the expressions, we need to
254    // recompute their types and other metadata. For the time being,
255    // do a full re-validation.
256    revalidate(module)
257}
258
259fn revalidate(
260    module: Module,
261) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
262    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
263    let module_info = validator.validate_resolved_overrides(&module)?;
264    Ok((Cow::Owned(module), Cow::Owned(module_info)))
265}
266
267fn process_workgroup_size_override(
268    module: &mut Module,
269    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
270    ep: &mut crate::EntryPoint,
271) -> Result<(), PipelineConstantError> {
272    match ep.workgroup_size_overrides {
273        None => {}
274        Some(overrides) => {
275            overrides.iter().enumerate().try_for_each(
276                |(i, overridden)| -> Result<(), PipelineConstantError> {
277                    match *overridden {
278                        None => Ok(()),
279                        Some(h) => {
280                            ep.workgroup_size[i] = module
281                                .to_ctx()
282                                .eval_expr_to_u32(adjusted_global_expressions[h])
283                                .map(|n| {
284                                    if n == 0 {
285                                        Err(PipelineConstantError::NegativeWorkgroupSize)
286                                    } else {
287                                        Ok(n)
288                                    }
289                                })
290                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
291                            Ok(())
292                        }
293                    }
294                },
295            )?;
296            ep.workgroup_size_overrides = None;
297        }
298    }
299    Ok(())
300}
301
302fn process_mesh_shader_overrides(
303    module: &mut Module,
304    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
305    ep: &mut crate::EntryPoint,
306) -> Result<(), PipelineConstantError> {
307    if let Some(ref mut mesh_info) = ep.mesh_info {
308        if let Some(r#override) = mesh_info.max_vertices_override {
309            mesh_info.max_vertices = module
310                .to_ctx()
311                .eval_expr_to_u32(adjusted_global_expressions[r#override])
312                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
313        }
314        if let Some(r#override) = mesh_info.max_primitives_override {
315            mesh_info.max_primitives = module
316                .to_ctx()
317                .eval_expr_to_u32(adjusted_global_expressions[r#override])
318                .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?;
319        }
320    }
321    Ok(())
322}
323
324/// Add a [`Constant`] to `module` for the override `old_h`.
325///
326/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
327fn process_override(
328    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
329    pipeline_constants: &PipelineConstants,
330    module: &mut Module,
331    override_map: &mut HandleVec<Override, Handle<Constant>>,
332    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
333    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
334    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
335) -> Result<Handle<Constant>, PipelineConstantError> {
336    // Determine which key to use for `r#override` in `pipeline_constants`.
337    let key = if let Some(id) = r#override.id {
338        Cow::Owned(id.to_string())
339    } else if let Some(ref name) = r#override.name {
340        Cow::Borrowed(name)
341    } else {
342        unreachable!();
343    };
344
345    // Generate a global expression for `r#override`'s value, either
346    // from the provided `pipeline_constants` table or its initializer
347    // in the module.
348    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
349        let literal = match module.types[r#override.ty].inner {
350            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
351            _ => unreachable!(),
352        };
353        let expr = module
354            .global_expressions
355            .append(Expression::Literal(literal), Span::UNDEFINED);
356        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
357        expr
358    } else if let Some(init) = r#override.init {
359        adjusted_global_expressions[init]
360    } else {
361        return Err(PipelineConstantError::MissingValue(key.to_string()));
362    };
363
364    // Generate a new `Constant` to represent the override's value.
365    let constant = Constant {
366        name: r#override.name.clone(),
367        ty: r#override.ty,
368        init,
369    };
370    let h = module.constants.append(constant, *span);
371    override_map.insert(old_h, h);
372    adjusted_constant_initializers.insert(h);
373    r#override.init = Some(init);
374    Ok(h)
375}
376
377/// Replace all override expressions in `function` with fully-evaluated constants.
378///
379/// Replace all `Expression::Override`s in `function`'s expression arena with
380/// the corresponding `Expression::Constant`s, as given in `override_map`.
381/// Replace any expressions whose values are now known with their fully
382/// evaluated form.
383///
384/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
385/// `Handle<Constant>` for the override's final value.
386fn process_function(
387    module: &mut Module,
388    override_map: &HandleVec<Override, Handle<Constant>>,
389    layouter: &mut crate::proc::Layouter,
390    function: &mut Function,
391) -> Result<(), ConstantEvaluatorError> {
392    // A map from original local expression handles to
393    // handles in the new, local expression arena.
394    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
395
396    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
397
398    let mut expressions = mem::take(&mut function.expressions);
399
400    // Dummy `emitter` and `block` for the constant evaluator.
401    // We can ignore the concept of emitting expressions here since
402    // expressions have already been covered by a `Statement::Emit`
403    // in the frontend.
404    // The only thing we might have to do is remove some expressions
405    // that have been covered by a `Statement::Emit`. See the docs of
406    // `filter_emits_in_block` for the reasoning.
407    let mut emitter = Emitter::default();
408    let mut block = Block::new();
409
410    let mut evaluator = ConstantEvaluator::for_wgsl_function(
411        module,
412        &mut function.expressions,
413        &mut local_expression_kind_tracker,
414        layouter,
415        &mut emitter,
416        &mut block,
417        false,
418    );
419
420    for (old_h, mut expr, span) in expressions.drain() {
421        if let Expression::Override(h) = expr {
422            expr = Expression::Constant(override_map[h]);
423        }
424        adjust_expr(&adjusted_local_expressions, &mut expr);
425        let h = evaluator.try_eval_and_append(expr, span)?;
426        adjusted_local_expressions.insert(old_h, h);
427    }
428
429    adjust_block(&adjusted_local_expressions, &mut function.body);
430
431    filter_emits_in_block(&mut function.body, &function.expressions);
432
433    // Update local expression initializers.
434    for (_, local) in function.local_variables.iter_mut() {
435        if let &mut Some(ref mut init) = &mut local.init {
436            *init = adjusted_local_expressions[*init];
437        }
438    }
439
440    // We've changed the keys of `function.named_expression`, so we have to
441    // rebuild it from scratch.
442    let named_expressions = mem::take(&mut function.named_expressions);
443    for (expr_h, name) in named_expressions {
444        function
445            .named_expressions
446            .insert(adjusted_local_expressions[expr_h], name);
447    }
448
449    Ok(())
450}
451
452/// Replace every expression handle in `expr` with its counterpart
453/// given by `new_pos`.
454fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
455    let adjust = |expr: &mut Handle<Expression>| {
456        *expr = new_pos[*expr];
457    };
458    match *expr {
459        Expression::Compose {
460            ref mut components,
461            ty: _,
462        } => {
463            for c in components.iter_mut() {
464                adjust(c);
465            }
466        }
467        Expression::Access {
468            ref mut base,
469            ref mut index,
470        } => {
471            adjust(base);
472            adjust(index);
473        }
474        Expression::AccessIndex {
475            ref mut base,
476            index: _,
477        } => {
478            adjust(base);
479        }
480        Expression::Splat {
481            ref mut value,
482            size: _,
483        } => {
484            adjust(value);
485        }
486        Expression::Swizzle {
487            ref mut vector,
488            size: _,
489            pattern: _,
490        } => {
491            adjust(vector);
492        }
493        Expression::Load { ref mut pointer } => {
494            adjust(pointer);
495        }
496        Expression::ImageSample {
497            ref mut image,
498            ref mut sampler,
499            ref mut coordinate,
500            ref mut array_index,
501            ref mut offset,
502            ref mut level,
503            ref mut depth_ref,
504            gather: _,
505            clamp_to_edge: _,
506        } => {
507            adjust(image);
508            adjust(sampler);
509            adjust(coordinate);
510            if let Some(e) = array_index.as_mut() {
511                adjust(e);
512            }
513            if let Some(e) = offset.as_mut() {
514                adjust(e);
515            }
516            match *level {
517                crate::SampleLevel::Exact(ref mut expr)
518                | crate::SampleLevel::Bias(ref mut expr) => {
519                    adjust(expr);
520                }
521                crate::SampleLevel::Gradient {
522                    ref mut x,
523                    ref mut y,
524                } => {
525                    adjust(x);
526                    adjust(y);
527                }
528                _ => {}
529            }
530            if let Some(e) = depth_ref.as_mut() {
531                adjust(e);
532            }
533        }
534        Expression::ImageLoad {
535            ref mut image,
536            ref mut coordinate,
537            ref mut array_index,
538            ref mut sample,
539            ref mut level,
540        } => {
541            adjust(image);
542            adjust(coordinate);
543            if let Some(e) = array_index.as_mut() {
544                adjust(e);
545            }
546            if let Some(e) = sample.as_mut() {
547                adjust(e);
548            }
549            if let Some(e) = level.as_mut() {
550                adjust(e);
551            }
552        }
553        Expression::ImageQuery {
554            ref mut image,
555            ref mut query,
556        } => {
557            adjust(image);
558            match *query {
559                crate::ImageQuery::Size { ref mut level } => {
560                    if let Some(e) = level.as_mut() {
561                        adjust(e);
562                    }
563                }
564                crate::ImageQuery::NumLevels
565                | crate::ImageQuery::NumLayers
566                | crate::ImageQuery::NumSamples => {}
567            }
568        }
569        Expression::Unary {
570            ref mut expr,
571            op: _,
572        } => {
573            adjust(expr);
574        }
575        Expression::Binary {
576            ref mut left,
577            ref mut right,
578            op: _,
579        } => {
580            adjust(left);
581            adjust(right);
582        }
583        Expression::Select {
584            ref mut condition,
585            ref mut accept,
586            ref mut reject,
587        } => {
588            adjust(condition);
589            adjust(accept);
590            adjust(reject);
591        }
592        Expression::Derivative {
593            ref mut expr,
594            axis: _,
595            ctrl: _,
596        } => {
597            adjust(expr);
598        }
599        Expression::Relational {
600            ref mut argument,
601            fun: _,
602        } => {
603            adjust(argument);
604        }
605        Expression::Math {
606            ref mut arg,
607            ref mut arg1,
608            ref mut arg2,
609            ref mut arg3,
610            fun: _,
611        } => {
612            adjust(arg);
613            if let Some(e) = arg1.as_mut() {
614                adjust(e);
615            }
616            if let Some(e) = arg2.as_mut() {
617                adjust(e);
618            }
619            if let Some(e) = arg3.as_mut() {
620                adjust(e);
621            }
622        }
623        Expression::As {
624            ref mut expr,
625            kind: _,
626            convert: _,
627        } => {
628            adjust(expr);
629        }
630        Expression::ArrayLength(ref mut expr) => {
631            adjust(expr);
632        }
633        Expression::RayQueryGetIntersection {
634            ref mut query,
635            committed: _,
636        } => {
637            adjust(query);
638        }
639        Expression::Literal(_)
640        | Expression::FunctionArgument(_)
641        | Expression::GlobalVariable(_)
642        | Expression::LocalVariable(_)
643        | Expression::CallResult(_)
644        | Expression::RayQueryProceedResult
645        | Expression::Constant(_)
646        | Expression::Override(_)
647        | Expression::ZeroValue(_)
648        | Expression::AtomicResult {
649            ty: _,
650            comparison: _,
651        }
652        | Expression::WorkGroupUniformLoadResult { ty: _ }
653        | Expression::SubgroupBallotResult
654        | Expression::SubgroupOperationResult { .. } => {}
655        Expression::RayQueryVertexPositions {
656            ref mut query,
657            committed: _,
658        } => {
659            adjust(query);
660        }
661    }
662}
663
664/// Replace every expression handle in `block` with its counterpart
665/// given by `new_pos`.
666fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
667    for stmt in block.iter_mut() {
668        adjust_stmt(new_pos, stmt);
669    }
670}
671
672/// Replace every expression handle in `stmt` with its counterpart
673/// given by `new_pos`.
674fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
675    let adjust = |expr: &mut Handle<Expression>| {
676        *expr = new_pos[*expr];
677    };
678    match *stmt {
679        Statement::Emit(ref mut range) => {
680            if let Some((mut first, mut last)) = range.first_and_last() {
681                adjust(&mut first);
682                adjust(&mut last);
683                *range = Range::new_from_bounds(first, last);
684            }
685        }
686        Statement::Block(ref mut block) => {
687            adjust_block(new_pos, block);
688        }
689        Statement::If {
690            ref mut condition,
691            ref mut accept,
692            ref mut reject,
693        } => {
694            adjust(condition);
695            adjust_block(new_pos, accept);
696            adjust_block(new_pos, reject);
697        }
698        Statement::Switch {
699            ref mut selector,
700            ref mut cases,
701        } => {
702            adjust(selector);
703            for case in cases.iter_mut() {
704                adjust_block(new_pos, &mut case.body);
705            }
706        }
707        Statement::Loop {
708            ref mut body,
709            ref mut continuing,
710            ref mut break_if,
711        } => {
712            adjust_block(new_pos, body);
713            adjust_block(new_pos, continuing);
714            if let Some(e) = break_if.as_mut() {
715                adjust(e);
716            }
717        }
718        Statement::Return { ref mut value } => {
719            if let Some(e) = value.as_mut() {
720                adjust(e);
721            }
722        }
723        Statement::Store {
724            ref mut pointer,
725            ref mut value,
726        } => {
727            adjust(pointer);
728            adjust(value);
729        }
730        Statement::ImageStore {
731            ref mut image,
732            ref mut coordinate,
733            ref mut array_index,
734            ref mut value,
735        } => {
736            adjust(image);
737            adjust(coordinate);
738            if let Some(e) = array_index.as_mut() {
739                adjust(e);
740            }
741            adjust(value);
742        }
743        Statement::Atomic {
744            ref mut pointer,
745            ref mut value,
746            ref mut result,
747            ref mut fun,
748        } => {
749            adjust(pointer);
750            adjust(value);
751            if let Some(ref mut result) = *result {
752                adjust(result);
753            }
754            match *fun {
755                crate::AtomicFunction::Exchange {
756                    compare: Some(ref mut compare),
757                } => {
758                    adjust(compare);
759                }
760                crate::AtomicFunction::Add
761                | crate::AtomicFunction::Subtract
762                | crate::AtomicFunction::And
763                | crate::AtomicFunction::ExclusiveOr
764                | crate::AtomicFunction::InclusiveOr
765                | crate::AtomicFunction::Min
766                | crate::AtomicFunction::Max
767                | crate::AtomicFunction::Exchange { compare: None } => {}
768            }
769        }
770        Statement::ImageAtomic {
771            ref mut image,
772            ref mut coordinate,
773            ref mut array_index,
774            fun: _,
775            ref mut value,
776        } => {
777            adjust(image);
778            adjust(coordinate);
779            if let Some(ref mut array_index) = *array_index {
780                adjust(array_index);
781            }
782            adjust(value);
783        }
784        Statement::WorkGroupUniformLoad {
785            ref mut pointer,
786            ref mut result,
787        } => {
788            adjust(pointer);
789            adjust(result);
790        }
791        Statement::SubgroupBallot {
792            ref mut result,
793            ref mut predicate,
794        } => {
795            if let Some(ref mut predicate) = *predicate {
796                adjust(predicate);
797            }
798            adjust(result);
799        }
800        Statement::SubgroupCollectiveOperation {
801            ref mut argument,
802            ref mut result,
803            ..
804        } => {
805            adjust(argument);
806            adjust(result);
807        }
808        Statement::SubgroupGather {
809            ref mut mode,
810            ref mut argument,
811            ref mut result,
812        } => {
813            match *mode {
814                crate::GatherMode::BroadcastFirst => {}
815                crate::GatherMode::Broadcast(ref mut index)
816                | crate::GatherMode::Shuffle(ref mut index)
817                | crate::GatherMode::ShuffleDown(ref mut index)
818                | crate::GatherMode::ShuffleUp(ref mut index)
819                | crate::GatherMode::ShuffleXor(ref mut index)
820                | crate::GatherMode::QuadBroadcast(ref mut index) => {
821                    adjust(index);
822                }
823                crate::GatherMode::QuadSwap(_) => {}
824            }
825            adjust(argument);
826            adjust(result)
827        }
828        Statement::Call {
829            ref mut arguments,
830            ref mut result,
831            function: _,
832        } => {
833            for argument in arguments.iter_mut() {
834                adjust(argument);
835            }
836            if let Some(e) = result.as_mut() {
837                adjust(e);
838            }
839        }
840        Statement::RayQuery {
841            ref mut query,
842            ref mut fun,
843        } => {
844            adjust(query);
845            match *fun {
846                crate::RayQueryFunction::Initialize {
847                    ref mut acceleration_structure,
848                    ref mut descriptor,
849                } => {
850                    adjust(acceleration_structure);
851                    adjust(descriptor);
852                }
853                crate::RayQueryFunction::Proceed { ref mut result } => {
854                    adjust(result);
855                }
856                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
857                    adjust(hit_t);
858                }
859                crate::RayQueryFunction::ConfirmIntersection => {}
860                crate::RayQueryFunction::Terminate => {}
861            }
862        }
863        Statement::Break
864        | Statement::Continue
865        | Statement::Kill
866        | Statement::ControlBarrier(_)
867        | Statement::MemoryBarrier(_) => {}
868    }
869}
870
871/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
872///
873/// According to validation, [`Emit`] statements must not cover any expressions
874/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
875/// by successful constant evaluation fall into that category, meaning that
876/// `process_function` will usually rewrite [`Override`] expressions and those
877/// that use their values into pre-emitted expressions, leaving any [`Emit`]
878/// statements that cover them invalid.
879///
880/// This function rewrites all [`Emit`] statements into zero or more new
881/// [`Emit`] statements covering only those expressions in the original range
882/// that are not pre-emitted.
883///
884/// [`Emit`]: Statement::Emit
885/// [`needs_pre_emit`]: Expression::needs_pre_emit
886/// [`Override`]: Expression::Override
887fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
888    let original = mem::replace(block, Block::with_capacity(block.len()));
889    for (stmt, span) in original.span_into_iter() {
890        match stmt {
891            Statement::Emit(range) => {
892                let mut current = None;
893                for expr_h in range {
894                    if expressions[expr_h].needs_pre_emit() {
895                        if let Some((first, last)) = current {
896                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
897                        }
898
899                        current = None;
900                    } else if let Some((_, ref mut last)) = current {
901                        *last = expr_h;
902                    } else {
903                        current = Some((expr_h, expr_h));
904                    }
905                }
906                if let Some((first, last)) = current {
907                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
908                }
909            }
910            Statement::Block(mut child) => {
911                filter_emits_in_block(&mut child, expressions);
912                block.push(Statement::Block(child), span);
913            }
914            Statement::If {
915                condition,
916                mut accept,
917                mut reject,
918            } => {
919                filter_emits_in_block(&mut accept, expressions);
920                filter_emits_in_block(&mut reject, expressions);
921                block.push(
922                    Statement::If {
923                        condition,
924                        accept,
925                        reject,
926                    },
927                    span,
928                );
929            }
930            Statement::Switch {
931                selector,
932                mut cases,
933            } => {
934                for case in &mut cases {
935                    filter_emits_in_block(&mut case.body, expressions);
936                }
937                block.push(Statement::Switch { selector, cases }, span);
938            }
939            Statement::Loop {
940                mut body,
941                mut continuing,
942                break_if,
943            } => {
944                filter_emits_in_block(&mut body, expressions);
945                filter_emits_in_block(&mut continuing, expressions);
946                block.push(
947                    Statement::Loop {
948                        body,
949                        continuing,
950                        break_if,
951                    },
952                    span,
953                );
954            }
955            stmt => block.push(stmt.clone(), span),
956        }
957    }
958}
959
960fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
961    // note that in rust 0.0 == -0.0
962    match scalar {
963        Scalar::BOOL => {
964            // https://webidl.spec.whatwg.org/#js-boolean
965            let value = value != 0.0 && !value.is_nan();
966            Ok(Literal::Bool(value))
967        }
968        Scalar::I32 => {
969            // https://webidl.spec.whatwg.org/#js-long
970            if !value.is_finite() {
971                return Err(PipelineConstantError::SrcNeedsToBeFinite);
972            }
973
974            let value = value.trunc();
975            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
976                return Err(PipelineConstantError::DstRangeTooSmall);
977            }
978
979            let value = value as i32;
980            Ok(Literal::I32(value))
981        }
982        Scalar::U32 => {
983            // https://webidl.spec.whatwg.org/#js-unsigned-long
984            if !value.is_finite() {
985                return Err(PipelineConstantError::SrcNeedsToBeFinite);
986            }
987
988            let value = value.trunc();
989            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
990                return Err(PipelineConstantError::DstRangeTooSmall);
991            }
992
993            let value = value as u32;
994            Ok(Literal::U32(value))
995        }
996        Scalar::F16 => {
997            // https://webidl.spec.whatwg.org/#js-float
998            if !value.is_finite() {
999                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1000            }
1001
1002            let value = half::f16::from_f64(value);
1003            if !value.is_finite() {
1004                return Err(PipelineConstantError::DstRangeTooSmall);
1005            }
1006
1007            Ok(Literal::F16(value))
1008        }
1009        Scalar::F32 => {
1010            // https://webidl.spec.whatwg.org/#js-float
1011            if !value.is_finite() {
1012                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1013            }
1014
1015            let value = value as f32;
1016            if !value.is_finite() {
1017                return Err(PipelineConstantError::DstRangeTooSmall);
1018            }
1019
1020            Ok(Literal::F32(value))
1021        }
1022        Scalar::F64 => {
1023            // https://webidl.spec.whatwg.org/#js-double
1024            if !value.is_finite() {
1025                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1026            }
1027
1028            Ok(Literal::F64(value))
1029        }
1030        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1031            unreachable!("abstract values should not be validated out of override processing")
1032        }
1033        _ => unreachable!("unrecognized scalar type for override"),
1034    }
1035}
1036
1037#[test]
1038fn test_map_value_to_literal() {
1039    let bool_test_cases = [
1040        (0.0, false),
1041        (-0.0, false),
1042        (f64::NAN, false),
1043        (1.0, true),
1044        (f64::INFINITY, true),
1045        (f64::NEG_INFINITY, true),
1046    ];
1047    for (value, out) in bool_test_cases {
1048        let res = Ok(Literal::Bool(out));
1049        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1050    }
1051
1052    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1053        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1054            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1055            assert_eq!(map_value_to_literal(value, scalar), res);
1056        }
1057    }
1058
1059    // i32
1060    assert_eq!(
1061        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1062        Ok(Literal::I32(i32::MIN))
1063    );
1064    assert_eq!(
1065        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1066        Ok(Literal::I32(i32::MAX))
1067    );
1068    assert_eq!(
1069        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1070        Err(PipelineConstantError::DstRangeTooSmall)
1071    );
1072    assert_eq!(
1073        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1074        Err(PipelineConstantError::DstRangeTooSmall)
1075    );
1076
1077    // u32
1078    assert_eq!(
1079        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1080        Ok(Literal::U32(u32::MIN))
1081    );
1082    assert_eq!(
1083        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1084        Ok(Literal::U32(u32::MAX))
1085    );
1086    assert_eq!(
1087        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1088        Err(PipelineConstantError::DstRangeTooSmall)
1089    );
1090    assert_eq!(
1091        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1092        Err(PipelineConstantError::DstRangeTooSmall)
1093    );
1094
1095    // f32
1096    assert_eq!(
1097        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1098        Ok(Literal::F32(f32::MIN))
1099    );
1100    assert_eq!(
1101        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1102        Ok(Literal::F32(f32::MAX))
1103    );
1104    assert_eq!(
1105        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1106        Ok(Literal::F32(f32::MIN))
1107    );
1108    assert_eq!(
1109        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1110        Ok(Literal::F32(f32::MAX))
1111    );
1112    assert_eq!(
1113        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1114        Err(PipelineConstantError::DstRangeTooSmall)
1115    );
1116    assert_eq!(
1117        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1118        Err(PipelineConstantError::DstRangeTooSmall)
1119    );
1120
1121    // f64
1122    assert_eq!(
1123        map_value_to_literal(f64::MIN, Scalar::F64),
1124        Ok(Literal::F64(f64::MIN))
1125    );
1126    assert_eq!(
1127        map_value_to_literal(f64::MAX, Scalar::F64),
1128        Ok(Literal::F64(f64::MAX))
1129    );
1130}