1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
mod expressions;
mod functions;
mod handle_set_map;
mod statements;
mod types;

use crate::arena::HandleSet;
use crate::{arena, compact::functions::FunctionTracer};
use handle_set_map::HandleMap;

/// Remove unused types, expressions, and constants from `module`.
///
/// Assuming that all globals, named constants, special types,
/// functions and entry points in `module` are used, determine which
/// types, constants, and expressions (both function-local and global
/// constant expressions) are actually used, and remove the rest,
/// adjusting all handles as necessary. The result should be a module
/// functionally identical to the original.
///
/// This may be useful to apply to modules generated in the snapshot
/// tests. Our backends often generate temporary names based on handle
/// indices, which means that adding or removing unused arena entries
/// can affect the output even though they have no semantic effect.
/// Such meaningless changes add noise to snapshot diffs, making
/// accurate patch review difficult. Compacting the modules before
/// generating snapshots makes the output independent of unused arena
/// entries.
///
/// # Panics
///
/// If `module` has not passed validation, this may panic.
pub fn compact(module: &mut crate::Module) {
    let mut module_tracer = ModuleTracer::new(module);

    // We treat all globals as used by definition.
    log::trace!("tracing global variables");
    {
        for (_, global) in module.global_variables.iter() {
            log::trace!("tracing global {:?}", global.name);
            module_tracer.types_used.insert(global.ty);
            if let Some(init) = global.init {
                module_tracer.global_expressions_used.insert(init);
            }
        }
    }

    // We treat all special types as used by definition.
    module_tracer.trace_special_types(&module.special_types);

    // We treat all named constants as used by definition.
    for (handle, constant) in module.constants.iter() {
        if constant.name.is_some() {
            module_tracer.constants_used.insert(handle);
            module_tracer.global_expressions_used.insert(constant.init);
        }
    }

    // We treat all overrides as used by definition.
    for (_, override_) in module.overrides.iter() {
        module_tracer.types_used.insert(override_.ty);
        if let Some(init) = override_.init {
            module_tracer.global_expressions_used.insert(init);
        }
    }

    for (_, ty) in module.types.iter() {
        if let crate::TypeInner::Array {
            size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)),
            ..
        } = ty.inner
        {
            module_tracer.global_expressions_used.insert(size_expr);
        }
    }

    for e in module.entry_points.iter() {
        if let Some(sizes) = e.workgroup_size_overrides {
            for size in sizes.iter().filter_map(|x| *x) {
                module_tracer.global_expressions_used.insert(size);
            }
        }
    }

    // We assume that all functions are used.
    //
    // Observe which types, constant expressions, constants, and
    // expressions each function uses, and produce maps for each
    // function from pre-compaction to post-compaction expression
    // handles.
    log::trace!("tracing functions");
    let function_maps: Vec<FunctionMap> = module
        .functions
        .iter()
        .map(|(_, f)| {
            log::trace!("tracing function {:?}", f.name);
            let mut function_tracer = module_tracer.as_function(f);
            function_tracer.trace();
            FunctionMap::from(function_tracer)
        })
        .collect();

    // Similarly, observe what each entry point actually uses.
    log::trace!("tracing entry points");
    let entry_point_maps: Vec<FunctionMap> = module
        .entry_points
        .iter()
        .map(|e| {
            log::trace!("tracing entry point {:?}", e.function.name);
            let mut used = module_tracer.as_function(&e.function);
            used.trace();
            FunctionMap::from(used)
        })
        .collect();

    // Given that the above steps have marked all the constant
    // expressions used directly by globals, constants, functions, and
    // entry points, walk the constant expression arena to find all
    // constant expressions used, directly or indirectly.
    module_tracer.as_const_expression().trace_expressions();

    // Constants' initializers are taken care of already, because
    // expression tracing sees through constants. But we still need to
    // note type usage.
    for (handle, constant) in module.constants.iter() {
        if module_tracer.constants_used.contains(handle) {
            module_tracer.types_used.insert(constant.ty);
        }
    }

    // Treat all named types as used.
    for (handle, ty) in module.types.iter() {
        log::trace!("tracing type {:?}, name {:?}", handle, ty.name);
        if ty.name.is_some() {
            module_tracer.types_used.insert(handle);
        }
    }

    // Propagate usage through types.
    module_tracer.as_type().trace_types();

    // Now that we know what is used and what is never touched,
    // produce maps from the `Handle`s that appear in `module` now to
    // the corresponding `Handle`s that will refer to the same items
    // in the compacted module.
    let module_map = ModuleMap::from(module_tracer);

    // Drop unused types from the type arena.
    //
    // `FastIndexSet`s don't have an underlying Vec<T> that we can
    // steal, compact in place, and then rebuild the `FastIndexSet`
    // from. So we have to rebuild the type arena from scratch.
    log::trace!("compacting types");
    let mut new_types = arena::UniqueArena::new();
    for (old_handle, mut ty, span) in module.types.drain_all() {
        if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) {
            module_map.adjust_type(&mut ty);
            let actual_new_handle = new_types.insert(ty, span);
            assert_eq!(actual_new_handle, expected_new_handle);
        }
    }
    module.types = new_types;
    log::trace!("adjusting special types");
    module_map.adjust_special_types(&mut module.special_types);

    // Drop unused constant expressions, reusing existing storage.
    log::trace!("adjusting constant expressions");
    module.global_expressions.retain_mut(|handle, expr| {
        if module_map.global_expressions.used(handle) {
            module_map.adjust_expression(expr, &module_map.global_expressions);
            true
        } else {
            false
        }
    });

    // Drop unused constants in place, reusing existing storage.
    log::trace!("adjusting constants");
    module.constants.retain_mut(|handle, constant| {
        if module_map.constants.used(handle) {
            module_map.types.adjust(&mut constant.ty);
            module_map.global_expressions.adjust(&mut constant.init);
            true
        } else {
            false
        }
    });

    // Adjust override types and initializers.
    log::trace!("adjusting overrides");
    for (_, override_) in module.overrides.iter_mut() {
        module_map.types.adjust(&mut override_.ty);
        if let Some(init) = override_.init.as_mut() {
            module_map.global_expressions.adjust(init);
        }
    }

    // Adjust workgroup_size_overrides
    log::trace!("adjusting workgroup_size_overrides");
    for e in module.entry_points.iter_mut() {
        if let Some(sizes) = e.workgroup_size_overrides.as_mut() {
            for size in sizes.iter_mut() {
                if let Some(expr) = size.as_mut() {
                    module_map.global_expressions.adjust(expr);
                }
            }
        }
    }

    // Adjust global variables' types and initializers.
    log::trace!("adjusting global variables");
    for (_, global) in module.global_variables.iter_mut() {
        log::trace!("adjusting global {:?}", global.name);
        module_map.types.adjust(&mut global.ty);
        if let Some(ref mut init) = global.init {
            module_map.global_expressions.adjust(init);
        }
    }

    // Temporary storage to help us reuse allocations of existing
    // named expression tables.
    let mut reused_named_expressions = crate::NamedExpressions::default();

    // Compact each function.
    for ((_, function), map) in module.functions.iter_mut().zip(function_maps.iter()) {
        log::trace!("compacting function {:?}", function.name);
        map.compact(function, &module_map, &mut reused_named_expressions);
    }

    // Compact each entry point.
    for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) {
        log::trace!("compacting entry point {:?}", entry.function.name);
        map.compact(
            &mut entry.function,
            &module_map,
            &mut reused_named_expressions,
        );
    }
}

