use std::mem::size_of;
use std::num::NonZeroU64;
use thiserror::Error;
use crate::{
device::DeviceError,
pipeline::{CreateComputePipelineError, CreateShaderModuleError},
};
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CreateDispatchIndirectValidationPipelineError {
#[error(transparent)]
DeviceError(#[from] DeviceError),
#[error(transparent)]
ShaderModule(#[from] CreateShaderModuleError),
#[error(transparent)]
ComputePipeline(#[from] CreateComputePipelineError),
}
#[derive(Debug)]
pub struct IndirectValidation {
module: Box<dyn hal::DynShaderModule>,
dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
pipeline: Box<dyn hal::DynComputePipeline>,
dst_buffer: Box<dyn hal::DynBuffer>,
dst_bind_group: Box<dyn hal::DynBindGroup>,
}
pub struct Params<'a> {
pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
pub pipeline: &'a dyn hal::DynComputePipeline,
pub dst_buffer: &'a dyn hal::DynBuffer,
pub dst_bind_group: &'a dyn hal::DynBindGroup,
pub aligned_offset: u64,
pub offset_remainder: u64,
}
impl IndirectValidation {
pub fn new(
device: &dyn hal::DynDevice,
limits: &wgt::Limits,
) -> Result<Self, CreateDispatchIndirectValidationPipelineError> {
let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
let src = format!(
"
@group(0) @binding(0)
var<storage, read_write> dst: array<u32, 6>;
@group(1) @binding(0)
var<storage, read> src: array<u32>;
struct OffsetPc {{
inner: u32,
}}
var<push_constant> offset: OffsetPc;
@compute @workgroup_size(1)
fn main() {{
let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
if (
src.x > max_compute_workgroups_per_dimension ||
src.y > max_compute_workgroups_per_dimension ||
src.z > max_compute_workgroups_per_dimension
) {{
dst = array(0u, 0u, 0u, 0u, 0u, 0u);
}} else {{
dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
}}
}}
"
);
const SRC_BUFFER_SIZE: NonZeroU64 =
unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
NonZeroU64::new_unchecked(
SRC_BUFFER_SIZE.get() * 2, )
};
let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
CreateShaderModuleError::Parsing(naga::error::ShaderError {
source: src.clone(),
label: None,
inner: Box::new(inner),
})
})?;
let info = crate::device::create_validator(
wgt::Features::PUSH_CONSTANTS,
wgt::DownlevelFlags::empty(),
naga::valid::ValidationFlags::all(),
)
.validate(&module)
.map_err(|inner| {
CreateShaderModuleError::Validation(naga::error::ShaderError {
source: src,
label: None,
inner: Box::new(inner),
})
})?;
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
module: std::borrow::Cow::Owned(module),
info,
debug_source: None,
});
let hal_desc = hal::ShaderModuleDescriptor {
label: None,
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
let module =
unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
match error {
hal::ShaderError::Device(error) => {
CreateShaderModuleError::Device(DeviceError::from_hal(error))
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
CreateShaderModuleError::Generation
}
}
})?;
let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
entries: &[wgt::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: Some(DST_BUFFER_SIZE),
},
count: None,
}],
};
let dst_bind_group_layout = unsafe {
device
.create_bind_group_layout(&dst_bind_group_layout_desc)
.map_err(DeviceError::from_hal)?
};
let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
label: None,
flags: hal::BindGroupLayoutFlags::empty(),
entries: &[wgt::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: true,
min_binding_size: Some(SRC_BUFFER_SIZE),
},
count: None,
}],
};
let src_bind_group_layout = unsafe {
device
.create_bind_group_layout(&src_bind_group_layout_desc)
.map_err(DeviceError::from_hal)?
};
let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
label: None,
flags: hal::PipelineLayoutFlags::empty(),
bind_group_layouts: &[
dst_bind_group_layout.as_ref(),
src_bind_group_layout.as_ref(),
],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
};
let pipeline_layout = unsafe {
device
.create_pipeline_layout(&pipeline_layout_desc)
.map_err(DeviceError::from_hal)?
};
let pipeline_desc = hal::ComputePipelineDescriptor {
label: None,
layout: pipeline_layout.as_ref(),
stage: hal::ProgrammableStage {
module: module.as_ref(),
entry_point: "main",
constants: &Default::default(),
zero_initialize_workgroup_memory: false,
},
cache: None,
};
let pipeline =
unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
hal::PipelineError::Device(error) => {
CreateComputePipelineError::Device(DeviceError::from_hal(error))
}
hal::PipelineError::Linkage(_stages, msg) => {
CreateComputePipelineError::Internal(msg)
}
hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
),
hal::PipelineError::PipelineConstants(_, error) => {
CreateComputePipelineError::PipelineConstants(error)
}
})?;
let dst_buffer_desc = hal::BufferDescriptor {
label: None,
size: DST_BUFFER_SIZE.get(),
usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE,
memory_flags: hal::MemoryFlags::empty(),
};
let dst_buffer =
unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
let dst_bind_group_desc = hal::BindGroupDescriptor {
label: None,
layout: dst_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer: dst_buffer.as_ref(),
offset: 0,
size: Some(DST_BUFFER_SIZE),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
let dst_bind_group = unsafe {
device
.create_bind_group(&dst_bind_group_desc)
.map_err(DeviceError::from_hal)
}?;
Ok(Self {
module,
dst_bind_group_layout,
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
})
}
pub fn create_src_bind_group(
&self,
device: &dyn hal::DynDevice,
limits: &wgt::Limits,
buffer_size: u64,
buffer: &dyn hal::DynBuffer,
) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
let Some(binding_size) = NonZeroU64::new(binding_size) else {
return Ok(None);
};
let hal_desc = hal::BindGroupDescriptor {
label: None,
layout: self.src_bind_group_layout.as_ref(),
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: Some(binding_size),
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
};
unsafe {
device
.create_bind_group(&hal_desc)
.map(Some)
.map_err(DeviceError::from_hal)
}
}
pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
let alignment = limits.min_storage_buffer_offset_alignment as u64;
let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
let aligned_offset = offset - offset % alignment;
let max_aligned_offset = buffer_size - binding_size;
let aligned_offset = aligned_offset.min(max_aligned_offset);
let offset_remainder = offset - aligned_offset;
Params {
pipeline_layout: self.pipeline_layout.as_ref(),
pipeline: self.pipeline.as_ref(),
dst_buffer: self.dst_buffer.as_ref(),
dst_bind_group: self.dst_bind_group.as_ref(),
aligned_offset,
offset_remainder,
}
}
pub fn dispose(self, device: &dyn hal::DynDevice) {
let IndirectValidation {
module,
dst_bind_group_layout,
src_bind_group_layout,
pipeline_layout,
pipeline,
dst_buffer,
dst_bind_group,
} = self;
unsafe {
device.destroy_bind_group(dst_bind_group);
device.destroy_buffer(dst_buffer);
device.destroy_compute_pipeline(pipeline);
device.destroy_pipeline_layout(pipeline_layout);
device.destroy_bind_group_layout(src_bind_group_layout);
device.destroy_bind_group_layout(dst_bind_group_layout);
device.destroy_shader_module(module);
}
}
}
fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
let alignment = limits.min_storage_buffer_offset_alignment as u64;
let binding_size = 2 * alignment + (buffer_size % alignment);
binding_size.min(buffer_size)
}