naga/front/glsl/
context.rs

1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5    ast::{
6        GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7        VariableReference,
8    },
9    error::{Error, ErrorKind},
10    types::{scalar_components, type_power},
11    Frontend, Result,
12};
13use crate::{
14    front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15    Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16    Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19/// The position at which an expression is, used while lowering
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22    /// The expression is in the left hand side of an assignment
23    Lhs,
24    /// The expression is in the right hand side of an assignment
25    Rhs,
26    /// The expression is an array being indexed, needed to allow constant
27    /// arrays to be dynamically indexed
28    AccessBase {
29        /// The index is a constant
30        constant_index: bool,
31    },
32}
33
34impl ExprPos {
35    /// Returns an lhs position if the current position is lhs otherwise AccessBase
36    const fn maybe_access_base(&self, constant_index: bool) -> Self {
37        match *self {
38            ExprPos::Lhs
39            | ExprPos::AccessBase {
40                constant_index: false,
41            } => *self,
42            _ => ExprPos::AccessBase { constant_index },
43        }
44    }
45}
46
47#[derive(Debug)]
48pub(crate) struct Context<'a> {
49    pub expressions: Arena<Expression>,
50    pub locals: Arena<LocalVariable>,
51
52    /// The [`FunctionArgument`]s for the final [`crate::Function`].
53    ///
54    /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
55    /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
56    /// a [`Vector`].
57    ///
58    /// [`Pointer`]: crate::TypeInner::Pointer
59    /// [`Vector`]: crate::TypeInner::Vector
60    pub arguments: Vec<FunctionArgument>,
61
62    /// The parameter types given in the source code.
63    ///
64    /// The `out` and `inout` qualifiers don't affect the types that appear
65    /// here. For example, an `inout vec2 a` argument would simply be a
66    /// [`Vector`], not a pointer to one.
67    ///
68    /// [`Vector`]: crate::TypeInner::Vector
69    pub parameters: Vec<Handle<Type>>,
70    pub parameters_info: Vec<ParameterInfo>,
71
72    pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73    pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75    pub const_typifier: Typifier,
76    pub typifier: Typifier,
77    layouter: Layouter,
78    emitter: Emitter,
79    stmt_ctx: Option<StmtContext>,
80    pub body: Block,
81    pub module: &'a mut crate::Module,
82    pub is_const: bool,
83    /// Tracks the expression kind of `Expression`s residing in `self.expressions`
84    pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85    /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
86    pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90    pub fn new(
91        frontend: &Frontend,
92        module: &'a mut crate::Module,
93        is_const: bool,
94        global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95    ) -> Result<Self> {
96        let mut this = Context {
97            expressions: Arena::new(),
98            locals: Arena::new(),
99            arguments: Vec::new(),
100
101            parameters: Vec::new(),
102            parameters_info: Vec::new(),
103
104            symbol_table: crate::front::SymbolTable::default(),
105            samplers: FastHashMap::default(),
106
107            const_typifier: Typifier::new(),
108            typifier: Typifier::new(),
109            layouter: Layouter::default(),
110            emitter: Emitter::default(),
111            stmt_ctx: Some(StmtContext::new()),
112            body: Block::new(),
113            module,
114            is_const: false,
115            local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116            global_expression_kind_tracker,
117        };
118
119        this.emit_start();
120
121        for &(ref name, lookup) in frontend.global_variables.iter() {
122            this.add_global(name, lookup)?
123        }
124        this.is_const = is_const;
125
126        Ok(this)
127    }
128
129    pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130    where
131        F: FnOnce(&mut Self) -> Result<()>,
132    {
133        self.new_body_with_ret(cb).map(|(b, _)| b)
134    }
135
136    pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137    where
138        F: FnOnce(&mut Self) -> Result<R>,
139    {
140        self.emit_restart();
141        let old_body = core::mem::replace(&mut self.body, Block::new());
142        let res = cb(self);
143        self.emit_restart();
144        let new_body = core::mem::replace(&mut self.body, old_body);
145        res.map(|r| (new_body, r))
146    }
147
148    pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149    where
150        F: FnOnce(&mut Self) -> Result<()>,
151    {
152        self.emit_restart();
153        let old_body = core::mem::replace(&mut self.body, body);
154        let res = cb(self);
155        self.emit_restart();
156        let body = core::mem::replace(&mut self.body, old_body);
157        res.map(|_| body)
158    }
159
160    pub fn add_global(
161        &mut self,
162        name: &str,
163        GlobalLookup {
164            kind,
165            entry_arg,
166            mutable,
167        }: GlobalLookup,
168    ) -> Result<()> {
169        let (expr, load, constant) = match kind {
170            GlobalLookupKind::Variable(v) => {
171                let span = self.module.global_variables.get_span(v);
172                (
173                    self.add_expression(Expression::GlobalVariable(v), span)?,
174                    self.module.global_variables[v].space != AddressSpace::Handle,
175                    None,
176                )
177            }
178            GlobalLookupKind::BlockSelect(handle, index) => {
179                let span = self.module.global_variables.get_span(handle);
180                let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181                let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183                (
184                    expr,
185                    {
186                        let ty = self.module.global_variables[handle].ty;
187
188                        match self.module.types[ty].inner {
189                            TypeInner::Struct { ref members, .. } => {
190                                if let TypeInner::Array {
191                                    size: crate::ArraySize::Dynamic,
192                                    ..
193                                } = self.module.types[members[index as usize].ty].inner
194                                {
195                                    false
196                                } else {
197                                    true
198                                }
199                            }
200                            _ => true,
201                        }
202                    },
203                    None,
204                )
205            }
206            GlobalLookupKind::Constant(v, ty) => {
207                let span = self.module.constants.get_span(v);
208                (
209                    self.add_expression(Expression::Constant(v), span)?,
210                    false,
211                    Some((v, ty)),
212                )
213            }
214            GlobalLookupKind::Override(v, _ty) => {
215                let span = self.module.overrides.get_span(v);
216                (
217                    self.add_expression(Expression::Override(v), span)?,
218                    false,
219                    None,
220                )
221            }
222        };
223
224        let var = VariableReference {
225            expr,
226            load,
227            mutable,
228            constant,
229            entry_arg,
230        };
231
232        self.symbol_table.add(name.into(), var);
233
234        Ok(())
235    }
236
237    /// Starts the expression emitter
238    ///
239    /// # Panics
240    ///
241    /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
242    #[inline]
243    pub fn emit_start(&mut self) {
244        self.emitter.start(&self.expressions)
245    }
246
247    /// Emits all the expressions captured by the emitter to the current body
248    ///
249    /// # Panics
250    ///
251    /// - If called before calling [`emit_start`].
252    /// - If called twice in a row without calling [`emit_start`].
253    ///
254    /// [`emit_start`]: Self::emit_start
255    pub fn emit_end(&mut self) {
256        self.body.extend(self.emitter.finish(&self.expressions))
257    }
258
259    /// Emits all the expressions captured by the emitter to the current body
260    /// and starts the emitter again
261    ///
262    /// # Panics
263    ///
264    /// - If called before calling [`emit_start`][Self::emit_start].
265    pub fn emit_restart(&mut self) {
266        self.emit_end();
267        self.emit_start()
268    }
269
270    pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
271        let mut eval = if self.is_const {
272            crate::proc::ConstantEvaluator::for_glsl_module(
273                self.module,
274                self.global_expression_kind_tracker,
275                &mut self.layouter,
276            )
277        } else {
278            crate::proc::ConstantEvaluator::for_glsl_function(
279                self.module,
280                &mut self.expressions,
281                &mut self.local_expression_kind_tracker,
282                &mut self.layouter,
283                &mut self.emitter,
284                &mut self.body,
285            )
286        };
287
288        eval.try_eval_and_append(expr, meta).map_err(|e| Error {
289            kind: e.into(),
290            meta,
291        })
292    }
293
294    /// Add variable to current scope
295    ///
296    /// Returns a variable if a variable with the same name was already defined,
297    /// otherwise returns `None`
298    pub fn add_local_var(
299        &mut self,
300        name: String,
301        expr: Handle<Expression>,
302        mutable: bool,
303    ) -> Option<VariableReference> {
304        let var = VariableReference {
305            expr,
306            load: true,
307            mutable,
308            constant: None,
309            entry_arg: None,
310        };
311
312        self.symbol_table.add(name, var)
313    }
314
315    /// Add function argument to current scope
316    pub fn add_function_arg(
317        &mut self,
318        name_meta: Option<(String, Span)>,
319        ty: Handle<Type>,
320        qualifier: ParameterQualifier,
321    ) -> Result<()> {
322        let index = self.arguments.len();
323        let mut arg = FunctionArgument {
324            name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
325            ty,
326            binding: None,
327        };
328        self.parameters.push(ty);
329
330        let opaque = match self.module.types[ty].inner {
331            TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
332            _ => false,
333        };
334
335        if qualifier.is_lhs() {
336            let span = self.module.types.get_span(arg.ty);
337            arg.ty = self.module.types.insert(
338                Type {
339                    name: None,
340                    inner: TypeInner::Pointer {
341                        base: arg.ty,
342                        space: AddressSpace::Function,
343                    },
344                },
345                span,
346            )
347        }
348
349        self.arguments.push(arg);
350
351        self.parameters_info.push(ParameterInfo {
352            qualifier,
353            depth: false,
354        });
355
356        if let Some((name, meta)) = name_meta {
357            let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
358            let mutable = qualifier != ParameterQualifier::Const && !opaque;
359            let load = qualifier.is_lhs();
360
361            let var = if mutable && !load {
362                let handle = self.locals.append(
363                    LocalVariable {
364                        name: Some(name.clone()),
365                        ty,
366                        init: None,
367                    },
368                    meta,
369                );
370                let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
371
372                self.emit_restart();
373
374                self.body.push(
375                    Statement::Store {
376                        pointer: local_expr,
377                        value: expr,
378                    },
379                    meta,
380                );
381
382                VariableReference {
383                    expr: local_expr,
384                    load: true,
385                    mutable,
386                    constant: None,
387                    entry_arg: None,
388                }
389            } else {
390                VariableReference {
391                    expr,
392                    load,
393                    mutable,
394                    constant: None,
395                    entry_arg: None,
396                }
397            };
398
399            self.symbol_table.add(name, var);
400        }
401
402        Ok(())
403    }
404
405    /// Returns a [`StmtContext`] to be used in parsing and lowering
406    ///
407    /// # Panics
408    ///
409    /// - If more than one [`StmtContext`] are active at the same time or if the
410    ///   previous call didn't use it in lowering.
411    #[must_use]
412    pub const fn stmt_ctx(&mut self) -> StmtContext {
413        self.stmt_ctx.take().unwrap()
414    }
415
416    /// Lowers a [`HirExpr`] which might produce a [`Expression`].
417    ///
418    /// consumes a [`StmtContext`] returning it to the context so that it can be
419    /// used again later.
420    pub fn lower(
421        &mut self,
422        mut stmt: StmtContext,
423        frontend: &mut Frontend,
424        expr: Handle<HirExpr>,
425        pos: ExprPos,
426    ) -> Result<(Option<Handle<Expression>>, Span)> {
427        let res = self.lower_inner(&stmt, frontend, expr, pos);
428
429        stmt.hir_exprs.clear();
430        self.stmt_ctx = Some(stmt);
431
432        res
433    }
434
435    /// Similar to [`lower`](Self::lower) but returns an error if the expression
436    /// returns void (ie. doesn't produce a [`Expression`]).
437    ///
438    /// consumes a [`StmtContext`] returning it to the context so that it can be
439    /// used again later.
440    pub fn lower_expect(
441        &mut self,
442        mut stmt: StmtContext,
443        frontend: &mut Frontend,
444        expr: Handle<HirExpr>,
445        pos: ExprPos,
446    ) -> Result<(Handle<Expression>, Span)> {
447        let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
448
449        stmt.hir_exprs.clear();
450        self.stmt_ctx = Some(stmt);
451
452        res
453    }
454
455    /// internal implementation of [`lower_expect`](Self::lower_expect)
456    ///
457    /// this method is only public because it's used in
458    /// [`function_call`](Frontend::function_call), unless you know what
459    /// you're doing use [`lower_expect`](Self::lower_expect)
460    pub fn lower_expect_inner(
461        &mut self,
462        stmt: &StmtContext,
463        frontend: &mut Frontend,
464        expr: Handle<HirExpr>,
465        pos: ExprPos,
466    ) -> Result<(Handle<Expression>, Span)> {
467        let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
468
469        let expr = match maybe_expr {
470            Some(e) => e,
471            None => {
472                return Err(Error {
473                    kind: ErrorKind::SemanticError("Expression returns void".into()),
474                    meta,
475                })
476            }
477        };
478
479        Ok((expr, meta))
480    }
481
482    fn lower_store(
483        &mut self,
484        pointer: Handle<Expression>,
485        value: Handle<Expression>,
486        meta: Span,
487    ) -> Result<()> {
488        if let Expression::Swizzle {
489            size,
490            mut vector,
491            pattern,
492        } = self.expressions[pointer]
493        {
494            // Stores to swizzled values are not directly supported,
495            // lower them as series of per-component stores.
496            let size = match size {
497                VectorSize::Bi => 2,
498                VectorSize::Tri => 3,
499                VectorSize::Quad => 4,
500            };
501
502            if let Expression::Load { pointer } = self.expressions[vector] {
503                vector = pointer;
504            }
505
506            #[allow(clippy::needless_range_loop)]
507            for index in 0..size {
508                let dst = self.add_expression(
509                    Expression::AccessIndex {
510                        base: vector,
511                        index: pattern[index].index(),
512                    },
513                    meta,
514                )?;
515                let src = self.add_expression(
516                    Expression::AccessIndex {
517                        base: value,
518                        index: index as u32,
519                    },
520                    meta,
521                )?;
522
523                self.emit_restart();
524
525                self.body.push(
526                    Statement::Store {
527                        pointer: dst,
528                        value: src,
529                    },
530                    meta,
531                );
532            }
533        } else {
534            self.emit_restart();
535
536            self.body.push(Statement::Store { pointer, value }, meta);
537        }
538
539        Ok(())
540    }
541
542    /// Internal implementation of [`lower`](Self::lower)
543    fn lower_inner(
544        &mut self,
545        stmt: &StmtContext,
546        frontend: &mut Frontend,
547        expr: Handle<HirExpr>,
548        pos: ExprPos,
549    ) -> Result<(Option<Handle<Expression>>, Span)> {
550        let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
551
552        log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
553
554        let handle = match *kind {
555            HirExprKind::Access { base, index } => {
556                let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
557                let maybe_constant_index = match pos {
558                    // Don't try to generate `AccessIndex` if in a LHS position, since it
559                    // wouldn't produce a pointer.
560                    ExprPos::Lhs => None,
561                    _ => self
562                        .module
563                        .to_ctx()
564                        .get_const_val_from(index, &self.expressions)
565                        .ok(),
566                };
567
568                let base = self
569                    .lower_expect_inner(
570                        stmt,
571                        frontend,
572                        base,
573                        pos.maybe_access_base(maybe_constant_index.is_some()),
574                    )?
575                    .0;
576
577                let pointer = maybe_constant_index
578                    .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
579                    .unwrap_or_else(|| {
580                        self.add_expression(Expression::Access { base, index }, meta)
581                    })?;
582
583                if ExprPos::Rhs == pos {
584                    let resolved = self.resolve_type(pointer, meta)?;
585                    if resolved.pointer_space().is_some() {
586                        return Ok((
587                            Some(self.add_expression(Expression::Load { pointer }, meta)?),
588                            meta,
589                        ));
590                    }
591                }
592
593                pointer
594            }
595            HirExprKind::Select { base, ref field } => {
596                let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
597
598                frontend.field_selection(self, pos, base, field, meta)?
599            }
600            HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
601                self.add_expression(Expression::Literal(literal), meta)?
602            }
603            HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
604                let (mut left, left_meta) =
605                    self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
606                let (mut right, right_meta) =
607                    self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
608
609                match op {
610                    BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
611                        self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
612                    }
613                    _ => self
614                        .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
615                }
616
617                self.typifier_grow(left, left_meta)?;
618                self.typifier_grow(right, right_meta)?;
619
620                let left_inner = self.get_type(left);
621                let right_inner = self.get_type(right);
622
623                match (left_inner, right_inner) {
624                    (
625                        &TypeInner::Matrix {
626                            columns: left_columns,
627                            rows: left_rows,
628                            scalar: left_scalar,
629                        },
630                        &TypeInner::Matrix {
631                            columns: right_columns,
632                            rows: right_rows,
633                            scalar: right_scalar,
634                        },
635                    ) => {
636                        let dimensions_ok = if op == BinaryOperator::Multiply {
637                            left_columns == right_rows
638                        } else {
639                            left_columns == right_columns && left_rows == right_rows
640                        };
641
642                        // Check that the two arguments have the same dimensions
643                        if !dimensions_ok || left_scalar != right_scalar {
644                            frontend.errors.push(Error {
645                                kind: ErrorKind::SemanticError(
646                                    format!(
647                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
648                                    )
649                                    .into(),
650                                ),
651                                meta,
652                            })
653                        }
654
655                        match op {
656                            BinaryOperator::Divide => {
657                                // Naga IR doesn't support matrix division so we need to
658                                // divide the columns individually and reassemble the matrix
659                                let mut components = Vec::with_capacity(left_columns as usize);
660
661                                for index in 0..left_columns as u32 {
662                                    // Get the column vectors
663                                    let left_vector = self.add_expression(
664                                        Expression::AccessIndex { base: left, index },
665                                        meta,
666                                    )?;
667                                    let right_vector = self.add_expression(
668                                        Expression::AccessIndex { base: right, index },
669                                        meta,
670                                    )?;
671
672                                    // Divide the vectors
673                                    let column = self.add_expression(
674                                        Expression::Binary {
675                                            op,
676                                            left: left_vector,
677                                            right: right_vector,
678                                        },
679                                        meta,
680                                    )?;
681
682                                    components.push(column)
683                                }
684
685                                let ty = self.module.types.insert(
686                                    Type {
687                                        name: None,
688                                        inner: TypeInner::Matrix {
689                                            columns: left_columns,
690                                            rows: left_rows,
691                                            scalar: left_scalar,
692                                        },
693                                    },
694                                    Span::default(),
695                                );
696
697                                // Rebuild the matrix from the divided vectors
698                                self.add_expression(Expression::Compose { ty, components }, meta)?
699                            }
700                            BinaryOperator::Equal | BinaryOperator::NotEqual => {
701                                // Naga IR doesn't support matrix comparisons so we need to
702                                // compare the columns individually and then fold them together
703                                //
704                                // The folding is done using a logical and for equality and
705                                // a logical or for inequality
706                                let equals = op == BinaryOperator::Equal;
707
708                                let (op, combine, fun) = match equals {
709                                    true => (
710                                        BinaryOperator::Equal,
711                                        BinaryOperator::LogicalAnd,
712                                        RelationalFunction::All,
713                                    ),
714                                    false => (
715                                        BinaryOperator::NotEqual,
716                                        BinaryOperator::LogicalOr,
717                                        RelationalFunction::Any,
718                                    ),
719                                };
720
721                                let mut root = None;
722
723                                for index in 0..left_columns as u32 {
724                                    // Get the column vectors
725                                    let left_vector = self.add_expression(
726                                        Expression::AccessIndex { base: left, index },
727                                        meta,
728                                    )?;
729                                    let right_vector = self.add_expression(
730                                        Expression::AccessIndex { base: right, index },
731                                        meta,
732                                    )?;
733
734                                    let argument = self.add_expression(
735                                        Expression::Binary {
736                                            op,
737                                            left: left_vector,
738                                            right: right_vector,
739                                        },
740                                        meta,
741                                    )?;
742
743                                    // The result of comparing two vectors is a boolean vector
744                                    // so use a relational function like all to get a single
745                                    // boolean value
746                                    let compare = self.add_expression(
747                                        Expression::Relational { fun, argument },
748                                        meta,
749                                    )?;
750
751                                    // Fold the result
752                                    root = Some(match root {
753                                        Some(right) => self.add_expression(
754                                            Expression::Binary {
755                                                op: combine,
756                                                left: compare,
757                                                right,
758                                            },
759                                            meta,
760                                        )?,
761                                        None => compare,
762                                    });
763                                }
764
765                                root.unwrap()
766                            }
767                            _ => {
768                                self.add_expression(Expression::Binary { left, op, right }, meta)?
769                            }
770                        }
771                    }
772                    (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
773                        BinaryOperator::Equal | BinaryOperator::NotEqual => {
774                            let equals = op == BinaryOperator::Equal;
775
776                            let (op, fun) = match equals {
777                                true => (BinaryOperator::Equal, RelationalFunction::All),
778                                false => (BinaryOperator::NotEqual, RelationalFunction::Any),
779                            };
780
781                            let argument =
782                                self.add_expression(Expression::Binary { op, left, right }, meta)?;
783
784                            self.add_expression(Expression::Relational { fun, argument }, meta)?
785                        }
786                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
787                    },
788                    (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
789                        BinaryOperator::Add
790                        | BinaryOperator::Subtract
791                        | BinaryOperator::Divide
792                        | BinaryOperator::And
793                        | BinaryOperator::ExclusiveOr
794                        | BinaryOperator::InclusiveOr
795                        | BinaryOperator::ShiftLeft
796                        | BinaryOperator::ShiftRight => {
797                            let scalar_vector = self
798                                .add_expression(Expression::Splat { size, value: right }, meta)?;
799
800                            self.add_expression(
801                                Expression::Binary {
802                                    op,
803                                    left,
804                                    right: scalar_vector,
805                                },
806                                meta,
807                            )?
808                        }
809                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
810                    },
811                    (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
812                        BinaryOperator::Add
813                        | BinaryOperator::Subtract
814                        | BinaryOperator::Divide
815                        | BinaryOperator::And
816                        | BinaryOperator::ExclusiveOr
817                        | BinaryOperator::InclusiveOr => {
818                            let scalar_vector =
819                                self.add_expression(Expression::Splat { size, value: left }, meta)?;
820
821                            self.add_expression(
822                                Expression::Binary {
823                                    op,
824                                    left: scalar_vector,
825                                    right,
826                                },
827                                meta,
828                            )?
829                        }
830                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
831                    },
832                    (
833                        &TypeInner::Scalar(left_scalar),
834                        &TypeInner::Matrix {
835                            rows,
836                            columns,
837                            scalar: right_scalar,
838                        },
839                    ) => {
840                        // Check that the two arguments have the same scalar type
841                        if left_scalar != right_scalar {
842                            frontend.errors.push(Error {
843                                kind: ErrorKind::SemanticError(
844                                    format!(
845                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
846                                    )
847                                    .into(),
848                                ),
849                                meta,
850                            })
851                        }
852
853                        match op {
854                            BinaryOperator::Divide
855                            | BinaryOperator::Add
856                            | BinaryOperator::Subtract => {
857                                // Naga IR doesn't support all matrix by scalar operations so
858                                // we need for some to turn the scalar into a vector by
859                                // splatting it and then for each column vector apply the
860                                // operation and finally reconstruct the matrix
861                                let scalar_vector = self.add_expression(
862                                    Expression::Splat {
863                                        size: rows,
864                                        value: left,
865                                    },
866                                    meta,
867                                )?;
868
869                                let mut components = Vec::with_capacity(columns as usize);
870
871                                for index in 0..columns as u32 {
872                                    // Get the column vector
873                                    let matrix_column = self.add_expression(
874                                        Expression::AccessIndex { base: right, index },
875                                        meta,
876                                    )?;
877
878                                    // Apply the operation to the splatted vector and
879                                    // the column vector
880                                    let column = self.add_expression(
881                                        Expression::Binary {
882                                            op,
883                                            left: scalar_vector,
884                                            right: matrix_column,
885                                        },
886                                        meta,
887                                    )?;
888
889                                    components.push(column)
890                                }
891
892                                let ty = self.module.types.insert(
893                                    Type {
894                                        name: None,
895                                        inner: TypeInner::Matrix {
896                                            columns,
897                                            rows,
898                                            scalar: left_scalar,
899                                        },
900                                    },
901                                    Span::default(),
902                                );
903
904                                // Rebuild the matrix from the operation result vectors
905                                self.add_expression(Expression::Compose { ty, components }, meta)?
906                            }
907                            _ => {
908                                self.add_expression(Expression::Binary { left, op, right }, meta)?
909                            }
910                        }
911                    }
912                    (
913                        &TypeInner::Matrix {
914                            rows,
915                            columns,
916                            scalar: left_scalar,
917                        },
918                        &TypeInner::Scalar(right_scalar),
919                    ) => {
920                        // Check that the two arguments have the same scalar type
921                        if left_scalar != right_scalar {
922                            frontend.errors.push(Error {
923                                kind: ErrorKind::SemanticError(
924                                    format!(
925                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
926                                    )
927                                    .into(),
928                                ),
929                                meta,
930                            })
931                        }
932
933                        match op {
934                            BinaryOperator::Divide
935                            | BinaryOperator::Add
936                            | BinaryOperator::Subtract => {
937                                // Naga IR doesn't support all matrix by scalar operations so
938                                // we need for some to turn the scalar into a vector by
939                                // splatting it and then for each column vector apply the
940                                // operation and finally reconstruct the matrix
941
942                                let scalar_vector = self.add_expression(
943                                    Expression::Splat {
944                                        size: rows,
945                                        value: right,
946                                    },
947                                    meta,
948                                )?;
949
950                                let mut components = Vec::with_capacity(columns as usize);
951
952                                for index in 0..columns as u32 {
953                                    // Get the column vector
954                                    let matrix_column = self.add_expression(
955                                        Expression::AccessIndex { base: left, index },
956                                        meta,
957                                    )?;
958
959                                    // Apply the operation to the splatted vector and
960                                    // the column vector
961                                    let column = self.add_expression(
962                                        Expression::Binary {
963                                            op,
964                                            left: matrix_column,
965                                            right: scalar_vector,
966                                        },
967                                        meta,
968                                    )?;
969
970                                    components.push(column)
971                                }
972
973                                let ty = self.module.types.insert(
974                                    Type {
975                                        name: None,
976                                        inner: TypeInner::Matrix {
977                                            columns,
978                                            rows,
979                                            scalar: left_scalar,
980                                        },
981                                    },
982                                    Span::default(),
983                                );
984
985                                // Rebuild the matrix from the operation result vectors
986                                self.add_expression(Expression::Compose { ty, components }, meta)?
987                            }
988                            _ => {
989                                self.add_expression(Expression::Binary { left, op, right }, meta)?
990                            }
991                        }
992                    }
993                    _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
994                }
995            }
996            HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
997                let expr = self
998                    .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
999                    .0;
1000
1001                if let TypeInner::Matrix { scalar, .. } = *self.resolve_type(expr, meta)? {
1002                    // Naga IR doesn't support matrix negation, so we need to turn it into
1003                    // multiplication by scalar -1.
1004                    let minus_one = Literal::minus_one(scalar).ok_or_else(|| Error {
1005                        kind: ErrorKind::SemanticError(
1006                            format!("Cannot apply operator {op:?} to type {scalar:?}").into(),
1007                        ),
1008                        meta,
1009                    })?;
1010                    let lhs = self.add_expression(Expression::Literal(minus_one), meta)?;
1011                    self.add_expression(
1012                        Expression::Binary {
1013                            op: BinaryOperator::Multiply,
1014                            left: lhs,
1015                            right: expr,
1016                        },
1017                        meta,
1018                    )?
1019                } else {
1020                    self.add_expression(Expression::Unary { op, expr }, meta)?
1021                }
1022            }
1023            HirExprKind::Variable(ref var) => match pos {
1024                ExprPos::Lhs => {
1025                    if !var.mutable {
1026                        frontend.errors.push(Error {
1027                            kind: ErrorKind::SemanticError(
1028                                "Variable cannot be used in LHS position".into(),
1029                            ),
1030                            meta,
1031                        })
1032                    }
1033
1034                    var.expr
1035                }
1036                ExprPos::AccessBase { constant_index } => {
1037                    // If the index isn't constant all accesses backed by a constant base need
1038                    // to be done through a proxy local variable, since constants have a non
1039                    // pointer type which is required for dynamic indexing
1040                    if !constant_index {
1041                        if let Some((constant, ty)) = var.constant {
1042                            let init = self
1043                                .add_expression(Expression::Constant(constant), Span::default())?;
1044                            let local = self.locals.append(
1045                                LocalVariable {
1046                                    name: None,
1047                                    ty,
1048                                    init: Some(init),
1049                                },
1050                                Span::default(),
1051                            );
1052
1053                            self.add_expression(Expression::LocalVariable(local), Span::default())?
1054                        } else {
1055                            var.expr
1056                        }
1057                    } else {
1058                        var.expr
1059                    }
1060                }
1061                _ if var.load => {
1062                    self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1063                }
1064                ExprPos::Rhs => {
1065                    if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1066                        self.add_expression(Expression::Constant(constant), meta)?
1067                    } else {
1068                        // Check if this is an Override expression in const context
1069                        if self.is_const {
1070                            if let Expression::Override(o) = self.expressions[var.expr] {
1071                                // Need to add the Override expression to the global arena
1072                                self.add_expression(Expression::Override(o), meta)?
1073                            } else {
1074                                var.expr
1075                            }
1076                        } else {
1077                            var.expr
1078                        }
1079                    }
1080                }
1081            },
1082            HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1083                let maybe_expr = frontend.function_or_constructor_call(
1084                    self,
1085                    stmt,
1086                    call.kind.clone(),
1087                    &call.args,
1088                    meta,
1089                )?;
1090                return Ok((maybe_expr, meta));
1091            }
1092            // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
1093            //
1094            // The ternary operator is defined to only evaluate one of the two possible
1095            // expressions which means that it's behavior is that of an `if` statement,
1096            // and it's merely syntactic sugar for it.
1097            HirExprKind::Conditional {
1098                condition,
1099                accept,
1100                reject,
1101            } if ExprPos::Lhs != pos => {
1102                // Given an expression `a ? b : c`, we need to produce a Naga
1103                // statement roughly like:
1104                //
1105                //     var temp;
1106                //     if a {
1107                //         temp = convert(b);
1108                //     } else  {
1109                //         temp = convert(c);
1110                //     }
1111                //
1112                // where `convert` stands for type conversions to bring `b` and `c` to
1113                // the same type, and then use `temp` to represent the value of the whole
1114                // conditional expression in subsequent code.
1115
1116                // Lower the condition first to the current bodyy
1117                let condition = self
1118                    .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1119                    .0;
1120
1121                let (mut accept_body, (mut accept, accept_meta)) =
1122                    self.new_body_with_ret(|ctx| {
1123                        // Lower the `true` branch
1124                        ctx.lower_expect_inner(stmt, frontend, accept, pos)
1125                    })?;
1126
1127                let (mut reject_body, (mut reject, reject_meta)) =
1128                    self.new_body_with_ret(|ctx| {
1129                        // Lower the `false` branch
1130                        ctx.lower_expect_inner(stmt, frontend, reject, pos)
1131                    })?;
1132
1133                // We need to do some custom implicit conversions since the two target expressions
1134                // are in different bodies
1135                if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1136                    // Get the components of both branches and calculate the type power
1137                    self.expr_scalar_components(accept, accept_meta)?
1138                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1139                    self.expr_scalar_components(reject, reject_meta)?
1140                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1141                ) {
1142                    match accept_power.cmp(&reject_power) {
1143                        core::cmp::Ordering::Less => {
1144                            accept_body = self.with_body(accept_body, |ctx| {
1145                                ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1146                                Ok(())
1147                            })?;
1148                        }
1149                        core::cmp::Ordering::Equal => {}
1150                        core::cmp::Ordering::Greater => {
1151                            reject_body = self.with_body(reject_body, |ctx| {
1152                                ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1153                                Ok(())
1154                            })?;
1155                        }
1156                    }
1157                }
1158
1159                // We need to get the type of the resulting expression to create the local,
1160                // this must be done after implicit conversions to ensure both branches have
1161                // the same type.
1162                let ty = self.resolve_type_handle(accept, accept_meta)?;
1163
1164                // Add the local that will hold the result of our conditional
1165                let local = self.locals.append(
1166                    LocalVariable {
1167                        name: None,
1168                        ty,
1169                        init: None,
1170                    },
1171                    meta,
1172                );
1173
1174                let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1175
1176                // Add to each  the store to the result variable
1177                accept_body.push(
1178                    Statement::Store {
1179                        pointer: local_expr,
1180                        value: accept,
1181                    },
1182                    accept_meta,
1183                );
1184                reject_body.push(
1185                    Statement::Store {
1186                        pointer: local_expr,
1187                        value: reject,
1188                    },
1189                    reject_meta,
1190                );
1191
1192                // Finally add the `If` to the main body with the `condition` we lowered
1193                // earlier and the branches we prepared.
1194                self.body.push(
1195                    Statement::If {
1196                        condition,
1197                        accept: accept_body,
1198                        reject: reject_body,
1199                    },
1200                    meta,
1201                );
1202
1203                // Note: `Expression::Load` must be emitted before it's used so make
1204                // sure the emitter is active here.
1205                self.add_expression(
1206                    Expression::Load {
1207                        pointer: local_expr,
1208                    },
1209                    meta,
1210                )?
1211            }
1212            HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1213                let (pointer, ptr_meta) =
1214                    self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1215                let (mut value, value_meta) =
1216                    self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1217
1218                let ty = match *self.resolve_type(pointer, ptr_meta)? {
1219                    TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1220                    ref ty => ty,
1221                };
1222
1223                if let Some(scalar) = scalar_components(ty) {
1224                    self.implicit_conversion(&mut value, value_meta, scalar)?;
1225                }
1226
1227                self.lower_store(pointer, value, meta)?;
1228
1229                value
1230            }
1231            HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1232                let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1233                let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1234                    pointer
1235                } else {
1236                    self.add_expression(Expression::Load { pointer }, meta)?
1237                };
1238
1239                let res = match *self.resolve_type(left, meta)? {
1240                    TypeInner::Scalar(scalar) => {
1241                        let ty = TypeInner::Scalar(scalar);
1242                        Literal::one(scalar).map(|i| (ty, i, None, None))
1243                    }
1244                    TypeInner::Vector { size, scalar } => {
1245                        let ty = TypeInner::Vector { size, scalar };
1246                        Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1247                    }
1248                    TypeInner::Matrix {
1249                        columns,
1250                        rows,
1251                        scalar,
1252                    } => {
1253                        let ty = TypeInner::Matrix {
1254                            columns,
1255                            rows,
1256                            scalar,
1257                        };
1258                        Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1259                    }
1260                    _ => None,
1261                };
1262                let (ty_inner, literal, rows, columns) = match res {
1263                    Some(res) => res,
1264                    None => {
1265                        frontend.errors.push(Error {
1266                            kind: ErrorKind::SemanticError(
1267                                "Increment/decrement only works on scalar/vector/matrix".into(),
1268                            ),
1269                            meta,
1270                        });
1271                        return Ok((Some(left), meta));
1272                    }
1273                };
1274
1275                let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1276
1277                // Glsl allows pre/postfixes operations on vectors and matrices, so if the
1278                // target is either of them change the right side of the addition to be splatted
1279                // to the same size as the target, furthermore if the target is a matrix
1280                // use a composed matrix using the splatted value.
1281                if let Some(size) = rows {
1282                    right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1283
1284                    if let Some(cols) = columns {
1285                        let ty = self.module.types.insert(
1286                            Type {
1287                                name: None,
1288                                inner: ty_inner,
1289                            },
1290                            meta,
1291                        );
1292
1293                        right = self.add_expression(
1294                            Expression::Compose {
1295                                ty,
1296                                components: core::iter::repeat_n(right, cols as usize).collect(),
1297                            },
1298                            meta,
1299                        )?;
1300                    }
1301                }
1302
1303                let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1304
1305                self.lower_store(pointer, value, meta)?;
1306
1307                if postfix {
1308                    left
1309                } else {
1310                    value
1311                }
1312            }
1313            HirExprKind::Method {
1314                expr: object,
1315                ref name,
1316                ref args,
1317            } if ExprPos::Lhs != pos => {
1318                let args = args
1319                    .iter()
1320                    .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1321                    .collect::<Result<Vec<_>>>()?;
1322                match name.as_ref() {
1323                    "length" => {
1324                        if !args.is_empty() {
1325                            frontend.errors.push(Error {
1326                                kind: ErrorKind::SemanticError(
1327                                    ".length() doesn't take any arguments".into(),
1328                                ),
1329                                meta,
1330                            });
1331                        }
1332                        let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1333                        let array_type = self.resolve_type(lowered_array, meta)?;
1334
1335                        match *array_type {
1336                            TypeInner::Array {
1337                                size: crate::ArraySize::Constant(size),
1338                                ..
1339                            } => {
1340                                let mut array_length = self.add_expression(
1341                                    Expression::Literal(Literal::U32(size.get())),
1342                                    meta,
1343                                )?;
1344                                self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1345                                array_length
1346                            }
1347                            // let the error be handled in type checking if it's not a dynamic array
1348                            _ => {
1349                                let mut array_length = self
1350                                    .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1351                                self.conversion(&mut array_length, meta, Scalar::I32)?;
1352                                array_length
1353                            }
1354                        }
1355                    }
1356
1357                    _ => {
1358                        return Err(Error {
1359                            kind: ErrorKind::SemanticError(
1360                                format!("unknown method '{name}'").into(),
1361                            ),
1362                            meta,
1363                        });
1364                    }
1365                }
1366            }
1367            HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => {
1368                let mut last_handle = None;
1369                for expr in exprs.iter() {
1370                    let (handle, _) =
1371                        self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?;
1372                    last_handle = Some(handle);
1373                }
1374                match last_handle {
1375                    Some(handle) => handle,
1376                    None => unreachable!(),
1377                }
1378            }
1379            _ => {
1380                return Err(Error {
1381                    kind: ErrorKind::SemanticError(
1382                        format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1383                            .into(),
1384                    ),
1385                    meta,
1386                })
1387            }
1388        };
1389
1390        log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1391
1392        Ok((Some(handle), meta))
1393    }
1394
1395    pub fn expr_scalar_components(
1396        &mut self,
1397        expr: Handle<Expression>,
1398        meta: Span,
1399    ) -> Result<Option<Scalar>> {
1400        let ty = self.resolve_type(expr, meta)?;
1401        Ok(scalar_components(ty))
1402    }
1403
1404    pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1405        Ok(self
1406            .expr_scalar_components(expr, meta)?
1407            .and_then(type_power))
1408    }
1409
1410    pub fn conversion(
1411        &mut self,
1412        expr: &mut Handle<Expression>,
1413        meta: Span,
1414        scalar: Scalar,
1415    ) -> Result<()> {
1416        *expr = self.add_expression(
1417            Expression::As {
1418                expr: *expr,
1419                kind: scalar.kind,
1420                convert: Some(scalar.width),
1421            },
1422            meta,
1423        )?;
1424
1425        Ok(())
1426    }
1427
1428    pub fn implicit_conversion(
1429        &mut self,
1430        expr: &mut Handle<Expression>,
1431        meta: Span,
1432        scalar: Scalar,
1433    ) -> Result<()> {
1434        if let (Some(tgt_power), Some(expr_power)) =
1435            (type_power(scalar), self.expr_power(*expr, meta)?)
1436        {
1437            if tgt_power > expr_power {
1438                self.conversion(expr, meta, scalar)?;
1439            }
1440        }
1441
1442        Ok(())
1443    }
1444
1445    pub fn forced_conversion(
1446        &mut self,
1447        expr: &mut Handle<Expression>,
1448        meta: Span,
1449        scalar: Scalar,
1450    ) -> Result<()> {
1451        if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1452            if expr_scalar != scalar {
1453                self.conversion(expr, meta, scalar)?;
1454            }
1455        }
1456
1457        Ok(())
1458    }
1459
1460    pub fn binary_implicit_conversion(
1461        &mut self,
1462        left: &mut Handle<Expression>,
1463        left_meta: Span,
1464        right: &mut Handle<Expression>,
1465        right_meta: Span,
1466    ) -> Result<()> {
1467        let left_components = self.expr_scalar_components(*left, left_meta)?;
1468        let right_components = self.expr_scalar_components(*right, right_meta)?;
1469
1470        if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1471            left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1472            right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1473        ) {
1474            match left_power.cmp(&right_power) {
1475                core::cmp::Ordering::Less => {
1476                    self.conversion(left, left_meta, right_scalar)?;
1477                }
1478                core::cmp::Ordering::Equal => {}
1479                core::cmp::Ordering::Greater => {
1480                    self.conversion(right, right_meta, left_scalar)?;
1481                }
1482            }
1483        }
1484
1485        Ok(())
1486    }
1487
1488    pub fn implicit_splat(
1489        &mut self,
1490        expr: &mut Handle<Expression>,
1491        meta: Span,
1492        vector_size: Option<VectorSize>,
1493    ) -> Result<()> {
1494        let expr_type = self.resolve_type(*expr, meta)?;
1495
1496        if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1497            *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1498        }
1499
1500        Ok(())
1501    }
1502
1503    pub fn vector_resize(
1504        &mut self,
1505        size: VectorSize,
1506        vector: Handle<Expression>,
1507        meta: Span,
1508    ) -> Result<Handle<Expression>> {
1509        self.add_expression(
1510            Expression::Swizzle {
1511                size,
1512                vector,
1513                pattern: crate::SwizzleComponent::XYZW,
1514            },
1515            meta,
1516        )
1517    }
1518}
1519
1520impl Index<Handle<Expression>> for Context<'_> {
1521    type Output = Expression;
1522
1523    fn index(&self, index: Handle<Expression>) -> &Self::Output {
1524        if self.is_const {
1525            &self.module.global_expressions[index]
1526        } else {
1527            &self.expressions[index]
1528        }
1529    }
1530}
1531
1532/// Helper struct passed when parsing expressions
1533///
1534/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
1535/// and only one of these may be active at any time per context.
1536#[derive(Debug)]
1537pub struct StmtContext {
1538    /// A arena of high level expressions which can be lowered through a
1539    /// [`Context`] to Naga's [`Expression`]s
1540    pub hir_exprs: Arena<HirExpr>,
1541}
1542
1543impl StmtContext {
1544    const fn new() -> Self {
1545        StmtContext {
1546            hir_exprs: Arena::new(),
1547        }
1548    }
1549}