naga/front/glsl/
context.rs

1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5    ast::{
6        GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7        VariableReference,
8    },
9    error::{Error, ErrorKind},
10    types::{scalar_components, type_power},
11    Frontend, Result,
12};
13use crate::{
14    front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15    Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16    Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19/// The position at which an expression is, used while lowering
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22    /// The expression is in the left hand side of an assignment
23    Lhs,
24    /// The expression is in the right hand side of an assignment
25    Rhs,
26    /// The expression is an array being indexed, needed to allow constant
27    /// arrays to be dynamically indexed
28    AccessBase {
29        /// The index is a constant
30        constant_index: bool,
31    },
32}
33
34impl ExprPos {
35    /// Returns an lhs position if the current position is lhs otherwise AccessBase
36    const fn maybe_access_base(&self, constant_index: bool) -> Self {
37        match *self {
38            ExprPos::Lhs
39            | ExprPos::AccessBase {
40                constant_index: false,
41            } => *self,
42            _ => ExprPos::AccessBase { constant_index },
43        }
44    }
45}
46
47#[derive(Debug)]
48pub struct Context<'a> {
49    pub expressions: Arena<Expression>,
50    pub locals: Arena<LocalVariable>,
51
52    /// The [`FunctionArgument`]s for the final [`crate::Function`].
53    ///
54    /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
55    /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
56    /// a [`Vector`].
57    ///
58    /// [`Pointer`]: crate::TypeInner::Pointer
59    /// [`Vector`]: crate::TypeInner::Vector
60    pub arguments: Vec<FunctionArgument>,
61
62    /// The parameter types given in the source code.
63    ///
64    /// The `out` and `inout` qualifiers don't affect the types that appear
65    /// here. For example, an `inout vec2 a` argument would simply be a
66    /// [`Vector`], not a pointer to one.
67    ///
68    /// [`Vector`]: crate::TypeInner::Vector
69    pub parameters: Vec<Handle<Type>>,
70    pub parameters_info: Vec<ParameterInfo>,
71
72    pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73    pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75    pub const_typifier: Typifier,
76    pub typifier: Typifier,
77    layouter: Layouter,
78    emitter: Emitter,
79    stmt_ctx: Option<StmtContext>,
80    pub body: Block,
81    pub module: &'a mut crate::Module,
82    pub is_const: bool,
83    /// Tracks the expression kind of `Expression`s residing in `self.expressions`
84    pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85    /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
86    pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90    pub fn new(
91        frontend: &Frontend,
92        module: &'a mut crate::Module,
93        is_const: bool,
94        global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95    ) -> Result<Self> {
96        let mut this = Context {
97            expressions: Arena::new(),
98            locals: Arena::new(),
99            arguments: Vec::new(),
100
101            parameters: Vec::new(),
102            parameters_info: Vec::new(),
103
104            symbol_table: crate::front::SymbolTable::default(),
105            samplers: FastHashMap::default(),
106
107            const_typifier: Typifier::new(),
108            typifier: Typifier::new(),
109            layouter: Layouter::default(),
110            emitter: Emitter::default(),
111            stmt_ctx: Some(StmtContext::new()),
112            body: Block::new(),
113            module,
114            is_const: false,
115            local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116            global_expression_kind_tracker,
117        };
118
119        this.emit_start();
120
121        for &(ref name, lookup) in frontend.global_variables.iter() {
122            this.add_global(name, lookup)?
123        }
124        this.is_const = is_const;
125
126        Ok(this)
127    }
128
129    pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130    where
131        F: FnOnce(&mut Self) -> Result<()>,
132    {
133        self.new_body_with_ret(cb).map(|(b, _)| b)
134    }
135
136    pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137    where
138        F: FnOnce(&mut Self) -> Result<R>,
139    {
140        self.emit_restart();
141        let old_body = core::mem::replace(&mut self.body, Block::new());
142        let res = cb(self);
143        self.emit_restart();
144        let new_body = core::mem::replace(&mut self.body, old_body);
145        res.map(|r| (new_body, r))
146    }
147
148    pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149    where
150        F: FnOnce(&mut Self) -> Result<()>,
151    {
152        self.emit_restart();
153        let old_body = core::mem::replace(&mut self.body, body);
154        let res = cb(self);
155        self.emit_restart();
156        let body = core::mem::replace(&mut self.body, old_body);
157        res.map(|_| body)
158    }
159
160    pub fn add_global(
161        &mut self,
162        name: &str,
163        GlobalLookup {
164            kind,
165            entry_arg,
166            mutable,
167        }: GlobalLookup,
168    ) -> Result<()> {
169        let (expr, load, constant) = match kind {
170            GlobalLookupKind::Variable(v) => {
171                let span = self.module.global_variables.get_span(v);
172                (
173                    self.add_expression(Expression::GlobalVariable(v), span)?,
174                    self.module.global_variables[v].space != AddressSpace::Handle,
175                    None,
176                )
177            }
178            GlobalLookupKind::BlockSelect(handle, index) => {
179                let span = self.module.global_variables.get_span(handle);
180                let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181                let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183                (
184                    expr,
185                    {
186                        let ty = self.module.global_variables[handle].ty;
187
188                        match self.module.types[ty].inner {
189                            TypeInner::Struct { ref members, .. } => {
190                                if let TypeInner::Array {
191                                    size: crate::ArraySize::Dynamic,
192                                    ..
193                                } = self.module.types[members[index as usize].ty].inner
194                                {
195                                    false
196                                } else {
197                                    true
198                                }
199                            }
200                            _ => true,
201                        }
202                    },
203                    None,
204                )
205            }
206            GlobalLookupKind::Constant(v, ty) => {
207                let span = self.module.constants.get_span(v);
208                (
209                    self.add_expression(Expression::Constant(v), span)?,
210                    false,
211                    Some((v, ty)),
212                )
213            }
214            GlobalLookupKind::Override(v, _ty) => {
215                let span = self.module.overrides.get_span(v);
216                (
217                    self.add_expression(Expression::Override(v), span)?,
218                    false,
219                    None,
220                )
221            }
222        };
223
224        let var = VariableReference {
225            expr,
226            load,
227            mutable,
228            constant,
229            entry_arg,
230        };
231
232        self.symbol_table.add(name.into(), var);
233
234        Ok(())
235    }
236
237    /// Starts the expression emitter
238    ///
239    /// # Panics
240    ///
241    /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
242    #[inline]
243    pub fn emit_start(&mut self) {
244        self.emitter.start(&self.expressions)
245    }
246
247    /// Emits all the expressions captured by the emitter to the current body
248    ///
249    /// # Panics
250    ///
251    /// - If called before calling [`emit_start`].
252    /// - If called twice in a row without calling [`emit_start`].
253    ///
254    /// [`emit_start`]: Self::emit_start
255    pub fn emit_end(&mut self) {
256        self.body.extend(self.emitter.finish(&self.expressions))
257    }
258
259    /// Emits all the expressions captured by the emitter to the current body
260    /// and starts the emitter again
261    ///
262    /// # Panics
263    ///
264    /// - If called before calling [`emit_start`][Self::emit_start].
265    pub fn emit_restart(&mut self) {
266        self.emit_end();
267        self.emit_start()
268    }
269
270    pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
271        let mut eval = if self.is_const {
272            crate::proc::ConstantEvaluator::for_glsl_module(
273                self.module,
274                self.global_expression_kind_tracker,
275                &mut self.layouter,
276            )
277        } else {
278            crate::proc::ConstantEvaluator::for_glsl_function(
279                self.module,
280                &mut self.expressions,
281                &mut self.local_expression_kind_tracker,
282                &mut self.layouter,
283                &mut self.emitter,
284                &mut self.body,
285            )
286        };
287
288        eval.try_eval_and_append(expr, meta).map_err(|e| Error {
289            kind: e.into(),
290            meta,
291        })
292    }
293
294    /// Add variable to current scope
295    ///
296    /// Returns a variable if a variable with the same name was already defined,
297    /// otherwise returns `None`
298    pub fn add_local_var(
299        &mut self,
300        name: String,
301        expr: Handle<Expression>,
302        mutable: bool,
303    ) -> Option<VariableReference> {
304        let var = VariableReference {
305            expr,
306            load: true,
307            mutable,
308            constant: None,
309            entry_arg: None,
310        };
311
312        self.symbol_table.add(name, var)
313    }
314
315    /// Add function argument to current scope
316    pub fn add_function_arg(
317        &mut self,
318        name_meta: Option<(String, Span)>,
319        ty: Handle<Type>,
320        qualifier: ParameterQualifier,
321    ) -> Result<()> {
322        let index = self.arguments.len();
323        let mut arg = FunctionArgument {
324            name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
325            ty,
326            binding: None,
327        };
328        self.parameters.push(ty);
329
330        let opaque = match self.module.types[ty].inner {
331            TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
332            _ => false,
333        };
334
335        if qualifier.is_lhs() {
336            let span = self.module.types.get_span(arg.ty);
337            arg.ty = self.module.types.insert(
338                Type {
339                    name: None,
340                    inner: TypeInner::Pointer {
341                        base: arg.ty,
342                        space: AddressSpace::Function,
343                    },
344                },
345                span,
346            )
347        }
348
349        self.arguments.push(arg);
350
351        self.parameters_info.push(ParameterInfo {
352            qualifier,
353            depth: false,
354        });
355
356        if let Some((name, meta)) = name_meta {
357            let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
358            let mutable = qualifier != ParameterQualifier::Const && !opaque;
359            let load = qualifier.is_lhs();
360
361            let var = if mutable && !load {
362                let handle = self.locals.append(
363                    LocalVariable {
364                        name: Some(name.clone()),
365                        ty,
366                        init: None,
367                    },
368                    meta,
369                );
370                let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
371
372                self.emit_restart();
373
374                self.body.push(
375                    Statement::Store {
376                        pointer: local_expr,
377                        value: expr,
378                    },
379                    meta,
380                );
381
382                VariableReference {
383                    expr: local_expr,
384                    load: true,
385                    mutable,
386                    constant: None,
387                    entry_arg: None,
388                }
389            } else {
390                VariableReference {
391                    expr,
392                    load,
393                    mutable,
394                    constant: None,
395                    entry_arg: None,
396                }
397            };
398
399            self.symbol_table.add(name, var);
400        }
401
402        Ok(())
403    }
404
405    /// Returns a [`StmtContext`] to be used in parsing and lowering
406    ///
407    /// # Panics
408    ///
409    /// - If more than one [`StmtContext`] are active at the same time or if the
410    ///   previous call didn't use it in lowering.
411    #[must_use]
412    pub fn stmt_ctx(&mut self) -> StmtContext {
413        self.stmt_ctx.take().unwrap()
414    }
415
416    /// Lowers a [`HirExpr`] which might produce a [`Expression`].
417    ///
418    /// consumes a [`StmtContext`] returning it to the context so that it can be
419    /// used again later.
420    pub fn lower(
421        &mut self,
422        mut stmt: StmtContext,
423        frontend: &mut Frontend,
424        expr: Handle<HirExpr>,
425        pos: ExprPos,
426    ) -> Result<(Option<Handle<Expression>>, Span)> {
427        let res = self.lower_inner(&stmt, frontend, expr, pos);
428
429        stmt.hir_exprs.clear();
430        self.stmt_ctx = Some(stmt);
431
432        res
433    }
434
435    /// Similar to [`lower`](Self::lower) but returns an error if the expression
436    /// returns void (ie. doesn't produce a [`Expression`]).
437    ///
438    /// consumes a [`StmtContext`] returning it to the context so that it can be
439    /// used again later.
440    pub fn lower_expect(
441        &mut self,
442        mut stmt: StmtContext,
443        frontend: &mut Frontend,
444        expr: Handle<HirExpr>,
445        pos: ExprPos,
446    ) -> Result<(Handle<Expression>, Span)> {
447        let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
448
449        stmt.hir_exprs.clear();
450        self.stmt_ctx = Some(stmt);
451
452        res
453    }
454
455    /// internal implementation of [`lower_expect`](Self::lower_expect)
456    ///
457    /// this method is only public because it's used in
458    /// [`function_call`](Frontend::function_call), unless you know what
459    /// you're doing use [`lower_expect`](Self::lower_expect)
460    pub fn lower_expect_inner(
461        &mut self,
462        stmt: &StmtContext,
463        frontend: &mut Frontend,
464        expr: Handle<HirExpr>,
465        pos: ExprPos,
466    ) -> Result<(Handle<Expression>, Span)> {
467        let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
468
469        let expr = match maybe_expr {
470            Some(e) => e,
471            None => {
472                return Err(Error {
473                    kind: ErrorKind::SemanticError("Expression returns void".into()),
474                    meta,
475                })
476            }
477        };
478
479        Ok((expr, meta))
480    }
481
482    fn lower_store(
483        &mut self,
484        pointer: Handle<Expression>,
485        value: Handle<Expression>,
486        meta: Span,
487    ) -> Result<()> {
488        if let Expression::Swizzle {
489            size,
490            mut vector,
491            pattern,
492        } = self.expressions[pointer]
493        {
494            // Stores to swizzled values are not directly supported,
495            // lower them as series of per-component stores.
496            let size = match size {
497                VectorSize::Bi => 2,
498                VectorSize::Tri => 3,
499                VectorSize::Quad => 4,
500            };
501
502            if let Expression::Load { pointer } = self.expressions[vector] {
503                vector = pointer;
504            }
505
506            #[allow(clippy::needless_range_loop)]
507            for index in 0..size {
508                let dst = self.add_expression(
509                    Expression::AccessIndex {
510                        base: vector,
511                        index: pattern[index].index(),
512                    },
513                    meta,
514                )?;
515                let src = self.add_expression(
516                    Expression::AccessIndex {
517                        base: value,
518                        index: index as u32,
519                    },
520                    meta,
521                )?;
522
523                self.emit_restart();
524
525                self.body.push(
526                    Statement::Store {
527                        pointer: dst,
528                        value: src,
529                    },
530                    meta,
531                );
532            }
533        } else {
534            self.emit_restart();
535
536            self.body.push(Statement::Store { pointer, value }, meta);
537        }
538
539        Ok(())
540    }
541
542    /// Internal implementation of [`lower`](Self::lower)
543    fn lower_inner(
544        &mut self,
545        stmt: &StmtContext,
546        frontend: &mut Frontend,
547        expr: Handle<HirExpr>,
548        pos: ExprPos,
549    ) -> Result<(Option<Handle<Expression>>, Span)> {
550        let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
551
552        log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
553
554        let handle = match *kind {
555            HirExprKind::Access { base, index } => {
556                let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
557                let maybe_constant_index = match pos {
558                    // Don't try to generate `AccessIndex` if in a LHS position, since it
559                    // wouldn't produce a pointer.
560                    ExprPos::Lhs => None,
561                    _ => self
562                        .module
563                        .to_ctx()
564                        .get_const_val_from(index, &self.expressions)
565                        .ok(),
566                };
567
568                let base = self
569                    .lower_expect_inner(
570                        stmt,
571                        frontend,
572                        base,
573                        pos.maybe_access_base(maybe_constant_index.is_some()),
574                    )?
575                    .0;
576
577                let pointer = maybe_constant_index
578                    .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
579                    .unwrap_or_else(|| {
580                        self.add_expression(Expression::Access { base, index }, meta)
581                    })?;
582
583                if ExprPos::Rhs == pos {
584                    let resolved = self.resolve_type(pointer, meta)?;
585                    if resolved.pointer_space().is_some() {
586                        return Ok((
587                            Some(self.add_expression(Expression::Load { pointer }, meta)?),
588                            meta,
589                        ));
590                    }
591                }
592
593                pointer
594            }
595            HirExprKind::Select { base, ref field } => {
596                let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
597
598                frontend.field_selection(self, pos, base, field, meta)?
599            }
600            HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
601                self.add_expression(Expression::Literal(literal), meta)?
602            }
603            HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
604                let (mut left, left_meta) =
605                    self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
606                let (mut right, right_meta) =
607                    self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
608
609                match op {
610                    BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
611                        self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
612                    }
613                    _ => self
614                        .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
615                }
616
617                self.typifier_grow(left, left_meta)?;
618                self.typifier_grow(right, right_meta)?;
619
620                let left_inner = self.get_type(left);
621                let right_inner = self.get_type(right);
622
623                match (left_inner, right_inner) {
624                    (
625                        &TypeInner::Matrix {
626                            columns: left_columns,
627                            rows: left_rows,
628                            scalar: left_scalar,
629                        },
630                        &TypeInner::Matrix {
631                            columns: right_columns,
632                            rows: right_rows,
633                            scalar: right_scalar,
634                        },
635                    ) => {
636                        let dimensions_ok = if op == BinaryOperator::Multiply {
637                            left_columns == right_rows
638                        } else {
639                            left_columns == right_columns && left_rows == right_rows
640                        };
641
642                        // Check that the two arguments have the same dimensions
643                        if !dimensions_ok || left_scalar != right_scalar {
644                            frontend.errors.push(Error {
645                                kind: ErrorKind::SemanticError(
646                                    format!(
647                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
648                                    )
649                                    .into(),
650                                ),
651                                meta,
652                            })
653                        }
654
655                        match op {
656                            BinaryOperator::Divide => {
657                                // Naga IR doesn't support matrix division so we need to
658                                // divide the columns individually and reassemble the matrix
659                                let mut components = Vec::with_capacity(left_columns as usize);
660
661                                for index in 0..left_columns as u32 {
662                                    // Get the column vectors
663                                    let left_vector = self.add_expression(
664                                        Expression::AccessIndex { base: left, index },
665                                        meta,
666                                    )?;
667                                    let right_vector = self.add_expression(
668                                        Expression::AccessIndex { base: right, index },
669                                        meta,
670                                    )?;
671
672                                    // Divide the vectors
673                                    let column = self.add_expression(
674                                        Expression::Binary {
675                                            op,
676                                            left: left_vector,
677                                            right: right_vector,
678                                        },
679                                        meta,
680                                    )?;
681
682                                    components.push(column)
683                                }
684
685                                let ty = self.module.types.insert(
686                                    Type {
687                                        name: None,
688                                        inner: TypeInner::Matrix {
689                                            columns: left_columns,
690                                            rows: left_rows,
691                                            scalar: left_scalar,
692                                        },
693                                    },
694                                    Span::default(),
695                                );
696
697                                // Rebuild the matrix from the divided vectors
698                                self.add_expression(Expression::Compose { ty, components }, meta)?
699                            }
700                            BinaryOperator::Equal | BinaryOperator::NotEqual => {
701                                // Naga IR doesn't support matrix comparisons so we need to
702                                // compare the columns individually and then fold them together
703                                //
704                                // The folding is done using a logical and for equality and
705                                // a logical or for inequality
706                                let equals = op == BinaryOperator::Equal;
707
708                                let (op, combine, fun) = match equals {
709                                    true => (
710                                        BinaryOperator::Equal,
711                                        BinaryOperator::LogicalAnd,
712                                        RelationalFunction::All,
713                                    ),
714                                    false => (
715                                        BinaryOperator::NotEqual,
716                                        BinaryOperator::LogicalOr,
717                                        RelationalFunction::Any,
718                                    ),
719                                };
720
721                                let mut root = None;
722
723                                for index in 0..left_columns as u32 {
724                                    // Get the column vectors
725                                    let left_vector = self.add_expression(
726                                        Expression::AccessIndex { base: left, index },
727                                        meta,
728                                    )?;
729                                    let right_vector = self.add_expression(
730                                        Expression::AccessIndex { base: right, index },
731                                        meta,
732                                    )?;
733
734                                    let argument = self.add_expression(
735                                        Expression::Binary {
736                                            op,
737                                            left: left_vector,
738                                            right: right_vector,
739                                        },
740                                        meta,
741                                    )?;
742
743                                    // The result of comparing two vectors is a boolean vector
744                                    // so use a relational function like all to get a single
745                                    // boolean value
746                                    let compare = self.add_expression(
747                                        Expression::Relational { fun, argument },
748                                        meta,
749                                    )?;
750
751                                    // Fold the result
752                                    root = Some(match root {
753                                        Some(right) => self.add_expression(
754                                            Expression::Binary {
755                                                op: combine,
756                                                left: compare,
757                                                right,
758                                            },
759                                            meta,
760                                        )?,
761                                        None => compare,
762                                    });
763                                }
764
765                                root.unwrap()
766                            }
767                            _ => {
768                                self.add_expression(Expression::Binary { left, op, right }, meta)?
769                            }
770                        }
771                    }
772                    (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
773                        BinaryOperator::Equal | BinaryOperator::NotEqual => {
774                            let equals = op == BinaryOperator::Equal;
775
776                            let (op, fun) = match equals {
777                                true => (BinaryOperator::Equal, RelationalFunction::All),
778                                false => (BinaryOperator::NotEqual, RelationalFunction::Any),
779                            };
780
781                            let argument =
782                                self.add_expression(Expression::Binary { op, left, right }, meta)?;
783
784                            self.add_expression(Expression::Relational { fun, argument }, meta)?
785                        }
786                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
787                    },
788                    (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
789                        BinaryOperator::Add
790                        | BinaryOperator::Subtract
791                        | BinaryOperator::Divide
792                        | BinaryOperator::And
793                        | BinaryOperator::ExclusiveOr
794                        | BinaryOperator::InclusiveOr
795                        | BinaryOperator::ShiftLeft
796                        | BinaryOperator::ShiftRight => {
797                            let scalar_vector = self
798                                .add_expression(Expression::Splat { size, value: right }, meta)?;
799
800                            self.add_expression(
801                                Expression::Binary {
802                                    op,
803                                    left,
804                                    right: scalar_vector,
805                                },
806                                meta,
807                            )?
808                        }
809                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
810                    },
811                    (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
812                        BinaryOperator::Add
813                        | BinaryOperator::Subtract
814                        | BinaryOperator::Divide
815                        | BinaryOperator::And
816                        | BinaryOperator::ExclusiveOr
817                        | BinaryOperator::InclusiveOr => {
818                            let scalar_vector =
819                                self.add_expression(Expression::Splat { size, value: left }, meta)?;
820
821                            self.add_expression(
822                                Expression::Binary {
823                                    op,
824                                    left: scalar_vector,
825                                    right,
826                                },
827                                meta,
828                            )?
829                        }
830                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
831                    },
832                    (
833                        &TypeInner::Scalar(left_scalar),
834                        &TypeInner::Matrix {
835                            rows,
836                            columns,
837                            scalar: right_scalar,
838                        },
839                    ) => {
840                        // Check that the two arguments have the same scalar type
841                        if left_scalar != right_scalar {
842                            frontend.errors.push(Error {
843                                kind: ErrorKind::SemanticError(
844                                    format!(
845                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
846                                    )
847                                    .into(),
848                                ),
849                                meta,
850                            })
851                        }
852
853                        match op {
854                            BinaryOperator::Divide
855                            | BinaryOperator::Add
856                            | BinaryOperator::Subtract => {
857                                // Naga IR doesn't support all matrix by scalar operations so
858                                // we need for some to turn the scalar into a vector by
859                                // splatting it and then for each column vector apply the
860                                // operation and finally reconstruct the matrix
861                                let scalar_vector = self.add_expression(
862                                    Expression::Splat {
863                                        size: rows,
864                                        value: left,
865                                    },
866                                    meta,
867                                )?;
868
869                                let mut components = Vec::with_capacity(columns as usize);
870
871                                for index in 0..columns as u32 {
872                                    // Get the column vector
873                                    let matrix_column = self.add_expression(
874                                        Expression::AccessIndex { base: right, index },
875                                        meta,
876                                    )?;
877
878                                    // Apply the operation to the splatted vector and
879                                    // the column vector
880                                    let column = self.add_expression(
881                                        Expression::Binary {
882                                            op,
883                                            left: scalar_vector,
884                                            right: matrix_column,
885                                        },
886                                        meta,
887                                    )?;
888
889                                    components.push(column)
890                                }
891
892                                let ty = self.module.types.insert(
893                                    Type {
894                                        name: None,
895                                        inner: TypeInner::Matrix {
896                                            columns,
897                                            rows,
898                                            scalar: left_scalar,
899                                        },
900                                    },
901                                    Span::default(),
902                                );
903
904                                // Rebuild the matrix from the operation result vectors
905                                self.add_expression(Expression::Compose { ty, components }, meta)?
906                            }
907                            _ => {
908                                self.add_expression(Expression::Binary { left, op, right }, meta)?
909                            }
910                        }
911                    }
912                    (
913                        &TypeInner::Matrix {
914                            rows,
915                            columns,
916                            scalar: left_scalar,
917                        },
918                        &TypeInner::Scalar(right_scalar),
919                    ) => {
920                        // Check that the two arguments have the same scalar type
921                        if left_scalar != right_scalar {
922                            frontend.errors.push(Error {
923                                kind: ErrorKind::SemanticError(
924                                    format!(
925                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
926                                    )
927                                    .into(),
928                                ),
929                                meta,
930                            })
931                        }
932
933                        match op {
934                            BinaryOperator::Divide
935                            | BinaryOperator::Add
936                            | BinaryOperator::Subtract => {
937                                // Naga IR doesn't support all matrix by scalar operations so
938                                // we need for some to turn the scalar into a vector by
939                                // splatting it and then for each column vector apply the
940                                // operation and finally reconstruct the matrix
941
942                                let scalar_vector = self.add_expression(
943                                    Expression::Splat {
944                                        size: rows,
945                                        value: right,
946                                    },
947                                    meta,
948                                )?;
949
950                                let mut components = Vec::with_capacity(columns as usize);
951
952                                for index in 0..columns as u32 {
953                                    // Get the column vector
954                                    let matrix_column = self.add_expression(
955                                        Expression::AccessIndex { base: left, index },
956                                        meta,
957                                    )?;
958
959                                    // Apply the operation to the splatted vector and
960                                    // the column vector
961                                    let column = self.add_expression(
962                                        Expression::Binary {
963                                            op,
964                                            left: matrix_column,
965                                            right: scalar_vector,
966                                        },
967                                        meta,
968                                    )?;
969
970                                    components.push(column)
971                                }
972
973                                let ty = self.module.types.insert(
974                                    Type {
975                                        name: None,
976                                        inner: TypeInner::Matrix {
977                                            columns,
978                                            rows,
979                                            scalar: left_scalar,
980                                        },
981                                    },
982                                    Span::default(),
983                                );
984
985                                // Rebuild the matrix from the operation result vectors
986                                self.add_expression(Expression::Compose { ty, components }, meta)?
987                            }
988                            _ => {
989                                self.add_expression(Expression::Binary { left, op, right }, meta)?
990                            }
991                        }
992                    }
993                    _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
994                }
995            }
996            HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
997                let expr = self
998                    .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
999                    .0;
1000
1001                self.add_expression(Expression::Unary { op, expr }, meta)?
1002            }
1003            HirExprKind::Variable(ref var) => match pos {
1004                ExprPos::Lhs => {
1005                    if !var.mutable {
1006                        frontend.errors.push(Error {
1007                            kind: ErrorKind::SemanticError(
1008                                "Variable cannot be used in LHS position".into(),
1009                            ),
1010                            meta,
1011                        })
1012                    }
1013
1014                    var.expr
1015                }
1016                ExprPos::AccessBase { constant_index } => {
1017                    // If the index isn't constant all accesses backed by a constant base need
1018                    // to be done through a proxy local variable, since constants have a non
1019                    // pointer type which is required for dynamic indexing
1020                    if !constant_index {
1021                        if let Some((constant, ty)) = var.constant {
1022                            let init = self
1023                                .add_expression(Expression::Constant(constant), Span::default())?;
1024                            let local = self.locals.append(
1025                                LocalVariable {
1026                                    name: None,
1027                                    ty,
1028                                    init: Some(init),
1029                                },
1030                                Span::default(),
1031                            );
1032
1033                            self.add_expression(Expression::LocalVariable(local), Span::default())?
1034                        } else {
1035                            var.expr
1036                        }
1037                    } else {
1038                        var.expr
1039                    }
1040                }
1041                _ if var.load => {
1042                    self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1043                }
1044                ExprPos::Rhs => {
1045                    if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1046                        self.add_expression(Expression::Constant(constant), meta)?
1047                    } else {
1048                        // Check if this is an Override expression in const context
1049                        if self.is_const {
1050                            if let Expression::Override(o) = self.expressions[var.expr] {
1051                                // Need to add the Override expression to the global arena
1052                                self.add_expression(Expression::Override(o), meta)?
1053                            } else {
1054                                var.expr
1055                            }
1056                        } else {
1057                            var.expr
1058                        }
1059                    }
1060                }
1061            },
1062            HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1063                let maybe_expr = frontend.function_or_constructor_call(
1064                    self,
1065                    stmt,
1066                    call.kind.clone(),
1067                    &call.args,
1068                    meta,
1069                )?;
1070                return Ok((maybe_expr, meta));
1071            }
1072            // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
1073            //
1074            // The ternary operator is defined to only evaluate one of the two possible
1075            // expressions which means that it's behavior is that of an `if` statement,
1076            // and it's merely syntactic sugar for it.
1077            HirExprKind::Conditional {
1078                condition,
1079                accept,
1080                reject,
1081            } if ExprPos::Lhs != pos => {
1082                // Given an expression `a ? b : c`, we need to produce a Naga
1083                // statement roughly like:
1084                //
1085                //     var temp;
1086                //     if a {
1087                //         temp = convert(b);
1088                //     } else  {
1089                //         temp = convert(c);
1090                //     }
1091                //
1092                // where `convert` stands for type conversions to bring `b` and `c` to
1093                // the same type, and then use `temp` to represent the value of the whole
1094                // conditional expression in subsequent code.
1095
1096                // Lower the condition first to the current bodyy
1097                let condition = self
1098                    .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1099                    .0;
1100
1101                let (mut accept_body, (mut accept, accept_meta)) =
1102                    self.new_body_with_ret(|ctx| {
1103                        // Lower the `true` branch
1104                        ctx.lower_expect_inner(stmt, frontend, accept, pos)
1105                    })?;
1106
1107                let (mut reject_body, (mut reject, reject_meta)) =
1108                    self.new_body_with_ret(|ctx| {
1109                        // Lower the `false` branch
1110                        ctx.lower_expect_inner(stmt, frontend, reject, pos)
1111                    })?;
1112
1113                // We need to do some custom implicit conversions since the two target expressions
1114                // are in different bodies
1115                if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1116                    // Get the components of both branches and calculate the type power
1117                    self.expr_scalar_components(accept, accept_meta)?
1118                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1119                    self.expr_scalar_components(reject, reject_meta)?
1120                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1121                ) {
1122                    match accept_power.cmp(&reject_power) {
1123                        core::cmp::Ordering::Less => {
1124                            accept_body = self.with_body(accept_body, |ctx| {
1125                                ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1126                                Ok(())
1127                            })?;
1128                        }
1129                        core::cmp::Ordering::Equal => {}
1130                        core::cmp::Ordering::Greater => {
1131                            reject_body = self.with_body(reject_body, |ctx| {
1132                                ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1133                                Ok(())
1134                            })?;
1135                        }
1136                    }
1137                }
1138
1139                // We need to get the type of the resulting expression to create the local,
1140                // this must be done after implicit conversions to ensure both branches have
1141                // the same type.
1142                let ty = self.resolve_type_handle(accept, accept_meta)?;
1143
1144                // Add the local that will hold the result of our conditional
1145                let local = self.locals.append(
1146                    LocalVariable {
1147                        name: None,
1148                        ty,
1149                        init: None,
1150                    },
1151                    meta,
1152                );
1153
1154                let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1155
1156                // Add to each  the store to the result variable
1157                accept_body.push(
1158                    Statement::Store {
1159                        pointer: local_expr,
1160                        value: accept,
1161                    },
1162                    accept_meta,
1163                );
1164                reject_body.push(
1165                    Statement::Store {
1166                        pointer: local_expr,
1167                        value: reject,
1168                    },
1169                    reject_meta,
1170                );
1171
1172                // Finally add the `If` to the main body with the `condition` we lowered
1173                // earlier and the branches we prepared.
1174                self.body.push(
1175                    Statement::If {
1176                        condition,
1177                        accept: accept_body,
1178                        reject: reject_body,
1179                    },
1180                    meta,
1181                );
1182
1183                // Note: `Expression::Load` must be emitted before it's used so make
1184                // sure the emitter is active here.
1185                self.add_expression(
1186                    Expression::Load {
1187                        pointer: local_expr,
1188                    },
1189                    meta,
1190                )?
1191            }
1192            HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1193                let (pointer, ptr_meta) =
1194                    self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1195                let (mut value, value_meta) =
1196                    self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1197
1198                let ty = match *self.resolve_type(pointer, ptr_meta)? {
1199                    TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1200                    ref ty => ty,
1201                };
1202
1203                if let Some(scalar) = scalar_components(ty) {
1204                    self.implicit_conversion(&mut value, value_meta, scalar)?;
1205                }
1206
1207                self.lower_store(pointer, value, meta)?;
1208
1209                value
1210            }
1211            HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1212                let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1213                let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1214                    pointer
1215                } else {
1216                    self.add_expression(Expression::Load { pointer }, meta)?
1217                };
1218
1219                let res = match *self.resolve_type(left, meta)? {
1220                    TypeInner::Scalar(scalar) => {
1221                        let ty = TypeInner::Scalar(scalar);
1222                        Literal::one(scalar).map(|i| (ty, i, None, None))
1223                    }
1224                    TypeInner::Vector { size, scalar } => {
1225                        let ty = TypeInner::Vector { size, scalar };
1226                        Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1227                    }
1228                    TypeInner::Matrix {
1229                        columns,
1230                        rows,
1231                        scalar,
1232                    } => {
1233                        let ty = TypeInner::Matrix {
1234                            columns,
1235                            rows,
1236                            scalar,
1237                        };
1238                        Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1239                    }
1240                    _ => None,
1241                };
1242                let (ty_inner, literal, rows, columns) = match res {
1243                    Some(res) => res,
1244                    None => {
1245                        frontend.errors.push(Error {
1246                            kind: ErrorKind::SemanticError(
1247                                "Increment/decrement only works on scalar/vector/matrix".into(),
1248                            ),
1249                            meta,
1250                        });
1251                        return Ok((Some(left), meta));
1252                    }
1253                };
1254
1255                let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1256
1257                // Glsl allows pre/postfixes operations on vectors and matrices, so if the
1258                // target is either of them change the right side of the addition to be splatted
1259                // to the same size as the target, furthermore if the target is a matrix
1260                // use a composed matrix using the splatted value.
1261                if let Some(size) = rows {
1262                    right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1263
1264                    if let Some(cols) = columns {
1265                        let ty = self.module.types.insert(
1266                            Type {
1267                                name: None,
1268                                inner: ty_inner,
1269                            },
1270                            meta,
1271                        );
1272
1273                        right = self.add_expression(
1274                            Expression::Compose {
1275                                ty,
1276                                components: core::iter::repeat_n(right, cols as usize).collect(),
1277                            },
1278                            meta,
1279                        )?;
1280                    }
1281                }
1282
1283                let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1284
1285                self.lower_store(pointer, value, meta)?;
1286
1287                if postfix {
1288                    left
1289                } else {
1290                    value
1291                }
1292            }
1293            HirExprKind::Method {
1294                expr: object,
1295                ref name,
1296                ref args,
1297            } if ExprPos::Lhs != pos => {
1298                let args = args
1299                    .iter()
1300                    .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1301                    .collect::<Result<Vec<_>>>()?;
1302                match name.as_ref() {
1303                    "length" => {
1304                        if !args.is_empty() {
1305                            frontend.errors.push(Error {
1306                                kind: ErrorKind::SemanticError(
1307                                    ".length() doesn't take any arguments".into(),
1308                                ),
1309                                meta,
1310                            });
1311                        }
1312                        let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1313                        let array_type = self.resolve_type(lowered_array, meta)?;
1314
1315                        match *array_type {
1316                            TypeInner::Array {
1317                                size: crate::ArraySize::Constant(size),
1318                                ..
1319                            } => {
1320                                let mut array_length = self.add_expression(
1321                                    Expression::Literal(Literal::U32(size.get())),
1322                                    meta,
1323                                )?;
1324                                self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1325                                array_length
1326                            }
1327                            // let the error be handled in type checking if it's not a dynamic array
1328                            _ => {
1329                                let mut array_length = self
1330                                    .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1331                                self.conversion(&mut array_length, meta, Scalar::I32)?;
1332                                array_length
1333                            }
1334                        }
1335                    }
1336
1337                    _ => {
1338                        return Err(Error {
1339                            kind: ErrorKind::SemanticError(
1340                                format!("unknown method '{name}'").into(),
1341                            ),
1342                            meta,
1343                        });
1344                    }
1345                }
1346            }
1347            HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => {
1348                let mut last_handle = None;
1349                for expr in exprs.iter() {
1350                    let (handle, _) =
1351                        self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?;
1352                    last_handle = Some(handle);
1353                }
1354                match last_handle {
1355                    Some(handle) => handle,
1356                    None => unreachable!(),
1357                }
1358            }
1359            _ => {
1360                return Err(Error {
1361                    kind: ErrorKind::SemanticError(
1362                        format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1363                            .into(),
1364                    ),
1365                    meta,
1366                })
1367            }
1368        };
1369
1370        log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1371
1372        Ok((Some(handle), meta))
1373    }
1374
1375    pub fn expr_scalar_components(
1376        &mut self,
1377        expr: Handle<Expression>,
1378        meta: Span,
1379    ) -> Result<Option<Scalar>> {
1380        let ty = self.resolve_type(expr, meta)?;
1381        Ok(scalar_components(ty))
1382    }
1383
1384    pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1385        Ok(self
1386            .expr_scalar_components(expr, meta)?
1387            .and_then(type_power))
1388    }
1389
1390    pub fn conversion(
1391        &mut self,
1392        expr: &mut Handle<Expression>,
1393        meta: Span,
1394        scalar: Scalar,
1395    ) -> Result<()> {
1396        *expr = self.add_expression(
1397            Expression::As {
1398                expr: *expr,
1399                kind: scalar.kind,
1400                convert: Some(scalar.width),
1401            },
1402            meta,
1403        )?;
1404
1405        Ok(())
1406    }
1407
1408    pub fn implicit_conversion(
1409        &mut self,
1410        expr: &mut Handle<Expression>,
1411        meta: Span,
1412        scalar: Scalar,
1413    ) -> Result<()> {
1414        if let (Some(tgt_power), Some(expr_power)) =
1415            (type_power(scalar), self.expr_power(*expr, meta)?)
1416        {
1417            if tgt_power > expr_power {
1418                self.conversion(expr, meta, scalar)?;
1419            }
1420        }
1421
1422        Ok(())
1423    }
1424
1425    pub fn forced_conversion(
1426        &mut self,
1427        expr: &mut Handle<Expression>,
1428        meta: Span,
1429        scalar: Scalar,
1430    ) -> Result<()> {
1431        if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1432            if expr_scalar != scalar {
1433                self.conversion(expr, meta, scalar)?;
1434            }
1435        }
1436
1437        Ok(())
1438    }
1439
1440    pub fn binary_implicit_conversion(
1441        &mut self,
1442        left: &mut Handle<Expression>,
1443        left_meta: Span,
1444        right: &mut Handle<Expression>,
1445        right_meta: Span,
1446    ) -> Result<()> {
1447        let left_components = self.expr_scalar_components(*left, left_meta)?;
1448        let right_components = self.expr_scalar_components(*right, right_meta)?;
1449
1450        if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1451            left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1452            right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1453        ) {
1454            match left_power.cmp(&right_power) {
1455                core::cmp::Ordering::Less => {
1456                    self.conversion(left, left_meta, right_scalar)?;
1457                }
1458                core::cmp::Ordering::Equal => {}
1459                core::cmp::Ordering::Greater => {
1460                    self.conversion(right, right_meta, left_scalar)?;
1461                }
1462            }
1463        }
1464
1465        Ok(())
1466    }
1467
1468    pub fn implicit_splat(
1469        &mut self,
1470        expr: &mut Handle<Expression>,
1471        meta: Span,
1472        vector_size: Option<VectorSize>,
1473    ) -> Result<()> {
1474        let expr_type = self.resolve_type(*expr, meta)?;
1475
1476        if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1477            *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1478        }
1479
1480        Ok(())
1481    }
1482
1483    pub fn vector_resize(
1484        &mut self,
1485        size: VectorSize,
1486        vector: Handle<Expression>,
1487        meta: Span,
1488    ) -> Result<Handle<Expression>> {
1489        self.add_expression(
1490            Expression::Swizzle {
1491                size,
1492                vector,
1493                pattern: crate::SwizzleComponent::XYZW,
1494            },
1495            meta,
1496        )
1497    }
1498}
1499
1500impl Index<Handle<Expression>> for Context<'_> {
1501    type Output = Expression;
1502
1503    fn index(&self, index: Handle<Expression>) -> &Self::Output {
1504        if self.is_const {
1505            &self.module.global_expressions[index]
1506        } else {
1507            &self.expressions[index]
1508        }
1509    }
1510}
1511
1512/// Helper struct passed when parsing expressions
1513///
1514/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
1515/// and only one of these may be active at any time per context.
1516#[derive(Debug)]
1517pub struct StmtContext {
1518    /// A arena of high level expressions which can be lowered through a
1519    /// [`Context`] to Naga's [`Expression`]s
1520    pub hir_exprs: Arena<HirExpr>,
1521}
1522
1523impl StmtContext {
1524    const fn new() -> Self {
1525        StmtContext {
1526            hir_exprs: Arena::new(),
1527        }
1528    }
1529}