naga/front/spv/
function.rs

1use alloc::{format, vec, vec::Vec};
2
3use super::{Error, Instruction, LookupExpression, LookupHelper as _};
4use crate::proc::Emitter;
5use crate::{
6    arena::{Arena, Handle},
7    front::spv::{BlockContext, BodyIndex},
8};
9
10pub type BlockId = u32;
11
12impl<I: Iterator<Item = u32>> super::Frontend<I> {
13    // Registers a function call. It will generate a dummy handle to call, which
14    // gets resolved after all the functions are processed.
15    pub(super) fn add_call(
16        &mut self,
17        from: spirv::Word,
18        to: spirv::Word,
19    ) -> Handle<crate::Function> {
20        let dummy_handle = self
21            .dummy_functions
22            .append(crate::Function::default(), Default::default());
23        self.deferred_function_calls.push(to);
24        self.function_call_graph.add_edge(from, to, ());
25        dummy_handle
26    }
27
28    pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
29        let start = self.data_offset;
30        self.lookup_expression.clear();
31        self.lookup_load_override.clear();
32        self.lookup_sampled_image.clear();
33
34        let result_type_id = self.next()?;
35        let fun_id = self.next()?;
36        let _fun_control = self.next()?;
37        let fun_type_id = self.next()?;
38
39        let mut fun = {
40            let ft = self.lookup_function_type.lookup(fun_type_id)?;
41            if ft.return_type_id != result_type_id {
42                return Err(Error::WrongFunctionResultType(result_type_id));
43            }
44            crate::Function {
45                name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
46                arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
47                result: if self.lookup_void_type == Some(result_type_id) {
48                    None
49                } else {
50                    let lookup_result_ty = self.lookup_type.lookup(result_type_id)?;
51                    Some(crate::FunctionResult {
52                        ty: lookup_result_ty.handle,
53                        binding: None,
54                    })
55                },
56                local_variables: Arena::new(),
57                expressions: self.make_expression_storage(
58                    &module.global_variables,
59                    &module.constants,
60                    &module.overrides,
61                ),
62                named_expressions: crate::NamedExpressions::default(),
63                body: crate::Block::new(),
64                diagnostic_filter_leaf: None,
65            }
66        };
67
68        // read parameters
69        for i in 0..fun.arguments.capacity() {
70            let start = self.data_offset;
71            match self.next_inst()? {
72                Instruction {
73                    op: spirv::Op::FunctionParameter,
74                    wc: 3,
75                } => {
76                    let type_id = self.next()?;
77                    let id = self.next()?;
78                    let handle = fun.expressions.append(
79                        crate::Expression::FunctionArgument(i as u32),
80                        self.span_from(start),
81                    );
82                    self.lookup_expression.insert(
83                        id,
84                        LookupExpression {
85                            handle,
86                            type_id,
87                            // Setting this to an invalid id will cause get_expr_handle
88                            // to default to the main body making sure no load/stores
89                            // are added.
90                            block_id: 0,
91                        },
92                    );
93                    //Note: we redo the lookup in order to work around `self` borrowing
94
95                    if type_id
96                        != self
97                            .lookup_function_type
98                            .lookup(fun_type_id)?
99                            .parameter_type_ids[i]
100                    {
101                        return Err(Error::WrongFunctionArgumentType(type_id));
102                    }
103                    let ty = self.lookup_type.lookup(type_id)?.handle;
104                    let decor = self.future_decor.remove(&id).unwrap_or_default();
105                    fun.arguments.push(crate::FunctionArgument {
106                        name: decor.name,
107                        ty,
108                        binding: None,
109                    });
110                }
111                Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
112            }
113        }
114
115        // Note the index this function's handle will be assigned, for tracing.
116        let function_index = module.functions.len();
117
118        // Read body
119        self.function_call_graph.add_node(fun_id);
120        let mut parameters_sampling =
121            vec![super::image::SamplingFlags::empty(); fun.arguments.len()];
122
123        let mut block_ctx = BlockContext {
124            phis: Default::default(),
125            blocks: Default::default(),
126            body_for_label: Default::default(),
127            mergers: Default::default(),
128            bodies: Default::default(),
129            module,
130            function_id: fun_id,
131            expressions: &mut fun.expressions,
132            local_arena: &mut fun.local_variables,
133            arguments: &fun.arguments,
134            parameter_sampling: &mut parameters_sampling,
135        };
136        // Insert the main body whose parent is also himself
137        block_ctx.bodies.push(super::Body::with_parent(0));
138
139        // Scan the blocks and add them as nodes
140        loop {
141            let fun_inst = self.next_inst()?;
142            log::debug!("{:?}", fun_inst.op);
143            match fun_inst.op {
144                spirv::Op::Line => {
145                    fun_inst.expect(4)?;
146                    let _file_id = self.next()?;
147                    let _row_id = self.next()?;
148                    let _col_id = self.next()?;
149                }
150                spirv::Op::Label => {
151                    // Read the label ID
152                    fun_inst.expect(2)?;
153                    let block_id = self.next()?;
154
155                    self.next_block(block_id, &mut block_ctx)?;
156                }
157                spirv::Op::FunctionEnd => {
158                    fun_inst.expect(1)?;
159                    break;
160                }
161                spirv::Op::ExtInst => {
162                    let _ = self.next()?;
163                    let _ = self.next()?;
164                    let set_id = self.next()?;
165                    if Some(set_id) == self.ext_non_semantic_id {
166                        for _ in 0..fun_inst.wc - 4 {
167                            self.next()?;
168                        }
169                    } else {
170                        return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
171                    }
172                }
173                _ => {
174                    return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
175                }
176            }
177        }
178
179        if let Some(ref prefix) = self.options.block_ctx_dump_prefix {
180            let dump_suffix = match self.lookup_entry_point.get(&fun_id) {
181                Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name),
182                None => format!("block_ctx.Fun-{function_index}.txt"),
183            };
184
185            cfg_if::cfg_if! {
186                if #[cfg(feature = "fs")] {
187                    let prefix: &std::path::Path = prefix.as_ref();
188                    let dest = prefix.join(dump_suffix);
189                    let dump = format!("{block_ctx:#?}");
190                    if let Err(e) = std::fs::write(&dest, dump) {
191                        log::error!("Unable to dump the block context into {dest:?}: {e}");
192                    }
193                } else {
194                    log::error!("Unable to dump the block context into {prefix:?}/{dump_suffix}: file system integration was not enabled with the `fs` feature");
195                }
196            }
197        }
198
199        // Emit `Store` statements to properly initialize all the local variables we
200        // created for `phi` expressions.
201        //
202        // Note that get_expr_handle also contributes slightly odd entries to this table,
203        // to get the spill.
204        for phi in block_ctx.phis.iter() {
205            // Get a pointer to the local variable for the phi's value.
206            let phi_pointer: Handle<crate::Expression> = block_ctx.expressions.append(
207                crate::Expression::LocalVariable(phi.local),
208                crate::Span::default(),
209            );
210
211            // At the end of each of `phi`'s predecessor blocks, store the corresponding
212            // source value in the phi's local variable.
213            for &(source, predecessor) in phi.expressions.iter() {
214                let source_lexp = &self.lookup_expression[&source];
215                let predecessor_body_idx = block_ctx.body_for_label[&predecessor];
216                // If the expression is a global/argument it will have a 0 block
217                // id so we must use a default value instead of panicking
218                let source_body_idx = block_ctx
219                    .body_for_label
220                    .get(&source_lexp.block_id)
221                    .copied()
222                    .unwrap_or(0);
223
224                // If the Naga `Expression` generated for `source` is in scope, then we
225                // can simply store that in the phi's local variable.
226                //
227                // Otherwise, spill the source value to a local variable in the block that
228                // defines it. (We know this store dominates the predecessor; otherwise,
229                // the phi wouldn't have been able to refer to that source expression in
230                // the first place.) Then, the predecessor block can count on finding the
231                // source's value in that local variable.
232                let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) {
233                    source_lexp.handle
234                } else {
235                    // The source SPIR-V expression is not defined in the phi's
236                    // predecessor block, nor is it a globally available expression. So it
237                    // must be defined off in some other block that merely dominates the
238                    // predecessor. This means that the corresponding Naga `Expression`
239                    // may not be in scope in the predecessor block.
240                    //
241                    // In the block that defines `source`, spill it to a fresh local
242                    // variable, to ensure we can still use it at the end of the
243                    // predecessor.
244                    let ty = self.lookup_type[&source_lexp.type_id].handle;
245                    let local = block_ctx.local_arena.append(
246                        crate::LocalVariable {
247                            name: None,
248                            ty,
249                            init: None,
250                        },
251                        crate::Span::default(),
252                    );
253
254                    let pointer = block_ctx.expressions.append(
255                        crate::Expression::LocalVariable(local),
256                        crate::Span::default(),
257                    );
258
259                    // Get the spilled value of the source expression.
260                    let start = block_ctx.expressions.len();
261                    let expr = block_ctx
262                        .expressions
263                        .append(crate::Expression::Load { pointer }, crate::Span::default());
264                    let range = block_ctx.expressions.range_from(start);
265
266                    block_ctx
267                        .blocks
268                        .get_mut(&predecessor)
269                        .unwrap()
270                        .push(crate::Statement::Emit(range), crate::Span::default());
271
272                    // At the end of the block that defines it, spill the source
273                    // expression's value.
274                    block_ctx
275                        .blocks
276                        .get_mut(&source_lexp.block_id)
277                        .unwrap()
278                        .push(
279                            crate::Statement::Store {
280                                pointer,
281                                value: source_lexp.handle,
282                            },
283                            crate::Span::default(),
284                        );
285
286                    expr
287                };
288
289                // At the end of the phi predecessor block, store the source
290                // value in the phi's value.
291                block_ctx.blocks.get_mut(&predecessor).unwrap().push(
292                    crate::Statement::Store {
293                        pointer: phi_pointer,
294                        value,
295                    },
296                    crate::Span::default(),
297                )
298            }
299        }
300
301        fun.body = block_ctx.lower();
302
303        // done
304        let fun_handle = module.functions.append(fun, self.span_from_with_op(start));
305        self.lookup_function.insert(
306            fun_id,
307            super::LookupFunction {
308                handle: fun_handle,
309                parameters_sampling,
310            },
311        );
312
313        if let Some(ep) = self.lookup_entry_point.remove(&fun_id) {
314            self.deferred_entry_points.push((ep, fun_id));
315        }
316
317        Ok(())
318    }
319
320    pub(super) fn process_entry_point(
321        &mut self,
322        module: &mut crate::Module,
323        ep: super::EntryPoint,
324        fun_id: u32,
325    ) -> Result<(), Error> {
326        // create a wrapping function
327        let mut function = crate::Function {
328            name: Some(format!("{}_wrap", ep.name)),
329            arguments: Vec::new(),
330            result: None,
331            local_variables: Arena::new(),
332            expressions: Arena::new(),
333            named_expressions: crate::NamedExpressions::default(),
334            body: crate::Block::new(),
335            diagnostic_filter_leaf: None,
336        };
337
338        // 1. copy the inputs from arguments to privates
339        for &v_id in ep.variable_ids.iter() {
340            let lvar = self.lookup_variable.lookup(v_id)?;
341            if let super::Variable::Input(ref arg) = lvar.inner {
342                let span = module.global_variables.get_span(lvar.handle);
343                let arg_expr = function.expressions.append(
344                    crate::Expression::FunctionArgument(function.arguments.len() as u32),
345                    span,
346                );
347                let load_expr = if arg.ty == module.global_variables[lvar.handle].ty {
348                    arg_expr
349                } else {
350                    // The only case where the type is different is if we need to treat
351                    // unsigned integer as signed.
352                    let mut emitter = Emitter::default();
353                    emitter.start(&function.expressions);
354                    let handle = function.expressions.append(
355                        crate::Expression::As {
356                            expr: arg_expr,
357                            kind: crate::ScalarKind::Sint,
358                            convert: Some(4),
359                        },
360                        span,
361                    );
362                    function.body.extend(emitter.finish(&function.expressions));
363                    handle
364                };
365                function.body.push(
366                    crate::Statement::Store {
367                        pointer: function
368                            .expressions
369                            .append(crate::Expression::GlobalVariable(lvar.handle), span),
370                        value: load_expr,
371                    },
372                    span,
373                );
374
375                let mut arg = arg.clone();
376                if ep.stage == crate::ShaderStage::Fragment {
377                    if let Some(ref mut binding) = arg.binding {
378                        binding.apply_default_interpolation(&module.types[arg.ty].inner);
379                    }
380                }
381                function.arguments.push(arg);
382            }
383        }
384        // 2. call the wrapped function
385        let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision
386        let dummy_handle = self.add_call(fake_id, fun_id);
387        function.body.push(
388            crate::Statement::Call {
389                function: dummy_handle,
390                arguments: Vec::new(),
391                result: None,
392            },
393            crate::Span::default(),
394        );
395
396        // 3. copy the outputs from privates to the result
397        //
398        // It would be nice to share struct layout code here with `parse_type_struct`,
399        // but that case needs to take into account offset decorations, which makes an
400        // abstraction harder to follow than just writing out what we mean. `Layouter`
401        // and `Alignment` cover the worst parts already.
402        let mut members = Vec::new();
403        self.layouter.update(module.to_ctx()).unwrap();
404        let mut next_member_offset = 0;
405        let mut struct_alignment = crate::proc::Alignment::ONE;
406        let mut components = Vec::new();
407        for &v_id in ep.variable_ids.iter() {
408            let lvar = self.lookup_variable.lookup(v_id)?;
409            if let super::Variable::Output(ref result) = lvar.inner {
410                let span = module.global_variables.get_span(lvar.handle);
411                let expr_handle = function
412                    .expressions
413                    .append(crate::Expression::GlobalVariable(lvar.handle), span);
414
415                // Cull problematic builtins of gl_PerVertex.
416                // See the docs for `Frontend::gl_per_vertex_builtin_access`.
417                {
418                    let ty = &module.types[result.ty];
419                    if let crate::TypeInner::Struct {
420                        members: ref original_members,
421                        span,
422                    } = ty.inner
423                    {
424                        let mut new_members = None;
425                        for (idx, member) in original_members.iter().enumerate() {
426                            if let Some(crate::Binding::BuiltIn(built_in)) = member.binding {
427                                if !self.gl_per_vertex_builtin_access.contains(&built_in) {
428                                    new_members.get_or_insert_with(|| original_members.clone())
429                                        [idx]
430                                        .binding = None;
431                                }
432                            }
433                        }
434                        if let Some(new_members) = new_members {
435                            module.types.replace(
436                                result.ty,
437                                crate::Type {
438                                    name: ty.name.clone(),
439                                    inner: crate::TypeInner::Struct {
440                                        members: new_members,
441                                        span,
442                                    },
443                                },
444                            );
445                        }
446                    }
447                }
448
449                match module.types[result.ty].inner {
450                    crate::TypeInner::Struct {
451                        members: ref sub_members,
452                        ..
453                    } => {
454                        for (index, sm) in sub_members.iter().enumerate() {
455                            if sm.binding.is_none() {
456                                continue;
457                            }
458                            let mut sm = sm.clone();
459
460                            if let Some(ref mut binding) = sm.binding {
461                                if ep.stage == crate::ShaderStage::Vertex {
462                                    binding.apply_default_interpolation(&module.types[sm.ty].inner);
463                                }
464                            }
465
466                            let member_alignment = self.layouter[sm.ty].alignment;
467                            next_member_offset = member_alignment.round_up(next_member_offset);
468                            sm.offset = next_member_offset;
469                            struct_alignment = struct_alignment.max(member_alignment);
470                            next_member_offset += self.layouter[sm.ty].size;
471                            members.push(sm);
472
473                            components.push(function.expressions.append(
474                                crate::Expression::AccessIndex {
475                                    base: expr_handle,
476                                    index: index as u32,
477                                },
478                                span,
479                            ));
480                        }
481                    }
482                    ref inner => {
483                        let mut binding = result.binding.clone();
484                        if let Some(ref mut binding) = binding {
485                            if ep.stage == crate::ShaderStage::Vertex {
486                                binding.apply_default_interpolation(inner);
487                            }
488                        }
489
490                        let member_alignment = self.layouter[result.ty].alignment;
491                        next_member_offset = member_alignment.round_up(next_member_offset);
492                        members.push(crate::StructMember {
493                            name: None,
494                            ty: result.ty,
495                            binding,
496                            offset: next_member_offset,
497                        });
498                        struct_alignment = struct_alignment.max(member_alignment);
499                        next_member_offset += self.layouter[result.ty].size;
500                        // populate just the globals first, then do `Load` in a
501                        // separate step, so that we can get a range.
502                        components.push(expr_handle);
503                    }
504                }
505            }
506        }
507
508        for (member_index, member) in members.iter().enumerate() {
509            match member.binding {
510                Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. }))
511                    if self.options.adjust_coordinate_space =>
512                {
513                    let mut emitter = Emitter::default();
514                    emitter.start(&function.expressions);
515                    let global_expr = components[member_index];
516                    let span = function.expressions.get_span(global_expr);
517                    let access_expr = function.expressions.append(
518                        crate::Expression::AccessIndex {
519                            base: global_expr,
520                            index: 1,
521                        },
522                        span,
523                    );
524                    let load_expr = function.expressions.append(
525                        crate::Expression::Load {
526                            pointer: access_expr,
527                        },
528                        span,
529                    );
530                    let neg_expr = function.expressions.append(
531                        crate::Expression::Unary {
532                            op: crate::UnaryOperator::Negate,
533                            expr: load_expr,
534                        },
535                        span,
536                    );
537                    function.body.extend(emitter.finish(&function.expressions));
538                    function.body.push(
539                        crate::Statement::Store {
540                            pointer: access_expr,
541                            value: neg_expr,
542                        },
543                        span,
544                    );
545                }
546                _ => {}
547            }
548        }
549
550        let mut emitter = Emitter::default();
551        emitter.start(&function.expressions);
552        for component in components.iter_mut() {
553            let load_expr = crate::Expression::Load {
554                pointer: *component,
555            };
556            let span = function.expressions.get_span(*component);
557            *component = function.expressions.append(load_expr, span);
558        }
559
560        match members[..] {
561            [] => {}
562            [ref member] => {
563                function.body.extend(emitter.finish(&function.expressions));
564                let span = function.expressions.get_span(components[0]);
565                function.body.push(
566                    crate::Statement::Return {
567                        value: components.first().cloned(),
568                    },
569                    span,
570                );
571                function.result = Some(crate::FunctionResult {
572                    ty: member.ty,
573                    binding: member.binding.clone(),
574                });
575            }
576            _ => {
577                let span = crate::Span::total_span(
578                    components.iter().map(|h| function.expressions.get_span(*h)),
579                );
580                let ty = module.types.insert(
581                    crate::Type {
582                        name: None,
583                        inner: crate::TypeInner::Struct {
584                            members,
585                            span: struct_alignment.round_up(next_member_offset),
586                        },
587                    },
588                    span,
589                );
590                let result_expr = function
591                    .expressions
592                    .append(crate::Expression::Compose { ty, components }, span);
593                function.body.extend(emitter.finish(&function.expressions));
594                function.body.push(
595                    crate::Statement::Return {
596                        value: Some(result_expr),
597                    },
598                    span,
599                );
600                function.result = Some(crate::FunctionResult { ty, binding: None });
601            }
602        }
603
604        module.entry_points.push(crate::EntryPoint {
605            name: ep.name,
606            stage: ep.stage,
607            early_depth_test: ep.early_depth_test,
608            workgroup_size: ep.workgroup_size,
609            workgroup_size_overrides: None,
610            function,
611            mesh_info: None,
612            task_payload: None,
613        });
614
615        Ok(())
616    }
617}
618
619impl BlockContext<'_> {
620    pub(super) fn gctx(&self) -> crate::proc::GlobalCtx<'_> {
621        crate::proc::GlobalCtx {
622            types: &self.module.types,
623            constants: &self.module.constants,
624            overrides: &self.module.overrides,
625            global_expressions: &self.module.global_expressions,
626        }
627    }
628
629    /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
630    fn lower(mut self) -> crate::Block {
631        fn lower_impl(
632            blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
633            bodies: &[super::Body],
634            body_idx: BodyIndex,
635        ) -> crate::Block {
636            let mut block = crate::Block::new();
637
638            for item in bodies[body_idx].data.iter() {
639                match *item {
640                    super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
641                    super::BodyFragment::If {
642                        condition,
643                        accept,
644                        reject,
645                    } => {
646                        let accept = lower_impl(blocks, bodies, accept);
647                        let reject = lower_impl(blocks, bodies, reject);
648
649                        block.push(
650                            crate::Statement::If {
651                                condition,
652                                accept,
653                                reject,
654                            },
655                            crate::Span::default(),
656                        )
657                    }
658                    super::BodyFragment::Loop {
659                        body,
660                        continuing,
661                        break_if,
662                    } => {
663                        let body = lower_impl(blocks, bodies, body);
664                        let continuing = lower_impl(blocks, bodies, continuing);
665
666                        block.push(
667                            crate::Statement::Loop {
668                                body,
669                                continuing,
670                                break_if,
671                            },
672                            crate::Span::default(),
673                        )
674                    }
675                    super::BodyFragment::Switch {
676                        selector,
677                        ref cases,
678                        default,
679                    } => {
680                        let mut ir_cases: Vec<_> = cases
681                            .iter()
682                            .map(|&(value, body_idx)| {
683                                let body = lower_impl(blocks, bodies, body_idx);
684
685                                // Handle simple cases that would make a fallthrough statement unreachable code
686                                let fall_through = body.last().is_none_or(|s| !s.is_terminator());
687
688                                crate::SwitchCase {
689                                    value: crate::SwitchValue::I32(value),
690                                    body,
691                                    fall_through,
692                                }
693                            })
694                            .collect();
695                        ir_cases.push(crate::SwitchCase {
696                            value: crate::SwitchValue::Default,
697                            body: lower_impl(blocks, bodies, default),
698                            fall_through: false,
699                        });
700
701                        block.push(
702                            crate::Statement::Switch {
703                                selector,
704                                cases: ir_cases,
705                            },
706                            crate::Span::default(),
707                        )
708                    }
709                    super::BodyFragment::Break => {
710                        block.push(crate::Statement::Break, crate::Span::default())
711                    }
712                    super::BodyFragment::Continue => {
713                        block.push(crate::Statement::Continue, crate::Span::default())
714                    }
715                }
716            }
717
718            block
719        }
720
721        lower_impl(&mut self.blocks, &self.bodies, 0)
722    }
723}