use std::mem::size_of_val;
use wgpu::util::DeviceExt;
async fn run() {
let mut local_a = [0i32; 100];
for (i, e) in local_a.iter_mut().enumerate() {
*e = i as i32;
}
log::info!("Input a: {local_a:?}");
let mut local_b = [0i32; 100];
for (i, e) in local_b.iter_mut().enumerate() {
*e = i as i32 * 2;
}
log::info!("Input b: {local_b:?}");
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await
.unwrap();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: wgpu::MemoryHints::MemoryUsage,
},
None,
)
.await
.unwrap();
let shader = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
let storage_buffer_a = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&local_a[..]),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let storage_buffer_b = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&local_b[..]),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size_of_val(&local_a) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer_a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: storage_buffer_b.as_entire_binding(),
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let mut command_encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut compute_pass = command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(local_a.len() as u32, 1, 1);
}
queue.submit(Some(command_encoder.finish()));
get_data(
&mut local_a[..],
&storage_buffer_a,
&output_staging_buffer,
&device,
&queue,
)
.await;
get_data(
&mut local_b[..],
&storage_buffer_b,
&output_staging_buffer,
&device,
&queue,
)
.await;
log::info!("Output in A: {local_a:?}");
log::info!("Output in B: {local_b:?}");
}
async fn get_data<T: bytemuck::Pod>(
output: &mut [T],
storage_buffer: &wgpu::Buffer,
staging_buffer: &wgpu::Buffer,
device: &wgpu::Device,
queue: &wgpu::Queue,
) {
let mut command_encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
command_encoder.copy_buffer_to_buffer(
storage_buffer,
0,
staging_buffer,
0,
size_of_val(output) as u64,
);
queue.submit(Some(command_encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
receiver.recv_async().await.unwrap().unwrap();
output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
staging_buffer.unmap();
}
pub fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.format_timestamp_nanos()
.init();
pollster::block_on(run());
}
#[cfg(target_arch = "wasm32")]
{
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_log::init_with_level(log::Level::Info).expect("could not initialize logger");
crate::utils::add_web_nothing_to_see_msg();
wasm_bindgen_futures::spawn_local(run());
}
}