wgpu_test/
image.rs

1//! Image comparison utilities
2
3use std::{borrow::Cow, ffi::OsStr, path::Path};
4
5use wgpu::util::{align_to, DeviceExt};
6use wgpu::*;
7
8use crate::TestingContext;
9
10#[cfg(not(any(target_arch = "wasm32", miri)))]
11async fn read_png(path: impl AsRef<Path>, width: u32, height: u32) -> Option<Vec<u8>> {
12    let data = match std::fs::read(&path) {
13        Ok(f) => f,
14        Err(e) => {
15            log::warn!(
16                "image comparison invalid: file io error when comparing {}: {}",
17                path.as_ref().display(),
18                e
19            );
20            return None;
21        }
22    };
23    let decoder = png::Decoder::new(std::io::Cursor::new(data));
24    let mut reader = decoder.read_info().ok()?;
25
26    let mut buffer = vec![0; reader.output_buffer_size()];
27    let info = reader.next_frame(&mut buffer).ok()?;
28    if info.width != width {
29        log::warn!("image comparison invalid: size mismatch");
30        return None;
31    }
32    if info.height != height {
33        log::warn!("image comparison invalid: size mismatch");
34        return None;
35    }
36    if info.color_type != png::ColorType::Rgba {
37        log::warn!("image comparison invalid: color type mismatch");
38        return None;
39    }
40    if info.bit_depth != png::BitDepth::Eight {
41        log::warn!("image comparison invalid: bit depth mismatch");
42        return None;
43    }
44
45    Some(buffer)
46}
47
48#[cfg(not(any(target_arch = "wasm32", miri)))]
49async fn write_png(
50    path: impl AsRef<Path>,
51    width: u32,
52    height: u32,
53    data: &[u8],
54    compression: png::Compression,
55) {
56    let file = std::io::BufWriter::new(std::fs::File::create(path).unwrap());
57
58    let mut encoder = png::Encoder::new(file, width, height);
59    encoder.set_color(png::ColorType::Rgba);
60    encoder.set_depth(png::BitDepth::Eight);
61    encoder.set_compression(compression);
62    let mut writer = encoder.write_header().unwrap();
63
64    writer.write_image_data(data).unwrap();
65}
66
67#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
68fn add_alpha(input: &[u8]) -> Vec<u8> {
69    input
70        .chunks_exact(3)
71        .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255])
72        .collect()
73}
74
75#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
76fn remove_alpha(input: &[u8]) -> Vec<u8> {
77    input
78        .chunks_exact(4)
79        .flat_map(|chunk| &chunk[0..3])
80        .copied()
81        .collect()
82}
83
84#[cfg(not(any(target_arch = "wasm32", miri)))]
85fn print_flip(pool: &mut nv_flip::FlipPool) {
86    println!("\tMean: {:.6}", pool.mean());
87    println!("\tMin Value: {:.6}", pool.min_value());
88    for percentile in [25, 50, 75, 95, 99] {
89        println!(
90            "\t      {percentile}%: {:.6}",
91            pool.get_percentile(percentile as f32 / 100.0, true)
92        );
93    }
94    println!("\tMax Value: {:.6}", pool.max_value());
95}
96
97/// The FLIP library generates a per-pixel error map where 0.0 represents "no error"
98/// and 1.0 represents "maximum error" between the images. This is then put into
99/// a weighted-histogram, which we query to determine if the errors between
100/// the test and reference image is high enough to count as "different".
101///
102/// Error thresholds will be different for every test, but good initial values
103/// to look at are in the [0.01, 0.1] range. The larger the area that might have
104/// inherent variance, the larger this base value is. Using a high percentile comparison
105/// (e.g. 95% or 99%) is good for images that are likely to have a lot of error
106/// in a small area when they fail.
107#[derive(Debug, Clone, Copy)]
108pub enum ComparisonType {
109    /// If the mean error is greater than the given value, the test will fail.
110    Mean(f32),
111    /// If the given percentile is greater than the given value, the test will fail.
112    ///
113    /// The percentile is given in the range [0, 1].
114    Percentile { percentile: f32, threshold: f32 },
115}
116
117impl ComparisonType {
118    #[cfg(not(any(target_arch = "wasm32", miri)))]
119    fn check(&self, pool: &mut nv_flip::FlipPool) -> bool {
120        match *self {
121            ComparisonType::Mean(v) => {
122                let mean = pool.mean();
123                let within = mean <= v;
124                println!(
125                    "\tExpected Mean ({:.6}) to be under expected maximum ({}): {}",
126                    mean,
127                    v,
128                    if within { "PASS" } else { "FAIL" }
129                );
130                within
131            }
132            ComparisonType::Percentile {
133                percentile: p,
134                threshold: v,
135            } => {
136                let percentile = pool.get_percentile(p, true);
137                let within = percentile <= v;
138                println!(
139                    "\tExpected {}% ({:.6}) to be under expected maximum ({}): {}",
140                    p * 100.0,
141                    percentile,
142                    v,
143                    if within { "PASS" } else { "FAIL" }
144                );
145                within
146            }
147        }
148    }
149}
150
151#[cfg(not(any(target_arch = "wasm32", miri)))]
152pub async fn compare_image_output(
153    path: impl AsRef<Path> + AsRef<OsStr>,
154    adapter_info: &wgpu::AdapterInfo,
155    width: u32,
156    height: u32,
157    test_with_alpha: &[u8],
158    checks: &[ComparisonType],
159) {
160    use std::{ffi::OsString, str::FromStr};
161
162    let reference_path = Path::new(&path);
163    let reference_with_alpha = read_png(&path, width, height).await;
164
165    let reference = match reference_with_alpha {
166        Some(v) => remove_alpha(&v),
167        None => {
168            write_png(
169                &path,
170                width,
171                height,
172                test_with_alpha,
173                png::Compression::Best,
174            )
175            .await;
176            return;
177        }
178    };
179    let test = remove_alpha(test_with_alpha);
180
181    assert_eq!(reference.len(), test.len());
182
183    let file_stem = reference_path.file_stem().unwrap().to_string_lossy();
184    let renderer = format!(
185        "{}-{}-{}",
186        adapter_info.backend,
187        sanitize_for_path(&adapter_info.name),
188        sanitize_for_path(&adapter_info.driver)
189    );
190    // Determine the paths to write out the various intermediate files
191    let actual_path = Path::new(&path)
192        .with_file_name(OsString::from_str(&format!("{file_stem}-{renderer}-actual.png")).unwrap());
193    let difference_path = Path::new(&path).with_file_name(
194        OsString::from_str(&format!("{file_stem}-{renderer}-difference.png",)).unwrap(),
195    );
196
197    let mut all_passed;
198    let magma_image_with_alpha;
199    {
200        let reference_flip = nv_flip::FlipImageRgb8::with_data(width, height, &reference);
201        let test_flip = nv_flip::FlipImageRgb8::with_data(width, height, &test);
202
203        let error_map_flip = nv_flip::flip(
204            reference_flip,
205            test_flip,
206            nv_flip::DEFAULT_PIXELS_PER_DEGREE,
207        );
208        let mut pool = nv_flip::FlipPool::from_image(&error_map_flip);
209
210        println!(
211            "Starting image comparison test with reference image \"{}\"",
212            reference_path.display()
213        );
214
215        print_flip(&mut pool);
216
217        // If there are no checks, we want to fail the test.
218        all_passed = !checks.is_empty();
219        // We always iterate all of these, as the call to check prints
220        for check in checks {
221            all_passed &= check.check(&mut pool);
222        }
223
224        // Convert the error values to a false color representation
225        let magma_image = error_map_flip
226            .apply_color_lut(&nv_flip::magma_lut())
227            .to_vec();
228        magma_image_with_alpha = add_alpha(&magma_image);
229    }
230
231    write_png(
232        actual_path,
233        width,
234        height,
235        test_with_alpha,
236        png::Compression::Fast,
237    )
238    .await;
239    write_png(
240        &difference_path,
241        width,
242        height,
243        &magma_image_with_alpha,
244        png::Compression::Fast,
245    )
246    .await;
247
248    if !all_passed {
249        panic!("Image data mismatch: {}", difference_path.display())
250    }
251}
252
253#[cfg(any(target_arch = "wasm32", miri))]
254pub async fn compare_image_output(
255    path: impl AsRef<Path> + AsRef<OsStr>,
256    adapter_info: &wgpu::AdapterInfo,
257    width: u32,
258    height: u32,
259    test_with_alpha: &[u8],
260    checks: &[ComparisonType],
261) {
262    #[cfg(any(target_arch = "wasm32", miri))]
263    {
264        let _ = (path, adapter_info, width, height, test_with_alpha, checks);
265    }
266}
267
268#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
269fn sanitize_for_path(s: &str) -> String {
270    s.chars()
271        .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
272        .collect()
273}
274
275fn copy_via_compute(
276    device: &Device,
277    encoder: &mut CommandEncoder,
278    texture: &Texture,
279    buffer: &Buffer,
280    aspect: TextureAspect,
281) {
282    let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
283        label: None,
284        entries: &[
285            BindGroupLayoutEntry {
286                binding: 0,
287                visibility: ShaderStages::COMPUTE,
288                ty: BindingType::Texture {
289                    sample_type: match aspect {
290                        TextureAspect::DepthOnly => TextureSampleType::Float { filterable: false },
291                        TextureAspect::StencilOnly => TextureSampleType::Uint,
292                        _ => unreachable!(),
293                    },
294                    view_dimension: TextureViewDimension::D2Array,
295                    multisampled: false,
296                },
297                count: None,
298            },
299            BindGroupLayoutEntry {
300                binding: 1,
301                visibility: ShaderStages::COMPUTE,
302                ty: BindingType::Buffer {
303                    ty: BufferBindingType::Storage { read_only: false },
304                    has_dynamic_offset: false,
305                    min_binding_size: None,
306                },
307                count: None,
308            },
309        ],
310    });
311
312    let view = texture.create_view(&TextureViewDescriptor {
313        aspect,
314        dimension: Some(TextureViewDimension::D2Array),
315        ..Default::default()
316    });
317
318    let output_buffer = device.create_buffer(&BufferDescriptor {
319        label: Some("output buffer"),
320        size: buffer.size(),
321        usage: BufferUsages::COPY_SRC | BufferUsages::STORAGE,
322        mapped_at_creation: false,
323    });
324
325    let bg = device.create_bind_group(&BindGroupDescriptor {
326        label: None,
327        layout: &bgl,
328        entries: &[
329            BindGroupEntry {
330                binding: 0,
331                resource: BindingResource::TextureView(&view),
332            },
333            BindGroupEntry {
334                binding: 1,
335                resource: BindingResource::Buffer(BufferBinding {
336                    buffer: &output_buffer,
337                    offset: 0,
338                    size: None,
339                }),
340            },
341        ],
342    });
343
344    let pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
345        label: None,
346        bind_group_layouts: &[&bgl],
347        push_constant_ranges: &[],
348    });
349
350    let source = String::from(include_str!("copy_texture_to_buffer.wgsl"));
351
352    let processed_source = source.replace(
353        "{{type}}",
354        match aspect {
355            TextureAspect::DepthOnly => "f32",
356            TextureAspect::StencilOnly => "u32",
357            _ => unreachable!(),
358        },
359    );
360
361    let sm = device.create_shader_module(ShaderModuleDescriptor {
362        label: Some("shader copy_texture_to_buffer.wgsl"),
363        source: ShaderSource::Wgsl(Cow::Borrowed(&processed_source)),
364    });
365
366    let pipeline_copy = device.create_compute_pipeline(&ComputePipelineDescriptor {
367        label: Some("pipeline read"),
368        layout: Some(&pll),
369        module: &sm,
370        entry_point: Some("copy_texture_to_buffer"),
371        compilation_options: Default::default(),
372        cache: None,
373    });
374
375    {
376        let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
377
378        pass.set_pipeline(&pipeline_copy);
379        pass.set_bind_group(0, &bg, &[]);
380        pass.dispatch_workgroups(1, 1, 1);
381    }
382
383    encoder.copy_buffer_to_buffer(&output_buffer, 0, buffer, 0, buffer.size());
384}
385
386fn copy_texture_to_buffer_with_aspect(
387    encoder: &mut CommandEncoder,
388    texture: &Texture,
389    buffer: &Buffer,
390    buffer_stencil: &Option<Buffer>,
391    aspect: TextureAspect,
392) {
393    let (block_width, block_height) = texture.format().block_dimensions();
394    let block_size = texture.format().block_copy_size(Some(aspect)).unwrap();
395    let bytes_per_row = align_to(
396        (texture.width() / block_width) * block_size,
397        COPY_BYTES_PER_ROW_ALIGNMENT,
398    );
399    let mip_level = 0;
400    encoder.copy_texture_to_buffer(
401        TexelCopyTextureInfo {
402            texture,
403            mip_level,
404            origin: Origin3d::ZERO,
405            aspect,
406        },
407        TexelCopyBufferInfo {
408            buffer: match aspect {
409                TextureAspect::StencilOnly => buffer_stencil.as_ref().unwrap(),
410                _ => buffer,
411            },
412            layout: TexelCopyBufferLayout {
413                offset: 0,
414                bytes_per_row: Some(bytes_per_row),
415                rows_per_image: Some(texture.height() / block_height),
416            },
417        },
418        texture
419            .size()
420            .mip_level_size(mip_level, texture.dimension()),
421    );
422}
423
424fn copy_texture_to_buffer(
425    device: &Device,
426    encoder: &mut CommandEncoder,
427    texture: &Texture,
428    buffer: &Buffer,
429    buffer_stencil: &Option<Buffer>,
430) {
431    match texture.format() {
432        TextureFormat::Depth24Plus => {
433            copy_via_compute(device, encoder, texture, buffer, TextureAspect::DepthOnly);
434        }
435        TextureFormat::Depth24PlusStencil8 => {
436            copy_via_compute(device, encoder, texture, buffer, TextureAspect::DepthOnly);
437            copy_texture_to_buffer_with_aspect(
438                encoder,
439                texture,
440                buffer,
441                buffer_stencil,
442                TextureAspect::StencilOnly,
443            );
444        }
445        TextureFormat::Depth32FloatStencil8 => {
446            copy_texture_to_buffer_with_aspect(
447                encoder,
448                texture,
449                buffer,
450                buffer_stencil,
451                TextureAspect::DepthOnly,
452            );
453            copy_texture_to_buffer_with_aspect(
454                encoder,
455                texture,
456                buffer,
457                buffer_stencil,
458                TextureAspect::StencilOnly,
459            );
460        }
461        _ => {
462            copy_texture_to_buffer_with_aspect(
463                encoder,
464                texture,
465                buffer,
466                buffer_stencil,
467                TextureAspect::All,
468            );
469        }
470    }
471}
472
473pub struct ReadbackBuffers {
474    /// texture format
475    texture_format: TextureFormat,
476    /// texture width
477    texture_width: u32,
478    /// texture height
479    texture_height: u32,
480    /// texture depth or array layer count
481    texture_depth_or_array_layers: u32,
482    /// buffer for color or depth aspects
483    buffer: Buffer,
484    /// buffer for stencil aspect
485    buffer_stencil: Option<Buffer>,
486}
487
488impl ReadbackBuffers {
489    pub fn new(device: &Device, texture: &Texture) -> Self {
490        let (block_width, block_height) = texture.format().block_dimensions();
491        const SKIP_ALIGNMENT_FORMATS: [TextureFormat; 2] = [
492            TextureFormat::Depth24Plus,
493            TextureFormat::Depth24PlusStencil8,
494        ];
495        let should_align_buffer_size = !SKIP_ALIGNMENT_FORMATS.contains(&texture.format());
496        if texture.format().is_combined_depth_stencil_format() {
497            let mut buffer_depth_bytes_per_row = (texture.width() / block_width)
498                * texture
499                    .format()
500                    .block_copy_size(Some(TextureAspect::DepthOnly))
501                    .unwrap_or(4);
502            if should_align_buffer_size {
503                buffer_depth_bytes_per_row =
504                    align_to(buffer_depth_bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT);
505            }
506            let buffer_size = buffer_depth_bytes_per_row
507                * (texture.height() / block_height)
508                * texture.depth_or_array_layers();
509
510            let buffer_stencil_bytes_per_row = align_to(
511                (texture.width() / block_width)
512                    * texture
513                        .format()
514                        .block_copy_size(Some(TextureAspect::StencilOnly))
515                        .unwrap_or(4),
516                COPY_BYTES_PER_ROW_ALIGNMENT,
517            );
518            let buffer_stencil_size = buffer_stencil_bytes_per_row
519                * (texture.height() / block_height)
520                * texture.depth_or_array_layers();
521
522            let buffer = device.create_buffer_init(&util::BufferInitDescriptor {
523                label: Some("Texture Readback"),
524                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
525                contents: &vec![255; buffer_size as usize],
526            });
527            let buffer_stencil = device.create_buffer_init(&util::BufferInitDescriptor {
528                label: Some("Texture Stencil-Aspect Readback"),
529                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
530                contents: &vec![255; buffer_stencil_size as usize],
531            });
532            ReadbackBuffers {
533                texture_format: texture.format(),
534                texture_width: texture.width(),
535                texture_height: texture.height(),
536                texture_depth_or_array_layers: texture.depth_or_array_layers(),
537                buffer,
538                buffer_stencil: Some(buffer_stencil),
539            }
540        } else {
541            let mut bytes_per_row = (texture.width() / block_width)
542                * texture.format().block_copy_size(None).unwrap_or(4);
543            if should_align_buffer_size {
544                bytes_per_row = align_to(bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT);
545            }
546            let buffer_size =
547                bytes_per_row * (texture.height() / block_height) * texture.depth_or_array_layers();
548            let buffer = device.create_buffer_init(&util::BufferInitDescriptor {
549                label: Some("Texture Readback"),
550                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
551                contents: &vec![255; buffer_size as usize],
552            });
553            ReadbackBuffers {
554                texture_format: texture.format(),
555                texture_width: texture.width(),
556                texture_height: texture.height(),
557                texture_depth_or_array_layers: texture.depth_or_array_layers(),
558                buffer,
559                buffer_stencil: None,
560            }
561        }
562    }
563
564    // TODO: also copy and check mips
565    pub fn copy_from(&self, device: &Device, encoder: &mut CommandEncoder, texture: &Texture) {
566        copy_texture_to_buffer(device, encoder, texture, &self.buffer, &self.buffer_stencil);
567    }
568
569    async fn retrieve_buffer(
570        &self,
571        ctx: &TestingContext,
572        buffer: &Buffer,
573        aspect: Option<TextureAspect>,
574    ) -> Vec<u8> {
575        let buffer_slice = buffer.slice(..);
576        buffer_slice.map_async(MapMode::Read, |_| ());
577        ctx.async_poll(PollType::wait()).await.unwrap();
578        let (block_width, block_height) = self.texture_format.block_dimensions();
579        let expected_bytes_per_row = (self.texture_width / block_width)
580            * self.texture_format.block_copy_size(aspect).unwrap_or(4);
581        let expected_buffer_size = expected_bytes_per_row
582            * (self.texture_height / block_height)
583            * self.texture_depth_or_array_layers;
584        let data: BufferView = buffer_slice.get_mapped_range();
585        if expected_buffer_size as usize == data.len() {
586            data.to_vec()
587        } else {
588            bytemuck::cast_slice(&data)
589                .chunks_exact(
590                    align_to(expected_bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT) as usize,
591                )
592                .flat_map(|x| x.iter().take(expected_bytes_per_row as usize))
593                .copied()
594                .collect()
595        }
596    }
597
598    fn buffer_aspect(&self) -> Option<TextureAspect> {
599        if self.texture_format.is_combined_depth_stencil_format() {
600            Some(TextureAspect::DepthOnly)
601        } else {
602            None
603        }
604    }
605
606    async fn is_zero(
607        &self,
608        ctx: &TestingContext,
609        buffer: &Buffer,
610        aspect: Option<TextureAspect>,
611    ) -> bool {
612        let is_zero = self
613            .retrieve_buffer(ctx, buffer, aspect)
614            .await
615            .iter()
616            .all(|b| *b == 0);
617        buffer.unmap();
618        is_zero
619    }
620
621    pub async fn are_zero(&self, ctx: &TestingContext) -> bool {
622        let buffer_zero = self.is_zero(ctx, &self.buffer, self.buffer_aspect()).await;
623        let mut stencil_buffer_zero = true;
624        if let Some(buffer) = &self.buffer_stencil {
625            stencil_buffer_zero = self
626                .is_zero(ctx, buffer, Some(TextureAspect::StencilOnly))
627                .await;
628        };
629        buffer_zero && stencil_buffer_zero
630    }
631
632    pub async fn assert_buffer_contents(&self, ctx: &TestingContext, expected_data: &[u8]) {
633        let result_buffer = self
634            .retrieve_buffer(ctx, &self.buffer, self.buffer_aspect())
635            .await;
636        assert!(
637            result_buffer.len() >= expected_data.len(),
638            "Result buffer ({}) smaller than expected buffer ({})",
639            result_buffer.len(),
640            expected_data.len()
641        );
642        let result_buffer = &result_buffer[..expected_data.len()];
643        assert_eq!(result_buffer, expected_data);
644        self.buffer.unmap();
645    }
646}