naga/compact/
mod.rs

1mod expressions;
2mod functions;
3mod handle_set_map;
4mod statements;
5mod types;
6
7use alloc::vec::Vec;
8
9use crate::{
10    arena::{self, HandleSet},
11    compact::functions::FunctionTracer,
12    ir,
13};
14use handle_set_map::HandleMap;
15
16#[cfg(test)]
17use alloc::{format, string::ToString};
18
19/// Configuration option for [`compact`]. See [`compact`] for details.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum KeepUnused {
22    No,
23    Yes,
24}
25
26impl From<KeepUnused> for bool {
27    fn from(keep_unused: KeepUnused) -> Self {
28        match keep_unused {
29            KeepUnused::No => false,
30            KeepUnused::Yes => true,
31        }
32    }
33}
34
35/// Remove most unused objects from `module`, which must be valid.
36///
37/// Always removes the following unused objects:
38/// - anonymous types, overrides, and constants
39/// - abstract-typed constants
40/// - expressions
41///
42/// If `keep_unused` is `Yes`, the following are never considered unused,
43/// otherwise, they will also be removed if unused:
44/// - functions
45/// - global variables
46/// - named types and overrides
47///
48/// The following are never removed:
49/// - named constants with a concrete type
50/// - special types
51/// - entry points
52/// - within an entry point or a used function:
53///     - arguments
54///     - local variables
55///     - named expressions
56///
57/// After removing items according to the rules above, all handles in the
58/// remaining objects are adjusted as necessary. When `KeepUnused` is `Yes`, the
59/// resulting module should have all the named objects (except abstract-typed
60/// constants) present in the original, and those objects should be functionally
61/// identical. When `KeepUnused` is `No`, the resulting module should have the
62/// entry points present in the original, and those entry points should be
63/// functionally identical.
64///
65/// # Panics
66///
67/// If `module` would not pass validation, this may panic.
68pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) {
69    // The trickiest part of compaction is determining what is used and what is
70    // not. Once we have computed that correctly, it's easy enough to call
71    // `retain_mut` on each arena, drop unused elements, and fix up the handles
72    // in what's left.
73    //
74    // For every compactable arena in a `Module`, whether global to the `Module`
75    // or local to a function or entry point, the `ModuleTracer` type holds a
76    // bitmap indicating which elements of that arena are used. Our task is to
77    // populate those bitmaps correctly.
78    //
79    // First, we mark everything that is considered used by definition, as
80    // described in this function's documentation.
81    //
82    // Since functions and entry points are considered used by definition, we
83    // traverse their statement trees, and mark the referents of all handles
84    // appearing in those statements as used.
85    //
86    // Once we've marked which elements of an arena are referred to directly by
87    // handles elsewhere (for example, which of a function's expressions are
88    // referred to by handles in its body statements), we can mark all the other
89    // arena elements that are used indirectly in a single pass, traversing the
90    // arena from back to front. Since Naga allows arena elements to refer only
91    // to prior elements, we know that by the time we reach an element, all
92    // other elements that could possibly refer to it have already been visited.
93    // Thus, if the present element has not been marked as used, then it is
94    // definitely unused, and compaction can remove it. Otherwise, the element
95    // is used and must be retained, so we must mark everything it refers to.
96    //
97    // The final step is to mark the global expressions and types, which must be
98    // traversed simultaneously; see `ModuleTracer::type_expression_tandem`'s
99    // documentation for details.
100    //
101    // # A definition and a rule of thumb
102    //
103    // In this module, to "trace" something is to mark everything else it refers
104    // to as used, on the assumption that the thing itself is used. For example,
105    // to trace an `Expression` is to mark its subexpressions as used, as well
106    // as any types, constants, overrides, etc. that it refers to. This is what
107    // `ExpressionTracer::trace_expression` does.
108    //
109    // Given that we we want to visit each thing only once (to keep compaction
110    // linear in the size of the module), this definition of "trace" implies
111    // that things that are not "used by definition" must be marked as used
112    // *before* we trace them.
113    //
114    // Thus, whenever you are marking something as used, it's a good idea to ask
115    // yourself how you know that thing will be traced in the future. If you're
116    // not sure, then you could be marking it too late to be noticed. The thing
117    // itself will be retained by compaction, but since it will not be traced,
118    // anything it refers to could be compacted away.
119    let mut module_tracer = ModuleTracer::new(module);
120
121    // Observe what each entry point actually uses.
122    log::trace!("tracing entry points");
123    let entry_point_maps = module
124        .entry_points
125        .iter()
126        .map(|e| {
127            log::trace!("tracing entry point {:?}", e.function.name);
128
129            if let Some(sizes) = e.workgroup_size_overrides {
130                for size in sizes.iter().filter_map(|x| *x) {
131                    module_tracer.global_expressions_used.insert(size);
132                }
133            }
134
135            if let Some(task_payload) = e.task_payload {
136                module_tracer.global_variables_used.insert(task_payload);
137            }
138            if let Some(ref mesh_info) = e.mesh_info {
139                module_tracer
140                    .global_variables_used
141                    .insert(mesh_info.output_variable);
142                module_tracer
143                    .types_used
144                    .insert(mesh_info.vertex_output_type);
145                module_tracer
146                    .types_used
147                    .insert(mesh_info.primitive_output_type);
148                if let Some(max_vertices_override) = mesh_info.max_vertices_override {
149                    module_tracer
150                        .global_expressions_used
151                        .insert(max_vertices_override);
152                }
153                if let Some(max_primitives_override) = mesh_info.max_primitives_override {
154                    module_tracer
155                        .global_expressions_used
156                        .insert(max_primitives_override);
157                }
158            }
159            if e.stage == crate::ShaderStage::Task || e.stage == crate::ShaderStage::Mesh {
160                // Mesh shaders always need a u32 type, as it is e.g. the type of some
161                // expressions. We tolerate its absence here because compaction is
162                // infallible, but the module will fail validation.
163                if let Some(u32_type) = module.types.iter().find_map(|tuple| {
164                    (tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32))
165                        .then_some(tuple.0)
166                }) {
167                    module_tracer.types_used.insert(u32_type);
168                }
169            }
170
171            let mut used = module_tracer.as_function(&e.function);
172            used.trace();
173            FunctionMap::from(used)
174        })
175        .collect::<Vec<_>>();
176
177    // Observe which types, constant expressions, constants, and expressions
178    // each function uses, and produce maps for each function from
179    // pre-compaction to post-compaction expression handles.
180    //
181    // The function tracing logic here works in conjunction with
182    // `FunctionTracer::trace_call`, which, when tracing a `Statement::Call`
183    // to a function not already identified as used, adds the called function
184    // to both `functions_used` and `functions_pending`.
185    //
186    // Called functions are required to appear before their callers in the
187    // functions arena (recursion is disallowed). We have already traced the
188    // entry point(s) and added any functions called directly by the entry
189    // point(s) to `functions_pending`. We proceed by repeatedly tracing the
190    // last function in `functions_pending`. By an inductive argument, any
191    // functions after the last function in `functions_pending` must be unused.
192    //
193    // When `KeepUnused` is active, we simply mark all functions as pending,
194    // and then trace all of them.
195    log::trace!("tracing functions");
196    let mut function_maps = HandleMap::with_capacity(module.functions.len());
197    if keep_unused.into() {
198        module_tracer.functions_used.add_all();
199        module_tracer.functions_pending.add_all();
200    }
201    while let Some(handle) = module_tracer.functions_pending.pop() {
202        let function = &module.functions[handle];
203        log::trace!("tracing function {function:?}");
204        let mut function_tracer = module_tracer.as_function(function);
205        function_tracer.trace();
206        function_maps.insert(handle, FunctionMap::from(function_tracer));
207    }
208
209    // We treat all special types as used by definition.
210    log::trace!("tracing special types");
211    module_tracer.trace_special_types(&module.special_types);
212
213    log::trace!("tracing global variables");
214    if keep_unused.into() {
215        module_tracer.global_variables_used.add_all();
216    }
217    for global in module_tracer.global_variables_used.iter() {
218        log::trace!("tracing global {:?}", module.global_variables[global].name);
219        module_tracer
220            .types_used
221            .insert(module.global_variables[global].ty);
222        if let Some(init) = module.global_variables[global].init {
223            module_tracer.global_expressions_used.insert(init);
224        }
225    }
226
227    // We treat all named constants as used by definition, unless they have an
228    // abstract type as we do not want those reaching the validator.
229    log::trace!("tracing named constants");
230    for (handle, constant) in module.constants.iter() {
231        if constant.name.is_none() || module.types[constant.ty].inner.is_abstract(&module.types) {
232            continue;
233        }
234
235        log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap());
236        module_tracer.constants_used.insert(handle);
237        module_tracer.types_used.insert(constant.ty);
238        module_tracer.global_expressions_used.insert(constant.init);
239    }
240
241    if keep_unused.into() {
242        // Treat all named overrides as used.
243        for (handle, r#override) in module.overrides.iter() {
244            if r#override.name.is_some() && module_tracer.overrides_used.insert(handle) {
245                module_tracer.types_used.insert(r#override.ty);
246                if let Some(init) = r#override.init {
247                    module_tracer.global_expressions_used.insert(init);
248                }
249            }
250        }
251
252        // Treat all named types as used.
253        for (handle, ty) in module.types.iter() {
254            if ty.name.is_some() {
255                module_tracer.types_used.insert(handle);
256            }
257        }
258    }
259
260    module_tracer.type_expression_tandem();
261
262    // Now that we know what is used and what is never touched,
263    // produce maps from the `Handle`s that appear in `module` now to
264    // the corresponding `Handle`s that will refer to the same items
265    // in the compacted module.
266    let module_map = ModuleMap::from(module_tracer);
267
268    // Drop unused types from the type arena.
269    //
270    // `FastIndexSet`s don't have an underlying Vec<T> that we can
271    // steal, compact in place, and then rebuild the `FastIndexSet`
272    // from. So we have to rebuild the type arena from scratch.
273    log::trace!("compacting types");
274    let mut new_types = arena::UniqueArena::new();
275    for (old_handle, mut ty, span) in module.types.drain_all() {
276        if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) {
277            module_map.adjust_type(&mut ty);
278            let actual_new_handle = new_types.insert(ty, span);
279            assert_eq!(actual_new_handle, expected_new_handle);
280        }
281    }
282    module.types = new_types;
283    log::trace!("adjusting special types");
284    module_map.adjust_special_types(&mut module.special_types);
285
286    // Drop unused constant expressions, reusing existing storage.
287    log::trace!("adjusting constant expressions");
288    module.global_expressions.retain_mut(|handle, expr| {
289        if module_map.global_expressions.used(handle) {
290            module_map.adjust_expression(expr, &module_map.global_expressions);
291            true
292        } else {
293            false
294        }
295    });
296
297    // Drop unused constants in place, reusing existing storage.
298    log::trace!("adjusting constants");
299    module.constants.retain_mut(|handle, constant| {
300        if module_map.constants.used(handle) {
301            module_map.types.adjust(&mut constant.ty);
302            module_map.global_expressions.adjust(&mut constant.init);
303            true
304        } else {
305            false
306        }
307    });
308
309    // Drop unused overrides in place, reusing existing storage.
310    log::trace!("adjusting overrides");
311    module.overrides.retain_mut(|handle, r#override| {
312        if module_map.overrides.used(handle) {
313            module_map.types.adjust(&mut r#override.ty);
314            if let Some(ref mut init) = r#override.init {
315                module_map.global_expressions.adjust(init);
316            }
317            true
318        } else {
319            false
320        }
321    });
322
323    // Adjust workgroup_size_overrides
324    log::trace!("adjusting workgroup_size_overrides");
325    for e in module.entry_points.iter_mut() {
326        if let Some(sizes) = e.workgroup_size_overrides.as_mut() {
327            for size in sizes.iter_mut() {
328                if let Some(expr) = size.as_mut() {
329                    module_map.global_expressions.adjust(expr);
330                }
331            }
332        }
333    }
334
335    // Drop unused global variables, reusing existing storage.
336    // Adjust used global variables' types and initializers.
337    log::trace!("adjusting global variables");
338    module.global_variables.retain_mut(|handle, global| {
339        if module_map.globals.used(handle) {
340            log::trace!("retaining global variable {:?}", global.name);
341            module_map.types.adjust(&mut global.ty);
342            if let Some(ref mut init) = global.init {
343                module_map.global_expressions.adjust(init);
344            }
345            true
346        } else {
347            log::trace!("dropping global variable {:?}", global.name);
348            false
349        }
350    });
351
352    // Adjust doc comments
353    if let Some(ref mut doc_comments) = module.doc_comments {
354        module_map.adjust_doc_comments(doc_comments.as_mut());
355    }
356
357    // Temporary storage to help us reuse allocations of existing
358    // named expression tables.
359    let mut reused_named_expressions = crate::NamedExpressions::default();
360
361    // Drop unused functions. Compact and adjust used functions.
362    module.functions.retain_mut(|handle, function| {
363        if let Some(map) = function_maps.get(handle) {
364            log::trace!("retaining and compacting function {:?}", function.name);
365            map.compact(function, &module_map, &mut reused_named_expressions);
366            true
367        } else {
368            log::trace!("dropping function {:?}", function.name);
369            false
370        }
371    });
372
373    // Compact each entry point.
374    for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) {
375        log::trace!("compacting entry point {:?}", entry.function.name);
376        map.compact(
377            &mut entry.function,
378            &module_map,
379            &mut reused_named_expressions,
380        );
381        if let Some(ref mut task_payload) = entry.task_payload {
382            module_map.globals.adjust(task_payload);
383        }
384        if let Some(ref mut mesh_info) = entry.mesh_info {
385            module_map.globals.adjust(&mut mesh_info.output_variable);
386            module_map.types.adjust(&mut mesh_info.vertex_output_type);
387            module_map
388                .types
389                .adjust(&mut mesh_info.primitive_output_type);
390            if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override {
391                module_map.global_expressions.adjust(max_vertices_override);
392            }
393            if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override {
394                module_map
395                    .global_expressions
396                    .adjust(max_primitives_override);
397            }
398        }
399    }
400}
401
402struct ModuleTracer<'module> {
403    module: &'module crate::Module,
404
405    /// The subset of functions in `functions_used` that have not yet been
406    /// traced.
407    functions_pending: HandleSet<crate::Function>,
408
409    functions_used: HandleSet<crate::Function>,
410    types_used: HandleSet<crate::Type>,
411    global_variables_used: HandleSet<crate::GlobalVariable>,
412    constants_used: HandleSet<crate::Constant>,
413    overrides_used: HandleSet<crate::Override>,
414    global_expressions_used: HandleSet<crate::Expression>,
415}
416
417impl<'module> ModuleTracer<'module> {
418    fn new(module: &'module crate::Module) -> Self {
419        Self {
420            module,
421            functions_pending: HandleSet::for_arena(&module.functions),
422            functions_used: HandleSet::for_arena(&module.functions),
423            types_used: HandleSet::for_arena(&module.types),
424            global_variables_used: HandleSet::for_arena(&module.global_variables),
425            constants_used: HandleSet::for_arena(&module.constants),
426            overrides_used: HandleSet::for_arena(&module.overrides),
427            global_expressions_used: HandleSet::for_arena(&module.global_expressions),
428        }
429    }
430
431    fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) {
432        let crate::SpecialTypes {
433            ref ray_desc,
434            ref ray_intersection,
435            ref ray_vertex_return,
436            ref predeclared_types,
437            ref external_texture_params,
438            ref external_texture_transfer_function,
439        } = *special_types;
440
441        if let Some(ray_desc) = *ray_desc {
442            self.types_used.insert(ray_desc);
443        }
444        if let Some(ray_intersection) = *ray_intersection {
445            self.types_used.insert(ray_intersection);
446        }
447        if let Some(ray_vertex_return) = *ray_vertex_return {
448            self.types_used.insert(ray_vertex_return);
449        }
450        // The `external_texture_params` type is generated purely as a
451        // convenience to the backends. While it will never actually be used in
452        // the IR, it must be marked as used so that it survives compaction.
453        if let Some(external_texture_params) = *external_texture_params {
454            self.types_used.insert(external_texture_params);
455        }
456        if let Some(external_texture_transfer_function) = *external_texture_transfer_function {
457            self.types_used.insert(external_texture_transfer_function);
458        }
459        for (_, &handle) in predeclared_types {
460            self.types_used.insert(handle);
461        }
462    }
463
464    /// Traverse types and global expressions in tandem to determine which are used.
465    ///
466    /// Assuming that all types and global expressions used by other parts of
467    /// the module have been added to [`types_used`] and
468    /// [`global_expressions_used`], expand those sets to include all types and
469    /// global expressions reachable from those.
470    ///
471    /// [`types_used`]: ModuleTracer::types_used
472    /// [`global_expressions_used`]: ModuleTracer::global_expressions_used
473    fn type_expression_tandem(&mut self) {
474        // For each type T, compute the latest global expression E that T and
475        // its predecessors refer to. Given the ordering rules on types and
476        // global expressions in valid modules, we can do this with a single
477        // forward scan of the type arena. The rules further imply that T can
478        // only be referred to by expressions after E.
479        let mut max_dep = Vec::with_capacity(self.module.types.len());
480        let mut previous = None;
481        for (_handle, ty) in self.module.types.iter() {
482            previous = core::cmp::max(
483                previous,
484                match ty.inner {
485                    crate::TypeInner::Array { size, .. }
486                    | crate::TypeInner::BindingArray { size, .. } => match size {
487                        crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None,
488                        crate::ArraySize::Pending(handle) => self.module.overrides[handle].init,
489                    },
490                    _ => None,
491                },
492            );
493            max_dep.push(previous);
494        }
495
496        // Visit types and global expressions from youngest to oldest.
497        //
498        // The outer loop visits types. Before visiting each type, the inner
499        // loop ensures that all global expressions that could possibly refer to
500        // it have been visited. And since the inner loop stop at the latest
501        // expression that the type could possibly refer to, we know that we
502        // have previously visited any types that might refer to each expression
503        // we visit.
504        //
505        // This lets us assume that any type or expression that is *not* marked
506        // as used by the time we visit it is genuinely unused, and can be
507        // ignored.
508        let mut exprs = self.module.global_expressions.iter().rev().peekable();
509
510        for ((ty_handle, ty), dep) in self.module.types.iter().zip(max_dep).rev() {
511            while let Some((expr_handle, expr)) = exprs.next_if(|&(h, _)| Some(h) > dep) {
512                if self.global_expressions_used.contains(expr_handle) {
513                    self.as_const_expression().trace_expression(expr);
514                }
515            }
516            if self.types_used.contains(ty_handle) {
517                self.as_type().trace_type(ty);
518            }
519        }
520        // Visit any remaining expressions.
521        for (expr_handle, expr) in exprs {
522            if self.global_expressions_used.contains(expr_handle) {
523                self.as_const_expression().trace_expression(expr);
524            }
525        }
526    }
527
528    const fn as_type(&mut self) -> types::TypeTracer<'_> {
529        types::TypeTracer {
530            overrides: &self.module.overrides,
531            types_used: &mut self.types_used,
532            expressions_used: &mut self.global_expressions_used,
533            overrides_used: &mut self.overrides_used,
534        }
535    }
536
537    const fn as_const_expression(&mut self) -> expressions::ExpressionTracer<'_> {
538        expressions::ExpressionTracer {
539            constants: &self.module.constants,
540            overrides: &self.module.overrides,
541            expressions: &self.module.global_expressions,
542            types_used: &mut self.types_used,
543            global_variables_used: &mut self.global_variables_used,
544            constants_used: &mut self.constants_used,
545            expressions_used: &mut self.global_expressions_used,
546            overrides_used: &mut self.overrides_used,
547            global_expressions_used: None,
548        }
549    }
550
551    pub fn as_function<'tracer>(
552        &'tracer mut self,
553        function: &'tracer crate::Function,
554    ) -> FunctionTracer<'tracer> {
555        FunctionTracer {
556            function,
557            constants: &self.module.constants,
558            overrides: &self.module.overrides,
559            functions_pending: &mut self.functions_pending,
560            functions_used: &mut self.functions_used,
561            types_used: &mut self.types_used,
562            global_variables_used: &mut self.global_variables_used,
563            constants_used: &mut self.constants_used,
564            overrides_used: &mut self.overrides_used,
565            global_expressions_used: &mut self.global_expressions_used,
566            expressions_used: HandleSet::for_arena(&function.expressions),
567        }
568    }
569}
570
571struct ModuleMap {
572    functions: HandleMap<crate::Function>,
573    types: HandleMap<crate::Type>,
574    globals: HandleMap<crate::GlobalVariable>,
575    constants: HandleMap<crate::Constant>,
576    overrides: HandleMap<crate::Override>,
577    global_expressions: HandleMap<crate::Expression>,
578}
579
580impl From<ModuleTracer<'_>> for ModuleMap {
581    fn from(used: ModuleTracer) -> Self {
582        ModuleMap {
583            functions: HandleMap::from_set(used.functions_used),
584            types: HandleMap::from_set(used.types_used),
585            globals: HandleMap::from_set(used.global_variables_used),
586            constants: HandleMap::from_set(used.constants_used),
587            overrides: HandleMap::from_set(used.overrides_used),
588            global_expressions: HandleMap::from_set(used.global_expressions_used),
589        }
590    }
591}
592
593impl ModuleMap {
594    fn adjust_special_types(&self, special: &mut crate::SpecialTypes) {
595        let crate::SpecialTypes {
596            ref mut ray_desc,
597            ref mut ray_intersection,
598            ref mut ray_vertex_return,
599            ref mut predeclared_types,
600            ref mut external_texture_params,
601            ref mut external_texture_transfer_function,
602        } = *special;
603
604        if let Some(ref mut ray_desc) = *ray_desc {
605            self.types.adjust(ray_desc);
606        }
607        if let Some(ref mut ray_intersection) = *ray_intersection {
608            self.types.adjust(ray_intersection);
609        }
610
611        if let Some(ref mut ray_vertex_return) = *ray_vertex_return {
612            self.types.adjust(ray_vertex_return);
613        }
614
615        if let Some(ref mut external_texture_params) = *external_texture_params {
616            self.types.adjust(external_texture_params);
617        }
618
619        if let Some(ref mut external_texture_transfer_function) =
620            *external_texture_transfer_function
621        {
622            self.types.adjust(external_texture_transfer_function);
623        }
624
625        for handle in predeclared_types.values_mut() {
626            self.types.adjust(handle);
627        }
628    }
629
630    fn adjust_doc_comments(&self, doc_comments: &mut ir::DocComments) {
631        let crate::DocComments {
632            module: _,
633            types: ref mut doc_types,
634            struct_members: ref mut doc_struct_members,
635            entry_points: _,
636            functions: ref mut doc_functions,
637            constants: ref mut doc_constants,
638            global_variables: ref mut doc_globals,
639        } = *doc_comments;
640        log::trace!("adjusting doc comments for types");
641        for (mut ty, doc_comment) in core::mem::take(doc_types) {
642            if !self.types.used(ty) {
643                continue;
644            }
645            self.types.adjust(&mut ty);
646            doc_types.insert(ty, doc_comment);
647        }
648        log::trace!("adjusting doc comments for struct members");
649        for ((mut ty, index), doc_comment) in core::mem::take(doc_struct_members) {
650            if !self.types.used(ty) {
651                continue;
652            }
653            self.types.adjust(&mut ty);
654            doc_struct_members.insert((ty, index), doc_comment);
655        }
656        log::trace!("adjusting doc comments for functions");
657        for (mut handle, doc_comment) in core::mem::take(doc_functions) {
658            if !self.functions.used(handle) {
659                continue;
660            }
661            self.functions.adjust(&mut handle);
662            doc_functions.insert(handle, doc_comment);
663        }
664        log::trace!("adjusting doc comments for constants");
665        for (mut constant, doc_comment) in core::mem::take(doc_constants) {
666            if !self.constants.used(constant) {
667                continue;
668            }
669            self.constants.adjust(&mut constant);
670            doc_constants.insert(constant, doc_comment);
671        }
672        log::trace!("adjusting doc comments for globals");
673        for (mut handle, doc_comment) in core::mem::take(doc_globals) {
674            if !self.globals.used(handle) {
675                continue;
676            }
677            self.globals.adjust(&mut handle);
678            doc_globals.insert(handle, doc_comment);
679        }
680    }
681}
682
683struct FunctionMap {
684    expressions: HandleMap<crate::Expression>,
685}
686
687impl From<FunctionTracer<'_>> for FunctionMap {
688    fn from(used: FunctionTracer) -> Self {
689        FunctionMap {
690            expressions: HandleMap::from_set(used.expressions_used),
691        }
692    }
693}
694
695#[test]
696fn type_expression_interdependence() {
697    let mut module: crate::Module = Default::default();
698    let u32 = module.types.insert(
699        crate::Type {
700            name: None,
701            inner: crate::TypeInner::Scalar(crate::Scalar {
702                kind: crate::ScalarKind::Uint,
703                width: 4,
704            }),
705        },
706        crate::Span::default(),
707    );
708    let expr = module.global_expressions.append(
709        crate::Expression::Literal(crate::Literal::U32(0)),
710        crate::Span::default(),
711    );
712    let type_needs_expression = |module: &mut crate::Module, handle| {
713        let override_handle = module.overrides.append(
714            crate::Override {
715                name: None,
716                id: None,
717                ty: u32,
718                init: Some(handle),
719            },
720            crate::Span::default(),
721        );
722        module.types.insert(
723            crate::Type {
724                name: None,
725                inner: crate::TypeInner::Array {
726                    base: u32,
727                    size: crate::ArraySize::Pending(override_handle),
728                    stride: 4,
729                },
730            },
731            crate::Span::default(),
732        )
733    };
734    let expression_needs_type = |module: &mut crate::Module, handle| {
735        module
736            .global_expressions
737            .append(crate::Expression::ZeroValue(handle), crate::Span::default())
738    };
739    let expression_needs_expression = |module: &mut crate::Module, handle| {
740        module.global_expressions.append(
741            crate::Expression::Load { pointer: handle },
742            crate::Span::default(),
743        )
744    };
745    let type_needs_type = |module: &mut crate::Module, handle| {
746        module.types.insert(
747            crate::Type {
748                name: None,
749                inner: crate::TypeInner::Array {
750                    base: handle,
751                    size: crate::ArraySize::Dynamic,
752                    stride: 0,
753                },
754            },
755            crate::Span::default(),
756        )
757    };
758    let mut type_name_counter = 0;
759    let mut type_needed = |module: &mut crate::Module, handle| {
760        let name = Some(format!("type{type_name_counter}"));
761        type_name_counter += 1;
762        module.types.insert(
763            crate::Type {
764                name,
765                inner: crate::TypeInner::Array {
766                    base: handle,
767                    size: crate::ArraySize::Dynamic,
768                    stride: 0,
769                },
770            },
771            crate::Span::default(),
772        )
773    };
774    let mut override_name_counter = 0;
775    let mut expression_needed = |module: &mut crate::Module, handle| {
776        let name = Some(format!("override{override_name_counter}"));
777        override_name_counter += 1;
778        module.overrides.append(
779            crate::Override {
780                name,
781                id: None,
782                ty: u32,
783                init: Some(handle),
784            },
785            crate::Span::default(),
786        )
787    };
788    let cmp_modules = |mod0: &crate::Module, mod1: &crate::Module| {
789        (mod0.types.iter().collect::<Vec<_>>() == mod1.types.iter().collect::<Vec<_>>())
790            && (mod0.global_expressions.iter().collect::<Vec<_>>()
791                == mod1.global_expressions.iter().collect::<Vec<_>>())
792    };
793    // borrow checker breaks without the tmp variables as of Rust 1.83.0
794    let expr_end = type_needs_expression(&mut module, expr);
795    let ty_trace = type_needs_type(&mut module, expr_end);
796    let expr_init = expression_needs_type(&mut module, ty_trace);
797    expression_needed(&mut module, expr_init);
798    let ty_end = expression_needs_type(&mut module, u32);
799    let expr_trace = expression_needs_expression(&mut module, ty_end);
800    let ty_init = type_needs_expression(&mut module, expr_trace);
801    type_needed(&mut module, ty_init);
802    let untouched = module.clone();
803    compact(&mut module, KeepUnused::Yes);
804    assert!(cmp_modules(&module, &untouched));
805    let unused_expr = module.global_expressions.append(
806        crate::Expression::Literal(crate::Literal::U32(1)),
807        crate::Span::default(),
808    );
809    type_needs_expression(&mut module, unused_expr);
810    assert!(!cmp_modules(&module, &untouched));
811    compact(&mut module, KeepUnused::Yes);
812    assert!(cmp_modules(&module, &untouched));
813}
814
815#[test]
816fn array_length_override() {
817    let mut module: crate::Module = Default::default();
818    let ty_bool = module.types.insert(
819        crate::Type {
820            name: None,
821            inner: crate::TypeInner::Scalar(crate::Scalar::BOOL),
822        },
823        crate::Span::default(),
824    );
825    let ty_u32 = module.types.insert(
826        crate::Type {
827            name: None,
828            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
829        },
830        crate::Span::default(),
831    );
832    let one = module.global_expressions.append(
833        crate::Expression::Literal(crate::Literal::U32(1)),
834        crate::Span::default(),
835    );
836    let _unused_override = module.overrides.append(
837        crate::Override {
838            name: None,
839            id: Some(40),
840            ty: ty_u32,
841            init: None,
842        },
843        crate::Span::default(),
844    );
845    let o = module.overrides.append(
846        crate::Override {
847            name: None,
848            id: Some(42),
849            ty: ty_u32,
850            init: Some(one),
851        },
852        crate::Span::default(),
853    );
854    let _ty_array = module.types.insert(
855        crate::Type {
856            name: Some("array<bool, o>".to_string()),
857            inner: crate::TypeInner::Array {
858                base: ty_bool,
859                size: crate::ArraySize::Pending(o),
860                stride: 4,
861            },
862        },
863        crate::Span::default(),
864    );
865
866    let mut validator = super::valid::Validator::new(
867        super::valid::ValidationFlags::all(),
868        super::valid::Capabilities::all(),
869    );
870
871    assert!(validator.validate(&module).is_ok());
872    compact(&mut module, KeepUnused::Yes);
873    assert!(validator.validate(&module).is_ok());
874}
875
876/// Test mutual references between types and expressions via override
877/// lengths.
878#[test]
879fn array_length_override_mutual() {
880    use crate::Expression as Ex;
881    use crate::Scalar as Sc;
882    use crate::TypeInner as Ti;
883
884    let nowhere = crate::Span::default();
885    let mut module = crate::Module::default();
886    let ty_u32 = module.types.insert(
887        crate::Type {
888            name: None,
889            inner: Ti::Scalar(Sc::U32),
890        },
891        nowhere,
892    );
893
894    // This type is only referred to by the override's init
895    // expression, so if we visit that too early, this type will be
896    // removed incorrectly.
897    let ty_i32 = module.types.insert(
898        crate::Type {
899            name: None,
900            inner: Ti::Scalar(Sc::I32),
901        },
902        nowhere,
903    );
904
905    // An override that the other override's init can refer to.
906    let first_override = module.overrides.append(
907        crate::Override {
908            name: None, // so it is not considered used by definition
909            id: Some(41),
910            ty: ty_i32,
911            init: None,
912        },
913        nowhere,
914    );
915
916    // Initializer expression for the override:
917    //
918    //     (first_override + 0) as u32
919    //
920    // The `first_override` makes it an override expression; the `0`
921    // gets a use of `ty_i32` in there; and the `as` makes it match
922    // the type of `second_override` without actually making
923    // `second_override` point at `ty_i32` directly.
924    let first_override_expr = module
925        .global_expressions
926        .append(Ex::Override(first_override), nowhere);
927    let zero = module
928        .global_expressions
929        .append(Ex::ZeroValue(ty_i32), nowhere);
930    let sum = module.global_expressions.append(
931        Ex::Binary {
932            op: crate::BinaryOperator::Add,
933            left: first_override_expr,
934            right: zero,
935        },
936        nowhere,
937    );
938    let init = module.global_expressions.append(
939        Ex::As {
940            expr: sum,
941            kind: crate::ScalarKind::Uint,
942            convert: None,
943        },
944        nowhere,
945    );
946
947    // Override that serves as the array's length.
948    let second_override = module.overrides.append(
949        crate::Override {
950            name: None, // so it is not considered used by definition
951            id: Some(42),
952            ty: ty_u32,
953            init: Some(init),
954        },
955        nowhere,
956    );
957
958    // Array type that uses the overload as its length.
959    // Since this is named, it is considered used by definition.
960    let _ty_array = module.types.insert(
961        crate::Type {
962            name: Some("delicious_array".to_string()),
963            inner: Ti::Array {
964                base: ty_u32,
965                size: crate::ArraySize::Pending(second_override),
966                stride: 4,
967            },
968        },
969        nowhere,
970    );
971
972    let mut validator = super::valid::Validator::new(
973        super::valid::ValidationFlags::all(),
974        super::valid::Capabilities::all(),
975    );
976
977    assert!(validator.validate(&module).is_ok());
978    compact(&mut module, KeepUnused::Yes);
979    assert!(validator.validate(&module).is_ok());
980}
981
982#[test]
983fn array_length_expression() {
984    let mut module: crate::Module = Default::default();
985    let ty_u32 = module.types.insert(
986        crate::Type {
987            name: None,
988            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
989        },
990        crate::Span::default(),
991    );
992    let _unused_zero = module.global_expressions.append(
993        crate::Expression::Literal(crate::Literal::U32(0)),
994        crate::Span::default(),
995    );
996    let one = module.global_expressions.append(
997        crate::Expression::Literal(crate::Literal::U32(1)),
998        crate::Span::default(),
999    );
1000    let override_one = module.overrides.append(
1001        crate::Override {
1002            name: None,
1003            id: None,
1004            ty: ty_u32,
1005            init: Some(one),
1006        },
1007        crate::Span::default(),
1008    );
1009    let _ty_array = module.types.insert(
1010        crate::Type {
1011            name: Some("array<u32, 1>".to_string()),
1012            inner: crate::TypeInner::Array {
1013                base: ty_u32,
1014                size: crate::ArraySize::Pending(override_one),
1015                stride: 4,
1016            },
1017        },
1018        crate::Span::default(),
1019    );
1020
1021    let mut validator = super::valid::Validator::new(
1022        super::valid::ValidationFlags::all(),
1023        super::valid::Capabilities::all(),
1024    );
1025
1026    assert!(validator.validate(&module).is_ok());
1027    compact(&mut module, KeepUnused::Yes);
1028    assert!(validator.validate(&module).is_ok());
1029}
1030
1031#[test]
1032fn global_expression_override() {
1033    let mut module: crate::Module = Default::default();
1034    let ty_u32 = module.types.insert(
1035        crate::Type {
1036            name: None,
1037            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1038        },
1039        crate::Span::default(),
1040    );
1041
1042    // This will only be retained if we trace the initializers
1043    // of overrides referred to by `Expression::Override`
1044    // in global expressions.
1045    let expr1 = module.global_expressions.append(
1046        crate::Expression::Literal(crate::Literal::U32(1)),
1047        crate::Span::default(),
1048    );
1049
1050    // This will only be traced via a global `Expression::Override`.
1051    let o = module.overrides.append(
1052        crate::Override {
1053            name: None,
1054            id: Some(42),
1055            ty: ty_u32,
1056            init: Some(expr1),
1057        },
1058        crate::Span::default(),
1059    );
1060
1061    // This is retained by _p.
1062    let expr2 = module
1063        .global_expressions
1064        .append(crate::Expression::Override(o), crate::Span::default());
1065
1066    // Since this is named, it will be retained.
1067    let _p = module.overrides.append(
1068        crate::Override {
1069            name: Some("p".to_string()),
1070            id: None,
1071            ty: ty_u32,
1072            init: Some(expr2),
1073        },
1074        crate::Span::default(),
1075    );
1076
1077    let mut validator = super::valid::Validator::new(
1078        super::valid::ValidationFlags::all(),
1079        super::valid::Capabilities::all(),
1080    );
1081
1082    assert!(validator.validate(&module).is_ok());
1083    compact(&mut module, KeepUnused::Yes);
1084    assert!(validator.validate(&module).is_ok());
1085}
1086
1087#[test]
1088fn local_expression_override() {
1089    let mut module: crate::Module = Default::default();
1090    let ty_u32 = module.types.insert(
1091        crate::Type {
1092            name: None,
1093            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1094        },
1095        crate::Span::default(),
1096    );
1097
1098    // This will only be retained if we trace the initializers
1099    // of overrides referred to by `Expression::Override` in a function.
1100    let expr1 = module.global_expressions.append(
1101        crate::Expression::Literal(crate::Literal::U32(1)),
1102        crate::Span::default(),
1103    );
1104
1105    // This will be removed by compaction.
1106    let _unused_override = module.overrides.append(
1107        crate::Override {
1108            name: None,
1109            id: Some(41),
1110            ty: ty_u32,
1111            init: None,
1112        },
1113        crate::Span::default(),
1114    );
1115
1116    // This will only be traced via an `Expression::Override` in a function.
1117    let o = module.overrides.append(
1118        crate::Override {
1119            name: None,
1120            id: Some(42),
1121            ty: ty_u32,
1122            init: Some(expr1),
1123        },
1124        crate::Span::default(),
1125    );
1126
1127    let mut fun = crate::Function {
1128        result: Some(crate::FunctionResult {
1129            ty: ty_u32,
1130            binding: None,
1131        }),
1132        ..crate::Function::default()
1133    };
1134
1135    // This is used by the `Return` statement.
1136    let o_expr = fun
1137        .expressions
1138        .append(crate::Expression::Override(o), crate::Span::default());
1139    fun.body.push(
1140        crate::Statement::Return {
1141            value: Some(o_expr),
1142        },
1143        crate::Span::default(),
1144    );
1145
1146    module.functions.append(fun, crate::Span::default());
1147
1148    let mut validator = super::valid::Validator::new(
1149        super::valid::ValidationFlags::all(),
1150        super::valid::Capabilities::all(),
1151    );
1152
1153    assert!(validator.validate(&module).is_ok());
1154    compact(&mut module, KeepUnused::Yes);
1155    assert!(validator.validate(&module).is_ok());
1156}
1157
1158#[test]
1159fn unnamed_constant_type() {
1160    let mut module = crate::Module::default();
1161    let nowhere = crate::Span::default();
1162
1163    // This type is used only by the unnamed constant.
1164    let ty_u32 = module.types.insert(
1165        crate::Type {
1166            name: None,
1167            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1168        },
1169        nowhere,
1170    );
1171
1172    // This type is used by the named constant.
1173    let ty_vec_u32 = module.types.insert(
1174        crate::Type {
1175            name: None,
1176            inner: crate::TypeInner::Vector {
1177                size: crate::VectorSize::Bi,
1178                scalar: crate::Scalar::U32,
1179            },
1180        },
1181        nowhere,
1182    );
1183
1184    let unnamed_init = module
1185        .global_expressions
1186        .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere);
1187
1188    let unnamed_constant = module.constants.append(
1189        crate::Constant {
1190            name: None,
1191            ty: ty_u32,
1192            init: unnamed_init,
1193        },
1194        nowhere,
1195    );
1196
1197    // The named constant is initialized using a Splat expression, to
1198    // give the named constant a type distinct from the unnamed
1199    // constant's.
1200    let unnamed_constant_expr = module
1201        .global_expressions
1202        .append(crate::Expression::Constant(unnamed_constant), nowhere);
1203    let named_init = module.global_expressions.append(
1204        crate::Expression::Splat {
1205            size: crate::VectorSize::Bi,
1206            value: unnamed_constant_expr,
1207        },
1208        nowhere,
1209    );
1210
1211    let _named_constant = module.constants.append(
1212        crate::Constant {
1213            name: Some("totally_named".to_string()),
1214            ty: ty_vec_u32,
1215            init: named_init,
1216        },
1217        nowhere,
1218    );
1219
1220    let mut validator = super::valid::Validator::new(
1221        super::valid::ValidationFlags::all(),
1222        super::valid::Capabilities::all(),
1223    );
1224
1225    assert!(validator.validate(&module).is_ok());
1226    compact(&mut module, KeepUnused::Yes);
1227    assert!(validator.validate(&module).is_ok());
1228}
1229
1230#[test]
1231fn unnamed_override_type() {
1232    let mut module = crate::Module::default();
1233    let nowhere = crate::Span::default();
1234
1235    // This type is used only by the unnamed override.
1236    let ty_u32 = module.types.insert(
1237        crate::Type {
1238            name: None,
1239            inner: crate::TypeInner::Scalar(crate::Scalar::U32),
1240        },
1241        nowhere,
1242    );
1243
1244    // This type is used by the named override.
1245    let ty_i32 = module.types.insert(
1246        crate::Type {
1247            name: None,
1248            inner: crate::TypeInner::Scalar(crate::Scalar::I32),
1249        },
1250        nowhere,
1251    );
1252
1253    let unnamed_init = module
1254        .global_expressions
1255        .append(crate::Expression::Literal(crate::Literal::U32(0)), nowhere);
1256
1257    let unnamed_override = module.overrides.append(
1258        crate::Override {
1259            name: None,
1260            id: Some(42),
1261            ty: ty_u32,
1262            init: Some(unnamed_init),
1263        },
1264        nowhere,
1265    );
1266
1267    // The named override is initialized using a Splat expression, to
1268    // give the named override a type distinct from the unnamed
1269    // override's.
1270    let unnamed_override_expr = module
1271        .global_expressions
1272        .append(crate::Expression::Override(unnamed_override), nowhere);
1273    let named_init = module.global_expressions.append(
1274        crate::Expression::As {
1275            expr: unnamed_override_expr,
1276            kind: crate::ScalarKind::Sint,
1277            convert: None,
1278        },
1279        nowhere,
1280    );
1281
1282    let _named_override = module.overrides.append(
1283        crate::Override {
1284            name: Some("totally_named".to_string()),
1285            id: None,
1286            ty: ty_i32,
1287            init: Some(named_init),
1288        },
1289        nowhere,
1290    );
1291
1292    let mut validator = super::valid::Validator::new(
1293        super::valid::ValidationFlags::all(),
1294        super::valid::Capabilities::all(),
1295    );
1296
1297    assert!(validator.validate(&module).is_ok());
1298    compact(&mut module, KeepUnused::Yes);
1299    assert!(validator.validate(&module).is_ok());
1300}