naga/front/glsl/
context.rs

1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5    ast::{
6        GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7        VariableReference,
8    },
9    error::{Error, ErrorKind},
10    types::{scalar_components, type_power},
11    Frontend, Result,
12};
13use crate::{
14    front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15    Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16    Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19/// The position at which an expression is, used while lowering
20#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22    /// The expression is in the left hand side of an assignment
23    Lhs,
24    /// The expression is in the right hand side of an assignment
25    Rhs,
26    /// The expression is an array being indexed, needed to allow constant
27    /// arrays to be dynamically indexed
28    AccessBase {
29        /// The index is a constant
30        constant_index: bool,
31    },
32}
33
34impl ExprPos {
35    /// Returns an lhs position if the current position is lhs otherwise AccessBase
36    const fn maybe_access_base(&self, constant_index: bool) -> Self {
37        match *self {
38            ExprPos::Lhs
39            | ExprPos::AccessBase {
40                constant_index: false,
41            } => *self,
42            _ => ExprPos::AccessBase { constant_index },
43        }
44    }
45}
46
47#[derive(Debug)]
48pub(crate) struct Context<'a> {
49    pub expressions: Arena<Expression>,
50    pub locals: Arena<LocalVariable>,
51
52    /// The [`FunctionArgument`]s for the final [`crate::Function`].
53    ///
54    /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
55    /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
56    /// a [`Vector`].
57    ///
58    /// [`Pointer`]: crate::TypeInner::Pointer
59    /// [`Vector`]: crate::TypeInner::Vector
60    pub arguments: Vec<FunctionArgument>,
61
62    /// The parameter types given in the source code.
63    ///
64    /// The `out` and `inout` qualifiers don't affect the types that appear
65    /// here. For example, an `inout vec2 a` argument would simply be a
66    /// [`Vector`], not a pointer to one.
67    ///
68    /// [`Vector`]: crate::TypeInner::Vector
69    pub parameters: Vec<Handle<Type>>,
70    pub parameters_info: Vec<ParameterInfo>,
71
72    pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73    pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75    pub const_typifier: Typifier,
76    pub typifier: Typifier,
77    layouter: Layouter,
78    emitter: Emitter,
79    stmt_ctx: Option<StmtContext>,
80    pub body: Block,
81    pub module: &'a mut crate::Module,
82    pub is_const: bool,
83    /// Tracks the expression kind of `Expression`s residing in `self.expressions`
84    pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85    /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
86    pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90    pub fn new(
91        frontend: &Frontend,
92        module: &'a mut crate::Module,
93        is_const: bool,
94        global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95    ) -> Result<Self> {
96        let mut this = Context {
97            expressions: Arena::new(),
98            locals: Arena::new(),
99            arguments: Vec::new(),
100
101            parameters: Vec::new(),
102            parameters_info: Vec::new(),
103
104            symbol_table: crate::front::SymbolTable::default(),
105            samplers: FastHashMap::default(),
106
107            const_typifier: Typifier::new(),
108            typifier: Typifier::new(),
109            layouter: Layouter::default(),
110            emitter: Emitter::default(),
111            stmt_ctx: Some(StmtContext::new()),
112            body: Block::new(),
113            module,
114            is_const: false,
115            local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116            global_expression_kind_tracker,
117        };
118
119        this.emit_start();
120
121        for &(ref name, lookup) in frontend.global_variables.iter() {
122            this.add_global(name, lookup)?
123        }
124        this.is_const = is_const;
125
126        Ok(this)
127    }
128
129    pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130    where
131        F: FnOnce(&mut Self) -> Result<()>,
132    {
133        self.new_body_with_ret(cb).map(|(b, _)| b)
134    }
135
136    pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137    where
138        F: FnOnce(&mut Self) -> Result<R>,
139    {
140        self.emit_restart();
141        let old_body = core::mem::replace(&mut self.body, Block::new());
142        let res = cb(self);
143        self.emit_restart();
144        let new_body = core::mem::replace(&mut self.body, old_body);
145        res.map(|r| (new_body, r))
146    }
147
148    pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149    where
150        F: FnOnce(&mut Self) -> Result<()>,
151    {
152        self.emit_restart();
153        let old_body = core::mem::replace(&mut self.body, body);
154        let res = cb(self);
155        self.emit_restart();
156        let body = core::mem::replace(&mut self.body, old_body);
157        res.map(|_| body)
158    }
159
160    pub fn add_global(
161        &mut self,
162        name: &str,
163        GlobalLookup {
164            kind,
165            entry_arg,
166            mutable,
167        }: GlobalLookup,
168    ) -> Result<()> {
169        let (expr, load, constant) = match kind {
170            GlobalLookupKind::Variable(v) => {
171                let span = self.module.global_variables.get_span(v);
172                (
173                    self.add_expression(Expression::GlobalVariable(v), span)?,
174                    self.module.global_variables[v].space != AddressSpace::Handle,
175                    None,
176                )
177            }
178            GlobalLookupKind::BlockSelect(handle, index) => {
179                let span = self.module.global_variables.get_span(handle);
180                let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181                let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183                (
184                    expr,
185                    {
186                        let ty = self.module.global_variables[handle].ty;
187
188                        match self.module.types[ty].inner {
189                            TypeInner::Struct { ref members, .. } => {
190                                if let TypeInner::Array {
191                                    size: crate::ArraySize::Dynamic,
192                                    ..
193                                } = self.module.types[members[index as usize].ty].inner
194                                {
195                                    false
196                                } else {
197                                    true
198                                }
199                            }
200                            _ => true,
201                        }
202                    },
203                    None,
204                )
205            }
206            GlobalLookupKind::Constant(v, ty) => {
207                let span = self.module.constants.get_span(v);
208                (
209                    self.add_expression(Expression::Constant(v), span)?,
210                    false,
211                    Some((v, ty)),
212                )
213            }
214            GlobalLookupKind::Override(v, _ty) => {
215                let span = self.module.overrides.get_span(v);
216                (
217                    self.add_expression(Expression::Override(v), span)?,
218                    false,
219                    None,
220                )
221            }
222        };
223
224        let var = VariableReference {
225            expr,
226            load,
227            mutable,
228            constant,
229            entry_arg,
230        };
231
232        self.symbol_table.add(name.into(), var);
233
234        Ok(())
235    }
236
237    /// Starts the expression emitter
238    ///
239    /// # Panics
240    ///
241    /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
242    #[inline]
243    pub fn emit_start(&mut self) {
244        self.emitter.start(&self.expressions)
245    }
246
247    /// Emits all the expressions captured by the emitter to the current body
248    ///
249    /// # Panics
250    ///
251    /// - If called before calling [`emit_start`].
252    /// - If called twice in a row without calling [`emit_start`].
253    ///
254    /// [`emit_start`]: Self::emit_start
255    pub fn emit_end(&mut self) {
256        self.body.extend(self.emitter.finish(&self.expressions))
257    }
258
259    /// Emits all the expressions captured by the emitter to the current body
260    /// and starts the emitter again
261    ///
262    /// # Panics
263    ///
264    /// - If called before calling [`emit_start`][Self::emit_start].
265    pub fn emit_restart(&mut self) {
266        self.emit_end();
267        self.emit_start()
268    }
269
270    pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
271        let mut eval = if self.is_const {
272            crate::proc::ConstantEvaluator::for_glsl_module(
273                self.module,
274                self.global_expression_kind_tracker,
275                &mut self.layouter,
276            )
277        } else {
278            crate::proc::ConstantEvaluator::for_glsl_function(
279                self.module,
280                &mut self.expressions,
281                &mut self.local_expression_kind_tracker,
282                &mut self.layouter,
283                &mut self.emitter,
284                &mut self.body,
285            )
286        };
287
288        eval.try_eval_and_append(expr, meta).map_err(|e| Error {
289            kind: e.into(),
290            meta,
291        })
292    }
293
294    /// Add variable to current scope
295    ///
296    /// Returns a variable if a variable with the same name was already defined,
297    /// otherwise returns `None`
298    pub fn add_local_var(
299        &mut self,
300        name: String,
301        expr: Handle<Expression>,
302        mutable: bool,
303    ) -> Option<VariableReference> {
304        let var = VariableReference {
305            expr,
306            load: true,
307            mutable,
308            constant: None,
309            entry_arg: None,
310        };
311
312        self.symbol_table.add(name, var)
313    }
314
315    /// Add function argument to current scope
316    pub fn add_function_arg(
317        &mut self,
318        name_meta: Option<(String, Span)>,
319        ty: Handle<Type>,
320        qualifier: ParameterQualifier,
321    ) -> Result<()> {
322        let index = self.arguments.len();
323        let mut arg = FunctionArgument {
324            name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
325            ty,
326            binding: None,
327        };
328        self.parameters.push(ty);
329
330        let opaque = match self.module.types[ty].inner {
331            TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
332            _ => false,
333        };
334
335        if qualifier.is_lhs() {
336            let span = self.module.types.get_span(arg.ty);
337            arg.ty = self.module.types.insert(
338                Type {
339                    name: None,
340                    inner: TypeInner::Pointer {
341                        base: arg.ty,
342                        space: AddressSpace::Function,
343                    },
344                },
345                span,
346            )
347        }
348
349        self.arguments.push(arg);
350
351        self.parameters_info.push(ParameterInfo {
352            qualifier,
353            depth: false,
354        });
355
356        if let Some((name, meta)) = name_meta {
357            let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
358            let mutable = qualifier != ParameterQualifier::Const && !opaque;
359            let load = qualifier.is_lhs();
360
361            let var = if mutable && !load {
362                let handle = self.locals.append(
363                    LocalVariable {
364                        name: Some(name.clone()),
365                        ty,
366                        init: None,
367                    },
368                    meta,
369                );
370                let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
371
372                self.emit_restart();
373
374                self.body.push(
375                    Statement::Store {
376                        pointer: local_expr,
377                        value: expr,
378                    },
379                    meta,
380                );
381
382                VariableReference {
383                    expr: local_expr,
384                    load: true,
385                    mutable,
386                    constant: None,
387                    entry_arg: None,
388                }
389            } else {
390                VariableReference {
391                    expr,
392                    load,
393                    mutable,
394                    constant: None,
395                    entry_arg: None,
396                }
397            };
398
399            self.symbol_table.add(name, var);
400        }
401
402        Ok(())
403    }
404
405    /// Returns a [`StmtContext`] to be used in parsing and lowering
406    ///
407    /// # Panics
408    ///
409    /// - If more than one [`StmtContext`] are active at the same time or if the
410    ///   previous call didn't use it in lowering.
411    #[must_use]
412    pub const fn stmt_ctx(&mut self) -> StmtContext {
413        self.stmt_ctx.take().unwrap()
414    }
415
416    /// Lowers a [`HirExpr`] which might produce a [`Expression`].
417    ///
418    /// consumes a [`StmtContext`] returning it to the context so that it can be
419    /// used again later.
420    pub fn lower(
421        &mut self,
422        mut stmt: StmtContext,
423        frontend: &mut Frontend,
424        expr: Handle<HirExpr>,
425        pos: ExprPos,
426    ) -> Result<(Option<Handle<Expression>>, Span)> {
427        let res = self.lower_inner(&stmt, frontend, expr, pos);
428
429        stmt.hir_exprs.clear();
430        self.stmt_ctx = Some(stmt);
431
432        res
433    }
434
435    /// Similar to [`lower`](Self::lower) but returns an error if the expression
436    /// returns void (ie. doesn't produce a [`Expression`]).
437    ///
438    /// consumes a [`StmtContext`] returning it to the context so that it can be
439    /// used again later.
440    pub fn lower_expect(
441        &mut self,
442        mut stmt: StmtContext,
443        frontend: &mut Frontend,
444        expr: Handle<HirExpr>,
445        pos: ExprPos,
446    ) -> Result<(Handle<Expression>, Span)> {
447        let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
448
449        stmt.hir_exprs.clear();
450        self.stmt_ctx = Some(stmt);
451
452        res
453    }
454
455    /// internal implementation of [`lower_expect`](Self::lower_expect)
456    ///
457    /// this method is only public because it's used in
458    /// [`function_call`](Frontend::function_call), unless you know what
459    /// you're doing use [`lower_expect`](Self::lower_expect)
460    pub fn lower_expect_inner(
461        &mut self,
462        stmt: &StmtContext,
463        frontend: &mut Frontend,
464        expr: Handle<HirExpr>,
465        pos: ExprPos,
466    ) -> Result<(Handle<Expression>, Span)> {
467        let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
468
469        let expr = match maybe_expr {
470            Some(e) => e,
471            None => {
472                return Err(Error {
473                    kind: ErrorKind::SemanticError("Expression returns void".into()),
474                    meta,
475                })
476            }
477        };
478
479        Ok((expr, meta))
480    }
481
482    fn lower_store(
483        &mut self,
484        pointer: Handle<Expression>,
485        value: Handle<Expression>,
486        meta: Span,
487    ) -> Result<()> {
488        if let Expression::Swizzle {
489            size,
490            mut vector,
491            pattern,
492        } = self.expressions[pointer]
493        {
494            // Stores to swizzled values are not directly supported,
495            // lower them as series of per-component stores.
496            let size = match size {
497                VectorSize::Bi => 2,
498                VectorSize::Tri => 3,
499                VectorSize::Quad => 4,
500            };
501
502            if let Expression::Load { pointer } = self.expressions[vector] {
503                vector = pointer;
504            }
505
506            #[allow(clippy::needless_range_loop)]
507            for index in 0..size {
508                let dst = self.add_expression(
509                    Expression::AccessIndex {
510                        base: vector,
511                        index: pattern[index].index(),
512                    },
513                    meta,
514                )?;
515                let src = self.add_expression(
516                    Expression::AccessIndex {
517                        base: value,
518                        index: index as u32,
519                    },
520                    meta,
521                )?;
522
523                self.emit_restart();
524
525                self.body.push(
526                    Statement::Store {
527                        pointer: dst,
528                        value: src,
529                    },
530                    meta,
531                );
532            }
533        } else {
534            self.emit_restart();
535
536            self.body.push(Statement::Store { pointer, value }, meta);
537        }
538
539        Ok(())
540    }
541
542    /// Internal implementation of [`lower`](Self::lower)
543    #[allow(clippy::large_stack_frames)] // TODO(https://github.com/gfx-rs/wgpu/issues/9456)
544    fn lower_inner(
545        &mut self,
546        stmt: &StmtContext,
547        frontend: &mut Frontend,
548        expr: Handle<HirExpr>,
549        pos: ExprPos,
550    ) -> Result<(Option<Handle<Expression>>, Span)> {
551        let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
552
553        log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
554
555        let handle = match *kind {
556            HirExprKind::Access { base, index } => {
557                let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
558                let maybe_constant_index = match pos {
559                    // Don't try to generate `AccessIndex` if in a LHS position, since it
560                    // wouldn't produce a pointer.
561                    ExprPos::Lhs => None,
562                    _ => self
563                        .module
564                        .to_ctx()
565                        .get_const_val_from(index, &self.expressions)
566                        .ok(),
567                };
568
569                let base = self
570                    .lower_expect_inner(
571                        stmt,
572                        frontend,
573                        base,
574                        pos.maybe_access_base(maybe_constant_index.is_some()),
575                    )?
576                    .0;
577
578                let pointer = maybe_constant_index
579                    .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
580                    .unwrap_or_else(|| {
581                        self.add_expression(Expression::Access { base, index }, meta)
582                    })?;
583
584                if ExprPos::Rhs == pos {
585                    let resolved = self.resolve_type(pointer, meta)?;
586                    if resolved.pointer_space().is_some() {
587                        return Ok((
588                            Some(self.add_expression(Expression::Load { pointer }, meta)?),
589                            meta,
590                        ));
591                    }
592                }
593
594                pointer
595            }
596            HirExprKind::Select { base, ref field } => {
597                let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
598
599                frontend.field_selection(self, pos, base, field, meta)?
600            }
601            HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
602                self.add_expression(Expression::Literal(literal), meta)?
603            }
604            HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
605                let (mut left, left_meta) =
606                    self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
607                let (mut right, right_meta) =
608                    self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
609
610                match op {
611                    BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
612                        self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
613                    }
614                    _ => self
615                        .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
616                }
617
618                self.typifier_grow(left, left_meta)?;
619                self.typifier_grow(right, right_meta)?;
620
621                let left_inner = self.get_type(left);
622                let right_inner = self.get_type(right);
623
624                match (left_inner, right_inner) {
625                    (
626                        &TypeInner::Matrix {
627                            columns: left_columns,
628                            rows: left_rows,
629                            scalar: left_scalar,
630                        },
631                        &TypeInner::Matrix {
632                            columns: right_columns,
633                            rows: right_rows,
634                            scalar: right_scalar,
635                        },
636                    ) => {
637                        let dimensions_ok = if op == BinaryOperator::Multiply {
638                            left_columns == right_rows
639                        } else {
640                            left_columns == right_columns && left_rows == right_rows
641                        };
642
643                        // Check that the two arguments have the same dimensions
644                        if !dimensions_ok || left_scalar != right_scalar {
645                            frontend.errors.push(Error {
646                                kind: ErrorKind::SemanticError(
647                                    format!(
648                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
649                                    )
650                                    .into(),
651                                ),
652                                meta,
653                            })
654                        }
655
656                        match op {
657                            BinaryOperator::Divide => {
658                                // Naga IR doesn't support matrix division so we need to
659                                // divide the columns individually and reassemble the matrix
660                                let mut components = Vec::with_capacity(left_columns as usize);
661
662                                for index in 0..left_columns as u32 {
663                                    // Get the column vectors
664                                    let left_vector = self.add_expression(
665                                        Expression::AccessIndex { base: left, index },
666                                        meta,
667                                    )?;
668                                    let right_vector = self.add_expression(
669                                        Expression::AccessIndex { base: right, index },
670                                        meta,
671                                    )?;
672
673                                    // Divide the vectors
674                                    let column = self.add_expression(
675                                        Expression::Binary {
676                                            op,
677                                            left: left_vector,
678                                            right: right_vector,
679                                        },
680                                        meta,
681                                    )?;
682
683                                    components.push(column)
684                                }
685
686                                let ty = self.module.types.insert(
687                                    Type {
688                                        name: None,
689                                        inner: TypeInner::Matrix {
690                                            columns: left_columns,
691                                            rows: left_rows,
692                                            scalar: left_scalar,
693                                        },
694                                    },
695                                    Span::default(),
696                                );
697
698                                // Rebuild the matrix from the divided vectors
699                                self.add_expression(Expression::Compose { ty, components }, meta)?
700                            }
701                            BinaryOperator::Equal | BinaryOperator::NotEqual => {
702                                // Naga IR doesn't support matrix comparisons so we need to
703                                // compare the columns individually and then fold them together
704                                //
705                                // The folding is done using a logical and for equality and
706                                // a logical or for inequality
707                                let equals = op == BinaryOperator::Equal;
708
709                                let (op, combine, fun) = match equals {
710                                    true => (
711                                        BinaryOperator::Equal,
712                                        BinaryOperator::LogicalAnd,
713                                        RelationalFunction::All,
714                                    ),
715                                    false => (
716                                        BinaryOperator::NotEqual,
717                                        BinaryOperator::LogicalOr,
718                                        RelationalFunction::Any,
719                                    ),
720                                };
721
722                                let mut root = None;
723
724                                for index in 0..left_columns as u32 {
725                                    // Get the column vectors
726                                    let left_vector = self.add_expression(
727                                        Expression::AccessIndex { base: left, index },
728                                        meta,
729                                    )?;
730                                    let right_vector = self.add_expression(
731                                        Expression::AccessIndex { base: right, index },
732                                        meta,
733                                    )?;
734
735                                    let argument = self.add_expression(
736                                        Expression::Binary {
737                                            op,
738                                            left: left_vector,
739                                            right: right_vector,
740                                        },
741                                        meta,
742                                    )?;
743
744                                    // The result of comparing two vectors is a boolean vector
745                                    // so use a relational function like all to get a single
746                                    // boolean value
747                                    let compare = self.add_expression(
748                                        Expression::Relational { fun, argument },
749                                        meta,
750                                    )?;
751
752                                    // Fold the result
753                                    root = Some(match root {
754                                        Some(right) => self.add_expression(
755                                            Expression::Binary {
756                                                op: combine,
757                                                left: compare,
758                                                right,
759                                            },
760                                            meta,
761                                        )?,
762                                        None => compare,
763                                    });
764                                }
765
766                                root.unwrap()
767                            }
768                            _ => {
769                                self.add_expression(Expression::Binary { left, op, right }, meta)?
770                            }
771                        }
772                    }
773                    (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
774                        BinaryOperator::Equal | BinaryOperator::NotEqual => {
775                            let equals = op == BinaryOperator::Equal;
776
777                            let (op, fun) = match equals {
778                                true => (BinaryOperator::Equal, RelationalFunction::All),
779                                false => (BinaryOperator::NotEqual, RelationalFunction::Any),
780                            };
781
782                            let argument =
783                                self.add_expression(Expression::Binary { op, left, right }, meta)?;
784
785                            self.add_expression(Expression::Relational { fun, argument }, meta)?
786                        }
787                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
788                    },
789                    (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
790                        BinaryOperator::Add
791                        | BinaryOperator::Subtract
792                        | BinaryOperator::Divide
793                        | BinaryOperator::And
794                        | BinaryOperator::ExclusiveOr
795                        | BinaryOperator::InclusiveOr
796                        | BinaryOperator::ShiftLeft
797                        | BinaryOperator::ShiftRight => {
798                            let scalar_vector = self
799                                .add_expression(Expression::Splat { size, value: right }, meta)?;
800
801                            self.add_expression(
802                                Expression::Binary {
803                                    op,
804                                    left,
805                                    right: scalar_vector,
806                                },
807                                meta,
808                            )?
809                        }
810                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
811                    },
812                    (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
813                        BinaryOperator::Add
814                        | BinaryOperator::Subtract
815                        | BinaryOperator::Divide
816                        | BinaryOperator::And
817                        | BinaryOperator::ExclusiveOr
818                        | BinaryOperator::InclusiveOr => {
819                            let scalar_vector =
820                                self.add_expression(Expression::Splat { size, value: left }, meta)?;
821
822                            self.add_expression(
823                                Expression::Binary {
824                                    op,
825                                    left: scalar_vector,
826                                    right,
827                                },
828                                meta,
829                            )?
830                        }
831                        _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
832                    },
833                    (
834                        &TypeInner::Scalar(left_scalar),
835                        &TypeInner::Matrix {
836                            rows,
837                            columns,
838                            scalar: right_scalar,
839                        },
840                    ) => {
841                        // Check that the two arguments have the same scalar type
842                        if left_scalar != right_scalar {
843                            frontend.errors.push(Error {
844                                kind: ErrorKind::SemanticError(
845                                    format!(
846                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
847                                    )
848                                    .into(),
849                                ),
850                                meta,
851                            })
852                        }
853
854                        match op {
855                            BinaryOperator::Divide
856                            | BinaryOperator::Add
857                            | BinaryOperator::Subtract => {
858                                // Naga IR doesn't support all matrix by scalar operations so
859                                // we need for some to turn the scalar into a vector by
860                                // splatting it and then for each column vector apply the
861                                // operation and finally reconstruct the matrix
862                                let scalar_vector = self.add_expression(
863                                    Expression::Splat {
864                                        size: rows,
865                                        value: left,
866                                    },
867                                    meta,
868                                )?;
869
870                                let mut components = Vec::with_capacity(columns as usize);
871
872                                for index in 0..columns as u32 {
873                                    // Get the column vector
874                                    let matrix_column = self.add_expression(
875                                        Expression::AccessIndex { base: right, index },
876                                        meta,
877                                    )?;
878
879                                    // Apply the operation to the splatted vector and
880                                    // the column vector
881                                    let column = self.add_expression(
882                                        Expression::Binary {
883                                            op,
884                                            left: scalar_vector,
885                                            right: matrix_column,
886                                        },
887                                        meta,
888                                    )?;
889
890                                    components.push(column)
891                                }
892
893                                let ty = self.module.types.insert(
894                                    Type {
895                                        name: None,
896                                        inner: TypeInner::Matrix {
897                                            columns,
898                                            rows,
899                                            scalar: left_scalar,
900                                        },
901                                    },
902                                    Span::default(),
903                                );
904
905                                // Rebuild the matrix from the operation result vectors
906                                self.add_expression(Expression::Compose { ty, components }, meta)?
907                            }
908                            _ => {
909                                self.add_expression(Expression::Binary { left, op, right }, meta)?
910                            }
911                        }
912                    }
913                    (
914                        &TypeInner::Matrix {
915                            rows,
916                            columns,
917                            scalar: left_scalar,
918                        },
919                        &TypeInner::Scalar(right_scalar),
920                    ) => {
921                        // Check that the two arguments have the same scalar type
922                        if left_scalar != right_scalar {
923                            frontend.errors.push(Error {
924                                kind: ErrorKind::SemanticError(
925                                    format!(
926                                        "Cannot apply operation to {left_inner:?} and {right_inner:?}"
927                                    )
928                                    .into(),
929                                ),
930                                meta,
931                            })
932                        }
933
934                        match op {
935                            BinaryOperator::Divide
936                            | BinaryOperator::Add
937                            | BinaryOperator::Subtract => {
938                                // Naga IR doesn't support all matrix by scalar operations so
939                                // we need for some to turn the scalar into a vector by
940                                // splatting it and then for each column vector apply the
941                                // operation and finally reconstruct the matrix
942
943                                let scalar_vector = self.add_expression(
944                                    Expression::Splat {
945                                        size: rows,
946                                        value: right,
947                                    },
948                                    meta,
949                                )?;
950
951                                let mut components = Vec::with_capacity(columns as usize);
952
953                                for index in 0..columns as u32 {
954                                    // Get the column vector
955                                    let matrix_column = self.add_expression(
956                                        Expression::AccessIndex { base: left, index },
957                                        meta,
958                                    )?;
959
960                                    // Apply the operation to the splatted vector and
961                                    // the column vector
962                                    let column = self.add_expression(
963                                        Expression::Binary {
964                                            op,
965                                            left: matrix_column,
966                                            right: scalar_vector,
967                                        },
968                                        meta,
969                                    )?;
970
971                                    components.push(column)
972                                }
973
974                                let ty = self.module.types.insert(
975                                    Type {
976                                        name: None,
977                                        inner: TypeInner::Matrix {
978                                            columns,
979                                            rows,
980                                            scalar: left_scalar,
981                                        },
982                                    },
983                                    Span::default(),
984                                );
985
986                                // Rebuild the matrix from the operation result vectors
987                                self.add_expression(Expression::Compose { ty, components }, meta)?
988                            }
989                            _ => {
990                                self.add_expression(Expression::Binary { left, op, right }, meta)?
991                            }
992                        }
993                    }
994                    _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
995                }
996            }
997            HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
998                let expr = self
999                    .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
1000                    .0;
1001
1002                if let TypeInner::Matrix { scalar, .. } = *self.resolve_type(expr, meta)? {
1003                    // Naga IR doesn't support matrix negation, so we need to turn it into
1004                    // multiplication by scalar -1.
1005                    let minus_one = Literal::minus_one(scalar).ok_or_else(|| Error {
1006                        kind: ErrorKind::SemanticError(
1007                            format!("Cannot apply operator {op:?} to type {scalar:?}").into(),
1008                        ),
1009                        meta,
1010                    })?;
1011                    let lhs = self.add_expression(Expression::Literal(minus_one), meta)?;
1012                    self.add_expression(
1013                        Expression::Binary {
1014                            op: BinaryOperator::Multiply,
1015                            left: lhs,
1016                            right: expr,
1017                        },
1018                        meta,
1019                    )?
1020                } else {
1021                    self.add_expression(Expression::Unary { op, expr }, meta)?
1022                }
1023            }
1024            HirExprKind::Variable(ref var) => match pos {
1025                ExprPos::Lhs => {
1026                    if !var.mutable {
1027                        frontend.errors.push(Error {
1028                            kind: ErrorKind::SemanticError(
1029                                "Variable cannot be used in LHS position".into(),
1030                            ),
1031                            meta,
1032                        })
1033                    }
1034
1035                    var.expr
1036                }
1037                ExprPos::AccessBase { constant_index } => {
1038                    // If the index isn't constant all accesses backed by a constant base need
1039                    // to be done through a proxy local variable, since constants have a non
1040                    // pointer type which is required for dynamic indexing
1041                    if !constant_index {
1042                        if let Some((constant, ty)) = var.constant {
1043                            let init = self
1044                                .add_expression(Expression::Constant(constant), Span::default())?;
1045                            let local = self.locals.append(
1046                                LocalVariable {
1047                                    name: None,
1048                                    ty,
1049                                    init: Some(init),
1050                                },
1051                                Span::default(),
1052                            );
1053
1054                            self.add_expression(Expression::LocalVariable(local), Span::default())?
1055                        } else {
1056                            var.expr
1057                        }
1058                    } else {
1059                        var.expr
1060                    }
1061                }
1062                _ if var.load => {
1063                    self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1064                }
1065                ExprPos::Rhs => {
1066                    if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1067                        self.add_expression(Expression::Constant(constant), meta)?
1068                    } else {
1069                        // Check if this is an Override expression in const context
1070                        if self.is_const {
1071                            if let Expression::Override(o) = self.expressions[var.expr] {
1072                                // Need to add the Override expression to the global arena
1073                                self.add_expression(Expression::Override(o), meta)?
1074                            } else {
1075                                var.expr
1076                            }
1077                        } else {
1078                            var.expr
1079                        }
1080                    }
1081                }
1082            },
1083            HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1084                let maybe_expr = frontend.function_or_constructor_call(
1085                    self,
1086                    stmt,
1087                    call.kind.clone(),
1088                    &call.args,
1089                    meta,
1090                )?;
1091                return Ok((maybe_expr, meta));
1092            }
1093            // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
1094            //
1095            // The ternary operator is defined to only evaluate one of the two possible
1096            // expressions which means that it's behavior is that of an `if` statement,
1097            // and it's merely syntactic sugar for it.
1098            HirExprKind::Conditional {
1099                condition,
1100                accept,
1101                reject,
1102            } if ExprPos::Lhs != pos => {
1103                // Given an expression `a ? b : c`, we need to produce a Naga
1104                // statement roughly like:
1105                //
1106                //     var temp;
1107                //     if a {
1108                //         temp = convert(b);
1109                //     } else  {
1110                //         temp = convert(c);
1111                //     }
1112                //
1113                // where `convert` stands for type conversions to bring `b` and `c` to
1114                // the same type, and then use `temp` to represent the value of the whole
1115                // conditional expression in subsequent code.
1116
1117                // Lower the condition first to the current bodyy
1118                let condition = self
1119                    .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1120                    .0;
1121
1122                let (mut accept_body, (mut accept, accept_meta)) =
1123                    self.new_body_with_ret(|ctx| {
1124                        // Lower the `true` branch
1125                        ctx.lower_expect_inner(stmt, frontend, accept, pos)
1126                    })?;
1127
1128                let (mut reject_body, (mut reject, reject_meta)) =
1129                    self.new_body_with_ret(|ctx| {
1130                        // Lower the `false` branch
1131                        ctx.lower_expect_inner(stmt, frontend, reject, pos)
1132                    })?;
1133
1134                // We need to do some custom implicit conversions since the two target expressions
1135                // are in different bodies
1136                if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1137                    // Get the components of both branches and calculate the type power
1138                    self.expr_scalar_components(accept, accept_meta)?
1139                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1140                    self.expr_scalar_components(reject, reject_meta)?
1141                        .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1142                ) {
1143                    match accept_power.cmp(&reject_power) {
1144                        core::cmp::Ordering::Less => {
1145                            accept_body = self.with_body(accept_body, |ctx| {
1146                                ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1147                                Ok(())
1148                            })?;
1149                        }
1150                        core::cmp::Ordering::Equal => {}
1151                        core::cmp::Ordering::Greater => {
1152                            reject_body = self.with_body(reject_body, |ctx| {
1153                                ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1154                                Ok(())
1155                            })?;
1156                        }
1157                    }
1158                }
1159
1160                // We need to get the type of the resulting expression to create the local,
1161                // this must be done after implicit conversions to ensure both branches have
1162                // the same type.
1163                let ty = self.resolve_type_handle(accept, accept_meta)?;
1164
1165                // Add the local that will hold the result of our conditional
1166                let local = self.locals.append(
1167                    LocalVariable {
1168                        name: None,
1169                        ty,
1170                        init: None,
1171                    },
1172                    meta,
1173                );
1174
1175                let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1176
1177                // Add to each  the store to the result variable
1178                accept_body.push(
1179                    Statement::Store {
1180                        pointer: local_expr,
1181                        value: accept,
1182                    },
1183                    accept_meta,
1184                );
1185                reject_body.push(
1186                    Statement::Store {
1187                        pointer: local_expr,
1188                        value: reject,
1189                    },
1190                    reject_meta,
1191                );
1192
1193                // Finally add the `If` to the main body with the `condition` we lowered
1194                // earlier and the branches we prepared.
1195                self.body.push(
1196                    Statement::If {
1197                        condition,
1198                        accept: accept_body,
1199                        reject: reject_body,
1200                    },
1201                    meta,
1202                );
1203
1204                // Note: `Expression::Load` must be emitted before it's used so make
1205                // sure the emitter is active here.
1206                self.add_expression(
1207                    Expression::Load {
1208                        pointer: local_expr,
1209                    },
1210                    meta,
1211                )?
1212            }
1213            HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1214                let (pointer, ptr_meta) =
1215                    self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1216                let (mut value, value_meta) =
1217                    self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1218
1219                let ty = match *self.resolve_type(pointer, ptr_meta)? {
1220                    TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1221                    ref ty => ty,
1222                };
1223
1224                if let Some(scalar) = scalar_components(ty) {
1225                    self.implicit_conversion(&mut value, value_meta, scalar)?;
1226                }
1227
1228                self.lower_store(pointer, value, meta)?;
1229
1230                value
1231            }
1232            HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1233                let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1234                let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1235                    pointer
1236                } else {
1237                    self.add_expression(Expression::Load { pointer }, meta)?
1238                };
1239
1240                let res = match *self.resolve_type(left, meta)? {
1241                    TypeInner::Scalar(scalar) => {
1242                        let ty = TypeInner::Scalar(scalar);
1243                        Literal::one(scalar).map(|i| (ty, i, None, None))
1244                    }
1245                    TypeInner::Vector { size, scalar } => {
1246                        let ty = TypeInner::Vector { size, scalar };
1247                        Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1248                    }
1249                    TypeInner::Matrix {
1250                        columns,
1251                        rows,
1252                        scalar,
1253                    } => {
1254                        let ty = TypeInner::Matrix {
1255                            columns,
1256                            rows,
1257                            scalar,
1258                        };
1259                        Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1260                    }
1261                    _ => None,
1262                };
1263                let (ty_inner, literal, rows, columns) = match res {
1264                    Some(res) => res,
1265                    None => {
1266                        frontend.errors.push(Error {
1267                            kind: ErrorKind::SemanticError(
1268                                "Increment/decrement only works on scalar/vector/matrix".into(),
1269                            ),
1270                            meta,
1271                        });
1272                        return Ok((Some(left), meta));
1273                    }
1274                };
1275
1276                let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1277
1278                // Glsl allows pre/postfixes operations on vectors and matrices, so if the
1279                // target is either of them change the right side of the addition to be splatted
1280                // to the same size as the target, furthermore if the target is a matrix
1281                // use a composed matrix using the splatted value.
1282                if let Some(size) = rows {
1283                    right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1284
1285                    if let Some(cols) = columns {
1286                        let ty = self.module.types.insert(
1287                            Type {
1288                                name: None,
1289                                inner: ty_inner,
1290                            },
1291                            meta,
1292                        );
1293
1294                        right = self.add_expression(
1295                            Expression::Compose {
1296                                ty,
1297                                components: core::iter::repeat_n(right, cols as usize).collect(),
1298                            },
1299                            meta,
1300                        )?;
1301                    }
1302                }
1303
1304                let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1305
1306                self.lower_store(pointer, value, meta)?;
1307
1308                if postfix {
1309                    left
1310                } else {
1311                    value
1312                }
1313            }
1314            HirExprKind::Method {
1315                expr: object,
1316                ref name,
1317                ref args,
1318            } if ExprPos::Lhs != pos => {
1319                let args = args
1320                    .iter()
1321                    .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1322                    .collect::<Result<Vec<_>>>()?;
1323                match name.as_ref() {
1324                    "length" => {
1325                        if !args.is_empty() {
1326                            frontend.errors.push(Error {
1327                                kind: ErrorKind::SemanticError(
1328                                    ".length() doesn't take any arguments".into(),
1329                                ),
1330                                meta,
1331                            });
1332                        }
1333                        let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1334                        let array_type = self.resolve_type(lowered_array, meta)?;
1335
1336                        match *array_type {
1337                            TypeInner::Array {
1338                                size: crate::ArraySize::Constant(size),
1339                                ..
1340                            } => {
1341                                let mut array_length = self.add_expression(
1342                                    Expression::Literal(Literal::U32(size.get())),
1343                                    meta,
1344                                )?;
1345                                self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1346                                array_length
1347                            }
1348                            // let the error be handled in type checking if it's not a dynamic array
1349                            _ => {
1350                                let mut array_length = self
1351                                    .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1352                                self.conversion(&mut array_length, meta, Scalar::I32)?;
1353                                array_length
1354                            }
1355                        }
1356                    }
1357
1358                    _ => {
1359                        return Err(Error {
1360                            kind: ErrorKind::SemanticError(
1361                                format!("unknown method '{name}'").into(),
1362                            ),
1363                            meta,
1364                        });
1365                    }
1366                }
1367            }
1368            HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => {
1369                let mut last_handle = None;
1370                for expr in exprs.iter() {
1371                    let (handle, _) =
1372                        self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?;
1373                    last_handle = Some(handle);
1374                }
1375                match last_handle {
1376                    Some(handle) => handle,
1377                    None => unreachable!(),
1378                }
1379            }
1380            _ => {
1381                return Err(Error {
1382                    kind: ErrorKind::SemanticError(
1383                        format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1384                            .into(),
1385                    ),
1386                    meta,
1387                })
1388            }
1389        };
1390
1391        log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1392
1393        Ok((Some(handle), meta))
1394    }
1395
1396    pub fn expr_scalar_components(
1397        &mut self,
1398        expr: Handle<Expression>,
1399        meta: Span,
1400    ) -> Result<Option<Scalar>> {
1401        let ty = self.resolve_type(expr, meta)?;
1402        Ok(scalar_components(ty))
1403    }
1404
1405    pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1406        Ok(self
1407            .expr_scalar_components(expr, meta)?
1408            .and_then(type_power))
1409    }
1410
1411    pub fn conversion(
1412        &mut self,
1413        expr: &mut Handle<Expression>,
1414        meta: Span,
1415        scalar: Scalar,
1416    ) -> Result<()> {
1417        *expr = self.add_expression(
1418            Expression::As {
1419                expr: *expr,
1420                kind: scalar.kind,
1421                convert: Some(scalar.width),
1422            },
1423            meta,
1424        )?;
1425
1426        Ok(())
1427    }
1428
1429    pub fn implicit_conversion(
1430        &mut self,
1431        expr: &mut Handle<Expression>,
1432        meta: Span,
1433        scalar: Scalar,
1434    ) -> Result<()> {
1435        if let (Some(tgt_power), Some(expr_power)) =
1436            (type_power(scalar), self.expr_power(*expr, meta)?)
1437        {
1438            if tgt_power > expr_power {
1439                self.conversion(expr, meta, scalar)?;
1440            }
1441        }
1442
1443        Ok(())
1444    }
1445
1446    pub fn forced_conversion(
1447        &mut self,
1448        expr: &mut Handle<Expression>,
1449        meta: Span,
1450        scalar: Scalar,
1451    ) -> Result<()> {
1452        if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1453            if expr_scalar != scalar {
1454                self.conversion(expr, meta, scalar)?;
1455            }
1456        }
1457
1458        Ok(())
1459    }
1460
1461    pub fn binary_implicit_conversion(
1462        &mut self,
1463        left: &mut Handle<Expression>,
1464        left_meta: Span,
1465        right: &mut Handle<Expression>,
1466        right_meta: Span,
1467    ) -> Result<()> {
1468        let left_components = self.expr_scalar_components(*left, left_meta)?;
1469        let right_components = self.expr_scalar_components(*right, right_meta)?;
1470
1471        if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1472            left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1473            right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1474        ) {
1475            match left_power.cmp(&right_power) {
1476                core::cmp::Ordering::Less => {
1477                    self.conversion(left, left_meta, right_scalar)?;
1478                }
1479                core::cmp::Ordering::Equal => {}
1480                core::cmp::Ordering::Greater => {
1481                    self.conversion(right, right_meta, left_scalar)?;
1482                }
1483            }
1484        }
1485
1486        Ok(())
1487    }
1488
1489    pub fn implicit_splat(
1490        &mut self,
1491        expr: &mut Handle<Expression>,
1492        meta: Span,
1493        vector_size: Option<VectorSize>,
1494    ) -> Result<()> {
1495        let expr_type = self.resolve_type(*expr, meta)?;
1496
1497        if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1498            *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1499        }
1500
1501        Ok(())
1502    }
1503
1504    pub fn vector_resize(
1505        &mut self,
1506        size: VectorSize,
1507        vector: Handle<Expression>,
1508        meta: Span,
1509    ) -> Result<Handle<Expression>> {
1510        self.add_expression(
1511            Expression::Swizzle {
1512                size,
1513                vector,
1514                pattern: crate::SwizzleComponent::XYZW,
1515            },
1516            meta,
1517        )
1518    }
1519}
1520
1521impl Index<Handle<Expression>> for Context<'_> {
1522    type Output = Expression;
1523
1524    fn index(&self, index: Handle<Expression>) -> &Self::Output {
1525        if self.is_const {
1526            &self.module.global_expressions[index]
1527        } else {
1528            &self.expressions[index]
1529        }
1530    }
1531}
1532
1533/// Helper struct passed when parsing expressions
1534///
1535/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
1536/// and only one of these may be active at any time per context.
1537#[derive(Debug)]
1538pub struct StmtContext {
1539    /// A arena of high level expressions which can be lowered through a
1540    /// [`Context`] to Naga's [`Expression`]s
1541    pub hir_exprs: Arena<HirExpr>,
1542}
1543
1544impl StmtContext {
1545    const fn new() -> Self {
1546        StmtContext {
1547            hir_exprs: Arena::new(),
1548        }
1549    }
1550}