1use alloc::{format, string::String, vec::Vec};
2use core::ops::Index;
3
4use super::{
5 ast::{
6 GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
7 VariableReference,
8 },
9 error::{Error, ErrorKind},
10 types::{scalar_components, type_power},
11 Frontend, Result,
12};
13use crate::{
14 front::Typifier, proc::Emitter, proc::Layouter, AddressSpace, Arena, BinaryOperator, Block,
15 Expression, FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction,
16 Scalar, Span, Statement, Type, TypeInner, VectorSize,
17};
18
19#[derive(Clone, Copy, PartialEq, Eq, Debug)]
21pub enum ExprPos {
22 Lhs,
24 Rhs,
26 AccessBase {
29 constant_index: bool,
31 },
32}
33
34impl ExprPos {
35 const fn maybe_access_base(&self, constant_index: bool) -> Self {
37 match *self {
38 ExprPos::Lhs
39 | ExprPos::AccessBase {
40 constant_index: false,
41 } => *self,
42 _ => ExprPos::AccessBase { constant_index },
43 }
44 }
45}
46
47#[derive(Debug)]
48pub(crate) struct Context<'a> {
49 pub expressions: Arena<Expression>,
50 pub locals: Arena<LocalVariable>,
51
52 pub arguments: Vec<FunctionArgument>,
61
62 pub parameters: Vec<Handle<Type>>,
70 pub parameters_info: Vec<ParameterInfo>,
71
72 pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
73 pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
74
75 pub const_typifier: Typifier,
76 pub typifier: Typifier,
77 layouter: Layouter,
78 emitter: Emitter,
79 stmt_ctx: Option<StmtContext>,
80 pub body: Block,
81 pub module: &'a mut crate::Module,
82 pub is_const: bool,
83 pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
85 pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
87}
88
89impl<'a> Context<'a> {
90 pub fn new(
91 frontend: &Frontend,
92 module: &'a mut crate::Module,
93 is_const: bool,
94 global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
95 ) -> Result<Self> {
96 let mut this = Context {
97 expressions: Arena::new(),
98 locals: Arena::new(),
99 arguments: Vec::new(),
100
101 parameters: Vec::new(),
102 parameters_info: Vec::new(),
103
104 symbol_table: crate::front::SymbolTable::default(),
105 samplers: FastHashMap::default(),
106
107 const_typifier: Typifier::new(),
108 typifier: Typifier::new(),
109 layouter: Layouter::default(),
110 emitter: Emitter::default(),
111 stmt_ctx: Some(StmtContext::new()),
112 body: Block::new(),
113 module,
114 is_const: false,
115 local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
116 global_expression_kind_tracker,
117 };
118
119 this.emit_start();
120
121 for &(ref name, lookup) in frontend.global_variables.iter() {
122 this.add_global(name, lookup)?
123 }
124 this.is_const = is_const;
125
126 Ok(this)
127 }
128
129 pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
130 where
131 F: FnOnce(&mut Self) -> Result<()>,
132 {
133 self.new_body_with_ret(cb).map(|(b, _)| b)
134 }
135
136 pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
137 where
138 F: FnOnce(&mut Self) -> Result<R>,
139 {
140 self.emit_restart();
141 let old_body = core::mem::replace(&mut self.body, Block::new());
142 let res = cb(self);
143 self.emit_restart();
144 let new_body = core::mem::replace(&mut self.body, old_body);
145 res.map(|r| (new_body, r))
146 }
147
148 pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
149 where
150 F: FnOnce(&mut Self) -> Result<()>,
151 {
152 self.emit_restart();
153 let old_body = core::mem::replace(&mut self.body, body);
154 let res = cb(self);
155 self.emit_restart();
156 let body = core::mem::replace(&mut self.body, old_body);
157 res.map(|_| body)
158 }
159
160 pub fn add_global(
161 &mut self,
162 name: &str,
163 GlobalLookup {
164 kind,
165 entry_arg,
166 mutable,
167 }: GlobalLookup,
168 ) -> Result<()> {
169 let (expr, load, constant) = match kind {
170 GlobalLookupKind::Variable(v) => {
171 let span = self.module.global_variables.get_span(v);
172 (
173 self.add_expression(Expression::GlobalVariable(v), span)?,
174 self.module.global_variables[v].space != AddressSpace::Handle,
175 None,
176 )
177 }
178 GlobalLookupKind::BlockSelect(handle, index) => {
179 let span = self.module.global_variables.get_span(handle);
180 let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
181 let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
182
183 (
184 expr,
185 {
186 let ty = self.module.global_variables[handle].ty;
187
188 match self.module.types[ty].inner {
189 TypeInner::Struct { ref members, .. } => {
190 if let TypeInner::Array {
191 size: crate::ArraySize::Dynamic,
192 ..
193 } = self.module.types[members[index as usize].ty].inner
194 {
195 false
196 } else {
197 true
198 }
199 }
200 _ => true,
201 }
202 },
203 None,
204 )
205 }
206 GlobalLookupKind::Constant(v, ty) => {
207 let span = self.module.constants.get_span(v);
208 (
209 self.add_expression(Expression::Constant(v), span)?,
210 false,
211 Some((v, ty)),
212 )
213 }
214 GlobalLookupKind::Override(v, _ty) => {
215 let span = self.module.overrides.get_span(v);
216 (
217 self.add_expression(Expression::Override(v), span)?,
218 false,
219 None,
220 )
221 }
222 };
223
224 let var = VariableReference {
225 expr,
226 load,
227 mutable,
228 constant,
229 entry_arg,
230 };
231
232 self.symbol_table.add(name.into(), var);
233
234 Ok(())
235 }
236
237 #[inline]
243 pub fn emit_start(&mut self) {
244 self.emitter.start(&self.expressions)
245 }
246
247 pub fn emit_end(&mut self) {
256 self.body.extend(self.emitter.finish(&self.expressions))
257 }
258
259 pub fn emit_restart(&mut self) {
266 self.emit_end();
267 self.emit_start()
268 }
269
270 pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
271 let mut eval = if self.is_const {
272 crate::proc::ConstantEvaluator::for_glsl_module(
273 self.module,
274 self.global_expression_kind_tracker,
275 &mut self.layouter,
276 )
277 } else {
278 crate::proc::ConstantEvaluator::for_glsl_function(
279 self.module,
280 &mut self.expressions,
281 &mut self.local_expression_kind_tracker,
282 &mut self.layouter,
283 &mut self.emitter,
284 &mut self.body,
285 )
286 };
287
288 eval.try_eval_and_append(expr, meta).map_err(|e| Error {
289 kind: e.into(),
290 meta,
291 })
292 }
293
294 pub fn add_local_var(
299 &mut self,
300 name: String,
301 expr: Handle<Expression>,
302 mutable: bool,
303 ) -> Option<VariableReference> {
304 let var = VariableReference {
305 expr,
306 load: true,
307 mutable,
308 constant: None,
309 entry_arg: None,
310 };
311
312 self.symbol_table.add(name, var)
313 }
314
315 pub fn add_function_arg(
317 &mut self,
318 name_meta: Option<(String, Span)>,
319 ty: Handle<Type>,
320 qualifier: ParameterQualifier,
321 ) -> Result<()> {
322 let index = self.arguments.len();
323 let mut arg = FunctionArgument {
324 name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
325 ty,
326 binding: None,
327 };
328 self.parameters.push(ty);
329
330 let opaque = match self.module.types[ty].inner {
331 TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
332 _ => false,
333 };
334
335 if qualifier.is_lhs() {
336 let span = self.module.types.get_span(arg.ty);
337 arg.ty = self.module.types.insert(
338 Type {
339 name: None,
340 inner: TypeInner::Pointer {
341 base: arg.ty,
342 space: AddressSpace::Function,
343 },
344 },
345 span,
346 )
347 }
348
349 self.arguments.push(arg);
350
351 self.parameters_info.push(ParameterInfo {
352 qualifier,
353 depth: false,
354 });
355
356 if let Some((name, meta)) = name_meta {
357 let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
358 let mutable = qualifier != ParameterQualifier::Const && !opaque;
359 let load = qualifier.is_lhs();
360
361 let var = if mutable && !load {
362 let handle = self.locals.append(
363 LocalVariable {
364 name: Some(name.clone()),
365 ty,
366 init: None,
367 },
368 meta,
369 );
370 let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
371
372 self.emit_restart();
373
374 self.body.push(
375 Statement::Store {
376 pointer: local_expr,
377 value: expr,
378 },
379 meta,
380 );
381
382 VariableReference {
383 expr: local_expr,
384 load: true,
385 mutable,
386 constant: None,
387 entry_arg: None,
388 }
389 } else {
390 VariableReference {
391 expr,
392 load,
393 mutable,
394 constant: None,
395 entry_arg: None,
396 }
397 };
398
399 self.symbol_table.add(name, var);
400 }
401
402 Ok(())
403 }
404
405 #[must_use]
412 pub const fn stmt_ctx(&mut self) -> StmtContext {
413 self.stmt_ctx.take().unwrap()
414 }
415
416 pub fn lower(
421 &mut self,
422 mut stmt: StmtContext,
423 frontend: &mut Frontend,
424 expr: Handle<HirExpr>,
425 pos: ExprPos,
426 ) -> Result<(Option<Handle<Expression>>, Span)> {
427 let res = self.lower_inner(&stmt, frontend, expr, pos);
428
429 stmt.hir_exprs.clear();
430 self.stmt_ctx = Some(stmt);
431
432 res
433 }
434
435 pub fn lower_expect(
441 &mut self,
442 mut stmt: StmtContext,
443 frontend: &mut Frontend,
444 expr: Handle<HirExpr>,
445 pos: ExprPos,
446 ) -> Result<(Handle<Expression>, Span)> {
447 let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
448
449 stmt.hir_exprs.clear();
450 self.stmt_ctx = Some(stmt);
451
452 res
453 }
454
455 pub fn lower_expect_inner(
461 &mut self,
462 stmt: &StmtContext,
463 frontend: &mut Frontend,
464 expr: Handle<HirExpr>,
465 pos: ExprPos,
466 ) -> Result<(Handle<Expression>, Span)> {
467 let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
468
469 let expr = match maybe_expr {
470 Some(e) => e,
471 None => {
472 return Err(Error {
473 kind: ErrorKind::SemanticError("Expression returns void".into()),
474 meta,
475 })
476 }
477 };
478
479 Ok((expr, meta))
480 }
481
482 fn lower_store(
483 &mut self,
484 pointer: Handle<Expression>,
485 value: Handle<Expression>,
486 meta: Span,
487 ) -> Result<()> {
488 if let Expression::Swizzle {
489 size,
490 mut vector,
491 pattern,
492 } = self.expressions[pointer]
493 {
494 let size = match size {
497 VectorSize::Bi => 2,
498 VectorSize::Tri => 3,
499 VectorSize::Quad => 4,
500 };
501
502 if let Expression::Load { pointer } = self.expressions[vector] {
503 vector = pointer;
504 }
505
506 #[allow(clippy::needless_range_loop)]
507 for index in 0..size {
508 let dst = self.add_expression(
509 Expression::AccessIndex {
510 base: vector,
511 index: pattern[index].index(),
512 },
513 meta,
514 )?;
515 let src = self.add_expression(
516 Expression::AccessIndex {
517 base: value,
518 index: index as u32,
519 },
520 meta,
521 )?;
522
523 self.emit_restart();
524
525 self.body.push(
526 Statement::Store {
527 pointer: dst,
528 value: src,
529 },
530 meta,
531 );
532 }
533 } else {
534 self.emit_restart();
535
536 self.body.push(Statement::Store { pointer, value }, meta);
537 }
538
539 Ok(())
540 }
541
542 #[allow(clippy::large_stack_frames)] fn lower_inner(
545 &mut self,
546 stmt: &StmtContext,
547 frontend: &mut Frontend,
548 expr: Handle<HirExpr>,
549 pos: ExprPos,
550 ) -> Result<(Option<Handle<Expression>>, Span)> {
551 let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
552
553 log::debug!("Lowering {expr:?} (kind {kind:?}, pos {pos:?})");
554
555 let handle = match *kind {
556 HirExprKind::Access { base, index } => {
557 let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
558 let maybe_constant_index = match pos {
559 ExprPos::Lhs => None,
562 _ => self
563 .module
564 .to_ctx()
565 .get_const_val_from(index, &self.expressions)
566 .ok(),
567 };
568
569 let base = self
570 .lower_expect_inner(
571 stmt,
572 frontend,
573 base,
574 pos.maybe_access_base(maybe_constant_index.is_some()),
575 )?
576 .0;
577
578 let pointer = maybe_constant_index
579 .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
580 .unwrap_or_else(|| {
581 self.add_expression(Expression::Access { base, index }, meta)
582 })?;
583
584 if ExprPos::Rhs == pos {
585 let resolved = self.resolve_type(pointer, meta)?;
586 if resolved.pointer_space().is_some() {
587 return Ok((
588 Some(self.add_expression(Expression::Load { pointer }, meta)?),
589 meta,
590 ));
591 }
592 }
593
594 pointer
595 }
596 HirExprKind::Select { base, ref field } => {
597 let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
598
599 frontend.field_selection(self, pos, base, field, meta)?
600 }
601 HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
602 self.add_expression(Expression::Literal(literal), meta)?
603 }
604 HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
605 let (mut left, left_meta) =
606 self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
607 let (mut right, right_meta) =
608 self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
609
610 match op {
611 BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
612 self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
613 }
614 _ => self
615 .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
616 }
617
618 self.typifier_grow(left, left_meta)?;
619 self.typifier_grow(right, right_meta)?;
620
621 let left_inner = self.get_type(left);
622 let right_inner = self.get_type(right);
623
624 match (left_inner, right_inner) {
625 (
626 &TypeInner::Matrix {
627 columns: left_columns,
628 rows: left_rows,
629 scalar: left_scalar,
630 },
631 &TypeInner::Matrix {
632 columns: right_columns,
633 rows: right_rows,
634 scalar: right_scalar,
635 },
636 ) => {
637 let dimensions_ok = if op == BinaryOperator::Multiply {
638 left_columns == right_rows
639 } else {
640 left_columns == right_columns && left_rows == right_rows
641 };
642
643 if !dimensions_ok || left_scalar != right_scalar {
645 frontend.errors.push(Error {
646 kind: ErrorKind::SemanticError(
647 format!(
648 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
649 )
650 .into(),
651 ),
652 meta,
653 })
654 }
655
656 match op {
657 BinaryOperator::Divide => {
658 let mut components = Vec::with_capacity(left_columns as usize);
661
662 for index in 0..left_columns as u32 {
663 let left_vector = self.add_expression(
665 Expression::AccessIndex { base: left, index },
666 meta,
667 )?;
668 let right_vector = self.add_expression(
669 Expression::AccessIndex { base: right, index },
670 meta,
671 )?;
672
673 let column = self.add_expression(
675 Expression::Binary {
676 op,
677 left: left_vector,
678 right: right_vector,
679 },
680 meta,
681 )?;
682
683 components.push(column)
684 }
685
686 let ty = self.module.types.insert(
687 Type {
688 name: None,
689 inner: TypeInner::Matrix {
690 columns: left_columns,
691 rows: left_rows,
692 scalar: left_scalar,
693 },
694 },
695 Span::default(),
696 );
697
698 self.add_expression(Expression::Compose { ty, components }, meta)?
700 }
701 BinaryOperator::Equal | BinaryOperator::NotEqual => {
702 let equals = op == BinaryOperator::Equal;
708
709 let (op, combine, fun) = match equals {
710 true => (
711 BinaryOperator::Equal,
712 BinaryOperator::LogicalAnd,
713 RelationalFunction::All,
714 ),
715 false => (
716 BinaryOperator::NotEqual,
717 BinaryOperator::LogicalOr,
718 RelationalFunction::Any,
719 ),
720 };
721
722 let mut root = None;
723
724 for index in 0..left_columns as u32 {
725 let left_vector = self.add_expression(
727 Expression::AccessIndex { base: left, index },
728 meta,
729 )?;
730 let right_vector = self.add_expression(
731 Expression::AccessIndex { base: right, index },
732 meta,
733 )?;
734
735 let argument = self.add_expression(
736 Expression::Binary {
737 op,
738 left: left_vector,
739 right: right_vector,
740 },
741 meta,
742 )?;
743
744 let compare = self.add_expression(
748 Expression::Relational { fun, argument },
749 meta,
750 )?;
751
752 root = Some(match root {
754 Some(right) => self.add_expression(
755 Expression::Binary {
756 op: combine,
757 left: compare,
758 right,
759 },
760 meta,
761 )?,
762 None => compare,
763 });
764 }
765
766 root.unwrap()
767 }
768 _ => {
769 self.add_expression(Expression::Binary { left, op, right }, meta)?
770 }
771 }
772 }
773 (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
774 BinaryOperator::Equal | BinaryOperator::NotEqual => {
775 let equals = op == BinaryOperator::Equal;
776
777 let (op, fun) = match equals {
778 true => (BinaryOperator::Equal, RelationalFunction::All),
779 false => (BinaryOperator::NotEqual, RelationalFunction::Any),
780 };
781
782 let argument =
783 self.add_expression(Expression::Binary { op, left, right }, meta)?;
784
785 self.add_expression(Expression::Relational { fun, argument }, meta)?
786 }
787 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
788 },
789 (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
790 BinaryOperator::Add
791 | BinaryOperator::Subtract
792 | BinaryOperator::Divide
793 | BinaryOperator::And
794 | BinaryOperator::ExclusiveOr
795 | BinaryOperator::InclusiveOr
796 | BinaryOperator::ShiftLeft
797 | BinaryOperator::ShiftRight => {
798 let scalar_vector = self
799 .add_expression(Expression::Splat { size, value: right }, meta)?;
800
801 self.add_expression(
802 Expression::Binary {
803 op,
804 left,
805 right: scalar_vector,
806 },
807 meta,
808 )?
809 }
810 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
811 },
812 (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
813 BinaryOperator::Add
814 | BinaryOperator::Subtract
815 | BinaryOperator::Divide
816 | BinaryOperator::And
817 | BinaryOperator::ExclusiveOr
818 | BinaryOperator::InclusiveOr => {
819 let scalar_vector =
820 self.add_expression(Expression::Splat { size, value: left }, meta)?;
821
822 self.add_expression(
823 Expression::Binary {
824 op,
825 left: scalar_vector,
826 right,
827 },
828 meta,
829 )?
830 }
831 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
832 },
833 (
834 &TypeInner::Scalar(left_scalar),
835 &TypeInner::Matrix {
836 rows,
837 columns,
838 scalar: right_scalar,
839 },
840 ) => {
841 if left_scalar != right_scalar {
843 frontend.errors.push(Error {
844 kind: ErrorKind::SemanticError(
845 format!(
846 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
847 )
848 .into(),
849 ),
850 meta,
851 })
852 }
853
854 match op {
855 BinaryOperator::Divide
856 | BinaryOperator::Add
857 | BinaryOperator::Subtract => {
858 let scalar_vector = self.add_expression(
863 Expression::Splat {
864 size: rows,
865 value: left,
866 },
867 meta,
868 )?;
869
870 let mut components = Vec::with_capacity(columns as usize);
871
872 for index in 0..columns as u32 {
873 let matrix_column = self.add_expression(
875 Expression::AccessIndex { base: right, index },
876 meta,
877 )?;
878
879 let column = self.add_expression(
882 Expression::Binary {
883 op,
884 left: scalar_vector,
885 right: matrix_column,
886 },
887 meta,
888 )?;
889
890 components.push(column)
891 }
892
893 let ty = self.module.types.insert(
894 Type {
895 name: None,
896 inner: TypeInner::Matrix {
897 columns,
898 rows,
899 scalar: left_scalar,
900 },
901 },
902 Span::default(),
903 );
904
905 self.add_expression(Expression::Compose { ty, components }, meta)?
907 }
908 _ => {
909 self.add_expression(Expression::Binary { left, op, right }, meta)?
910 }
911 }
912 }
913 (
914 &TypeInner::Matrix {
915 rows,
916 columns,
917 scalar: left_scalar,
918 },
919 &TypeInner::Scalar(right_scalar),
920 ) => {
921 if left_scalar != right_scalar {
923 frontend.errors.push(Error {
924 kind: ErrorKind::SemanticError(
925 format!(
926 "Cannot apply operation to {left_inner:?} and {right_inner:?}"
927 )
928 .into(),
929 ),
930 meta,
931 })
932 }
933
934 match op {
935 BinaryOperator::Divide
936 | BinaryOperator::Add
937 | BinaryOperator::Subtract => {
938 let scalar_vector = self.add_expression(
944 Expression::Splat {
945 size: rows,
946 value: right,
947 },
948 meta,
949 )?;
950
951 let mut components = Vec::with_capacity(columns as usize);
952
953 for index in 0..columns as u32 {
954 let matrix_column = self.add_expression(
956 Expression::AccessIndex { base: left, index },
957 meta,
958 )?;
959
960 let column = self.add_expression(
963 Expression::Binary {
964 op,
965 left: matrix_column,
966 right: scalar_vector,
967 },
968 meta,
969 )?;
970
971 components.push(column)
972 }
973
974 let ty = self.module.types.insert(
975 Type {
976 name: None,
977 inner: TypeInner::Matrix {
978 columns,
979 rows,
980 scalar: left_scalar,
981 },
982 },
983 Span::default(),
984 );
985
986 self.add_expression(Expression::Compose { ty, components }, meta)?
988 }
989 _ => {
990 self.add_expression(Expression::Binary { left, op, right }, meta)?
991 }
992 }
993 }
994 _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
995 }
996 }
997 HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
998 let expr = self
999 .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
1000 .0;
1001
1002 if let TypeInner::Matrix { scalar, .. } = *self.resolve_type(expr, meta)? {
1003 let minus_one = Literal::minus_one(scalar).ok_or_else(|| Error {
1006 kind: ErrorKind::SemanticError(
1007 format!("Cannot apply operator {op:?} to type {scalar:?}").into(),
1008 ),
1009 meta,
1010 })?;
1011 let lhs = self.add_expression(Expression::Literal(minus_one), meta)?;
1012 self.add_expression(
1013 Expression::Binary {
1014 op: BinaryOperator::Multiply,
1015 left: lhs,
1016 right: expr,
1017 },
1018 meta,
1019 )?
1020 } else {
1021 self.add_expression(Expression::Unary { op, expr }, meta)?
1022 }
1023 }
1024 HirExprKind::Variable(ref var) => match pos {
1025 ExprPos::Lhs => {
1026 if !var.mutable {
1027 frontend.errors.push(Error {
1028 kind: ErrorKind::SemanticError(
1029 "Variable cannot be used in LHS position".into(),
1030 ),
1031 meta,
1032 })
1033 }
1034
1035 var.expr
1036 }
1037 ExprPos::AccessBase { constant_index } => {
1038 if !constant_index {
1042 if let Some((constant, ty)) = var.constant {
1043 let init = self
1044 .add_expression(Expression::Constant(constant), Span::default())?;
1045 let local = self.locals.append(
1046 LocalVariable {
1047 name: None,
1048 ty,
1049 init: Some(init),
1050 },
1051 Span::default(),
1052 );
1053
1054 self.add_expression(Expression::LocalVariable(local), Span::default())?
1055 } else {
1056 var.expr
1057 }
1058 } else {
1059 var.expr
1060 }
1061 }
1062 _ if var.load => {
1063 self.add_expression(Expression::Load { pointer: var.expr }, meta)?
1064 }
1065 ExprPos::Rhs => {
1066 if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
1067 self.add_expression(Expression::Constant(constant), meta)?
1068 } else {
1069 if self.is_const {
1071 if let Expression::Override(o) = self.expressions[var.expr] {
1072 self.add_expression(Expression::Override(o), meta)?
1074 } else {
1075 var.expr
1076 }
1077 } else {
1078 var.expr
1079 }
1080 }
1081 }
1082 },
1083 HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
1084 let maybe_expr = frontend.function_or_constructor_call(
1085 self,
1086 stmt,
1087 call.kind.clone(),
1088 &call.args,
1089 meta,
1090 )?;
1091 return Ok((maybe_expr, meta));
1092 }
1093 HirExprKind::Conditional {
1099 condition,
1100 accept,
1101 reject,
1102 } if ExprPos::Lhs != pos => {
1103 let condition = self
1119 .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
1120 .0;
1121
1122 let (mut accept_body, (mut accept, accept_meta)) =
1123 self.new_body_with_ret(|ctx| {
1124 ctx.lower_expect_inner(stmt, frontend, accept, pos)
1126 })?;
1127
1128 let (mut reject_body, (mut reject, reject_meta)) =
1129 self.new_body_with_ret(|ctx| {
1130 ctx.lower_expect_inner(stmt, frontend, reject, pos)
1132 })?;
1133
1134 if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
1137 self.expr_scalar_components(accept, accept_meta)?
1139 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1140 self.expr_scalar_components(reject, reject_meta)?
1141 .and_then(|scalar| Some((type_power(scalar)?, scalar))),
1142 ) {
1143 match accept_power.cmp(&reject_power) {
1144 core::cmp::Ordering::Less => {
1145 accept_body = self.with_body(accept_body, |ctx| {
1146 ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
1147 Ok(())
1148 })?;
1149 }
1150 core::cmp::Ordering::Equal => {}
1151 core::cmp::Ordering::Greater => {
1152 reject_body = self.with_body(reject_body, |ctx| {
1153 ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
1154 Ok(())
1155 })?;
1156 }
1157 }
1158 }
1159
1160 let ty = self.resolve_type_handle(accept, accept_meta)?;
1164
1165 let local = self.locals.append(
1167 LocalVariable {
1168 name: None,
1169 ty,
1170 init: None,
1171 },
1172 meta,
1173 );
1174
1175 let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
1176
1177 accept_body.push(
1179 Statement::Store {
1180 pointer: local_expr,
1181 value: accept,
1182 },
1183 accept_meta,
1184 );
1185 reject_body.push(
1186 Statement::Store {
1187 pointer: local_expr,
1188 value: reject,
1189 },
1190 reject_meta,
1191 );
1192
1193 self.body.push(
1196 Statement::If {
1197 condition,
1198 accept: accept_body,
1199 reject: reject_body,
1200 },
1201 meta,
1202 );
1203
1204 self.add_expression(
1207 Expression::Load {
1208 pointer: local_expr,
1209 },
1210 meta,
1211 )?
1212 }
1213 HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
1214 let (pointer, ptr_meta) =
1215 self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
1216 let (mut value, value_meta) =
1217 self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
1218
1219 let ty = match *self.resolve_type(pointer, ptr_meta)? {
1220 TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
1221 ref ty => ty,
1222 };
1223
1224 if let Some(scalar) = scalar_components(ty) {
1225 self.implicit_conversion(&mut value, value_meta, scalar)?;
1226 }
1227
1228 self.lower_store(pointer, value, meta)?;
1229
1230 value
1231 }
1232 HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
1233 let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
1234 let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
1235 pointer
1236 } else {
1237 self.add_expression(Expression::Load { pointer }, meta)?
1238 };
1239
1240 let res = match *self.resolve_type(left, meta)? {
1241 TypeInner::Scalar(scalar) => {
1242 let ty = TypeInner::Scalar(scalar);
1243 Literal::one(scalar).map(|i| (ty, i, None, None))
1244 }
1245 TypeInner::Vector { size, scalar } => {
1246 let ty = TypeInner::Vector { size, scalar };
1247 Literal::one(scalar).map(|i| (ty, i, Some(size), None))
1248 }
1249 TypeInner::Matrix {
1250 columns,
1251 rows,
1252 scalar,
1253 } => {
1254 let ty = TypeInner::Matrix {
1255 columns,
1256 rows,
1257 scalar,
1258 };
1259 Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
1260 }
1261 _ => None,
1262 };
1263 let (ty_inner, literal, rows, columns) = match res {
1264 Some(res) => res,
1265 None => {
1266 frontend.errors.push(Error {
1267 kind: ErrorKind::SemanticError(
1268 "Increment/decrement only works on scalar/vector/matrix".into(),
1269 ),
1270 meta,
1271 });
1272 return Ok((Some(left), meta));
1273 }
1274 };
1275
1276 let mut right = self.add_expression(Expression::Literal(literal), meta)?;
1277
1278 if let Some(size) = rows {
1283 right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
1284
1285 if let Some(cols) = columns {
1286 let ty = self.module.types.insert(
1287 Type {
1288 name: None,
1289 inner: ty_inner,
1290 },
1291 meta,
1292 );
1293
1294 right = self.add_expression(
1295 Expression::Compose {
1296 ty,
1297 components: core::iter::repeat_n(right, cols as usize).collect(),
1298 },
1299 meta,
1300 )?;
1301 }
1302 }
1303
1304 let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
1305
1306 self.lower_store(pointer, value, meta)?;
1307
1308 if postfix {
1309 left
1310 } else {
1311 value
1312 }
1313 }
1314 HirExprKind::Method {
1315 expr: object,
1316 ref name,
1317 ref args,
1318 } if ExprPos::Lhs != pos => {
1319 let args = args
1320 .iter()
1321 .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
1322 .collect::<Result<Vec<_>>>()?;
1323 match name.as_ref() {
1324 "length" => {
1325 if !args.is_empty() {
1326 frontend.errors.push(Error {
1327 kind: ErrorKind::SemanticError(
1328 ".length() doesn't take any arguments".into(),
1329 ),
1330 meta,
1331 });
1332 }
1333 let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
1334 let array_type = self.resolve_type(lowered_array, meta)?;
1335
1336 match *array_type {
1337 TypeInner::Array {
1338 size: crate::ArraySize::Constant(size),
1339 ..
1340 } => {
1341 let mut array_length = self.add_expression(
1342 Expression::Literal(Literal::U32(size.get())),
1343 meta,
1344 )?;
1345 self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
1346 array_length
1347 }
1348 _ => {
1350 let mut array_length = self
1351 .add_expression(Expression::ArrayLength(lowered_array), meta)?;
1352 self.conversion(&mut array_length, meta, Scalar::I32)?;
1353 array_length
1354 }
1355 }
1356 }
1357
1358 _ => {
1359 return Err(Error {
1360 kind: ErrorKind::SemanticError(
1361 format!("unknown method '{name}'").into(),
1362 ),
1363 meta,
1364 });
1365 }
1366 }
1367 }
1368 HirExprKind::Sequence { ref exprs } if pos != ExprPos::Lhs => {
1369 let mut last_handle = None;
1370 for expr in exprs.iter() {
1371 let (handle, _) =
1372 self.lower_expect_inner(stmt, frontend, *expr, ExprPos::Rhs)?;
1373 last_handle = Some(handle);
1374 }
1375 match last_handle {
1376 Some(handle) => handle,
1377 None => unreachable!(),
1378 }
1379 }
1380 _ => {
1381 return Err(Error {
1382 kind: ErrorKind::SemanticError(
1383 format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
1384 .into(),
1385 ),
1386 meta,
1387 })
1388 }
1389 };
1390
1391 log::trace!("Lowered {expr:?}\n\tKind = {kind:?}\n\tPos = {pos:?}\n\tResult = {handle:?}");
1392
1393 Ok((Some(handle), meta))
1394 }
1395
1396 pub fn expr_scalar_components(
1397 &mut self,
1398 expr: Handle<Expression>,
1399 meta: Span,
1400 ) -> Result<Option<Scalar>> {
1401 let ty = self.resolve_type(expr, meta)?;
1402 Ok(scalar_components(ty))
1403 }
1404
1405 pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
1406 Ok(self
1407 .expr_scalar_components(expr, meta)?
1408 .and_then(type_power))
1409 }
1410
1411 pub fn conversion(
1412 &mut self,
1413 expr: &mut Handle<Expression>,
1414 meta: Span,
1415 scalar: Scalar,
1416 ) -> Result<()> {
1417 *expr = self.add_expression(
1418 Expression::As {
1419 expr: *expr,
1420 kind: scalar.kind,
1421 convert: Some(scalar.width),
1422 },
1423 meta,
1424 )?;
1425
1426 Ok(())
1427 }
1428
1429 pub fn implicit_conversion(
1430 &mut self,
1431 expr: &mut Handle<Expression>,
1432 meta: Span,
1433 scalar: Scalar,
1434 ) -> Result<()> {
1435 if let (Some(tgt_power), Some(expr_power)) =
1436 (type_power(scalar), self.expr_power(*expr, meta)?)
1437 {
1438 if tgt_power > expr_power {
1439 self.conversion(expr, meta, scalar)?;
1440 }
1441 }
1442
1443 Ok(())
1444 }
1445
1446 pub fn forced_conversion(
1447 &mut self,
1448 expr: &mut Handle<Expression>,
1449 meta: Span,
1450 scalar: Scalar,
1451 ) -> Result<()> {
1452 if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
1453 if expr_scalar != scalar {
1454 self.conversion(expr, meta, scalar)?;
1455 }
1456 }
1457
1458 Ok(())
1459 }
1460
1461 pub fn binary_implicit_conversion(
1462 &mut self,
1463 left: &mut Handle<Expression>,
1464 left_meta: Span,
1465 right: &mut Handle<Expression>,
1466 right_meta: Span,
1467 ) -> Result<()> {
1468 let left_components = self.expr_scalar_components(*left, left_meta)?;
1469 let right_components = self.expr_scalar_components(*right, right_meta)?;
1470
1471 if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
1472 left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1473 right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
1474 ) {
1475 match left_power.cmp(&right_power) {
1476 core::cmp::Ordering::Less => {
1477 self.conversion(left, left_meta, right_scalar)?;
1478 }
1479 core::cmp::Ordering::Equal => {}
1480 core::cmp::Ordering::Greater => {
1481 self.conversion(right, right_meta, left_scalar)?;
1482 }
1483 }
1484 }
1485
1486 Ok(())
1487 }
1488
1489 pub fn implicit_splat(
1490 &mut self,
1491 expr: &mut Handle<Expression>,
1492 meta: Span,
1493 vector_size: Option<VectorSize>,
1494 ) -> Result<()> {
1495 let expr_type = self.resolve_type(*expr, meta)?;
1496
1497 if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
1498 *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
1499 }
1500
1501 Ok(())
1502 }
1503
1504 pub fn vector_resize(
1505 &mut self,
1506 size: VectorSize,
1507 vector: Handle<Expression>,
1508 meta: Span,
1509 ) -> Result<Handle<Expression>> {
1510 self.add_expression(
1511 Expression::Swizzle {
1512 size,
1513 vector,
1514 pattern: crate::SwizzleComponent::XYZW,
1515 },
1516 meta,
1517 )
1518 }
1519}
1520
1521impl Index<Handle<Expression>> for Context<'_> {
1522 type Output = Expression;
1523
1524 fn index(&self, index: Handle<Expression>) -> &Self::Output {
1525 if self.is_const {
1526 &self.module.global_expressions[index]
1527 } else {
1528 &self.expressions[index]
1529 }
1530 }
1531}
1532
1533#[derive(Debug)]
1538pub struct StmtContext {
1539 pub hir_exprs: Arena<HirExpr>,
1542}
1543
1544impl StmtContext {
1545 const fn new() -> Self {
1546 StmtContext {
1547 hir_exprs: Arena::new(),
1548 }
1549 }
1550}