naga/compact/
statements.rs

1use alloc::{vec, vec::Vec};
2
3use super::functions::FunctionTracer;
4use super::FunctionMap;
5use crate::arena::Handle;
6use crate::compact::handle_set_map::HandleMap;
7
8impl FunctionTracer<'_> {
9    pub fn trace_block(&mut self, block: &[crate::Statement]) {
10        let mut worklist: Vec<&[crate::Statement]> = vec![block];
11        while let Some(last) = worklist.pop() {
12            for stmt in last {
13                use crate::Statement as St;
14                match *stmt {
15                    St::Emit(ref _range) => {
16                        // If we come across a statement that actually uses an
17                        // expression in this range, it'll get traced from
18                        // there. But since evaluating expressions has no
19                        // effect, we don't need to assume that everything
20                        // emitted is live.
21                    }
22                    St::Block(ref block) => worklist.push(block),
23                    St::If {
24                        condition,
25                        ref accept,
26                        ref reject,
27                    } => {
28                        self.expressions_used.insert(condition);
29                        worklist.push(accept);
30                        worklist.push(reject);
31                    }
32                    St::Switch {
33                        selector,
34                        ref cases,
35                    } => {
36                        self.expressions_used.insert(selector);
37                        for case in cases {
38                            worklist.push(&case.body);
39                        }
40                    }
41                    St::Loop {
42                        ref body,
43                        ref continuing,
44                        break_if,
45                    } => {
46                        if let Some(break_if) = break_if {
47                            self.expressions_used.insert(break_if);
48                        }
49                        worklist.push(body);
50                        worklist.push(continuing);
51                    }
52                    St::Return { value: Some(value) } => {
53                        self.expressions_used.insert(value);
54                    }
55                    St::Store { pointer, value } => {
56                        self.expressions_used.insert(pointer);
57                        self.expressions_used.insert(value);
58                    }
59                    St::ImageStore {
60                        image,
61                        coordinate,
62                        array_index,
63                        value,
64                    } => {
65                        self.expressions_used.insert(image);
66                        self.expressions_used.insert(coordinate);
67                        if let Some(array_index) = array_index {
68                            self.expressions_used.insert(array_index);
69                        }
70                        self.expressions_used.insert(value);
71                    }
72                    St::Atomic {
73                        pointer,
74                        ref fun,
75                        value,
76                        result,
77                    } => {
78                        self.expressions_used.insert(pointer);
79                        self.trace_atomic_function(fun);
80                        self.expressions_used.insert(value);
81                        if let Some(result) = result {
82                            self.expressions_used.insert(result);
83                        }
84                    }
85                    St::ImageAtomic {
86                        image,
87                        coordinate,
88                        array_index,
89                        fun: _,
90                        value,
91                    } => {
92                        self.expressions_used.insert(image);
93                        self.expressions_used.insert(coordinate);
94                        if let Some(array_index) = array_index {
95                            self.expressions_used.insert(array_index);
96                        }
97                        self.expressions_used.insert(value);
98                    }
99                    St::WorkGroupUniformLoad { pointer, result } => {
100                        self.expressions_used.insert(pointer);
101                        self.expressions_used.insert(result);
102                    }
103                    St::Call {
104                        function,
105                        ref arguments,
106                        result,
107                    } => {
108                        self.trace_call(function);
109                        for expr in arguments {
110                            self.expressions_used.insert(*expr);
111                        }
112                        if let Some(result) = result {
113                            self.expressions_used.insert(result);
114                        }
115                    }
116                    St::RayQuery { query, ref fun } => {
117                        self.expressions_used.insert(query);
118                        self.trace_ray_query_function(fun);
119                    }
120                    St::SubgroupBallot { result, predicate } => {
121                        if let Some(predicate) = predicate {
122                            self.expressions_used.insert(predicate);
123                        }
124                        self.expressions_used.insert(result);
125                    }
126                    St::SubgroupCollectiveOperation {
127                        op: _,
128                        collective_op: _,
129                        argument,
130                        result,
131                    } => {
132                        self.expressions_used.insert(argument);
133                        self.expressions_used.insert(result);
134                    }
135                    St::SubgroupGather {
136                        mode,
137                        argument,
138                        result,
139                    } => {
140                        match mode {
141                            crate::GatherMode::BroadcastFirst => {}
142                            crate::GatherMode::Broadcast(index)
143                            | crate::GatherMode::Shuffle(index)
144                            | crate::GatherMode::ShuffleDown(index)
145                            | crate::GatherMode::ShuffleUp(index)
146                            | crate::GatherMode::ShuffleXor(index)
147                            | crate::GatherMode::QuadBroadcast(index) => {
148                                self.expressions_used.insert(index);
149                            }
150                            crate::GatherMode::QuadSwap(_) => {}
151                        }
152                        self.expressions_used.insert(argument);
153                        self.expressions_used.insert(result);
154                    }
155                    St::CooperativeStore { target, ref data } => {
156                        self.expressions_used.insert(target);
157                        self.expressions_used.insert(data.pointer);
158                        self.expressions_used.insert(data.stride);
159                    }
160                    St::RayPipelineFunction(func) => match func {
161                        crate::RayPipelineFunction::TraceRay {
162                            acceleration_structure,
163                            descriptor,
164                            payload,
165                        } => {
166                            self.expressions_used.insert(acceleration_structure);
167                            self.expressions_used.insert(descriptor);
168                            self.expressions_used.insert(payload);
169                        }
170                    },
171
172                    // Trivial statements.
173                    St::Break
174                    | St::Continue
175                    | St::Kill
176                    | St::ControlBarrier(_)
177                    | St::MemoryBarrier(_)
178                    | St::Return { value: None } => {}
179                }
180            }
181        }
182    }
183
184    fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) {
185        use crate::AtomicFunction as Af;
186        match *fun {
187            Af::Exchange {
188                compare: Some(expr),
189            } => {
190                self.expressions_used.insert(expr);
191            }
192            Af::Exchange { compare: None }
193            | Af::Add
194            | Af::Subtract
195            | Af::And
196            | Af::ExclusiveOr
197            | Af::InclusiveOr
198            | Af::Min
199            | Af::Max => {}
200        }
201    }
202
203    fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) {
204        use crate::RayQueryFunction as Qf;
205        match *fun {
206            Qf::Initialize {
207                acceleration_structure,
208                descriptor,
209            } => {
210                self.expressions_used.insert(acceleration_structure);
211                self.expressions_used.insert(descriptor);
212            }
213            Qf::Proceed { result } => {
214                self.expressions_used.insert(result);
215            }
216            Qf::GenerateIntersection { hit_t } => {
217                self.expressions_used.insert(hit_t);
218            }
219            Qf::ConfirmIntersection => {}
220            Qf::Terminate => {}
221        }
222    }
223}
224
225impl FunctionMap {
226    /// Adjust statements in the body of `function`.
227    ///
228    /// Adjusts expressions using `self.expressions`, and adjusts calls to other
229    /// functions using `function_map`.
230    pub fn adjust_body(
231        &self,
232        function: &mut crate::Function,
233        function_map: &HandleMap<crate::Function>,
234    ) {
235        let block = &mut function.body;
236        let mut worklist: Vec<&mut [crate::Statement]> = vec![block];
237        let adjust = |handle: &mut Handle<crate::Expression>| {
238            self.expressions.adjust(handle);
239        };
240        while let Some(last) = worklist.pop() {
241            for stmt in last {
242                use crate::Statement as St;
243                match *stmt {
244                    St::Emit(ref mut range) => {
245                        self.expressions.adjust_range(range, &function.expressions);
246                    }
247                    St::Block(ref mut block) => worklist.push(block),
248                    St::If {
249                        ref mut condition,
250                        ref mut accept,
251                        ref mut reject,
252                    } => {
253                        adjust(condition);
254                        worklist.push(accept);
255                        worklist.push(reject);
256                    }
257                    St::Switch {
258                        ref mut selector,
259                        ref mut cases,
260                    } => {
261                        adjust(selector);
262                        for case in cases {
263                            worklist.push(&mut case.body);
264                        }
265                    }
266                    St::Loop {
267                        ref mut body,
268                        ref mut continuing,
269                        ref mut break_if,
270                    } => {
271                        if let Some(ref mut break_if) = *break_if {
272                            adjust(break_if);
273                        }
274                        worklist.push(body);
275                        worklist.push(continuing);
276                    }
277                    St::Return {
278                        value: Some(ref mut value),
279                    } => adjust(value),
280                    St::Store {
281                        ref mut pointer,
282                        ref mut value,
283                    } => {
284                        adjust(pointer);
285                        adjust(value);
286                    }
287                    St::ImageStore {
288                        ref mut image,
289                        ref mut coordinate,
290                        ref mut array_index,
291                        ref mut value,
292                    } => {
293                        adjust(image);
294                        adjust(coordinate);
295                        if let Some(ref mut array_index) = *array_index {
296                            adjust(array_index);
297                        }
298                        adjust(value);
299                    }
300                    St::Atomic {
301                        ref mut pointer,
302                        ref mut fun,
303                        ref mut value,
304                        ref mut result,
305                    } => {
306                        adjust(pointer);
307                        self.adjust_atomic_function(fun);
308                        adjust(value);
309                        if let Some(ref mut result) = *result {
310                            adjust(result);
311                        }
312                    }
313                    St::ImageAtomic {
314                        ref mut image,
315                        ref mut coordinate,
316                        ref mut array_index,
317                        fun: _,
318                        ref mut value,
319                    } => {
320                        adjust(image);
321                        adjust(coordinate);
322                        if let Some(ref mut array_index) = *array_index {
323                            adjust(array_index);
324                        }
325                        adjust(value);
326                    }
327                    St::WorkGroupUniformLoad {
328                        ref mut pointer,
329                        ref mut result,
330                    } => {
331                        adjust(pointer);
332                        adjust(result);
333                    }
334                    St::Call {
335                        ref mut function,
336                        ref mut arguments,
337                        ref mut result,
338                    } => {
339                        function_map.adjust(function);
340                        for expr in arguments {
341                            adjust(expr);
342                        }
343                        if let Some(ref mut result) = *result {
344                            adjust(result);
345                        }
346                    }
347                    St::RayQuery {
348                        ref mut query,
349                        ref mut fun,
350                    } => {
351                        adjust(query);
352                        self.adjust_ray_query_function(fun);
353                    }
354                    St::SubgroupBallot {
355                        ref mut result,
356                        ref mut predicate,
357                    } => {
358                        if let Some(ref mut predicate) = *predicate {
359                            adjust(predicate);
360                        }
361                        adjust(result);
362                    }
363                    St::SubgroupCollectiveOperation {
364                        op: _,
365                        collective_op: _,
366                        ref mut argument,
367                        ref mut result,
368                    } => {
369                        adjust(argument);
370                        adjust(result);
371                    }
372                    St::SubgroupGather {
373                        ref mut mode,
374                        ref mut argument,
375                        ref mut result,
376                    } => {
377                        match *mode {
378                            crate::GatherMode::BroadcastFirst => {}
379                            crate::GatherMode::Broadcast(ref mut index)
380                            | crate::GatherMode::Shuffle(ref mut index)
381                            | crate::GatherMode::ShuffleDown(ref mut index)
382                            | crate::GatherMode::ShuffleUp(ref mut index)
383                            | crate::GatherMode::ShuffleXor(ref mut index)
384                            | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
385                            crate::GatherMode::QuadSwap(_) => {}
386                        }
387                        adjust(argument);
388                        adjust(result);
389                    }
390                    St::CooperativeStore {
391                        ref mut target,
392                        ref mut data,
393                    } => {
394                        adjust(target);
395                        adjust(&mut data.pointer);
396                        adjust(&mut data.stride);
397                    }
398                    St::RayPipelineFunction(ref mut func) => match *func {
399                        crate::RayPipelineFunction::TraceRay {
400                            ref mut acceleration_structure,
401                            ref mut descriptor,
402                            ref mut payload,
403                        } => {
404                            adjust(acceleration_structure);
405                            adjust(descriptor);
406                            adjust(payload);
407                        }
408                    },
409
410                    // Trivial statements.
411                    St::Break
412                    | St::Continue
413                    | St::Kill
414                    | St::ControlBarrier(_)
415                    | St::MemoryBarrier(_)
416                    | St::Return { value: None } => {}
417                }
418            }
419        }
420    }
421
422    fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) {
423        use crate::AtomicFunction as Af;
424        match *fun {
425            Af::Exchange {
426                compare: Some(ref mut expr),
427            } => {
428                self.expressions.adjust(expr);
429            }
430            Af::Exchange { compare: None }
431            | Af::Add
432            | Af::Subtract
433            | Af::And
434            | Af::ExclusiveOr
435            | Af::InclusiveOr
436            | Af::Min
437            | Af::Max => {}
438        }
439    }
440
441    fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) {
442        use crate::RayQueryFunction as Qf;
443        match *fun {
444            Qf::Initialize {
445                ref mut acceleration_structure,
446                ref mut descriptor,
447            } => {
448                self.expressions.adjust(acceleration_structure);
449                self.expressions.adjust(descriptor);
450            }
451            Qf::Proceed { ref mut result } => {
452                self.expressions.adjust(result);
453            }
454            Qf::GenerateIntersection { ref mut hit_t } => {
455                self.expressions.adjust(hit_t);
456            }
457            Qf::ConfirmIntersection => {}
458            Qf::Terminate => {}
459        }
460    }
461}