1use alloc::{boxed::Box, vec::Vec};
2
3use crate::{
4 front::wgsl::{
5 error::Error,
6 lower::{ExpressionContext, Lowerer, Result},
7 parse::{ast, conv},
8 },
9 ir, Handle, Span,
10};
11
12pub struct TemplateListIter<'iter, 'source> {
19 ident_span: Span,
20 template_list: core::slice::Iter<'iter, Handle<ast::Expression<'source>>>,
21}
22
23impl<'iter, 'source> TemplateListIter<'iter, 'source> {
24 pub fn new(ident_span: Span, template_list: &'iter [Handle<ast::Expression<'source>>]) -> Self {
25 Self {
26 ident_span,
27 template_list: template_list.iter(),
28 }
29 }
30
31 pub fn finish(self, ctx: &ExpressionContext<'source, '_, '_>) -> Result<'source, ()> {
32 let unused_args: Vec<Span> = self
33 .template_list
34 .map(|expr| ctx.ast_expressions.get_span(*expr))
35 .collect();
36 if unused_args.is_empty() {
37 Ok(())
38 } else {
39 Err(Box::new(Error::UnusedArgsForTemplate(unused_args)))
40 }
41 }
42
43 fn expect_next(
44 &mut self,
45 description: &'static str,
46 ) -> Result<'source, Handle<ast::Expression<'source>>> {
47 if let Some(expr) = self.template_list.next() {
48 Ok(*expr)
49 } else {
50 Err(Box::new(Error::MissingTemplateArg {
51 span: self.ident_span,
52 description,
53 }))
54 }
55 }
56
57 pub fn ty(
58 &mut self,
59 lowerer: &mut Lowerer<'source, '_>,
60 ctx: &mut ExpressionContext<'source, '_, '_>,
61 ) -> Result<'source, Handle<ir::Type>> {
62 let expr = self.expect_next("`T`, a type")?;
63 lowerer.type_expression(expr, ctx)
64 }
65
66 pub fn ty_with_span(
72 &mut self,
73 lowerer: &mut Lowerer<'source, '_>,
74 ctx: &mut ExpressionContext<'source, '_, '_>,
75 ) -> Result<'source, (Handle<ir::Type>, Span)> {
76 let expr = self.expect_next("`T`, a type")?;
77 let span = ctx.ast_expressions.get_span(expr);
78 let ty = lowerer.type_expression(expr, ctx)?;
79 Ok((ty, span))
80 }
81
82 pub fn scalar_ty(
83 &mut self,
84 lowerer: &mut Lowerer<'source, '_>,
85 ctx: &mut ExpressionContext<'source, '_, '_>,
86 ) -> Result<'source, (ir::Scalar, Span)> {
87 let expr = self.expect_next("`T`, a scalar type")?;
88 let ty = lowerer.type_expression(expr, ctx)?;
89 let span = ctx.ast_expressions.get_span(expr);
90 match ctx.module.types[ty].inner {
91 ir::TypeInner::Scalar(scalar) => Ok((scalar, span)),
92 _ => Err(Box::new(Error::UnknownScalarType(span))),
93 }
94 }
95
96 pub fn maybe_array_size(
97 &mut self,
98 lowerer: &mut Lowerer<'source, '_>,
99 ctx: &mut ExpressionContext<'source, '_, '_>,
100 ) -> Result<'source, ir::ArraySize> {
101 if let Some(expr) = self.template_list.next() {
102 lowerer.array_size(*expr, ctx)
103 } else {
104 Ok(ir::ArraySize::Dynamic)
105 }
106 }
107
108 pub fn address_space(
109 &mut self,
110 ctx: &ExpressionContext<'source, '_, '_>,
111 ) -> Result<'source, ir::AddressSpace> {
112 let expr = self.expect_next("`AS`, an address space")?;
113 let (enumerant, span) = ctx.enumerant(expr)?;
114 conv::map_address_space(enumerant, span, &ctx.enable_extensions)
115 }
116 pub fn maybe_address_space(
117 &mut self,
118 ctx: &ExpressionContext<'source, '_, '_>,
119 ) -> Result<'source, Option<ir::AddressSpace>> {
120 let Some(expr) = self.template_list.next() else {
121 return Ok(None);
122 };
123
124 let (enumerant, span) = ctx.enumerant(*expr)?;
125 Ok(Some(conv::map_address_space(
126 enumerant,
127 span,
128 &ctx.enable_extensions,
129 )?))
130 }
131
132 pub fn access_mode(
133 &mut self,
134 ctx: &ExpressionContext<'source, '_, '_>,
135 ) -> Result<'source, ir::StorageAccess> {
136 let expr = self.expect_next("`Access`, an access mode")?;
137 let (enumerant, span) = ctx.enumerant(expr)?;
138 conv::map_access_mode(enumerant, span)
139 }
140
141 pub fn maybe_access_mode(
152 &mut self,
153 space: &mut ir::AddressSpace,
154 ctx: &ExpressionContext<'source, '_, '_>,
155 ) -> Result<'source, ()> {
156 if let &mut ir::AddressSpace::Storage { ref mut access } = space {
157 if let Some(expr) = self.template_list.next() {
158 let (enumerant, span) = ctx.enumerant(*expr)?;
159 let access_mode = conv::map_access_mode(enumerant, span)?;
160 *access = access_mode;
161 } else {
162 *access = ir::StorageAccess::LOAD
164 }
165 }
166 Ok(())
167 }
168
169 pub fn storage_format(
170 &mut self,
171 ctx: &ExpressionContext<'source, '_, '_>,
172 ) -> Result<'source, ir::StorageFormat> {
173 let expr = self.expect_next("`Format`, a texel format")?;
174 let (enumerant, span) = ctx.enumerant(expr)?;
175 conv::map_storage_format(enumerant, span)
176 }
177
178 pub fn maybe_vertex_return(
179 &mut self,
180 ctx: &ExpressionContext<'source, '_, '_>,
181 ) -> Result<'source, bool> {
182 let Some(expr) = self.template_list.next() else {
183 return Ok(false);
184 };
185
186 let (enumerant, span) = ctx.enumerant(*expr)?;
187 conv::map_ray_flag(&ctx.enable_extensions, enumerant, span)?;
188 Ok(true)
189 }
190
191 pub fn cooperative_role(
192 &mut self,
193 ctx: &ExpressionContext<'source, '_, '_>,
194 ) -> Result<'source, crate::CooperativeRole> {
195 let role_expr = self.expect_next("`Role`, a cooperative matrix role")?;
196 let (enumerant, span) = ctx.enumerant(role_expr)?;
197 let role = conv::map_cooperative_role(enumerant, span)?;
198 Ok(role)
199 }
200}