1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5 ast::{
6 GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7 VariableReference,
8 },
9 error::{Error, ErrorKind},
10 types::{scalar_components, type_power},
11 Frontend, Result,
12};
13use crate::{
14 front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15 Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16 Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22 Lhs,
24 Rhs,
26 AccessBase {
29 constant_index: bool,
31 },
32}
33
34impl ExprPos {
35 const fn maybe_access_base(&self, constant_index: bool) -> Self {
37 match *self {
38 ExprPos::Lhs
39 | ExprPos::AccessBase {
40 constant_index: false,
41 } => *self,
42 _ => ExprPos::AccessBase { constant_index },
43 }
44 }
45}
46
47#[derive(Debug)]
48pub struct Context<'a> {
49 pub expressions: Arena<Expression>,
50 pub locals: Arena<LocalVariable>,
51
52 pub arguments: Vec<FunctionArgument>,
61
62 pub parameters: Vec<Handle<Type>>,
70 pub parameters_info: Vec<ParameterInfo>,
71
72 pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73 pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75 pub const_typifier: Typifier,
76 pub typifier: Typifier,
77 layouter: Layouter,
78 emitter: Emitter,
79 stmt_ctx: Option<StmtContext>,
80 pub body: Block,
81 pub module: &'a mut crate::Module,
82 pub is_const: bool,
83 pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85 pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90 pub fn new(
91 frontend: &Frontend,
92 module: &'a mut crate::Module,
93 is_const: bool,
94 global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95 ) -> Result<Self> {
96 let mut this = Context {
97 expressions: Arena::new(),
98 locals: Arena::new(),
99 arguments: Vec::new(),
100
101 parameters: Vec::new(),
102 parameters_info: Vec::new(),
103
104 symbol_table: crate::front::SymbolTable::default(),
105 samplers: FastHashMap::default(),
106
107 const_typifier: Typifier::new(),
108 typifier: Typifier::new(),
109 layouter: Layouter::default(),
110 emitter: Emitter::default(),
111 stmt_ctx: Some(StmtContext::new()),
112 body: Block::new(),
113 module,
114 is_const: false,
115 local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116 global_expression_kind_tracker,
117 };
118
119 this.emit_start();
120
121 for &(ref name, lookup) in frontend.global_variables.iter() {
122 this.add_global(name, lookup)?
123 }
124 this.is_const = is_const;
125
126 Ok(this)
127 }
128
129 pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130 where
131 F: FnOnce(&mut Self) -> Result<()>,
132 {
133 self.new_body_with_ret(cb).map(|(b, _)| b)
134 }
135
136 pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137 where
138 F: FnOnce(&mut Self) -> Result<R>,
139 {
140 self.emit_restart();
141 let old_body = core::mem::replace(&mut self.body, Block::new());
142 let res = cb(self);
143 self.emit_restart();
144 let new_body = core::mem::replace(&mut self.body, old_body);
145 res.map(|r| (new_body, r))
146 }
147
148 pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149 where
150 F: FnOnce(&mut Self) -> Result<()>,
151 {
152 self.emit_restart();
153 let old_body = core::mem::replace(&mut self.body, body);
154 let res = cb(self);
155 self.emit_restart();
156 let body = core::mem::replace(&mut self.body, old_body);
157 res.map(|_| body)
158 }
159
160 pub fn add_global(
161 &mut self,
162 name: &str,
163 GlobalLookup {
164 kind,
165 entry_arg,
166 mutable,
167 }: GlobalLookup,
168 ) -> Result<()> {
169 let (expr, load, constant) = match kind {
170 GlobalLookupKind::Variable(v) => {
171 let span = self.module.global_variables.get_span(v);
172 (
173 self.add_expression(Expression::GlobalVariable(v), span)?,
174 self.module.global_variables[v].space != AddressSpace::Handle,
175 None,
176 )
177 }
178 GlobalLookupKind::BlockSelect(handle, index) => {
179 let span = self.module.global_variables.get_span(handle);
180 let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181 let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183 (
184 expr,
185 {
186 let ty = self.module.global_variables[handle].ty;
187
188 match self.module.types[ty].inner {
189 TypeInner::Struct { ref members, .. } => {
190 if let TypeInner::Array {
191 size: crate::ArraySize::Dynamic,
192 ..
193 } = self.module.types[members[index as usize].ty].inner
194 {
195 false
196 } else {
197 true
198 }
199 }
200 _ => true,
201 }
202 },
203 None,
204 )
205 }
206 GlobalLookupKind::Constant(v, ty) => {
207 let span = self.module.constants.get_span(v);
208 (
209 self.add_expression(Expression::Constant(v), span)?,
210 false,
211 Some((v, ty)),
212 )
213 }
214 GlobalLookupKind::Override(v, _ty) => {
215 let span = self.module.overrides.get_span(v);
216 (
217 self.add_expression(Expression::Override(v), span)?,
218 false,
219 None,
220 )
221 }
222 };
223
224 let var = VariableReference {
225 expr,
226 load,
227 mutable,
228 constant,
229 entry_arg,
230 };
231
232 self.symbol_table.add(name.into(), var);
233
234 Ok(())
235 }
236
237 #[inline]
243 pub fn emit_start(&mut self) {
244 self.emitter.start(&self.expressions)
245 }
246
247 pub fn emit_end(&mut self) {
256 self.body.extend(self.emitter.finish(&self.expressions))
257 }
258
259 pub fn emit_restart(&mut self) {
266 self.emit_end();
267 self.emit_start()
268 }
269
270 pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
271 let mut eval = if self.is_const {
272 crate::proc::ConstantEvaluator::for_glsl_module(
273 self.module,
274 self.global_expression_kind_tracker,
275 &mut self.layouter,
276 )
277 } else {
278 crate::proc::ConstantEvaluator::for_glsl_function(
279 self.module,
280 &mut self.expressions,
281 &mut self.local_expression_kind_tracker,
282 &mut self.layouter,
283 &mut self.emitter,
284 &mut self.body,
285 )
286 };
287
288 eval.try_eval_and_append(expr, meta).map_err(|e| Error {
289 kind: e.into(),
290 meta,
291 })
292 }
293
294 pub fn add_local_var(
299 &mut self,
300 name: String,
301 expr: Handle<Expression>,
302 mutable: bool,
303 ) -> Option<VariableReference> {
304 let var = VariableReference {
305 expr,
306 load: true,
307 mutable,
308 constant: None,
309 entry_arg: None,
310 };
311
312 self.symbol_table.add(name, var)
313 }
314
315 pub fn add_function_arg(
317 &mut self,
318 name_meta: Option<(String, Span)>,
319 ty: Handle<Type>,
320 qualifier: ParameterQualifier,
321 ) -> Result<()> {
322 let index = self.arguments.len();
323 let mut arg = FunctionArgument {
324 name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
325 ty,
326 binding: None,
327 };
328 self.parameters.push(ty);
329
330 let opaque = match self.module.types[ty].inner {
331 TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
332 _ => false,
333 };
334
335 if qualifier.is_lhs() {
336 let span = self.module.types.get_span(arg.ty);
337 arg.ty = self.module.types.insert(
338 Type {
339 name: None,
340 inner: TypeInner::Pointer {
341 base: arg.ty,
342 space: AddressSpace::Function,
343 },
344 },
345 span,
346 )
347 }
348
349 self.arguments.push(arg);
350
351 self.parameters_info.push(ParameterInfo {
352 qualifier,
353 depth: false,
354 });
355
356 if let Some((name, meta)) = name_meta {
357 let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
358 let mutable = qualifier != ParameterQualifier::Const && !opaque;
359 let load = qualifier.is_lhs();
360
361 let var = if mutable && !load {
362 let handle = self.locals.append(
363 LocalVariable {
364 name: Some(name.clone()),
365 ty,
366 init: None,
367 },
368 meta,
369 );
370 let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
371
372 self.emit_restart();
373
374 self.body.push(
375 Statement::Store {
376 pointer: local_expr,
377 value: expr,
378 },
379 meta,
380 );
381
382 VariableReference {
383 expr: local_expr,
384 load: true,
385 mutable,
386 constant: None,
387 entry_arg: None,
388 }
389 } else {
390 VariableReference {
391 expr,
392 load,
393 mutable,
394 constant: None,
395 entry_arg: None,
396 }
397 };
398
399 self.symbol_table.add(name, var);
400 }
401
402 Ok(())
403 }
404
405 #[must_use]
412 pub fn stmt_ctx(&mut self) -> StmtContext {
413 self.stmt_ctx.take().unwrap()
414 }
415
416 pub fn lower(
421 &mut self,
422 mut stmt: StmtContext,
423 frontend: &mut Frontend,
424 expr: Handle<HirExpr>,
425 pos: ExprPos,
426 ) -> Result<(Option<Handle<Expression>>, Span)> {
427 let res = self.lower_inner(&stmt, frontend, expr, pos);
428
429 stmt.hir_exprs.clear();
430 self.stmt_ctx = Some(stmt);
431
432 res
433 }
434
435 pub fn lower_expect(
441 &mut self,
442 mut stmt: StmtContext,
443 frontend: &mut Frontend,
444 expr: Handle<HirExpr>,
445 pos: ExprPos,
446 ) -> Result<(Handle<Expression>, Span)> {
447 let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
448
449 stmt.hir_exprs.clear();
450 self.stmt_ctx = Some(stmt);
451
452 res
453 }
454
455 pub fn lower_expect_inner(
461 &mut self,
462 stmt: &StmtContext,
463 frontend: &mut Frontend,
464 expr: Handle<HirExpr>,
465 pos: ExprPos,
466 ) -> Result<(Handle<Expression>, Span)> {
467 let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
468
469 let expr = match maybe_expr {
470 Some(e) => e,
471 None => {
472 return Err(Error {
473 kind: ErrorKind::SemanticError("Expression returns void".into()),
474 meta,
475 })
476 }
477 };
478
479 Ok((expr, meta))
480 }
481
482 fn lower_store(
483 &mut self,
484 pointer: Handle<Expression>,
485 value: Handle<Expression>,
486 meta: Span,
487 ) -> Result<()> {
488 if let Expression::Swizzle {
489 size,
490 mut vector,
491 pattern,
492 } = self.expressions[pointer]
493 {
494 let size = match size {
497 VectorSize::Bi => 2,
498 VectorSize::Tri => 3,
499 VectorSize::Quad => 4,
500 };
501
502 if let Expression::Load { pointer } = self.expressions[vector] {
503 vector = pointer;
504 }
505
506 #[allow(clippy::needless_range_loop)]
507 for index in 0..size {
508 let dst = self.add_expression(
509 Expression::AccessIndex {
510 base: vector,
511 index: pattern[index].index(),
512 },
513 meta,
514 )?;
515 let src = self.add_expression(
516 Expression::AccessIndex {
517 base: value,
518 index: index as u32,
519 },
520 meta,
521 )?;
522
523 self.emit_restart();
524
525 self.body.push(
526 Statement::Store {
527 pointer: dst,
528 value: src,
529 },
530 meta,
531 );
532 }
533 } else {
534 self.emit_restart();
535
536 self.body.push(Statement::Store { pointer, value }, meta);
537 }
538
539 Ok(())
540 }
541
542 fn lower_inner(
544 &mut self,
545 stmt: &StmtContext,
546 frontend: &mut Frontend,
547 expr: Handle<HirExpr>,
548 pos: ExprPos,
549 ) -> Result<(Option<Handle<Expression>>, Span)> {
550 let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
551
552 log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
553
554 let handle = match *kind {
555 HirExprKind::Access { base, index } => {
556 let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
557 let maybe_constant_index = match pos {
558 ExprPos::Lhs => None,
561 _ => self
562 .module
563 .to_ctx()
564 .get_const_val_from(index, &self.expressions)
565 .ok(),
566 };
567
568 let base = self
569 .lower_expect_inner(
570 stmt,
571 frontend,
572 base,
573 pos.maybe_access_base(maybe_constant_index.is_some()),
574 )?
575 .0;
576
577 let pointer = maybe_constant_index
578 .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
579 .unwrap_or_else(|| {
580 self.add_expression(Expression::Access { base, index }, meta)
581 })?;
582
583 if ExprPos::Rhs == pos {
584 let resolved = self.resolve_type(pointer, meta)?;
585 if resolved.pointer_space().is_some() {
586 return Ok((
587 Some(self.add_expression(Expression::Load { pointer }, meta)?),
588 meta,
589 ));
590 }
591 }
592
593 pointer
594 }
595 HirExprKind::Select { base, ref field } => {
596 let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
597
598 frontend.field_selection(self, pos, base, field, meta)?
599 }
600 HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
601 self.add_expression(Expression::Literal(literal), meta)?
602 }
603 HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
604 let (mut left, left_meta) =
605 self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
606 let (mut right, right_meta) =
607 self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
608
609 match op {
610 BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
611 self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
612 }
613 _ => self
614 .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
615 }
616
617 self.typifier_grow(left, left_meta)?;
618 self.typifier_grow(right, right_meta)?;
619
620 let left_inner = self.get_type(left);
621 let right_inner = self.get_type(right);
622
623 match (left_inner, right_inner) {
624 (
625 &TypeInner::Matrix {
626 columns: left_columns,
627 rows: left_rows,
628 scalar: left_scalar,
629 },
630 &TypeInner::Matrix {
631 columns: right_columns,
632 rows: right_rows,
633 scalar: right_scalar,
634 },
635 ) => {
636 let dimensions_ok = if op == BinaryOperator::Multiply {
637 left_columns == right_rows
638 } else {
639 left_columns == right_columns && left_rows == right_rows
640 };
641
642 if !dimensions_ok || left_scalar != right_scalar {
644 frontend.errors.push(Error {
645 kind: ErrorKind::SemanticError(
646 format!(
647 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
648 )
649 .into(),
650 ),
651 meta,
652 })
653 }
654
655 match op {
656 BinaryOperator::Divide => {
657 let mut components = Vec::with_capacity(left_columns as usize);
660
661 for index in 0..left_columns as u32 {
662 let left_vector = self.add_expression(
664 Expression::AccessIndex { base: left, index },
665 meta,
666 )?;
667 let right_vector = self.add_expression(
668 Expression::AccessIndex { base: right, index },
669 meta,
670 )?;
671
672 let column = self.add_expression(
674 Expression::Binary {
675 op,
676 left: left_vector,
677 right: right_vector,
678 },
679 meta,
680 )?;
681
682 components.push(column)
683 }
684
685 let ty = self.module.types.insert(
686 Type {
687 name: None,
688 inner: TypeInner::Matrix {
689 columns: left_columns,
690 rows: left_rows,
691 scalar: left_scalar,
692 },
693 },
694 Span::default(),
695 );
696
697 self.add_expression(Expression::Compose { ty, components }, meta)?
699 }
700 BinaryOperator::Equal | BinaryOperator::NotEqual => {
701 let equals = op == BinaryOperator::Equal;
707
708 let (op, combine, fun) = match equals {
709 true => (
710 BinaryOperator::Equal,
711 BinaryOperator::LogicalAnd,
712 RelationalFunction::All,
713 ),
714 false => (
715 BinaryOperator::NotEqual,
716 BinaryOperator::LogicalOr,
717 RelationalFunction::Any,
718 ),
719 };
720
721 let mut root = None;
722
723 for index in 0..left_columns as u32 {
724 let left_vector = self.add_expression(
726 Expression::AccessIndex { base: left, index },
727 meta,
728 )?;
729 let right_vector = self.add_expression(
730 Expression::AccessIndex { base: right, index },
731 meta,
732 )?;
733
734 let argument = self.add_expression(
735 Expression::Binary {
736 op,
737 left: left_vector,
738 right: right_vector,
739 },
740 meta,
741 )?;
742
743 let compare = self.add_expression(
747 Expression::Relational { fun, argument },
748 meta,
749 )?;
750
751 root = Some(match root {
753 Some(right) => self.add_expression(
754 Expression::Binary {
755 op: combine,
756 left: compare,
757 right,
758 },
759 meta,
760 )?,
761 None => compare,
762 });
763 }
764
765 root.unwrap()
766 }
767 _ => {
768 self.add_expression(Expression::Binary { left, op, right }, meta)?
769 }
770 }
771 }
772 (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
773 BinaryOperator::Equal | BinaryOperator::NotEqual => {
774 let equals = op == BinaryOperator::Equal;
775
776 let (op, fun) = match equals {
777 true => (BinaryOperator::Equal, RelationalFunction::All),
778 false => (BinaryOperator::NotEqual, RelationalFunction::Any),
779 };
780
781 let argument =
782 self.add_expression(Expression::Binary { op, left, right }, meta)?;
783
784 self.add_expression(Expression::Relational { fun, argument }, meta)?
785 }
786 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
787 },
788 (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
789 BinaryOperator::Add
790 | BinaryOperator::Subtract
791 | BinaryOperator::Divide
792 | BinaryOperator::And
793 | BinaryOperator::ExclusiveOr
794 | BinaryOperator::InclusiveOr
795 | BinaryOperator::ShiftLeft
796 | BinaryOperator::ShiftRight => {
797 let scalar_vector = self
798 .add_expression(Expression::Splat { size, value: right }, meta)?;
799
800 self.add_expression(
801 Expression::Binary {
802 op,
803 left,
804 right: scalar_vector,
805 },
806 meta,
807 )?
808 }
809 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
810 },
811 (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
812 BinaryOperator::Add
813 | BinaryOperator::Subtract
814 | BinaryOperator::Divide
815 | BinaryOperator::And
816 | BinaryOperator::ExclusiveOr
817 | BinaryOperator::InclusiveOr => {
818 let scalar_vector =
819 self.add_expression(Expression::Splat { size, value: left }, meta)?;
820
821 self.add_expression(
822 Expression::Binary {
823 op,
824 left: scalar_vector,
825 right,
826 },
827 meta,
828 )?
829 }
830 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
831 },
832 (
833 &TypeInner::Scalar(left_scalar),
834 &TypeInner::Matrix {
835 rows,
836 columns,
837 scalar: right_scalar,
838 },
839 ) => {
840 if left_scalar != right_scalar {
842 frontend.errors.push(Error {
843 kind: ErrorKind::SemanticError(
844 format!(
845 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
846 )
847 .into(),
848 ),
849 meta,
850 })
851 }
852
853 match op {
854 BinaryOperator::Divide
855 | BinaryOperator::Add
856 | BinaryOperator::Subtract => {
857 let scalar_vector = self.add_expression(
862 Expression::Splat {
863 size: rows,
864 value: left,
865 },
866 meta,
867 )?;
868
869 let mut components = Vec::with_capacity(columns as usize);
870
871 for index in 0..columns as u32 {
872 let matrix_column = self.add_expression(
874 Expression::AccessIndex { base: right, index },
875 meta,
876 )?;
877
878 let column = self.add_expression(
881 Expression::Binary {
882 op,
883 left: scalar_vector,
884 right: matrix_column,
885 },
886 meta,
887 )?;
888
889 components.push(column)
890 }
891
892 let ty = self.module.types.insert(
893 Type {
894 name: None,
895 inner: TypeInner::Matrix {
896 columns,
897 rows,
898 scalar: left_scalar,
899 },
900 },
901 Span::default(),
902 );
903
904 self.add_expression(Expression::Compose { ty, components }, meta)?
906 }
907 _ => {
908 self.add_expression(Expression::Binary { left, op, right }, meta)?
909 }
910 }
911 }
912 (
913 &TypeInner::Matrix {
914 rows,
915 columns,
916 scalar: left_scalar,
917 },
918 &TypeInner::Scalar(right_scalar),
919 ) => {
920 if left_scalar != right_scalar {
922 frontend.errors.push(Error {
923 kind: ErrorKind::SemanticError(
924 format!(
925 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
926 )
927 .into(),
928 ),
929 meta,
930 })
931 }
932
933 match op {
934 BinaryOperator::Divide
935 | BinaryOperator::Add
936 | BinaryOperator::Subtract => {
937 let scalar_vector = self.add_expression(
943 Expression::Splat {
944 size: rows,
945 value: right,
946 },
947 meta,
948 )?;
949
950 let mut components = Vec::with_capacity(columns as usize);
951
952 for index in 0..columns as u32 {
953 let matrix_column = self.add_expression(
955 Expression::AccessIndex { base: left, index },
956 meta,
957 )?;
958
959 let column = self.add_expression(
962 Expression::Binary {
963 op,
964 left: matrix_column,
965 right: scalar_vector,
966 },
967 meta,
968 )?;
969
970 components.push(column)
971 }
972
973 let ty = self.module.types.insert(
974 Type {
975 name: None,
976 inner: TypeInner::Matrix {
977 columns,
978 rows,
979 scalar: left_scalar,
980 },
981 },
982 Span::default(),
983 );
984
985 self.add_expression(Expression::Compose { ty, components }, meta)?
987 }
988 _ => {
989 self.add_expression(Expression::Binary { left, op, right }, meta)?
990 }
991 }
992 }
993 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
994 }
995 }
996 HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
997 let expr = self
998 .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
999 .0;
1000
1001 self.add_expression(Expression::Unary { op, expr }, meta)?
1002 }
1003 HirExprKind::Variable(ref var) => match pos {
1004 ExprPos::Lhs => {
1005 if !var.mutable {
1006 frontend.errors.push(Error {
1007 kind: ErrorKind::SemanticError(
1008 "Variable cannot be used in LHS position".into(),
1009 ),
1010 meta,
1011 })
1012 }
1013
1014 var.expr
1015 }
1016 ExprPos::AccessBase { constant_index } => {
1017 if !constant_index {
1021 if let Some((constant, ty)) = var.constant {
1022 let init = self
1023 .add_expression(Expression::Constant(constant), Span::default())?;
1024 let local = self.locals.append(
1025 LocalVariable {
1026 name: None,
1027 ty,
1028 init: Some(init),
1029 },
1030 Span::default(),
1031 );
1032
1033 self.add_expression(Expression::LocalVariable(local), Span::default())?
1034 } else {
1035 var.expr
1036 }
1037 } else {
1038 var.expr
1039 }
1040 }
1041 _ if var.load => {
1042 self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1043 }
1044 ExprPos::Rhs => {
1045 if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1046 self.add_expression(Expression::Constant(constant), meta)?
1047 } else {
1048 if self.is_const {
1050 if let Expression::Override(o) = self.expressions[var.expr] {
1051 self.add_expression(Expression::Override(o), meta)?
1053 } else {
1054 var.expr
1055 }
1056 } else {
1057 var.expr
1058 }
1059 }
1060 }
1061 },
1062 HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1063 let maybe_expr = frontend.function_or_constructor_call(
1064 self,
1065 stmt,
1066 call.kind.clone(),
1067 &call.args,
1068 meta,
1069 )?;
1070 return Ok((maybe_expr, meta));
1071 }
1072 HirExprKind::Conditional {
1078 condition,
1079 accept,
1080 reject,
1081 } if ExprPos::Lhs != pos => {
1082 let condition = self
1098 .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1099 .0;
1100
1101 let (mut accept_body, (mut accept, accept_meta)) =
1102 self.new_body_with_ret(|ctx| {
1103 ctx.lower_expect_inner(stmt, frontend, accept, pos)
1105 })?;
1106
1107 let (mut reject_body, (mut reject, reject_meta)) =
1108 self.new_body_with_ret(|ctx| {
1109 ctx.lower_expect_inner(stmt, frontend, reject, pos)
1111 })?;
1112
1113 if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1116 self.expr_scalar_components(accept, accept_meta)?
1118 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1119 self.expr_scalar_components(reject, reject_meta)?
1120 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1121 ) {
1122 match accept_power.cmp(&reject_power) {
1123 core::cmp::Ordering::Less => {
1124 accept_body = self.with_body(accept_body, |ctx| {
1125 ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1126 Ok(())
1127 })?;
1128 }
1129 core::cmp::Ordering::Equal => {}
1130 core::cmp::Ordering::Greater => {
1131 reject_body = self.with_body(reject_body, |ctx| {
1132 ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1133 Ok(())
1134 })?;
1135 }
1136 }
1137 }
1138
1139 let ty = self.resolve_type_handle(accept, accept_meta)?;
1143
1144 let local = self.locals.append(
1146 LocalVariable {
1147 name: None,
1148 ty,
1149 init: None,
1150 },
1151 meta,
1152 );
1153
1154 let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1155
1156 accept_body.push(
1158 Statement::Store {
1159 pointer: local_expr,
1160 value: accept,
1161 },
1162 accept_meta,
1163 );
1164 reject_body.push(
1165 Statement::Store {
1166 pointer: local_expr,
1167 value: reject,
1168 },
1169 reject_meta,
1170 );
1171
1172 self.body.push(
1175 Statement::If {
1176 condition,
1177 accept: accept_body,
1178 reject: reject_body,
1179 },
1180 meta,
1181 );
1182
1183 self.add_expression(
1186 Expression::Load {
1187 pointer: local_expr,
1188 },
1189 meta,
1190 )?
1191 }
1192 HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1193 let (pointer, ptr_meta) =
1194 self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1195 let (mut value, value_meta) =
1196 self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1197
1198 let ty = match *self.resolve_type(pointer, ptr_meta)? {
1199 TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1200 ref ty => ty,
1201 };
1202
1203 if let Some(scalar) = scalar_components(ty) {
1204 self.implicit_conversion(&mut value, value_meta, scalar)?;
1205 }
1206
1207 self.lower_store(pointer, value, meta)?;
1208
1209 value
1210 }
1211 HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1212 let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1213 let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1214 pointer
1215 } else {
1216 self.add_expression(Expression::Load { pointer }, meta)?
1217 };
1218
1219 let res = match *self.resolve_type(left, meta)? {
1220 TypeInner::Scalar(scalar) => {
1221 let ty = TypeInner::Scalar(scalar);
1222 Literal::one(scalar).map(|i| (ty, i, None, None))
1223 }
1224 TypeInner::Vector { size, scalar } => {
1225 let ty = TypeInner::Vector { size, scalar };
1226 Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1227 }
1228 TypeInner::Matrix {
1229 columns,
1230 rows,
1231 scalar,
1232 } => {
1233 let ty = TypeInner::Matrix {
1234 columns,
1235 rows,
1236 scalar,
1237 };
1238 Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1239 }
1240 _ => None,
1241 };
1242 let (ty_inner, literal, rows, columns) = match res {
1243 Some(res) => res,
1244 None => {
1245 frontend.errors.push(Error {
1246 kind: ErrorKind::SemanticError(
1247 "Increment/decrement only works on scalar/vector/matrix".into(),
1248 ),
1249 meta,
1250 });
1251 return Ok((Some(left), meta));
1252 }
1253 };
1254
1255 let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1256
1257 if let Some(size) = rows {
1262 right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1263
1264 if let Some(cols) = columns {
1265 let ty = self.module.types.insert(
1266 Type {
1267 name: None,
1268 inner: ty_inner,
1269 },
1270 meta,
1271 );
1272
1273 right = self.add_expression(
1274 Expression::Compose {
1275 ty,
1276 components: core::iter::repeat_n(right, cols as usize).collect(),
1277 },
1278 meta,
1279 )?;
1280 }
1281 }
1282
1283 let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1284
1285 self.lower_store(pointer, value, meta)?;
1286
1287 if postfix {
1288 left
1289 } else {
1290 value
1291 }
1292 }
1293 HirExprKind::Method {
1294 expr: object,
1295 ref name,
1296 ref args,
1297 } if ExprPos::Lhs != pos => {
1298 let args = args
1299 .iter()
1300 .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1301 .collect::<Result<Vec<_>>>()?;
1302 match name.as_ref() {
1303 "length" => {
1304 if !args.is_empty() {
1305 frontend.errors.push(Error {
1306 kind: ErrorKind::SemanticError(
1307 ".length() doesn't take any arguments".into(),
1308 ),
1309 meta,
1310 });
1311 }
1312 let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1313 let array_type = self.resolve_type(lowered_array, meta)?;
1314
1315 match *array_type {
1316 TypeInner::Array {
1317 size: crate::ArraySize::Constant(size),
1318 ..
1319 } => {
1320 let mut array_length = self.add_expression(
1321 Expression::Literal(Literal::U32(size.get())),
1322 meta,
1323 )?;
1324 self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1325 array_length
1326 }
1327 _ => {
1329 let mut array_length = self
1330 .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1331 self.conversion(&mut array_length, meta, Scalar::I32)?;
1332 array_length
1333 }
1334 }
1335 }
1336
1337 _ => {
1338 return Err(Error {
1339 kind: ErrorKind::SemanticError(
1340 format!("unknown method '{name}'").into(),
1341 ),
1342 meta,
1343 });
1344 }
1345 }
1346 }
1347 HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => {
1348 let mut last_handle = None;
1349 for expr in exprs.iter() {
1350 let (handle, _) =
1351 self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?;
1352 last_handle = Some(handle);
1353 }
1354 match last_handle {
1355 Some(handle) => handle,
1356 None => unreachable!(),
1357 }
1358 }
1359 _ => {
1360 return Err(Error {
1361 kind: ErrorKind::SemanticError(
1362 format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1363 .into(),
1364 ),
1365 meta,
1366 })
1367 }
1368 };
1369
1370 log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1371
1372 Ok((Some(handle), meta))
1373 }
1374
1375 pub fn expr_scalar_components(
1376 &mut self,
1377 expr: Handle<Expression>,
1378 meta: Span,
1379 ) -> Result<Option<Scalar>> {
1380 let ty = self.resolve_type(expr, meta)?;
1381 Ok(scalar_components(ty))
1382 }
1383
1384 pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1385 Ok(self
1386 .expr_scalar_components(expr, meta)?
1387 .and_then(type_power))
1388 }
1389
1390 pub fn conversion(
1391 &mut self,
1392 expr: &mut Handle<Expression>,
1393 meta: Span,
1394 scalar: Scalar,
1395 ) -> Result<()> {
1396 *expr = self.add_expression(
1397 Expression::As {
1398 expr: *expr,
1399 kind: scalar.kind,
1400 convert: Some(scalar.width),
1401 },
1402 meta,
1403 )?;
1404
1405 Ok(())
1406 }
1407
1408 pub fn implicit_conversion(
1409 &mut self,
1410 expr: &mut Handle<Expression>,
1411 meta: Span,
1412 scalar: Scalar,
1413 ) -> Result<()> {
1414 if let (Some(tgt_power), Some(expr_power)) =
1415 (type_power(scalar), self.expr_power(*expr, meta)?)
1416 {
1417 if tgt_power > expr_power {
1418 self.conversion(expr, meta, scalar)?;
1419 }
1420 }
1421
1422 Ok(())
1423 }
1424
1425 pub fn forced_conversion(
1426 &mut self,
1427 expr: &mut Handle<Expression>,
1428 meta: Span,
1429 scalar: Scalar,
1430 ) -> Result<()> {
1431 if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1432 if expr_scalar != scalar {
1433 self.conversion(expr, meta, scalar)?;
1434 }
1435 }
1436
1437 Ok(())
1438 }
1439
1440 pub fn binary_implicit_conversion(
1441 &mut self,
1442 left: &mut Handle<Expression>,
1443 left_meta: Span,
1444 right: &mut Handle<Expression>,
1445 right_meta: Span,
1446 ) -> Result<()> {
1447 let left_components = self.expr_scalar_components(*left, left_meta)?;
1448 let right_components = self.expr_scalar_components(*right, right_meta)?;
1449
1450 if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1451 left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1452 right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1453 ) {
1454 match left_power.cmp(&right_power) {
1455 core::cmp::Ordering::Less => {
1456 self.conversion(left, left_meta, right_scalar)?;
1457 }
1458 core::cmp::Ordering::Equal => {}
1459 core::cmp::Ordering::Greater => {
1460 self.conversion(right, right_meta, left_scalar)?;
1461 }
1462 }
1463 }
1464
1465 Ok(())
1466 }
1467
1468 pub fn implicit_splat(
1469 &mut self,
1470 expr: &mut Handle<Expression>,
1471 meta: Span,
1472 vector_size: Option<VectorSize>,
1473 ) -> Result<()> {
1474 let expr_type = self.resolve_type(*expr, meta)?;
1475
1476 if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1477 *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1478 }
1479
1480 Ok(())
1481 }
1482
1483 pub fn vector_resize(
1484 &mut self,
1485 size: VectorSize,
1486 vector: Handle<Expression>,
1487 meta: Span,
1488 ) -> Result<Handle<Expression>> {
1489 self.add_expression(
1490 Expression::Swizzle {
1491 size,
1492 vector,
1493 pattern: crate::SwizzleComponent::XYZW,
1494 },
1495 meta,
1496 )
1497 }
1498}
1499
1500impl Index<Handle<Expression>> for Context<'_> {
1501 type Output = Expression;
1502
1503 fn index(&self, index: Handle<Expression>) -> &Self::Output {
1504 if self.is_const {
1505 &self.module.global_expressions[index]
1506 } else {
1507 &self.expressions[index]
1508 }
1509 }
1510}
1511
1512#[derive(Debug)]
1517pub struct StmtContext {
1518 pub hir_exprs: Arena<HirExpr>,
1521}
1522
1523impl StmtContext {
1524 const fn new() -> Self {
1525 StmtContext {
1526 hir_exprs: Arena::new(),
1527 }
1528 }
1529}