wgpu_core/timestamp_normalization/
mod.rs

1//! Utility for normalizing GPU timestamp queries to have a consistent
2//! 1GHz period. This uses a compute shader to do the normalization,
3//! so the timestamps exist in their correct format on the GPU, as
4//! is required by the WebGPU specification.
5//!
6//! ## Algorithm
7//!
8//! The fundamental operation is multiplying a u64 timestamp by an f32
9//! value. We have neither f64s nor u64s in shaders, so we need to do
10//! something more complicated.
11//!
12//! We first decompose the f32 into a u32 fraction where the denominator
13//! is a power of two. We do the computation with f64 for ease of computation,
14//! as those can store u32s losslessly.
15//!
16//! Because the denominator is a power of two, this means the shader can evaluate
17//! this divide by using a shift. Additionally, we always choose the largest denominator
18//! we can, so that the fraction is as precise as possible.
19//!
20//! To evaluate this function, we have two helper operations (both in common.wgsl).
21//!
22//! 1. `u64_mul_u32` multiplies a u64 by a u32 and returns a u96.
23//! 2. `shift_right_u96` shifts a u96 right by a given amount, returning a u96.
24//!
25//! See their implementations for more details.
26//!
27//! We then multiply the timestamp by the numerator, and shift it right by the
28//! denominator. This gives us the normalized timestamp.
29
30use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35
36use crate::{
37    device::{Device, DeviceError},
38    hal_label,
39    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40    resource::Buffer,
41    snatch::SnatchGuard,
42    track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46    wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49    temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51    pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56    #[error("Failed to initialize bind group layout")]
57    BindGroupLayout(#[source] DeviceError),
58    #[cfg(feature = "wgsl")]
59    #[error("Failed to parse shader")]
60    ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61    #[error("Failed to validate shader module")]
62    ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63    #[error("Failed to create shader module")]
64    CreateShaderModule(#[from] CreateShaderModuleError),
65    #[error("Failed to create pipeline layout")]
66    PipelineLayout(#[source] DeviceError),
67    #[error("Failed to create compute pipeline")]
68    ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71/// Normalizes GPU timestamps to have a consistent 1GHz period.
72/// See module documentation for more information.
73pub struct TimestampNormalizer {
74    state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78    /// Creates a new timestamp normalizer.
79    ///
80    /// If the device cannot support automatic timestamp normalization,
81    /// this will return a normalizer that does nothing.
82    ///
83    /// # Errors
84    ///
85    /// If any resources are invalid, this will return an error.
86    pub fn new(
87        device: &Device,
88        timestamp_period: f32,
89    ) -> Result<Self, TimestampNormalizerInitError> {
90        unsafe {
91            if !device
92                .instance_flags
93                .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94            {
95                return Ok(Self { state: None });
96            }
97
98            if !device
99                .downlevel
100                .flags
101                .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102            {
103                log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104                return Ok(Self { state: None });
105            }
106
107            if timestamp_period == 1.0 {
108                // If the period is 1, we don't need to do anything to them.
109                return Ok(Self { state: None });
110            }
111
112            let temporary_bind_group_layout = device
113                .raw()
114                .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115                    label: hal_label(
116                        Some("Timestamp Normalization Bind Group Layout"),
117                        device.instance_flags,
118                    ),
119                    flags: hal::BindGroupLayoutFlags::empty(),
120                    entries: &[wgt::BindGroupLayoutEntry {
121                        binding: 0,
122                        visibility: wgt::ShaderStages::COMPUTE,
123                        ty: wgt::BindingType::Buffer {
124                            ty: wgt::BufferBindingType::Storage { read_only: false },
125                            has_dynamic_offset: false,
126                            min_binding_size: Some(NonZeroU64::new(8).unwrap()),
127                        },
128                        count: None,
129                    }],
130                })
131                .map_err(|e| {
132                    TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
133                })?;
134
135            let common_src = include_str!("common.wgsl");
136            let src = include_str!("timestamp_normalization.wgsl");
137
138            let preprocessed_src = alloc::format!("{common_src}\n{src}");
139
140            #[cfg(feature = "wgsl")]
141            let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
142                TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
143                    source: preprocessed_src.clone(),
144                    label: None,
145                    inner: Box::new(inner),
146                })
147            })?;
148            #[cfg(not(feature = "wgsl"))]
149            #[allow(clippy::diverging_sub_expression)]
150            let module =
151                panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
152
153            let info = crate::device::create_validator(
154                wgt::Features::IMMEDIATES,
155                wgt::DownlevelFlags::empty(),
156                naga::valid::ValidationFlags::all(),
157            )
158            .validate(&module)
159            .map_err(|inner| {
160                TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
161                    source: preprocessed_src.clone(),
162                    label: None,
163                    inner: Box::new(inner),
164                })
165            })?;
166            let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
167                module: alloc::borrow::Cow::Owned(module),
168                info,
169                debug_source: None,
170            });
171            let hal_desc = hal::ShaderModuleDescriptor {
172                label: None,
173                runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
174            };
175            let module = device
176                .raw()
177                .create_shader_module(&hal_desc, hal_shader)
178                .map_err(|error| match error {
179                    hal::ShaderError::Device(error) => {
180                        CreateShaderModuleError::Device(device.handle_hal_error(error))
181                    }
182                    hal::ShaderError::Compilation(ref msg) => {
183                        log::error!("Shader error: {msg}");
184                        CreateShaderModuleError::Generation
185                    }
186                })?;
187
188            let pipeline_layout = device
189                .raw()
190                .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
191                    label: None,
192                    bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
193                    immediate_size: 8,
194                    flags: hal::PipelineLayoutFlags::empty(),
195                })
196                .map_err(|e| {
197                    TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
198                })?;
199
200            let (multiplier, shift) = compute_timestamp_period(timestamp_period);
201
202            let mut constants = HashMap::with_capacity(2);
203            constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
204            constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
205
206            let pipeline_desc = hal::ComputePipelineDescriptor {
207                label: None,
208                layout: pipeline_layout.as_ref(),
209                stage: hal::ProgrammableStage {
210                    module: module.as_ref(),
211                    entry_point: "main",
212                    constants: &constants,
213                    zero_initialize_workgroup_memory: false,
214                },
215                cache: None,
216            };
217            let pipeline = device
218                .raw()
219                .create_compute_pipeline(&pipeline_desc)
220                .map_err(|err| match err {
221                    hal::PipelineError::Device(error) => {
222                        CreateComputePipelineError::Device(device.handle_hal_error(error))
223                    }
224                    hal::PipelineError::Linkage(_stages, msg) => {
225                        CreateComputePipelineError::Internal(msg)
226                    }
227                    hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
228                        crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
229                    ),
230                    hal::PipelineError::PipelineConstants(_, error) => {
231                        CreateComputePipelineError::PipelineConstants(error)
232                    }
233                })?;
234
235            Ok(Self {
236                state: Some(InternalState {
237                    temporary_bind_group_layout,
238                    pipeline_layout,
239                    pipeline,
240                }),
241            })
242        }
243    }
244
245    /// Create a bind group for normalizing timestamps in `buffer`.
246    ///
247    /// This function is unsafe because it does not know that `buffer_size` is
248    /// the true size of the buffer.
249    pub unsafe fn create_normalization_bind_group(
250        &self,
251        device: &Device,
252        buffer: &dyn hal::DynBuffer,
253        buffer_label: Option<&str>,
254        buffer_size: wgt::BufferSize,
255        buffer_usages: wgt::BufferUsages,
256    ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
257        unsafe {
258            let Some(ref state) = &self.state else {
259                return Ok(TimestampNormalizationBindGroup { raw: None });
260            };
261
262            if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
263                return Ok(TimestampNormalizationBindGroup { raw: None });
264            }
265
266            // If this buffer is large enough that we wouldn't be able to bind the entire thing
267            // at once to normalize the timestamps, we can't use it. We force the buffer to fail
268            // to allocate. The lowest max binding size is 128MB, and query sets must be small
269            // (no more than 4096), so this should never be hit in practice by sane programs.
270            if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
271                return Err(DeviceError::OutOfMemory);
272            }
273
274            let bg_label_alloc;
275            let label = match buffer_label {
276                Some(label) => {
277                    bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
278                    &*bg_label_alloc
279                }
280                None => "Timestamp normalization bind group",
281            };
282
283            let bg = device
284                .raw()
285                .create_bind_group(&hal::BindGroupDescriptor {
286                    label: hal_label(Some(label), device.instance_flags),
287                    layout: &*state.temporary_bind_group_layout,
288                    buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
289                    samplers: &[],
290                    textures: &[],
291                    acceleration_structures: &[],
292                    external_textures: &[],
293                    entries: &[hal::BindGroupEntry {
294                        binding: 0,
295                        resource_index: 0,
296                        count: 1,
297                    }],
298                })
299                .map_err(|e| device.handle_hal_error(e))?;
300
301            Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
302        }
303    }
304
305    pub fn normalize(
306        &self,
307        snatch_guard: &SnatchGuard<'_>,
308        encoder: &mut dyn hal::DynCommandEncoder,
309        tracker: &mut BufferTracker,
310        bind_group: &TimestampNormalizationBindGroup,
311        buffer: &Arc<Buffer>,
312        buffer_offset_bytes: u64,
313        total_timestamps: u32,
314    ) {
315        let Some(ref state) = &self.state else {
316            return;
317        };
318
319        let Some(bind_group) = bind_group.raw.as_deref() else {
320            return;
321        };
322
323        let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); // Unreachable as MAX_QUERIES is way less than u32::MAX
324
325        let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
326
327        let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
328
329        let needed_workgroups = total_timestamps.div_ceil(64);
330
331        unsafe {
332            encoder.transition_buffers(barrier.as_slice());
333            encoder.begin_compute_pass(&hal::ComputePassDescriptor {
334                label: hal_label(
335                    Some("Timestamp normalization pass"),
336                    buffer.device.instance_flags,
337                ),
338                timestamp_writes: None,
339            });
340            encoder.set_compute_pipeline(&*state.pipeline);
341            encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
342            encoder.set_immediates(
343                &*state.pipeline_layout,
344                0,
345                &[buffer_offset_timestamps, total_timestamps],
346            );
347            encoder.dispatch([needed_workgroups, 1, 1]);
348            encoder.end_compute_pass();
349        }
350    }
351
352    pub fn dispose(self, device: &dyn hal::DynDevice) {
353        unsafe {
354            let Some(state) = self.state else {
355                return;
356            };
357
358            device.destroy_compute_pipeline(state.pipeline);
359            device.destroy_pipeline_layout(state.pipeline_layout);
360            device.destroy_bind_group_layout(state.temporary_bind_group_layout);
361        }
362    }
363
364    pub fn enabled(&self) -> bool {
365        self.state.is_some()
366    }
367}
368
369#[derive(Debug)]
370pub struct TimestampNormalizationBindGroup {
371    raw: Option<Box<dyn hal::DynBindGroup>>,
372}
373
374impl TimestampNormalizationBindGroup {
375    pub fn dispose(self, device: &dyn hal::DynDevice) {
376        unsafe {
377            if let Some(raw) = self.raw {
378                device.destroy_bind_group(raw);
379            }
380        }
381    }
382}
383
384fn compute_timestamp_period(input: f32) -> (u32, u32) {
385    let pow2 = input.log2().ceil() as i32;
386    let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
387    let shift = 32 - clamped_pow2;
388
389    let denominator = (1u64 << shift) as f64;
390
391    // float -> int conversions are defined to saturate.
392    let multiplier = (input as f64 * denominator).round() as u32;
393
394    (multiplier, shift)
395}
396
397#[cfg(test)]
398mod tests {
399    use core::f64;
400
401    fn assert_timestamp_case(input: f32) {
402        let (multiplier, shift) = super::compute_timestamp_period(input);
403
404        let output = multiplier as f64 / (1u64 << shift) as f64;
405
406        assert!((input as f64 - output).abs() < 0.0000001);
407    }
408
409    #[test]
410    fn compute_timestamp_period() {
411        assert_timestamp_case(0.01);
412        assert_timestamp_case(0.5);
413        assert_timestamp_case(1.0);
414        assert_timestamp_case(2.0);
415        assert_timestamp_case(2.7);
416        assert_timestamp_case(1000.7);
417    }
418}