wgpu_core/timestamp_normalization/
mod.rs

1//! Utility for normalizing GPU timestamp queries to have a consistent
2//! 1GHz period. This uses a compute shader to do the normalization,
3//! so the timestamps exist in their correct format on the GPU, as
4//! is required by the WebGPU specification.
5//!
6//! ## Algorithm
7//!
8//! The fundamental operation is multiplying a u64 timestamp by an f32
9//! value. We have neither f64s nor u64s in shaders, so we need to do
10//! something more complicated.
11//!
12//! We first decompose the f32 into a u32 fraction where the denominator
13//! is a power of two. We do the computation with f64 for ease of computation,
14//! as those can store u32s losslessly.
15//!
16//! Because the denominator is a power of two, this means the shader can evaluate
17//! this divide by using a shift. Additionally, we always choose the largest denominator
18//! we can, so that the fraction is as precise as possible.
19//!
20//! To evaluate this function, we have two helper operations (both in common.wgsl).
21//!
22//! 1. `u64_mul_u32` multiplies a u64 by a u32 and returns a u96.
23//! 2. `shift_right_u96` shifts a u96 right by a given amount, returning a u96.
24//!
25//! See their implementations for more details.
26//!
27//! We then multiply the timestamp by the numerator, and shift it right by the
28//! denominator. This gives us the normalized timestamp.
29
30use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35use wgt::PushConstantRange;
36
37use crate::{
38    device::{Device, DeviceError},
39    hal_label,
40    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
41    resource::Buffer,
42    snatch::SnatchGuard,
43    track::BufferTracker,
44};
45
46pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
47    wgt::BufferUses::STORAGE_READ_WRITE;
48
49struct InternalState {
50    temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
52    pipeline: Box<dyn hal::DynComputePipeline>,
53}
54
55#[derive(Debug, Clone, thiserror::Error)]
56pub enum TimestampNormalizerInitError {
57    #[error("Failed to initialize bind group layout")]
58    BindGroupLayout(#[source] DeviceError),
59    #[cfg(feature = "wgsl")]
60    #[error("Failed to parse shader")]
61    ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
62    #[error("Failed to validate shader module")]
63    ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
64    #[error("Failed to create shader module")]
65    CreateShaderModule(#[from] CreateShaderModuleError),
66    #[error("Failed to create pipeline layout")]
67    PipelineLayout(#[source] DeviceError),
68    #[error("Failed to create compute pipeline")]
69    ComputePipeline(#[from] CreateComputePipelineError),
70}
71
72/// Normalizes GPU timestamps to have a consistent 1GHz period.
73/// See module documentation for more information.
74pub struct TimestampNormalizer {
75    state: Option<InternalState>,
76}
77
78impl TimestampNormalizer {
79    /// Creates a new timestamp normalizer.
80    ///
81    /// If the device cannot support automatic timestamp normalization,
82    /// this will return a normalizer that does nothing.
83    ///
84    /// # Errors
85    ///
86    /// If any resources are invalid, this will return an error.
87    pub fn new(
88        device: &Device,
89        timestamp_period: f32,
90    ) -> Result<Self, TimestampNormalizerInitError> {
91        unsafe {
92            if !device
93                .instance_flags
94                .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
95            {
96                return Ok(Self { state: None });
97            }
98
99            if !device
100                .downlevel
101                .flags
102                .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
103            {
104                log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
105                return Ok(Self { state: None });
106            }
107
108            if timestamp_period == 1.0 {
109                // If the period is 1, we don't need to do anything to them.
110                return Ok(Self { state: None });
111            }
112
113            let temporary_bind_group_layout = device
114                .raw()
115                .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
116                    label: hal_label(
117                        Some("Timestamp Normalization Bind Group Layout"),
118                        device.instance_flags,
119                    ),
120                    flags: hal::BindGroupLayoutFlags::empty(),
121                    entries: &[wgt::BindGroupLayoutEntry {
122                        binding: 0,
123                        visibility: wgt::ShaderStages::COMPUTE,
124                        ty: wgt::BindingType::Buffer {
125                            ty: wgt::BufferBindingType::Storage { read_only: false },
126                            has_dynamic_offset: false,
127                            min_binding_size: Some(NonZeroU64::new(8).unwrap()),
128                        },
129                        count: None,
130                    }],
131                })
132                .map_err(|e| {
133                    TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
134                })?;
135
136            let common_src = include_str!("common.wgsl");
137            let src = include_str!("timestamp_normalization.wgsl");
138
139            let preprocessed_src = alloc::format!("{common_src}\n{src}");
140
141            #[cfg(feature = "wgsl")]
142            let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
143                TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
144                    source: preprocessed_src.clone(),
145                    label: None,
146                    inner: Box::new(inner),
147                })
148            })?;
149            #[cfg(not(feature = "wgsl"))]
150            #[allow(clippy::diverging_sub_expression)]
151            let module =
152                panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
153
154            let info = crate::device::create_validator(
155                wgt::Features::PUSH_CONSTANTS,
156                wgt::DownlevelFlags::empty(),
157                naga::valid::ValidationFlags::all(),
158            )
159            .validate(&module)
160            .map_err(|inner| {
161                TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
162                    source: preprocessed_src.clone(),
163                    label: None,
164                    inner: Box::new(inner),
165                })
166            })?;
167            let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
168                module: alloc::borrow::Cow::Owned(module),
169                info,
170                debug_source: None,
171            });
172            let hal_desc = hal::ShaderModuleDescriptor {
173                label: None,
174                runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
175            };
176            let module = device
177                .raw()
178                .create_shader_module(&hal_desc, hal_shader)
179                .map_err(|error| match error {
180                    hal::ShaderError::Device(error) => {
181                        CreateShaderModuleError::Device(device.handle_hal_error(error))
182                    }
183                    hal::ShaderError::Compilation(ref msg) => {
184                        log::error!("Shader error: {msg}");
185                        CreateShaderModuleError::Generation
186                    }
187                })?;
188
189            let pipeline_layout = device
190                .raw()
191                .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
192                    label: None,
193                    bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
194                    push_constant_ranges: &[PushConstantRange {
195                        stages: wgt::ShaderStages::COMPUTE,
196                        range: 0..8,
197                    }],
198                    flags: hal::PipelineLayoutFlags::empty(),
199                })
200                .map_err(|e| {
201                    TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
202                })?;
203
204            let (multiplier, shift) = compute_timestamp_period(timestamp_period);
205
206            let mut constants = HashMap::with_capacity(2);
207            constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
208            constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
209
210            let pipeline_desc = hal::ComputePipelineDescriptor {
211                label: None,
212                layout: pipeline_layout.as_ref(),
213                stage: hal::ProgrammableStage {
214                    module: module.as_ref(),
215                    entry_point: "main",
216                    constants: &constants,
217                    zero_initialize_workgroup_memory: false,
218                },
219                cache: None,
220            };
221            let pipeline = device
222                .raw()
223                .create_compute_pipeline(&pipeline_desc)
224                .map_err(|err| match err {
225                    hal::PipelineError::Device(error) => {
226                        CreateComputePipelineError::Device(device.handle_hal_error(error))
227                    }
228                    hal::PipelineError::Linkage(_stages, msg) => {
229                        CreateComputePipelineError::Internal(msg)
230                    }
231                    hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
232                        crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
233                    ),
234                    hal::PipelineError::PipelineConstants(_, error) => {
235                        CreateComputePipelineError::PipelineConstants(error)
236                    }
237                })?;
238
239            Ok(Self {
240                state: Some(InternalState {
241                    temporary_bind_group_layout,
242                    pipeline_layout,
243                    pipeline,
244                }),
245            })
246        }
247    }
248
249    /// Create a bind group for normalizing timestamps in `buffer`.
250    ///
251    /// This function is unsafe because it does not know that `buffer_size` is
252    /// the true size of the buffer.
253    pub unsafe fn create_normalization_bind_group(
254        &self,
255        device: &Device,
256        buffer: &dyn hal::DynBuffer,
257        buffer_label: Option<&str>,
258        buffer_size: wgt::BufferSize,
259        buffer_usages: wgt::BufferUsages,
260    ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
261        unsafe {
262            let Some(ref state) = &self.state else {
263                return Ok(TimestampNormalizationBindGroup { raw: None });
264            };
265
266            if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
267                return Ok(TimestampNormalizationBindGroup { raw: None });
268            }
269
270            // If this buffer is large enough that we wouldn't be able to bind the entire thing
271            // at once to normalize the timestamps, we can't use it. We force the buffer to fail
272            // to allocate. The lowest max binding size is 128MB, and query sets must be small
273            // (no more than 4096), so this should never be hit in practice by sane programs.
274            if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
275                return Err(DeviceError::OutOfMemory);
276            }
277
278            let bg_label_alloc;
279            let label = match buffer_label {
280                Some(label) => {
281                    bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
282                    &*bg_label_alloc
283                }
284                None => "Timestamp normalization bind group",
285            };
286
287            let bg = device
288                .raw()
289                .create_bind_group(&hal::BindGroupDescriptor {
290                    label: hal_label(Some(label), device.instance_flags),
291                    layout: &*state.temporary_bind_group_layout,
292                    buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
293                    samplers: &[],
294                    textures: &[],
295                    acceleration_structures: &[],
296                    external_textures: &[],
297                    entries: &[hal::BindGroupEntry {
298                        binding: 0,
299                        resource_index: 0,
300                        count: 1,
301                    }],
302                })
303                .map_err(|e| device.handle_hal_error(e))?;
304
305            Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
306        }
307    }
308
309    pub fn normalize(
310        &self,
311        snatch_guard: &SnatchGuard<'_>,
312        encoder: &mut dyn hal::DynCommandEncoder,
313        tracker: &mut BufferTracker,
314        bind_group: &TimestampNormalizationBindGroup,
315        buffer: &Arc<Buffer>,
316        buffer_offset_bytes: u64,
317        total_timestamps: u32,
318    ) {
319        let Some(ref state) = &self.state else {
320            return;
321        };
322
323        let Some(bind_group) = bind_group.raw.as_deref() else {
324            return;
325        };
326
327        let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); // Unreachable as MAX_QUERIES is way less than u32::MAX
328
329        let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
330
331        let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
332
333        let needed_workgroups = total_timestamps.div_ceil(64);
334
335        unsafe {
336            encoder.transition_buffers(barrier.as_slice());
337            encoder.begin_compute_pass(&hal::ComputePassDescriptor {
338                label: hal_label(
339                    Some("Timestamp normalization pass"),
340                    buffer.device.instance_flags,
341                ),
342                timestamp_writes: None,
343            });
344            encoder.set_compute_pipeline(&*state.pipeline);
345            encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
346            encoder.set_push_constants(
347                &*state.pipeline_layout,
348                wgt::ShaderStages::COMPUTE,
349                0,
350                &[buffer_offset_timestamps, total_timestamps],
351            );
352            encoder.dispatch([needed_workgroups, 1, 1]);
353            encoder.end_compute_pass();
354        }
355    }
356
357    pub fn dispose(self, device: &dyn hal::DynDevice) {
358        unsafe {
359            let Some(state) = self.state else {
360                return;
361            };
362
363            device.destroy_compute_pipeline(state.pipeline);
364            device.destroy_pipeline_layout(state.pipeline_layout);
365            device.destroy_bind_group_layout(state.temporary_bind_group_layout);
366        }
367    }
368
369    pub fn enabled(&self) -> bool {
370        self.state.is_some()
371    }
372}
373
374#[derive(Debug)]
375pub struct TimestampNormalizationBindGroup {
376    raw: Option<Box<dyn hal::DynBindGroup>>,
377}
378
379impl TimestampNormalizationBindGroup {
380    pub fn dispose(self, device: &dyn hal::DynDevice) {
381        unsafe {
382            if let Some(raw) = self.raw {
383                device.destroy_bind_group(raw);
384            }
385        }
386    }
387}
388
389fn compute_timestamp_period(input: f32) -> (u32, u32) {
390    let pow2 = input.log2().ceil() as i32;
391    let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
392    let shift = 32 - clamped_pow2;
393
394    let denominator = (1u64 << shift) as f64;
395
396    // float -> int conversions are defined to saturate.
397    let multiplier = (input as f64 * denominator).round() as u32;
398
399    (multiplier, shift)
400}
401
402#[cfg(test)]
403mod tests {
404    use core::f64;
405
406    fn assert_timestamp_case(input: f32) {
407        let (multiplier, shift) = super::compute_timestamp_period(input);
408
409        let output = multiplier as f64 / (1u64 << shift) as f64;
410
411        assert!((input as f64 - output).abs() < 0.0000001);
412    }
413
414    #[test]
415    fn compute_timestamp_period() {
416        assert_timestamp_case(0.01);
417        assert_timestamp_case(0.5);
418        assert_timestamp_case(1.0);
419        assert_timestamp_case(2.0);
420        assert_timestamp_case(2.7);
421        assert_timestamp_case(1000.7);
422    }
423}