naga/front/glsl/
variables.rs

1use alloc::{format, string::String, vec::Vec};
2
3use super::{
4    ast::*,
5    context::{Context, ExprPos},
6    error::{Error, ErrorKind},
7    Frontend, Result, Span,
8};
9use crate::{
10    AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation,
11    LocalVariable, Override, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent,
12    Type, TypeInner, VectorSize,
13};
14
15pub struct VarDeclaration<'a, 'key> {
16    pub qualifiers: &'a mut TypeQualifiers<'key>,
17    pub ty: Handle<Type>,
18    pub name: Option<String>,
19    pub init: Option<Handle<Expression>>,
20    pub meta: Span,
21}
22
23/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin).
24struct BuiltInData {
25    /// The type of the builtin.
26    inner: TypeInner,
27    /// The associated builtin class.
28    builtin: BuiltIn,
29    /// Whether the builtin can be written to or not.
30    mutable: bool,
31    /// The storage used for the builtin.
32    storage: StorageQualifier,
33}
34
35pub enum GlobalOrConstant {
36    Global(Handle<GlobalVariable>),
37    Constant(Handle<Constant>),
38    Override(Handle<Override>),
39}
40
41impl Frontend {
42    /// Adds a builtin and returns a variable reference to it
43    fn add_builtin(
44        &mut self,
45        ctx: &mut Context,
46        name: &str,
47        data: BuiltInData,
48        meta: Span,
49    ) -> Result<Option<VariableReference>> {
50        let ty = ctx.module.types.insert(
51            Type {
52                name: None,
53                inner: data.inner,
54            },
55            meta,
56        );
57
58        let handle = ctx.module.global_variables.append(
59            GlobalVariable {
60                name: Some(name.into()),
61                space: AddressSpace::Private,
62                binding: None,
63                ty,
64                init: None,
65            },
66            meta,
67        );
68
69        let idx = self.entry_args.len();
70        self.entry_args.push(EntryArg {
71            name: Some(name.into()),
72            binding: Binding::BuiltIn(data.builtin),
73            handle,
74            storage: data.storage,
75        });
76
77        self.global_variables.push((
78            name.into(),
79            GlobalLookup {
80                kind: GlobalLookupKind::Variable(handle),
81                entry_arg: Some(idx),
82                mutable: data.mutable,
83            },
84        ));
85
86        let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?;
87
88        let var = VariableReference {
89            expr,
90            load: true,
91            mutable: data.mutable,
92            constant: None,
93            entry_arg: Some(idx),
94        };
95
96        ctx.symbol_table.add_root(name.into(), var.clone());
97
98        Ok(Some(var))
99    }
100
101    pub(crate) fn lookup_variable(
102        &mut self,
103        ctx: &mut Context,
104        name: &str,
105        meta: Span,
106    ) -> Result<Option<VariableReference>> {
107        if let Some(var) = ctx.symbol_table.lookup(name).cloned() {
108            return Ok(Some(var));
109        }
110
111        let data = match name {
112            "gl_Position" => BuiltInData {
113                inner: TypeInner::Vector {
114                    size: VectorSize::Quad,
115                    scalar: Scalar::F32,
116                },
117                builtin: BuiltIn::Position { invariant: false },
118                mutable: true,
119                storage: StorageQualifier::Output,
120            },
121            "gl_FragCoord" => BuiltInData {
122                inner: TypeInner::Vector {
123                    size: VectorSize::Quad,
124                    scalar: Scalar::F32,
125                },
126                builtin: BuiltIn::Position { invariant: false },
127                mutable: false,
128                storage: StorageQualifier::Input,
129            },
130            "gl_PointCoord" => BuiltInData {
131                inner: TypeInner::Vector {
132                    size: VectorSize::Bi,
133                    scalar: Scalar::F32,
134                },
135                builtin: BuiltIn::PointCoord,
136                mutable: false,
137                storage: StorageQualifier::Input,
138            },
139            "gl_GlobalInvocationID"
140            | "gl_NumWorkGroups"
141            | "gl_WorkGroupSize"
142            | "gl_WorkGroupID"
143            | "gl_LocalInvocationID" => BuiltInData {
144                inner: TypeInner::Vector {
145                    size: VectorSize::Tri,
146                    scalar: Scalar::U32,
147                },
148                builtin: match name {
149                    "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId,
150                    "gl_NumWorkGroups" => BuiltIn::NumWorkGroups,
151                    "gl_WorkGroupSize" => BuiltIn::WorkGroupSize,
152                    "gl_WorkGroupID" => BuiltIn::WorkGroupId,
153                    "gl_LocalInvocationID" => BuiltIn::LocalInvocationId,
154                    _ => unreachable!(),
155                },
156                mutable: false,
157                storage: StorageQualifier::Input,
158            },
159            "gl_FrontFacing" => BuiltInData {
160                inner: TypeInner::Scalar(Scalar::BOOL),
161                builtin: BuiltIn::FrontFacing,
162                mutable: false,
163                storage: StorageQualifier::Input,
164            },
165            "gl_PointSize" | "gl_FragDepth" => BuiltInData {
166                inner: TypeInner::Scalar(Scalar::F32),
167                builtin: match name {
168                    "gl_PointSize" => BuiltIn::PointSize,
169                    "gl_FragDepth" => BuiltIn::FragDepth,
170                    _ => unreachable!(),
171                },
172                mutable: true,
173                storage: StorageQualifier::Output,
174            },
175            "gl_ClipDistance" | "gl_CullDistance" => {
176                let base = ctx.module.types.insert(
177                    Type {
178                        name: None,
179                        inner: TypeInner::Scalar(Scalar::F32),
180                    },
181                    meta,
182                );
183
184                BuiltInData {
185                    inner: TypeInner::Array {
186                        base,
187                        size: crate::ArraySize::Dynamic,
188                        stride: 4,
189                    },
190                    builtin: match name {
191                        "gl_ClipDistance" => BuiltIn::ClipDistance,
192                        "gl_CullDistance" => BuiltIn::CullDistance,
193                        _ => unreachable!(),
194                    },
195                    mutable: self.meta.stage == ShaderStage::Vertex,
196                    storage: StorageQualifier::Output,
197                }
198            }
199            _ => {
200                let builtin = match name {
201                    "gl_BaseVertex" => BuiltIn::BaseVertex,
202                    "gl_BaseInstance" => BuiltIn::BaseInstance,
203                    "gl_PrimitiveID" => BuiltIn::PrimitiveIndex,
204                    "gl_BaryCoordEXT" => BuiltIn::Barycentric { perspective: true },
205                    "gl_BaryCoordNoPerspEXT" => BuiltIn::Barycentric { perspective: false },
206                    "gl_InstanceIndex" => BuiltIn::InstanceIndex,
207                    "gl_VertexIndex" => BuiltIn::VertexIndex,
208                    "gl_SampleID" => BuiltIn::SampleIndex,
209                    "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex,
210                    "gl_DrawID" => BuiltIn::DrawID,
211                    _ => return Ok(None),
212                };
213
214                BuiltInData {
215                    inner: TypeInner::Scalar(Scalar::U32),
216                    builtin,
217                    mutable: false,
218                    storage: StorageQualifier::Input,
219                }
220            }
221        };
222
223        self.add_builtin(ctx, name, data, meta)
224    }
225
226    pub(crate) fn make_variable_invariant(
227        &mut self,
228        ctx: &mut Context,
229        name: &str,
230        meta: Span,
231    ) -> Result<()> {
232        if let Some(var) = self.lookup_variable(ctx, name, meta)? {
233            if let Some(index) = var.entry_arg {
234                if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) =
235                    self.entry_args[index].binding
236                {
237                    *invariant = true;
238                }
239            }
240        }
241        Ok(())
242    }
243
244    pub(crate) fn field_selection(
245        &mut self,
246        ctx: &mut Context,
247        pos: ExprPos,
248        expression: Handle<Expression>,
249        name: &str,
250        meta: Span,
251    ) -> Result<Handle<Expression>> {
252        let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? {
253            TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true),
254            ref ty => (ty, false),
255        };
256        match *ty {
257            TypeInner::Struct { ref members, .. } => {
258                let index = members
259                    .iter()
260                    .position(|m| m.name == Some(name.into()))
261                    .ok_or_else(|| Error {
262                        kind: ErrorKind::UnknownField(name.into()),
263                        meta,
264                    })?;
265                let pointer = ctx.add_expression(
266                    Expression::AccessIndex {
267                        base: expression,
268                        index: index as u32,
269                    },
270                    meta,
271                )?;
272
273                Ok(match pos {
274                    ExprPos::Rhs if is_pointer => {
275                        ctx.add_expression(Expression::Load { pointer }, meta)?
276                    }
277                    _ => pointer,
278                })
279            }
280            // swizzles (xyzw, rgba, stpq)
281            TypeInner::Vector { size, .. } => {
282                let check_swizzle_components = |comps: &str| {
283                    name.chars()
284                        .map(|c| {
285                            comps
286                                .find(c)
287                                .filter(|i| *i < size as usize)
288                                .map(|i| SwizzleComponent::from_index(i as u32))
289                        })
290                        .collect::<Option<Vec<SwizzleComponent>>>()
291                };
292
293                let components = check_swizzle_components("xyzw")
294                    .or_else(|| check_swizzle_components("rgba"))
295                    .or_else(|| check_swizzle_components("stpq"));
296
297                if let Some(components) = components {
298                    if let ExprPos::Lhs = pos {
299                        let not_unique = (1..components.len())
300                            .any(|i| components[i..].contains(&components[i - 1]));
301                        if not_unique {
302                            self.errors.push(Error {
303                                kind: ErrorKind::SemanticError(
304                                    format!(
305                                        concat!(
306                                            "swizzle cannot have duplicate components in ",
307                                            "left-hand-side expression for \"{:?}\""
308                                        ),
309                                        name
310                                    )
311                                    .into(),
312                                ),
313                                meta,
314                            })
315                        }
316                    }
317
318                    let mut pattern = [SwizzleComponent::X; 4];
319                    for (pat, component) in pattern.iter_mut().zip(&components) {
320                        *pat = *component;
321                    }
322
323                    // flatten nested swizzles (vec.zyx.xy.x => vec.z)
324                    let mut expression = expression;
325                    if let Expression::Swizzle {
326                        size: _,
327                        vector,
328                        pattern: ref src_pattern,
329                    } = ctx[expression]
330                    {
331                        expression = vector;
332                        for pat in &mut pattern {
333                            *pat = src_pattern[pat.index() as usize];
334                        }
335                    }
336
337                    let size = match components.len() {
338                        // Swizzles with just one component are accesses and not swizzles
339                        1 => {
340                            match pos {
341                                // If the position is in the right hand side and the base
342                                // vector is a pointer, load it, otherwise the swizzle would
343                                // produce a pointer
344                                ExprPos::Rhs if is_pointer => {
345                                    expression = ctx.add_expression(
346                                        Expression::Load {
347                                            pointer: expression,
348                                        },
349                                        meta,
350                                    )?;
351                                }
352                                _ => {}
353                            };
354                            return ctx.add_expression(
355                                Expression::AccessIndex {
356                                    base: expression,
357                                    index: pattern[0].index(),
358                                },
359                                meta,
360                            );
361                        }
362                        2 => VectorSize::Bi,
363                        3 => VectorSize::Tri,
364                        4 => VectorSize::Quad,
365                        _ => {
366                            self.errors.push(Error {
367                                kind: ErrorKind::SemanticError(
368                                    format!("Bad swizzle size for \"{name:?}\"").into(),
369                                ),
370                                meta,
371                            });
372
373                            VectorSize::Quad
374                        }
375                    };
376
377                    if is_pointer {
378                        // NOTE: for lhs expression, this extra load ends up as an unused expr, because the
379                        // assignment will extract the pointer and use it directly anyway. Unfortunately we
380                        // need it for validation to pass, as swizzles cannot operate on pointer values.
381                        expression = ctx.add_expression(
382                            Expression::Load {
383                                pointer: expression,
384                            },
385                            meta,
386                        )?;
387                    }
388
389                    Ok(ctx.add_expression(
390                        Expression::Swizzle {
391                            size,
392                            vector: expression,
393                            pattern,
394                        },
395                        meta,
396                    )?)
397                } else {
398                    Err(Error {
399                        kind: ErrorKind::SemanticError(
400                            format!("Invalid swizzle for vector \"{name}\"").into(),
401                        ),
402                        meta,
403                    })
404                }
405            }
406            _ => Err(Error {
407                kind: ErrorKind::SemanticError(
408                    format!("Can't lookup field on this type \"{name}\"").into(),
409                ),
410                meta,
411            }),
412        }
413    }
414
415    pub(crate) fn add_global_var(
416        &mut self,
417        ctx: &mut Context,
418        VarDeclaration {
419            qualifiers,
420            mut ty,
421            name,
422            init,
423            meta,
424        }: VarDeclaration,
425    ) -> Result<GlobalOrConstant> {
426        let storage = qualifiers.storage.0;
427        let (ret, lookup) = match storage {
428            StorageQualifier::Input | StorageQualifier::Output => {
429                let input = storage == StorageQualifier::Input;
430                // TODO: glslang seems to use a counter for variables without
431                // explicit location (even if that causes collisions)
432                let location = qualifiers
433                    .uint_layout_qualifier("location", &mut self.errors)
434                    .unwrap_or(0);
435                let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| {
436                    let kind = ctx.module.types[ty].inner.scalar_kind()?;
437                    Some(match kind {
438                        ScalarKind::Float => Interpolation::Perspective,
439                        _ => Interpolation::Flat,
440                    })
441                });
442                let sampling = qualifiers.sampling.take().map(|(s, _)| s);
443
444                let handle = ctx.module.global_variables.append(
445                    GlobalVariable {
446                        name: name.clone(),
447                        space: AddressSpace::Private,
448                        binding: None,
449                        ty,
450                        init,
451                    },
452                    meta,
453                );
454
455                let blend_src = qualifiers
456                    .layout_qualifiers
457                    .remove(&QualifierKey::Index)
458                    .and_then(|(value, _span)| match value {
459                        QualifierValue::Uint(index) => Some(index),
460                        _ => None,
461                    });
462
463                let idx = self.entry_args.len();
464                self.entry_args.push(EntryArg {
465                    name: name.clone(),
466                    binding: Binding::Location {
467                        location,
468                        interpolation,
469                        sampling,
470                        blend_src,
471                        per_primitive: false,
472                    },
473                    handle,
474                    storage,
475                });
476
477                let lookup = GlobalLookup {
478                    kind: GlobalLookupKind::Variable(handle),
479                    entry_arg: Some(idx),
480                    mutable: !input,
481                };
482
483                (GlobalOrConstant::Global(handle), lookup)
484            }
485            StorageQualifier::Const => {
486                // Check if this is a specialization constant with constant_id
487                let constant_id = qualifiers.uint_layout_qualifier("constant_id", &mut self.errors);
488
489                if let Some(id) = constant_id {
490                    // This is a specialization constant - convert to Override
491                    let id: Option<u16> = match id.try_into() {
492                        Ok(v) => Some(v),
493                        Err(_) => {
494                            self.errors.push(Error {
495                                kind: ErrorKind::SemanticError(
496                                    format!(
497                                        "constant_id value {id} is too high (maximum is {})",
498                                        u16::MAX
499                                    )
500                                    .into(),
501                                ),
502                                meta,
503                            });
504                            None
505                        }
506                    };
507
508                    let override_handle = ctx.module.overrides.append(
509                        Override {
510                            name: name.clone(),
511                            id,
512                            ty,
513                            init,
514                        },
515                        meta,
516                    );
517
518                    let lookup = GlobalLookup {
519                        kind: GlobalLookupKind::Override(override_handle, ty),
520                        entry_arg: None,
521                        mutable: false,
522                    };
523
524                    (GlobalOrConstant::Override(override_handle), lookup)
525                } else {
526                    // Regular constant
527                    let init = init.ok_or_else(|| Error {
528                        kind: ErrorKind::SemanticError(
529                            "const values must have an initializer".into(),
530                        ),
531                        meta,
532                    })?;
533
534                    let constant = Constant {
535                        name: name.clone(),
536                        ty,
537                        init,
538                    };
539                    let handle = ctx.module.constants.append(constant, meta);
540
541                    let lookup = GlobalLookup {
542                        kind: GlobalLookupKind::Constant(handle, ty),
543                        entry_arg: None,
544                        mutable: false,
545                    };
546
547                    (GlobalOrConstant::Constant(handle), lookup)
548                }
549            }
550            StorageQualifier::AddressSpace(mut space) => {
551                match space {
552                    AddressSpace::Storage { ref mut access } => {
553                        if let Some((allowed_access, _)) = qualifiers.storage_access.take() {
554                            *access = allowed_access;
555                        }
556                    }
557                    AddressSpace::Uniform => match ctx.module.types[ty].inner {
558                        TypeInner::Image {
559                            class,
560                            dim,
561                            arrayed,
562                        } => {
563                            if let crate::ImageClass::Storage {
564                                mut access,
565                                mut format,
566                            } = class
567                            {
568                                if let Some((allowed_access, _)) = qualifiers.storage_access.take()
569                                {
570                                    access = allowed_access;
571                                }
572
573                                match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) {
574                                    Some((QualifierValue::Format(f), _)) => format = f,
575                                    // TODO: glsl supports images without format qualifier
576                                    // if they are `writeonly`
577                                    None => self.errors.push(Error {
578                                        kind: ErrorKind::SemanticError(
579                                            "image types require a format layout qualifier".into(),
580                                        ),
581                                        meta,
582                                    }),
583                                    _ => unreachable!(),
584                                }
585
586                                ty = ctx.module.types.insert(
587                                    Type {
588                                        name: None,
589                                        inner: TypeInner::Image {
590                                            dim,
591                                            arrayed,
592                                            class: crate::ImageClass::Storage { format, access },
593                                        },
594                                    },
595                                    meta,
596                                );
597                            }
598
599                            space = AddressSpace::Handle
600                        }
601                        TypeInner::Sampler { .. } => space = AddressSpace::Handle,
602                        _ => {
603                            if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) {
604                                space = AddressSpace::Immediate
605                            }
606                        }
607                    },
608                    AddressSpace::Function => space = AddressSpace::Private,
609                    _ => {}
610                };
611
612                let binding = match space {
613                    AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => {
614                        let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors);
615                        if binding.is_none() {
616                            self.errors.push(Error {
617                                kind: ErrorKind::SemanticError(
618                                    "uniform/buffer blocks require layout(binding=X)".into(),
619                                ),
620                                meta,
621                            });
622                        }
623                        let set = qualifiers.uint_layout_qualifier("set", &mut self.errors);
624                        binding.map(|binding| ResourceBinding {
625                            group: set.unwrap_or(0),
626                            binding,
627                        })
628                    }
629                    _ => None,
630                };
631
632                let handle = ctx.module.global_variables.append(
633                    GlobalVariable {
634                        name: name.clone(),
635                        space,
636                        binding,
637                        ty,
638                        init,
639                    },
640                    meta,
641                );
642
643                let lookup = GlobalLookup {
644                    kind: GlobalLookupKind::Variable(handle),
645                    entry_arg: None,
646                    mutable: true,
647                };
648
649                (GlobalOrConstant::Global(handle), lookup)
650            }
651        };
652
653        if let Some(name) = name {
654            ctx.add_global(&name, lookup)?;
655
656            self.global_variables.push((name, lookup));
657        }
658
659        qualifiers.unused_errors(&mut self.errors);
660
661        Ok(ret)
662    }
663
664    pub(crate) fn add_local_var(
665        &mut self,
666        ctx: &mut Context,
667        decl: VarDeclaration,
668    ) -> Result<Handle<Expression>> {
669        let storage = decl.qualifiers.storage;
670        let mutable = match storage.0 {
671            StorageQualifier::AddressSpace(AddressSpace::Function) => true,
672            StorageQualifier::Const => false,
673            _ => {
674                self.errors.push(Error {
675                    kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()),
676                    meta: storage.1,
677                });
678                true
679            }
680        };
681
682        let handle = ctx.locals.append(
683            LocalVariable {
684                name: decl.name.clone(),
685                ty: decl.ty,
686                init: decl.init,
687            },
688            decl.meta,
689        );
690        let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?;
691
692        if let Some(name) = decl.name {
693            let maybe_var = ctx.add_local_var(name.clone(), expr, mutable);
694
695            if maybe_var.is_some() {
696                self.errors.push(Error {
697                    kind: ErrorKind::VariableAlreadyDeclared(name),
698                    meta: decl.meta,
699                })
700            }
701        }
702
703        decl.qualifiers.unused_errors(&mut self.errors);
704
705        Ok(expr)
706    }
707}