wgpu_core/command/
query.rs

1use alloc::{sync::Arc, vec, vec::Vec};
2use core::{iter, mem};
3
4use crate::{
5    command::{encoder::EncodingState, ArcCommand, EncoderStateError},
6    device::{Device, DeviceError, MissingFeatures},
7    global::Global,
8    id,
9    init_tracker::MemoryInitKind,
10    resource::{
11        Buffer, DestroyedResourceError, InvalidResourceError, MissingBufferUsageError,
12        ParentDevice, QuerySet, RawResourceAccess, Trackable,
13    },
14    track::{StatelessTracker, TrackerIndex},
15    FastHashMap,
16};
17use thiserror::Error;
18use wgt::{
19    error::{ErrorType, WebGpuError},
20    BufferAddress,
21};
22
23#[derive(Debug)]
24pub(crate) struct QueryResetMap {
25    map: FastHashMap<TrackerIndex, (Vec<bool>, Arc<QuerySet>)>,
26}
27impl QueryResetMap {
28    pub fn new() -> Self {
29        Self {
30            map: FastHashMap::default(),
31        }
32    }
33
34    pub fn use_query_set(&mut self, query_set: &Arc<QuerySet>, query: u32) -> bool {
35        let vec_pair = self
36            .map
37            .entry(query_set.tracker_index())
38            .or_insert_with(|| {
39                (
40                    vec![false; query_set.desc.count as usize],
41                    query_set.clone(),
42                )
43            });
44
45        mem::replace(&mut vec_pair.0[query as usize], true)
46    }
47
48    pub fn reset_queries(&mut self, raw_encoder: &mut dyn hal::DynCommandEncoder) {
49        for (_, (state, query_set)) in self.map.drain() {
50            debug_assert_eq!(state.len(), query_set.desc.count as usize);
51
52            // Need to find all "runs" of values which need resets. If the state vector is:
53            // [false, true, true, false, true], we want to reset [1..3, 4..5]. This minimizes
54            // the amount of resets needed.
55            let mut run_start: Option<u32> = None;
56            for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
57                match (run_start, value) {
58                    // We're inside of a run, do nothing
59                    (Some(..), true) => {}
60                    // We've hit the end of a run, dispatch a reset
61                    (Some(start), false) => {
62                        run_start = None;
63                        unsafe { raw_encoder.reset_queries(query_set.raw(), start..idx as u32) };
64                    }
65                    // We're starting a run
66                    (None, true) => {
67                        run_start = Some(idx as u32);
68                    }
69                    // We're in a run of falses, do nothing.
70                    (None, false) => {}
71                }
72            }
73        }
74    }
75}
76
77#[derive(Debug, Copy, Clone, PartialEq, Eq)]
78pub enum SimplifiedQueryType {
79    Occlusion,
80    Timestamp,
81    PipelineStatistics,
82}
83impl From<wgt::QueryType> for SimplifiedQueryType {
84    fn from(q: wgt::QueryType) -> Self {
85        match q {
86            wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
87            wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
88            wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
89        }
90    }
91}
92
93/// Error encountered when dealing with queries
94#[derive(Clone, Debug, Error)]
95#[non_exhaustive]
96pub enum QueryError {
97    #[error(transparent)]
98    Device(#[from] DeviceError),
99    #[error(transparent)]
100    EncoderState(#[from] EncoderStateError),
101    #[error(transparent)]
102    MissingFeature(#[from] MissingFeatures),
103    #[error("Error encountered while trying to use queries")]
104    Use(#[from] QueryUseError),
105    #[error("Error encountered while trying to resolve a query")]
106    Resolve(#[from] ResolveError),
107    #[error(transparent)]
108    DestroyedResource(#[from] DestroyedResourceError),
109    #[error(transparent)]
110    InvalidResource(#[from] InvalidResourceError),
111}
112
113impl WebGpuError for QueryError {
114    fn webgpu_error_type(&self) -> ErrorType {
115        let e: &dyn WebGpuError = match self {
116            Self::EncoderState(e) => e,
117            Self::Use(e) => e,
118            Self::Resolve(e) => e,
119            Self::InvalidResource(e) => e,
120            Self::Device(e) => e,
121            Self::MissingFeature(e) => e,
122            Self::DestroyedResource(e) => e,
123        };
124        e.webgpu_error_type()
125    }
126}
127
128/// Error encountered while trying to use queries
129#[derive(Clone, Debug, Error)]
130#[non_exhaustive]
131pub enum QueryUseError {
132    #[error(transparent)]
133    Device(#[from] DeviceError),
134    #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
135    OutOfBounds {
136        query_index: u32,
137        query_set_size: u32,
138    },
139    #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
140    UsedTwiceInsideRenderpass { query_index: u32 },
141    #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
142    AlreadyStarted {
143        active_query_index: u32,
144        new_query_index: u32,
145    },
146    #[error("Query was stopped while there was no active query")]
147    AlreadyStopped,
148    #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
149    IncompatibleType {
150        set_type: SimplifiedQueryType,
151        query_type: SimplifiedQueryType,
152    },
153    #[error("A query of type {query_type:?} was not ended before the encoder was finished")]
154    MissingEnd { query_type: SimplifiedQueryType },
155}
156
157impl WebGpuError for QueryUseError {
158    fn webgpu_error_type(&self) -> ErrorType {
159        match self {
160            Self::Device(e) => e.webgpu_error_type(),
161            Self::OutOfBounds { .. }
162            | Self::UsedTwiceInsideRenderpass { .. }
163            | Self::AlreadyStarted { .. }
164            | Self::AlreadyStopped
165            | Self::IncompatibleType { .. }
166            | Self::MissingEnd { .. } => ErrorType::Validation,
167        }
168    }
169}
170
171/// Error encountered while trying to resolve a query.
172#[derive(Clone, Debug, Error)]
173#[non_exhaustive]
174pub enum ResolveError {
175    #[error(transparent)]
176    MissingBufferUsage(#[from] MissingBufferUsageError),
177    #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
178    BufferOffsetAlignment,
179    #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
180    QueryOverrun {
181        start_query: u32,
182        end_query: u64,
183        query_set_size: u32,
184    },
185    #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..(<start> + {bytes_used})")]
186    BufferOverrun {
187        start_query: u32,
188        end_query: u32,
189        stride: u32,
190        buffer_size: BufferAddress,
191        buffer_start_offset: BufferAddress,
192        bytes_used: BufferAddress,
193    },
194}
195
196impl WebGpuError for ResolveError {
197    fn webgpu_error_type(&self) -> ErrorType {
198        match self {
199            Self::MissingBufferUsage(e) => e.webgpu_error_type(),
200            Self::BufferOffsetAlignment
201            | Self::QueryOverrun { .. }
202            | Self::BufferOverrun { .. } => ErrorType::Validation,
203        }
204    }
205}
206
207impl QuerySet {
208    pub(crate) fn validate_query(
209        self: &Arc<Self>,
210        query_type: SimplifiedQueryType,
211        query_index: u32,
212        reset_state: Option<&mut QueryResetMap>,
213    ) -> Result<(), QueryUseError> {
214        // NOTE: Further code assumes the index is good, so do this first.
215        if query_index >= self.desc.count {
216            return Err(QueryUseError::OutOfBounds {
217                query_index,
218                query_set_size: self.desc.count,
219            });
220        }
221
222        // We need to defer our resets because we are in a renderpass,
223        // add the usage to the reset map.
224        if let Some(reset) = reset_state {
225            let used = reset.use_query_set(self, query_index);
226            if used {
227                return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
228            }
229        }
230
231        let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
232        if simple_set_type != query_type {
233            return Err(QueryUseError::IncompatibleType {
234                query_type,
235                set_type: simple_set_type,
236            });
237        }
238
239        Ok(())
240    }
241
242    pub(super) fn validate_and_write_timestamp(
243        self: &Arc<Self>,
244        raw_encoder: &mut dyn hal::DynCommandEncoder,
245        query_index: u32,
246        reset_state: Option<&mut QueryResetMap>,
247    ) -> Result<(), QueryUseError> {
248        let needs_reset = reset_state.is_none();
249        self.validate_query(SimplifiedQueryType::Timestamp, query_index, reset_state)?;
250
251        unsafe {
252            // If we don't have a reset state tracker which can defer resets, we must reset now.
253            if needs_reset {
254                raw_encoder.reset_queries(self.raw(), query_index..(query_index + 1));
255            }
256            raw_encoder.write_timestamp(self.raw(), query_index);
257        }
258
259        Ok(())
260    }
261}
262
263pub(super) fn validate_and_begin_occlusion_query(
264    query_set: Arc<QuerySet>,
265    raw_encoder: &mut dyn hal::DynCommandEncoder,
266    tracker: &mut StatelessTracker<QuerySet>,
267    query_index: u32,
268    reset_state: Option<&mut QueryResetMap>,
269    active_query: &mut Option<(Arc<QuerySet>, u32)>,
270) -> Result<(), QueryUseError> {
271    let needs_reset = reset_state.is_none();
272    query_set.validate_query(SimplifiedQueryType::Occlusion, query_index, reset_state)?;
273
274    tracker.insert_single(query_set.clone());
275
276    if let Some((_old, old_idx)) = active_query.take() {
277        return Err(QueryUseError::AlreadyStarted {
278            active_query_index: old_idx,
279            new_query_index: query_index,
280        });
281    }
282    let (query_set, _) = &active_query.insert((query_set, query_index));
283
284    unsafe {
285        // If we don't have a reset state tracker which can defer resets, we must reset now.
286        if needs_reset {
287            raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
288        }
289        raw_encoder.begin_query(query_set.raw(), query_index);
290    }
291
292    Ok(())
293}
294
295pub(super) fn end_occlusion_query(
296    raw_encoder: &mut dyn hal::DynCommandEncoder,
297    active_query: &mut Option<(Arc<QuerySet>, u32)>,
298) -> Result<(), QueryUseError> {
299    if let Some((query_set, query_index)) = active_query.take() {
300        unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
301        Ok(())
302    } else {
303        Err(QueryUseError::AlreadyStopped)
304    }
305}
306
307pub(super) fn validate_and_begin_pipeline_statistics_query(
308    query_set: Arc<QuerySet>,
309    raw_encoder: &mut dyn hal::DynCommandEncoder,
310    tracker: &mut StatelessTracker<QuerySet>,
311    device: &Arc<Device>,
312    query_index: u32,
313    reset_state: Option<&mut QueryResetMap>,
314    active_query: &mut Option<(Arc<QuerySet>, u32)>,
315) -> Result<(), QueryUseError> {
316    query_set.same_device(device)?;
317
318    let needs_reset = reset_state.is_none();
319    query_set.validate_query(
320        SimplifiedQueryType::PipelineStatistics,
321        query_index,
322        reset_state,
323    )?;
324
325    tracker.insert_single(query_set.clone());
326
327    if let Some((_old, old_idx)) = active_query.take() {
328        return Err(QueryUseError::AlreadyStarted {
329            active_query_index: old_idx,
330            new_query_index: query_index,
331        });
332    }
333    let (query_set, _) = &active_query.insert((query_set, query_index));
334
335    unsafe {
336        // If we don't have a reset state tracker which can defer resets, we must reset now.
337        if needs_reset {
338            raw_encoder.reset_queries(query_set.raw(), query_index..(query_index + 1));
339        }
340        raw_encoder.begin_query(query_set.raw(), query_index);
341    }
342
343    Ok(())
344}
345
346pub(super) fn end_pipeline_statistics_query(
347    raw_encoder: &mut dyn hal::DynCommandEncoder,
348    active_query: &mut Option<(Arc<QuerySet>, u32)>,
349) -> Result<(), QueryUseError> {
350    if let Some((query_set, query_index)) = active_query.take() {
351        unsafe { raw_encoder.end_query(query_set.raw(), query_index) };
352        Ok(())
353    } else {
354        Err(QueryUseError::AlreadyStopped)
355    }
356}
357
358impl Global {
359    pub fn command_encoder_write_timestamp(
360        &self,
361        command_encoder_id: id::CommandEncoderId,
362        query_set_id: id::QuerySetId,
363        query_index: u32,
364    ) -> Result<(), EncoderStateError> {
365        let hub = &self.hub;
366
367        let cmd_enc = hub.command_encoders.get(command_encoder_id);
368        let mut cmd_buf_data = cmd_enc.data.lock();
369
370        cmd_buf_data.push_with(|| -> Result<_, QueryError> {
371            Ok(ArcCommand::WriteTimestamp {
372                query_set: self.resolve_query_set(query_set_id)?,
373                query_index,
374            })
375        })
376    }
377
378    pub fn command_encoder_resolve_query_set(
379        &self,
380        command_encoder_id: id::CommandEncoderId,
381        query_set_id: id::QuerySetId,
382        start_query: u32,
383        query_count: u32,
384        destination: id::BufferId,
385        destination_offset: BufferAddress,
386    ) -> Result<(), EncoderStateError> {
387        let hub = &self.hub;
388
389        let cmd_enc = hub.command_encoders.get(command_encoder_id);
390        let mut cmd_buf_data = cmd_enc.data.lock();
391
392        cmd_buf_data.push_with(|| -> Result<_, QueryError> {
393            Ok(ArcCommand::ResolveQuerySet {
394                query_set: self.resolve_query_set(query_set_id)?,
395                start_query,
396                query_count,
397                destination: self.resolve_buffer_id(destination)?,
398                destination_offset,
399            })
400        })
401    }
402}
403
404pub(super) fn write_timestamp(
405    state: &mut EncodingState,
406    query_set: Arc<QuerySet>,
407    query_index: u32,
408) -> Result<(), QueryError> {
409    state
410        .device
411        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
412
413    query_set.same_device(state.device)?;
414
415    query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?;
416
417    state.tracker.query_sets.insert_single(query_set);
418
419    Ok(())
420}
421
422pub(super) fn resolve_query_set(
423    state: &mut EncodingState,
424    query_set: Arc<QuerySet>,
425    start_query: u32,
426    query_count: u32,
427    dst_buffer: Arc<Buffer>,
428    destination_offset: BufferAddress,
429) -> Result<(), QueryError> {
430    if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
431        return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
432    }
433
434    query_set.same_device(state.device)?;
435    dst_buffer.same_device(state.device)?;
436
437    dst_buffer.check_destroyed(state.snatch_guard)?;
438
439    let dst_pending = state
440        .tracker
441        .buffers
442        .set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
443    let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, state.snatch_guard));
444
445    dst_buffer
446        .check_usage(wgt::BufferUsages::QUERY_RESOLVE)
447        .map_err(ResolveError::MissingBufferUsage)?;
448
449    let end_query = u64::from(start_query)
450        .checked_add(u64::from(query_count))
451        .expect("`u64` overflow from adding two `u32`s, should be unreachable");
452    if end_query > u64::from(query_set.desc.count) {
453        return Err(ResolveError::QueryOverrun {
454            start_query,
455            end_query,
456            query_set_size: query_set.desc.count,
457        }
458        .into());
459    }
460    let end_query =
461        u32::try_from(end_query).expect("`u32` overflow for `end_query`, which should be `u32`");
462
463    let elements_per_query = match query_set.desc.ty {
464        wgt::QueryType::Occlusion => 1,
465        wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
466        wgt::QueryType::Timestamp => 1,
467    };
468    let stride = elements_per_query * wgt::QUERY_SIZE;
469    let bytes_used: BufferAddress = u64::from(stride)
470        .checked_mul(u64::from(query_count))
471        .expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
472
473    let buffer_start_offset = destination_offset;
474    let buffer_end_offset = buffer_start_offset
475        .checked_add(bytes_used)
476        .filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
477        .ok_or(ResolveError::BufferOverrun {
478            start_query,
479            end_query,
480            stride,
481            buffer_size: dst_buffer.size,
482            buffer_start_offset,
483            bytes_used,
484        })?;
485
486    // TODO(https://github.com/gfx-rs/wgpu/issues/3993): Need to track initialization state.
487    state
488        .buffer_memory_init_actions
489        .extend(dst_buffer.initialization_status.read().create_action(
490            &dst_buffer,
491            buffer_start_offset..buffer_end_offset,
492            MemoryInitKind::ImplicitlyInitialized,
493        ));
494
495    let raw_dst_buffer = dst_buffer.try_raw(state.snatch_guard)?;
496    unsafe {
497        state.raw_encoder.transition_buffers(dst_barrier.as_slice());
498        state.raw_encoder.copy_query_results(
499            query_set.raw(),
500            start_query..end_query,
501            raw_dst_buffer,
502            destination_offset,
503            wgt::BufferSize::new_unchecked(stride as u64),
504        );
505    }
506
507    if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
508        // Timestamp normalization is only needed for timestamps.
509        state.device.timestamp_normalizer.get().unwrap().normalize(
510            state.snatch_guard,
511            state.raw_encoder,
512            &mut state.tracker.buffers,
513            dst_buffer
514                .timestamp_normalization_bind_group
515                .get(state.snatch_guard)
516                .unwrap(),
517            &dst_buffer,
518            destination_offset,
519            query_count,
520        );
521    }
522
523    state.tracker.query_sets.insert_single(query_set);
524
525    Ok(())
526}