naga/compact/
statements.rs

1use alloc::{vec, vec::Vec};
2
3use super::functions::FunctionTracer;
4use super::FunctionMap;
5use crate::arena::Handle;
6use crate::compact::handle_set_map::HandleMap;
7
8impl FunctionTracer<'_> {
9    pub fn trace_block(&mut self, block: &[crate::Statement]) {
10        let mut worklist: Vec<&[crate::Statement]> = vec![block];
11        while let Some(last) = worklist.pop() {
12            for stmt in last {
13                use crate::Statement as St;
14                match *stmt {
15                    St::Emit(ref _range) => {
16                        // If we come across a statement that actually uses an
17                        // expression in this range, it'll get traced from
18                        // there. But since evaluating expressions has no
19                        // effect, we don't need to assume that everything
20                        // emitted is live.
21                    }
22                    St::Block(ref block) => worklist.push(block),
23                    St::If {
24                        condition,
25                        ref accept,
26                        ref reject,
27                    } => {
28                        self.expressions_used.insert(condition);
29                        worklist.push(accept);
30                        worklist.push(reject);
31                    }
32                    St::Switch {
33                        selector,
34                        ref cases,
35                    } => {
36                        self.expressions_used.insert(selector);
37                        for case in cases {
38                            worklist.push(&case.body);
39                        }
40                    }
41                    St::Loop {
42                        ref body,
43                        ref continuing,
44                        break_if,
45                    } => {
46                        if let Some(break_if) = break_if {
47                            self.expressions_used.insert(break_if);
48                        }
49                        worklist.push(body);
50                        worklist.push(continuing);
51                    }
52                    St::Return { value: Some(value) } => {
53                        self.expressions_used.insert(value);
54                    }
55                    St::Store { pointer, value } => {
56                        self.expressions_used.insert(pointer);
57                        self.expressions_used.insert(value);
58                    }
59                    St::ImageStore {
60                        image,
61                        coordinate,
62                        array_index,
63                        value,
64                    } => {
65                        self.expressions_used.insert(image);
66                        self.expressions_used.insert(coordinate);
67                        if let Some(array_index) = array_index {
68                            self.expressions_used.insert(array_index);
69                        }
70                        self.expressions_used.insert(value);
71                    }
72                    St::Atomic {
73                        pointer,
74                        ref fun,
75                        value,
76                        result,
77                    } => {
78                        self.expressions_used.insert(pointer);
79                        self.trace_atomic_function(fun);
80                        self.expressions_used.insert(value);
81                        if let Some(result) = result {
82                            self.expressions_used.insert(result);
83                        }
84                    }
85                    St::ImageAtomic {
86                        image,
87                        coordinate,
88                        array_index,
89                        fun: _,
90                        value,
91                    } => {
92                        self.expressions_used.insert(image);
93                        self.expressions_used.insert(coordinate);
94                        if let Some(array_index) = array_index {
95                            self.expressions_used.insert(array_index);
96                        }
97                        self.expressions_used.insert(value);
98                    }
99                    St::WorkGroupUniformLoad { pointer, result } => {
100                        self.expressions_used.insert(pointer);
101                        self.expressions_used.insert(result);
102                    }
103                    St::Call {
104                        function,
105                        ref arguments,
106                        result,
107                    } => {
108                        self.trace_call(function);
109                        for expr in arguments {
110                            self.expressions_used.insert(*expr);
111                        }
112                        if let Some(result) = result {
113                            self.expressions_used.insert(result);
114                        }
115                    }
116                    St::RayQuery { query, ref fun } => {
117                        self.expressions_used.insert(query);
118                        self.trace_ray_query_function(fun);
119                    }
120                    St::SubgroupBallot { result, predicate } => {
121                        if let Some(predicate) = predicate {
122                            self.expressions_used.insert(predicate);
123                        }
124                        self.expressions_used.insert(result);
125                    }
126                    St::SubgroupCollectiveOperation {
127                        op: _,
128                        collective_op: _,
129                        argument,
130                        result,
131                    } => {
132                        self.expressions_used.insert(argument);
133                        self.expressions_used.insert(result);
134                    }
135                    St::SubgroupGather {
136                        mode,
137                        argument,
138                        result,
139                    } => {
140                        match mode {
141                            crate::GatherMode::BroadcastFirst => {}
142                            crate::GatherMode::Broadcast(index)
143                            | crate::GatherMode::Shuffle(index)
144                            | crate::GatherMode::ShuffleDown(index)
145                            | crate::GatherMode::ShuffleUp(index)
146                            | crate::GatherMode::ShuffleXor(index)
147                            | crate::GatherMode::QuadBroadcast(index) => {
148                                self.expressions_used.insert(index);
149                            }
150                            crate::GatherMode::QuadSwap(_) => {}
151                        }
152                        self.expressions_used.insert(argument);
153                        self.expressions_used.insert(result);
154                    }
155
156                    // Trivial statements.
157                    St::Break
158                    | St::Continue
159                    | St::Kill
160                    | St::ControlBarrier(_)
161                    | St::MemoryBarrier(_)
162                    | St::Return { value: None } => {}
163                }
164            }
165        }
166    }
167
168    fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) {
169        use crate::AtomicFunction as Af;
170        match *fun {
171            Af::Exchange {
172                compare: Some(expr),
173            } => {
174                self.expressions_used.insert(expr);
175            }
176            Af::Exchange { compare: None }
177            | Af::Add
178            | Af::Subtract
179            | Af::And
180            | Af::ExclusiveOr
181            | Af::InclusiveOr
182            | Af::Min
183            | Af::Max => {}
184        }
185    }
186
187    fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) {
188        use crate::RayQueryFunction as Qf;
189        match *fun {
190            Qf::Initialize {
191                acceleration_structure,
192                descriptor,
193            } => {
194                self.expressions_used.insert(acceleration_structure);
195                self.expressions_used.insert(descriptor);
196            }
197            Qf::Proceed { result } => {
198                self.expressions_used.insert(result);
199            }
200            Qf::GenerateIntersection { hit_t } => {
201                self.expressions_used.insert(hit_t);
202            }
203            Qf::ConfirmIntersection => {}
204            Qf::Terminate => {}
205        }
206    }
207}
208
209impl FunctionMap {
210    /// Adjust statements in the body of `function`.
211    ///
212    /// Adjusts expressions using `self.expressions`, and adjusts calls to other
213    /// functions using `function_map`.
214    pub fn adjust_body(
215        &self,
216        function: &mut crate::Function,
217        function_map: &HandleMap<crate::Function>,
218    ) {
219        let block = &mut function.body;
220        let mut worklist: Vec<&mut [crate::Statement]> = vec![block];
221        let adjust = |handle: &mut Handle<crate::Expression>| {
222            self.expressions.adjust(handle);
223        };
224        while let Some(last) = worklist.pop() {
225            for stmt in last {
226                use crate::Statement as St;
227                match *stmt {
228                    St::Emit(ref mut range) => {
229                        self.expressions.adjust_range(range, &function.expressions);
230                    }
231                    St::Block(ref mut block) => worklist.push(block),
232                    St::If {
233                        ref mut condition,
234                        ref mut accept,
235                        ref mut reject,
236                    } => {
237                        adjust(condition);
238                        worklist.push(accept);
239                        worklist.push(reject);
240                    }
241                    St::Switch {
242                        ref mut selector,
243                        ref mut cases,
244                    } => {
245                        adjust(selector);
246                        for case in cases {
247                            worklist.push(&mut case.body);
248                        }
249                    }
250                    St::Loop {
251                        ref mut body,
252                        ref mut continuing,
253                        ref mut break_if,
254                    } => {
255                        if let Some(ref mut break_if) = *break_if {
256                            adjust(break_if);
257                        }
258                        worklist.push(body);
259                        worklist.push(continuing);
260                    }
261                    St::Return {
262                        value: Some(ref mut value),
263                    } => adjust(value),
264                    St::Store {
265                        ref mut pointer,
266                        ref mut value,
267                    } => {
268                        adjust(pointer);
269                        adjust(value);
270                    }
271                    St::ImageStore {
272                        ref mut image,
273                        ref mut coordinate,
274                        ref mut array_index,
275                        ref mut value,
276                    } => {
277                        adjust(image);
278                        adjust(coordinate);
279                        if let Some(ref mut array_index) = *array_index {
280                            adjust(array_index);
281                        }
282                        adjust(value);
283                    }
284                    St::Atomic {
285                        ref mut pointer,
286                        ref mut fun,
287                        ref mut value,
288                        ref mut result,
289                    } => {
290                        adjust(pointer);
291                        self.adjust_atomic_function(fun);
292                        adjust(value);
293                        if let Some(ref mut result) = *result {
294                            adjust(result);
295                        }
296                    }
297                    St::ImageAtomic {
298                        ref mut image,
299                        ref mut coordinate,
300                        ref mut array_index,
301                        fun: _,
302                        ref mut value,
303                    } => {
304                        adjust(image);
305                        adjust(coordinate);
306                        if let Some(ref mut array_index) = *array_index {
307                            adjust(array_index);
308                        }
309                        adjust(value);
310                    }
311                    St::WorkGroupUniformLoad {
312                        ref mut pointer,
313                        ref mut result,
314                    } => {
315                        adjust(pointer);
316                        adjust(result);
317                    }
318                    St::Call {
319                        ref mut function,
320                        ref mut arguments,
321                        ref mut result,
322                    } => {
323                        function_map.adjust(function);
324                        for expr in arguments {
325                            adjust(expr);
326                        }
327                        if let Some(ref mut result) = *result {
328                            adjust(result);
329                        }
330                    }
331                    St::RayQuery {
332                        ref mut query,
333                        ref mut fun,
334                    } => {
335                        adjust(query);
336                        self.adjust_ray_query_function(fun);
337                    }
338                    St::SubgroupBallot {
339                        ref mut result,
340                        ref mut predicate,
341                    } => {
342                        if let Some(ref mut predicate) = *predicate {
343                            adjust(predicate);
344                        }
345                        adjust(result);
346                    }
347                    St::SubgroupCollectiveOperation {
348                        op: _,
349                        collective_op: _,
350                        ref mut argument,
351                        ref mut result,
352                    } => {
353                        adjust(argument);
354                        adjust(result);
355                    }
356                    St::SubgroupGather {
357                        ref mut mode,
358                        ref mut argument,
359                        ref mut result,
360                    } => {
361                        match *mode {
362                            crate::GatherMode::BroadcastFirst => {}
363                            crate::GatherMode::Broadcast(ref mut index)
364                            | crate::GatherMode::Shuffle(ref mut index)
365                            | crate::GatherMode::ShuffleDown(ref mut index)
366                            | crate::GatherMode::ShuffleUp(ref mut index)
367                            | crate::GatherMode::ShuffleXor(ref mut index)
368                            | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
369                            crate::GatherMode::QuadSwap(_) => {}
370                        }
371                        adjust(argument);
372                        adjust(result);
373                    }
374
375                    // Trivial statements.
376                    St::Break
377                    | St::Continue
378                    | St::Kill
379                    | St::ControlBarrier(_)
380                    | St::MemoryBarrier(_)
381                    | St::Return { value: None } => {}
382                }
383            }
384        }
385    }
386
387    fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) {
388        use crate::AtomicFunction as Af;
389        match *fun {
390            Af::Exchange {
391                compare: Some(ref mut expr),
392            } => {
393                self.expressions.adjust(expr);
394            }
395            Af::Exchange { compare: None }
396            | Af::Add
397            | Af::Subtract
398            | Af::And
399            | Af::ExclusiveOr
400            | Af::InclusiveOr
401            | Af::Min
402            | Af::Max => {}
403        }
404    }
405
406    fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) {
407        use crate::RayQueryFunction as Qf;
408        match *fun {
409            Qf::Initialize {
410                ref mut acceleration_structure,
411                ref mut descriptor,
412            } => {
413                self.expressions.adjust(acceleration_structure);
414                self.expressions.adjust(descriptor);
415            }
416            Qf::Proceed { ref mut result } => {
417                self.expressions.adjust(result);
418            }
419            Qf::GenerateIntersection { ref mut hit_t } => {
420                self.expressions.adjust(hit_t);
421            }
422            Qf::ConfirmIntersection => {}
423            Qf::Terminate => {}
424        }
425    }
426}