1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
use std::{mem::size_of_val, str::FromStr};
use wgpu::util::DeviceExt;
// Indicates a u32 overflow in an intermediate Collatz value
const OVERFLOW: u32 = 0xffffffff;
#[cfg_attr(test, allow(dead_code))]
async fn run() {
let numbers = if std::env::args().len() <= 2 {
let default = vec![1, 2, 3, 4];
println!("No numbers were provided, defaulting to {default:?}");
} else {
.map(|s| u32::from_str(&s).expect("You must pass a list of positive integers!"))
let steps = execute_gpu(&numbers).await.unwrap();
let disp_steps: Vec<String> = steps
.map(|&n| match n {
OVERFLOW => "OVERFLOW".to_string(),
_ => n.to_string(),
println!("Steps: [{}]", disp_steps.join(", "));
#[cfg(target_arch = "wasm32")]
log::info!("Steps: [{}]", disp_steps.join(", "));
#[cfg_attr(test, allow(dead_code))]
async fn execute_gpu(numbers: &[u32]) -> Option<Vec<u32>> {
// Instantiates instance of WebGPU
let instance = wgpu::Instance::default();
// `request_adapter` instantiates the general connection to the GPU
let adapter = instance
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
// `features` being the available features.
let (device, queue) = adapter
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: wgpu::MemoryHints::MemoryUsage,
execute_gpu_inner(&device, &queue, numbers).await
async fn execute_gpu_inner(
device: &wgpu::Device,
queue: &wgpu::Queue,
numbers: &[u32],
) -> Option<Vec<u32>> {
// Loads the shader from WGSL
let cs_module = device.create_shader_module(wgpu::include_wgsl!("shader.wgsl"));
// Gets the size in bytes of the buffer.
let size = size_of_val(numbers) as wgpu::BufferAddress;
// Instantiates buffer without data.
// `usage` of buffer specifies how it can be used:
// `BufferUsages::MAP_READ` allows it to be read (outside the shader).
// `BufferUsages::COPY_DST` allows it to be the destination of the copy.
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
// Instantiates buffer with data (`numbers`).
// Usage allowing the buffer to be:
// A storage buffer (can be bound within a bind group and thus available to a shader).
// The destination of a copy.
// The source of a copy.
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(numbers),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
// A bind group defines how buffers are accessed by shaders.
// It is to WebGPU what a descriptor set is to Vulkan.
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`).
// A pipeline specifies the operation of a shader
// Instantiates the pipeline.
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &cs_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
// Instantiates the bind group, once again specifying the binding of buffers.
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
// A command encoder executes one or many pipelines.
// It is to WebGPU what a command buffer is to Vulkan.
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
cpass.set_bind_group(0, &bind_group, &[]);
cpass.insert_debug_marker("compute collatz iterations");
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
// Submits command encoder for processing
// Note that we're not calling `.await` here.
let buffer_slice = staging_buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
// Awaits until `buffer_future` can be read from
if let Ok(Ok(())) = receiver.recv_async().await {
// Gets contents of buffer
let data = buffer_slice.get_mapped_range();
// Since contents are got in bytes, this converts these bytes back to u32
let result = bytemuck::cast_slice(&data).to_vec();
// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
staging_buffer.unmap(); // Unmaps buffer from memory
// If you are familiar with C++ these 2 lines can be thought of similarly to:
// delete myPointer;
// myPointer = NULL;
// It effectively frees the memory
// Returns data from buffer
} else {
panic!("failed to run compute on gpu!")
pub fn main() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(target_arch = "wasm32")]
console_log::init().expect("could not initialize logger");
mod tests;