wgpu_core/timestamp_normalization/
mod.rs

1//! Utility for normalizing GPU timestamp queries to have a consistent
2//! 1GHz period. This uses a compute shader to do the normalization,
3//! so the timestamps exist in their correct format on the GPU, as
4//! is required by the WebGPU specification.
5//!
6//! ## Algorithm
7//!
8//! The fundamental operation is multiplying a u64 timestamp by an f32
9//! value. We have neither f64s nor u64s in shaders, so we need to do
10//! something more complicated.
11//!
12//! We first decompose the f32 into a u32 fraction where the denominator
13//! is a power of two. We do the computation with f64 for ease of computation,
14//! as those can store u32s losslessly.
15//!
16//! Because the denominator is a power of two, this means the shader can evaluate
17//! this divide by using a shift. Additionally, we always choose the largest denominator
18//! we can, so that the fraction is as precise as possible.
19//!
20//! To evaluate this function, we have two helper operations (both in common.wgsl).
21//!
22//! 1. `u64_mul_u32` multiplies a u64 by a u32 and returns a u96.
23//! 2. `shift_right_u96` shifts a u96 right by a given amount, returning a u96.
24//!
25//! See their implementations for more details.
26//!
27//! We then multiply the timestamp by the numerator, and shift it right by the
28//! denominator. This gives us the normalized timestamp.
29
30use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35
36use crate::{
37    device::{Device, DeviceError},
38    hal_label,
39    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40    resource::Buffer,
41    snatch::SnatchGuard,
42    track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46    wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49    temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51    pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56    #[error("Failed to initialize bind group layout")]
57    BindGroupLayout(#[source] DeviceError),
58    #[cfg(feature = "wgsl")]
59    #[error("Failed to parse shader")]
60    ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61    #[error("Failed to validate shader module")]
62    ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63    #[error("Failed to create shader module")]
64    CreateShaderModule(#[from] CreateShaderModuleError),
65    #[error("Failed to create pipeline layout")]
66    PipelineLayout(#[source] DeviceError),
67    #[error("Failed to create compute pipeline")]
68    ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71/// Normalizes GPU timestamps to have a consistent 1GHz period.
72/// See module documentation for more information.
73pub struct TimestampNormalizer {
74    state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78    /// Creates a new timestamp normalizer.
79    ///
80    /// If the device cannot support automatic timestamp normalization,
81    /// this will return a normalizer that does nothing.
82    ///
83    /// # Errors
84    ///
85    /// If any resources are invalid, this will return an error.
86    pub fn new(
87        device: &Device,
88        timestamp_period: f32,
89    ) -> Result<Self, TimestampNormalizerInitError> {
90        unsafe {
91            if !device
92                .instance_flags
93                .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94            {
95                return Ok(Self { state: None });
96            }
97
98            if !device
99                .downlevel
100                .flags
101                .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102            {
103                log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104                return Ok(Self { state: None });
105            }
106
107            if timestamp_period == 1.0 {
108                // If the period is 1, we don't need to do anything to them.
109                return Ok(Self { state: None });
110            }
111
112            let temporary_bind_group_layout = device
113                .raw()
114                .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115                    label: hal_label(
116                        Some("(wgpu internal) Timestamp Normalization Bind Group Layout"),
117                        device.instance_flags,
118                    ),
119                    flags: hal::BindGroupLayoutFlags::empty(),
120                    entries: &[wgt::BindGroupLayoutEntry {
121                        binding: 0,
122                        visibility: wgt::ShaderStages::COMPUTE,
123                        ty: wgt::BindingType::Buffer {
124                            ty: wgt::BufferBindingType::Storage { read_only: false },
125                            has_dynamic_offset: false,
126                            min_binding_size: Some(NonZeroU64::new(8).unwrap()),
127                        },
128                        count: None,
129                    }],
130                })
131                .map_err(|e| {
132                    TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
133                })?;
134
135            let common_src = include_str!("common.wgsl");
136            let src = include_str!("timestamp_normalization.wgsl");
137
138            let preprocessed_src = alloc::format!("{common_src}\n{src}");
139
140            #[cfg(feature = "wgsl")]
141            let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
142                TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
143                    source: preprocessed_src.clone(),
144                    label: None,
145                    inner: Box::new(inner),
146                })
147            })?;
148            #[cfg(not(feature = "wgsl"))]
149            #[allow(clippy::diverging_sub_expression)]
150            let module =
151                panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
152
153            let info = crate::device::create_validator(
154                wgt::Features::IMMEDIATES,
155                wgt::DownlevelFlags::empty(),
156                naga::valid::ValidationFlags::all(),
157            )
158            .validate(&module)
159            .map_err(|inner| {
160                TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
161                    source: preprocessed_src.clone(),
162                    label: None,
163                    inner: Box::new(inner),
164                })
165            })?;
166            let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
167                module: alloc::borrow::Cow::Owned(module),
168                info,
169                debug_source: None,
170            });
171            let hal_desc = hal::ShaderModuleDescriptor {
172                label: hal_label(
173                    Some("(wgpu internal) Timestamp normalizer shader module"),
174                    device.instance_flags,
175                ),
176                runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
177            };
178            let module = device
179                .raw()
180                .create_shader_module(&hal_desc, hal_shader)
181                .map_err(|error| match error {
182                    hal::ShaderError::Device(error) => {
183                        CreateShaderModuleError::Device(device.handle_hal_error(error))
184                    }
185                    hal::ShaderError::Compilation(ref msg) => {
186                        log::error!("Shader error: {msg}");
187                        CreateShaderModuleError::Generation
188                    }
189                })?;
190
191            let pipeline_layout = device
192                .raw()
193                .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
194                    label: hal_label(
195                        Some("(wgpu internal) Timestamp normalizer pipeline layout"),
196                        device.instance_flags,
197                    ),
198                    bind_group_layouts: &[Some(temporary_bind_group_layout.as_ref())],
199                    immediate_size: 8,
200                    flags: hal::PipelineLayoutFlags::empty(),
201                })
202                .map_err(|e| {
203                    TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
204                })?;
205
206            let (multiplier, shift) = compute_timestamp_period(timestamp_period);
207
208            let mut constants = HashMap::with_capacity(2);
209            constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
210            constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
211
212            let pipeline_desc = hal::ComputePipelineDescriptor {
213                label: hal_label(
214                    Some("(wgpu internal) Timestamp normalizer pipeline"),
215                    device.instance_flags,
216                ),
217                layout: pipeline_layout.as_ref(),
218                stage: hal::ProgrammableStage {
219                    module: module.as_ref(),
220                    entry_point: "main",
221                    constants: &constants,
222                    zero_initialize_workgroup_memory: false,
223                },
224                cache: None,
225            };
226            let pipeline = device
227                .raw()
228                .create_compute_pipeline(&pipeline_desc)
229                .map_err(|err| match err {
230                    hal::PipelineError::Device(error) => {
231                        CreateComputePipelineError::Device(device.handle_hal_error(error))
232                    }
233                    hal::PipelineError::Linkage(_stages, msg) => {
234                        CreateComputePipelineError::Internal(msg)
235                    }
236                    hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
237                        crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
238                    ),
239                    hal::PipelineError::PipelineConstants(_, error) => {
240                        CreateComputePipelineError::PipelineConstants(error)
241                    }
242                })?;
243
244            Ok(Self {
245                state: Some(InternalState {
246                    temporary_bind_group_layout,
247                    pipeline_layout,
248                    pipeline,
249                }),
250            })
251        }
252    }
253
254    /// Create a bind group for normalizing timestamps in `buffer`.
255    ///
256    /// This function is unsafe because it does not know that `buffer_size` is
257    /// the true size of the buffer.
258    pub unsafe fn create_normalization_bind_group(
259        &self,
260        device: &Device,
261        buffer: &dyn hal::DynBuffer,
262        buffer_label: Option<&str>,
263        buffer_size: wgt::BufferSize,
264        buffer_usages: wgt::BufferUsages,
265    ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
266        unsafe {
267            let Some(ref state) = &self.state else {
268                return Ok(TimestampNormalizationBindGroup { raw: None });
269            };
270
271            if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
272                return Ok(TimestampNormalizationBindGroup { raw: None });
273            }
274
275            // If this buffer is large enough that we wouldn't be able to bind the entire thing
276            // at once to normalize the timestamps, we can't use it. We force the buffer to fail
277            // to allocate. The lowest max binding size is 128MB, and query sets must be small
278            // (no more than 4096), so this should never be hit in practice by sane programs.
279            if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
280                return Err(DeviceError::OutOfMemory);
281            }
282
283            let bg_label_alloc;
284            let label = match buffer_label {
285                Some(label) => {
286                    bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
287                    &*bg_label_alloc
288                }
289                None => "Timestamp normalization bind group",
290            };
291
292            let bg = device
293                .raw()
294                .create_bind_group(&hal::BindGroupDescriptor {
295                    label: hal_label(Some(label), device.instance_flags),
296                    layout: &*state.temporary_bind_group_layout,
297                    buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
298                    samplers: &[],
299                    textures: &[],
300                    acceleration_structures: &[],
301                    external_textures: &[],
302                    entries: &[hal::BindGroupEntry {
303                        binding: 0,
304                        resource_index: 0,
305                        count: 1,
306                    }],
307                })
308                .map_err(|e| device.handle_hal_error(e))?;
309
310            Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
311        }
312    }
313
314    pub fn normalize(
315        &self,
316        snatch_guard: &SnatchGuard<'_>,
317        encoder: &mut dyn hal::DynCommandEncoder,
318        tracker: &mut BufferTracker,
319        bind_group: &TimestampNormalizationBindGroup,
320        buffer: &Arc<Buffer>,
321        buffer_offset_bytes: u64,
322        total_timestamps: u32,
323    ) {
324        let Some(ref state) = &self.state else {
325            return;
326        };
327
328        let Some(bind_group) = bind_group.raw.as_deref() else {
329            return;
330        };
331
332        let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); // Unreachable as MAX_QUERIES is way less than u32::MAX
333
334        let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
335
336        let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
337
338        let needed_workgroups = total_timestamps.div_ceil(64);
339
340        unsafe {
341            encoder.transition_buffers(barrier.as_slice());
342            encoder.begin_compute_pass(&hal::ComputePassDescriptor {
343                label: hal_label(
344                    Some("(wgpu internal) Timestamp normalization pass"),
345                    buffer.device.instance_flags,
346                ),
347                timestamp_writes: None,
348            });
349            encoder.set_compute_pipeline(&*state.pipeline);
350            encoder.set_bind_group(&*state.pipeline_layout, 0, bind_group, &[]);
351            encoder.set_immediates(
352                &*state.pipeline_layout,
353                0,
354                &[buffer_offset_timestamps, total_timestamps],
355            );
356            encoder.dispatch([needed_workgroups, 1, 1]);
357            encoder.end_compute_pass();
358        }
359    }
360
361    pub fn dispose(self, device: &dyn hal::DynDevice) {
362        unsafe {
363            let Some(state) = self.state else {
364                return;
365            };
366
367            device.destroy_compute_pipeline(state.pipeline);
368            device.destroy_pipeline_layout(state.pipeline_layout);
369            device.destroy_bind_group_layout(state.temporary_bind_group_layout);
370        }
371    }
372
373    pub fn enabled(&self) -> bool {
374        self.state.is_some()
375    }
376}
377
378#[derive(Debug)]
379pub struct TimestampNormalizationBindGroup {
380    raw: Option<Box<dyn hal::DynBindGroup>>,
381}
382
383impl TimestampNormalizationBindGroup {
384    pub fn dispose(self, device: &dyn hal::DynDevice) {
385        unsafe {
386            if let Some(raw) = self.raw {
387                device.destroy_bind_group(raw);
388            }
389        }
390    }
391}
392
393fn compute_timestamp_period(input: f32) -> (u32, u32) {
394    let pow2 = input.log2().ceil() as i32;
395    let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
396    let shift = 32 - clamped_pow2;
397
398    let denominator = (1u64 << shift) as f64;
399
400    // float -> int conversions are defined to saturate.
401    let multiplier = (input as f64 * denominator).round() as u32;
402
403    (multiplier, shift)
404}
405
406#[cfg(test)]
407mod tests {
408    use core::f64;
409
410    fn assert_timestamp_case(input: f32) {
411        let (multiplier, shift) = super::compute_timestamp_period(input);
412
413        let output = multiplier as f64 / (1u64 << shift) as f64;
414
415        assert!((input as f64 - output).abs() < 0.0000001);
416    }
417
418    #[test]
419    fn compute_timestamp_period() {
420        assert_timestamp_case(0.01);
421        assert_timestamp_case(0.5);
422        assert_timestamp_case(1.0);
423        assert_timestamp_case(2.0);
424        assert_timestamp_case(2.7);
425        assert_timestamp_case(1000.7);
426    }
427}