wgpu_core/command/
query.rs

1use alloc::{sync::Arc, vec, vec::Vec};
2use bit_vec::BitVec;
3use core::{iter, mem};
4
5use crate::{
6    command::{encoder::EncodingState, ArcCommand, EncoderStateError, InnerCommandEncoder},
7    device::{Device, DeviceError, MissingFeatures},
8    global::Global,
9    id,
10    init_tracker::MemoryInitKind,
11    resource::{
12        Buffer, DestroyedResourceError, InvalidResourceError, MissingBufferUsageError,
13        ParentDevice, QuerySet, RawResourceAccess, Trackable,
14    },
15    snatch::SnatchGuard,
16    track::{QuerySetTracker, TrackerIndex},
17    FastHashMap,
18};
19use thiserror::Error;
20use wgt::{
21    error::{ErrorType, WebGpuError},
22    BufferAddress,
23};
24
25pub(crate) struct DeferredQuerySetResolve {
26    pub(crate) query_set: Arc<QuerySet>,
27    pub(crate) query_set_writes: Option<BitVec>,
28    pub(crate) start_query: u32,
29    pub(crate) end_query: u32,
30    pub(crate) dst_buffer: Arc<Buffer>,
31    pub(crate) destination_offset: BufferAddress,
32    /// Bytes per query slot in the destination buffer
33    /// (accounts for pipeline-statistics element count * `QUERY_SIZE`).
34    pub(crate) stride: u64,
35    /// Index into [`InnerCommandEncoder::list`] at which a new command buffer
36    /// for the resolve operation must be inserted at submit time so that it
37    /// executes at exactly the position it was recorded.
38    pub(crate) insertion_point: usize,
39}
40
41pub(crate) type QuerySetWrites = FastHashMap<TrackerIndex, BitVec>;
42
43pub(super) fn record_pass_timestamp_writes(
44    tw: &crate::command::ArcPassTimestampWrites,
45    query_set_writes: &mut QuerySetWrites,
46) {
47    for index in tw
48        .beginning_of_pass_write_index
49        .into_iter()
50        .chain(tw.end_of_pass_write_index)
51    {
52        record_query_write(query_set_writes, &tw.query_set, index);
53    }
54}
55
56pub(crate) fn record_query_write(
57    query_set_writes: &mut QuerySetWrites,
58    query_set: &Arc<QuerySet>,
59    slot_index: u32,
60) {
61    query_set_writes
62        .entry(query_set.tracker_index())
63        .or_insert_with(|| BitVec::from_elem(query_set.desc.count as usize, false))
64        .set(slot_index as usize, true);
65}
66
67#[derive(Debug)]
68pub(crate) struct QueryResetMap {
69    map: FastHashMap<TrackerIndex, (Vec<bool>, Arc<QuerySet>)>,
70}
71impl QueryResetMap {
72    pub fn new() -> Self {
73        Self {
74            map: FastHashMap::default(),
75        }
76    }
77
78    pub fn use_query_set(&mut self, query_set: &Arc<QuerySet>, query: u32) -> bool {
79        let vec_pair = self
80            .map
81            .entry(query_set.tracker_index())
82            .or_insert_with(|| {
83                (
84                    vec![false; query_set.desc.count as usize],
85                    query_set.clone(),
86                )
87            });
88
89        mem::replace(&mut vec_pair.0[query as usize], true)
90    }
91
92    pub fn reset_queries(
93        &mut self,
94        raw_encoder: &mut dyn hal::DynCommandEncoder,
95        snatch_guard: &SnatchGuard<'_>,
96    ) -> Result<(), DestroyedResourceError> {
97        for (_, (state, query_set)) in self.map.drain() {
98            debug_assert_eq!(state.len(), query_set.desc.count as usize);
99
100            // Need to find all "runs" of values which need resets. If the state vector is:
101            // [false, true, true, false, true], we want to reset [1..3, 4..5]. This minimizes
102            // the amount of resets needed.
103            let mut run_start: Option<u32> = None;
104            for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
105                match (run_start, value) {
106                    // We're inside of a run, do nothing
107                    (Some(..), true) => {}
108                    // We've hit the end of a run, dispatch a reset
109                    (Some(start), false) => {
110                        run_start = None;
111                        unsafe {
112                            raw_encoder
113                                .reset_queries(query_set.try_raw(snatch_guard)?, start..idx as u32)
114                        };
115                    }
116                    // We're starting a run
117                    (None, true) => {
118                        run_start = Some(idx as u32);
119                    }
120                    // We're in a run of falses, do nothing.
121                    (None, false) => {}
122                }
123            }
124        }
125        Ok(())
126    }
127}
128
129#[derive(Debug, Copy, Clone, PartialEq, Eq)]
130pub enum SimplifiedQueryType {
131    Occlusion,
132    Timestamp,
133    PipelineStatistics,
134}
135impl From<wgt::QueryType> for SimplifiedQueryType {
136    fn from(q: wgt::QueryType) -> Self {
137        match q {
138            wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
139            wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
140            wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
141        }
142    }
143}
144
145/// Error encountered when dealing with queries
146#[derive(Clone, Debug, Error)]
147#[non_exhaustive]
148pub enum QueryError {
149    #[error(transparent)]
150    Device(#[from] DeviceError),
151    #[error(transparent)]
152    EncoderState(#[from] EncoderStateError),
153    #[error(transparent)]
154    MissingFeature(#[from] MissingFeatures),
155    #[error("Error encountered while trying to use queries")]
156    Use(#[from] QueryUseError),
157    #[error("Error encountered while trying to resolve a query")]
158    Resolve(#[from] ResolveError),
159    #[error(transparent)]
160    DestroyedResource(#[from] DestroyedResourceError),
161    #[error(transparent)]
162    InvalidResource(#[from] InvalidResourceError),
163}
164
165impl WebGpuError for QueryError {
166    fn webgpu_error_type(&self) -> ErrorType {
167        match self {
168            Self::EncoderState(e) => e.webgpu_error_type(),
169            Self::Use(e) => e.webgpu_error_type(),
170            Self::Resolve(e) => e.webgpu_error_type(),
171            Self::InvalidResource(e) => e.webgpu_error_type(),
172            Self::Device(e) => e.webgpu_error_type(),
173            Self::MissingFeature(e) => e.webgpu_error_type(),
174            Self::DestroyedResource(e) => e.webgpu_error_type(),
175        }
176    }
177}
178
179/// Error encountered while trying to use queries
180#[derive(Clone, Debug, Error)]
181#[non_exhaustive]
182pub enum QueryUseError {
183    #[error(transparent)]
184    Device(#[from] DeviceError),
185    #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
186    OutOfBounds {
187        query_index: u32,
188        query_set_size: u32,
189    },
190    #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
191    UsedTwiceInsideRenderpass { query_index: u32 },
192    #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
193    AlreadyStarted {
194        active_query_index: u32,
195        new_query_index: u32,
196    },
197    #[error("Query was stopped while there was no active query")]
198    AlreadyStopped,
199    #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
200    IncompatibleType {
201        set_type: SimplifiedQueryType,
202        query_type: SimplifiedQueryType,
203    },
204    #[error("A query of type {query_type:?} was not ended before the encoder was finished")]
205    MissingEnd { query_type: SimplifiedQueryType },
206    #[error(transparent)]
207    DestroyedResource(#[from] DestroyedResourceError),
208}
209
210impl WebGpuError for QueryUseError {
211    fn webgpu_error_type(&self) -> ErrorType {
212        match self {
213            Self::Device(e) => e.webgpu_error_type(),
214            Self::DestroyedResource(e) => e.webgpu_error_type(),
215            Self::OutOfBounds { .. }
216            | Self::UsedTwiceInsideRenderpass { .. }
217            | Self::AlreadyStarted { .. }
218            | Self::AlreadyStopped
219            | Self::IncompatibleType { .. }
220            | Self::MissingEnd { .. } => ErrorType::Validation,
221        }
222    }
223}
224
225/// Error encountered while trying to resolve a query.
226#[derive(Clone, Debug, Error)]
227#[non_exhaustive]
228pub enum ResolveError {
229    #[error(transparent)]
230    MissingBufferUsage(#[from] MissingBufferUsageError),
231    #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
232    BufferOffsetAlignment,
233    #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
234    QueryOverrun {
235        start_query: u32,
236        end_query: u64,
237        query_set_size: u32,
238    },
239    #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..(<start> + {bytes_used})")]
240    BufferOverrun {
241        start_query: u32,
242        end_query: u32,
243        stride: u32,
244        buffer_size: BufferAddress,
245        buffer_start_offset: BufferAddress,
246        bytes_used: BufferAddress,
247    },
248}
249
250impl WebGpuError for ResolveError {
251    fn webgpu_error_type(&self) -> ErrorType {
252        match self {
253            Self::MissingBufferUsage(e) => e.webgpu_error_type(),
254            Self::BufferOffsetAlignment
255            | Self::QueryOverrun { .. }
256            | Self::BufferOverrun { .. } => ErrorType::Validation,
257        }
258    }
259}
260
261impl QuerySet {
262    pub(crate) fn validate_query(
263        self: &Arc<Self>,
264        query_type: SimplifiedQueryType,
265        query_index: u32,
266        reset_state: Option<&mut QueryResetMap>,
267    ) -> Result<(), QueryUseError> {
268        // NOTE: Further code assumes the index is good, so do this first.
269        if query_index >= self.desc.count {
270            return Err(QueryUseError::OutOfBounds {
271                query_index,
272                query_set_size: self.desc.count,
273            });
274        }
275
276        // We need to defer our resets because we are in a renderpass,
277        // add the usage to the reset map.
278        if let Some(reset) = reset_state {
279            let used = reset.use_query_set(self, query_index);
280            if used {
281                return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
282            }
283        }
284
285        let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
286        if simple_set_type != query_type {
287            return Err(QueryUseError::IncompatibleType {
288                query_type,
289                set_type: simple_set_type,
290            });
291        }
292
293        Ok(())
294    }
295
296    pub(super) fn validate_and_write_timestamp(
297        self: &Arc<Self>,
298        raw_encoder: &mut dyn hal::DynCommandEncoder,
299        query_index: u32,
300        reset_state: Option<&mut QueryResetMap>,
301        snatch_guard: &SnatchGuard<'_>,
302        query_set_writes: &mut QuerySetWrites,
303    ) -> Result<(), QueryUseError> {
304        let needs_reset = reset_state.is_none();
305        self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
306
307        unsafe {
308            // If we don't have a reset state tracker which can defer resets, we must reset now.
309            if needs_reset {
310                raw_encoder
311                    .reset_queries(self.try_raw(snatch_guard)?, query_index..(query_index + 1));
312            }
313            raw_encoder.write_timestamp(self.try_raw(snatch_guard)?, query_index);
314        }
315
316        record_query_write(query_set_writes, self, query_index);
317        Ok(())
318    }
319}
320
321pub(super) fn validate_and_begin_occlusion_query(
322    query_set: Arc<QuerySet>,
323    raw_encoder: &mut dyn hal::DynCommandEncoder,
324    tracker: &mut QuerySetTracker,
325    query_index: u32,
326    reset_state: Option<&mut QueryResetMap>,
327    active_query: &mut Option<(Arc<QuerySet>, u32)>,
328    snatch_guard: &SnatchGuard<'_>,
329) -> Result<(), QueryUseError> {
330    let needs_reset = reset_state.is_none();
331    query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
332
333    tracker.insert_single(query_set.clone());
334
335    if let Some((_old, old_idx)) = active_query.take() {
336        return Err(QueryUseError::AlreadyStarted {
337            active_query_index: old_idx,
338            new_query_index: query_index,
339        });
340    }
341    let (query_set, _) = &active_query.insert((query_set, query_index));
342
343    unsafe {
344        // If we don't have a reset state tracker which can defer resets, we must reset now.
345        if needs_reset {
346            raw_encoder.reset_queries(
347                query_set.try_raw(snatch_guard)?,
348                query_index..(query_index + 1),
349            );
350        }
351        raw_encoder.begin_query(query_set.try_raw(snatch_guard)?, query_index);
352    }
353
354    Ok(())
355}
356
357pub(super) fn end_occlusion_query(
358    raw_encoder: &mut dyn hal::DynCommandEncoder,
359    active_query: &mut Option<(Arc<QuerySet>, u32)>,
360    snatch_guard: &SnatchGuard<'_>,
361    query_set_writes: &mut QuerySetWrites,
362) -> Result<(), QueryUseError> {
363    if let Some((query_set, query_index)) = active_query.take() {
364        unsafe { raw_encoder.end_query(query_set.try_raw(snatch_guard)?, query_index) };
365        record_query_write(query_set_writes, &query_set, query_index);
366        Ok(())
367    } else {
368        Err(QueryUseError::AlreadyStopped)
369    }
370}
371
372pub(super) fn validate_and_begin_pipeline_statistics_query(
373    query_set: Arc<QuerySet>,
374    raw_encoder: &mut dyn hal::DynCommandEncoder,
375    tracker: &mut QuerySetTracker,
376    device: &Arc<Device>,
377    query_index: u32,
378    reset_state: Option<&mut QueryResetMap>,
379    active_query: &mut Option<(Arc<QuerySet>, u32)>,
380    snatch_guard: &SnatchGuard<'_>,
381) -> Result<(), QueryUseError> {
382    query_set.same_device(device)?;
383
384    let needs_reset = reset_state.is_none();
385    query_set.validate_query(
386        SimplifiedQueryType::PipelineStatistics,
387        query_index,
388        reset_state,
389    )?;
390
391    tracker.insert_single(query_set.clone());
392
393    if let Some((_old, old_idx)) = active_query.take() {
394        return Err(QueryUseError::AlreadyStarted {
395            active_query_index: old_idx,
396            new_query_index: query_index,
397        });
398    }
399    let (query_set, _) = &active_query.insert((query_set, query_index));
400
401    unsafe {
402        // If we don't have a reset state tracker which can defer resets, we must reset now.
403        if needs_reset {
404            raw_encoder.reset_queries(
405                query_set.try_raw(snatch_guard)?,
406                query_index..(query_index + 1),
407            );
408        }
409        raw_encoder.begin_query(query_set.try_raw(snatch_guard)?, query_index);
410    }
411
412    Ok(())
413}
414
415pub(super) fn end_pipeline_statistics_query(
416    raw_encoder: &mut dyn hal::DynCommandEncoder,
417    active_query: &mut Option<(Arc<QuerySet>, u32)>,
418    snatch_guard: &SnatchGuard<'_>,
419    query_set_writes: &mut QuerySetWrites,
420) -> Result<(), QueryUseError> {
421    if let Some((query_set, query_index)) = active_query.take() {
422        unsafe { raw_encoder.end_query(query_set.try_raw(snatch_guard)?, query_index) };
423        record_query_write(query_set_writes, &query_set, query_index);
424        Ok(())
425    } else {
426        Err(QueryUseError::AlreadyStopped)
427    }
428}
429
430impl Global {
431    pub fn command_encoder_write_timestamp(
432        &self,
433        command_encoder_id: id::CommandEncoderId,
434        query_set_id: id::QuerySetId,
435        query_index: u32,
436    ) -> Result<(), EncoderStateError> {
437        let hub = &self.hub;
438
439        let cmd_enc = hub.command_encoders.get(command_encoder_id);
440        let mut cmd_buf_data = cmd_enc.data.lock();
441
442        cmd_buf_data.push_with(|| -> Result<_, QueryError> {
443            let query_set = self.resolve_query_set_id(query_set_id);
444            query_set.check_is_valid()?;
445            Ok(ArcCommand::WriteTimestamp {
446                query_set,
447                query_index,
448            })
449        })
450    }
451
452    pub fn command_encoder_resolve_query_set(
453        &self,
454        command_encoder_id: id::CommandEncoderId,
455        query_set_id: id::QuerySetId,
456        start_query: u32,
457        query_count: u32,
458        destination: id::BufferId,
459        destination_offset: BufferAddress,
460    ) -> Result<(), EncoderStateError> {
461        let hub = &self.hub;
462
463        let cmd_enc = hub.command_encoders.get(command_encoder_id);
464        let mut cmd_buf_data = cmd_enc.data.lock();
465
466        cmd_buf_data.push_with(|| -> Result<_, QueryError> {
467            let query_set = self.resolve_query_set_id(query_set_id);
468            query_set.check_is_valid()?;
469            Ok(ArcCommand::ResolveQuerySet {
470                query_set,
471                start_query,
472                query_count,
473                destination: self.resolve_buffer_id(destination)?,
474                destination_offset,
475            })
476        })
477    }
478}
479
480pub(super) fn write_timestamp(
481    state: &mut EncodingState,
482    query_set: Arc<QuerySet>,
483    query_index: u32,
484) -> Result<(), QueryError> {
485    state
486        .device
487        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
488
489    query_set.same_device(state.device)?;
490
491    query_set.validate_and_write_timestamp(
492        state.raw_encoder,
493        query_index,
494        None,
495        state.snatch_guard,
496        state.query_set_writes,
497    )?;
498
499    state.tracker.query_sets.insert_single(query_set);
500
501    Ok(())
502}
503
504pub(super) fn resolve_query_set(
505    state: &mut EncodingState<'_, '_, InnerCommandEncoder>,
506    query_set: Arc<QuerySet>,
507    start_query: u32,
508    query_count: u32,
509    dst_buffer: Arc<Buffer>,
510    destination_offset: BufferAddress,
511) -> Result<(), QueryError> {
512    if !destination_offset.is_multiple_of(wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT) {
513        return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
514    }
515
516    query_set.same_device(state.device)?;
517    dst_buffer.same_device(state.device)?;
518
519    dst_buffer.check_destroyed(state.snatch_guard)?;
520
521    let dst_pending = state
522        .tracker
523        .buffers
524        .set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
525    let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, state.snatch_guard));
526
527    dst_buffer
528        .check_usage(wgt::BufferUsages::QUERY_RESOLVE)
529        .map_err(ResolveError::MissingBufferUsage)?;
530
531    let end_query = u64::from(start_query)
532        .checked_add(u64::from(query_count))
533        .expect("`u64` overflow from adding two `u32`s, should be unreachable");
534    if end_query > u64::from(query_set.desc.count) {
535        return Err(ResolveError::QueryOverrun {
536            start_query,
537            end_query,
538            query_set_size: query_set.desc.count,
539        }
540        .into());
541    }
542    let end_query =
543        u32::try_from(end_query).expect("`u32` overflow for `end_query`, which should be `u32`");
544
545    let elements_per_query = match query_set.desc.ty {
546        wgt::QueryType::Occlusion => 1,
547        wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
548        wgt::QueryType::Timestamp => 1,
549    };
550    let stride = elements_per_query * wgt::QUERY_SIZE;
551    let bytes_used: BufferAddress = u64::from(stride)
552        .checked_mul(u64::from(query_count))
553        .expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
554
555    let buffer_start_offset = destination_offset;
556    let buffer_end_offset = buffer_start_offset
557        .checked_add(bytes_used)
558        .filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
559        .ok_or(ResolveError::BufferOverrun {
560            start_query,
561            end_query,
562            stride,
563            buffer_size: dst_buffer.size,
564            buffer_start_offset,
565            bytes_used,
566        })?;
567
568    let query_set = state.tracker.query_sets.insert_single(query_set);
569
570    state
571        .buffer_memory_init_actions
572        .extend(dst_buffer.initialization_status.read().create_action(
573            &dst_buffer,
574            buffer_start_offset..buffer_end_offset,
575            MemoryInitKind::ImplicitlyInitialized,
576        ));
577
578    let raw_encoder = state.raw_encoder.open_if_closed()?;
579    let raw_dst_buffer = dst_buffer.try_raw(state.snatch_guard)?;
580    unsafe {
581        raw_encoder.transition_buffers(dst_barrier.as_slice());
582    }
583
584    // Check if all slots in the range have been written within this encoder.
585    // If so we can emit `copy_query_results` directly.
586    // Otherwise defer to submit time where we have knowledge of
587    // the query set initialization state.
588    let query_set_writes = state.query_set_writes.get(&query_set.tracker_index());
589    let all_written =
590        query_set_writes.is_some_and(|slots| (start_query..end_query).all(|i| slots[i as usize]));
591
592    if all_written {
593        unsafe {
594            raw_encoder.copy_query_results(
595                query_set.try_raw(state.snatch_guard)?,
596                start_query..end_query,
597                raw_dst_buffer,
598                destination_offset,
599                wgt::BufferSize::new_unchecked(stride as u64),
600            );
601        }
602    } else {
603        state.raw_encoder.close_if_open()?;
604        let insertion_point = state.raw_encoder.list.len();
605
606        state
607            .deferred_query_set_resolves
608            .push(DeferredQuerySetResolve {
609                query_set: query_set.clone(),
610                query_set_writes: query_set_writes.cloned(),
611                start_query,
612                end_query,
613                dst_buffer: dst_buffer.clone(),
614                destination_offset,
615                stride: stride as u64,
616                insertion_point,
617            });
618    }
619
620    if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
621        let raw_encoder = state.raw_encoder.open_if_closed()?;
622
623        // Timestamp normalization is only needed for timestamps.
624        state.device.timestamp_normalizer.get().unwrap().normalize(
625            state.snatch_guard,
626            raw_encoder,
627            &mut state.tracker.buffers,
628            dst_buffer
629                .timestamp_normalization_bind_group
630                .get(state.snatch_guard)
631                .unwrap(),
632            &dst_buffer,
633            destination_offset,
634            query_count,
635        );
636    }
637
638    Ok(())
639}