wgpu_core/timestamp_normalization/
mod.rs1use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35
36use crate::{
37 device::{Device, DeviceError},
38 hal_label,
39 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40 resource::Buffer,
41 snatch::SnatchGuard,
42 track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46 wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49 temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51 pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56 #[error("Failed to initialize bind group layout")]
57 BindGroupLayout(#[source] DeviceError),
58 #[cfg(feature = "wgsl")]
59 #[error("Failed to parse shader")]
60 ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61 #[error("Failed to validate shader module")]
62 ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63 #[error("Failed to create shader module")]
64 CreateShaderModule(#[from] CreateShaderModuleError),
65 #[error("Failed to create pipeline layout")]
66 PipelineLayout(#[source] DeviceError),
67 #[error("Failed to create compute pipeline")]
68 ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71pub struct TimestampNormalizer {
74 state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78 pub fn new(
87 device: &Device,
88 timestamp_period: f32,
89 ) -> Result<Self, TimestampNormalizerInitError> {
90 unsafe {
91 if !device
92 .instance_flags
93 .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94 {
95 return Ok(Self { state: None });
96 }
97
98 if !device
99 .downlevel
100 .flags
101 .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102 {
103 log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104 return Ok(Self { state: None });
105 }
106
107 if timestamp_period == 1.0 {
108 return Ok(Self { state: None });
110 }
111
112 let temporary_bind_group_layout = device
113 .raw()
114 .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115 label: hal_label(
116 Some("(wgpu internal) Timestamp Normalization Bind Group Layout"),
117 device.instance_flags,
118 ),
119 flags: hal::BindGroupLayoutFlags::empty(),
120 entries: &[wgt::BindGroupLayoutEntry {
121 binding: 0,
122 visibility: wgt::ShaderStages::COMPUTE,
123 ty: wgt::BindingType::Buffer {
124 ty: wgt::BufferBindingType::Storage { read_only: false },
125 has_dynamic_offset: false,
126 min_binding_size: Some(NonZeroU64::new(8).unwrap()),
127 },
128 count: None,
129 }],
130 })
131 .map_err(|e| {
132 TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
133 })?;
134
135 let common_src = include_str!("common.wgsl");
136 let src = include_str!("timestamp_normalization.wgsl");
137
138 let preprocessed_src = alloc::format!("{common_src}\n{src}");
139
140 #[cfg(feature = "wgsl")]
141 let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
142 TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
143 source: preprocessed_src.clone(),
144 label: None,
145 inner: Box::new(inner),
146 })
147 })?;
148 #[cfg(not(feature = "wgsl"))]
149 #[allow(clippy::diverging_sub_expression)]
150 let module =
151 panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
152
153 let info = crate::device::create_validator(
154 wgt::Features::IMMEDIATES,
155 wgt::DownlevelFlags::empty(),
156 naga::valid::ValidationFlags::all(),
157 )
158 .validate(&module)
159 .map_err(|inner| {
160 TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
161 source: preprocessed_src.clone(),
162 label: None,
163 inner: Box::new(inner),
164 })
165 })?;
166 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
167 module: alloc::borrow::Cow::Owned(module),
168 info,
169 debug_source: None,
170 });
171 let hal_desc = hal::ShaderModuleDescriptor {
172 label: hal_label(
173 Some("(wgpu internal) Timestamp normalizer shader module"),
174 device.instance_flags,
175 ),
176 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
177 };
178 let module = device
179 .raw()
180 .create_shader_module(&hal_desc, hal_shader)
181 .map_err(|error| match error {
182 hal::ShaderError::Device(error) => {
183 CreateShaderModuleError::Device(device.handle_hal_error(error))
184 }
185 hal::ShaderError::Compilation(ref msg) => {
186 log::error!("Shader error: {msg}");
187 CreateShaderModuleError::Generation
188 }
189 })?;
190
191 let pipeline_layout = device
192 .raw()
193 .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
194 label: hal_label(
195 Some("(wgpu internal) Timestamp normalizer pipeline layout"),
196 device.instance_flags,
197 ),
198 bind_group_layouts: &[Some(temporary_bind_group_layout.as_ref())],
199 immediate_size: 8,
200 flags: hal::PipelineLayoutFlags::empty(),
201 })
202 .map_err(|e| {
203 TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
204 })?;
205
206 let (multiplier, shift) = compute_timestamp_period(timestamp_period);
207
208 let mut constants = HashMap::with_capacity(2);
209 constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
210 constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
211
212 let pipeline_desc = hal::ComputePipelineDescriptor {
213 label: hal_label(
214 Some("(wgpu internal) Timestamp normalizer pipeline"),
215 device.instance_flags,
216 ),
217 layout: pipeline_layout.as_ref(),
218 stage: hal::ProgrammableStage {
219 module: module.as_ref(),
220 entry_point: "main",
221 constants: &constants,
222 zero_initialize_workgroup_memory: false,
223 },
224 cache: None,
225 };
226 let pipeline = device
227 .raw()
228 .create_compute_pipeline(&pipeline_desc)
229 .map_err(|err| match err {
230 hal::PipelineError::Device(error) => {
231 CreateComputePipelineError::Device(device.handle_hal_error(error))
232 }
233 hal::PipelineError::Linkage(_stages, msg) => {
234 CreateComputePipelineError::Internal(msg)
235 }
236 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
237 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
238 ),
239 hal::PipelineError::PipelineConstants(_, error) => {
240 CreateComputePipelineError::PipelineConstants(error)
241 }
242 })?;
243
244 Ok(Self {
245 state: Some(InternalState {
246 temporary_bind_group_layout,
247 pipeline_layout,
248 pipeline,
249 }),
250 })
251 }
252 }
253
254 pub unsafe fn create_normalization_bind_group(
259 &self,
260 device: &Device,
261 buffer: &dyn hal::DynBuffer,
262 buffer_label: Option<&str>,
263 buffer_size: wgt::BufferSize,
264 buffer_usages: wgt::BufferUsages,
265 ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
266 unsafe {
267 let Some(ref state) = &self.state else {
268 return Ok(TimestampNormalizationBindGroup { raw: None });
269 };
270
271 if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
272 return Ok(TimestampNormalizationBindGroup { raw: None });
273 }
274
275 if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
280 return Err(DeviceError::OutOfMemory);
281 }
282
283 let bg_label_alloc;
284 let label = match buffer_label {
285 Some(label) => {
286 bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
287 &*bg_label_alloc
288 }
289 None => "Timestamp normalization bind group",
290 };
291
292 let bg = device
293 .raw()
294 .create_bind_group(&hal::BindGroupDescriptor {
295 label: hal_label(Some(label), device.instance_flags),
296 layout: &*state.temporary_bind_group_layout,
297 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
298 samplers: &[],
299 textures: &[],
300 acceleration_structures: &[],
301 external_textures: &[],
302 entries: &[hal::BindGroupEntry {
303 binding: 0,
304 resource_index: 0,
305 count: 1,
306 }],
307 })
308 .map_err(|e| device.handle_hal_error(e))?;
309
310 Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
311 }
312 }
313
314 pub fn normalize(
315 &self,
316 snatch_guard: &SnatchGuard<'_>,
317 encoder: &mut dyn hal::DynCommandEncoder,
318 tracker: &mut BufferTracker,
319 bind_group: &TimestampNormalizationBindGroup,
320 buffer: &Arc<Buffer>,
321 buffer_offset_bytes: u64,
322 total_timestamps: u32,
323 ) {
324 let Some(ref state) = &self.state else {
325 return;
326 };
327
328 let Some(bind_group) = bind_group.raw.as_deref() else {
329 return;
330 };
331
332 let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
335
336 let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
337
338 let needed_workgroups = total_timestamps.div_ceil(64);
339
340 unsafe {
341 encoder.transition_buffers(barrier.as_slice());
342 encoder.begin_compute_pass(&hal::ComputePassDescriptor {
343 label: hal_label(
344 Some("(wgpu internal) Timestamp normalization pass"),
345 buffer.device.instance_flags,
346 ),
347 timestamp_writes: None,
348 });
349 encoder.set_compute_pipeline(&*state.pipeline);
350 encoder.set_bind_group(&*state.pipeline_layout, 0, bind_group, &[]);
351 encoder.set_immediates(
352 &*state.pipeline_layout,
353 0,
354 &[buffer_offset_timestamps, total_timestamps],
355 );
356 encoder.dispatch([needed_workgroups, 1, 1]);
357 encoder.end_compute_pass();
358 }
359 }
360
361 pub fn dispose(self, device: &dyn hal::DynDevice) {
362 unsafe {
363 let Some(state) = self.state else {
364 return;
365 };
366
367 device.destroy_compute_pipeline(state.pipeline);
368 device.destroy_pipeline_layout(state.pipeline_layout);
369 device.destroy_bind_group_layout(state.temporary_bind_group_layout);
370 }
371 }
372
373 pub fn enabled(&self) -> bool {
374 self.state.is_some()
375 }
376}
377
378#[derive(Debug)]
379pub struct TimestampNormalizationBindGroup {
380 raw: Option<Box<dyn hal::DynBindGroup>>,
381}
382
383impl TimestampNormalizationBindGroup {
384 pub fn dispose(self, device: &dyn hal::DynDevice) {
385 unsafe {
386 if let Some(raw) = self.raw {
387 device.destroy_bind_group(raw);
388 }
389 }
390 }
391}
392
393fn compute_timestamp_period(input: f32) -> (u32, u32) {
394 let pow2 = input.log2().ceil() as i32;
395 let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
396 let shift = 32 - clamped_pow2;
397
398 let denominator = (1u64 << shift) as f64;
399
400 let multiplier = (input as f64 * denominator).round() as u32;
402
403 (multiplier, shift)
404}
405
406#[cfg(test)]
407mod tests {
408 use core::f64;
409
410 fn assert_timestamp_case(input: f32) {
411 let (multiplier, shift) = super::compute_timestamp_period(input);
412
413 let output = multiplier as f64 / (1u64 << shift) as f64;
414
415 assert!((input as f64 - output).abs() < 0.0000001);
416 }
417
418 #[test]
419 fn compute_timestamp_period() {
420 assert_timestamp_case(0.01);
421 assert_timestamp_case(0.5);
422 assert_timestamp_case(1.0);
423 assert_timestamp_case(2.0);
424 assert_timestamp_case(2.7);
425 assert_timestamp_case(1000.7);
426 }
427}