1use alloc::{sync::Arc, vec, vec::Vec};
2use bit_vec::BitVec;
3use core::{iter, mem};
4
5use crate::{
6 command::{encoder::EncodingState, ArcCommand, EncoderStateError, InnerCommandEncoder},
7 device::{Device, DeviceError, MissingFeatures},
8 global::Global,
9 id,
10 init_tracker::MemoryInitKind,
11 resource::{
12 Buffer, DestroyedResourceError, InvalidResourceError, MissingBufferUsageError,
13 ParentDevice, QuerySet, RawResourceAccess, Trackable,
14 },
15 snatch::SnatchGuard,
16 track::{QuerySetTracker, TrackerIndex},
17 FastHashMap,
18};
19use thiserror::Error;
20use wgt::{
21 error::{ErrorType, WebGpuError},
22 BufferAddress,
23};
24
25pub(crate) struct DeferredQuerySetResolve {
26 pub(crate) query_set: Arc<QuerySet>,
27 pub(crate) query_set_writes: Option<BitVec>,
28 pub(crate) start_query: u32,
29 pub(crate) end_query: u32,
30 pub(crate) dst_buffer: Arc<Buffer>,
31 pub(crate) destination_offset: BufferAddress,
32 pub(crate) stride: u64,
35 pub(crate) insertion_point: usize,
39}
40
41pub(crate) type QuerySetWrites = FastHashMap<TrackerIndex, BitVec>;
42
43pub(super) fn record_pass_timestamp_writes(
44 tw: &crate::command::ArcPassTimestampWrites,
45 query_set_writes: &mut QuerySetWrites,
46) {
47 for index in tw
48 .beginning_of_pass_write_index
49 .into_iter()
50 .chain(tw.end_of_pass_write_index)
51 {
52 record_query_write(query_set_writes, &tw.query_set, index);
53 }
54}
55
56pub(crate) fn record_query_write(
57 query_set_writes: &mut QuerySetWrites,
58 query_set: &Arc<QuerySet>,
59 slot_index: u32,
60) {
61 query_set_writes
62 .entry(query_set.tracker_index())
63 .or_insert_with(|| BitVec::from_elem(query_set.desc.count as usize, false))
64 .set(slot_index as usize, true);
65}
66
67#[derive(Debug)]
68pub(crate) struct QueryResetMap {
69 map: FastHashMap<TrackerIndex, (Vec<bool>, Arc<QuerySet>)>,
70}
71impl QueryResetMap {
72 pub fn new() -> Self {
73 Self {
74 map: FastHashMap::default(),
75 }
76 }
77
78 pub fn use_query_set(&mut self, query_set: &Arc<QuerySet>, query: u32) -> bool {
79 let vec_pair = self
80 .map
81 .entry(query_set.tracker_index())
82 .or_insert_with(|| {
83 (
84 vec![false; query_set.desc.count as usize],
85 query_set.clone(),
86 )
87 });
88
89 mem::replace(&mut vec_pair.0[query as usize], true)
90 }
91
92 pub fn reset_queries(
93 &mut self,
94 raw_encoder: &mut dyn hal::DynCommandEncoder,
95 snatch_guard: &SnatchGuard<'_>,
96 ) -> Result<(), DestroyedResourceError> {
97 for (_, (state, query_set)) in self.map.drain() {
98 debug_assert_eq!(state.len(), query_set.desc.count as usize);
99
100 let mut run_start: Option<u32> = None;
104 for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
105 match (run_start, value) {
106 (Some(..), true) => {}
108 (Some(start), false) => {
110 run_start = None;
111 unsafe {
112 raw_encoder
113 .reset_queries(query_set.try_raw(snatch_guard)?, start..idx as u32)
114 };
115 }
116 (None, true) => {
118 run_start = Some(idx as u32);
119 }
120 (None, false) => {}
122 }
123 }
124 }
125 Ok(())
126 }
127}
128
129#[derive(Debug, Copy, Clone, PartialEq, Eq)]
130pub enum SimplifiedQueryType {
131 Occlusion,
132 Timestamp,
133 PipelineStatistics,
134}
135impl From<wgt::QueryType> for SimplifiedQueryType {
136 fn from(q: wgt::QueryType) -> Self {
137 match q {
138 wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
139 wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
140 wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
141 }
142 }
143}
144
145#[derive(Clone, Debug, Error)]
147#[non_exhaustive]
148pub enum QueryError {
149 #[error(transparent)]
150 Device(#[from] DeviceError),
151 #[error(transparent)]
152 EncoderState(#[from] EncoderStateError),
153 #[error(transparent)]
154 MissingFeature(#[from] MissingFeatures),
155 #[error("Error encountered while trying to use queries")]
156 Use(#[from] QueryUseError),
157 #[error("Error encountered while trying to resolve a query")]
158 Resolve(#[from] ResolveError),
159 #[error(transparent)]
160 DestroyedResource(#[from] DestroyedResourceError),
161 #[error(transparent)]
162 InvalidResource(#[from] InvalidResourceError),
163}
164
165impl WebGpuError for QueryError {
166 fn webgpu_error_type(&self) -> ErrorType {
167 match self {
168 Self::EncoderState(e) => e.webgpu_error_type(),
169 Self::Use(e) => e.webgpu_error_type(),
170 Self::Resolve(e) => e.webgpu_error_type(),
171 Self::InvalidResource(e) => e.webgpu_error_type(),
172 Self::Device(e) => e.webgpu_error_type(),
173 Self::MissingFeature(e) => e.webgpu_error_type(),
174 Self::DestroyedResource(e) => e.webgpu_error_type(),
175 }
176 }
177}
178
179#[derive(Clone, Debug, Error)]
181#[non_exhaustive]
182pub enum QueryUseError {
183 #[error(transparent)]
184 Device(#[from] DeviceError),
185 #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
186 OutOfBounds {
187 query_index: u32,
188 query_set_size: u32,
189 },
190 #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
191 UsedTwiceInsideRenderpass { query_index: u32 },
192 #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
193 AlreadyStarted {
194 active_query_index: u32,
195 new_query_index: u32,
196 },
197 #[error("Query was stopped while there was no active query")]
198 AlreadyStopped,
199 #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
200 IncompatibleType {
201 set_type: SimplifiedQueryType,
202 query_type: SimplifiedQueryType,
203 },
204 #[error("A query of type {query_type:?} was not ended before the encoder was finished")]
205 MissingEnd { query_type: SimplifiedQueryType },
206 #[error(transparent)]
207 DestroyedResource(#[from] DestroyedResourceError),
208}
209
210impl WebGpuError for QueryUseError {
211 fn webgpu_error_type(&self) -> ErrorType {
212 match self {
213 Self::Device(e) => e.webgpu_error_type(),
214 Self::DestroyedResource(e) => e.webgpu_error_type(),
215 Self::OutOfBounds { .. }
216 | Self::UsedTwiceInsideRenderpass { .. }
217 | Self::AlreadyStarted { .. }
218 | Self::AlreadyStopped
219 | Self::IncompatibleType { .. }
220 | Self::MissingEnd { .. } => ErrorType::Validation,
221 }
222 }
223}
224
225#[derive(Clone, Debug, Error)]
227#[non_exhaustive]
228pub enum ResolveError {
229 #[error(transparent)]
230 MissingBufferUsage(#[from] MissingBufferUsageError),
231 #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
232 BufferOffsetAlignment,
233 #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
234 QueryOverrun {
235 start_query: u32,
236 end_query: u64,
237 query_set_size: u32,
238 },
239 #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..(<start> + {bytes_used})")]
240 BufferOverrun {
241 start_query: u32,
242 end_query: u32,
243 stride: u32,
244 buffer_size: BufferAddress,
245 buffer_start_offset: BufferAddress,
246 bytes_used: BufferAddress,
247 },
248}
249
250impl WebGpuError for ResolveError {
251 fn webgpu_error_type(&self) -> ErrorType {
252 match self {
253 Self::MissingBufferUsage(e) => e.webgpu_error_type(),
254 Self::BufferOffsetAlignment
255 | Self::QueryOverrun { .. }
256 | Self::BufferOverrun { .. } => ErrorType::Validation,
257 }
258 }
259}
260
261impl QuerySet {
262 pub(crate) fn validate_query(
263 self: &Arc<Self>,
264 query_type: SimplifiedQueryType,
265 query_index: u32,
266 reset_state: Option<&mut QueryResetMap>,
267 ) -> Result<(), QueryUseError> {
268 if query_index >= self.desc.count {
270 return Err(QueryUseError::OutOfBounds {
271 query_index,
272 query_set_size: self.desc.count,
273 });
274 }
275
276 if let Some(reset) = reset_state {
279 let used = reset.use_query_set(self, query_index);
280 if used {
281 return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
282 }
283 }
284
285 let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
286 if simple_set_type != query_type {
287 return Err(QueryUseError::IncompatibleType {
288 query_type,
289 set_type: simple_set_type,
290 });
291 }
292
293 Ok(())
294 }
295
296 pub(super) fn validate_and_write_timestamp(
297 self: &Arc<Self>,
298 raw_encoder: &mut dyn hal::DynCommandEncoder,
299 query_index: u32,
300 reset_state: Option<&mut QueryResetMap>,
301 snatch_guard: &SnatchGuard<'_>,
302 query_set_writes: &mut QuerySetWrites,
303 ) -> Result<(), QueryUseError> {
304 let needs_reset = reset_state.is_none();
305 self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
306
307 unsafe {
308 if needs_reset {
310 raw_encoder
311 .reset_queries(self.try_raw(snatch_guard)?, query_index..(query_index + 1));
312 }
313 raw_encoder.write_timestamp(self.try_raw(snatch_guard)?, query_index);
314 }
315
316 record_query_write(query_set_writes, self, query_index);
317 Ok(())
318 }
319}
320
321pub(super) fn validate_and_begin_occlusion_query(
322 query_set: Arc<QuerySet>,
323 raw_encoder: &mut dyn hal::DynCommandEncoder,
324 tracker: &mut QuerySetTracker,
325 query_index: u32,
326 reset_state: Option<&mut QueryResetMap>,
327 active_query: &mut Option<(Arc<QuerySet>, u32)>,
328 snatch_guard: &SnatchGuard<'_>,
329) -> Result<(), QueryUseError> {
330 let needs_reset = reset_state.is_none();
331 query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
332
333 tracker.insert_single(query_set.clone());
334
335 if let Some((_old, old_idx)) = active_query.take() {
336 return Err(QueryUseError::AlreadyStarted {
337 active_query_index: old_idx,
338 new_query_index: query_index,
339 });
340 }
341 let (query_set, _) = &active_query.insert((query_set, query_index));
342
343 unsafe {
344 if needs_reset {
346 raw_encoder.reset_queries(
347 query_set.try_raw(snatch_guard)?,
348 query_index..(query_index + 1),
349 );
350 }
351 raw_encoder.begin_query(query_set.try_raw(snatch_guard)?, query_index);
352 }
353
354 Ok(())
355}
356
357pub(super) fn end_occlusion_query(
358 raw_encoder: &mut dyn hal::DynCommandEncoder,
359 active_query: &mut Option<(Arc<QuerySet>, u32)>,
360 snatch_guard: &SnatchGuard<'_>,
361 query_set_writes: &mut QuerySetWrites,
362) -> Result<(), QueryUseError> {
363 if let Some((query_set, query_index)) = active_query.take() {
364 unsafe { raw_encoder.end_query(query_set.try_raw(snatch_guard)?, query_index) };
365 record_query_write(query_set_writes, &query_set, query_index);
366 Ok(())
367 } else {
368 Err(QueryUseError::AlreadyStopped)
369 }
370}
371
372pub(super) fn validate_and_begin_pipeline_statistics_query(
373 query_set: Arc<QuerySet>,
374 raw_encoder: &mut dyn hal::DynCommandEncoder,
375 tracker: &mut QuerySetTracker,
376 device: &Arc<Device>,
377 query_index: u32,
378 reset_state: Option<&mut QueryResetMap>,
379 active_query: &mut Option<(Arc<QuerySet>, u32)>,
380 snatch_guard: &SnatchGuard<'_>,
381) -> Result<(), QueryUseError> {
382 query_set.same_device(device)?;
383
384 let needs_reset = reset_state.is_none();
385 query_set.validate_query(
386 SimplifiedQueryType::PipelineStatistics,
387 query_index,
388 reset_state,
389 )?;
390
391 tracker.insert_single(query_set.clone());
392
393 if let Some((_old, old_idx)) = active_query.take() {
394 return Err(QueryUseError::AlreadyStarted {
395 active_query_index: old_idx,
396 new_query_index: query_index,
397 });
398 }
399 let (query_set, _) = &active_query.insert((query_set, query_index));
400
401 unsafe {
402 if needs_reset {
404 raw_encoder.reset_queries(
405 query_set.try_raw(snatch_guard)?,
406 query_index..(query_index + 1),
407 );
408 }
409 raw_encoder.begin_query(query_set.try_raw(snatch_guard)?, query_index);
410 }
411
412 Ok(())
413}
414
415pub(super) fn end_pipeline_statistics_query(
416 raw_encoder: &mut dyn hal::DynCommandEncoder,
417 active_query: &mut Option<(Arc<QuerySet>, u32)>,
418 snatch_guard: &SnatchGuard<'_>,
419 query_set_writes: &mut QuerySetWrites,
420) -> Result<(), QueryUseError> {
421 if let Some((query_set, query_index)) = active_query.take() {
422 unsafe { raw_encoder.end_query(query_set.try_raw(snatch_guard)?, query_index) };
423 record_query_write(query_set_writes, &query_set, query_index);
424 Ok(())
425 } else {
426 Err(QueryUseError::AlreadyStopped)
427 }
428}
429
430impl Global {
431 pub fn command_encoder_write_timestamp(
432 &self,
433 command_encoder_id: id::CommandEncoderId,
434 query_set_id: id::QuerySetId,
435 query_index: u32,
436 ) -> Result<(), EncoderStateError> {
437 let hub = &self.hub;
438
439 let cmd_enc = hub.command_encoders.get(command_encoder_id);
440 let mut cmd_buf_data = cmd_enc.data.lock();
441
442 cmd_buf_data.push_with(|| -> Result<_, QueryError> {
443 let query_set = self.resolve_query_set_id(query_set_id);
444 query_set.check_is_valid()?;
445 Ok(ArcCommand::WriteTimestamp {
446 query_set,
447 query_index,
448 })
449 })
450 }
451
452 pub fn command_encoder_resolve_query_set(
453 &self,
454 command_encoder_id: id::CommandEncoderId,
455 query_set_id: id::QuerySetId,
456 start_query: u32,
457 query_count: u32,
458 destination: id::BufferId,
459 destination_offset: BufferAddress,
460 ) -> Result<(), EncoderStateError> {
461 let hub = &self.hub;
462
463 let cmd_enc = hub.command_encoders.get(command_encoder_id);
464 let mut cmd_buf_data = cmd_enc.data.lock();
465
466 cmd_buf_data.push_with(|| -> Result<_, QueryError> {
467 let query_set = self.resolve_query_set_id(query_set_id);
468 query_set.check_is_valid()?;
469 Ok(ArcCommand::ResolveQuerySet {
470 query_set,
471 start_query,
472 query_count,
473 destination: self.resolve_buffer_id(destination)?,
474 destination_offset,
475 })
476 })
477 }
478}
479
480pub(super) fn write_timestamp(
481 state: &mut EncodingState,
482 query_set: Arc<QuerySet>,
483 query_index: u32,
484) -> Result<(), QueryError> {
485 state
486 .device
487 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
488
489 query_set.same_device(state.device)?;
490
491 query_set.validate_and_write_timestamp(
492 state.raw_encoder,
493 query_index,
494 None,
495 state.snatch_guard,
496 state.query_set_writes,
497 )?;
498
499 state.tracker.query_sets.insert_single(query_set);
500
501 Ok(())
502}
503
504pub(super) fn resolve_query_set(
505 state: &mut EncodingState<'_, '_, InnerCommandEncoder>,
506 query_set: Arc<QuerySet>,
507 start_query: u32,
508 query_count: u32,
509 dst_buffer: Arc<Buffer>,
510 destination_offset: BufferAddress,
511) -> Result<(), QueryError> {
512 if !destination_offset.is_multiple_of(wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT) {
513 return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
514 }
515
516 query_set.same_device(state.device)?;
517 dst_buffer.same_device(state.device)?;
518
519 dst_buffer.check_destroyed(state.snatch_guard)?;
520
521 let dst_pending = state
522 .tracker
523 .buffers
524 .set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
525 let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, state.snatch_guard));
526
527 dst_buffer
528 .check_usage(wgt::BufferUsages::QUERY_RESOLVE)
529 .map_err(ResolveError::MissingBufferUsage)?;
530
531 let end_query = u64::from(start_query)
532 .checked_add(u64::from(query_count))
533 .expect("`u64` overflow from adding two `u32`s, should be unreachable");
534 if end_query > u64::from(query_set.desc.count) {
535 return Err(ResolveError::QueryOverrun {
536 start_query,
537 end_query,
538 query_set_size: query_set.desc.count,
539 }
540 .into());
541 }
542 let end_query =
543 u32::try_from(end_query).expect("`u32` overflow for `end_query`, which should be `u32`");
544
545 let elements_per_query = match query_set.desc.ty {
546 wgt::QueryType::Occlusion => 1,
547 wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
548 wgt::QueryType::Timestamp => 1,
549 };
550 let stride = elements_per_query * wgt::QUERY_SIZE;
551 let bytes_used: BufferAddress = u64::from(stride)
552 .checked_mul(u64::from(query_count))
553 .expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
554
555 let buffer_start_offset = destination_offset;
556 let buffer_end_offset = buffer_start_offset
557 .checked_add(bytes_used)
558 .filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
559 .ok_or(ResolveError::BufferOverrun {
560 start_query,
561 end_query,
562 stride,
563 buffer_size: dst_buffer.size,
564 buffer_start_offset,
565 bytes_used,
566 })?;
567
568 let query_set = state.tracker.query_sets.insert_single(query_set);
569
570 state
571 .buffer_memory_init_actions
572 .extend(dst_buffer.initialization_status.read().create_action(
573 &dst_buffer,
574 buffer_start_offset..buffer_end_offset,
575 MemoryInitKind::ImplicitlyInitialized,
576 ));
577
578 let raw_encoder = state.raw_encoder.open_if_closed()?;
579 let raw_dst_buffer = dst_buffer.try_raw(state.snatch_guard)?;
580 unsafe {
581 raw_encoder.transition_buffers(dst_barrier.as_slice());
582 }
583
584 let query_set_writes = state.query_set_writes.get(&query_set.tracker_index());
589 let all_written =
590 query_set_writes.is_some_and(|slots| (start_query..end_query).all(|i| slots[i as usize]));
591
592 if all_written {
593 unsafe {
594 raw_encoder.copy_query_results(
595 query_set.try_raw(state.snatch_guard)?,
596 start_query..end_query,
597 raw_dst_buffer,
598 destination_offset,
599 wgt::BufferSize::new_unchecked(stride as u64),
600 );
601 }
602 } else {
603 state.raw_encoder.close_if_open()?;
604 let insertion_point = state.raw_encoder.list.len();
605
606 state
607 .deferred_query_set_resolves
608 .push(DeferredQuerySetResolve {
609 query_set: query_set.clone(),
610 query_set_writes: query_set_writes.cloned(),
611 start_query,
612 end_query,
613 dst_buffer: dst_buffer.clone(),
614 destination_offset,
615 stride: stride as u64,
616 insertion_point,
617 });
618 }
619
620 if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
621 let raw_encoder = state.raw_encoder.open_if_closed()?;
622
623 state.device.timestamp_normalizer.get().unwrap().normalize(
625 state.snatch_guard,
626 raw_encoder,
627 &mut state.tracker.buffers,
628 dst_buffer
629 .timestamp_normalization_bind_group
630 .get(state.snatch_guard)
631 .unwrap(),
632 &dst_buffer,
633 destination_offset,
634 query_count,
635 );
636 }
637
638 Ok(())
639}