naga/front/glsl/
variables.rs

1use alloc::{format, string::String, vec::Vec};
2
3use super::{
4    ast::*,
5    context::{Context, ExprPos},
6    error::{Error, ErrorKind},
7    Frontend, Result, Span,
8};
9use crate::{
10    AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation,
11    LocalVariable, Override, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent,
12    Type, TypeInner, VectorSize,
13};
14
15pub struct VarDeclaration<'a, 'key> {
16    pub qualifiers: &'a mut TypeQualifiers<'key>,
17    pub ty: Handle<Type>,
18    pub name: Option<String>,
19    pub init: Option<Handle<Expression>>,
20    pub meta: Span,
21}
22
23/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin).
24struct BuiltInData {
25    /// The type of the builtin.
26    inner: TypeInner,
27    /// The associated builtin class.
28    builtin: BuiltIn,
29    /// Whether the builtin can be written to or not.
30    mutable: bool,
31    /// The storage used for the builtin.
32    storage: StorageQualifier,
33}
34
35pub enum GlobalOrConstant {
36    Global(Handle<GlobalVariable>),
37    Constant(Handle<Constant>),
38    Override(Handle<Override>),
39}
40
41impl Frontend {
42    /// Adds a builtin and returns a variable reference to it
43    fn add_builtin(
44        &mut self,
45        ctx: &mut Context,
46        name: &str,
47        data: BuiltInData,
48        meta: Span,
49    ) -> Result<Option<VariableReference>> {
50        let ty = ctx.module.types.insert(
51            Type {
52                name: None,
53                inner: data.inner,
54            },
55            meta,
56        );
57
58        let handle = ctx.module.global_variables.append(
59            GlobalVariable {
60                name: Some(name.into()),
61                space: AddressSpace::Private,
62                binding: None,
63                ty,
64                init: None,
65                memory_decorations: crate::MemoryDecorations::empty(),
66            },
67            meta,
68        );
69
70        let idx = self.entry_args.len();
71        self.entry_args.push(EntryArg {
72            name: Some(name.into()),
73            binding: Binding::BuiltIn(data.builtin),
74            handle,
75            storage: data.storage,
76        });
77
78        self.global_variables.push((
79            name.into(),
80            GlobalLookup {
81                kind: GlobalLookupKind::Variable(handle),
82                entry_arg: Some(idx),
83                mutable: data.mutable,
84            },
85        ));
86
87        let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?;
88
89        let var = VariableReference {
90            expr,
91            load: true,
92            mutable: data.mutable,
93            constant: None,
94            entry_arg: Some(idx),
95        };
96
97        ctx.symbol_table.add_root(name.into(), var.clone());
98
99        Ok(Some(var))
100    }
101
102    pub(crate) fn lookup_variable(
103        &mut self,
104        ctx: &mut Context,
105        name: &str,
106        meta: Span,
107    ) -> Result<Option<VariableReference>> {
108        if let Some(var) = ctx.symbol_table.lookup(name).cloned() {
109            return Ok(Some(var));
110        }
111
112        let data = match name {
113            "gl_Position" => BuiltInData {
114                inner: TypeInner::Vector {
115                    size: VectorSize::Quad,
116                    scalar: Scalar::F32,
117                },
118                builtin: BuiltIn::Position { invariant: false },
119                mutable: true,
120                storage: StorageQualifier::Output,
121            },
122            "gl_FragCoord" => BuiltInData {
123                inner: TypeInner::Vector {
124                    size: VectorSize::Quad,
125                    scalar: Scalar::F32,
126                },
127                builtin: BuiltIn::Position { invariant: false },
128                mutable: false,
129                storage: StorageQualifier::Input,
130            },
131            "gl_PointCoord" => BuiltInData {
132                inner: TypeInner::Vector {
133                    size: VectorSize::Bi,
134                    scalar: Scalar::F32,
135                },
136                builtin: BuiltIn::PointCoord,
137                mutable: false,
138                storage: StorageQualifier::Input,
139            },
140            "gl_GlobalInvocationID"
141            | "gl_NumWorkGroups"
142            | "gl_WorkGroupSize"
143            | "gl_WorkGroupID"
144            | "gl_LocalInvocationID" => BuiltInData {
145                inner: TypeInner::Vector {
146                    size: VectorSize::Tri,
147                    scalar: Scalar::U32,
148                },
149                builtin: match name {
150                    "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId,
151                    "gl_NumWorkGroups" => BuiltIn::NumWorkGroups,
152                    "gl_WorkGroupSize" => BuiltIn::WorkGroupSize,
153                    "gl_WorkGroupID" => BuiltIn::WorkGroupId,
154                    "gl_LocalInvocationID" => BuiltIn::LocalInvocationId,
155                    _ => unreachable!(),
156                },
157                mutable: false,
158                storage: StorageQualifier::Input,
159            },
160            "gl_FrontFacing" => BuiltInData {
161                inner: TypeInner::Scalar(Scalar::BOOL),
162                builtin: BuiltIn::FrontFacing,
163                mutable: false,
164                storage: StorageQualifier::Input,
165            },
166            "gl_PointSize" | "gl_FragDepth" => BuiltInData {
167                inner: TypeInner::Scalar(Scalar::F32),
168                builtin: match name {
169                    "gl_PointSize" => BuiltIn::PointSize,
170                    "gl_FragDepth" => BuiltIn::FragDepth,
171                    _ => unreachable!(),
172                },
173                mutable: true,
174                storage: StorageQualifier::Output,
175            },
176            "gl_ClipDistance" | "gl_CullDistance" => {
177                let base = ctx.module.types.insert(
178                    Type {
179                        name: None,
180                        inner: TypeInner::Scalar(Scalar::F32),
181                    },
182                    meta,
183                );
184
185                BuiltInData {
186                    inner: TypeInner::Array {
187                        base,
188                        size: crate::ArraySize::Dynamic,
189                        stride: 4,
190                    },
191                    builtin: match name {
192                        "gl_ClipDistance" => BuiltIn::ClipDistances,
193                        "gl_CullDistance" => BuiltIn::CullDistance,
194                        _ => unreachable!(),
195                    },
196                    mutable: self.meta.stage == ShaderStage::Vertex,
197                    storage: StorageQualifier::Output,
198                }
199            }
200            _ => {
201                let builtin = match name {
202                    "gl_BaseVertex" => BuiltIn::BaseVertex,
203                    "gl_BaseInstance" => BuiltIn::BaseInstance,
204                    "gl_PrimitiveID" => BuiltIn::PrimitiveIndex,
205                    "gl_BaryCoordEXT" => BuiltIn::Barycentric { perspective: true },
206                    "gl_BaryCoordNoPerspEXT" => BuiltIn::Barycentric { perspective: false },
207                    "gl_InstanceIndex" => BuiltIn::InstanceIndex,
208                    "gl_VertexIndex" => BuiltIn::VertexIndex,
209                    "gl_SampleID" => BuiltIn::SampleIndex,
210                    "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex,
211                    "gl_DrawID" => BuiltIn::DrawIndex,
212                    _ => return Ok(None),
213                };
214
215                BuiltInData {
216                    inner: TypeInner::Scalar(Scalar::U32),
217                    builtin,
218                    mutable: false,
219                    storage: StorageQualifier::Input,
220                }
221            }
222        };
223
224        self.add_builtin(ctx, name, data, meta)
225    }
226
227    pub(crate) fn make_variable_invariant(
228        &mut self,
229        ctx: &mut Context,
230        name: &str,
231        meta: Span,
232    ) -> Result<()> {
233        if let Some(var) = self.lookup_variable(ctx, name, meta)? {
234            if let Some(index) = var.entry_arg {
235                if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) =
236                    self.entry_args[index].binding
237                {
238                    *invariant = true;
239                }
240            }
241        }
242        Ok(())
243    }
244
245    pub(crate) fn field_selection(
246        &mut self,
247        ctx: &mut Context,
248        pos: ExprPos,
249        expression: Handle<Expression>,
250        name: &str,
251        meta: Span,
252    ) -> Result<Handle<Expression>> {
253        let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? {
254            TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true),
255            ref ty => (ty, false),
256        };
257        match *ty {
258            TypeInner::Struct { ref members, .. } => {
259                let index = members
260                    .iter()
261                    .position(|m| m.name == Some(name.into()))
262                    .ok_or_else(|| Error {
263                        kind: ErrorKind::UnknownField(name.into()),
264                        meta,
265                    })?;
266                let pointer = ctx.add_expression(
267                    Expression::AccessIndex {
268                        base: expression,
269                        index: index as u32,
270                    },
271                    meta,
272                )?;
273
274                Ok(match pos {
275                    ExprPos::Rhs if is_pointer => {
276                        ctx.add_expression(Expression::Load { pointer }, meta)?
277                    }
278                    _ => pointer,
279                })
280            }
281            // swizzles (xyzw, rgba, stpq)
282            TypeInner::Vector { size, .. } => {
283                let check_swizzle_components = |comps: &str| {
284                    name.chars()
285                        .map(|c| {
286                            comps
287                                .find(c)
288                                .filter(|i| *i < size as usize)
289                                .map(|i| SwizzleComponent::from_index(i as u32))
290                        })
291                        .collect::<Option<Vec<SwizzleComponent>>>()
292                };
293
294                let components = check_swizzle_components("xyzw")
295                    .or_else(|| check_swizzle_components("rgba"))
296                    .or_else(|| check_swizzle_components("stpq"));
297
298                if let Some(components) = components {
299                    if let ExprPos::Lhs = pos {
300                        let not_unique = (1..components.len())
301                            .any(|i| components[i..].contains(&components[i - 1]));
302                        if not_unique {
303                            self.errors.push(Error {
304                                kind: ErrorKind::SemanticError(
305                                    format!(
306                                        concat!(
307                                            "swizzle cannot have duplicate components in ",
308                                            "left-hand-side expression for \"{:?}\""
309                                        ),
310                                        name
311                                    )
312                                    .into(),
313                                ),
314                                meta,
315                            })
316                        }
317                    }
318
319                    let mut pattern = [SwizzleComponent::X; 4];
320                    for (pat, component) in pattern.iter_mut().zip(&components) {
321                        *pat = *component;
322                    }
323
324                    // flatten nested swizzles (vec.zyx.xy.x => vec.z)
325                    let mut expression = expression;
326                    if let Expression::Swizzle {
327                        size: _,
328                        vector,
329                        pattern: ref src_pattern,
330                    } = ctx[expression]
331                    {
332                        expression = vector;
333                        for pat in &mut pattern {
334                            *pat = src_pattern[pat.index() as usize];
335                        }
336                    }
337
338                    let size = match components.len() {
339                        // Swizzles with just one component are accesses and not swizzles
340                        1 => {
341                            match pos {
342                                // If the position is in the right hand side and the base
343                                // vector is a pointer, load it, otherwise the swizzle would
344                                // produce a pointer
345                                ExprPos::Rhs if is_pointer => {
346                                    expression = ctx.add_expression(
347                                        Expression::Load {
348                                            pointer: expression,
349                                        },
350                                        meta,
351                                    )?;
352                                }
353                                _ => {}
354                            };
355                            return ctx.add_expression(
356                                Expression::AccessIndex {
357                                    base: expression,
358                                    index: pattern[0].index(),
359                                },
360                                meta,
361                            );
362                        }
363                        2 => VectorSize::Bi,
364                        3 => VectorSize::Tri,
365                        4 => VectorSize::Quad,
366                        _ => {
367                            self.errors.push(Error {
368                                kind: ErrorKind::SemanticError(
369                                    format!("Bad swizzle size for \"{name:?}\"").into(),
370                                ),
371                                meta,
372                            });
373
374                            VectorSize::Quad
375                        }
376                    };
377
378                    if is_pointer {
379                        // NOTE: for lhs expression, this extra load ends up as an unused expr, because the
380                        // assignment will extract the pointer and use it directly anyway. Unfortunately we
381                        // need it for validation to pass, as swizzles cannot operate on pointer values.
382                        expression = ctx.add_expression(
383                            Expression::Load {
384                                pointer: expression,
385                            },
386                            meta,
387                        )?;
388                    }
389
390                    Ok(ctx.add_expression(
391                        Expression::Swizzle {
392                            size,
393                            vector: expression,
394                            pattern,
395                        },
396                        meta,
397                    )?)
398                } else {
399                    Err(Error {
400                        kind: ErrorKind::SemanticError(
401                            format!("Invalid swizzle for vector \"{name}\"").into(),
402                        ),
403                        meta,
404                    })
405                }
406            }
407            _ => Err(Error {
408                kind: ErrorKind::SemanticError(
409                    format!("Can't lookup field on this type \"{name}\"").into(),
410                ),
411                meta,
412            }),
413        }
414    }
415
416    pub(crate) fn add_global_var(
417        &mut self,
418        ctx: &mut Context,
419        VarDeclaration {
420            qualifiers,
421            mut ty,
422            name,
423            init,
424            meta,
425        }: VarDeclaration,
426    ) -> Result<GlobalOrConstant> {
427        let storage = qualifiers.storage.0;
428        let (ret, lookup) = match storage {
429            StorageQualifier::Input | StorageQualifier::Output => {
430                let input = storage == StorageQualifier::Input;
431                // TODO: glslang seems to use a counter for variables without
432                // explicit location (even if that causes collisions)
433                let location = qualifiers
434                    .uint_layout_qualifier("location", &mut self.errors)
435                    .unwrap_or(0);
436                let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| {
437                    let kind = ctx.module.types[ty].inner.scalar_kind()?;
438                    Some(match kind {
439                        ScalarKind::Float => Interpolation::Perspective,
440                        _ => Interpolation::Flat,
441                    })
442                });
443                let sampling = qualifiers.sampling.take().map(|(s, _)| s);
444
445                let handle = ctx.module.global_variables.append(
446                    GlobalVariable {
447                        name: name.clone(),
448                        space: AddressSpace::Private,
449                        binding: None,
450                        ty,
451                        init,
452                        memory_decorations: crate::MemoryDecorations::empty(),
453                    },
454                    meta,
455                );
456
457                let blend_src = qualifiers
458                    .layout_qualifiers
459                    .remove(&QualifierKey::Index)
460                    .and_then(|(value, _span)| match value {
461                        QualifierValue::Uint(index) => Some(index),
462                        _ => None,
463                    });
464
465                let idx = self.entry_args.len();
466                self.entry_args.push(EntryArg {
467                    name: name.clone(),
468                    binding: Binding::Location {
469                        location,
470                        interpolation,
471                        sampling,
472                        blend_src,
473                        per_primitive: false,
474                    },
475                    handle,
476                    storage,
477                });
478
479                let lookup = GlobalLookup {
480                    kind: GlobalLookupKind::Variable(handle),
481                    entry_arg: Some(idx),
482                    mutable: !input,
483                };
484
485                (GlobalOrConstant::Global(handle), lookup)
486            }
487            StorageQualifier::Const => {
488                // Check if this is a specialization constant with constant_id
489                let constant_id = qualifiers.uint_layout_qualifier("constant_id", &mut self.errors);
490
491                if let Some(id) = constant_id {
492                    // This is a specialization constant - convert to Override
493                    let id: Option<u16> = match id.try_into() {
494                        Ok(v) => Some(v),
495                        Err(_) => {
496                            self.errors.push(Error {
497                                kind: ErrorKind::SemanticError(
498                                    format!(
499                                        "constant_id value {id} is too high (maximum is {})",
500                                        u16::MAX
501                                    )
502                                    .into(),
503                                ),
504                                meta,
505                            });
506                            None
507                        }
508                    };
509
510                    let override_handle = ctx.module.overrides.append(
511                        Override {
512                            name: name.clone(),
513                            id,
514                            ty,
515                            init,
516                        },
517                        meta,
518                    );
519
520                    let lookup = GlobalLookup {
521                        kind: GlobalLookupKind::Override(override_handle, ty),
522                        entry_arg: None,
523                        mutable: false,
524                    };
525
526                    (GlobalOrConstant::Override(override_handle), lookup)
527                } else {
528                    // Regular constant
529                    let init = init.ok_or_else(|| Error {
530                        kind: ErrorKind::SemanticError(
531                            "const values must have an initializer".into(),
532                        ),
533                        meta,
534                    })?;
535
536                    let constant = Constant {
537                        name: name.clone(),
538                        ty,
539                        init,
540                    };
541                    let handle = ctx.module.constants.append(constant, meta);
542
543                    let lookup = GlobalLookup {
544                        kind: GlobalLookupKind::Constant(handle, ty),
545                        entry_arg: None,
546                        mutable: false,
547                    };
548
549                    (GlobalOrConstant::Constant(handle), lookup)
550                }
551            }
552            StorageQualifier::AddressSpace(mut space) => {
553                match space {
554                    AddressSpace::Storage { ref mut access } => {
555                        if let Some((allowed_access, _)) = qualifiers.storage_access.take() {
556                            *access = allowed_access;
557                        }
558                    }
559                    AddressSpace::Uniform => match ctx.module.types[ty].inner {
560                        TypeInner::Image {
561                            class,
562                            dim,
563                            arrayed,
564                        } => {
565                            if let crate::ImageClass::Storage {
566                                mut access,
567                                mut format,
568                            } = class
569                            {
570                                if let Some((allowed_access, _)) = qualifiers.storage_access.take()
571                                {
572                                    access = allowed_access;
573                                }
574
575                                match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) {
576                                    Some((QualifierValue::Format(f), _)) => format = f,
577                                    // TODO: glsl supports images without format qualifier
578                                    // if they are `writeonly`
579                                    None => self.errors.push(Error {
580                                        kind: ErrorKind::SemanticError(
581                                            "image types require a format layout qualifier".into(),
582                                        ),
583                                        meta,
584                                    }),
585                                    _ => unreachable!(),
586                                }
587
588                                ty = ctx.module.types.insert(
589                                    Type {
590                                        name: None,
591                                        inner: TypeInner::Image {
592                                            dim,
593                                            arrayed,
594                                            class: crate::ImageClass::Storage { format, access },
595                                        },
596                                    },
597                                    meta,
598                                );
599                            }
600
601                            space = AddressSpace::Handle
602                        }
603                        TypeInner::Sampler { .. } => space = AddressSpace::Handle,
604                        _ => {
605                            if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) {
606                                space = AddressSpace::Immediate
607                            }
608                        }
609                    },
610                    AddressSpace::Function => space = AddressSpace::Private,
611                    _ => {}
612                };
613
614                let binding = match space {
615                    AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => {
616                        let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors);
617                        if binding.is_none() {
618                            self.errors.push(Error {
619                                kind: ErrorKind::SemanticError(
620                                    "uniform/buffer blocks require layout(binding=X)".into(),
621                                ),
622                                meta,
623                            });
624                        }
625                        let set = qualifiers.uint_layout_qualifier("set", &mut self.errors);
626                        binding.map(|binding| ResourceBinding {
627                            group: set.unwrap_or(0),
628                            binding,
629                        })
630                    }
631                    _ => None,
632                };
633
634                let handle = ctx.module.global_variables.append(
635                    GlobalVariable {
636                        name: name.clone(),
637                        space,
638                        binding,
639                        ty,
640                        init,
641                        memory_decorations: crate::MemoryDecorations::empty(),
642                    },
643                    meta,
644                );
645
646                let lookup = GlobalLookup {
647                    kind: GlobalLookupKind::Variable(handle),
648                    entry_arg: None,
649                    mutable: true,
650                };
651
652                (GlobalOrConstant::Global(handle), lookup)
653            }
654        };
655
656        if let Some(name) = name {
657            ctx.add_global(&name, lookup)?;
658
659            self.global_variables.push((name, lookup));
660        }
661
662        qualifiers.unused_errors(&mut self.errors);
663
664        Ok(ret)
665    }
666
667    pub(crate) fn add_local_var(
668        &mut self,
669        ctx: &mut Context,
670        decl: VarDeclaration,
671    ) -> Result<Handle<Expression>> {
672        let storage = decl.qualifiers.storage;
673        let mutable = match storage.0 {
674            StorageQualifier::AddressSpace(AddressSpace::Function) => true,
675            StorageQualifier::Const => false,
676            _ => {
677                self.errors.push(Error {
678                    kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()),
679                    meta: storage.1,
680                });
681                true
682            }
683        };
684
685        let handle = ctx.locals.append(
686            LocalVariable {
687                name: decl.name.clone(),
688                ty: decl.ty,
689                init: decl.init,
690            },
691            decl.meta,
692        );
693        let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?;
694
695        if let Some(name) = decl.name {
696            let maybe_var = ctx.add_local_var(name.clone(), expr, mutable);
697
698            if maybe_var.is_some() {
699                self.errors.push(Error {
700                    kind: ErrorKind::VariableAlreadyDeclared(name),
701                    meta: decl.meta,
702                })
703            }
704        }
705
706        decl.qualifiers.unused_errors(&mut self.errors);
707
708        Ok(expr)
709    }
710}