naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    compact::{compact, KeepUnused},
14    ir,
15    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
16    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
17    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
18    Span, Statement, TypeInner, WithSpan,
19};
20
21#[cfg(no_std)]
22use num_traits::float::FloatCore as _;
23
24#[derive(Error, Debug, Clone)]
25#[cfg_attr(test, derive(PartialEq))]
26pub enum PipelineConstantError {
27    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
28    MissingValue(String),
29    #[error(
30        "Source f64 value needs to be finite ({}) for number destinations",
31        "NaNs and Inifinites are not allowed"
32    )]
33    SrcNeedsToBeFinite,
34    #[error("Source f64 value doesn't fit in destination")]
35    DstRangeTooSmall,
36    #[error(transparent)]
37    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
38    #[error(transparent)]
39    ValidationError(#[from] WithSpan<ValidationError>),
40    #[error("workgroup_size override isn't strictly positive")]
41    NegativeWorkgroupSize,
42}
43
44/// Compact `module` and replace all overrides with constants.
45///
46/// If no changes are needed, this just returns `Cow::Borrowed` references to
47/// `module` and `module_info`. Otherwise, it clones `module`, retains only the
48/// selected entry point, compacts the module, edits its [`global_expressions`]
49/// arena to contain only fully-evaluated expressions, and returns the
50/// simplified module and its validation results.
51///
52/// The module returned has an empty `overrides` arena, and the
53/// `global_expressions` arena contains only fully-evaluated expressions.
54///
55/// [`global_expressions`]: Module::global_expressions
56pub fn process_overrides<'a>(
57    module: &'a Module,
58    module_info: &'a ModuleInfo,
59    entry_point: Option<(ir::ShaderStage, &str)>,
60    pipeline_constants: &PipelineConstants,
61) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
62    if (entry_point.is_none() || module.entry_points.len() <= 1) && module.overrides.is_empty() {
63        // We skip compacting the module here mostly to reduce the risk of
64        // hitting corner cases like https://github.com/gfx-rs/wgpu/issues/7793.
65        // Compaction doesn't cost very much [1], so it would also be reasonable
66        // to do it unconditionally. Even when there is a single entry point or
67        // when no entry point is specified, it is still possible that there
68        // are unreferenced items in the module that would be removed by this
69        // compaction.
70        //
71        // [1]: https://github.com/gfx-rs/wgpu/pull/7703#issuecomment-2902153760
72        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
73    }
74
75    let mut module = module.clone();
76    if let Some((ep_stage, ep_name)) = entry_point {
77        module
78            .entry_points
79            .retain(|ep| ep.stage == ep_stage && ep.name == ep_name);
80    }
81
82    // Compact the module to remove anything not reachable from an entry point.
83    // This is necessary because we may not have values for overrides that are
84    // not reachable from the/an entry point.
85    compact(&mut module, KeepUnused::No);
86
87    // If there are no overrides in the module, then we can skip the rest.
88    if module.overrides.is_empty() {
89        return revalidate(module);
90    }
91
92    // A map from override handles to the handles of the constants
93    // we've replaced them with.
94    let mut override_map = HandleVec::with_capacity(module.overrides.len());
95
96    // A map from `module`'s original global expression handles to
97    // handles in the new, simplified global expression arena.
98    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
99
100    // The set of constants whose initializer handles we've already
101    // updated to refer to the newly built global expression arena.
102    //
103    // All constants in `module` must have their `init` handles
104    // updated to point into the new, simplified global expression
105    // arena. Some of these we can most easily handle as a side effect
106    // during the simplification process, but we must handle the rest
107    // in a final fixup pass, guided by `adjusted_global_expressions`. We
108    // add their handles to this set, so that the final fixup step can
109    // leave them alone.
110    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
111
112    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
113    let mut layouter = crate::proc::Layouter::default();
114
115    // An iterator through the original overrides table, consumed in
116    // approximate tandem with the global expressions.
117    let mut overrides = mem::take(&mut module.overrides);
118    let mut override_iter = overrides.iter_mut_span();
119
120    // Do two things in tandem:
121    //
122    // - Rebuild the global expression arena from scratch, fully
123    //   evaluating all expressions, and replacing each `Override`
124    //   expression in `module.global_expressions` with a `Constant`
125    //   expression.
126    //
127    // - Build a new `Constant` in `module.constants` to take the
128    //   place of each `Override`.
129    //
130    // Build a map from old global expression handles to their
131    // fully-evaluated counterparts in `adjusted_global_expressions` as we
132    // go.
133    //
134    // Why in tandem? Overrides refer to expressions, and expressions
135    // refer to overrides, so we can't disentangle the two into
136    // separate phases. However, we can take advantage of the fact
137    // that the overrides and expressions must form a DAG, and work
138    // our way from the leaves to the roots, replacing and evaluating
139    // as we go.
140    //
141    // Although the two loops are nested, this is really two
142    // alternating phases: we adjust and evaluate constant expressions
143    // until we hit an `Override` expression, at which point we switch
144    // to building `Constant`s for `Overrides` until we've handled the
145    // one used by the expression. Then we switch back to processing
146    // expressions. Because we know they form a DAG, we know the
147    // `Override` expressions we encounter can only have initializers
148    // referring to global expressions we've already simplified.
149    for (old_h, expr, span) in module.global_expressions.drain() {
150        let mut expr = match expr {
151            Expression::Override(h) => {
152                let c_h = if let Some(new_h) = override_map.get(h) {
153                    *new_h
154                } else {
155                    let mut new_h = None;
156                    for entry in override_iter.by_ref() {
157                        let stop = entry.0 == h;
158                        new_h = Some(process_override(
159                            entry,
160                            pipeline_constants,
161                            &mut module,
162                            &mut override_map,
163                            &adjusted_global_expressions,
164                            &mut adjusted_constant_initializers,
165                            &mut global_expression_kind_tracker,
166                        )?);
167                        if stop {
168                            break;
169                        }
170                    }
171                    new_h.unwrap()
172                };
173                Expression::Constant(c_h)
174            }
175            Expression::Constant(c_h) => {
176                if adjusted_constant_initializers.insert(c_h) {
177                    let init = &mut module.constants[c_h].init;
178                    *init = adjusted_global_expressions[*init];
179                }
180                expr
181            }
182            expr => expr,
183        };
184        let mut evaluator = ConstantEvaluator::for_wgsl_module(
185            &mut module,
186            &mut global_expression_kind_tracker,
187            &mut layouter,
188            false,
189        );
190        adjust_expr(&adjusted_global_expressions, &mut expr);
191        let h = evaluator.try_eval_and_append(expr, span)?;
192        adjusted_global_expressions.insert(old_h, h);
193    }
194
195    // Finish processing any overrides we didn't visit in the loop above.
196    for entry in override_iter {
197        match *entry.1 {
198            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
199                process_override(
200                    entry,
201                    pipeline_constants,
202                    &mut module,
203                    &mut override_map,
204                    &adjusted_global_expressions,
205                    &mut adjusted_constant_initializers,
206                    &mut global_expression_kind_tracker,
207                )?;
208            }
209            Override {
210                init: Some(ref mut init),
211                ..
212            } => {
213                *init = adjusted_global_expressions[*init];
214            }
215            _ => {}
216        }
217    }
218
219    // Update the initialization expression handles of all `Constant`s
220    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
221    // passant.
222    for (_, c) in module
223        .constants
224        .iter_mut()
225        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
226    {
227        c.init = adjusted_global_expressions[c.init];
228    }
229
230    for (_, v) in module.global_variables.iter_mut() {
231        if let Some(ref mut init) = v.init {
232            *init = adjusted_global_expressions[*init];
233        }
234    }
235
236    let mut functions = mem::take(&mut module.functions);
237    for (_, function) in functions.iter_mut() {
238        process_function(&mut module, &override_map, &mut layouter, function)?;
239    }
240    module.functions = functions;
241
242    let mut entry_points = mem::take(&mut module.entry_points);
243    for ep in entry_points.iter_mut() {
244        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
245        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
246    }
247    module.entry_points = entry_points;
248    module.overrides = overrides;
249
250    // Now that we've rewritten all the expressions, we need to
251    // recompute their types and other metadata. For the time being,
252    // do a full re-validation.
253    revalidate(module)
254}
255
256fn revalidate(
257    module: Module,
258) -> Result<(Cow<'static, Module>, Cow<'static, ModuleInfo>), PipelineConstantError> {
259    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
260    let module_info = validator.validate_resolved_overrides(&module)?;
261    Ok((Cow::Owned(module), Cow::Owned(module_info)))
262}
263
264fn process_workgroup_size_override(
265    module: &mut Module,
266    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
267    ep: &mut crate::EntryPoint,
268) -> Result<(), PipelineConstantError> {
269    match ep.workgroup_size_overrides {
270        None => {}
271        Some(overrides) => {
272            overrides.iter().enumerate().try_for_each(
273                |(i, overridden)| -> Result<(), PipelineConstantError> {
274                    match *overridden {
275                        None => Ok(()),
276                        Some(h) => {
277                            ep.workgroup_size[i] = module
278                                .to_ctx()
279                                .eval_expr_to_u32(adjusted_global_expressions[h])
280                                .map(|n| {
281                                    if n == 0 {
282                                        Err(PipelineConstantError::NegativeWorkgroupSize)
283                                    } else {
284                                        Ok(n)
285                                    }
286                                })
287                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
288                            Ok(())
289                        }
290                    }
291                },
292            )?;
293            ep.workgroup_size_overrides = None;
294        }
295    }
296    Ok(())
297}
298
299/// Add a [`Constant`] to `module` for the override `old_h`.
300///
301/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
302fn process_override(
303    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
304    pipeline_constants: &PipelineConstants,
305    module: &mut Module,
306    override_map: &mut HandleVec<Override, Handle<Constant>>,
307    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
308    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
309    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
310) -> Result<Handle<Constant>, PipelineConstantError> {
311    // Determine which key to use for `r#override` in `pipeline_constants`.
312    let key = if let Some(id) = r#override.id {
313        Cow::Owned(id.to_string())
314    } else if let Some(ref name) = r#override.name {
315        Cow::Borrowed(name)
316    } else {
317        unreachable!();
318    };
319
320    // Generate a global expression for `r#override`'s value, either
321    // from the provided `pipeline_constants` table or its initializer
322    // in the module.
323    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
324        let literal = match module.types[r#override.ty].inner {
325            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
326            _ => unreachable!(),
327        };
328        let expr = module
329            .global_expressions
330            .append(Expression::Literal(literal), Span::UNDEFINED);
331        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
332        expr
333    } else if let Some(init) = r#override.init {
334        adjusted_global_expressions[init]
335    } else {
336        return Err(PipelineConstantError::MissingValue(key.to_string()));
337    };
338
339    // Generate a new `Constant` to represent the override's value.
340    let constant = Constant {
341        name: r#override.name.clone(),
342        ty: r#override.ty,
343        init,
344    };
345    let h = module.constants.append(constant, *span);
346    override_map.insert(old_h, h);
347    adjusted_constant_initializers.insert(h);
348    r#override.init = Some(init);
349    Ok(h)
350}
351
352/// Replace all override expressions in `function` with fully-evaluated constants.
353///
354/// Replace all `Expression::Override`s in `function`'s expression arena with
355/// the corresponding `Expression::Constant`s, as given in `override_map`.
356/// Replace any expressions whose values are now known with their fully
357/// evaluated form.
358///
359/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
360/// `Handle<Constant>` for the override's final value.
361fn process_function(
362    module: &mut Module,
363    override_map: &HandleVec<Override, Handle<Constant>>,
364    layouter: &mut crate::proc::Layouter,
365    function: &mut Function,
366) -> Result<(), ConstantEvaluatorError> {
367    // A map from original local expression handles to
368    // handles in the new, local expression arena.
369    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
370
371    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
372
373    let mut expressions = mem::take(&mut function.expressions);
374
375    // Dummy `emitter` and `block` for the constant evaluator.
376    // We can ignore the concept of emitting expressions here since
377    // expressions have already been covered by a `Statement::Emit`
378    // in the frontend.
379    // The only thing we might have to do is remove some expressions
380    // that have been covered by a `Statement::Emit`. See the docs of
381    // `filter_emits_in_block` for the reasoning.
382    let mut emitter = Emitter::default();
383    let mut block = Block::new();
384
385    let mut evaluator = ConstantEvaluator::for_wgsl_function(
386        module,
387        &mut function.expressions,
388        &mut local_expression_kind_tracker,
389        layouter,
390        &mut emitter,
391        &mut block,
392        false,
393    );
394
395    for (old_h, mut expr, span) in expressions.drain() {
396        if let Expression::Override(h) = expr {
397            expr = Expression::Constant(override_map[h]);
398        }
399        adjust_expr(&adjusted_local_expressions, &mut expr);
400        let h = evaluator.try_eval_and_append(expr, span)?;
401        adjusted_local_expressions.insert(old_h, h);
402    }
403
404    adjust_block(&adjusted_local_expressions, &mut function.body);
405
406    filter_emits_in_block(&mut function.body, &function.expressions);
407
408    // Update local expression initializers.
409    for (_, local) in function.local_variables.iter_mut() {
410        if let &mut Some(ref mut init) = &mut local.init {
411            *init = adjusted_local_expressions[*init];
412        }
413    }
414
415    // We've changed the keys of `function.named_expression`, so we have to
416    // rebuild it from scratch.
417    let named_expressions = mem::take(&mut function.named_expressions);
418    for (expr_h, name) in named_expressions {
419        function
420            .named_expressions
421            .insert(adjusted_local_expressions[expr_h], name);
422    }
423
424    Ok(())
425}
426
427/// Replace every expression handle in `expr` with its counterpart
428/// given by `new_pos`.
429fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
430    let adjust = |expr: &mut Handle<Expression>| {
431        *expr = new_pos[*expr];
432    };
433    match *expr {
434        Expression::Compose {
435            ref mut components,
436            ty: _,
437        } => {
438            for c in components.iter_mut() {
439                adjust(c);
440            }
441        }
442        Expression::Access {
443            ref mut base,
444            ref mut index,
445        } => {
446            adjust(base);
447            adjust(index);
448        }
449        Expression::AccessIndex {
450            ref mut base,
451            index: _,
452        } => {
453            adjust(base);
454        }
455        Expression::Splat {
456            ref mut value,
457            size: _,
458        } => {
459            adjust(value);
460        }
461        Expression::Swizzle {
462            ref mut vector,
463            size: _,
464            pattern: _,
465        } => {
466            adjust(vector);
467        }
468        Expression::Load { ref mut pointer } => {
469            adjust(pointer);
470        }
471        Expression::ImageSample {
472            ref mut image,
473            ref mut sampler,
474            ref mut coordinate,
475            ref mut array_index,
476            ref mut offset,
477            ref mut level,
478            ref mut depth_ref,
479            gather: _,
480            clamp_to_edge: _,
481        } => {
482            adjust(image);
483            adjust(sampler);
484            adjust(coordinate);
485            if let Some(e) = array_index.as_mut() {
486                adjust(e);
487            }
488            if let Some(e) = offset.as_mut() {
489                adjust(e);
490            }
491            match *level {
492                crate::SampleLevel::Exact(ref mut expr)
493                | crate::SampleLevel::Bias(ref mut expr) => {
494                    adjust(expr);
495                }
496                crate::SampleLevel::Gradient {
497                    ref mut x,
498                    ref mut y,
499                } => {
500                    adjust(x);
501                    adjust(y);
502                }
503                _ => {}
504            }
505            if let Some(e) = depth_ref.as_mut() {
506                adjust(e);
507            }
508        }
509        Expression::ImageLoad {
510            ref mut image,
511            ref mut coordinate,
512            ref mut array_index,
513            ref mut sample,
514            ref mut level,
515        } => {
516            adjust(image);
517            adjust(coordinate);
518            if let Some(e) = array_index.as_mut() {
519                adjust(e);
520            }
521            if let Some(e) = sample.as_mut() {
522                adjust(e);
523            }
524            if let Some(e) = level.as_mut() {
525                adjust(e);
526            }
527        }
528        Expression::ImageQuery {
529            ref mut image,
530            ref mut query,
531        } => {
532            adjust(image);
533            match *query {
534                crate::ImageQuery::Size { ref mut level } => {
535                    if let Some(e) = level.as_mut() {
536                        adjust(e);
537                    }
538                }
539                crate::ImageQuery::NumLevels
540                | crate::ImageQuery::NumLayers
541                | crate::ImageQuery::NumSamples => {}
542            }
543        }
544        Expression::Unary {
545            ref mut expr,
546            op: _,
547        } => {
548            adjust(expr);
549        }
550        Expression::Binary {
551            ref mut left,
552            ref mut right,
553            op: _,
554        } => {
555            adjust(left);
556            adjust(right);
557        }
558        Expression::Select {
559            ref mut condition,
560            ref mut accept,
561            ref mut reject,
562        } => {
563            adjust(condition);
564            adjust(accept);
565            adjust(reject);
566        }
567        Expression::Derivative {
568            ref mut expr,
569            axis: _,
570            ctrl: _,
571        } => {
572            adjust(expr);
573        }
574        Expression::Relational {
575            ref mut argument,
576            fun: _,
577        } => {
578            adjust(argument);
579        }
580        Expression::Math {
581            ref mut arg,
582            ref mut arg1,
583            ref mut arg2,
584            ref mut arg3,
585            fun: _,
586        } => {
587            adjust(arg);
588            if let Some(e) = arg1.as_mut() {
589                adjust(e);
590            }
591            if let Some(e) = arg2.as_mut() {
592                adjust(e);
593            }
594            if let Some(e) = arg3.as_mut() {
595                adjust(e);
596            }
597        }
598        Expression::As {
599            ref mut expr,
600            kind: _,
601            convert: _,
602        } => {
603            adjust(expr);
604        }
605        Expression::ArrayLength(ref mut expr) => {
606            adjust(expr);
607        }
608        Expression::RayQueryGetIntersection {
609            ref mut query,
610            committed: _,
611        } => {
612            adjust(query);
613        }
614        Expression::Literal(_)
615        | Expression::FunctionArgument(_)
616        | Expression::GlobalVariable(_)
617        | Expression::LocalVariable(_)
618        | Expression::CallResult(_)
619        | Expression::RayQueryProceedResult
620        | Expression::Constant(_)
621        | Expression::Override(_)
622        | Expression::ZeroValue(_)
623        | Expression::AtomicResult {
624            ty: _,
625            comparison: _,
626        }
627        | Expression::WorkGroupUniformLoadResult { ty: _ }
628        | Expression::SubgroupBallotResult
629        | Expression::SubgroupOperationResult { .. } => {}
630        Expression::RayQueryVertexPositions {
631            ref mut query,
632            committed: _,
633        } => {
634            adjust(query);
635        }
636    }
637}
638
639/// Replace every expression handle in `block` with its counterpart
640/// given by `new_pos`.
641fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
642    for stmt in block.iter_mut() {
643        adjust_stmt(new_pos, stmt);
644    }
645}
646
647/// Replace every expression handle in `stmt` with its counterpart
648/// given by `new_pos`.
649fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
650    let adjust = |expr: &mut Handle<Expression>| {
651        *expr = new_pos[*expr];
652    };
653    match *stmt {
654        Statement::Emit(ref mut range) => {
655            if let Some((mut first, mut last)) = range.first_and_last() {
656                adjust(&mut first);
657                adjust(&mut last);
658                *range = Range::new_from_bounds(first, last);
659            }
660        }
661        Statement::Block(ref mut block) => {
662            adjust_block(new_pos, block);
663        }
664        Statement::If {
665            ref mut condition,
666            ref mut accept,
667            ref mut reject,
668        } => {
669            adjust(condition);
670            adjust_block(new_pos, accept);
671            adjust_block(new_pos, reject);
672        }
673        Statement::Switch {
674            ref mut selector,
675            ref mut cases,
676        } => {
677            adjust(selector);
678            for case in cases.iter_mut() {
679                adjust_block(new_pos, &mut case.body);
680            }
681        }
682        Statement::Loop {
683            ref mut body,
684            ref mut continuing,
685            ref mut break_if,
686        } => {
687            adjust_block(new_pos, body);
688            adjust_block(new_pos, continuing);
689            if let Some(e) = break_if.as_mut() {
690                adjust(e);
691            }
692        }
693        Statement::Return { ref mut value } => {
694            if let Some(e) = value.as_mut() {
695                adjust(e);
696            }
697        }
698        Statement::Store {
699            ref mut pointer,
700            ref mut value,
701        } => {
702            adjust(pointer);
703            adjust(value);
704        }
705        Statement::ImageStore {
706            ref mut image,
707            ref mut coordinate,
708            ref mut array_index,
709            ref mut value,
710        } => {
711            adjust(image);
712            adjust(coordinate);
713            if let Some(e) = array_index.as_mut() {
714                adjust(e);
715            }
716            adjust(value);
717        }
718        Statement::Atomic {
719            ref mut pointer,
720            ref mut value,
721            ref mut result,
722            ref mut fun,
723        } => {
724            adjust(pointer);
725            adjust(value);
726            if let Some(ref mut result) = *result {
727                adjust(result);
728            }
729            match *fun {
730                crate::AtomicFunction::Exchange {
731                    compare: Some(ref mut compare),
732                } => {
733                    adjust(compare);
734                }
735                crate::AtomicFunction::Add
736                | crate::AtomicFunction::Subtract
737                | crate::AtomicFunction::And
738                | crate::AtomicFunction::ExclusiveOr
739                | crate::AtomicFunction::InclusiveOr
740                | crate::AtomicFunction::Min
741                | crate::AtomicFunction::Max
742                | crate::AtomicFunction::Exchange { compare: None } => {}
743            }
744        }
745        Statement::ImageAtomic {
746            ref mut image,
747            ref mut coordinate,
748            ref mut array_index,
749            fun: _,
750            ref mut value,
751        } => {
752            adjust(image);
753            adjust(coordinate);
754            if let Some(ref mut array_index) = *array_index {
755                adjust(array_index);
756            }
757            adjust(value);
758        }
759        Statement::WorkGroupUniformLoad {
760            ref mut pointer,
761            ref mut result,
762        } => {
763            adjust(pointer);
764            adjust(result);
765        }
766        Statement::SubgroupBallot {
767            ref mut result,
768            ref mut predicate,
769        } => {
770            if let Some(ref mut predicate) = *predicate {
771                adjust(predicate);
772            }
773            adjust(result);
774        }
775        Statement::SubgroupCollectiveOperation {
776            ref mut argument,
777            ref mut result,
778            ..
779        } => {
780            adjust(argument);
781            adjust(result);
782        }
783        Statement::SubgroupGather {
784            ref mut mode,
785            ref mut argument,
786            ref mut result,
787        } => {
788            match *mode {
789                crate::GatherMode::BroadcastFirst => {}
790                crate::GatherMode::Broadcast(ref mut index)
791                | crate::GatherMode::Shuffle(ref mut index)
792                | crate::GatherMode::ShuffleDown(ref mut index)
793                | crate::GatherMode::ShuffleUp(ref mut index)
794                | crate::GatherMode::ShuffleXor(ref mut index)
795                | crate::GatherMode::QuadBroadcast(ref mut index) => {
796                    adjust(index);
797                }
798                crate::GatherMode::QuadSwap(_) => {}
799            }
800            adjust(argument);
801            adjust(result)
802        }
803        Statement::Call {
804            ref mut arguments,
805            ref mut result,
806            function: _,
807        } => {
808            for argument in arguments.iter_mut() {
809                adjust(argument);
810            }
811            if let Some(e) = result.as_mut() {
812                adjust(e);
813            }
814        }
815        Statement::RayQuery {
816            ref mut query,
817            ref mut fun,
818        } => {
819            adjust(query);
820            match *fun {
821                crate::RayQueryFunction::Initialize {
822                    ref mut acceleration_structure,
823                    ref mut descriptor,
824                } => {
825                    adjust(acceleration_structure);
826                    adjust(descriptor);
827                }
828                crate::RayQueryFunction::Proceed { ref mut result } => {
829                    adjust(result);
830                }
831                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
832                    adjust(hit_t);
833                }
834                crate::RayQueryFunction::ConfirmIntersection => {}
835                crate::RayQueryFunction::Terminate => {}
836            }
837        }
838        Statement::Break
839        | Statement::Continue
840        | Statement::Kill
841        | Statement::ControlBarrier(_)
842        | Statement::MemoryBarrier(_) => {}
843    }
844}
845
846/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
847///
848/// According to validation, [`Emit`] statements must not cover any expressions
849/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
850/// by successful constant evaluation fall into that category, meaning that
851/// `process_function` will usually rewrite [`Override`] expressions and those
852/// that use their values into pre-emitted expressions, leaving any [`Emit`]
853/// statements that cover them invalid.
854///
855/// This function rewrites all [`Emit`] statements into zero or more new
856/// [`Emit`] statements covering only those expressions in the original range
857/// that are not pre-emitted.
858///
859/// [`Emit`]: Statement::Emit
860/// [`needs_pre_emit`]: Expression::needs_pre_emit
861/// [`Override`]: Expression::Override
862fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
863    let original = mem::replace(block, Block::with_capacity(block.len()));
864    for (stmt, span) in original.span_into_iter() {
865        match stmt {
866            Statement::Emit(range) => {
867                let mut current = None;
868                for expr_h in range {
869                    if expressions[expr_h].needs_pre_emit() {
870                        if let Some((first, last)) = current {
871                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
872                        }
873
874                        current = None;
875                    } else if let Some((_, ref mut last)) = current {
876                        *last = expr_h;
877                    } else {
878                        current = Some((expr_h, expr_h));
879                    }
880                }
881                if let Some((first, last)) = current {
882                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
883                }
884            }
885            Statement::Block(mut child) => {
886                filter_emits_in_block(&mut child, expressions);
887                block.push(Statement::Block(child), span);
888            }
889            Statement::If {
890                condition,
891                mut accept,
892                mut reject,
893            } => {
894                filter_emits_in_block(&mut accept, expressions);
895                filter_emits_in_block(&mut reject, expressions);
896                block.push(
897                    Statement::If {
898                        condition,
899                        accept,
900                        reject,
901                    },
902                    span,
903                );
904            }
905            Statement::Switch {
906                selector,
907                mut cases,
908            } => {
909                for case in &mut cases {
910                    filter_emits_in_block(&mut case.body, expressions);
911                }
912                block.push(Statement::Switch { selector, cases }, span);
913            }
914            Statement::Loop {
915                mut body,
916                mut continuing,
917                break_if,
918            } => {
919                filter_emits_in_block(&mut body, expressions);
920                filter_emits_in_block(&mut continuing, expressions);
921                block.push(
922                    Statement::Loop {
923                        body,
924                        continuing,
925                        break_if,
926                    },
927                    span,
928                );
929            }
930            stmt => block.push(stmt.clone(), span),
931        }
932    }
933}
934
935fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
936    // note that in rust 0.0 == -0.0
937    match scalar {
938        Scalar::BOOL => {
939            // https://webidl.spec.whatwg.org/#js-boolean
940            let value = value != 0.0 && !value.is_nan();
941            Ok(Literal::Bool(value))
942        }
943        Scalar::I32 => {
944            // https://webidl.spec.whatwg.org/#js-long
945            if !value.is_finite() {
946                return Err(PipelineConstantError::SrcNeedsToBeFinite);
947            }
948
949            let value = value.trunc();
950            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
951                return Err(PipelineConstantError::DstRangeTooSmall);
952            }
953
954            let value = value as i32;
955            Ok(Literal::I32(value))
956        }
957        Scalar::U32 => {
958            // https://webidl.spec.whatwg.org/#js-unsigned-long
959            if !value.is_finite() {
960                return Err(PipelineConstantError::SrcNeedsToBeFinite);
961            }
962
963            let value = value.trunc();
964            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
965                return Err(PipelineConstantError::DstRangeTooSmall);
966            }
967
968            let value = value as u32;
969            Ok(Literal::U32(value))
970        }
971        Scalar::F16 => {
972            // https://webidl.spec.whatwg.org/#js-float
973            if !value.is_finite() {
974                return Err(PipelineConstantError::SrcNeedsToBeFinite);
975            }
976
977            let value = half::f16::from_f64(value);
978            if !value.is_finite() {
979                return Err(PipelineConstantError::DstRangeTooSmall);
980            }
981
982            Ok(Literal::F16(value))
983        }
984        Scalar::F32 => {
985            // https://webidl.spec.whatwg.org/#js-float
986            if !value.is_finite() {
987                return Err(PipelineConstantError::SrcNeedsToBeFinite);
988            }
989
990            let value = value as f32;
991            if !value.is_finite() {
992                return Err(PipelineConstantError::DstRangeTooSmall);
993            }
994
995            Ok(Literal::F32(value))
996        }
997        Scalar::F64 => {
998            // https://webidl.spec.whatwg.org/#js-double
999            if !value.is_finite() {
1000                return Err(PipelineConstantError::SrcNeedsToBeFinite);
1001            }
1002
1003            Ok(Literal::F64(value))
1004        }
1005        Scalar::ABSTRACT_FLOAT | Scalar::ABSTRACT_INT => {
1006            unreachable!("abstract values should not be validated out of override processing")
1007        }
1008        _ => unreachable!("unrecognized scalar type for override"),
1009    }
1010}
1011
1012#[test]
1013fn test_map_value_to_literal() {
1014    let bool_test_cases = [
1015        (0.0, false),
1016        (-0.0, false),
1017        (f64::NAN, false),
1018        (1.0, true),
1019        (f64::INFINITY, true),
1020        (f64::NEG_INFINITY, true),
1021    ];
1022    for (value, out) in bool_test_cases {
1023        let res = Ok(Literal::Bool(out));
1024        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
1025    }
1026
1027    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
1028        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
1029            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
1030            assert_eq!(map_value_to_literal(value, scalar), res);
1031        }
1032    }
1033
1034    // i32
1035    assert_eq!(
1036        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
1037        Ok(Literal::I32(i32::MIN))
1038    );
1039    assert_eq!(
1040        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
1041        Ok(Literal::I32(i32::MAX))
1042    );
1043    assert_eq!(
1044        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
1045        Err(PipelineConstantError::DstRangeTooSmall)
1046    );
1047    assert_eq!(
1048        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
1049        Err(PipelineConstantError::DstRangeTooSmall)
1050    );
1051
1052    // u32
1053    assert_eq!(
1054        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
1055        Ok(Literal::U32(u32::MIN))
1056    );
1057    assert_eq!(
1058        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1059        Ok(Literal::U32(u32::MAX))
1060    );
1061    assert_eq!(
1062        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1063        Err(PipelineConstantError::DstRangeTooSmall)
1064    );
1065    assert_eq!(
1066        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1067        Err(PipelineConstantError::DstRangeTooSmall)
1068    );
1069
1070    // f32
1071    assert_eq!(
1072        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1073        Ok(Literal::F32(f32::MIN))
1074    );
1075    assert_eq!(
1076        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1077        Ok(Literal::F32(f32::MAX))
1078    );
1079    assert_eq!(
1080        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1081        Ok(Literal::F32(f32::MIN))
1082    );
1083    assert_eq!(
1084        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1085        Ok(Literal::F32(f32::MAX))
1086    );
1087    assert_eq!(
1088        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1089        Err(PipelineConstantError::DstRangeTooSmall)
1090    );
1091    assert_eq!(
1092        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1093        Err(PipelineConstantError::DstRangeTooSmall)
1094    );
1095
1096    // f64
1097    assert_eq!(
1098        map_value_to_literal(f64::MIN, Scalar::F64),
1099        Ok(Literal::F64(f64::MIN))
1100    );
1101    assert_eq!(
1102        map_value_to_literal(f64::MAX, Scalar::F64),
1103        Ok(Literal::F64(f64::MAX))
1104    );
1105}