naga/front/glsl/
context.rs

1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5    ast::{
6        GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7        VariableReference,
8    },
9    error::{Error, ErrorKind},
10    types::{scalar_components, type_power},
11    Frontend, Result,
12};
13use crate::{
14    front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15    Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16    Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19/// The position at which an expression is, used while lowering
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22    /// The expression is in the left hand side of an assignment
23    Lhs,
24    /// The expression is in the right hand side of an assignment
25    Rhs,
26    /// The expression is an array being indexed, needed to allow constant
27    /// arrays to be dynamically indexed
28    AccessBase {
29        /// The index is a constant
30        constant_index: bool,
31    },
32}
33
34impl ExprPos {
35    /// Returns an lhs position if the current position is lhs otherwise AccessBase
36    const fn maybe_access_base(&self, constant_index: bool) -> Self {
37        match *self {
38            ExprPos::Lhs
39            | ExprPos::AccessBase {
40                constant_index: false,
41            } => *self,
42            _ => ExprPos::AccessBase { constant_index },
43        }
44    }
45}
46
47#[derive(Debug)]
48pub struct Context<'a> {
49    pub expressions: Arena<Expression>,
50    pub locals: Arena<LocalVariable>,
51
52    /// The [`FunctionArgument`]s for the final [`crate::Function`].
53    ///
54    /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
55    /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
56    /// a [`Vector`].
57    ///
58    /// [`Pointer`]: crate::TypeInner::Pointer
59    /// [`Vector`]: crate::TypeInner::Vector
60    pub arguments: Vec<FunctionArgument>,
61
62    /// The parameter types given in the source code.
63    ///
64    /// The `out` and `inout` qualifiers don't affect the types that appear
65    /// here. For example, an `inout vec2 a` argument would simply be a
66    /// [`Vector`], not a pointer to one.
67    ///
68    /// [`Vector`]: crate::TypeInner::Vector
69    pub parameters: Vec<Handle<Type>>,
70    pub parameters_info: Vec<ParameterInfo>,
71
72    pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73    pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75    pub const_typifier: Typifier,
76    pub typifier: Typifier,
77    layouter: Layouter,
78    emitter: Emitter,
79    stmt_ctx: Option<StmtContext>,
80    pub body: Block,
81    pub module: &'a mut crate::Module,
82    pub is_const: bool,
83    /// Tracks the expression kind of `Expression`s residing in `self.expressions`
84    pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85    /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
86    pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90    pub fn new(
91        frontend: &Frontend,
92        module: &'a mut crate::Module,
93        is_const: bool,
94        global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95    ) -> Result<Self> {
96        let mut this = Context {
97            expressions: Arena::new(),
98            locals: Arena::new(),
99            arguments: Vec::new(),
100
101            parameters: Vec::new(),
102            parameters_info: Vec::new(),
103
104            symbol_table: crate::front::SymbolTable::default(),
105            samplers: FastHashMap::default(),
106
107            const_typifier: Typifier::new(),
108            typifier: Typifier::new(),
109            layouter: Layouter::default(),
110            emitter: Emitter::default(),
111            stmt_ctx: Some(StmtContext::new()),
112            body: Block::new(),
113            module,
114            is_const: false,
115            local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116            global_expression_kind_tracker,
117        };
118
119        this.emit_start();
120
121        for &(ref name, lookup) in frontend.global_variables.iter() {
122            this.add_global(name, lookup)?
123        }
124        this.is_const = is_const;
125
126        Ok(this)
127    }
128
129    pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130    where
131        F: FnOnce(&mut Self) -> Result<()>,
132    {
133        self.new_body_with_ret(cb).map(|(b, _)| b)
134    }
135
136    pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137    where
138        F: FnOnce(&mut Self) -> Result<R>,
139    {
140        self.emit_restart();
141        let old_body = core::mem::replace(&mut self.body, Block::new());
142        let res = cb(self);
143        self.emit_restart();
144        let new_body = core::mem::replace(&mut self.body, old_body);
145        res.map(|r| (new_body, r))
146    }
147
148    pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149    where
150        F: FnOnce(&mut Self) -> Result<()>,
151    {
152        self.emit_restart();
153        let old_body = core::mem::replace(&mut self.body, body);
154        let res = cb(self);
155        self.emit_restart();
156        let body = core::mem::replace(&mut self.body, old_body);
157        res.map(|_| body)
158    }
159
160    pub fn add_global(
161        &mut self,
162        name: &str,
163        GlobalLookup {
164            kind,
165            entry_arg,
166            mutable,
167        }: GlobalLookup,
168    ) -> Result<()> {
169        let (expr, load, constant) = match kind {
170            GlobalLookupKind::Variable(v) => {
171                let span = self.module.global_variables.get_span(v);
172                (
173                    self.add_expression(Expression::GlobalVariable(v), span)?,
174                    self.module.global_variables[v].space != AddressSpace::Handle,
175                    None,
176                )
177            }
178            GlobalLookupKind::BlockSelect(handle, index) => {
179                let span = self.module.global_variables.get_span(handle);
180                let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181                let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183                (
184                    expr,
185                    {
186                        let ty = self.module.global_variables[handle].ty;
187
188                        match self.module.types[ty].inner {
189                            TypeInner::Struct { ref members, .. } => {
190                                if let TypeInner::Array {
191                                    size: crate::ArraySize::Dynamic,
192                                    ..
193                                } = self.module.types[members[index as usize].ty].inner
194                                {
195                                    false
196                                } else {
197                                    true
198                                }
199                            }
200                            _ => true,
201                        }
202                    },
203                    None,
204                )
205            }
206            GlobalLookupKind::Constant(v, ty) => {
207                let span = self.module.constants.get_span(v);
208                (
209                    self.add_expression(Expression::Constant(v), span)?,
210                    false,
211                    Some((v, ty)),
212                )
213            }
214        };
215
216        let var = VariableReference {
217            expr,
218            load,
219            mutable,
220            constant,
221            entry_arg,
222        };
223
224        self.symbol_table.add(name.into(), var);
225
226        Ok(())
227    }
228
229    /// Starts the expression emitter
230    ///
231    /// # Panics
232    ///
233    /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
234    #[inline]
235    pub fn emit_start(&mut self) {
236        self.emitter.start(&self.expressions)
237    }
238
239    /// Emits all the expressions captured by the emitter to the current body
240    ///
241    /// # Panics
242    ///
243    /// - If called before calling [`emit_start`].
244    /// - If called twice in a row without calling [`emit_start`].
245    ///
246    /// [`emit_start`]: Self::emit_start
247    pub fn emit_end(&mut self) {
248        self.body.extend(self.emitter.finish(&self.expressions))
249    }
250
251    /// Emits all the expressions captured by the emitter to the current body
252    /// and starts the emitter again
253    ///
254    /// # Panics
255    ///
256    /// - If called before calling [`emit_start`][Self::emit_start].
257    pub fn emit_restart(&mut self) {
258        self.emit_end();
259        self.emit_start()
260    }
261
262    pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
263        let mut eval = if self.is_const {
264            crate::proc::ConstantEvaluator::for_glsl_module(
265                self.module,
266                self.global_expression_kind_tracker,
267                &mut self.layouter,
268            )
269        } else {
270            crate::proc::ConstantEvaluator::for_glsl_function(
271                self.module,
272                &mut self.expressions,
273                &mut self.local_expression_kind_tracker,
274                &mut self.layouter,
275                &mut self.emitter,
276                &mut self.body,
277            )
278        };
279
280        eval.try_eval_and_append(expr, meta).map_err(|e| Error {
281            kind: e.into(),
282            meta,
283        })
284    }
285
286    /// Add variable to current scope
287    ///
288    /// Returns a variable if a variable with the same name was already defined,
289    /// otherwise returns `None`
290    pub fn add_local_var(
291        &mut self,
292        name: String,
293        expr: Handle<Expression>,
294        mutable: bool,
295    ) -> Option<VariableReference> {
296        let var = VariableReference {
297            expr,
298            load: true,
299            mutable,
300            constant: None,
301            entry_arg: None,
302        };
303
304        self.symbol_table.add(name, var)
305    }
306
307    /// Add function argument to current scope
308    pub fn add_function_arg(
309        &mut self,
310        name_meta: Option<(String, Span)>,
311        ty: Handle<Type>,
312        qualifier: ParameterQualifier,
313    ) -> Result<()> {
314        let index = self.arguments.len();
315        let mut arg = FunctionArgument {
316            name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
317            ty,
318            binding: None,
319        };
320        self.parameters.push(ty);
321
322        let opaque = match self.module.types[ty].inner {
323            TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
324            _ => false,
325        };
326
327        if qualifier.is_lhs() {
328            let span = self.module.types.get_span(arg.ty);
329            arg.ty = self.module.types.insert(
330                Type {
331                    name: None,
332                    inner: TypeInner::Pointer {
333                        base: arg.ty,
334                        space: AddressSpace::Function,
335                    },
336                },
337                span,
338            )
339        }
340
341        self.arguments.push(arg);
342
343        self.parameters_info.push(ParameterInfo {
344            qualifier,
345            depth: false,
346        });
347
348        if let Some((name, meta)) = name_meta {
349            let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
350            let mutable = qualifier != ParameterQualifier::Const && !opaque;
351            let load = qualifier.is_lhs();
352
353            let var = if mutable && !load {
354                let handle = self.locals.append(
355                    LocalVariable {
356                        name: Some(name.clone()),
357                        ty,
358                        init: None,
359                    },
360                    meta,
361                );
362                let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
363
364                self.emit_restart();
365
366                self.body.push(
367                    Statement::Store {
368                        pointer: local_expr,
369                        value: expr,
370                    },
371                    meta,
372                );
373
374                VariableReference {
375                    expr: local_expr,
376                    load: true,
377                    mutable,
378                    constant: None,
379                    entry_arg: None,
380                }
381            } else {
382                VariableReference {
383                    expr,
384                    load,
385                    mutable,
386                    constant: None,
387                    entry_arg: None,
388                }
389            };
390
391            self.symbol_table.add(name, var);
392        }
393
394        Ok(())
395    }
396
397    /// Returns a [`StmtContext`] to be used in parsing and lowering
398    ///
399    /// # Panics
400    ///
401    /// - If more than one [`StmtContext`] are active at the same time or if the
402    ///   previous call didn't use it in lowering.
403    #[must_use]
404    pub fn stmt_ctx(&mut self) -> StmtContext {
405        self.stmt_ctx.take().unwrap()
406    }
407
408    /// Lowers a [`HirExpr`] which might produce a [`Expression`].
409    ///
410    /// consumes a [`StmtContext`] returning it to the context so that it can be
411    /// used again later.
412    pub fn lower(
413        &mut self,
414        mut stmt: StmtContext,
415        frontend: &mut Frontend,
416        expr: Handle<HirExpr>,
417        pos: ExprPos,
418    ) -> Result<(Option<Handle<Expression>>, Span)> {
419        let res = self.lower_inner(&stmt, frontend, expr, pos);
420
421        stmt.hir_exprs.clear();
422        self.stmt_ctx = Some(stmt);
423
424        res
425    }
426
427    /// Similar to [`lower`](Self::lower) but returns an error if the expression
428    /// returns void (ie. doesn't produce a [`Expression`]).
429    ///
430    /// consumes a [`StmtContext`] returning it to the context so that it can be
431    /// used again later.
432    pub fn lower_expect(
433        &mut self,
434        mut stmt: StmtContext,
435        frontend: &mut Frontend,
436        expr: Handle<HirExpr>,
437        pos: ExprPos,
438    ) -> Result<(Handle<Expression>, Span)> {
439        let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
440
441        stmt.hir_exprs.clear();
442        self.stmt_ctx = Some(stmt);
443
444        res
445    }
446
447    /// internal implementation of [`lower_expect`](Self::lower_expect)
448    ///
449    /// this method is only public because it's used in
450    /// [`function_call`](Frontend::function_call), unless you know what
451    /// you're doing use [`lower_expect`](Self::lower_expect)
452    pub fn lower_expect_inner(
453        &mut self,
454        stmt: &StmtContext,
455        frontend: &mut Frontend,
456        expr: Handle<HirExpr>,
457        pos: ExprPos,
458    ) -> Result<(Handle<Expression>, Span)> {
459        let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
460
461        let expr = match maybe_expr {
462            Some(e) => e,
463            None => {
464                return Err(Error {
465                    kind: ErrorKind::SemanticError("Expression returns void".into()),
466                    meta,
467                })
468            }
469        };
470
471        Ok((expr, meta))
472    }
473
474    fn lower_store(
475        &mut self,
476        pointer: Handle<Expression>,
477        value: Handle<Expression>,
478        meta: Span,
479    ) -> Result<()> {
480        if let Expression::Swizzle {
481            size,
482            mut vector,
483            pattern,
484        } = self.expressions[pointer]
485        {
486            // Stores to swizzled values are not directly supported,
487            // lower them as series of per-component stores.
488            let size = match size {
489                VectorSize::Bi => 2,
490                VectorSize::Tri => 3,
491                VectorSize::Quad => 4,
492            };
493
494            if let Expression::Load { pointer } = self.expressions[vector] {
495                vector = pointer;
496            }
497
498            #[allow(clippy::needless_range_loop)]
499            for index in 0..size {
500                let dst = self.add_expression(
501                    Expression::AccessIndex {
502                        base: vector,
503                        index: pattern[index].index(),
504                    },
505                    meta,
506                )?;
507                let src = self.add_expression(
508                    Expression::AccessIndex {
509                        base: value,
510                        index: index as u32,
511                    },
512                    meta,
513                )?;
514
515                self.emit_restart();
516
517                self.body.push(
518                    Statement::Store {
519                        pointer: dst,
520                        value: src,
521                    },
522                    meta,
523                );
524            }
525        } else {
526            self.emit_restart();
527
528            self.body.push(Statement::Store { pointer, value }, meta);
529        }
530
531        Ok(())
532    }
533
534    /// Internal implementation of [`lower`](Self::lower)
535    fn lower_inner(
536        &mut self,
537        stmt: &StmtContext,
538        frontend: &mut Frontend,
539        expr: Handle<HirExpr>,
540        pos: ExprPos,
541    ) -> Result<(Option<Handle<Expression>>, Span)> {
542        let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
543
544        log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
545
546        let handle = match *kind {
547            HirExprKind::Access { base, index } => {
548                let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
549                let maybe_constant_index = match pos {
550                    // Don't try to generate `AccessIndex` if in a LHS position, since it
551                    // wouldn't produce a pointer.
552                    ExprPos::Lhs => None,
553                    _ => self
554                        .module
555                        .to_ctx()
556                        .eval_expr_to_u32_from(index, &self.expressions)
557                        .ok(),
558                };
559
560                let base = self
561                    .lower_expect_inner(
562                        stmt,
563                        frontend,
564                        base,
565                        pos.maybe_access_base(maybe_constant_index.is_some()),
566                    )?
567                    .0;
568
569                let pointer = maybe_constant_index
570                    .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
571                    .unwrap_or_else(|| {
572                        self.add_expression(Expression::Access { base, index }, meta)
573                    })?;
574
575                if ExprPos::Rhs == pos {
576                    let resolved = self.resolve_type(pointer, meta)?;
577                    if resolved.pointer_space().is_some() {
578                        return Ok((
579                            Some(self.add_expression(Expression::Load { pointer }, meta)?),
580                            meta,
581                        ));
582                    }
583                }
584
585                pointer
586            }
587            HirExprKind::Select { base, ref field } => {
588                let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
589
590                frontend.field_selection(self, pos, base, field, meta)?
591            }
592            HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
593                self.add_expression(Expression::Literal(literal), meta)?
594            }
595            HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
596                let (mut left, left_meta) =
597                    self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
598                let (mut right, right_meta) =
599                    self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
600
601                match op {
602                    BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
603                        self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
604                    }
605                    _ => self
606                        .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
607                }
608
609                self.typifier_grow(left, left_meta)?;
610                self.typifier_grow(right, right_meta)?;
611
612                let left_inner = self.get_type(left);
613                let right_inner = self.get_type(right);
614
615                match (left_inner, right_inner) {
616                    (
617                        &TypeInner::Matrix {
618                            columns: left_columns,
619                            rows: left_rows,
620                            scalar: left_scalar,
621                        },
622                        &TypeInner::Matrix {
623                            columns: right_columns,
624                            rows: right_rows,
625                            scalar: right_scalar,
626                        },
627                    ) => {
628                        let dimensions_ok = if op == BinaryOperator::Multiply {
629                            left_columns == right_rows
630                        } else {
631                            left_columns == right_columns && left_rows == right_rows
632                        };
633
634                        // Check that the two arguments have the same dimensions
635                        if !dimensions_ok || left_scalar != right_scalar {
636                            frontend.errors.push(Error {
637                                kind: ErrorKind::SemanticError(
638                                    format!(
639                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
640                                    )
641                                    .into(),
642                                ),
643                                meta,
644                            })
645                        }
646
647                        match op {
648                            BinaryOperator::Divide => {
649                                // Naga IR doesn't support matrix division so we need to
650                                // divide the columns individually and reassemble the matrix
651                                let mut components = Vec::with_capacity(left_columns as usize);
652
653                                for index in 0..left_columns as u32 {
654                                    // Get the column vectors
655                                    let left_vector = self.add_expression(
656                                        Expression::AccessIndex { base: left, index },
657                                        meta,
658                                    )?;
659                                    let right_vector = self.add_expression(
660                                        Expression::AccessIndex { base: right, index },
661                                        meta,
662                                    )?;
663
664                                    // Divide the vectors
665                                    let column = self.add_expression(
666                                        Expression::Binary {
667                                            op,
668                                            left: left_vector,
669                                            right: right_vector,
670                                        },
671                                        meta,
672                                    )?;
673
674                                    components.push(column)
675                                }
676
677                                let ty = self.module.types.insert(
678                                    Type {
679                                        name: None,
680                                        inner: TypeInner::Matrix {
681                                            columns: left_columns,
682                                            rows: left_rows,
683                                            scalar: left_scalar,
684                                        },
685                                    },
686                                    Span::default(),
687                                );
688
689                                // Rebuild the matrix from the divided vectors
690                                self.add_expression(Expression::Compose { ty, components }, meta)?
691                            }
692                            BinaryOperator::Equal | BinaryOperator::NotEqual => {
693                                // Naga IR doesn't support matrix comparisons so we need to
694                                // compare the columns individually and then fold them together
695                                //
696                                // The folding is done using a logical and for equality and
697                                // a logical or for inequality
698                                let equals = op == BinaryOperator::Equal;
699
700                                let (op, combine, fun) = match equals {
701                                    true => (
702                                        BinaryOperator::Equal,
703                                        BinaryOperator::LogicalAnd,
704                                        RelationalFunction::All,
705                                    ),
706                                    false => (
707                                        BinaryOperator::NotEqual,
708                                        BinaryOperator::LogicalOr,
709                                        RelationalFunction::Any,
710                                    ),
711                                };
712
713                                let mut root = None;
714
715                                for index in 0..left_columns as u32 {
716                                    // Get the column vectors
717                                    let left_vector = self.add_expression(
718                                        Expression::AccessIndex { base: left, index },
719                                        meta,
720                                    )?;
721                                    let right_vector = self.add_expression(
722                                        Expression::AccessIndex { base: right, index },
723                                        meta,
724                                    )?;
725
726                                    let argument = self.add_expression(
727                                        Expression::Binary {
728                                            op,
729                                            left: left_vector,
730                                            right: right_vector,
731                                        },
732                                        meta,
733                                    )?;
734
735                                    // The result of comparing two vectors is a boolean vector
736                                    // so use a relational function like all to get a single
737                                    // boolean value
738                                    let compare = self.add_expression(
739                                        Expression::Relational { fun, argument },
740                                        meta,
741                                    )?;
742
743                                    // Fold the result
744                                    root = Some(match root {
745                                        Some(right) => self.add_expression(
746                                            Expression::Binary {
747                                                op: combine,
748                                                left: compare,
749                                                right,
750                                            },
751                                            meta,
752                                        )?,
753                                        None => compare,
754                                    });
755                                }
756
757                                root.unwrap()
758                            }
759                            _ => {
760                                self.add_expression(Expression::Binary { left, op, right }, meta)?
761                            }
762                        }
763                    }
764                    (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
765                        BinaryOperator::Equal | BinaryOperator::NotEqual => {
766                            let equals = op == BinaryOperator::Equal;
767
768                            let (op, fun) = match equals {
769                                true => (BinaryOperator::Equal, RelationalFunction::All),
770                                false => (BinaryOperator::NotEqual, RelationalFunction::Any),
771                            };
772
773                            let argument =
774                                self.add_expression(Expression::Binary { op, left, right }, meta)?;
775
776                            self.add_expression(Expression::Relational { fun, argument }, meta)?
777                        }
778                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
779                    },
780                    (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
781                        BinaryOperator::Add
782                        | BinaryOperator::Subtract
783                        | BinaryOperator::Divide
784                        | BinaryOperator::And
785                        | BinaryOperator::ExclusiveOr
786                        | BinaryOperator::InclusiveOr
787                        | BinaryOperator::ShiftLeft
788                        | BinaryOperator::ShiftRight => {
789                            let scalar_vector = self
790                                .add_expression(Expression::Splat { size, value: right }, meta)?;
791
792                            self.add_expression(
793                                Expression::Binary {
794                                    op,
795                                    left,
796                                    right: scalar_vector,
797                                },
798                                meta,
799                            )?
800                        }
801                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
802                    },
803                    (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
804                        BinaryOperator::Add
805                        | BinaryOperator::Subtract
806                        | BinaryOperator::Divide
807                        | BinaryOperator::And
808                        | BinaryOperator::ExclusiveOr
809                        | BinaryOperator::InclusiveOr => {
810                            let scalar_vector =
811                                self.add_expression(Expression::Splat { size, value: left }, meta)?;
812
813                            self.add_expression(
814                                Expression::Binary {
815                                    op,
816                                    left: scalar_vector,
817                                    right,
818                                },
819                                meta,
820                            )?
821                        }
822                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
823                    },
824                    (
825                        &TypeInner::Scalar(left_scalar),
826                        &TypeInner::Matrix {
827                            rows,
828                            columns,
829                            scalar: right_scalar,
830                        },
831                    ) => {
832                        // Check that the two arguments have the same scalar type
833                        if left_scalar != right_scalar {
834                            frontend.errors.push(Error {
835                                kind: ErrorKind::SemanticError(
836                                    format!(
837                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
838                                    )
839                                    .into(),
840                                ),
841                                meta,
842                            })
843                        }
844
845                        match op {
846                            BinaryOperator::Divide
847                            | BinaryOperator::Add
848                            | BinaryOperator::Subtract => {
849                                // Naga IR doesn't support all matrix by scalar operations so
850                                // we need for some to turn the scalar into a vector by
851                                // splatting it and then for each column vector apply the
852                                // operation and finally reconstruct the matrix
853                                let scalar_vector = self.add_expression(
854                                    Expression::Splat {
855                                        size: rows,
856                                        value: left,
857                                    },
858                                    meta,
859                                )?;
860
861                                let mut components = Vec::with_capacity(columns as usize);
862
863                                for index in 0..columns as u32 {
864                                    // Get the column vector
865                                    let matrix_column = self.add_expression(
866                                        Expression::AccessIndex { base: right, index },
867                                        meta,
868                                    )?;
869
870                                    // Apply the operation to the splatted vector and
871                                    // the column vector
872                                    let column = self.add_expression(
873                                        Expression::Binary {
874                                            op,
875                                            left: scalar_vector,
876                                            right: matrix_column,
877                                        },
878                                        meta,
879                                    )?;
880
881                                    components.push(column)
882                                }
883
884                                let ty = self.module.types.insert(
885                                    Type {
886                                        name: None,
887                                        inner: TypeInner::Matrix {
888                                            columns,
889                                            rows,
890                                            scalar: left_scalar,
891                                        },
892                                    },
893                                    Span::default(),
894                                );
895
896                                // Rebuild the matrix from the operation result vectors
897                                self.add_expression(Expression::Compose { ty, components }, meta)?
898                            }
899                            _ => {
900                                self.add_expression(Expression::Binary { left, op, right }, meta)?
901                            }
902                        }
903                    }
904                    (
905                        &TypeInner::Matrix {
906                            rows,
907                            columns,
908                            scalar: left_scalar,
909                        },
910                        &TypeInner::Scalar(right_scalar),
911                    ) => {
912                        // Check that the two arguments have the same scalar type
913                        if left_scalar != right_scalar {
914                            frontend.errors.push(Error {
915                                kind: ErrorKind::SemanticError(
916                                    format!(
917                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
918                                    )
919                                    .into(),
920                                ),
921                                meta,
922                            })
923                        }
924
925                        match op {
926                            BinaryOperator::Divide
927                            | BinaryOperator::Add
928                            | BinaryOperator::Subtract => {
929                                // Naga IR doesn't support all matrix by scalar operations so
930                                // we need for some to turn the scalar into a vector by
931                                // splatting it and then for each column vector apply the
932                                // operation and finally reconstruct the matrix
933
934                                let scalar_vector = self.add_expression(
935                                    Expression::Splat {
936                                        size: rows,
937                                        value: right,
938                                    },
939                                    meta,
940                                )?;
941
942                                let mut components = Vec::with_capacity(columns as usize);
943
944                                for index in 0..columns as u32 {
945                                    // Get the column vector
946                                    let matrix_column = self.add_expression(
947                                        Expression::AccessIndex { base: left, index },
948                                        meta,
949                                    )?;
950
951                                    // Apply the operation to the splatted vector and
952                                    // the column vector
953                                    let column = self.add_expression(
954                                        Expression::Binary {
955                                            op,
956                                            left: matrix_column,
957                                            right: scalar_vector,
958                                        },
959                                        meta,
960                                    )?;
961
962                                    components.push(column)
963                                }
964
965                                let ty = self.module.types.insert(
966                                    Type {
967                                        name: None,
968                                        inner: TypeInner::Matrix {
969                                            columns,
970                                            rows,
971                                            scalar: left_scalar,
972                                        },
973                                    },
974                                    Span::default(),
975                                );
976
977                                // Rebuild the matrix from the operation result vectors
978                                self.add_expression(Expression::Compose { ty, components }, meta)?
979                            }
980                            _ => {
981                                self.add_expression(Expression::Binary { left, op, right }, meta)?
982                            }
983                        }
984                    }
985                    _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
986                }
987            }
988            HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
989                let expr = self
990                    .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
991                    .0;
992
993                self.add_expression(Expression::Unary { op, expr }, meta)?
994            }
995            HirExprKind::Variable(ref var) => match pos {
996                ExprPos::Lhs => {
997                    if !var.mutable {
998                        frontend.errors.push(Error {
999                            kind: ErrorKind::SemanticError(
1000                                "Variable cannot be used in LHS position".into(),
1001                            ),
1002                            meta,
1003                        })
1004                    }
1005
1006                    var.expr
1007                }
1008                ExprPos::AccessBase { constant_index } => {
1009                    // If the index isn't constant all accesses backed by a constant base need
1010                    // to be done through a proxy local variable, since constants have a non
1011                    // pointer type which is required for dynamic indexing
1012                    if !constant_index {
1013                        if let Some((constant, ty)) = var.constant {
1014                            let init = self
1015                                .add_expression(Expression::Constant(constant), Span::default())?;
1016                            let local = self.locals.append(
1017                                LocalVariable {
1018                                    name: None,
1019                                    ty,
1020                                    init: Some(init),
1021                                },
1022                                Span::default(),
1023                            );
1024
1025                            self.add_expression(Expression::LocalVariable(local), Span::default())?
1026                        } else {
1027                            var.expr
1028                        }
1029                    } else {
1030                        var.expr
1031                    }
1032                }
1033                _ if var.load => {
1034                    self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1035                }
1036                ExprPos::Rhs => {
1037                    if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1038                        self.add_expression(Expression::Constant(constant), meta)?
1039                    } else {
1040                        var.expr
1041                    }
1042                }
1043            },
1044            HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1045                let maybe_expr = frontend.function_or_constructor_call(
1046                    self,
1047                    stmt,
1048                    call.kind.clone(),
1049                    &call.args,
1050                    meta,
1051                )?;
1052                return Ok((maybe_expr, meta));
1053            }
1054            // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
1055            //
1056            // The ternary operator is defined to only evaluate one of the two possible
1057            // expressions which means that it's behavior is that of an `if` statement,
1058            // and it's merely syntactic sugar for it.
1059            HirExprKind::Conditional {
1060                condition,
1061                accept,
1062                reject,
1063            } if ExprPos::Lhs != pos => {
1064                // Given an expression `a ? b : c`, we need to produce a Naga
1065                // statement roughly like:
1066                //
1067                //     var temp;
1068                //     if a {
1069                //         temp = convert(b);
1070                //     } else  {
1071                //         temp = convert(c);
1072                //     }
1073                //
1074                // where `convert` stands for type conversions to bring `b` and `c` to
1075                // the same type, and then use `temp` to represent the value of the whole
1076                // conditional expression in subsequent code.
1077
1078                // Lower the condition first to the current bodyy
1079                let condition = self
1080                    .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1081                    .0;
1082
1083                let (mut accept_body, (mut accept, accept_meta)) =
1084                    self.new_body_with_ret(|ctx| {
1085                        // Lower the `true` branch
1086                        ctx.lower_expect_inner(stmt, frontend, accept, pos)
1087                    })?;
1088
1089                let (mut reject_body, (mut reject, reject_meta)) =
1090                    self.new_body_with_ret(|ctx| {
1091                        // Lower the `false` branch
1092                        ctx.lower_expect_inner(stmt, frontend, reject, pos)
1093                    })?;
1094
1095                // We need to do some custom implicit conversions since the two target expressions
1096                // are in different bodies
1097                if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1098                    // Get the components of both branches and calculate the type power
1099                    self.expr_scalar_components(accept, accept_meta)?
1100                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1101                    self.expr_scalar_components(reject, reject_meta)?
1102                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1103                ) {
1104                    match accept_power.cmp(&reject_power) {
1105                        core::cmp::Ordering::Less => {
1106                            accept_body = self.with_body(accept_body, |ctx| {
1107                                ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1108                                Ok(())
1109                            })?;
1110                        }
1111                        core::cmp::Ordering::Equal => {}
1112                        core::cmp::Ordering::Greater => {
1113                            reject_body = self.with_body(reject_body, |ctx| {
1114                                ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1115                                Ok(())
1116                            })?;
1117                        }
1118                    }
1119                }
1120
1121                // We need to get the type of the resulting expression to create the local,
1122                // this must be done after implicit conversions to ensure both branches have
1123                // the same type.
1124                let ty = self.resolve_type_handle(accept, accept_meta)?;
1125
1126                // Add the local that will hold the result of our conditional
1127                let local = self.locals.append(
1128                    LocalVariable {
1129                        name: None,
1130                        ty,
1131                        init: None,
1132                    },
1133                    meta,
1134                );
1135
1136                let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1137
1138                // Add to each  the store to the result variable
1139                accept_body.push(
1140                    Statement::Store {
1141                        pointer: local_expr,
1142                        value: accept,
1143                    },
1144                    accept_meta,
1145                );
1146                reject_body.push(
1147                    Statement::Store {
1148                        pointer: local_expr,
1149                        value: reject,
1150                    },
1151                    reject_meta,
1152                );
1153
1154                // Finally add the `If` to the main body with the `condition` we lowered
1155                // earlier and the branches we prepared.
1156                self.body.push(
1157                    Statement::If {
1158                        condition,
1159                        accept: accept_body,
1160                        reject: reject_body,
1161                    },
1162                    meta,
1163                );
1164
1165                // Note: `Expression::Load` must be emitted before it's used so make
1166                // sure the emitter is active here.
1167                self.add_expression(
1168                    Expression::Load {
1169                        pointer: local_expr,
1170                    },
1171                    meta,
1172                )?
1173            }
1174            HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1175                let (pointer, ptr_meta) =
1176                    self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1177                let (mut value, value_meta) =
1178                    self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1179
1180                let ty = match *self.resolve_type(pointer, ptr_meta)? {
1181                    TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1182                    ref ty => ty,
1183                };
1184
1185                if let Some(scalar) = scalar_components(ty) {
1186                    self.implicit_conversion(&mut value, value_meta, scalar)?;
1187                }
1188
1189                self.lower_store(pointer, value, meta)?;
1190
1191                value
1192            }
1193            HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1194                let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1195                let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1196                    pointer
1197                } else {
1198                    self.add_expression(Expression::Load { pointer }, meta)?
1199                };
1200
1201                let res = match *self.resolve_type(left, meta)? {
1202                    TypeInner::Scalar(scalar) => {
1203                        let ty = TypeInner::Scalar(scalar);
1204                        Literal::one(scalar).map(|i| (ty, i, None, None))
1205                    }
1206                    TypeInner::Vector { size, scalar } => {
1207                        let ty = TypeInner::Vector { size, scalar };
1208                        Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1209                    }
1210                    TypeInner::Matrix {
1211                        columns,
1212                        rows,
1213                        scalar,
1214                    } => {
1215                        let ty = TypeInner::Matrix {
1216                            columns,
1217                            rows,
1218                            scalar,
1219                        };
1220                        Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1221                    }
1222                    _ => None,
1223                };
1224                let (ty_inner, literal, rows, columns) = match res {
1225                    Some(res) => res,
1226                    None => {
1227                        frontend.errors.push(Error {
1228                            kind: ErrorKind::SemanticError(
1229                                "Increment/decrement only works on scalar/vector/matrix".into(),
1230                            ),
1231                            meta,
1232                        });
1233                        return Ok((Some(left), meta));
1234                    }
1235                };
1236
1237                let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1238
1239                // Glsl allows pre/postfixes operations on vectors and matrices, so if the
1240                // target is either of them change the right side of the addition to be splatted
1241                // to the same size as the target, furthermore if the target is a matrix
1242                // use a composed matrix using the splatted value.
1243                if let Some(size) = rows {
1244                    right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1245
1246                    if let Some(cols) = columns {
1247                        let ty = self.module.types.insert(
1248                            Type {
1249                                name: None,
1250                                inner: ty_inner,
1251                            },
1252                            meta,
1253                        );
1254
1255                        right = self.add_expression(
1256                            Expression::Compose {
1257                                ty,
1258                                components: core::iter::repeat_n(right, cols as usize).collect(),
1259                            },
1260                            meta,
1261                        )?;
1262                    }
1263                }
1264
1265                let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1266
1267                self.lower_store(pointer, value, meta)?;
1268
1269                if postfix {
1270                    left
1271                } else {
1272                    value
1273                }
1274            }
1275            HirExprKind::Method {
1276                expr: object,
1277                ref name,
1278                ref args,
1279            } if ExprPos::Lhs != pos => {
1280                let args = args
1281                    .iter()
1282                    .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1283                    .collect::<Result<Vec<_>>>()?;
1284                match name.as_ref() {
1285                    "length" => {
1286                        if !args.is_empty() {
1287                            frontend.errors.push(Error {
1288                                kind: ErrorKind::SemanticError(
1289                                    ".length() doesn't take any arguments".into(),
1290                                ),
1291                                meta,
1292                            });
1293                        }
1294                        let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1295                        let array_type = self.resolve_type(lowered_array, meta)?;
1296
1297                        match *array_type {
1298                            TypeInner::Array {
1299                                size: crate::ArraySize::Constant(size),
1300                                ..
1301                            } => {
1302                                let mut array_length = self.add_expression(
1303                                    Expression::Literal(Literal::U32(size.get())),
1304                                    meta,
1305                                )?;
1306                                self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1307                                array_length
1308                            }
1309                            // let the error be handled in type checking if it's not a dynamic array
1310                            _ => {
1311                                let mut array_length = self
1312                                    .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1313                                self.conversion(&mut array_length, meta, Scalar::I32)?;
1314                                array_length
1315                            }
1316                        }
1317                    }
1318                    _ => {
1319                        return Err(Error {
1320                            kind: ErrorKind::SemanticError(
1321                                format!("unknown method '{name}'").into(),
1322                            ),
1323                            meta,
1324                        });
1325                    }
1326                }
1327            }
1328            _ => {
1329                return Err(Error {
1330                    kind: ErrorKind::SemanticError(
1331                        format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1332                            .into(),
1333                    ),
1334                    meta,
1335                })
1336            }
1337        };
1338
1339        log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1340
1341        Ok((Some(handle), meta))
1342    }
1343
1344    pub fn expr_scalar_components(
1345        &mut self,
1346        expr: Handle<Expression>,
1347        meta: Span,
1348    ) -> Result<Option<Scalar>> {
1349        let ty = self.resolve_type(expr, meta)?;
1350        Ok(scalar_components(ty))
1351    }
1352
1353    pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1354        Ok(self
1355            .expr_scalar_components(expr, meta)?
1356            .and_then(type_power))
1357    }
1358
1359    pub fn conversion(
1360        &mut self,
1361        expr: &mut Handle<Expression>,
1362        meta: Span,
1363        scalar: Scalar,
1364    ) -> Result<()> {
1365        *expr = self.add_expression(
1366            Expression::As {
1367                expr: *expr,
1368                kind: scalar.kind,
1369                convert: Some(scalar.width),
1370            },
1371            meta,
1372        )?;
1373
1374        Ok(())
1375    }
1376
1377    pub fn implicit_conversion(
1378        &mut self,
1379        expr: &mut Handle<Expression>,
1380        meta: Span,
1381        scalar: Scalar,
1382    ) -> Result<()> {
1383        if let (Some(tgt_power), Some(expr_power)) =
1384            (type_power(scalar), self.expr_power(*expr, meta)?)
1385        {
1386            if tgt_power > expr_power {
1387                self.conversion(expr, meta, scalar)?;
1388            }
1389        }
1390
1391        Ok(())
1392    }
1393
1394    pub fn forced_conversion(
1395        &mut self,
1396        expr: &mut Handle<Expression>,
1397        meta: Span,
1398        scalar: Scalar,
1399    ) -> Result<()> {
1400        if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1401            if expr_scalar != scalar {
1402                self.conversion(expr, meta, scalar)?;
1403            }
1404        }
1405
1406        Ok(())
1407    }
1408
1409    pub fn binary_implicit_conversion(
1410        &mut self,
1411        left: &mut Handle<Expression>,
1412        left_meta: Span,
1413        right: &mut Handle<Expression>,
1414        right_meta: Span,
1415    ) -> Result<()> {
1416        let left_components = self.expr_scalar_components(*left, left_meta)?;
1417        let right_components = self.expr_scalar_components(*right, right_meta)?;
1418
1419        if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1420            left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1421            right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1422        ) {
1423            match left_power.cmp(&right_power) {
1424                core::cmp::Ordering::Less => {
1425                    self.conversion(left, left_meta, right_scalar)?;
1426                }
1427                core::cmp::Ordering::Equal => {}
1428                core::cmp::Ordering::Greater => {
1429                    self.conversion(right, right_meta, left_scalar)?;
1430                }
1431            }
1432        }
1433
1434        Ok(())
1435    }
1436
1437    pub fn implicit_splat(
1438        &mut self,
1439        expr: &mut Handle<Expression>,
1440        meta: Span,
1441        vector_size: Option<VectorSize>,
1442    ) -> Result<()> {
1443        let expr_type = self.resolve_type(*expr, meta)?;
1444
1445        if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1446            *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1447        }
1448
1449        Ok(())
1450    }
1451
1452    pub fn vector_resize(
1453        &mut self,
1454        size: VectorSize,
1455        vector: Handle<Expression>,
1456        meta: Span,
1457    ) -> Result<Handle<Expression>> {
1458        self.add_expression(
1459            Expression::Swizzle {
1460                size,
1461                vector,
1462                pattern: crate::SwizzleComponent::XYZW,
1463            },
1464            meta,
1465        )
1466    }
1467}
1468
1469impl Index<Handle<Expression>> for Context<'_> {
1470    type Output = Expression;
1471
1472    fn index(&self, index: Handle<Expression>) -> &Self::Output {
1473        if self.is_const {
1474            &self.module.global_expressions[index]
1475        } else {
1476            &self.expressions[index]
1477        }
1478    }
1479}
1480
1481/// Helper struct passed when parsing expressions
1482///
1483/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
1484/// and only one of these may be active at any time per context.
1485#[derive(Debug)]
1486pub struct StmtContext {
1487    /// A arena of high level expressions which can be lowered through a
1488    /// [`Context`] to Naga's [`Expression`]s
1489    pub hir_exprs: Arena<HirExpr>,
1490}
1491
1492impl StmtContext {
1493    const fn new() -> Self {
1494        StmtContext {
1495            hir_exprs: Arena::new(),
1496        }
1497    }
1498}