1use alloc::format;
59use core::{fmt, mem};
60
61use super::{super::FunctionCtx, BackendResult, Error};
62use crate::{
63 proc::{Alignment, NameKey, TypeResolution},
64 Handle,
65};
66
67const STORE_TEMP_NAME: &str = "_value";
68
69#[derive(Debug)]
79pub(super) enum SubAccess {
80 BufferOffset {
81 group: u32,
82 offset: u32,
83 },
84
85 Offset(u32),
89
90 Index {
94 value: Handle<crate::Expression>,
95 stride: u32,
96 },
97}
98
99pub(super) enum StoreValue {
100 Expression(Handle<crate::Expression>),
101 TempIndex {
102 depth: usize,
103 index: u32,
104 ty: TypeResolution,
105 },
106 TempAccess {
107 depth: usize,
108 base: Handle<crate::Type>,
109 member_index: u32,
110 },
111 TempColumnAccess {
113 depth: usize,
114 base: Handle<crate::Type>,
115 member_index: u32,
116 column: u32,
117 },
118}
119
120impl<W: fmt::Write> super::Writer<'_, W> {
121 pub(super) fn write_storage_address(
122 &mut self,
123 module: &crate::Module,
124 chain: &[SubAccess],
125 func_ctx: &FunctionCtx,
126 ) -> BackendResult {
127 if chain.is_empty() {
128 write!(self.out, "0")?;
129 }
130 for (i, access) in chain.iter().enumerate() {
131 if i != 0 {
132 write!(self.out, "+")?;
133 }
134 match *access {
135 SubAccess::BufferOffset { group, offset } => {
136 write!(self.out, "__dynamic_buffer_offsets{group}._{offset}")?;
137 }
138 SubAccess::Offset(offset) => {
139 write!(self.out, "{offset}")?;
140 }
141 SubAccess::Index { value, stride } => {
142 self.write_expr(module, value, func_ctx)?;
143 write!(self.out, "*{stride}")?;
144 }
145 }
146 }
147 Ok(())
148 }
149
150 fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
151 &mut self,
152 module: &crate::Module,
153 var_handle: Handle<crate::GlobalVariable>,
154 sequence: I,
155 func_ctx: &FunctionCtx,
156 ) -> BackendResult {
157 for (i, (ty_resolution, offset)) in sequence.enumerate() {
158 self.temp_access_chain.push(SubAccess::Offset(offset));
160 if i != 0 {
161 write!(self.out, ", ")?;
162 };
163 self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
164 self.temp_access_chain.pop();
165 }
166 Ok(())
167 }
168
169 pub(super) fn write_storage_load(
180 &mut self,
181 module: &crate::Module,
182 var_handle: Handle<crate::GlobalVariable>,
183 result_ty: TypeResolution,
184 func_ctx: &FunctionCtx,
185 ) -> BackendResult {
186 match *result_ty.inner_with(&module.types) {
187 crate::TypeInner::Scalar(scalar) => {
188 let chain = mem::take(&mut self.temp_access_chain);
190 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
191 if scalar.width == 4 {
193 let cast = scalar.kind.to_hlsl_cast();
194 write!(self.out, "{cast}({var_name}.Load(")?;
195 } else {
196 let ty = scalar.to_hlsl_str()?;
197 write!(self.out, "{var_name}.Load<{ty}>(")?;
198 };
199 self.write_storage_address(module, &chain, func_ctx)?;
200 write!(self.out, ")")?;
201 if scalar.width == 4 {
202 write!(self.out, ")")?;
203 }
204 self.temp_access_chain = chain;
205 }
206 crate::TypeInner::Vector { size, scalar } => {
207 let chain = mem::take(&mut self.temp_access_chain);
209 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
210 let size = size as u8;
211 if scalar.width == 4 {
213 let cast = scalar.kind.to_hlsl_cast();
214 write!(self.out, "{cast}({var_name}.Load{size}(")?;
215 } else {
216 let ty = scalar.to_hlsl_str()?;
217 write!(self.out, "{var_name}.Load<{ty}{size}>(")?;
218 };
219 self.write_storage_address(module, &chain, func_ctx)?;
220 write!(self.out, ")")?;
221 if scalar.width == 4 {
222 write!(self.out, ")")?;
223 }
224 self.temp_access_chain = chain;
225 }
226 crate::TypeInner::Matrix {
227 columns,
228 rows,
229 scalar,
230 } => {
231 write!(
232 self.out,
233 "{}{}x{}(",
234 scalar.to_hlsl_str()?,
235 columns as u8,
236 rows as u8,
237 )?;
238
239 let row_stride = Alignment::from(rows) * scalar.width as u32;
241 let iter = (0..columns as u32).map(|i| {
242 let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
243 (TypeResolution::Value(ty_inner), i * row_stride)
244 });
245 self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
246 write!(self.out, ")")?;
247 }
248 crate::TypeInner::Array {
249 base,
250 size: crate::ArraySize::Constant(size),
251 stride,
252 } => {
253 let constructor = super::help::WrappedConstructor {
254 ty: result_ty.handle().unwrap(),
255 };
256 self.write_wrapped_constructor_function_name(module, constructor)?;
257 write!(self.out, "(")?;
258 let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
259 self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
260 write!(self.out, ")")?;
261 }
262 crate::TypeInner::Struct { ref members, .. } => {
263 let constructor = super::help::WrappedConstructor {
264 ty: result_ty.handle().unwrap(),
265 };
266 self.write_wrapped_constructor_function_name(module, constructor)?;
267 write!(self.out, "(")?;
268 let iter = members
269 .iter()
270 .map(|m| (TypeResolution::Handle(m.ty), m.offset));
271 self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
272 write!(self.out, ")")?;
273 }
274 _ => unreachable!(),
275 }
276 Ok(())
277 }
278
279 fn write_store_value(
280 &mut self,
281 module: &crate::Module,
282 value: &StoreValue,
283 func_ctx: &FunctionCtx,
284 ) -> BackendResult {
285 match *value {
286 StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
287 StoreValue::TempIndex {
288 depth,
289 index,
290 ty: _,
291 } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
292 StoreValue::TempAccess {
293 depth,
294 base,
295 member_index,
296 } => {
297 let name = &self.names[&NameKey::StructMember(base, member_index)];
298 write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
299 }
300 StoreValue::TempColumnAccess {
301 depth,
302 base,
303 member_index,
304 column,
305 } => {
306 let name = &self.names[&NameKey::StructMember(base, member_index)];
307 write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")?
308 }
309 }
310 Ok(())
311 }
312
313 pub(super) fn write_storage_store(
315 &mut self,
316 module: &crate::Module,
317 var_handle: Handle<crate::GlobalVariable>,
318 value: StoreValue,
319 func_ctx: &FunctionCtx,
320 level: crate::back::Level,
321 within_struct: Option<Handle<crate::Type>>,
322 ) -> BackendResult {
323 let temp_resolution;
324 let ty_resolution = match value {
325 StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
326 StoreValue::TempIndex {
327 depth: _,
328 index: _,
329 ref ty,
330 } => ty,
331 StoreValue::TempAccess {
332 depth: _,
333 base,
334 member_index,
335 } => {
336 let ty_handle = match module.types[base].inner {
337 crate::TypeInner::Struct { ref members, .. } => {
338 members[member_index as usize].ty
339 }
340 _ => unreachable!(),
341 };
342 temp_resolution = TypeResolution::Handle(ty_handle);
343 &temp_resolution
344 }
345 StoreValue::TempColumnAccess { .. } => {
346 unreachable!("attempting write_storage_store for TempColumnAccess");
347 }
348 };
349 match *ty_resolution.inner_with(&module.types) {
350 crate::TypeInner::Scalar(scalar) => {
351 let chain = mem::take(&mut self.temp_access_chain);
353 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
354 if scalar.width == 4 {
356 write!(self.out, "{level}{var_name}.Store(")?;
357 self.write_storage_address(module, &chain, func_ctx)?;
358 write!(self.out, ", asuint(")?;
359 self.write_store_value(module, &value, func_ctx)?;
360 writeln!(self.out, "));")?;
361 } else {
362 write!(self.out, "{level}{var_name}.Store(")?;
363 self.write_storage_address(module, &chain, func_ctx)?;
364 write!(self.out, ", ")?;
365 self.write_store_value(module, &value, func_ctx)?;
366 writeln!(self.out, ");")?;
367 }
368 self.temp_access_chain = chain;
369 }
370 crate::TypeInner::Vector { size, scalar } => {
371 let chain = mem::take(&mut self.temp_access_chain);
373 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
374 if scalar.width == 4 {
376 write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
377 self.write_storage_address(module, &chain, func_ctx)?;
378 write!(self.out, ", asuint(")?;
379 self.write_store_value(module, &value, func_ctx)?;
380 writeln!(self.out, "));")?;
381 } else {
382 write!(self.out, "{level}{var_name}.Store(")?;
383 self.write_storage_address(module, &chain, func_ctx)?;
384 write!(self.out, ", ")?;
385 self.write_store_value(module, &value, func_ctx)?;
386 writeln!(self.out, ");")?;
387 }
388 self.temp_access_chain = chain;
389 }
390 crate::TypeInner::Matrix {
391 columns,
392 rows,
393 scalar,
394 } => {
395 let row_stride = Alignment::from(rows) * scalar.width as u32;
397
398 writeln!(self.out, "{level}{{")?;
399
400 match within_struct {
401 Some(containing_struct) if rows == crate::VectorSize::Bi => {
402 let mut chain = mem::take(&mut self.temp_access_chain);
405 for i in 0..columns as u32 {
406 chain.push(SubAccess::Offset(i * row_stride));
407 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
409 let StoreValue::TempAccess { member_index, .. } = value else {
410 unreachable!(
411 "write_storage_store within_struct but not TempAccess"
412 );
413 };
414 let column_value = StoreValue::TempColumnAccess {
415 depth: level.0, base: containing_struct,
417 member_index,
418 column: i,
419 };
420 if scalar.width == 4 {
422 write!(
423 self.out,
424 "{}{}.Store{}(",
425 level.next(),
426 var_name,
427 rows as u8
428 )?;
429 self.write_storage_address(module, &chain, func_ctx)?;
430 write!(self.out, ", asuint(")?;
431 self.write_store_value(module, &column_value, func_ctx)?;
432 writeln!(self.out, "));")?;
433 } else {
434 write!(self.out, "{}{var_name}.Store(", level.next())?;
435 self.write_storage_address(module, &chain, func_ctx)?;
436 write!(self.out, ", ")?;
437 self.write_store_value(module, &column_value, func_ctx)?;
438 writeln!(self.out, ");")?;
439 }
440 chain.pop();
441 }
442 self.temp_access_chain = chain;
443 }
444 _ => {
445 let depth = level.0 + 1;
447 write!(
448 self.out,
449 "{}{}{}x{} {}{} = ",
450 level.next(),
451 scalar.to_hlsl_str()?,
452 columns as u8,
453 rows as u8,
454 STORE_TEMP_NAME,
455 depth,
456 )?;
457 self.write_store_value(module, &value, func_ctx)?;
458 writeln!(self.out, ";")?;
459
460 for i in 0..columns as u32 {
462 self.temp_access_chain
463 .push(SubAccess::Offset(i * row_stride));
464 let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
465 let sv = StoreValue::TempIndex {
466 depth,
467 index: i,
468 ty: TypeResolution::Value(ty_inner),
469 };
470 self.write_storage_store(
471 module,
472 var_handle,
473 sv,
474 func_ctx,
475 level.next(),
476 None,
477 )?;
478 self.temp_access_chain.pop();
479 }
480 }
481 }
482 writeln!(self.out, "{level}}}")?;
484 }
485 crate::TypeInner::Array {
486 base,
487 size: crate::ArraySize::Constant(size),
488 stride,
489 } => {
490 writeln!(self.out, "{level}{{")?;
492 write!(self.out, "{}", level.next())?;
493 self.write_type(module, base)?;
494 let depth = level.next().0;
495 write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
496 self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
497 write!(self.out, " = ")?;
498 self.write_store_value(module, &value, func_ctx)?;
499 writeln!(self.out, ";")?;
500 for i in 0..size.get() {
502 self.temp_access_chain.push(SubAccess::Offset(i * stride));
503 let sv = StoreValue::TempIndex {
504 depth,
505 index: i,
506 ty: TypeResolution::Handle(base),
507 };
508 self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?;
509 self.temp_access_chain.pop();
510 }
511 writeln!(self.out, "{level}}}")?;
513 }
514 crate::TypeInner::Struct { ref members, .. } => {
515 writeln!(self.out, "{level}{{")?;
517 let depth = level.next().0;
518 let struct_ty = ty_resolution.handle().unwrap();
519 let struct_name = &self.names[&NameKey::Type(struct_ty)];
520 write!(
521 self.out,
522 "{}{} {}{} = ",
523 level.next(),
524 struct_name,
525 STORE_TEMP_NAME,
526 depth
527 )?;
528 self.write_store_value(module, &value, func_ctx)?;
529 writeln!(self.out, ";")?;
530 for (i, member) in members.iter().enumerate() {
532 self.temp_access_chain
533 .push(SubAccess::Offset(member.offset));
534 let sv = StoreValue::TempAccess {
535 depth,
536 base: struct_ty,
537 member_index: i as u32,
538 };
539 self.write_storage_store(
540 module,
541 var_handle,
542 sv,
543 func_ctx,
544 level.next(),
545 Some(struct_ty),
546 )?;
547 self.temp_access_chain.pop();
548 }
549 writeln!(self.out, "{level}}}")?;
551 }
552 _ => unreachable!(),
553 }
554 Ok(())
555 }
556
557 pub(super) fn fill_access_chain(
569 &mut self,
570 module: &crate::Module,
571 mut cur_expr: Handle<crate::Expression>,
572 func_ctx: &FunctionCtx,
573 ) -> Result<Handle<crate::GlobalVariable>, Error> {
574 enum AccessIndex {
575 Expression(Handle<crate::Expression>),
576 Constant(u32),
577 }
578 enum Parent<'a> {
579 Array { stride: u32 },
580 Struct(&'a [crate::StructMember]),
581 }
582 self.temp_access_chain.clear();
583
584 loop {
585 let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
586 crate::Expression::GlobalVariable(handle) => {
587 if let Some(ref binding) = module.global_variables[handle].binding {
588 let bt = self.options.resolve_resource_binding(binding).unwrap();
590 if let Some(dynamic_storage_buffer_offsets_index) =
591 bt.dynamic_storage_buffer_offsets_index
592 {
593 self.temp_access_chain.push(SubAccess::BufferOffset {
594 group: binding.group,
595 offset: dynamic_storage_buffer_offsets_index,
596 });
597 }
598 }
599 return Ok(handle);
600 }
601 crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
602 crate::Expression::AccessIndex { base, index } => {
603 (base, AccessIndex::Constant(index))
604 }
605 ref other => {
606 return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
607 }
608 };
609
610 let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
611 crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
612 crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
613 crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
614 crate::TypeInner::Vector { scalar, .. } => Parent::Array {
615 stride: scalar.width as u32,
616 },
617 crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
618 stride: Alignment::from(rows) * scalar.width as u32,
621 },
622 _ => unreachable!(),
623 },
624 crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
625 stride: scalar.width as u32,
626 },
627 _ => unreachable!(),
628 };
629
630 let sub = match (parent, access_index) {
631 (Parent::Array { stride }, AccessIndex::Expression(value)) => {
632 SubAccess::Index { value, stride }
633 }
634 (Parent::Array { stride }, AccessIndex::Constant(index)) => {
635 SubAccess::Offset(stride * index)
636 }
637 (Parent::Struct(members), AccessIndex::Constant(index)) => {
638 SubAccess::Offset(members[index as usize].offset)
639 }
640 (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
641 };
642
643 self.temp_access_chain.push(sub);
644 cur_expr = next_expr;
645 }
646 }
647}