1use alloc::{vec, vec::Vec};
2
3use super::functions::FunctionTracer;
4use super::FunctionMap;
5use crate::arena::Handle;
6use crate::compact::handle_set_map::HandleMap;
7
8impl FunctionTracer<'_> {
9 pub fn trace_block(&mut self, block: &[crate::Statement]) {
10 let mut worklist: Vec<&[crate::Statement]> = vec![block];
11 while let Some(last) = worklist.pop() {
12 for stmt in last {
13 use crate::Statement as St;
14 match *stmt {
15 St::Emit(ref _range) => {
16 }
22 St::Block(ref block) => worklist.push(block),
23 St::If {
24 condition,
25 ref accept,
26 ref reject,
27 } => {
28 self.expressions_used.insert(condition);
29 worklist.push(accept);
30 worklist.push(reject);
31 }
32 St::Switch {
33 selector,
34 ref cases,
35 } => {
36 self.expressions_used.insert(selector);
37 for case in cases {
38 worklist.push(&case.body);
39 }
40 }
41 St::Loop {
42 ref body,
43 ref continuing,
44 break_if,
45 } => {
46 if let Some(break_if) = break_if {
47 self.expressions_used.insert(break_if);
48 }
49 worklist.push(body);
50 worklist.push(continuing);
51 }
52 St::Return { value: Some(value) } => {
53 self.expressions_used.insert(value);
54 }
55 St::Store { pointer, value } => {
56 self.expressions_used.insert(pointer);
57 self.expressions_used.insert(value);
58 }
59 St::ImageStore {
60 image,
61 coordinate,
62 array_index,
63 value,
64 } => {
65 self.expressions_used.insert(image);
66 self.expressions_used.insert(coordinate);
67 if let Some(array_index) = array_index {
68 self.expressions_used.insert(array_index);
69 }
70 self.expressions_used.insert(value);
71 }
72 St::Atomic {
73 pointer,
74 ref fun,
75 value,
76 result,
77 } => {
78 self.expressions_used.insert(pointer);
79 self.trace_atomic_function(fun);
80 self.expressions_used.insert(value);
81 if let Some(result) = result {
82 self.expressions_used.insert(result);
83 }
84 }
85 St::ImageAtomic {
86 image,
87 coordinate,
88 array_index,
89 fun: _,
90 value,
91 } => {
92 self.expressions_used.insert(image);
93 self.expressions_used.insert(coordinate);
94 if let Some(array_index) = array_index {
95 self.expressions_used.insert(array_index);
96 }
97 self.expressions_used.insert(value);
98 }
99 St::WorkGroupUniformLoad { pointer, result } => {
100 self.expressions_used.insert(pointer);
101 self.expressions_used.insert(result);
102 }
103 St::Call {
104 function,
105 ref arguments,
106 result,
107 } => {
108 self.trace_call(function);
109 for expr in arguments {
110 self.expressions_used.insert(*expr);
111 }
112 if let Some(result) = result {
113 self.expressions_used.insert(result);
114 }
115 }
116 St::RayQuery { query, ref fun } => {
117 self.expressions_used.insert(query);
118 self.trace_ray_query_function(fun);
119 }
120 St::SubgroupBallot { result, predicate } => {
121 if let Some(predicate) = predicate {
122 self.expressions_used.insert(predicate);
123 }
124 self.expressions_used.insert(result);
125 }
126 St::SubgroupCollectiveOperation {
127 op: _,
128 collective_op: _,
129 argument,
130 result,
131 } => {
132 self.expressions_used.insert(argument);
133 self.expressions_used.insert(result);
134 }
135 St::SubgroupGather {
136 mode,
137 argument,
138 result,
139 } => {
140 match mode {
141 crate::GatherMode::BroadcastFirst => {}
142 crate::GatherMode::Broadcast(index)
143 | crate::GatherMode::Shuffle(index)
144 | crate::GatherMode::ShuffleDown(index)
145 | crate::GatherMode::ShuffleUp(index)
146 | crate::GatherMode::ShuffleXor(index)
147 | crate::GatherMode::QuadBroadcast(index) => {
148 self.expressions_used.insert(index);
149 }
150 crate::GatherMode::QuadSwap(_) => {}
151 }
152 self.expressions_used.insert(argument);
153 self.expressions_used.insert(result);
154 }
155 St::CooperativeStore { target, ref data } => {
156 self.expressions_used.insert(target);
157 self.expressions_used.insert(data.pointer);
158 self.expressions_used.insert(data.stride);
159 }
160 St::RayPipelineFunction(func) => match func {
161 crate::RayPipelineFunction::TraceRay {
162 acceleration_structure,
163 descriptor,
164 payload,
165 } => {
166 self.expressions_used.insert(acceleration_structure);
167 self.expressions_used.insert(descriptor);
168 self.expressions_used.insert(payload);
169 }
170 },
171
172 St::Break
174 | St::Continue
175 | St::Kill
176 | St::ControlBarrier(_)
177 | St::MemoryBarrier(_)
178 | St::Return { value: None } => {}
179 }
180 }
181 }
182 }
183
184 fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) {
185 use crate::AtomicFunction as Af;
186 match *fun {
187 Af::Exchange {
188 compare: Some(expr),
189 } => {
190 self.expressions_used.insert(expr);
191 }
192 Af::Exchange { compare: None }
193 | Af::Add
194 | Af::Subtract
195 | Af::And
196 | Af::ExclusiveOr
197 | Af::InclusiveOr
198 | Af::Min
199 | Af::Max => {}
200 }
201 }
202
203 fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) {
204 use crate::RayQueryFunction as Qf;
205 match *fun {
206 Qf::Initialize {
207 acceleration_structure,
208 descriptor,
209 } => {
210 self.expressions_used.insert(acceleration_structure);
211 self.expressions_used.insert(descriptor);
212 }
213 Qf::Proceed { result } => {
214 self.expressions_used.insert(result);
215 }
216 Qf::GenerateIntersection { hit_t } => {
217 self.expressions_used.insert(hit_t);
218 }
219 Qf::ConfirmIntersection => {}
220 Qf::Terminate => {}
221 }
222 }
223}
224
225impl FunctionMap {
226 pub fn adjust_body(
231 &self,
232 function: &mut crate::Function,
233 function_map: &HandleMap<crate::Function>,
234 ) {
235 let block = &mut function.body;
236 let mut worklist: Vec<&mut [crate::Statement]> = vec![block];
237 let adjust = |handle: &mut Handle<crate::Expression>| {
238 self.expressions.adjust(handle);
239 };
240 while let Some(last) = worklist.pop() {
241 for stmt in last {
242 use crate::Statement as St;
243 match *stmt {
244 St::Emit(ref mut range) => {
245 self.expressions.adjust_range(range, &function.expressions);
246 }
247 St::Block(ref mut block) => worklist.push(block),
248 St::If {
249 ref mut condition,
250 ref mut accept,
251 ref mut reject,
252 } => {
253 adjust(condition);
254 worklist.push(accept);
255 worklist.push(reject);
256 }
257 St::Switch {
258 ref mut selector,
259 ref mut cases,
260 } => {
261 adjust(selector);
262 for case in cases {
263 worklist.push(&mut case.body);
264 }
265 }
266 St::Loop {
267 ref mut body,
268 ref mut continuing,
269 ref mut break_if,
270 } => {
271 if let Some(ref mut break_if) = *break_if {
272 adjust(break_if);
273 }
274 worklist.push(body);
275 worklist.push(continuing);
276 }
277 St::Return {
278 value: Some(ref mut value),
279 } => adjust(value),
280 St::Store {
281 ref mut pointer,
282 ref mut value,
283 } => {
284 adjust(pointer);
285 adjust(value);
286 }
287 St::ImageStore {
288 ref mut image,
289 ref mut coordinate,
290 ref mut array_index,
291 ref mut value,
292 } => {
293 adjust(image);
294 adjust(coordinate);
295 if let Some(ref mut array_index) = *array_index {
296 adjust(array_index);
297 }
298 adjust(value);
299 }
300 St::Atomic {
301 ref mut pointer,
302 ref mut fun,
303 ref mut value,
304 ref mut result,
305 } => {
306 adjust(pointer);
307 self.adjust_atomic_function(fun);
308 adjust(value);
309 if let Some(ref mut result) = *result {
310 adjust(result);
311 }
312 }
313 St::ImageAtomic {
314 ref mut image,
315 ref mut coordinate,
316 ref mut array_index,
317 fun: _,
318 ref mut value,
319 } => {
320 adjust(image);
321 adjust(coordinate);
322 if let Some(ref mut array_index) = *array_index {
323 adjust(array_index);
324 }
325 adjust(value);
326 }
327 St::WorkGroupUniformLoad {
328 ref mut pointer,
329 ref mut result,
330 } => {
331 adjust(pointer);
332 adjust(result);
333 }
334 St::Call {
335 ref mut function,
336 ref mut arguments,
337 ref mut result,
338 } => {
339 function_map.adjust(function);
340 for expr in arguments {
341 adjust(expr);
342 }
343 if let Some(ref mut result) = *result {
344 adjust(result);
345 }
346 }
347 St::RayQuery {
348 ref mut query,
349 ref mut fun,
350 } => {
351 adjust(query);
352 self.adjust_ray_query_function(fun);
353 }
354 St::SubgroupBallot {
355 ref mut result,
356 ref mut predicate,
357 } => {
358 if let Some(ref mut predicate) = *predicate {
359 adjust(predicate);
360 }
361 adjust(result);
362 }
363 St::SubgroupCollectiveOperation {
364 op: _,
365 collective_op: _,
366 ref mut argument,
367 ref mut result,
368 } => {
369 adjust(argument);
370 adjust(result);
371 }
372 St::SubgroupGather {
373 ref mut mode,
374 ref mut argument,
375 ref mut result,
376 } => {
377 match *mode {
378 crate::GatherMode::BroadcastFirst => {}
379 crate::GatherMode::Broadcast(ref mut index)
380 | crate::GatherMode::Shuffle(ref mut index)
381 | crate::GatherMode::ShuffleDown(ref mut index)
382 | crate::GatherMode::ShuffleUp(ref mut index)
383 | crate::GatherMode::ShuffleXor(ref mut index)
384 | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
385 crate::GatherMode::QuadSwap(_) => {}
386 }
387 adjust(argument);
388 adjust(result);
389 }
390 St::CooperativeStore {
391 ref mut target,
392 ref mut data,
393 } => {
394 adjust(target);
395 adjust(&mut data.pointer);
396 adjust(&mut data.stride);
397 }
398 St::RayPipelineFunction(ref mut func) => match *func {
399 crate::RayPipelineFunction::TraceRay {
400 ref mut acceleration_structure,
401 ref mut descriptor,
402 ref mut payload,
403 } => {
404 adjust(acceleration_structure);
405 adjust(descriptor);
406 adjust(payload);
407 }
408 },
409
410 St::Break
412 | St::Continue
413 | St::Kill
414 | St::ControlBarrier(_)
415 | St::MemoryBarrier(_)
416 | St::Return { value: None } => {}
417 }
418 }
419 }
420 }
421
422 fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) {
423 use crate::AtomicFunction as Af;
424 match *fun {
425 Af::Exchange {
426 compare: Some(ref mut expr),
427 } => {
428 self.expressions.adjust(expr);
429 }
430 Af::Exchange { compare: None }
431 | Af::Add
432 | Af::Subtract
433 | Af::And
434 | Af::ExclusiveOr
435 | Af::InclusiveOr
436 | Af::Min
437 | Af::Max => {}
438 }
439 }
440
441 fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) {
442 use crate::RayQueryFunction as Qf;
443 match *fun {
444 Qf::Initialize {
445 ref mut acceleration_structure,
446 ref mut descriptor,
447 } => {
448 self.expressions.adjust(acceleration_structure);
449 self.expressions.adjust(descriptor);
450 }
451 Qf::Proceed { ref mut result } => {
452 self.expressions.adjust(result);
453 }
454 Qf::GenerateIntersection { ref mut hit_t } => {
455 self.expressions.adjust(hit_t);
456 }
457 Qf::ConfirmIntersection => {}
458 Qf::Terminate => {}
459 }
460 }
461}