naga/compact/
statements.rs

1use alloc::{vec, vec::Vec};
2
3use super::functions::FunctionTracer;
4use super::FunctionMap;
5use crate::arena::Handle;
6use crate::compact::handle_set_map::HandleMap;
7
8impl FunctionTracer<'_> {
9    pub fn trace_block(&mut self, block: &[crate::Statement]) {
10        let mut worklist: Vec<&[crate::Statement]> = vec![block];
11        while let Some(last) = worklist.pop() {
12            for stmt in last {
13                use crate::Statement as St;
14                match *stmt {
15                    St::Emit(ref _range) => {
16                        // If we come across a statement that actually uses an
17                        // expression in this range, it'll get traced from
18                        // there. But since evaluating expressions has no
19                        // effect, we don't need to assume that everything
20                        // emitted is live.
21                    }
22                    St::Block(ref block) => worklist.push(block),
23                    St::If {
24                        condition,
25                        ref accept,
26                        ref reject,
27                    } => {
28                        self.expressions_used.insert(condition);
29                        worklist.push(accept);
30                        worklist.push(reject);
31                    }
32                    St::Switch {
33                        selector,
34                        ref cases,
35                    } => {
36                        self.expressions_used.insert(selector);
37                        for case in cases {
38                            worklist.push(&case.body);
39                        }
40                    }
41                    St::Loop {
42                        ref body,
43                        ref continuing,
44                        break_if,
45                    } => {
46                        if let Some(break_if) = break_if {
47                            self.expressions_used.insert(break_if);
48                        }
49                        worklist.push(body);
50                        worklist.push(continuing);
51                    }
52                    St::Return { value: Some(value) } => {
53                        self.expressions_used.insert(value);
54                    }
55                    St::Store { pointer, value } => {
56                        self.expressions_used.insert(pointer);
57                        self.expressions_used.insert(value);
58                    }
59                    St::ImageStore {
60                        image,
61                        coordinate,
62                        array_index,
63                        value,
64                    } => {
65                        self.expressions_used.insert(image);
66                        self.expressions_used.insert(coordinate);
67                        if let Some(array_index) = array_index {
68                            self.expressions_used.insert(array_index);
69                        }
70                        self.expressions_used.insert(value);
71                    }
72                    St::Atomic {
73                        pointer,
74                        ref fun,
75                        value,
76                        result,
77                    } => {
78                        self.expressions_used.insert(pointer);
79                        self.trace_atomic_function(fun);
80                        self.expressions_used.insert(value);
81                        if let Some(result) = result {
82                            self.expressions_used.insert(result);
83                        }
84                    }
85                    St::ImageAtomic {
86                        image,
87                        coordinate,
88                        array_index,
89                        fun: _,
90                        value,
91                    } => {
92                        self.expressions_used.insert(image);
93                        self.expressions_used.insert(coordinate);
94                        if let Some(array_index) = array_index {
95                            self.expressions_used.insert(array_index);
96                        }
97                        self.expressions_used.insert(value);
98                    }
99                    St::WorkGroupUniformLoad { pointer, result } => {
100                        self.expressions_used.insert(pointer);
101                        self.expressions_used.insert(result);
102                    }
103                    St::Call {
104                        function,
105                        ref arguments,
106                        result,
107                    } => {
108                        self.trace_call(function);
109                        for expr in arguments {
110                            self.expressions_used.insert(*expr);
111                        }
112                        if let Some(result) = result {
113                            self.expressions_used.insert(result);
114                        }
115                    }
116                    St::RayQuery { query, ref fun } => {
117                        self.expressions_used.insert(query);
118                        self.trace_ray_query_function(fun);
119                    }
120                    St::SubgroupBallot { result, predicate } => {
121                        if let Some(predicate) = predicate {
122                            self.expressions_used.insert(predicate);
123                        }
124                        self.expressions_used.insert(result);
125                    }
126                    St::SubgroupCollectiveOperation {
127                        op: _,
128                        collective_op: _,
129                        argument,
130                        result,
131                    } => {
132                        self.expressions_used.insert(argument);
133                        self.expressions_used.insert(result);
134                    }
135                    St::SubgroupGather {
136                        mode,
137                        argument,
138                        result,
139                    } => {
140                        match mode {
141                            crate::GatherMode::BroadcastFirst => {}
142                            crate::GatherMode::Broadcast(index)
143                            | crate::GatherMode::Shuffle(index)
144                            | crate::GatherMode::ShuffleDown(index)
145                            | crate::GatherMode::ShuffleUp(index)
146                            | crate::GatherMode::ShuffleXor(index)
147                            | crate::GatherMode::QuadBroadcast(index) => {
148                                self.expressions_used.insert(index);
149                            }
150                            crate::GatherMode::QuadSwap(_) => {}
151                        }
152                        self.expressions_used.insert(argument);
153                        self.expressions_used.insert(result);
154                    }
155                    St::CooperativeStore { target, ref data } => {
156                        self.expressions_used.insert(target);
157                        self.expressions_used.insert(data.pointer);
158                        self.expressions_used.insert(data.stride);
159                    }
160
161                    // Trivial statements.
162                    St::Break
163                    | St::Continue
164                    | St::Kill
165                    | St::ControlBarrier(_)
166                    | St::MemoryBarrier(_)
167                    | St::Return { value: None } => {}
168                }
169            }
170        }
171    }
172
173    fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) {
174        use crate::AtomicFunction as Af;
175        match *fun {
176            Af::Exchange {
177                compare: Some(expr),
178            } => {
179                self.expressions_used.insert(expr);
180            }
181            Af::Exchange { compare: None }
182            | Af::Add
183            | Af::Subtract
184            | Af::And
185            | Af::ExclusiveOr
186            | Af::InclusiveOr
187            | Af::Min
188            | Af::Max => {}
189        }
190    }
191
192    fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) {
193        use crate::RayQueryFunction as Qf;
194        match *fun {
195            Qf::Initialize {
196                acceleration_structure,
197                descriptor,
198            } => {
199                self.expressions_used.insert(acceleration_structure);
200                self.expressions_used.insert(descriptor);
201            }
202            Qf::Proceed { result } => {
203                self.expressions_used.insert(result);
204            }
205            Qf::GenerateIntersection { hit_t } => {
206                self.expressions_used.insert(hit_t);
207            }
208            Qf::ConfirmIntersection => {}
209            Qf::Terminate => {}
210        }
211    }
212}
213
214impl FunctionMap {
215    /// Adjust statements in the body of `function`.
216    ///
217    /// Adjusts expressions using `self.expressions`, and adjusts calls to other
218    /// functions using `function_map`.
219    pub fn adjust_body(
220        &self,
221        function: &mut crate::Function,
222        function_map: &HandleMap<crate::Function>,
223    ) {
224        let block = &mut function.body;
225        let mut worklist: Vec<&mut [crate::Statement]> = vec![block];
226        let adjust = |handle: &mut Handle<crate::Expression>| {
227            self.expressions.adjust(handle);
228        };
229        while let Some(last) = worklist.pop() {
230            for stmt in last {
231                use crate::Statement as St;
232                match *stmt {
233                    St::Emit(ref mut range) => {
234                        self.expressions.adjust_range(range, &function.expressions);
235                    }
236                    St::Block(ref mut block) => worklist.push(block),
237                    St::If {
238                        ref mut condition,
239                        ref mut accept,
240                        ref mut reject,
241                    } => {
242                        adjust(condition);
243                        worklist.push(accept);
244                        worklist.push(reject);
245                    }
246                    St::Switch {
247                        ref mut selector,
248                        ref mut cases,
249                    } => {
250                        adjust(selector);
251                        for case in cases {
252                            worklist.push(&mut case.body);
253                        }
254                    }
255                    St::Loop {
256                        ref mut body,
257                        ref mut continuing,
258                        ref mut break_if,
259                    } => {
260                        if let Some(ref mut break_if) = *break_if {
261                            adjust(break_if);
262                        }
263                        worklist.push(body);
264                        worklist.push(continuing);
265                    }
266                    St::Return {
267                        value: Some(ref mut value),
268                    } => adjust(value),
269                    St::Store {
270                        ref mut pointer,
271                        ref mut value,
272                    } => {
273                        adjust(pointer);
274                        adjust(value);
275                    }
276                    St::ImageStore {
277                        ref mut image,
278                        ref mut coordinate,
279                        ref mut array_index,
280                        ref mut value,
281                    } => {
282                        adjust(image);
283                        adjust(coordinate);
284                        if let Some(ref mut array_index) = *array_index {
285                            adjust(array_index);
286                        }
287                        adjust(value);
288                    }
289                    St::Atomic {
290                        ref mut pointer,
291                        ref mut fun,
292                        ref mut value,
293                        ref mut result,
294                    } => {
295                        adjust(pointer);
296                        self.adjust_atomic_function(fun);
297                        adjust(value);
298                        if let Some(ref mut result) = *result {
299                            adjust(result);
300                        }
301                    }
302                    St::ImageAtomic {
303                        ref mut image,
304                        ref mut coordinate,
305                        ref mut array_index,
306                        fun: _,
307                        ref mut value,
308                    } => {
309                        adjust(image);
310                        adjust(coordinate);
311                        if let Some(ref mut array_index) = *array_index {
312                            adjust(array_index);
313                        }
314                        adjust(value);
315                    }
316                    St::WorkGroupUniformLoad {
317                        ref mut pointer,
318                        ref mut result,
319                    } => {
320                        adjust(pointer);
321                        adjust(result);
322                    }
323                    St::Call {
324                        ref mut function,
325                        ref mut arguments,
326                        ref mut result,
327                    } => {
328                        function_map.adjust(function);
329                        for expr in arguments {
330                            adjust(expr);
331                        }
332                        if let Some(ref mut result) = *result {
333                            adjust(result);
334                        }
335                    }
336                    St::RayQuery {
337                        ref mut query,
338                        ref mut fun,
339                    } => {
340                        adjust(query);
341                        self.adjust_ray_query_function(fun);
342                    }
343                    St::SubgroupBallot {
344                        ref mut result,
345                        ref mut predicate,
346                    } => {
347                        if let Some(ref mut predicate) = *predicate {
348                            adjust(predicate);
349                        }
350                        adjust(result);
351                    }
352                    St::SubgroupCollectiveOperation {
353                        op: _,
354                        collective_op: _,
355                        ref mut argument,
356                        ref mut result,
357                    } => {
358                        adjust(argument);
359                        adjust(result);
360                    }
361                    St::SubgroupGather {
362                        ref mut mode,
363                        ref mut argument,
364                        ref mut result,
365                    } => {
366                        match *mode {
367                            crate::GatherMode::BroadcastFirst => {}
368                            crate::GatherMode::Broadcast(ref mut index)
369                            | crate::GatherMode::Shuffle(ref mut index)
370                            | crate::GatherMode::ShuffleDown(ref mut index)
371                            | crate::GatherMode::ShuffleUp(ref mut index)
372                            | crate::GatherMode::ShuffleXor(ref mut index)
373                            | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
374                            crate::GatherMode::QuadSwap(_) => {}
375                        }
376                        adjust(argument);
377                        adjust(result);
378                    }
379                    St::CooperativeStore {
380                        ref mut target,
381                        ref mut data,
382                    } => {
383                        adjust(target);
384                        adjust(&mut data.pointer);
385                        adjust(&mut data.stride);
386                    }
387
388                    // Trivial statements.
389                    St::Break
390                    | St::Continue
391                    | St::Kill
392                    | St::ControlBarrier(_)
393                    | St::MemoryBarrier(_)
394                    | St::Return { value: None } => {}
395                }
396            }
397        }
398    }
399
400    fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) {
401        use crate::AtomicFunction as Af;
402        match *fun {
403            Af::Exchange {
404                compare: Some(ref mut expr),
405            } => {
406                self.expressions.adjust(expr);
407            }
408            Af::Exchange { compare: None }
409            | Af::Add
410            | Af::Subtract
411            | Af::And
412            | Af::ExclusiveOr
413            | Af::InclusiveOr
414            | Af::Min
415            | Af::Max => {}
416        }
417    }
418
419    fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) {
420        use crate::RayQueryFunction as Qf;
421        match *fun {
422            Qf::Initialize {
423                ref mut acceleration_structure,
424                ref mut descriptor,
425            } => {
426                self.expressions.adjust(acceleration_structure);
427                self.expressions.adjust(descriptor);
428            }
429            Qf::Proceed { ref mut result } => {
430                self.expressions.adjust(result);
431            }
432            Qf::GenerateIntersection { ref mut hit_t } => {
433                self.expressions.adjust(hit_t);
434            }
435            Qf::ConfirmIntersection => {}
436            Qf::Terminate => {}
437        }
438    }
439}