struct ModuleTracer<'module> {
    module: &'module crate::Module,
    types_used: HandleSet<crate::Type>,
    constants_used: HandleSet<crate::Constant>,
    global_expressions_used: HandleSet<crate::Expression>,
}

impl<'module> ModuleTracer<'module> {
    fn new(module: &'module crate::Module) -> Self {
        Self {
            module,
            types_used: HandleSet::for_arena(&module.types),
            constants_used: HandleSet::for_arena(&module.constants),
            global_expressions_used: HandleSet::for_arena(&module.global_expressions),
        }
    }

    fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) {
        let crate::SpecialTypes {
            ref ray_desc,
            ref ray_intersection,
            ref predeclared_types,
        } = *special_types;

        if let Some(ray_desc) = *ray_desc {
            self.types_used.insert(ray_desc);
        }
        if let Some(ray_intersection) = *ray_intersection {
            self.types_used.insert(ray_intersection);
        }
        for (_, &handle) in predeclared_types {
            self.types_used.insert(handle);
        }
    }

    fn as_type(&mut self) -> types::TypeTracer {
        types::TypeTracer {
            types: &self.module.types,
            types_used: &mut self.types_used,
        }
    }

    fn as_const_expression(&mut self) -> expressions::ExpressionTracer {
        expressions::ExpressionTracer {
            expressions: &self.module.global_expressions,
            constants: &self.module.constants,
            types_used: &mut self.types_used,
            constants_used: &mut self.constants_used,
            expressions_used: &mut self.global_expressions_used,
            global_expressions_used: None,
        }
    }

    pub fn as_function<'tracer>(
        &'tracer mut self,
        function: &'tracer crate::Function,
    ) -> FunctionTracer<'tracer> {
        FunctionTracer {
            function,
            constants: &self.module.constants,
            types_used: &mut self.types_used,
            constants_used: &mut self.constants_used,
            global_expressions_used: &mut self.global_expressions_used,
            expressions_used: HandleSet::for_arena(&function.expressions),
        }
    }
}

struct ModuleMap {
    types: HandleMap<crate::Type>,
    constants: HandleMap<crate::Constant>,
    global_expressions: HandleMap<crate::Expression>,
}

impl From<ModuleTracer<'_>> for ModuleMap {
    fn from(used: ModuleTracer) -> Self {
        ModuleMap {
            types: HandleMap::from_set(used.types_used),
            constants: HandleMap::from_set(used.constants_used),
            global_expressions: HandleMap::from_set(used.global_expressions_used),
        }
    }
}

impl ModuleMap {
    fn adjust_special_types(&self, special: &mut crate::SpecialTypes) {
        let crate::SpecialTypes {
            ref mut ray_desc,
            ref mut ray_intersection,
            ref mut predeclared_types,
        } = *special;

        if let Some(ref mut ray_desc) = *ray_desc {
            self.types.adjust(ray_desc);
        }
        if let Some(ref mut ray_intersection) = *ray_intersection {
            self.types.adjust(ray_intersection);
        }

        for handle in predeclared_types.values_mut() {
            self.types.adjust(handle);
        }
    }
}

struct FunctionMap {
    expressions: HandleMap<crate::Expression>,
}

impl From<FunctionTracer<'_>> for FunctionMap {
    fn from(used: FunctionTracer) -> Self {
        FunctionMap {
            expressions: HandleMap::from_set(used.expressions_used),
        }
    }
}