naga/front/spv/
function.rs

1use alloc::{format, vec, vec::Vec};
2
3use super::{Error, Instruction, LookupExpression, LookupHelper as _};
4use crate::proc::Emitter;
5use crate::{
6    arena::{Arena, Handle},
7    front::spv::{BlockContext, BodyIndex},
8};
9
10pub type BlockId = u32;
11
12impl<I: Iterator<Item = u32>> super::Frontend<I> {
13    // Registers a function call. It will generate a dummy handle to call, which
14    // gets resolved after all the functions are processed.
15    pub(super) fn add_call(
16        &mut self,
17        from: spirv::Word,
18        to: spirv::Word,
19    ) -> Handle<crate::Function> {
20        let dummy_handle = self
21            .dummy_functions
22            .append(crate::Function::default(), Default::default());
23        self.deferred_function_calls.push(to);
24        self.function_call_graph.add_edge(from, to, ());
25        dummy_handle
26    }
27
28    pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
29        let start = self.data_offset;
30        self.lookup_expression.clear();
31        self.lookup_load_override.clear();
32        self.lookup_sampled_image.clear();
33
34        let result_type_id = self.next()?;
35        let fun_id = self.next()?;
36        let _fun_control = self.next()?;
37        let fun_type_id = self.next()?;
38
39        let mut fun = {
40            let ft = self.lookup_function_type.lookup(fun_type_id)?;
41            if ft.return_type_id != result_type_id {
42                return Err(Error::WrongFunctionResultType(result_type_id));
43            }
44            crate::Function {
45                name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
46                arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
47                result: if self.lookup_void_type == Some(result_type_id) {
48                    None
49                } else {
50                    let lookup_result_ty = self.lookup_type.lookup(result_type_id)?;
51                    Some(crate::FunctionResult {
52                        ty: lookup_result_ty.handle,
53                        binding: None,
54                    })
55                },
56                local_variables: Arena::new(),
57                expressions: self.make_expression_storage(
58                    &module.global_variables,
59                    &module.constants,
60                    &module.overrides,
61                ),
62                named_expressions: crate::NamedExpressions::default(),
63                body: crate::Block::new(),
64                diagnostic_filter_leaf: None,
65            }
66        };
67
68        // read parameters
69        for i in 0..fun.arguments.capacity() {
70            let start = self.data_offset;
71            match self.next_inst()? {
72                Instruction {
73                    op: spirv::Op::FunctionParameter,
74                    wc: 3,
75                } => {
76                    let type_id = self.next()?;
77                    let id = self.next()?;
78                    let handle = fun.expressions.append(
79                        crate::Expression::FunctionArgument(i as u32),
80                        self.span_from(start),
81                    );
82                    self.lookup_expression.insert(
83                        id,
84                        LookupExpression {
85                            handle,
86                            type_id,
87                            // Setting this to an invalid id will cause get_expr_handle
88                            // to default to the main body making sure no load/stores
89                            // are added.
90                            block_id: 0,
91                        },
92                    );
93                    //Note: we redo the lookup in order to work around `self` borrowing
94
95                    if type_id
96                        != self
97                            .lookup_function_type
98                            .lookup(fun_type_id)?
99                            .parameter_type_ids[i]
100                    {
101                        return Err(Error::WrongFunctionArgumentType(type_id));
102                    }
103                    let ty = self.lookup_type.lookup(type_id)?.handle;
104                    let decor = self.future_decor.remove(&id).unwrap_or_default();
105                    fun.arguments.push(crate::FunctionArgument {
106                        name: decor.name,
107                        ty,
108                        binding: None,
109                    });
110                }
111                Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
112            }
113        }
114
115        // Note the index this function's handle will be assigned, for tracing.
116        let function_index = module.functions.len();
117
118        // Read body
119        self.function_call_graph.add_node(fun_id);
120        let mut parameters_sampling =
121            vec![super::image::SamplingFlags::empty(); fun.arguments.len()];
122
123        let mut block_ctx = BlockContext {
124            phis: Default::default(),
125            blocks: Default::default(),
126            body_for_label: Default::default(),
127            mergers: Default::default(),
128            bodies: Default::default(),
129            module,
130            function_id: fun_id,
131            expressions: &mut fun.expressions,
132            local_arena: &mut fun.local_variables,
133            arguments: &fun.arguments,
134            parameter_sampling: &mut parameters_sampling,
135        };
136        // Insert the main body whose parent is also himself
137        block_ctx.bodies.push(super::Body::with_parent(0));
138
139        // Scan the blocks and add them as nodes
140        loop {
141            let fun_inst = self.next_inst()?;
142            log::debug!("{:?}", fun_inst.op);
143            match fun_inst.op {
144                spirv::Op::Line => {
145                    fun_inst.expect(4)?;
146                    let _file_id = self.next()?;
147                    let _row_id = self.next()?;
148                    let _col_id = self.next()?;
149                }
150                spirv::Op::Label => {
151                    // Read the label ID
152                    fun_inst.expect(2)?;
153                    let block_id = self.next()?;
154
155                    self.next_block(block_id, &mut block_ctx)?;
156                }
157                spirv::Op::FunctionEnd => {
158                    fun_inst.expect(1)?;
159                    break;
160                }
161                _ => {
162                    return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
163                }
164            }
165        }
166
167        if let Some(ref prefix) = self.options.block_ctx_dump_prefix {
168            let dump_suffix = match self.lookup_entry_point.get(&fun_id) {
169                Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name),
170                None => format!("block_ctx.Fun-{function_index}.txt"),
171            };
172
173            cfg_if::cfg_if! {
174                if #[cfg(feature = "fs")] {
175                    let prefix: &std::path::Path = prefix.as_ref();
176                    let dest = prefix.join(dump_suffix);
177                    let dump = format!("{block_ctx:#?}");
178                    if let Err(e) = std::fs::write(&dest, dump) {
179                        log::error!("Unable to dump the block context into {dest:?}: {e}");
180                    }
181                } else {
182                    log::error!("Unable to dump the block context into {prefix:?}/{dump_suffix}: file system integration was not enabled with the `fs` feature");
183                }
184            }
185        }
186
187        // Emit `Store` statements to properly initialize all the local variables we
188        // created for `phi` expressions.
189        //
190        // Note that get_expr_handle also contributes slightly odd entries to this table,
191        // to get the spill.
192        for phi in block_ctx.phis.iter() {
193            // Get a pointer to the local variable for the phi's value.
194            let phi_pointer: Handle<crate::Expression> = block_ctx.expressions.append(
195                crate::Expression::LocalVariable(phi.local),
196                crate::Span::default(),
197            );
198
199            // At the end of each of `phi`'s predecessor blocks, store the corresponding
200            // source value in the phi's local variable.
201            for &(source, predecessor) in phi.expressions.iter() {
202                let source_lexp = &self.lookup_expression[&source];
203                let predecessor_body_idx = block_ctx.body_for_label[&predecessor];
204                // If the expression is a global/argument it will have a 0 block
205                // id so we must use a default value instead of panicking
206                let source_body_idx = block_ctx
207                    .body_for_label
208                    .get(&source_lexp.block_id)
209                    .copied()
210                    .unwrap_or(0);
211
212                // If the Naga `Expression` generated for `source` is in scope, then we
213                // can simply store that in the phi's local variable.
214                //
215                // Otherwise, spill the source value to a local variable in the block that
216                // defines it. (We know this store dominates the predecessor; otherwise,
217                // the phi wouldn't have been able to refer to that source expression in
218                // the first place.) Then, the predecessor block can count on finding the
219                // source's value in that local variable.
220                let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) {
221                    source_lexp.handle
222                } else {
223                    // The source SPIR-V expression is not defined in the phi's
224                    // predecessor block, nor is it a globally available expression. So it
225                    // must be defined off in some other block that merely dominates the
226                    // predecessor. This means that the corresponding Naga `Expression`
227                    // may not be in scope in the predecessor block.
228                    //
229                    // In the block that defines `source`, spill it to a fresh local
230                    // variable, to ensure we can still use it at the end of the
231                    // predecessor.
232                    let ty = self.lookup_type[&source_lexp.type_id].handle;
233                    let local = block_ctx.local_arena.append(
234                        crate::LocalVariable {
235                            name: None,
236                            ty,
237                            init: None,
238                        },
239                        crate::Span::default(),
240                    );
241
242                    let pointer = block_ctx.expressions.append(
243                        crate::Expression::LocalVariable(local),
244                        crate::Span::default(),
245                    );
246
247                    // Get the spilled value of the source expression.
248                    let start = block_ctx.expressions.len();
249                    let expr = block_ctx
250                        .expressions
251                        .append(crate::Expression::Load { pointer }, crate::Span::default());
252                    let range = block_ctx.expressions.range_from(start);
253
254                    block_ctx
255                        .blocks
256                        .get_mut(&predecessor)
257                        .unwrap()
258                        .push(crate::Statement::Emit(range), crate::Span::default());
259
260                    // At the end of the block that defines it, spill the source
261                    // expression's value.
262                    block_ctx
263                        .blocks
264                        .get_mut(&source_lexp.block_id)
265                        .unwrap()
266                        .push(
267                            crate::Statement::Store {
268                                pointer,
269                                value: source_lexp.handle,
270                            },
271                            crate::Span::default(),
272                        );
273
274                    expr
275                };
276
277                // At the end of the phi predecessor block, store the source
278                // value in the phi's value.
279                block_ctx.blocks.get_mut(&predecessor).unwrap().push(
280                    crate::Statement::Store {
281                        pointer: phi_pointer,
282                        value,
283                    },
284                    crate::Span::default(),
285                )
286            }
287        }
288
289        fun.body = block_ctx.lower();
290
291        // done
292        let fun_handle = module.functions.append(fun, self.span_from_with_op(start));
293        self.lookup_function.insert(
294            fun_id,
295            super::LookupFunction {
296                handle: fun_handle,
297                parameters_sampling,
298            },
299        );
300
301        if let Some(ep) = self.lookup_entry_point.remove(&fun_id) {
302            self.deferred_entry_points.push((ep, fun_id));
303        }
304
305        Ok(())
306    }
307
308    pub(super) fn process_entry_point(
309        &mut self,
310        module: &mut crate::Module,
311        ep: super::EntryPoint,
312        fun_id: u32,
313    ) -> Result<(), Error> {
314        // create a wrapping function
315        let mut function = crate::Function {
316            name: Some(format!("{}_wrap", ep.name)),
317            arguments: Vec::new(),
318            result: None,
319            local_variables: Arena::new(),
320            expressions: Arena::new(),
321            named_expressions: crate::NamedExpressions::default(),
322            body: crate::Block::new(),
323            diagnostic_filter_leaf: None,
324        };
325
326        // 1. copy the inputs from arguments to privates
327        for &v_id in ep.variable_ids.iter() {
328            let lvar = self.lookup_variable.lookup(v_id)?;
329            if let super::Variable::Input(ref arg) = lvar.inner {
330                let span = module.global_variables.get_span(lvar.handle);
331                let arg_expr = function.expressions.append(
332                    crate::Expression::FunctionArgument(function.arguments.len() as u32),
333                    span,
334                );
335                let load_expr = if arg.ty == module.global_variables[lvar.handle].ty {
336                    arg_expr
337                } else {
338                    // The only case where the type is different is if we need to treat
339                    // unsigned integer as signed.
340                    let mut emitter = Emitter::default();
341                    emitter.start(&function.expressions);
342                    let handle = function.expressions.append(
343                        crate::Expression::As {
344                            expr: arg_expr,
345                            kind: crate::ScalarKind::Sint,
346                            convert: Some(4),
347                        },
348                        span,
349                    );
350                    function.body.extend(emitter.finish(&function.expressions));
351                    handle
352                };
353                function.body.push(
354                    crate::Statement::Store {
355                        pointer: function
356                            .expressions
357                            .append(crate::Expression::GlobalVariable(lvar.handle), span),
358                        value: load_expr,
359                    },
360                    span,
361                );
362
363                let mut arg = arg.clone();
364                if ep.stage == crate::ShaderStage::Fragment {
365                    if let Some(ref mut binding) = arg.binding {
366                        binding.apply_default_interpolation(&module.types[arg.ty].inner);
367                    }
368                }
369                function.arguments.push(arg);
370            }
371        }
372        // 2. call the wrapped function
373        let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision
374        let dummy_handle = self.add_call(fake_id, fun_id);
375        function.body.push(
376            crate::Statement::Call {
377                function: dummy_handle,
378                arguments: Vec::new(),
379                result: None,
380            },
381            crate::Span::default(),
382        );
383
384        // 3. copy the outputs from privates to the result
385        //
386        // It would be nice to share struct layout code here with `parse_type_struct`,
387        // but that case needs to take into account offset decorations, which makes an
388        // abstraction harder to follow than just writing out what we mean. `Layouter`
389        // and `Alignment` cover the worst parts already.
390        let mut members = Vec::new();
391        self.layouter.update(module.to_ctx()).unwrap();
392        let mut next_member_offset = 0;
393        let mut struct_alignment = crate::proc::Alignment::ONE;
394        let mut components = Vec::new();
395        for &v_id in ep.variable_ids.iter() {
396            let lvar = self.lookup_variable.lookup(v_id)?;
397            if let super::Variable::Output(ref result) = lvar.inner {
398                let span = module.global_variables.get_span(lvar.handle);
399                let expr_handle = function
400                    .expressions
401                    .append(crate::Expression::GlobalVariable(lvar.handle), span);
402
403                // Cull problematic builtins of gl_PerVertex.
404                // See the docs for `Frontend::gl_per_vertex_builtin_access`.
405                {
406                    let ty = &module.types[result.ty];
407                    if let crate::TypeInner::Struct {
408                        members: ref original_members,
409                        span,
410                    } = ty.inner
411                    {
412                        let mut new_members = None;
413                        for (idx, member) in original_members.iter().enumerate() {
414                            if let Some(crate::Binding::BuiltIn(built_in)) = member.binding {
415                                if !self.gl_per_vertex_builtin_access.contains(&built_in) {
416                                    new_members.get_or_insert_with(|| original_members.clone())
417                                        [idx]
418                                        .binding = None;
419                                }
420                            }
421                        }
422                        if let Some(new_members) = new_members {
423                            module.types.replace(
424                                result.ty,
425                                crate::Type {
426                                    name: ty.name.clone(),
427                                    inner: crate::TypeInner::Struct {
428                                        members: new_members,
429                                        span,
430                                    },
431                                },
432                            );
433                        }
434                    }
435                }
436
437                match module.types[result.ty].inner {
438                    crate::TypeInner::Struct {
439                        members: ref sub_members,
440                        ..
441                    } => {
442                        for (index, sm) in sub_members.iter().enumerate() {
443                            if sm.binding.is_none() {
444                                continue;
445                            }
446                            let mut sm = sm.clone();
447
448                            if let Some(ref mut binding) = sm.binding {
449                                if ep.stage == crate::ShaderStage::Vertex {
450                                    binding.apply_default_interpolation(&module.types[sm.ty].inner);
451                                }
452                            }
453
454                            let member_alignment = self.layouter[sm.ty].alignment;
455                            next_member_offset = member_alignment.round_up(next_member_offset);
456                            sm.offset = next_member_offset;
457                            struct_alignment = struct_alignment.max(member_alignment);
458                            next_member_offset += self.layouter[sm.ty].size;
459                            members.push(sm);
460
461                            components.push(function.expressions.append(
462                                crate::Expression::AccessIndex {
463                                    base: expr_handle,
464                                    index: index as u32,
465                                },
466                                span,
467                            ));
468                        }
469                    }
470                    ref inner => {
471                        let mut binding = result.binding.clone();
472                        if let Some(ref mut binding) = binding {
473                            if ep.stage == crate::ShaderStage::Vertex {
474                                binding.apply_default_interpolation(inner);
475                            }
476                        }
477
478                        let member_alignment = self.layouter[result.ty].alignment;
479                        next_member_offset = member_alignment.round_up(next_member_offset);
480                        members.push(crate::StructMember {
481                            name: None,
482                            ty: result.ty,
483                            binding,
484                            offset: next_member_offset,
485                        });
486                        struct_alignment = struct_alignment.max(member_alignment);
487                        next_member_offset += self.layouter[result.ty].size;
488                        // populate just the globals first, then do `Load` in a
489                        // separate step, so that we can get a range.
490                        components.push(expr_handle);
491                    }
492                }
493            }
494        }
495
496        for (member_index, member) in members.iter().enumerate() {
497            match member.binding {
498                Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. }))
499                    if self.options.adjust_coordinate_space =>
500                {
501                    let mut emitter = Emitter::default();
502                    emitter.start(&function.expressions);
503                    let global_expr = components[member_index];
504                    let span = function.expressions.get_span(global_expr);
505                    let access_expr = function.expressions.append(
506                        crate::Expression::AccessIndex {
507                            base: global_expr,
508                            index: 1,
509                        },
510                        span,
511                    );
512                    let load_expr = function.expressions.append(
513                        crate::Expression::Load {
514                            pointer: access_expr,
515                        },
516                        span,
517                    );
518                    let neg_expr = function.expressions.append(
519                        crate::Expression::Unary {
520                            op: crate::UnaryOperator::Negate,
521                            expr: load_expr,
522                        },
523                        span,
524                    );
525                    function.body.extend(emitter.finish(&function.expressions));
526                    function.body.push(
527                        crate::Statement::Store {
528                            pointer: access_expr,
529                            value: neg_expr,
530                        },
531                        span,
532                    );
533                }
534                _ => {}
535            }
536        }
537
538        let mut emitter = Emitter::default();
539        emitter.start(&function.expressions);
540        for component in components.iter_mut() {
541            let load_expr = crate::Expression::Load {
542                pointer: *component,
543            };
544            let span = function.expressions.get_span(*component);
545            *component = function.expressions.append(load_expr, span);
546        }
547
548        match members[..] {
549            [] => {}
550            [ref member] => {
551                function.body.extend(emitter.finish(&function.expressions));
552                let span = function.expressions.get_span(components[0]);
553                function.body.push(
554                    crate::Statement::Return {
555                        value: components.first().cloned(),
556                    },
557                    span,
558                );
559                function.result = Some(crate::FunctionResult {
560                    ty: member.ty,
561                    binding: member.binding.clone(),
562                });
563            }
564            _ => {
565                let span = crate::Span::total_span(
566                    components.iter().map(|h| function.expressions.get_span(*h)),
567                );
568                let ty = module.types.insert(
569                    crate::Type {
570                        name: None,
571                        inner: crate::TypeInner::Struct {
572                            members,
573                            span: struct_alignment.round_up(next_member_offset),
574                        },
575                    },
576                    span,
577                );
578                let result_expr = function
579                    .expressions
580                    .append(crate::Expression::Compose { ty, components }, span);
581                function.body.extend(emitter.finish(&function.expressions));
582                function.body.push(
583                    crate::Statement::Return {
584                        value: Some(result_expr),
585                    },
586                    span,
587                );
588                function.result = Some(crate::FunctionResult { ty, binding: None });
589            }
590        }
591
592        module.entry_points.push(crate::EntryPoint {
593            name: ep.name,
594            stage: ep.stage,
595            early_depth_test: ep.early_depth_test,
596            workgroup_size: ep.workgroup_size,
597            workgroup_size_overrides: None,
598            function,
599        });
600
601        Ok(())
602    }
603}
604
605impl BlockContext<'_> {
606    pub(super) fn gctx(&self) -> crate::proc::GlobalCtx<'_> {
607        crate::proc::GlobalCtx {
608            types: &self.module.types,
609            constants: &self.module.constants,
610            overrides: &self.module.overrides,
611            global_expressions: &self.module.global_expressions,
612        }
613    }
614
615    /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
616    fn lower(mut self) -> crate::Block {
617        fn lower_impl(
618            blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
619            bodies: &[super::Body],
620            body_idx: BodyIndex,
621        ) -> crate::Block {
622            let mut block = crate::Block::new();
623
624            for item in bodies[body_idx].data.iter() {
625                match *item {
626                    super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
627                    super::BodyFragment::If {
628                        condition,
629                        accept,
630                        reject,
631                    } => {
632                        let accept = lower_impl(blocks, bodies, accept);
633                        let reject = lower_impl(blocks, bodies, reject);
634
635                        block.push(
636                            crate::Statement::If {
637                                condition,
638                                accept,
639                                reject,
640                            },
641                            crate::Span::default(),
642                        )
643                    }
644                    super::BodyFragment::Loop {
645                        body,
646                        continuing,
647                        break_if,
648                    } => {
649                        let body = lower_impl(blocks, bodies, body);
650                        let continuing = lower_impl(blocks, bodies, continuing);
651
652                        block.push(
653                            crate::Statement::Loop {
654                                body,
655                                continuing,
656                                break_if,
657                            },
658                            crate::Span::default(),
659                        )
660                    }
661                    super::BodyFragment::Switch {
662                        selector,
663                        ref cases,
664                        default,
665                    } => {
666                        let mut ir_cases: Vec<_> = cases
667                            .iter()
668                            .map(|&(value, body_idx)| {
669                                let body = lower_impl(blocks, bodies, body_idx);
670
671                                // Handle simple cases that would make a fallthrough statement unreachable code
672                                let fall_through = body.last().is_none_or(|s| !s.is_terminator());
673
674                                crate::SwitchCase {
675                                    value: crate::SwitchValue::I32(value),
676                                    body,
677                                    fall_through,
678                                }
679                            })
680                            .collect();
681                        ir_cases.push(crate::SwitchCase {
682                            value: crate::SwitchValue::Default,
683                            body: lower_impl(blocks, bodies, default),
684                            fall_through: false,
685                        });
686
687                        block.push(
688                            crate::Statement::Switch {
689                                selector,
690                                cases: ir_cases,
691                            },
692                            crate::Span::default(),
693                        )
694                    }
695                    super::BodyFragment::Break => {
696                        block.push(crate::Statement::Break, crate::Span::default())
697                    }
698                    super::BodyFragment::Continue => {
699                        block.push(crate::Statement::Continue, crate::Span::default())
700                    }
701                }
702            }
703
704            block
705        }
706
707        lower_impl(&mut self.blocks, &self.bodies, 0)
708    }
709}