wgpu_core/timestamp_normalization/
mod.rs1use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35
36use crate::{
37 device::{Device, DeviceError},
38 hal_label,
39 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40 resource::Buffer,
41 snatch::SnatchGuard,
42 track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46 wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49 temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51 pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56 #[error("Failed to initialize bind group layout")]
57 BindGroupLayout(#[source] DeviceError),
58 #[cfg(feature = "wgsl")]
59 #[error("Failed to parse shader")]
60 ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61 #[error("Failed to validate shader module")]
62 ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63 #[error("Failed to create shader module")]
64 CreateShaderModule(#[from] CreateShaderModuleError),
65 #[error("Failed to create pipeline layout")]
66 PipelineLayout(#[source] DeviceError),
67 #[error("Failed to create compute pipeline")]
68 ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71pub struct TimestampNormalizer {
74 state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78 pub fn new(
87 device: &Device,
88 timestamp_period: f32,
89 ) -> Result<Self, TimestampNormalizerInitError> {
90 unsafe {
91 if !device
92 .instance_flags
93 .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94 {
95 return Ok(Self { state: None });
96 }
97
98 if !device
99 .downlevel
100 .flags
101 .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102 {
103 log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104 return Ok(Self { state: None });
105 }
106
107 if timestamp_period == 1.0 {
108 return Ok(Self { state: None });
110 }
111
112 let temporary_bind_group_layout = device
113 .raw()
114 .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115 label: hal_label(
116 Some("Timestamp Normalization Bind Group Layout"),
117 device.instance_flags,
118 ),
119 flags: hal::BindGroupLayoutFlags::empty(),
120 entries: &[wgt::BindGroupLayoutEntry {
121 binding: 0,
122 visibility: wgt::ShaderStages::COMPUTE,
123 ty: wgt::BindingType::Buffer {
124 ty: wgt::BufferBindingType::Storage { read_only: false },
125 has_dynamic_offset: false,
126 min_binding_size: Some(NonZeroU64::new(8).unwrap()),
127 },
128 count: None,
129 }],
130 })
131 .map_err(|e| {
132 TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
133 })?;
134
135 let common_src = include_str!("common.wgsl");
136 let src = include_str!("timestamp_normalization.wgsl");
137
138 let preprocessed_src = alloc::format!("{common_src}\n{src}");
139
140 #[cfg(feature = "wgsl")]
141 let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
142 TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
143 source: preprocessed_src.clone(),
144 label: None,
145 inner: Box::new(inner),
146 })
147 })?;
148 #[cfg(not(feature = "wgsl"))]
149 #[allow(clippy::diverging_sub_expression)]
150 let module =
151 panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
152
153 let info = crate::device::create_validator(
154 wgt::Features::IMMEDIATES,
155 wgt::DownlevelFlags::empty(),
156 naga::valid::ValidationFlags::all(),
157 )
158 .validate(&module)
159 .map_err(|inner| {
160 TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
161 source: preprocessed_src.clone(),
162 label: None,
163 inner: Box::new(inner),
164 })
165 })?;
166 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
167 module: alloc::borrow::Cow::Owned(module),
168 info,
169 debug_source: None,
170 });
171 let hal_desc = hal::ShaderModuleDescriptor {
172 label: None,
173 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
174 };
175 let module = device
176 .raw()
177 .create_shader_module(&hal_desc, hal_shader)
178 .map_err(|error| match error {
179 hal::ShaderError::Device(error) => {
180 CreateShaderModuleError::Device(device.handle_hal_error(error))
181 }
182 hal::ShaderError::Compilation(ref msg) => {
183 log::error!("Shader error: {msg}");
184 CreateShaderModuleError::Generation
185 }
186 })?;
187
188 let pipeline_layout = device
189 .raw()
190 .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
191 label: None,
192 bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
193 immediate_size: 8,
194 flags: hal::PipelineLayoutFlags::empty(),
195 })
196 .map_err(|e| {
197 TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
198 })?;
199
200 let (multiplier, shift) = compute_timestamp_period(timestamp_period);
201
202 let mut constants = HashMap::with_capacity(2);
203 constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
204 constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
205
206 let pipeline_desc = hal::ComputePipelineDescriptor {
207 label: None,
208 layout: pipeline_layout.as_ref(),
209 stage: hal::ProgrammableStage {
210 module: module.as_ref(),
211 entry_point: "main",
212 constants: &constants,
213 zero_initialize_workgroup_memory: false,
214 },
215 cache: None,
216 };
217 let pipeline = device
218 .raw()
219 .create_compute_pipeline(&pipeline_desc)
220 .map_err(|err| match err {
221 hal::PipelineError::Device(error) => {
222 CreateComputePipelineError::Device(device.handle_hal_error(error))
223 }
224 hal::PipelineError::Linkage(_stages, msg) => {
225 CreateComputePipelineError::Internal(msg)
226 }
227 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
228 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
229 ),
230 hal::PipelineError::PipelineConstants(_, error) => {
231 CreateComputePipelineError::PipelineConstants(error)
232 }
233 })?;
234
235 Ok(Self {
236 state: Some(InternalState {
237 temporary_bind_group_layout,
238 pipeline_layout,
239 pipeline,
240 }),
241 })
242 }
243 }
244
245 pub unsafe fn create_normalization_bind_group(
250 &self,
251 device: &Device,
252 buffer: &dyn hal::DynBuffer,
253 buffer_label: Option<&str>,
254 buffer_size: wgt::BufferSize,
255 buffer_usages: wgt::BufferUsages,
256 ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
257 unsafe {
258 let Some(ref state) = &self.state else {
259 return Ok(TimestampNormalizationBindGroup { raw: None });
260 };
261
262 if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
263 return Ok(TimestampNormalizationBindGroup { raw: None });
264 }
265
266 if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
271 return Err(DeviceError::OutOfMemory);
272 }
273
274 let bg_label_alloc;
275 let label = match buffer_label {
276 Some(label) => {
277 bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
278 &*bg_label_alloc
279 }
280 None => "Timestamp normalization bind group",
281 };
282
283 let bg = device
284 .raw()
285 .create_bind_group(&hal::BindGroupDescriptor {
286 label: hal_label(Some(label), device.instance_flags),
287 layout: &*state.temporary_bind_group_layout,
288 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
289 samplers: &[],
290 textures: &[],
291 acceleration_structures: &[],
292 external_textures: &[],
293 entries: &[hal::BindGroupEntry {
294 binding: 0,
295 resource_index: 0,
296 count: 1,
297 }],
298 })
299 .map_err(|e| device.handle_hal_error(e))?;
300
301 Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
302 }
303 }
304
305 pub fn normalize(
306 &self,
307 snatch_guard: &SnatchGuard<'_>,
308 encoder: &mut dyn hal::DynCommandEncoder,
309 tracker: &mut BufferTracker,
310 bind_group: &TimestampNormalizationBindGroup,
311 buffer: &Arc<Buffer>,
312 buffer_offset_bytes: u64,
313 total_timestamps: u32,
314 ) {
315 let Some(ref state) = &self.state else {
316 return;
317 };
318
319 let Some(bind_group) = bind_group.raw.as_deref() else {
320 return;
321 };
322
323 let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
326
327 let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
328
329 let needed_workgroups = total_timestamps.div_ceil(64);
330
331 unsafe {
332 encoder.transition_buffers(barrier.as_slice());
333 encoder.begin_compute_pass(&hal::ComputePassDescriptor {
334 label: hal_label(
335 Some("Timestamp normalization pass"),
336 buffer.device.instance_flags,
337 ),
338 timestamp_writes: None,
339 });
340 encoder.set_compute_pipeline(&*state.pipeline);
341 encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
342 encoder.set_immediates(
343 &*state.pipeline_layout,
344 0,
345 &[buffer_offset_timestamps, total_timestamps],
346 );
347 encoder.dispatch([needed_workgroups, 1, 1]);
348 encoder.end_compute_pass();
349 }
350 }
351
352 pub fn dispose(self, device: &dyn hal::DynDevice) {
353 unsafe {
354 let Some(state) = self.state else {
355 return;
356 };
357
358 device.destroy_compute_pipeline(state.pipeline);
359 device.destroy_pipeline_layout(state.pipeline_layout);
360 device.destroy_bind_group_layout(state.temporary_bind_group_layout);
361 }
362 }
363
364 pub fn enabled(&self) -> bool {
365 self.state.is_some()
366 }
367}
368
369#[derive(Debug)]
370pub struct TimestampNormalizationBindGroup {
371 raw: Option<Box<dyn hal::DynBindGroup>>,
372}
373
374impl TimestampNormalizationBindGroup {
375 pub fn dispose(self, device: &dyn hal::DynDevice) {
376 unsafe {
377 if let Some(raw) = self.raw {
378 device.destroy_bind_group(raw);
379 }
380 }
381 }
382}
383
384fn compute_timestamp_period(input: f32) -> (u32, u32) {
385 let pow2 = input.log2().ceil() as i32;
386 let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
387 let shift = 32 - clamped_pow2;
388
389 let denominator = (1u64 << shift) as f64;
390
391 let multiplier = (input as f64 * denominator).round() as u32;
393
394 (multiplier, shift)
395}
396
397#[cfg(test)]
398mod tests {
399 use core::f64;
400
401 fn assert_timestamp_case(input: f32) {
402 let (multiplier, shift) = super::compute_timestamp_period(input);
403
404 let output = multiplier as f64 / (1u64 << shift) as f64;
405
406 assert!((input as f64 - output).abs() < 0.0000001);
407 }
408
409 #[test]
410 fn compute_timestamp_period() {
411 assert_timestamp_case(0.01);
412 assert_timestamp_case(0.5);
413 assert_timestamp_case(1.0);
414 assert_timestamp_case(2.0);
415 assert_timestamp_case(2.7);
416 assert_timestamp_case(1000.7);
417 }
418}