naga/compact/
mod.rs

1mod expressions;
2mod functions;
3mod handle_set_map;
4mod statements;
5mod types;
6
7use alloc::vec::Vec;
8
9use crate::{
10    arena::{self, HandleSet},
11    compact::functions::FunctionTracer,
12    ir,
13};
14use handle_set_map::HandleMap;
15
16#[cfg(test)]
17use alloc::{format, string::ToString};
18
19/// Configuration option for [`compact`]. See [`compact`] for details.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum KeepUnused {
22    No,
23    Yes,
24}
25
26impl From<KeepUnused> for bool {
27    fn from(keep_unused: KeepUnused) -> Self {
28        match keep_unused {
29            KeepUnused::No => false,
30            KeepUnused::Yes => true,
31        }
32    }
33}
34
35/// Remove most unused objects from `module`, which must be valid.
36///
37/// Always removes the following unused objects:
38/// - anonymous types, overrides, and constants
39/// - abstract-typed constants
40/// - expressions
41///
42/// If `keep_unused` is `Yes`, the following are never considered unused,
43/// otherwise, they will also be removed if unused:
44/// - functions
45/// - global variables
46/// - named types and overrides
47///
48/// The following are never removed:
49/// - named constants with a concrete type
50/// - special types
51/// - entry points
52/// - within an entry point or a used function:
53///     - arguments
54///     - local variables
55///     - named expressions
56///
57/// After removing items according to the rules above, all handles in the
58/// remaining objects are adjusted as necessary. When `KeepUnused` is `Yes`, the
59/// resulting module should have all the named objects (except abstract-typed
60/// constants) present in the original, and those objects should be functionally
61/// identical. When `KeepUnused` is `No`, the resulting module should have the
62/// entry points present in the original, and those entry points should be
63/// functionally identical.
64///
65/// # Panics
66///
67/// If `module` would not pass validation, this may panic.
68pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) {
69    // The trickiest part of compaction is determining what is used and what is
70    // not. Once we have computed that correctly, it's easy enough to call
71    // `retain_mut` on each arena, drop unused elements, and fix up the handles
72    // in what's left.
73    //
74    // For every compactable arena in a `Module`, whether global to the `Module`
75    // or local to a function or entry point, the `ModuleTracer` type holds a
76    // bitmap indicating which elements of that arena are used. Our task is to
77    // populate those bitmaps correctly.
78    //
79    // First, we mark everything that is considered used by definition, as
80    // described in this function's documentation.
81    //
82    // Since functions and entry points are considered used by definition, we
83    // traverse their statement trees, and mark the referents of all handles
84    // appearing in those statements as used.
85    //
86    // Once we've marked which elements of an arena are referred to directly by
87    // handles elsewhere (for example, which of a function's expressions are
88    // referred to by handles in its body statements), we can mark all the other
89    // arena elements that are used indirectly in a single pass, traversing the
90    // arena from back to front. Since Naga allows arena elements to refer only
91    // to prior elements, we know that by the time we reach an element, all
92    // other elements that could possibly refer to it have already been visited.
93    // Thus, if the present element has not been marked as used, then it is
94    // definitely unused, and compaction can remove it. Otherwise, the element
95    // is used and must be retained, so we must mark everything it refers to.
96    //
97    // The final step is to mark the global expressions and types, which must be
98    // traversed simultaneously; see `ModuleTracer::type_expression_tandem`'s
99    // documentation for details.
100    //
101    // # A definition and a rule of thumb
102    //
103    // In this module, to "trace" something is to mark everything else it refers
104    // to as used, on the assumption that the thing itself is used. For example,
105    // to trace an `Expression` is to mark its subexpressions as used, as well
106    // as any types, constants, overrides, etc. that it refers to. This is what
107    // `ExpressionTracer::trace_expression` does.
108    //
109    // Given that we we want to visit each thing only once (to keep compaction
110    // linear in the size of the module), this definition of "trace" implies
111    // that things that are not "used by definition" must be marked as used
112    // *before* we trace them.
113    //
114    // Thus, whenever you are marking something as used, it's a good idea to ask
115    // yourself how you know that thing will be traced in the future. If you're
116    // not sure, then you could be marking it too late to be noticed. The thing
117    // itself will be retained by compaction, but since it will not be traced,
118    // anything it refers to could be compacted away.
119    let mut module_tracer = ModuleTracer::new(module);
120
121    // Observe what each entry point actually uses.
122    log::trace!("tracing entry points");
123    let entry_point_maps = module
124        .entry_points
125        .iter()
126        .map(|e| {
127            log::trace!("tracing entry point {:?}", e.function.name);
128
129            if let Some(sizes) = e.workgroup_size_overrides {
130                for size in sizes.iter().filter_map(|x| *x) {
131                    module_tracer.global_expressions_used.insert(size);
132                }
133            }
134
135            let mut used = module_tracer.as_function(&e.function);
136            used.trace();
137            FunctionMap::from(used)
138        })
139        .collect::<Vec<_>>();
140
141    // Observe which types, constant expressions, constants, and expressions
142    // each function uses, and produce maps for each function from
143    // pre-compaction to post-compaction expression handles.
144    //
145    // The function tracing logic here works in conjunction with
146    // `FunctionTracer::trace_call`, which, when tracing a `Statement::Call`
147    // to a function not already identified as used, adds the called function
148    // to both `functions_used` and `functions_pending`.
149    //
150    // Called functions are required to appear before their callers in the
151    // functions arena (recursion is disallowed). We have already traced the
152    // entry point(s) and added any functions called directly by the entry
153    // point(s) to `functions_pending`. We proceed by repeatedly tracing the
154    // last function in `functions_pending`. By an inductive argument, any
155    // functions after the last function in `functions_pending` must be unused.
156    //
157    // When `KeepUnused` is active, we simply mark all functions as pending,
158    // and then trace all of them.
159    log::trace!("tracing functions");
160    let mut function_maps = HandleMap::with_capacity(module.functions.len());
161    if keep_unused.into() {
162        module_tracer.functions_used.add_all();
163        module_tracer.functions_pending.add_all();
164    }
165    while let Some(handle) = module_tracer.functions_pending.pop() {
166        let function = &module.functions[handle];
167        log::trace!("tracing function {function:?}");
168        let mut function_tracer = module_tracer.as_function(function);
169        function_tracer.trace();
170        function_maps.insert(handle, FunctionMap::from(function_tracer));
171    }
172
173    // We treat all special types as used by definition.
174    log::trace!("tracing special types");
175    module_tracer.trace_special_types(&module.special_types);
176
177    log::trace!("tracing global variables");
178    if keep_unused.into() {
179        module_tracer.global_variables_used.add_all();
180    }
181    for global in module_tracer.global_variables_used.iter() {
182        log::trace!("tracing global {:?}", module.global_variables[global].name);
183        module_tracer
184            .types_used
185            .insert(module.global_variables[global].ty);
186        if let Some(init) = module.global_variables[global].init {
187            module_tracer.global_expressions_used.insert(init);
188        }
189    }
190
191    // We treat all named constants as used by definition, unless they have an
192    // abstract type as we do not want those reaching the validator.
193    log::trace!("tracing named constants");
194    for (handle, constant) in module.constants.iter() {
195        if constant.name.is_none() || module.types[constant.ty].inner.is_abstract(&module.types) {
196            continue;
197        }
198
199        log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap());
200        module_tracer.constants_used.insert(handle);
201        module_tracer.types_used.insert(constant.ty);
202        module_tracer.global_expressions_used.insert(constant.init);
203    }
204
205    if keep_unused.into() {
206        // Treat all named overrides as used.
207        for (handle, r#override) in module.overrides.iter() {
208            if r#override.name.is_some() && module_tracer.overrides_used.insert(handle) {
209                module_tracer.types_used.insert(r#override.ty);
210                if let Some(init) = r#override.init {
211                    module_tracer.global_expressions_used.insert(init);
212                }
213            }
214        }
215
216        // Treat all named types as used.
217        for (handle, ty) in module.types.iter() {
218            if ty.name.is_some() {
219                module_tracer.types_used.insert(handle);
220            }
221        }
222    }
223
224    for entry in &module.entry_points {
225        if let Some(task_payload) = entry.task_payload {
226            module_tracer.global_variables_used.insert(task_payload);
227        }
228        if let Some(ref mesh_info) = entry.mesh_info {
229            module_tracer
230                .types_used
231                .insert(mesh_info.vertex_output_type);
232            module_tracer
233                .types_used
234                .insert(mesh_info.primitive_output_type);
235            if let Some(max_vertices_override) = mesh_info.max_vertices_override {
236                module_tracer
237                    .global_expressions_used
238                    .insert(max_vertices_override);
239            }
240            if let Some(max_primitives_override) = mesh_info.max_primitives_override {
241                module_tracer
242                    .global_expressions_used
243                    .insert(max_primitives_override);
244            }
245        }
246        if entry.stage == crate::ShaderStage::Task || entry.stage == crate::ShaderStage::Mesh {
247            // u32 should always be there if the module is valid, as it is e.g. the type of some expressions
248            let u32_type = module
249                .types
250                .iter()
251                .find_map(|tuple| {
252                    if tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32) {
253                        Some(tuple.0)
254                    } else {
255                        None
256                    }
257                })
258                .unwrap();
259            module_tracer.types_used.insert(u32_type);
260        }
261    }
262
263    module_tracer.type_expression_tandem();
264
265    // Now that we know what is used and what is never touched,
266    // produce maps from the `Handle`s that appear in `module` now to
267    // the corresponding `Handle`s that will refer to the same items
268    // in the compacted module.
269    let module_map = ModuleMap::from(module_tracer);
270
271    // Drop unused types from the type arena.
272    //
273    // `FastIndexSet`s don't have an underlying Vec<T> that we can
274    // steal, compact in place, and then rebuild the `FastIndexSet`
275    // from. So we have to rebuild the type arena from scratch.
276    log::trace!("compacting types");
277    let mut new_types = arena::UniqueArena::new();
278    for (old_handle, mut ty, span) in module.types.drain_all() {
279        if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) {
280            module_map.adjust_type(&mut ty);
281            let actual_new_handle = new_types.insert(ty, span);
282            assert_eq!(actual_new_handle, expected_new_handle);
283        }
284    }
285    module.types = new_types;
286    log::trace!("adjusting special types");
287    module_map.adjust_special_types(&mut module.special_types);
288
289    // Drop unused constant expressions, reusing existing storage.
290    log::trace!("adjusting constant expressions");
291    module.global_expressions.retain_mut(|handle, expr| {
292        if module_map.global_expressions.used(handle) {
293            module_map.adjust_expression(expr, &module_map.global_expressions);
294            true
295        } else {
296            false
297        }
298    });
299
300    // Drop unused constants in place, reusing existing storage.
301    log::trace!("adjusting constants");
302    module.constants.retain_mut(|handle, constant| {
303        if module_map.constants.used(handle) {
304            module_map.types.adjust(&mut constant.ty);
305            module_map.global_expressions.adjust(&mut constant.init);
306            true
307        } else {
308            false
309        }
310    });
311
312    // Drop unused overrides in place, reusing existing storage.
313    log::trace!("adjusting overrides");
314    module.overrides.retain_mut(|handle, r#override| {
315        if module_map.overrides.used(handle) {
316            module_map.types.adjust(&mut r#override.ty);
317            if let Some(ref mut init) = r#override.init {
318                module_map.global_expressions.adjust(init);
319            }
320            true
321        } else {
322            false
323        }
324    });
325
326    // Adjust workgroup_size_overrides
327    log::trace!("adjusting workgroup_size_overrides");
328    for e in module.entry_points.iter_mut() {
329        if let Some(sizes) = e.workgroup_size_overrides.as_mut() {
330            for size in sizes.iter_mut() {
331                if let Some(expr) = size.as_mut() {
332                    module_map.global_expressions.adjust(expr);
333                }
334            }
335        }
336    }
337
338    // Drop unused global variables, reusing existing storage.
339    // Adjust used global variables' types and initializers.
340    log::trace!("adjusting global variables");
341    module.global_variables.retain_mut(|handle, global| {
342        if module_map.globals.used(handle) {
343            log::trace!("retaining global variable {:?}", global.name);
344            module_map.types.adjust(&mut global.ty);
345            if let Some(ref mut init) = global.init {
346                module_map.global_expressions.adjust(init);
347            }
348            true
349        } else {
350            log::trace!("dropping global variable {:?}", global.name);
351            false
352        }
353    });
354
355    // Adjust doc comments
356    if let Some(ref mut doc_comments) = module.doc_comments {
357        module_map.adjust_doc_comments(doc_comments.as_mut());
358    }
359
360    // Temporary storage to help us reuse allocations of existing
361    // named expression tables.
362    let mut reused_named_expressions = crate::NamedExpressions::default();
363
364    // Drop unused functions. Compact and adjust used functions.
365    module.functions.retain_mut(|handle, function| {
366        if let Some(map) = function_maps.get(handle) {
367            log::trace!("retaining and compacting function {:?}", function.name);
368            map.compact(function, &module_map, &mut reused_named_expressions);
369            true
370        } else {
371            log::trace!("dropping function {:?}", function.name);
372            false
373        }
374    });
375
376    // Compact each entry point.
377    for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) {
378        log::trace!("compacting entry point {:?}", entry.function.name);
379        map.compact(
380            &mut entry.function,
381            &module_map,
382            &mut reused_named_expressions,
383        );
384        if let Some(ref mut task_payload) = entry.task_payload {
385            module_map.globals.adjust(task_payload);
386        }
387        if let Some(ref mut mesh_info) = entry.mesh_info {
388            module_map.types.adjust(&mut mesh_info.vertex_output_type);
389            module_map
390                .types
391                .adjust(&mut mesh_info.primitive_output_type);
392            if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override {
393                module_map.global_expressions.adjust(max_vertices_override);
394            }
395            if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override {
396                module_map
397                    .global_expressions
398                    .adjust(max_primitives_override);
399            }
400        }
401    }
402}
403
404struct ModuleTracer<'module> {
405    module: &'module crate::Module,
406
407    /// The subset of functions in `functions_used` that have not yet been
408    /// traced.
409    functions_pending: HandleSet<crate::Function>,
410
411    functions_used: HandleSet<crate::Function>,
412    types_used: HandleSet<crate::Type>,
413    global_variables_used: HandleSet<crate::GlobalVariable>,
414    constants_used: HandleSet<crate::Constant>,
415    overrides_used: HandleSet<crate::Override>,
416    global_expressions_used: HandleSet<crate::Expression>,
417}
418
419impl<'module> ModuleTracer<'module> {
420    fn new(module: &'module crate::Module) -> Self {
421        Self {
422            module,
423            functions_pending: HandleSet::for_arena(&module.functions),
424            functions_used: HandleSet::for_arena(&module.functions),
425            types_used: HandleSet::for_arena(&module.types),
426            global_variables_used: HandleSet::for_arena(&module.global_variables),
427            constants_used: HandleSet::for_arena(&module.constants),
428            overrides_used: HandleSet::for_arena(&module.overrides),
429            global_expressions_used: HandleSet::for_arena(&module.global_expressions),
430        }
431    }
432
433    fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) {
434        let crate::SpecialTypes {
435            ref ray_desc,
436            ref ray_intersection,
437            ref ray_vertex_return,
438            ref predeclared_types,
439            ref external_texture_params,
440            ref external_texture_transfer_function,
441        } = *special_types;
442
443        if let Some(ray_desc) = *ray_desc {
444            self.types_used.insert(ray_desc);
445        }
446        if let Some(ray_intersection) = *ray_intersection {
447            self.types_used.insert(ray_intersection);
448        }
449        if let Some(ray_vertex_return) = *ray_vertex_return {
450            self.types_used.insert(ray_vertex_return);
451        }
452        // The `external_texture_params` type is generated purely as a
453        // convenience to the backends. While it will never actually be used in
454        // the IR, it must be marked as used so that it survives compaction.
455        if let Some(external_texture_params) = *external_texture_params {
456            self.types_used.insert(external_texture_params);
457        }
458        if let Some(external_texture_transfer_function) = *external_texture_transfer_function {
459            self.types_used.insert(external_texture_transfer_function);
460        }
461        for (_, &handle) in predeclared_types {
462            self.types_used.insert(handle);
463        }
464    }
465
466    /// Traverse types and global expressions in tandem to determine which are used.
467    ///
468    /// Assuming that all types and global expressions used by other parts of
469    /// the module have been added to [`types_used`] and
470    /// [`global_expressions_used`], expand those sets to include all types and
471    /// global expressions reachable from those.
472    ///
473    /// [`types_used`]: ModuleTracer::types_used
474    /// [`global_expressions_used`]: ModuleTracer::global_expressions_used
475    fn type_expression_tandem(&mut self) {
476        // For each type T, compute the latest global expression E that T and
477        // its predecessors refer to. Given the ordering rules on types and
478        // global expressions in valid modules, we can do this with a single
479        // forward scan of the type arena. The rules further imply that T can
480        // only be referred to by expressions after E.
481        let mut max_dep = Vec::with_capacity(self.module.types.len());
482        let mut previous = None;
483        for (_handle, ty) in self.module.types.iter() {
484            previous = core::cmp::max(
485                previous,
486                match ty.inner {
487                    crate::TypeInner::Array { size, .. }
488                    | crate::TypeInner::BindingArray { size, .. } => match size {
489                        crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
490                        crate::ArraySize::Pending(handle) => self.module.overrides[handle].init,
491                    },
492                    _ => None,
493                },
494            );
495            max_dep.push(previous);
496        }
497
498        // Visit types and global expressions from youngest to oldest.
499        //
500        // The outer loop visits types. Before visiting each type, the inner
501        // loop ensures that all global expressions that could possibly refer to
502        // it have been visited. And since the inner loop stop at the latest
503        // expression that the type could possibly refer to, we know that we
504        // have previously visited any types that might refer to each expression
505        // we visit.
506        //
507        // This lets us assume that any type or expression that is *not* marked
508        // as used by the time we visit it is genuinely unused, and can be
509        // ignored.
510        let mut exprs = self.module.global_expressions.iter().rev().peekable();
511
512        for ((ty_handle, ty), dep) in self.module.types.iter().zip(max_dep).rev() {
513            while let Some((expr_handle, expr)) = exprs.next_if(|&(h, _)| Some(h) > dep) {
514                if self.global_expressions_used.contains(expr_handle) {
515                    self.as_const_expression().trace_expression(expr);
516                }
517            }
518            if self.types_used.contains(ty_handle) {
519                self.as_type().trace_type(ty);
520            }
521        }
522        // Visit any remaining expressions.
523        for (expr_handle, expr) in exprs {
524            if self.global_expressions_used.contains(expr_handle) {
525                self.as_const_expression().trace_expression(expr);
526            }
527        }
528    }
529
530    fn as_type(&mut self) -> types::TypeTracer<'_> {
531        types::TypeTracer {
532            overrides: &self.module.overrides,
533            types_used: &mut self.types_used,
534            expressions_used: &mut self.global_expressions_used,
535            overrides_used: &mut self.overrides_used,
536        }
537    }
538
539    fn as_const_expression(&mut self) -> expressions::ExpressionTracer<'_> {
540        expressions::ExpressionTracer {
541            constants: &self.module.constants,
542            overrides: &self.module.overrides,
543            expressions: &self.module.global_expressions,
544            types_used: &mut self.types_used,
545            global_variables_used: &mut self.global_variables_used,
546            constants_used: &mut self.constants_used,
547            expressions_used: &mut self.global_expressions_used,
548            overrides_used: &mut self.overrides_used,
549            global_expressions_used: None,
550        }
551    }
552
553    pub fn as_function<'tracer>(
554        &'tracer mut self,
555        function: &'tracer crate::Function,
556    ) -> FunctionTracer<'tracer> {
557        FunctionTracer {
558            function,
559            constants: &self.module.constants,
560            overrides: &self.module.overrides,
561            functions_pending: &mut self.functions_pending,
562            functions_used: &mut self.functions_used,
563            types_used: &mut self.types_used,
564            global_variables_used: &mut self.global_variables_used,
565            constants_used: &mut self.constants_used,
566            overrides_used: &mut self.overrides_used,
567            global_expressions_used: &mut self.global_expressions_used,
568            expressions_used: HandleSet::for_arena(&function.expressions),
569        }
570    }
571}
572
573struct ModuleMap {
574    functions: HandleMap<crate::Function>,
575    types: HandleMap<crate::Type>,
576    globals: HandleMap<crate::GlobalVariable>,
577    constants: HandleMap<crate::Constant>,
578    overrides: HandleMap<crate::Override>,
579    global_expressions: HandleMap<crate::Expression>,
580}
581
582impl From<ModuleTracer<'_>> for ModuleMap {
583    fn from(used: ModuleTracer) -> Self {
584        ModuleMap {
585            functions: HandleMap::from_set(used.functions_used),
586            types: HandleMap::from_set(used.types_used),
587            globals: HandleMap::from_set(used.global_variables_used),
588            constants: HandleMap::from_set(used.constants_used),
589            overrides: HandleMap::from_set(used.overrides_used),
590            global_expressions: HandleMap::from_set(used.global_expressions_used),
591        }
592    }
593}
594
595impl ModuleMap {
596    fn adjust_special_types(&self, special: &mut crate::SpecialTypes) {
597        let crate::SpecialTypes {
598            ref mut ray_desc,
599            ref mut ray_intersection,
600            ref mut ray_vertex_return,
601            ref mut predeclared_types,
602            ref mut external_texture_params,
603            ref mut external_texture_transfer_function,
604        } = *special;
605
606        if let Some(ref mut ray_desc) = *ray_desc {
607            self.types.adjust(ray_desc);
608        }
609        if let Some(ref mut ray_intersection) = *ray_intersection {
610            self.types.adjust(ray_intersection);
611        }
612
613        if let Some(ref mut ray_vertex_return) = *ray_vertex_return {
614            self.types.adjust(ray_vertex_return);
615        }
616
617        if let Some(ref mut external_texture_params) = *external_texture_params {
618            self.types.adjust(external_texture_params);
619        }
620
621        if let Some(ref mut external_texture_transfer_function) =
622            *external_texture_transfer_function
623        {
624            self.types.adjust(external_texture_transfer_function);
625        }
626
627        for handle in predeclared_types.values_mut() {
628            self.types.adjust(handle);
629        }
630    }
631
632    fn adjust_doc_comments(&self, doc_comments: &mut ir::DocComments) {
633        let crate::DocComments {
634            module: _,
635            types: ref mut doc_types,
636            struct_members: ref mut doc_struct_members,
637            entry_points: _,
638            functions: ref mut doc_functions,
639            constants: ref mut doc_constants,
640            global_variables: ref mut doc_globals,
641        } = *doc_comments;
642        log::trace!("adjusting doc comments for types");
643        for (mut ty, doc_comment) in core::mem::take(doc_types) {
644            if !self.types.used(ty) {
645                continue;
646            }
647            self.types.adjust(&mut ty);
648            doc_types.insert(ty, doc_comment);
649        }
650        log::trace!("adjusting doc comments for struct members");
651        for ((mut ty, index), doc_comment) in core::mem::take(doc_struct_members) {
652            if !self.types.used(ty) {
653                continue;
654            }
655            self.types.adjust(&mut ty);
656            doc_struct_members.insert((ty, index), doc_comment);
657        }
658        log::trace!("adjusting doc comments for functions");
659        for (mut handle, doc_comment) in core::mem::take(doc_functions) {
660            if !self.functions.used(handle) {
661                continue;
662            }
663            self.functions.adjust(&mut handle);
664            doc_functions.insert(handle, doc_comment);
665        }
666        log::trace!("adjusting doc comments for constants");
667        for (mut constant, doc_comment) in core::mem::take(doc_constants) {
668            if !self.constants.used(constant) {
669                continue;
670            }
671            self.constants.adjust(&mut constant);
672            doc_constants.insert(constant, doc_comment);
673        }
674        log::trace!("adjusting doc comments for globals");
675        for (mut handle, doc_comment) in core::mem::take(doc_globals) {
676            if !self.globals.used(handle) {
677                continue;
678            }
679            self.globals.adjust(&mut handle);
680            doc_globals.insert(handle, doc_comment);
681        }
682    }
683}
684
685struct FunctionMap {
686    expressions: HandleMap<crate::Expression>,
687}
688
689impl From<FunctionTracer<'_>> for FunctionMap {
690    fn from(used: FunctionTracer) -> Self {
691        FunctionMap {
692            expressions: HandleMap::from_set(used.expressions_used),
693        }
694    }
695}
696
697#[test]
698fn type_expression_interdependence() {
699    let mut module: crate::Module = Default::default();
700    let u32 = module.types.insert(
701        crate::Type {
702            name: None,
703            inner: crate::TypeInner::Scalar(crate::Scalar {
704                kind: crate::ScalarKind::Uint,
705                width: 4,
706            }),
707        },
708        crate::Span::default(),
709    );
710    let expr = module.global_expressions.append(
711        crate::Expression::Literal(crate::Literal::U32(0)),
712        crate::Span::default(),
713    );
714    let type_needs_expression = |module: &mut crate::Module, handle| {
715        let override_handle = module.overrides.append(
716            crate::Override {
717                name: None,
718                id: None,
719                ty: u32,
720                init: Some(handle),
721            },
722            crate::Span::default(),
723        );
724        module.types.insert(
725            crate::Type {
726                name: None,
727                inner: crate::TypeInner::Array {
728                    base: u32,
729                    size: crate::ArraySize::Pending(override_handle),
730                    stride: 4,
731                },
732            },
733            crate::Span::default(),
734        )
735    };
736    let expression_needs_type = |module: &mut crate::Module, handle| {
737        module
738            .global_expressions
739            .append(crate::Expression::ZeroValue(handle), crate::Span::default())
740    };
741    let expression_needs_expression = |module: &mut crate::Module, handle| {
742        module.global_expressions.append(
743            crate::Expression::Load { pointer: handle },
744            crate::Span::default(),
745        )
746    };
747    let type_needs_type = |module: &mut crate::Module, handle| {
748        module.types.insert(
749            crate::Type {
750                name: None,
751                inner: crate::TypeInner::Array {
752                    base: handle,
753                    size: crate::ArraySize::Dynamic,
754                    stride: 0,
755                },
756            },
757            crate::Span::default(),
758        )
759    };
760    let mut type_name_counter = 0;
761    let mut type_needed = |module: &mut crate::Module, handle| {
762        let name = Some(format!("type{type_name_counter}"));
763        type_name_counter += 1;
764        module.types.insert(
765            crate::Type {
766                name,
767                inner: crate::TypeInner::Array {
768                    base: handle,
769                    size: crate::ArraySize::Dynamic,
770                    stride: 0,
771                },
772            },
773            crate::Span::default(),
774        )
775    };
776    let mut override_name_counter = 0;
777    let mut expression_needed = |module: &mut crate::Module, handle| {
778        let name = Some(format!("override{override_name_counter}"));
779        override_name_counter += 1;
780        module.overrides.append(
781            crate::Override {
782                name,
783                id: None,
784                ty: u32,
785                init: Some(handle),
786            },
787            crate::Span::default(),
788        )
789    };
790    let cmp_modules = |mod0: &crate::Module, mod1: &crate::Module| {
791        (mod0.types.iter().collect::<Vec<_>>() == mod1.types.iter().collect::<Vec<_>>())
792            && (mod0.global_expressions.iter().collect::<Vec<_>>()
793                == mod1.global_expressions.iter().collect::<Vec<_>>())
794    };
795    // borrow checker breaks without the tmp variables as of Rust 1.83.0
796    let expr_end = type_needs_expression(&mut module, expr);
797    let ty_trace = type_needs_type(&mut module, expr_end);
798    let expr_init = expression_needs_type(&mut module, ty_trace);
799    expression_needed(&mut module, expr_init);
800    let ty_end = expression_needs_type(&mut module, u32);
801    let expr_trace = expression_needs_expression(&mut module, ty_end);
802    let ty_init = type_needs_expression(&mut module, expr_trace);
803    type_needed(&mut module, ty_init);
804    let untouched = module.clone();
805    compact(&mut module, KeepUnused::Yes);
806    assert!(cmp_modules(&module, &untouched));
807    let unused_expr = module.global_expressions.append(
808        crate::Expression::Literal(crate::Literal::U32(1)),
809        crate::Span::default(),
810    );
811    type_needs_expression(&mut module, unused_expr);
812    assert!(!cmp_modules(&module, &untouched));
813    compact(&mut module, KeepUnused::Yes);
814    assert!(cmp_modules(&module, &untouched));
815}
816
817#[test]
818fn array_length_override() {
819    let mut module: crate::Module = Default::default();
820    let ty_bool = module.types.insert(
821        crate::Type {
822            name: None,
823            inner: crate::TypeInner::Scalar(crate::Scalar::BOOL),
824        },
825        crate::Span::default(),
826    );
827    let ty_u32 = module.types.insert(
828        crate::Type {
829            name: None,
830            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
831        },
832        crate::Span::default(),
833    );
834    let one = module.global_expressions.append(
835        crate::Expression::Literal(crate::Literal::U32(1)),
836        crate::Span::default(),
837    );
838    let _unused_override = module.overrides.append(
839        crate::Override {
840            name: None,
841            id: Some(40),
842            ty: ty_u32,
843            init: None,
844        },
845        crate::Span::default(),
846    );
847    let o = module.overrides.append(
848        crate::Override {
849            name: None,
850            id: Some(42),
851            ty: ty_u32,
852            init: Some(one),
853        },
854        crate::Span::default(),
855    );
856    let _ty_array = module.types.insert(
857        crate::Type {
858            name: Some("array<bool, o>".to_string()),
859            inner: crate::TypeInner::Array {
860                base: ty_bool,
861                size: crate::ArraySize::Pending(o),
862                stride: 4,
863            },
864        },
865        crate::Span::default(),
866    );
867
868    let mut validator = super::valid::Validator::new(
869        super::valid::ValidationFlags::all(),
870        super::valid::Capabilities::all(),
871    );
872
873    assert!(validator.validate(&module).is_ok());
874    compact(&mut module, KeepUnused::Yes);
875    assert!(validator.validate(&module).is_ok());
876}
877
878/// Test mutual references between types and expressions via override
879/// lengths.
880#[test]
881fn array_length_override_mutual() {
882    use crate::Expression as Ex;
883    use crate::Scalar as Sc;
884    use crate::TypeInner as Ti;
885
886    let nowhere = crate::Span::default();
887    let mut module = crate::Module::default();
888    let ty_u32 = module.types.insert(
889        crate::Type {
890            name: None,
891            inner: Ti::Scalar(Sc::U32),
892        },
893        nowhere,
894    );
895
896    // This type is only referred to by the override's init
897    // expression, so if we visit that too early, this type will be
898    // removed incorrectly.
899    let ty_i32 = module.types.insert(
900        crate::Type {
901            name: None,
902            inner: Ti::Scalar(Sc::I32),
903        },
904        nowhere,
905    );
906
907    // An override that the other override's init can refer to.
908    let first_override = module.overrides.append(
909        crate::Override {
910            name: None, // so it is not considered used by definition
911            id: Some(41),
912            ty: ty_i32,
913            init: None,
914        },
915        nowhere,
916    );
917
918    // Initializer expression for the override:
919    //
920    //     (first_override + 0) as u32
921    //
922    // The `first_override` makes it an override expression; the `0`
923    // gets a use of `ty_i32` in there; and the `as` makes it match
924    // the type of `second_override` without actually making
925    // `second_override` point at `ty_i32` directly.
926    let first_override_expr = module
927        .global_expressions
928        .append(Ex::Override(first_override), nowhere);
929    let zero = module
930        .global_expressions
931        .append(Ex::ZeroValue(ty_i32), nowhere);
932    let sum = module.global_expressions.append(
933        Ex::Binary {
934            op: crate::BinaryOperator::Add,
935            left: first_override_expr,
936            right: zero,
937        },
938        nowhere,
939    );
940    let init = module.global_expressions.append(
941        Ex::As {
942            expr: sum,
943            kind: crate::ScalarKind::Uint,
944            convert: None,
945        },
946        nowhere,
947    );
948
949    // Override that serves as the array's length.
950    let second_override = module.overrides.append(
951        crate::Override {
952            name: None, // so it is not considered used by definition
953            id: Some(42),
954            ty: ty_u32,
955            init: Some(init),
956        },
957        nowhere,
958    );
959
960    // Array type that uses the overload as its length.
961    // Since this is named, it is considered used by definition.
962    let _ty_array = module.types.insert(
963        crate::Type {
964            name: Some("delicious_array".to_string()),
965            inner: Ti::Array {
966                base: ty_u32,
967                size: crate::ArraySize::Pending(second_override),
968                stride: 4,
969            },
970        },
971        nowhere,
972    );
973
974    let mut validator = super::valid::Validator::new(
975        super::valid::ValidationFlags::all(),
976        super::valid::Capabilities::all(),
977    );
978
979    assert!(validator.validate(&module).is_ok());
980    compact(&mut module, KeepUnused::Yes);
981    assert!(validator.validate(&module).is_ok());
982}
983
984#[test]
985fn array_length_expression() {
986    let mut module: crate::Module = Default::default();
987    let ty_u32 = module.types.insert(
988        crate::Type {
989            name: None,
990            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
991        },
992        crate::Span::default(),
993    );
994    let _unused_zero = module.global_expressions.append(
995        crate::Expression::Literal(crate::Literal::U32(0)),
996        crate::Span::default(),
997    );
998    let one = module.global_expressions.append(
999        crate::Expression::Literal(crate::Literal::U32(1)),
1000        crate::Span::default(),
1001    );
1002    let override_one = module.overrides.append(
1003        crate::Override {
1004            name: None,
1005            id: None,
1006            ty: ty_u32,
1007            init: Some(one),
1008        },
1009        crate::Span::default(),
1010    );
1011    let _ty_array = module.types.insert(
1012        crate::Type {
1013            name: Some("array<u32, 1>".to_string()),
1014            inner: crate::TypeInner::Array {
1015                base: ty_u32,
1016                size: crate::ArraySize::Pending(override_one),
1017                stride: 4,
1018            },
1019        },
1020        crate::Span::default(),
1021    );
1022
1023    let mut validator = super::valid::Validator::new(
1024        super::valid::ValidationFlags::all(),
1025        super::valid::Capabilities::all(),
1026    );
1027
1028    assert!(validator.validate(&module).is_ok());
1029    compact(&mut module, KeepUnused::Yes);
1030    assert!(validator.validate(&module).is_ok());
1031}
1032
1033#[test]
1034fn global_expression_override() {
1035    let mut module: crate::Module = Default::default();
1036    let ty_u32 = module.types.insert(
1037        crate::Type {
1038            name: None,
1039            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1040        },
1041        crate::Span::default(),
1042    );
1043
1044    // This will only be retained if we trace the initializers
1045    // of overrides referred to by `Expression::Override`
1046    // in global expressions.
1047    let expr1 = module.global_expressions.append(
1048        crate::Expression::Literal(crate::Literal::U32(1)),
1049        crate::Span::default(),
1050    );
1051
1052    // This will only be traced via a global `Expression::Override`.
1053    let o = module.overrides.append(
1054        crate::Override {
1055            name: None,
1056            id: Some(42),
1057            ty: ty_u32,
1058            init: Some(expr1),
1059        },
1060        crate::Span::default(),
1061    );
1062
1063    // This is retained by _p.
1064    let expr2 = module
1065        .global_expressions
1066        .append(crate::Expression::Override(o), crate::Span::default());
1067
1068    // Since this is named, it will be retained.
1069    let _p = module.overrides.append(
1070        crate::Override {
1071            name: Some("p".to_string()),
1072            id: None,
1073            ty: ty_u32,
1074            init: Some(expr2),
1075        },
1076        crate::Span::default(),
1077    );
1078
1079    let mut validator = super::valid::Validator::new(
1080        super::valid::ValidationFlags::all(),
1081        super::valid::Capabilities::all(),
1082    );
1083
1084    assert!(validator.validate(&module).is_ok());
1085    compact(&mut module, KeepUnused::Yes);
1086    assert!(validator.validate(&module).is_ok());
1087}
1088
1089#[test]
1090fn local_expression_override() {
1091    let mut module: crate::Module = Default::default();
1092    let ty_u32 = module.types.insert(
1093        crate::Type {
1094            name: None,
1095            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1096        },
1097        crate::Span::default(),
1098    );
1099
1100    // This will only be retained if we trace the initializers
1101    // of overrides referred to by `Expression::Override` in a function.
1102    let expr1 = module.global_expressions.append(
1103        crate::Expression::Literal(crate::Literal::U32(1)),
1104        crate::Span::default(),
1105    );
1106
1107    // This will be removed by compaction.
1108    let _unused_override = module.overrides.append(
1109        crate::Override {
1110            name: None,
1111            id: Some(41),
1112            ty: ty_u32,
1113            init: None,
1114        },
1115        crate::Span::default(),
1116    );
1117
1118    // This will only be traced via an `Expression::Override` in a function.
1119    let o = module.overrides.append(
1120        crate::Override {
1121            name: None,
1122            id: Some(42),
1123            ty: ty_u32,
1124            init: Some(expr1),
1125        },
1126        crate::Span::default(),
1127    );
1128
1129    let mut fun = crate::Function {
1130        result: Some(crate::FunctionResult {
1131            ty: ty_u32,
1132            binding: None,
1133        }),
1134        ..crate::Function::default()
1135    };
1136
1137    // This is used by the `Return` statement.
1138    let o_expr = fun
1139        .expressions
1140        .append(crate::Expression::Override(o), crate::Span::default());
1141    fun.body.push(
1142        crate::Statement::Return {
1143            value: Some(o_expr),
1144        },
1145        crate::Span::default(),
1146    );
1147
1148    module.functions.append(fun, crate::Span::default());
1149
1150    let mut validator = super::valid::Validator::new(
1151        super::valid::ValidationFlags::all(),
1152        super::valid::Capabilities::all(),
1153    );
1154
1155    assert!(validator.validate(&module).is_ok());
1156    compact(&mut module, KeepUnused::Yes);
1157    assert!(validator.validate(&module).is_ok());
1158}
1159
1160#[test]
1161fn unnamed_constant_type() {
1162    let mut module = crate::Module::default();
1163    let nowhere = crate::Span::default();
1164
1165    // This type is used only by the unnamed constant.
1166    let ty_u32 = module.types.insert(
1167        crate::Type {
1168            name: None,
1169            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1170        },
1171        nowhere,
1172    );
1173
1174    // This type is used by the named constant.
1175    let ty_vec_u32 = module.types.insert(
1176        crate::Type {
1177            name: None,
1178            inner: crate::TypeInner::Vector {
1179                size: crate::VectorSize::Bi,
1180                scalar: crate::Scalar::U32,
1181            },
1182        },
1183        nowhere,
1184    );
1185
1186    let unnamed_init = module
1187        .global_expressions
1188        .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere);
1189
1190    let unnamed_constant = module.constants.append(
1191        crate::Constant {
1192            name: None,
1193            ty: ty_u32,
1194            init: unnamed_init,
1195        },
1196        nowhere,
1197    );
1198
1199    // The named constant is initialized using a Splat expression, to
1200    // give the named constant a type distinct from the unnamed
1201    // constant's.
1202    let unnamed_constant_expr = module
1203        .global_expressions
1204        .append(crate::Expression::Constant(unnamed_constant), nowhere);
1205    let named_init = module.global_expressions.append(
1206        crate::Expression::Splat {
1207            size: crate::VectorSize::Bi,
1208            value: unnamed_constant_expr,
1209        },
1210        nowhere,
1211    );
1212
1213    let _named_constant = module.constants.append(
1214        crate::Constant {
1215            name: Some("totally_named".to_string()),
1216            ty: ty_vec_u32,
1217            init: named_init,
1218        },
1219        nowhere,
1220    );
1221
1222    let mut validator = super::valid::Validator::new(
1223        super::valid::ValidationFlags::all(),
1224        super::valid::Capabilities::all(),
1225    );
1226
1227    assert!(validator.validate(&module).is_ok());
1228    compact(&mut module, KeepUnused::Yes);
1229    assert!(validator.validate(&module).is_ok());
1230}
1231
1232#[test]
1233fn unnamed_override_type() {
1234    let mut module = crate::Module::default();
1235    let nowhere = crate::Span::default();
1236
1237    // This type is used only by the unnamed override.
1238    let ty_u32 = module.types.insert(
1239        crate::Type {
1240            name: None,
1241            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1242        },
1243        nowhere,
1244    );
1245
1246    // This type is used by the named override.
1247    let ty_i32 = module.types.insert(
1248        crate::Type {
1249            name: None,
1250            inner: crate::TypeInner::Scalar(crate::Scalar::I32),
1251        },
1252        nowhere,
1253    );
1254
1255    let unnamed_init = module
1256        .global_expressions
1257        .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere);
1258
1259    let unnamed_override = module.overrides.append(
1260        crate::Override {
1261            name: None,
1262            id: Some(42),
1263            ty: ty_u32,
1264            init: Some(unnamed_init),
1265        },
1266        nowhere,
1267    );
1268
1269    // The named override is initialized using a Splat expression, to
1270    // give the named override a type distinct from the unnamed
1271    // override's.
1272    let unnamed_override_expr = module
1273        .global_expressions
1274        .append(crate::Expression::Override(unnamed_override), nowhere);
1275    let named_init = module.global_expressions.append(
1276        crate::Expression::As {
1277            expr: unnamed_override_expr,
1278            kind: crate::ScalarKind::Sint,
1279            convert: None,
1280        },
1281        nowhere,
1282    );
1283
1284    let _named_override = module.overrides.append(
1285        crate::Override {
1286            name: Some("totally_named".to_string()),
1287            id: None,
1288            ty: ty_i32,
1289            init: Some(named_init),
1290        },
1291        nowhere,
1292    );
1293
1294    let mut validator = super::valid::Validator::new(
1295        super::valid::ValidationFlags::all(),
1296        super::valid::Capabilities::all(),
1297    );
1298
1299    assert!(validator.validate(&module).is_ok());
1300    compact(&mut module, KeepUnused::Yes);
1301    assert!(validator.validate(&module).is_ok());
1302}