1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5 ast::{
6 GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7 VariableReference,
8 },
9 error::{Error, ErrorKind},
10 types::{scalar_components, type_power},
11 Frontend, Result,
12};
13use crate::{
14 front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15 Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16 Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22 Lhs,
24 Rhs,
26 AccessBase {
29 constant_index: bool,
31 },
32}
33
34impl ExprPos {
35 const fn maybe_access_base(&self, constant_index: bool) -> Self {
37 match *self {
38 ExprPos::Lhs
39 | ExprPos::AccessBase {
40 constant_index: false,
41 } => *self,
42 _ => ExprPos::AccessBase { constant_index },
43 }
44 }
45}
46
47#[derive(Debug)]
48pub struct Context<'a> {
49 pub expressions: Arena<Expression>,
50 pub locals: Arena<LocalVariable>,
51
52 pub arguments: Vec<FunctionArgument>,
61
62 pub parameters: Vec<Handle<Type>>,
70 pub parameters_info: Vec<ParameterInfo>,
71
72 pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73 pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75 pub const_typifier: Typifier,
76 pub typifier: Typifier,
77 layouter: Layouter,
78 emitter: Emitter,
79 stmt_ctx: Option<StmtContext>,
80 pub body: Block,
81 pub module: &'a mut crate::Module,
82 pub is_const: bool,
83 pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85 pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90 pub fn new(
91 frontend: &Frontend,
92 module: &'a mut crate::Module,
93 is_const: bool,
94 global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95 ) -> Result<Self> {
96 let mut this = Context {
97 expressions: Arena::new(),
98 locals: Arena::new(),
99 arguments: Vec::new(),
100
101 parameters: Vec::new(),
102 parameters_info: Vec::new(),
103
104 symbol_table: crate::front::SymbolTable::default(),
105 samplers: FastHashMap::default(),
106
107 const_typifier: Typifier::new(),
108 typifier: Typifier::new(),
109 layouter: Layouter::default(),
110 emitter: Emitter::default(),
111 stmt_ctx: Some(StmtContext::new()),
112 body: Block::new(),
113 module,
114 is_const: false,
115 local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116 global_expression_kind_tracker,
117 };
118
119 this.emit_start();
120
121 for &(ref name, lookup) in frontend.global_variables.iter() {
122 this.add_global(name, lookup)?
123 }
124 this.is_const = is_const;
125
126 Ok(this)
127 }
128
129 pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130 where
131 F: FnOnce(&mut Self) -> Result<()>,
132 {
133 self.new_body_with_ret(cb).map(|(b, _)| b)
134 }
135
136 pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137 where
138 F: FnOnce(&mut Self) -> Result<R>,
139 {
140 self.emit_restart();
141 let old_body = core::mem::replace(&mut self.body, Block::new());
142 let res = cb(self);
143 self.emit_restart();
144 let new_body = core::mem::replace(&mut self.body, old_body);
145 res.map(|r| (new_body, r))
146 }
147
148 pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149 where
150 F: FnOnce(&mut Self) -> Result<()>,
151 {
152 self.emit_restart();
153 let old_body = core::mem::replace(&mut self.body, body);
154 let res = cb(self);
155 self.emit_restart();
156 let body = core::mem::replace(&mut self.body, old_body);
157 res.map(|_| body)
158 }
159
160 pub fn add_global(
161 &mut self,
162 name: &str,
163 GlobalLookup {
164 kind,
165 entry_arg,
166 mutable,
167 }: GlobalLookup,
168 ) -> Result<()> {
169 let (expr, load, constant) = match kind {
170 GlobalLookupKind::Variable(v) => {
171 let span = self.module.global_variables.get_span(v);
172 (
173 self.add_expression(Expression::GlobalVariable(v), span)?,
174 self.module.global_variables[v].space != AddressSpace::Handle,
175 None,
176 )
177 }
178 GlobalLookupKind::BlockSelect(handle, index) => {
179 let span = self.module.global_variables.get_span(handle);
180 let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181 let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183 (
184 expr,
185 {
186 let ty = self.module.global_variables[handle].ty;
187
188 match self.module.types[ty].inner {
189 TypeInner::Struct { ref members, .. } => {
190 if let TypeInner::Array {
191 size: crate::ArraySize::Dynamic,
192 ..
193 } = self.module.types[members[index as usize].ty].inner
194 {
195 false
196 } else {
197 true
198 }
199 }
200 _ => true,
201 }
202 },
203 None,
204 )
205 }
206 GlobalLookupKind::Constant(v, ty) => {
207 let span = self.module.constants.get_span(v);
208 (
209 self.add_expression(Expression::Constant(v), span)?,
210 false,
211 Some((v, ty)),
212 )
213 }
214 };
215
216 let var = VariableReference {
217 expr,
218 load,
219 mutable,
220 constant,
221 entry_arg,
222 };
223
224 self.symbol_table.add(name.into(), var);
225
226 Ok(())
227 }
228
229 #[inline]
235 pub fn emit_start(&mut self) {
236 self.emitter.start(&self.expressions)
237 }
238
239 pub fn emit_end(&mut self) {
248 self.body.extend(self.emitter.finish(&self.expressions))
249 }
250
251 pub fn emit_restart(&mut self) {
258 self.emit_end();
259 self.emit_start()
260 }
261
262 pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
263 let mut eval = if self.is_const {
264 crate::proc::ConstantEvaluator::for_glsl_module(
265 self.module,
266 self.global_expression_kind_tracker,
267 &mut self.layouter,
268 )
269 } else {
270 crate::proc::ConstantEvaluator::for_glsl_function(
271 self.module,
272 &mut self.expressions,
273 &mut self.local_expression_kind_tracker,
274 &mut self.layouter,
275 &mut self.emitter,
276 &mut self.body,
277 )
278 };
279
280 eval.try_eval_and_append(expr, meta).map_err(|e| Error {
281 kind: e.into(),
282 meta,
283 })
284 }
285
286 pub fn add_local_var(
291 &mut self,
292 name: String,
293 expr: Handle<Expression>,
294 mutable: bool,
295 ) -> Option<VariableReference> {
296 let var = VariableReference {
297 expr,
298 load: true,
299 mutable,
300 constant: None,
301 entry_arg: None,
302 };
303
304 self.symbol_table.add(name, var)
305 }
306
307 pub fn add_function_arg(
309 &mut self,
310 name_meta: Option<(String, Span)>,
311 ty: Handle<Type>,
312 qualifier: ParameterQualifier,
313 ) -> Result<()> {
314 let index = self.arguments.len();
315 let mut arg = FunctionArgument {
316 name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
317 ty,
318 binding: None,
319 };
320 self.parameters.push(ty);
321
322 let opaque = match self.module.types[ty].inner {
323 TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
324 _ => false,
325 };
326
327 if qualifier.is_lhs() {
328 let span = self.module.types.get_span(arg.ty);
329 arg.ty = self.module.types.insert(
330 Type {
331 name: None,
332 inner: TypeInner::Pointer {
333 base: arg.ty,
334 space: AddressSpace::Function,
335 },
336 },
337 span,
338 )
339 }
340
341 self.arguments.push(arg);
342
343 self.parameters_info.push(ParameterInfo {
344 qualifier,
345 depth: false,
346 });
347
348 if let Some((name, meta)) = name_meta {
349 let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
350 let mutable = qualifier != ParameterQualifier::Const && !opaque;
351 let load = qualifier.is_lhs();
352
353 let var = if mutable && !load {
354 let handle = self.locals.append(
355 LocalVariable {
356 name: Some(name.clone()),
357 ty,
358 init: None,
359 },
360 meta,
361 );
362 let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
363
364 self.emit_restart();
365
366 self.body.push(
367 Statement::Store {
368 pointer: local_expr,
369 value: expr,
370 },
371 meta,
372 );
373
374 VariableReference {
375 expr: local_expr,
376 load: true,
377 mutable,
378 constant: None,
379 entry_arg: None,
380 }
381 } else {
382 VariableReference {
383 expr,
384 load,
385 mutable,
386 constant: None,
387 entry_arg: None,
388 }
389 };
390
391 self.symbol_table.add(name, var);
392 }
393
394 Ok(())
395 }
396
397 #[must_use]
404 pub fn stmt_ctx(&mut self) -> StmtContext {
405 self.stmt_ctx.take().unwrap()
406 }
407
408 pub fn lower(
413 &mut self,
414 mut stmt: StmtContext,
415 frontend: &mut Frontend,
416 expr: Handle<HirExpr>,
417 pos: ExprPos,
418 ) -> Result<(Option<Handle<Expression>>, Span)> {
419 let res = self.lower_inner(&stmt, frontend, expr, pos);
420
421 stmt.hir_exprs.clear();
422 self.stmt_ctx = Some(stmt);
423
424 res
425 }
426
427 pub fn lower_expect(
433 &mut self,
434 mut stmt: StmtContext,
435 frontend: &mut Frontend,
436 expr: Handle<HirExpr>,
437 pos: ExprPos,
438 ) -> Result<(Handle<Expression>, Span)> {
439 let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
440
441 stmt.hir_exprs.clear();
442 self.stmt_ctx = Some(stmt);
443
444 res
445 }
446
447 pub fn lower_expect_inner(
453 &mut self,
454 stmt: &StmtContext,
455 frontend: &mut Frontend,
456 expr: Handle<HirExpr>,
457 pos: ExprPos,
458 ) -> Result<(Handle<Expression>, Span)> {
459 let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
460
461 let expr = match maybe_expr {
462 Some(e) => e,
463 None => {
464 return Err(Error {
465 kind: ErrorKind::SemanticError("Expression returns void".into()),
466 meta,
467 })
468 }
469 };
470
471 Ok((expr, meta))
472 }
473
474 fn lower_store(
475 &mut self,
476 pointer: Handle<Expression>,
477 value: Handle<Expression>,
478 meta: Span,
479 ) -> Result<()> {
480 if let Expression::Swizzle {
481 size,
482 mut vector,
483 pattern,
484 } = self.expressions[pointer]
485 {
486 let size = match size {
489 VectorSize::Bi => 2,
490 VectorSize::Tri => 3,
491 VectorSize::Quad => 4,
492 };
493
494 if let Expression::Load { pointer } = self.expressions[vector] {
495 vector = pointer;
496 }
497
498 #[allow(clippy::needless_range_loop)]
499 for index in 0..size {
500 let dst = self.add_expression(
501 Expression::AccessIndex {
502 base: vector,
503 index: pattern[index].index(),
504 },
505 meta,
506 )?;
507 let src = self.add_expression(
508 Expression::AccessIndex {
509 base: value,
510 index: index as u32,
511 },
512 meta,
513 )?;
514
515 self.emit_restart();
516
517 self.body.push(
518 Statement::Store {
519 pointer: dst,
520 value: src,
521 },
522 meta,
523 );
524 }
525 } else {
526 self.emit_restart();
527
528 self.body.push(Statement::Store { pointer, value }, meta);
529 }
530
531 Ok(())
532 }
533
534 fn lower_inner(
536 &mut self,
537 stmt: &StmtContext,
538 frontend: &mut Frontend,
539 expr: Handle<HirExpr>,
540 pos: ExprPos,
541 ) -> Result<(Option<Handle<Expression>>, Span)> {
542 let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
543
544 log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
545
546 let handle = match *kind {
547 HirExprKind::Access { base, index } => {
548 let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
549 let maybe_constant_index = match pos {
550 ExprPos::Lhs => None,
553 _ => self
554 .module
555 .to_ctx()
556 .eval_expr_to_u32_from(index, &self.expressions)
557 .ok(),
558 };
559
560 let base = self
561 .lower_expect_inner(
562 stmt,
563 frontend,
564 base,
565 pos.maybe_access_base(maybe_constant_index.is_some()),
566 )?
567 .0;
568
569 let pointer = maybe_constant_index
570 .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
571 .unwrap_or_else(|| {
572 self.add_expression(Expression::Access { base, index }, meta)
573 })?;
574
575 if ExprPos::Rhs == pos {
576 let resolved = self.resolve_type(pointer, meta)?;
577 if resolved.pointer_space().is_some() {
578 return Ok((
579 Some(self.add_expression(Expression::Load { pointer }, meta)?),
580 meta,
581 ));
582 }
583 }
584
585 pointer
586 }
587 HirExprKind::Select { base, ref field } => {
588 let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
589
590 frontend.field_selection(self, pos, base, field, meta)?
591 }
592 HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
593 self.add_expression(Expression::Literal(literal), meta)?
594 }
595 HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
596 let (mut left, left_meta) =
597 self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
598 let (mut right, right_meta) =
599 self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
600
601 match op {
602 BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
603 self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
604 }
605 _ => self
606 .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
607 }
608
609 self.typifier_grow(left, left_meta)?;
610 self.typifier_grow(right, right_meta)?;
611
612 let left_inner = self.get_type(left);
613 let right_inner = self.get_type(right);
614
615 match (left_inner, right_inner) {
616 (
617 &TypeInner::Matrix {
618 columns: left_columns,
619 rows: left_rows,
620 scalar: left_scalar,
621 },
622 &TypeInner::Matrix {
623 columns: right_columns,
624 rows: right_rows,
625 scalar: right_scalar,
626 },
627 ) => {
628 let dimensions_ok = if op == BinaryOperator::Multiply {
629 left_columns == right_rows
630 } else {
631 left_columns == right_columns && left_rows == right_rows
632 };
633
634 if !dimensions_ok || left_scalar != right_scalar {
636 frontend.errors.push(Error {
637 kind: ErrorKind::SemanticError(
638 format!(
639 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
640 )
641 .into(),
642 ),
643 meta,
644 })
645 }
646
647 match op {
648 BinaryOperator::Divide => {
649 let mut components = Vec::with_capacity(left_columns as usize);
652
653 for index in 0..left_columns as u32 {
654 let left_vector = self.add_expression(
656 Expression::AccessIndex { base: left, index },
657 meta,
658 )?;
659 let right_vector = self.add_expression(
660 Expression::AccessIndex { base: right, index },
661 meta,
662 )?;
663
664 let column = self.add_expression(
666 Expression::Binary {
667 op,
668 left: left_vector,
669 right: right_vector,
670 },
671 meta,
672 )?;
673
674 components.push(column)
675 }
676
677 let ty = self.module.types.insert(
678 Type {
679 name: None,
680 inner: TypeInner::Matrix {
681 columns: left_columns,
682 rows: left_rows,
683 scalar: left_scalar,
684 },
685 },
686 Span::default(),
687 );
688
689 self.add_expression(Expression::Compose { ty, components }, meta)?
691 }
692 BinaryOperator::Equal | BinaryOperator::NotEqual => {
693 let equals = op == BinaryOperator::Equal;
699
700 let (op, combine, fun) = match equals {
701 true => (
702 BinaryOperator::Equal,
703 BinaryOperator::LogicalAnd,
704 RelationalFunction::All,
705 ),
706 false => (
707 BinaryOperator::NotEqual,
708 BinaryOperator::LogicalOr,
709 RelationalFunction::Any,
710 ),
711 };
712
713 let mut root = None;
714
715 for index in 0..left_columns as u32 {
716 let left_vector = self.add_expression(
718 Expression::AccessIndex { base: left, index },
719 meta,
720 )?;
721 let right_vector = self.add_expression(
722 Expression::AccessIndex { base: right, index },
723 meta,
724 )?;
725
726 let argument = self.add_expression(
727 Expression::Binary {
728 op,
729 left: left_vector,
730 right: right_vector,
731 },
732 meta,
733 )?;
734
735 let compare = self.add_expression(
739 Expression::Relational { fun, argument },
740 meta,
741 )?;
742
743 root = Some(match root {
745 Some(right) => self.add_expression(
746 Expression::Binary {
747 op: combine,
748 left: compare,
749 right,
750 },
751 meta,
752 )?,
753 None => compare,
754 });
755 }
756
757 root.unwrap()
758 }
759 _ => {
760 self.add_expression(Expression::Binary { left, op, right }, meta)?
761 }
762 }
763 }
764 (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
765 BinaryOperator::Equal | BinaryOperator::NotEqual => {
766 let equals = op == BinaryOperator::Equal;
767
768 let (op, fun) = match equals {
769 true => (BinaryOperator::Equal, RelationalFunction::All),
770 false => (BinaryOperator::NotEqual, RelationalFunction::Any),
771 };
772
773 let argument =
774 self.add_expression(Expression::Binary { op, left, right }, meta)?;
775
776 self.add_expression(Expression::Relational { fun, argument }, meta)?
777 }
778 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
779 },
780 (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
781 BinaryOperator::Add
782 | BinaryOperator::Subtract
783 | BinaryOperator::Divide
784 | BinaryOperator::And
785 | BinaryOperator::ExclusiveOr
786 | BinaryOperator::InclusiveOr
787 | BinaryOperator::ShiftLeft
788 | BinaryOperator::ShiftRight => {
789 let scalar_vector = self
790 .add_expression(Expression::Splat { size, value: right }, meta)?;
791
792 self.add_expression(
793 Expression::Binary {
794 op,
795 left,
796 right: scalar_vector,
797 },
798 meta,
799 )?
800 }
801 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
802 },
803 (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
804 BinaryOperator::Add
805 | BinaryOperator::Subtract
806 | BinaryOperator::Divide
807 | BinaryOperator::And
808 | BinaryOperator::ExclusiveOr
809 | BinaryOperator::InclusiveOr => {
810 let scalar_vector =
811 self.add_expression(Expression::Splat { size, value: left }, meta)?;
812
813 self.add_expression(
814 Expression::Binary {
815 op,
816 left: scalar_vector,
817 right,
818 },
819 meta,
820 )?
821 }
822 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
823 },
824 (
825 &TypeInner::Scalar(left_scalar),
826 &TypeInner::Matrix {
827 rows,
828 columns,
829 scalar: right_scalar,
830 },
831 ) => {
832 if left_scalar != right_scalar {
834 frontend.errors.push(Error {
835 kind: ErrorKind::SemanticError(
836 format!(
837 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
838 )
839 .into(),
840 ),
841 meta,
842 })
843 }
844
845 match op {
846 BinaryOperator::Divide
847 | BinaryOperator::Add
848 | BinaryOperator::Subtract => {
849 let scalar_vector = self.add_expression(
854 Expression::Splat {
855 size: rows,
856 value: left,
857 },
858 meta,
859 )?;
860
861 let mut components = Vec::with_capacity(columns as usize);
862
863 for index in 0..columns as u32 {
864 let matrix_column = self.add_expression(
866 Expression::AccessIndex { base: right, index },
867 meta,
868 )?;
869
870 let column = self.add_expression(
873 Expression::Binary {
874 op,
875 left: scalar_vector,
876 right: matrix_column,
877 },
878 meta,
879 )?;
880
881 components.push(column)
882 }
883
884 let ty = self.module.types.insert(
885 Type {
886 name: None,
887 inner: TypeInner::Matrix {
888 columns,
889 rows,
890 scalar: left_scalar,
891 },
892 },
893 Span::default(),
894 );
895
896 self.add_expression(Expression::Compose { ty, components }, meta)?
898 }
899 _ => {
900 self.add_expression(Expression::Binary { left, op, right }, meta)?
901 }
902 }
903 }
904 (
905 &TypeInner::Matrix {
906 rows,
907 columns,
908 scalar: left_scalar,
909 },
910 &TypeInner::Scalar(right_scalar),
911 ) => {
912 if left_scalar != right_scalar {
914 frontend.errors.push(Error {
915 kind: ErrorKind::SemanticError(
916 format!(
917 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
918 )
919 .into(),
920 ),
921 meta,
922 })
923 }
924
925 match op {
926 BinaryOperator::Divide
927 | BinaryOperator::Add
928 | BinaryOperator::Subtract => {
929 let scalar_vector = self.add_expression(
935 Expression::Splat {
936 size: rows,
937 value: right,
938 },
939 meta,
940 )?;
941
942 let mut components = Vec::with_capacity(columns as usize);
943
944 for index in 0..columns as u32 {
945 let matrix_column = self.add_expression(
947 Expression::AccessIndex { base: left, index },
948 meta,
949 )?;
950
951 let column = self.add_expression(
954 Expression::Binary {
955 op,
956 left: matrix_column,
957 right: scalar_vector,
958 },
959 meta,
960 )?;
961
962 components.push(column)
963 }
964
965 let ty = self.module.types.insert(
966 Type {
967 name: None,
968 inner: TypeInner::Matrix {
969 columns,
970 rows,
971 scalar: left_scalar,
972 },
973 },
974 Span::default(),
975 );
976
977 self.add_expression(Expression::Compose { ty, components }, meta)?
979 }
980 _ => {
981 self.add_expression(Expression::Binary { left, op, right }, meta)?
982 }
983 }
984 }
985 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
986 }
987 }
988 HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
989 let expr = self
990 .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
991 .0;
992
993 self.add_expression(Expression::Unary { op, expr }, meta)?
994 }
995 HirExprKind::Variable(ref var) => match pos {
996 ExprPos::Lhs => {
997 if !var.mutable {
998 frontend.errors.push(Error {
999 kind: ErrorKind::SemanticError(
1000 "Variable cannot be used in LHS position".into(),
1001 ),
1002 meta,
1003 })
1004 }
1005
1006 var.expr
1007 }
1008 ExprPos::AccessBase { constant_index } => {
1009 if !constant_index {
1013 if let Some((constant, ty)) = var.constant {
1014 let init = self
1015 .add_expression(Expression::Constant(constant), Span::default())?;
1016 let local = self.locals.append(
1017 LocalVariable {
1018 name: None,
1019 ty,
1020 init: Some(init),
1021 },
1022 Span::default(),
1023 );
1024
1025 self.add_expression(Expression::LocalVariable(local), Span::default())?
1026 } else {
1027 var.expr
1028 }
1029 } else {
1030 var.expr
1031 }
1032 }
1033 _ if var.load => {
1034 self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1035 }
1036 ExprPos::Rhs => {
1037 if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1038 self.add_expression(Expression::Constant(constant), meta)?
1039 } else {
1040 var.expr
1041 }
1042 }
1043 },
1044 HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1045 let maybe_expr = frontend.function_or_constructor_call(
1046 self,
1047 stmt,
1048 call.kind.clone(),
1049 &call.args,
1050 meta,
1051 )?;
1052 return Ok((maybe_expr, meta));
1053 }
1054 HirExprKind::Conditional {
1060 condition,
1061 accept,
1062 reject,
1063 } if ExprPos::Lhs != pos => {
1064 let condition = self
1080 .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1081 .0;
1082
1083 let (mut accept_body, (mut accept, accept_meta)) =
1084 self.new_body_with_ret(|ctx| {
1085 ctx.lower_expect_inner(stmt, frontend, accept, pos)
1087 })?;
1088
1089 let (mut reject_body, (mut reject, reject_meta)) =
1090 self.new_body_with_ret(|ctx| {
1091 ctx.lower_expect_inner(stmt, frontend, reject, pos)
1093 })?;
1094
1095 if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1098 self.expr_scalar_components(accept, accept_meta)?
1100 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1101 self.expr_scalar_components(reject, reject_meta)?
1102 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1103 ) {
1104 match accept_power.cmp(&reject_power) {
1105 core::cmp::Ordering::Less => {
1106 accept_body = self.with_body(accept_body, |ctx| {
1107 ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1108 Ok(())
1109 })?;
1110 }
1111 core::cmp::Ordering::Equal => {}
1112 core::cmp::Ordering::Greater => {
1113 reject_body = self.with_body(reject_body, |ctx| {
1114 ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1115 Ok(())
1116 })?;
1117 }
1118 }
1119 }
1120
1121 let ty = self.resolve_type_handle(accept, accept_meta)?;
1125
1126 let local = self.locals.append(
1128 LocalVariable {
1129 name: None,
1130 ty,
1131 init: None,
1132 },
1133 meta,
1134 );
1135
1136 let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1137
1138 accept_body.push(
1140 Statement::Store {
1141 pointer: local_expr,
1142 value: accept,
1143 },
1144 accept_meta,
1145 );
1146 reject_body.push(
1147 Statement::Store {
1148 pointer: local_expr,
1149 value: reject,
1150 },
1151 reject_meta,
1152 );
1153
1154 self.body.push(
1157 Statement::If {
1158 condition,
1159 accept: accept_body,
1160 reject: reject_body,
1161 },
1162 meta,
1163 );
1164
1165 self.add_expression(
1168 Expression::Load {
1169 pointer: local_expr,
1170 },
1171 meta,
1172 )?
1173 }
1174 HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1175 let (pointer, ptr_meta) =
1176 self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1177 let (mut value, value_meta) =
1178 self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1179
1180 let ty = match *self.resolve_type(pointer, ptr_meta)? {
1181 TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1182 ref ty => ty,
1183 };
1184
1185 if let Some(scalar) = scalar_components(ty) {
1186 self.implicit_conversion(&mut value, value_meta, scalar)?;
1187 }
1188
1189 self.lower_store(pointer, value, meta)?;
1190
1191 value
1192 }
1193 HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1194 let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1195 let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1196 pointer
1197 } else {
1198 self.add_expression(Expression::Load { pointer }, meta)?
1199 };
1200
1201 let res = match *self.resolve_type(left, meta)? {
1202 TypeInner::Scalar(scalar) => {
1203 let ty = TypeInner::Scalar(scalar);
1204 Literal::one(scalar).map(|i| (ty, i, None, None))
1205 }
1206 TypeInner::Vector { size, scalar } => {
1207 let ty = TypeInner::Vector { size, scalar };
1208 Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1209 }
1210 TypeInner::Matrix {
1211 columns,
1212 rows,
1213 scalar,
1214 } => {
1215 let ty = TypeInner::Matrix {
1216 columns,
1217 rows,
1218 scalar,
1219 };
1220 Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1221 }
1222 _ => None,
1223 };
1224 let (ty_inner, literal, rows, columns) = match res {
1225 Some(res) => res,
1226 None => {
1227 frontend.errors.push(Error {
1228 kind: ErrorKind::SemanticError(
1229 "Increment/decrement only works on scalar/vector/matrix".into(),
1230 ),
1231 meta,
1232 });
1233 return Ok((Some(left), meta));
1234 }
1235 };
1236
1237 let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1238
1239 if let Some(size) = rows {
1244 right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1245
1246 if let Some(cols) = columns {
1247 let ty = self.module.types.insert(
1248 Type {
1249 name: None,
1250 inner: ty_inner,
1251 },
1252 meta,
1253 );
1254
1255 right = self.add_expression(
1256 Expression::Compose {
1257 ty,
1258 components: core::iter::repeat_n(right, cols as usize).collect(),
1259 },
1260 meta,
1261 )?;
1262 }
1263 }
1264
1265 let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1266
1267 self.lower_store(pointer, value, meta)?;
1268
1269 if postfix {
1270 left
1271 } else {
1272 value
1273 }
1274 }
1275 HirExprKind::Method {
1276 expr: object,
1277 ref name,
1278 ref args,
1279 } if ExprPos::Lhs != pos => {
1280 let args = args
1281 .iter()
1282 .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1283 .collect::<Result<Vec<_>>>()?;
1284 match name.as_ref() {
1285 "length" => {
1286 if !args.is_empty() {
1287 frontend.errors.push(Error {
1288 kind: ErrorKind::SemanticError(
1289 ".length() doesn't take any arguments".into(),
1290 ),
1291 meta,
1292 });
1293 }
1294 let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1295 let array_type = self.resolve_type(lowered_array, meta)?;
1296
1297 match *array_type {
1298 TypeInner::Array {
1299 size: crate::ArraySize::Constant(size),
1300 ..
1301 } => {
1302 let mut array_length = self.add_expression(
1303 Expression::Literal(Literal::U32(size.get())),
1304 meta,
1305 )?;
1306 self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1307 array_length
1308 }
1309 _ => {
1311 let mut array_length = self
1312 .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1313 self.conversion(&mut array_length, meta, Scalar::I32)?;
1314 array_length
1315 }
1316 }
1317 }
1318 _ => {
1319 return Err(Error {
1320 kind: ErrorKind::SemanticError(
1321 format!("unknown method '{name}'").into(),
1322 ),
1323 meta,
1324 });
1325 }
1326 }
1327 }
1328 _ => {
1329 return Err(Error {
1330 kind: ErrorKind::SemanticError(
1331 format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1332 .into(),
1333 ),
1334 meta,
1335 })
1336 }
1337 };
1338
1339 log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1340
1341 Ok((Some(handle), meta))
1342 }
1343
1344 pub fn expr_scalar_components(
1345 &mut self,
1346 expr: Handle<Expression>,
1347 meta: Span,
1348 ) -> Result<Option<Scalar>> {
1349 let ty = self.resolve_type(expr, meta)?;
1350 Ok(scalar_components(ty))
1351 }
1352
1353 pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1354 Ok(self
1355 .expr_scalar_components(expr, meta)?
1356 .and_then(type_power))
1357 }
1358
1359 pub fn conversion(
1360 &mut self,
1361 expr: &mut Handle<Expression>,
1362 meta: Span,
1363 scalar: Scalar,
1364 ) -> Result<()> {
1365 *expr = self.add_expression(
1366 Expression::As {
1367 expr: *expr,
1368 kind: scalar.kind,
1369 convert: Some(scalar.width),
1370 },
1371 meta,
1372 )?;
1373
1374 Ok(())
1375 }
1376
1377 pub fn implicit_conversion(
1378 &mut self,
1379 expr: &mut Handle<Expression>,
1380 meta: Span,
1381 scalar: Scalar,
1382 ) -> Result<()> {
1383 if let (Some(tgt_power), Some(expr_power)) =
1384 (type_power(scalar), self.expr_power(*expr, meta)?)
1385 {
1386 if tgt_power > expr_power {
1387 self.conversion(expr, meta, scalar)?;
1388 }
1389 }
1390
1391 Ok(())
1392 }
1393
1394 pub fn forced_conversion(
1395 &mut self,
1396 expr: &mut Handle<Expression>,
1397 meta: Span,
1398 scalar: Scalar,
1399 ) -> Result<()> {
1400 if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1401 if expr_scalar != scalar {
1402 self.conversion(expr, meta, scalar)?;
1403 }
1404 }
1405
1406 Ok(())
1407 }
1408
1409 pub fn binary_implicit_conversion(
1410 &mut self,
1411 left: &mut Handle<Expression>,
1412 left_meta: Span,
1413 right: &mut Handle<Expression>,
1414 right_meta: Span,
1415 ) -> Result<()> {
1416 let left_components = self.expr_scalar_components(*left, left_meta)?;
1417 let right_components = self.expr_scalar_components(*right, right_meta)?;
1418
1419 if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1420 left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1421 right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1422 ) {
1423 match left_power.cmp(&right_power) {
1424 core::cmp::Ordering::Less => {
1425 self.conversion(left, left_meta, right_scalar)?;
1426 }
1427 core::cmp::Ordering::Equal => {}
1428 core::cmp::Ordering::Greater => {
1429 self.conversion(right, right_meta, left_scalar)?;
1430 }
1431 }
1432 }
1433
1434 Ok(())
1435 }
1436
1437 pub fn implicit_splat(
1438 &mut self,
1439 expr: &mut Handle<Expression>,
1440 meta: Span,
1441 vector_size: Option<VectorSize>,
1442 ) -> Result<()> {
1443 let expr_type = self.resolve_type(*expr, meta)?;
1444
1445 if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1446 *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1447 }
1448
1449 Ok(())
1450 }
1451
1452 pub fn vector_resize(
1453 &mut self,
1454 size: VectorSize,
1455 vector: Handle<Expression>,
1456 meta: Span,
1457 ) -> Result<Handle<Expression>> {
1458 self.add_expression(
1459 Expression::Swizzle {
1460 size,
1461 vector,
1462 pattern: crate::SwizzleComponent::XYZW,
1463 },
1464 meta,
1465 )
1466 }
1467}
1468
1469impl Index<Handle<Expression>> for Context<'_> {
1470 type Output = Expression;
1471
1472 fn index(&self, index: Handle<Expression>) -> &Self::Output {
1473 if self.is_const {
1474 &self.module.global_expressions[index]
1475 } else {
1476 &self.expressions[index]
1477 }
1478 }
1479}
1480
1481#[derive(Debug)]
1486pub struct StmtContext {
1487 pub hir_exprs: Arena<HirExpr>,
1490}
1491
1492impl StmtContext {
1493 const fn new() -> Self {
1494 StmtContext {
1495 hir_exprs: Arena::new(),
1496 }
1497 }
1498}