wgpu_core/timestamp_normalization/
mod.rs
1use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35use wgt::PushConstantRange;
36
37use crate::{
38 device::{Device, DeviceError},
39 hal_label,
40 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
41 resource::Buffer,
42 snatch::SnatchGuard,
43 track::BufferTracker,
44};
45
46pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
47 wgt::BufferUses::STORAGE_READ_WRITE;
48
49struct InternalState {
50 temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
51 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
52 pipeline: Box<dyn hal::DynComputePipeline>,
53}
54
55#[derive(Debug, Clone, thiserror::Error)]
56pub enum TimestampNormalizerInitError {
57 #[error("Failed to initialize bind group layout")]
58 BindGroupLayout(#[source] DeviceError),
59 #[cfg(feature = "wgsl")]
60 #[error("Failed to parse shader")]
61 ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
62 #[error("Failed to validate shader module")]
63 ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
64 #[error("Failed to create shader module")]
65 CreateShaderModule(#[from] CreateShaderModuleError),
66 #[error("Failed to create pipeline layout")]
67 PipelineLayout(#[source] DeviceError),
68 #[error("Failed to create compute pipeline")]
69 ComputePipeline(#[from] CreateComputePipelineError),
70}
71
72pub struct TimestampNormalizer {
75 state: Option<InternalState>,
76}
77
78impl TimestampNormalizer {
79 pub fn new(
88 device: &Device,
89 timestamp_period: f32,
90 ) -> Result<Self, TimestampNormalizerInitError> {
91 unsafe {
92 if !device
93 .instance_flags
94 .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
95 {
96 return Ok(Self { state: None });
97 }
98
99 if !device
100 .downlevel
101 .flags
102 .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
103 {
104 log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
105 return Ok(Self { state: None });
106 }
107
108 if timestamp_period == 1.0 {
109 return Ok(Self { state: None });
111 }
112
113 let temporary_bind_group_layout = device
114 .raw()
115 .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
116 label: hal_label(
117 Some("Timestamp Normalization Bind Group Layout"),
118 device.instance_flags,
119 ),
120 flags: hal::BindGroupLayoutFlags::empty(),
121 entries: &[wgt::BindGroupLayoutEntry {
122 binding: 0,
123 visibility: wgt::ShaderStages::COMPUTE,
124 ty: wgt::BindingType::Buffer {
125 ty: wgt::BufferBindingType::Storage { read_only: false },
126 has_dynamic_offset: false,
127 min_binding_size: Some(NonZeroU64::new(8).unwrap()),
128 },
129 count: None,
130 }],
131 })
132 .map_err(|e| {
133 TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
134 })?;
135
136 let common_src = include_str!("common.wgsl");
137 let src = include_str!("timestamp_normalization.wgsl");
138
139 let preprocessed_src = alloc::format!("{common_src}\n{src}");
140
141 #[cfg(feature = "wgsl")]
142 let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
143 TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
144 source: preprocessed_src.clone(),
145 label: None,
146 inner: Box::new(inner),
147 })
148 })?;
149 #[cfg(not(feature = "wgsl"))]
150 #[allow(clippy::diverging_sub_expression)]
151 let module =
152 panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
153
154 let info = crate::device::create_validator(
155 wgt::Features::PUSH_CONSTANTS,
156 wgt::DownlevelFlags::empty(),
157 naga::valid::ValidationFlags::all(),
158 )
159 .validate(&module)
160 .map_err(|inner| {
161 TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
162 source: preprocessed_src.clone(),
163 label: None,
164 inner: Box::new(inner),
165 })
166 })?;
167 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
168 module: alloc::borrow::Cow::Owned(module),
169 info,
170 debug_source: None,
171 });
172 let hal_desc = hal::ShaderModuleDescriptor {
173 label: None,
174 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
175 };
176 let module = device
177 .raw()
178 .create_shader_module(&hal_desc, hal_shader)
179 .map_err(|error| match error {
180 hal::ShaderError::Device(error) => {
181 CreateShaderModuleError::Device(device.handle_hal_error(error))
182 }
183 hal::ShaderError::Compilation(ref msg) => {
184 log::error!("Shader error: {msg}");
185 CreateShaderModuleError::Generation
186 }
187 })?;
188
189 let pipeline_layout = device
190 .raw()
191 .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
192 label: None,
193 bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
194 push_constant_ranges: &[PushConstantRange {
195 stages: wgt::ShaderStages::COMPUTE,
196 range: 0..8,
197 }],
198 flags: hal::PipelineLayoutFlags::empty(),
199 })
200 .map_err(|e| {
201 TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
202 })?;
203
204 let (multiplier, shift) = compute_timestamp_period(timestamp_period);
205
206 let mut constants = HashMap::with_capacity(2);
207 constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
208 constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
209
210 let pipeline_desc = hal::ComputePipelineDescriptor {
211 label: None,
212 layout: pipeline_layout.as_ref(),
213 stage: hal::ProgrammableStage {
214 module: module.as_ref(),
215 entry_point: "main",
216 constants: &constants,
217 zero_initialize_workgroup_memory: false,
218 },
219 cache: None,
220 };
221 let pipeline = device
222 .raw()
223 .create_compute_pipeline(&pipeline_desc)
224 .map_err(|err| match err {
225 hal::PipelineError::Device(error) => {
226 CreateComputePipelineError::Device(device.handle_hal_error(error))
227 }
228 hal::PipelineError::Linkage(_stages, msg) => {
229 CreateComputePipelineError::Internal(msg)
230 }
231 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
232 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
233 ),
234 hal::PipelineError::PipelineConstants(_, error) => {
235 CreateComputePipelineError::PipelineConstants(error)
236 }
237 })?;
238
239 Ok(Self {
240 state: Some(InternalState {
241 temporary_bind_group_layout,
242 pipeline_layout,
243 pipeline,
244 }),
245 })
246 }
247 }
248
249 pub unsafe fn create_normalization_bind_group(
254 &self,
255 device: &Device,
256 buffer: &dyn hal::DynBuffer,
257 buffer_label: Option<&str>,
258 buffer_size: wgt::BufferSize,
259 buffer_usages: wgt::BufferUsages,
260 ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
261 unsafe {
262 let Some(ref state) = &self.state else {
263 return Ok(TimestampNormalizationBindGroup { raw: None });
264 };
265
266 if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
267 return Ok(TimestampNormalizationBindGroup { raw: None });
268 }
269
270 if buffer_size.get() > device.adapter.limits().max_storage_buffer_binding_size as u64 {
275 return Err(DeviceError::OutOfMemory);
276 }
277
278 let bg_label_alloc;
279 let label = match buffer_label {
280 Some(label) => {
281 bg_label_alloc = alloc::format!("Timestamp normalization bind group ({label})");
282 &*bg_label_alloc
283 }
284 None => "Timestamp normalization bind group",
285 };
286
287 let bg = device
288 .raw()
289 .create_bind_group(&hal::BindGroupDescriptor {
290 label: hal_label(Some(label), device.instance_flags),
291 layout: &*state.temporary_bind_group_layout,
292 buffers: &[hal::BufferBinding::new_unchecked(buffer, 0, buffer_size)],
293 samplers: &[],
294 textures: &[],
295 acceleration_structures: &[],
296 external_textures: &[],
297 entries: &[hal::BindGroupEntry {
298 binding: 0,
299 resource_index: 0,
300 count: 1,
301 }],
302 })
303 .map_err(|e| device.handle_hal_error(e))?;
304
305 Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
306 }
307 }
308
309 pub fn normalize(
310 &self,
311 snatch_guard: &SnatchGuard<'_>,
312 encoder: &mut dyn hal::DynCommandEncoder,
313 tracker: &mut BufferTracker,
314 bind_group: &TimestampNormalizationBindGroup,
315 buffer: &Arc<Buffer>,
316 buffer_offset_bytes: u64,
317 total_timestamps: u32,
318 ) {
319 let Some(ref state) = &self.state else {
320 return;
321 };
322
323 let Some(bind_group) = bind_group.raw.as_deref() else {
324 return;
325 };
326
327 let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
330
331 let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
332
333 let needed_workgroups = total_timestamps.div_ceil(64);
334
335 unsafe {
336 encoder.transition_buffers(barrier.as_slice());
337 encoder.begin_compute_pass(&hal::ComputePassDescriptor {
338 label: hal_label(
339 Some("Timestamp normalization pass"),
340 buffer.device.instance_flags,
341 ),
342 timestamp_writes: None,
343 });
344 encoder.set_compute_pipeline(&*state.pipeline);
345 encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
346 encoder.set_push_constants(
347 &*state.pipeline_layout,
348 wgt::ShaderStages::COMPUTE,
349 0,
350 &[buffer_offset_timestamps, total_timestamps],
351 );
352 encoder.dispatch([needed_workgroups, 1, 1]);
353 encoder.end_compute_pass();
354 }
355 }
356
357 pub fn dispose(self, device: &dyn hal::DynDevice) {
358 unsafe {
359 let Some(state) = self.state else {
360 return;
361 };
362
363 device.destroy_compute_pipeline(state.pipeline);
364 device.destroy_pipeline_layout(state.pipeline_layout);
365 device.destroy_bind_group_layout(state.temporary_bind_group_layout);
366 }
367 }
368
369 pub fn enabled(&self) -> bool {
370 self.state.is_some()
371 }
372}
373
374#[derive(Debug)]
375pub struct TimestampNormalizationBindGroup {
376 raw: Option<Box<dyn hal::DynBindGroup>>,
377}
378
379impl TimestampNormalizationBindGroup {
380 pub fn dispose(self, device: &dyn hal::DynDevice) {
381 unsafe {
382 if let Some(raw) = self.raw {
383 device.destroy_bind_group(raw);
384 }
385 }
386 }
387}
388
389fn compute_timestamp_period(input: f32) -> (u32, u32) {
390 let pow2 = input.log2().ceil() as i32;
391 let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
392 let shift = 32 - clamped_pow2;
393
394 let denominator = (1u64 << shift) as f64;
395
396 let multiplier = (input as f64 * denominator).round() as u32;
398
399 (multiplier, shift)
400}
401
402#[cfg(test)]
403mod tests {
404 use core::f64;
405
406 fn assert_timestamp_case(input: f32) {
407 let (multiplier, shift) = super::compute_timestamp_period(input);
408
409 let output = multiplier as f64 / (1u64 << shift) as f64;
410
411 assert!((input as f64 - output).abs() < 0.0000001);
412 }
413
414 #[test]
415 fn compute_timestamp_period() {
416 assert_timestamp_case(0.01);
417 assert_timestamp_case(0.5);
418 assert_timestamp_case(1.0);
419 assert_timestamp_case(2.0);
420 assert_timestamp_case(2.7);
421 assert_timestamp_case(1000.7);
422 }
423}