1use alloc::string::String;
2
3use pp_rs::{
4 pp::Preprocessor,
5 token::{PreprocessorError, Punct, TokenValue as PPTokenValue},
6};
7
8use super::{
9 ast::Precision,
10 token::{Directive, DirectiveKind, Token, TokenValue},
11 types::parse_type,
12};
13use crate::{FastHashMap, Span, StorageAccess};
14
15#[derive(Debug)]
16#[cfg_attr(test, derive(PartialEq))]
17pub struct LexerResult {
18 pub kind: LexerResultKind,
19 pub meta: Span,
20}
21
22#[derive(Debug)]
23#[cfg_attr(test, derive(PartialEq))]
24pub enum LexerResultKind {
25 Token(Token),
26 Directive(Directive),
27 Error(PreprocessorError),
28}
29
30pub struct Lexer<'a> {
31 pp: Preprocessor<'a>,
32}
33
34impl<'a> Lexer<'a> {
35 pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self {
36 let mut pp = Preprocessor::new(input);
37 for (define, value) in defines {
38 pp.add_define(define, value).unwrap(); }
40 Lexer { pp }
41 }
42}
43
44impl Iterator for Lexer<'_> {
45 type Item = LexerResult;
46 fn next(&mut self) -> Option<Self::Item> {
47 let pp_token = match self.pp.next()? {
48 Ok(t) => t,
49 Err((err, loc)) => {
50 return Some(LexerResult {
51 kind: LexerResultKind::Error(err),
52 meta: loc.into(),
53 });
54 }
55 };
56
57 let meta = pp_token.location.into();
58 let value = match pp_token.value {
59 PPTokenValue::Extension(extension) => {
60 return Some(LexerResult {
61 kind: LexerResultKind::Directive(Directive {
62 kind: DirectiveKind::Extension,
63 tokens: extension.tokens,
64 }),
65 meta,
66 })
67 }
68 PPTokenValue::Float(float) => TokenValue::FloatConstant(float),
69 PPTokenValue::Ident(ident) => {
70 match ident.as_str() {
71 "layout" => TokenValue::Layout,
73 "in" => TokenValue::In,
74 "out" => TokenValue::Out,
75 "uniform" => TokenValue::Uniform,
76 "buffer" => TokenValue::Buffer,
77 "shared" => TokenValue::Shared,
78 "invariant" => TokenValue::Invariant,
79 "flat" => TokenValue::Interpolation(crate::Interpolation::Flat),
80 "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear),
81 "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective),
82 "centroid" => TokenValue::Sampling(crate::Sampling::Centroid),
83 "sample" => TokenValue::Sampling(crate::Sampling::Sample),
84 "const" => TokenValue::Const,
85 "inout" => TokenValue::InOut,
86 "precision" => TokenValue::Precision,
87 "highp" => TokenValue::PrecisionQualifier(Precision::High),
88 "mediump" => TokenValue::PrecisionQualifier(Precision::Medium),
89 "lowp" => TokenValue::PrecisionQualifier(Precision::Low),
90 "restrict" => TokenValue::Restrict,
91 "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD),
92 "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE),
93 "true" => TokenValue::BoolConstant(true),
95 "false" => TokenValue::BoolConstant(false),
96 "continue" => TokenValue::Continue,
98 "break" => TokenValue::Break,
99 "return" => TokenValue::Return,
100 "discard" => TokenValue::Discard,
101 "if" => TokenValue::If,
103 "else" => TokenValue::Else,
104 "switch" => TokenValue::Switch,
105 "case" => TokenValue::Case,
106 "default" => TokenValue::Default,
107 "while" => TokenValue::While,
109 "do" => TokenValue::Do,
110 "for" => TokenValue::For,
111 "void" => TokenValue::Void,
113 "struct" => TokenValue::Struct,
114 word => match parse_type(word) {
115 Some(t) => TokenValue::TypeName(t),
116 None => TokenValue::Identifier(String::from(word)),
117 },
118 }
119 }
120 PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer),
121 PPTokenValue::Punct(punct) => match punct {
122 Punct::AddAssign => TokenValue::AddAssign,
124 Punct::SubAssign => TokenValue::SubAssign,
125 Punct::MulAssign => TokenValue::MulAssign,
126 Punct::DivAssign => TokenValue::DivAssign,
127 Punct::ModAssign => TokenValue::ModAssign,
128 Punct::LeftShiftAssign => TokenValue::LeftShiftAssign,
129 Punct::RightShiftAssign => TokenValue::RightShiftAssign,
130 Punct::AndAssign => TokenValue::AndAssign,
131 Punct::XorAssign => TokenValue::XorAssign,
132 Punct::OrAssign => TokenValue::OrAssign,
133
134 Punct::Increment => TokenValue::Increment,
136 Punct::Decrement => TokenValue::Decrement,
137 Punct::LogicalAnd => TokenValue::LogicalAnd,
138 Punct::LogicalOr => TokenValue::LogicalOr,
139 Punct::LogicalXor => TokenValue::LogicalXor,
140 Punct::LessEqual => TokenValue::LessEqual,
141 Punct::GreaterEqual => TokenValue::GreaterEqual,
142 Punct::EqualEqual => TokenValue::Equal,
143 Punct::NotEqual => TokenValue::NotEqual,
144 Punct::LeftShift => TokenValue::LeftShift,
145 Punct::RightShift => TokenValue::RightShift,
146
147 Punct::LeftBrace => TokenValue::LeftBrace,
149 Punct::RightBrace => TokenValue::RightBrace,
150 Punct::LeftParen => TokenValue::LeftParen,
151 Punct::RightParen => TokenValue::RightParen,
152 Punct::LeftBracket => TokenValue::LeftBracket,
153 Punct::RightBracket => TokenValue::RightBracket,
154
155 Punct::LeftAngle => TokenValue::LeftAngle,
157 Punct::RightAngle => TokenValue::RightAngle,
158 Punct::Semicolon => TokenValue::Semicolon,
159 Punct::Comma => TokenValue::Comma,
160 Punct::Colon => TokenValue::Colon,
161 Punct::Dot => TokenValue::Dot,
162 Punct::Equal => TokenValue::Assign,
163 Punct::Bang => TokenValue::Bang,
164 Punct::Minus => TokenValue::Dash,
165 Punct::Tilde => TokenValue::Tilde,
166 Punct::Plus => TokenValue::Plus,
167 Punct::Star => TokenValue::Star,
168 Punct::Slash => TokenValue::Slash,
169 Punct::Percent => TokenValue::Percent,
170 Punct::Pipe => TokenValue::VerticalBar,
171 Punct::Caret => TokenValue::Caret,
172 Punct::Ampersand => TokenValue::Ampersand,
173 Punct::Question => TokenValue::Question,
174 },
175 PPTokenValue::Pragma(pragma) => {
176 return Some(LexerResult {
177 kind: LexerResultKind::Directive(Directive {
178 kind: DirectiveKind::Pragma,
179 tokens: pragma.tokens,
180 }),
181 meta,
182 })
183 }
184 PPTokenValue::Version(version) => {
185 return Some(LexerResult {
186 kind: LexerResultKind::Directive(Directive {
187 kind: DirectiveKind::Version {
188 is_first_directive: version.is_first_directive,
189 },
190 tokens: version.tokens,
191 }),
192 meta,
193 })
194 }
195 };
196
197 Some(LexerResult {
198 kind: LexerResultKind::Token(Token { value, meta }),
199 meta,
200 })
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use alloc::vec;
207
208 use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue};
209
210 use super::{
211 super::token::{Directive, DirectiveKind, Token, TokenValue},
212 Lexer, LexerResult, LexerResultKind,
213 };
214 use crate::Span;
215
216 #[test]
217 fn lex_tokens() {
218 let defines = crate::FastHashMap::default();
219
220 let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines);
222 let mut location = Location::default();
223 location.start = 9;
224 location.end = 12;
225 assert_eq!(
226 lex.next().unwrap(),
227 LexerResult {
228 kind: LexerResultKind::Directive(Directive {
229 kind: DirectiveKind::Version {
230 is_first_directive: true
231 },
232 tokens: vec![PPToken {
233 value: PPTokenValue::Integer(Integer {
234 signed: true,
235 value: 450,
236 width: 32
237 }),
238 location
239 }]
240 }),
241 meta: Span::new(1, 8)
242 }
243 );
244 assert_eq!(
245 lex.next().unwrap(),
246 LexerResult {
247 kind: LexerResultKind::Token(Token {
248 value: TokenValue::Void,
249 meta: Span::new(13, 17)
250 }),
251 meta: Span::new(13, 17)
252 }
253 );
254 assert_eq!(
255 lex.next().unwrap(),
256 LexerResult {
257 kind: LexerResultKind::Token(Token {
258 value: TokenValue::Identifier("main".into()),
259 meta: Span::new(18, 22)
260 }),
261 meta: Span::new(18, 22)
262 }
263 );
264 assert_eq!(
265 lex.next().unwrap(),
266 LexerResult {
267 kind: LexerResultKind::Token(Token {
268 value: TokenValue::LeftParen,
269 meta: Span::new(23, 24)
270 }),
271 meta: Span::new(23, 24)
272 }
273 );
274 assert_eq!(
275 lex.next().unwrap(),
276 LexerResult {
277 kind: LexerResultKind::Token(Token {
278 value: TokenValue::RightParen,
279 meta: Span::new(24, 25)
280 }),
281 meta: Span::new(24, 25)
282 }
283 );
284 assert_eq!(
285 lex.next().unwrap(),
286 LexerResult {
287 kind: LexerResultKind::Token(Token {
288 value: TokenValue::LeftBrace,
289 meta: Span::new(26, 27)
290 }),
291 meta: Span::new(26, 27)
292 }
293 );
294 assert_eq!(
295 lex.next().unwrap(),
296 LexerResult {
297 kind: LexerResultKind::Token(Token {
298 value: TokenValue::RightBrace,
299 meta: Span::new(27, 28)
300 }),
301 meta: Span::new(27, 28)
302 }
303 );
304 assert_eq!(lex.next(), None);
305 }
306}