1use alloc::{sync::Arc, vec, vec::Vec};
2use core::{iter, mem};
3
4use crate::{
5 command::{encoder::EncodingState, ArcCommand, EncoderStateError},
6 device::{Device, DeviceError, MissingFeatures},
7 global::Global,
8 id,
9 init_tracker::MemoryInitKind,
10 resource::{
11 Buffer, DestroyedResourceError, InvalidResourceError, MissingBufferUsageError,
12 ParentDevice, QuerySet, RawResourceAccess, Trackable,
13 },
14 track::{StatelessTracker, TrackerIndex},
15 FastHashMap,
16};
17use thiserror::Error;
18use wgt::{
19 error::{ErrorType, WebGpuError},
20 BufferAddress,
21};
22
23#[derive(Debug)]
24pub(crate) struct QueryResetMap {
25 map: FastHashMap<TrackerIndex, (Vec<bool>, Arc<QuerySet>)>,
26}
27impl QueryResetMap {
28 pub fn new() -> Self {
29 Self {
30 map: FastHashMap::default(),
31 }
32 }
33
34 pub fn use_query_set(&mut self, query_set: &Arc<QuerySet>, query: u32) -> bool {
35 let vec_pair = self
36 .map
37 .entry(query_set.tracker_index())
38 .or_insert_with(|| {
39 (
40 vec![false; query_set.desc.count as usize],
41 query_set.clone(),
42 )
43 });
44
45 mem::replace(&mut vec_pair.0[query as usize], true)
46 }
47
48 pub fn reset_queries(&mut self, raw_encoder: &mut dyn hal::DynCommandEncoder) {
49 for (_, (state, query_set)) in self.map.drain() {
50 debug_assert_eq!(state.len(), query_set.desc.count as usize);
51
52 let mut run_start: Option<u32> = None;
56 for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
57 match (run_start, value) {
58 (Some(..), true) => {}
60 (Some(start), false) => {
62 run_start = None;
63 unsafe { raw_encoder.reset_queries(query_set.raw(), start..idx as u32) };
64 }
65 (None, true) => {
67 run_start = Some(idx as u32);
68 }
69 (None, false) => {}
71 }
72 }
73 }
74 }
75}
76
77#[derive(Debug, Copy, Clone, PartialEq, Eq)]
78pub enum SimplifiedQueryType {
79 Occlusion,
80 Timestamp,
81 PipelineStatistics,
82}
83impl From<wgt::QueryType> for SimplifiedQueryType {
84 fn from(q: wgt::QueryType) -> Self {
85 match q {
86 wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
87 wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
88 wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
89 }
90 }
91}
92
93#[derive(Clone, Debug, Error)]
95#[non_exhaustive]
96pub enum QueryError {
97 #[error(transparent)]
98 Device(#[from] DeviceError),
99 #[error(transparent)]
100 EncoderState(#[from] EncoderStateError),
101 #[error(transparent)]
102 MissingFeature(#[from] MissingFeatures),
103 #[error("Error encountered while trying to use queries")]
104 Use(#[from] QueryUseError),
105 #[error("Error encountered while trying to resolve a query")]
106 Resolve(#[from] ResolveError),
107 #[error(transparent)]
108 DestroyedResource(#[from] DestroyedResourceError),
109 #[error(transparent)]
110 InvalidResource(#[from] InvalidResourceError),
111}
112
113impl WebGpuError for QueryError {
114 fn webgpu_error_type(&self) -> ErrorType {
115 let e: &dyn WebGpuError = match self {
116 Self::EncoderState(e) => e,
117 Self::Use(e) => e,
118 Self::Resolve(e) => e,
119 Self::InvalidResource(e) => e,
120 Self::Device(e) => e,
121 Self::MissingFeature(e) => e,
122 Self::DestroyedResource(e) => e,
123 };
124 e.webgpu_error_type()
125 }
126}
127
128#[derive(Clone, Debug, Error)]
130#[non_exhaustive]
131pub enum QueryUseError {
132 #[error(transparent)]
133 Device(#[from] DeviceError),
134 #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
135 OutOfBounds {
136 query_index: u32,
137 query_set_size: u32,
138 },
139 #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
140 UsedTwiceInsideRenderpass { query_index: u32 },
141 #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
142 AlreadyStarted {
143 active_query_index: u32,
144 new_query_index: u32,
145 },
146 #[error("Query was stopped while there was no active query")]
147 AlreadyStopped,
148 #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
149 IncompatibleType {
150 set_type: SimplifiedQueryType,
151 query_type: SimplifiedQueryType,
152 },
153}
154
155impl WebGpuError for QueryUseError {
156 fn webgpu_error_type(&self) -> ErrorType {
157 match self {
158 Self::Device(e) => e.webgpu_error_type(),
159 Self::OutOfBounds { .. }
160 | Self::UsedTwiceInsideRenderpass { .. }
161 | Self::AlreadyStarted { .. }
162 | Self::AlreadyStopped
163 | Self::IncompatibleType { .. } => ErrorType::Validation,
164 }
165 }
166}
167
168#[derive(Clone, Debug, Error)]
170#[non_exhaustive]
171pub enum ResolveError {
172 #[error(transparent)]
173 MissingBufferUsage(#[from] MissingBufferUsageError),
174 #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
175 BufferOffsetAlignment,
176 #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
177 QueryOverrun {
178 start_query: u32,
179 end_query: u64,
180 query_set_size: u32,
181 },
182 #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..(<start> + {bytes_used})")]
183 BufferOverrun {
184 start_query: u32,
185 end_query: u32,
186 stride: u32,
187 buffer_size: BufferAddress,
188 buffer_start_offset: BufferAddress,
189 bytes_used: BufferAddress,
190 },
191}
192
193impl WebGpuError for ResolveError {
194 fn webgpu_error_type(&self) -> ErrorType {
195 match self {
196 Self::MissingBufferUsage(e) => e.webgpu_error_type(),
197 Self::BufferOffsetAlignment
198 | Self::QueryOverrun { .. }
199 | Self::BufferOverrun { .. } => ErrorType::Validation,
200 }
201 }
202}
203
204impl QuerySet {
205 pub(crate) fn validate_query(
206 self: &Arc<Self>,
207 query_type: SimplifiedQueryType,
208 query_index: u32,
209 reset_state: Option<&mut QueryResetMap>,
210 ) -> Result<(), QueryUseError> {
211 if query_index >= self.desc.count {
213 return Err(QueryUseError::OutOfBounds {
214 query_index,
215 query_set_size: self.desc.count,
216 });
217 }
218
219 if let Some(reset) = reset_state {
222 let used = reset.use_query_set(self, query_index);
223 if used {
224 return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
225 }
226 }
227
228 let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
229 if simple_set_type != query_type {
230 return Err(QueryUseError::IncompatibleType {
231 query_type,
232 set_type: simple_set_type,
233 });
234 }
235
236 Ok(())
237 }
238
239 pub(super) fn validate_and_write_timestamp(
240 self: &Arc<Self>,
241 raw_encoder: &mut dyn hal::DynCommandEncoder,
242 query_index: u32,
243 reset_state: Option<&mut QueryResetMap>,
244 ) -> Result<(), QueryUseError> {
245 let needs_reset = reset_state.is_none();
246 self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
247
248 unsafe {
249 if needs_reset {
251 raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
252 }
253 raw_encoder.write_timestamp(self.raw(), query_index);
254 }
255
256 Ok(())
257 }
258}
259
260pub(super) fn validate_and_begin_occlusion_query(
261 query_set: Arc<QuerySet>,
262 raw_encoder: &mut dyn hal::DynCommandEncoder,
263 tracker: &mut StatelessTracker<QuerySet>,
264 query_index: u32,
265 reset_state: Option<&mut QueryResetMap>,
266 active_query: &mut Option<(Arc<QuerySet>, u32)>,
267) -> Result<(), QueryUseError> {
268 let needs_reset = reset_state.is_none();
269 query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
270
271 tracker.insert_single(query_set.clone());
272
273 if let Some((_old, old_idx)) = active_query.take() {
274 return Err(QueryUseError::AlreadyStarted {
275 active_query_index: old_idx,
276 new_query_index: query_index,
277 });
278 }
279 let (query_set, _) = &active_query.insert((query_set, query_index));
280
281 unsafe {
282 if needs_reset {
284 raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
285 }
286 raw_encoder.begin_query(query_set.raw(), query_index);
287 }
288
289 Ok(())
290}
291
292pub(super) fn end_occlusion_query(
293 raw_encoder: &mut dyn hal::DynCommandEncoder,
294 active_query: &mut Option<(Arc<QuerySet>, u32)>,
295) -> Result<(), QueryUseError> {
296 if let Some((query_set, query_index)) = active_query.take() {
297 unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
298 Ok(())
299 } else {
300 Err(QueryUseError::AlreadyStopped)
301 }
302}
303
304pub(super) fn validate_and_begin_pipeline_statistics_query(
305 query_set: Arc<QuerySet>,
306 raw_encoder: &mut dyn hal::DynCommandEncoder,
307 tracker: &mut StatelessTracker<QuerySet>,
308 device: &Arc<Device>,
309 query_index: u32,
310 reset_state: Option<&mut QueryResetMap>,
311 active_query: &mut Option<(Arc<QuerySet>, u32)>,
312) -> Result<(), QueryUseError> {
313 query_set.same_device(device)?;
314
315 let needs_reset = reset_state.is_none();
316 query_set.validate_query(
317 SimplifiedQueryType::PipelineStatistics,
318 query_index,
319 reset_state,
320 )?;
321
322 tracker.insert_single(query_set.clone());
323
324 if let Some((_old, old_idx)) = active_query.take() {
325 return Err(QueryUseError::AlreadyStarted {
326 active_query_index: old_idx,
327 new_query_index: query_index,
328 });
329 }
330 let (query_set, _) = &active_query.insert((query_set, query_index));
331
332 unsafe {
333 if needs_reset {
335 raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
336 }
337 raw_encoder.begin_query(query_set.raw(), query_index);
338 }
339
340 Ok(())
341}
342
343pub(super) fn end_pipeline_statistics_query(
344 raw_encoder: &mut dyn hal::DynCommandEncoder,
345 active_query: &mut Option<(Arc<QuerySet>, u32)>,
346) -> Result<(), QueryUseError> {
347 if let Some((query_set, query_index)) = active_query.take() {
348 unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
349 Ok(())
350 } else {
351 Err(QueryUseError::AlreadyStopped)
352 }
353}
354
355impl Global {
356 pub fn command_encoder_write_timestamp(
357 &self,
358 command_encoder_id: id::CommandEncoderId,
359 query_set_id: id::QuerySetId,
360 query_index: u32,
361 ) -> Result<(), EncoderStateError> {
362 let hub = &self.hub;
363
364 let cmd_enc = hub.command_encoders.get(command_encoder_id);
365 let mut cmd_buf_data = cmd_enc.data.lock();
366
367 cmd_buf_data.push_with(|| -> Result<_, QueryError> {
368 Ok(ArcCommand::WriteTimestamp {
369 query_set: self.resolve_query_set(query_set_id)?,
370 query_index,
371 })
372 })
373 }
374
375 pub fn command_encoder_resolve_query_set(
376 &self,
377 command_encoder_id: id::CommandEncoderId,
378 query_set_id: id::QuerySetId,
379 start_query: u32,
380 query_count: u32,
381 destination: id::BufferId,
382 destination_offset: BufferAddress,
383 ) -> Result<(), EncoderStateError> {
384 let hub = &self.hub;
385
386 let cmd_enc = hub.command_encoders.get(command_encoder_id);
387 let mut cmd_buf_data = cmd_enc.data.lock();
388
389 cmd_buf_data.push_with(|| -> Result<_, QueryError> {
390 Ok(ArcCommand::ResolveQuerySet {
391 query_set: self.resolve_query_set(query_set_id)?,
392 start_query,
393 query_count,
394 destination: self.resolve_buffer_id(destination)?,
395 destination_offset,
396 })
397 })
398 }
399}
400
401pub(super) fn write_timestamp(
402 state: &mut EncodingState,
403 query_set: Arc<QuerySet>,
404 query_index: u32,
405) -> Result<(), QueryError> {
406 state
407 .device
408 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
409
410 query_set.same_device(state.device)?;
411
412 query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
413
414 state.tracker.query_sets.insert_single(query_set);
415
416 Ok(())
417}
418
419pub(super) fn resolve_query_set(
420 state: &mut EncodingState,
421 query_set: Arc<QuerySet>,
422 start_query: u32,
423 query_count: u32,
424 dst_buffer: Arc<Buffer>,
425 destination_offset: BufferAddress,
426) -> Result<(), QueryError> {
427 if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
428 return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
429 }
430
431 query_set.same_device(state.device)?;
432 dst_buffer.same_device(state.device)?;
433
434 dst_buffer.check_destroyed(state.snatch_guard)?;
435
436 let dst_pending = state
437 .tracker
438 .buffers
439 .set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
440 let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, state.snatch_guard));
441
442 dst_buffer
443 .check_usage(wgt::BufferUsages::QUERY_RESOLVE)
444 .map_err(ResolveError::MissingBufferUsage)?;
445
446 let end_query = u64::from(start_query)
447 .checked_add(u64::from(query_count))
448 .expect("`u64` overflow from adding two `u32`s, should be unreachable");
449 if end_query > u64::from(query_set.desc.count) {
450 return Err(ResolveError::QueryOverrun {
451 start_query,
452 end_query,
453 query_set_size: query_set.desc.count,
454 }
455 .into());
456 }
457 let end_query =
458 u32::try_from(end_query).expect("`u32` overflow for `end_query`, which should be `u32`");
459
460 let elements_per_query = match query_set.desc.ty {
461 wgt::QueryType::Occlusion => 1,
462 wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
463 wgt::QueryType::Timestamp => 1,
464 };
465 let stride = elements_per_query * wgt::QUERY_SIZE;
466 let bytes_used: BufferAddress = u64::from(stride)
467 .checked_mul(u64::from(query_count))
468 .expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
469
470 let buffer_start_offset = destination_offset;
471 let buffer_end_offset = buffer_start_offset
472 .checked_add(bytes_used)
473 .filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
474 .ok_or(ResolveError::BufferOverrun {
475 start_query,
476 end_query,
477 stride,
478 buffer_size: dst_buffer.size,
479 buffer_start_offset,
480 bytes_used,
481 })?;
482
483 state
485 .buffer_memory_init_actions
486 .extend(dst_buffer.initialization_status.read().create_action(
487 &dst_buffer,
488 buffer_start_offset..buffer_end_offset,
489 MemoryInitKind::ImplicitlyInitialized,
490 ));
491
492 let raw_dst_buffer = dst_buffer.try_raw(state.snatch_guard)?;
493 unsafe {
494 state.raw_encoder.transition_buffers(dst_barrier.as_slice());
495 state.raw_encoder.copy_query_results(
496 query_set.raw(),
497 start_query..end_query,
498 raw_dst_buffer,
499 destination_offset,
500 wgt::BufferSize::new_unchecked(stride as u64),
501 );
502 }
503
504 if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
505 state.device.timestamp_normalizer.get().unwrap().normalize(
507 state.snatch_guard,
508 state.raw_encoder,
509 &mut state.tracker.buffers,
510 dst_buffer
511 .timestamp_normalization_bind_group
512 .get(state.snatch_guard)
513 .unwrap(),
514 &dst_buffer,
515 destination_offset,
516 query_count,
517 );
518 }
519
520 state.tracker.query_sets.insert_single(query_set);
521
522 Ok(())
523}