wgpu_test/
image.rs

1//! Image comparison utilities
2
3use std::{borrow::Cow, ffi::OsStr, path::Path};
4
5use wgpu::util::{align_to, DeviceExt};
6use wgpu::*;
7
8use crate::TestingContext;
9
10#[cfg(not(any(target_arch = "wasm32", miri)))]
11async fn read_png(path: impl AsRef<Path>, width: u32, height: u32) -> Option<Vec<u8>> {
12    let data = match std::fs::read(&path) {
13        Ok(f) => f,
14        Err(e) => {
15            log::warn!(
16                "image comparison invalid: file io error when comparing {}: {}",
17                path.as_ref().display(),
18                e
19            );
20            return None;
21        }
22    };
23    let decoder = png::Decoder::new(std::io::Cursor::new(data));
24    let mut reader = decoder.read_info().ok()?;
25
26    let buffer_len = reader
27        .output_buffer_size()
28        .expect("output buffer would not fit in memory");
29    let mut buffer = vec![0; buffer_len];
30    let info = reader.next_frame(&mut buffer).ok()?;
31    if info.width != width {
32        log::warn!("image comparison invalid: size mismatch");
33        return None;
34    }
35    if info.height != height {
36        log::warn!("image comparison invalid: size mismatch");
37        return None;
38    }
39    if info.color_type != png::ColorType::Rgba {
40        log::warn!("image comparison invalid: color type mismatch");
41        return None;
42    }
43    if info.bit_depth != png::BitDepth::Eight {
44        log::warn!("image comparison invalid: bit depth mismatch");
45        return None;
46    }
47
48    Some(buffer)
49}
50
51#[cfg(not(any(target_arch = "wasm32", miri)))]
52async fn write_png(
53    path: impl AsRef<Path>,
54    width: u32,
55    height: u32,
56    data: &[u8],
57    compression: png::Compression,
58) {
59    let file = std::io::BufWriter::new(std::fs::File::create(path).unwrap());
60
61    let mut encoder = png::Encoder::new(file, width, height);
62    encoder.set_color(png::ColorType::Rgba);
63    encoder.set_depth(png::BitDepth::Eight);
64    encoder.set_compression(compression);
65    let mut writer = encoder.write_header().unwrap();
66
67    writer.write_image_data(data).unwrap();
68}
69
70#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
71fn add_alpha(input: &[u8]) -> Vec<u8> {
72    input
73        .chunks_exact(3)
74        .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255])
75        .collect()
76}
77
78#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
79fn remove_alpha(input: &[u8]) -> Vec<u8> {
80    input
81        .chunks_exact(4)
82        .flat_map(|chunk| &chunk[0..3])
83        .copied()
84        .collect()
85}
86
87#[cfg(not(any(target_arch = "wasm32", miri)))]
88fn print_flip(pool: &mut nv_flip::FlipPool) {
89    println!("\tMean: {:.6}", pool.mean());
90    println!("\tMin Value: {:.6}", pool.min_value());
91    for percentile in [25, 50, 75, 95, 99] {
92        println!(
93            "\t      {percentile}%: {:.6}",
94            pool.get_percentile(percentile as f32 / 100.0, true)
95        );
96    }
97    println!("\tMax Value: {:.6}", pool.max_value());
98}
99
100/// The FLIP library generates a per-pixel error map where 0.0 represents "no error"
101/// and 1.0 represents "maximum error" between the images. This is then put into
102/// a weighted-histogram, which we query to determine if the errors between
103/// the test and reference image is high enough to count as "different".
104///
105/// Error thresholds will be different for every test, but good initial values
106/// to look at are in the [0.01, 0.1] range. The larger the area that might have
107/// inherent variance, the larger this base value is. Using a high percentile comparison
108/// (e.g. 95% or 99%) is good for images that are likely to have a lot of error
109/// in a small area when they fail.
110#[derive(Debug, Clone, Copy)]
111pub enum ComparisonType {
112    /// If the mean error is greater than the given value, the test will fail.
113    Mean(f32),
114    /// If the given percentile is greater than the given value, the test will fail.
115    ///
116    /// The percentile is given in the range [0, 1].
117    Percentile { percentile: f32, threshold: f32 },
118}
119
120impl ComparisonType {
121    #[cfg(not(any(target_arch = "wasm32", miri)))]
122    fn check(&self, pool: &mut nv_flip::FlipPool) -> bool {
123        match *self {
124            ComparisonType::Mean(v) => {
125                let mean = pool.mean();
126                let within = mean <= v;
127                println!(
128                    "\tExpected Mean ({:.6}) to be under expected maximum ({}): {}",
129                    mean,
130                    v,
131                    if within { "PASS" } else { "FAIL" }
132                );
133                within
134            }
135            ComparisonType::Percentile {
136                percentile: p,
137                threshold: v,
138            } => {
139                let percentile = pool.get_percentile(p, true);
140                let within = percentile <= v;
141                println!(
142                    "\tExpected {}% ({:.6}) to be under expected maximum ({}): {}",
143                    p * 100.0,
144                    percentile,
145                    v,
146                    if within { "PASS" } else { "FAIL" }
147                );
148                within
149            }
150        }
151    }
152}
153
154#[cfg(not(any(target_arch = "wasm32", miri)))]
155pub async fn compare_image_output(
156    path: impl AsRef<Path> + AsRef<OsStr>,
157    adapter_info: &wgpu::AdapterInfo,
158    width: u32,
159    height: u32,
160    test_with_alpha: &[u8],
161    checks: &[ComparisonType],
162) {
163    use std::{ffi::OsString, str::FromStr};
164
165    let reference_path = Path::new(&path);
166    let reference_with_alpha = read_png(&path, width, height).await;
167
168    let reference = match reference_with_alpha {
169        Some(v) => remove_alpha(&v),
170        None => {
171            write_png(
172                &path,
173                width,
174                height,
175                test_with_alpha,
176                png::Compression::High,
177            )
178            .await;
179            return;
180        }
181    };
182    let test = remove_alpha(test_with_alpha);
183
184    assert_eq!(reference.len(), test.len());
185
186    let file_stem = reference_path.file_stem().unwrap().to_string_lossy();
187    let renderer = format!(
188        "{}-{}-{}",
189        adapter_info.backend,
190        sanitize_for_path(&adapter_info.name),
191        sanitize_for_path(&adapter_info.driver)
192    );
193    // Determine the paths to write out the various intermediate files
194    let actual_path = Path::new(&path)
195        .with_file_name(OsString::from_str(&format!("{file_stem}-{renderer}-actual.png")).unwrap());
196    let difference_path = Path::new(&path).with_file_name(
197        OsString::from_str(&format!("{file_stem}-{renderer}-difference.png",)).unwrap(),
198    );
199
200    let mut all_passed;
201    let magma_image_with_alpha;
202    {
203        let reference_flip = nv_flip::FlipImageRgb8::with_data(width, height, &reference);
204        let test_flip = nv_flip::FlipImageRgb8::with_data(width, height, &test);
205
206        let error_map_flip = nv_flip::flip(
207            reference_flip,
208            test_flip,
209            nv_flip::DEFAULT_PIXELS_PER_DEGREE,
210        );
211        let mut pool = nv_flip::FlipPool::from_image(&error_map_flip);
212
213        println!(
214            "Starting image comparison test with reference image \"{}\"",
215            reference_path.display()
216        );
217
218        print_flip(&mut pool);
219
220        // If there are no checks, we want to fail the test.
221        all_passed = !checks.is_empty();
222        // We always iterate all of these, as the call to check prints
223        for check in checks {
224            all_passed &= check.check(&mut pool);
225        }
226
227        // Convert the error values to a false color representation
228        let magma_image = error_map_flip
229            .apply_color_lut(&nv_flip::magma_lut())
230            .to_vec();
231        magma_image_with_alpha = add_alpha(&magma_image);
232    }
233
234    write_png(
235        actual_path,
236        width,
237        height,
238        test_with_alpha,
239        png::Compression::Fast,
240    )
241    .await;
242    write_png(
243        &difference_path,
244        width,
245        height,
246        &magma_image_with_alpha,
247        png::Compression::Fast,
248    )
249    .await;
250
251    if !all_passed {
252        panic!("Image data mismatch: {}", difference_path.display())
253    }
254}
255
256#[cfg(any(target_arch = "wasm32", miri))]
257pub async fn compare_image_output(
258    path: impl AsRef<Path> + AsRef<OsStr>,
259    adapter_info: &wgpu::AdapterInfo,
260    width: u32,
261    height: u32,
262    test_with_alpha: &[u8],
263    checks: &[ComparisonType],
264) {
265    #[cfg(any(target_arch = "wasm32", miri))]
266    {
267        let _ = (path, adapter_info, width, height, test_with_alpha, checks);
268    }
269}
270
271#[cfg_attr(any(target_arch = "wasm32", miri), allow(unused))]
272fn sanitize_for_path(s: &str) -> String {
273    s.chars()
274        .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '_' })
275        .collect()
276}
277
278fn copy_via_compute(
279    device: &Device,
280    encoder: &mut CommandEncoder,
281    texture: &Texture,
282    buffer: &Buffer,
283    aspect: TextureAspect,
284) {
285    let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
286        label: None,
287        entries: &[
288            BindGroupLayoutEntry {
289                binding: 0,
290                visibility: ShaderStages::COMPUTE,
291                ty: BindingType::Texture {
292                    sample_type: match aspect {
293                        TextureAspect::DepthOnly => TextureSampleType::Float { filterable: false },
294                        TextureAspect::StencilOnly => TextureSampleType::Uint,
295                        _ => unreachable!(),
296                    },
297                    view_dimension: TextureViewDimension::D2Array,
298                    multisampled: false,
299                },
300                count: None,
301            },
302            BindGroupLayoutEntry {
303                binding: 1,
304                visibility: ShaderStages::COMPUTE,
305                ty: BindingType::Buffer {
306                    ty: BufferBindingType::Storage { read_only: false },
307                    has_dynamic_offset: false,
308                    min_binding_size: None,
309                },
310                count: None,
311            },
312        ],
313    });
314
315    let view = texture.create_view(&TextureViewDescriptor {
316        aspect,
317        dimension: Some(TextureViewDimension::D2Array),
318        ..Default::default()
319    });
320
321    let output_buffer = device.create_buffer(&BufferDescriptor {
322        label: Some("output buffer"),
323        size: buffer.size(),
324        usage: BufferUsages::COPY_SRC | BufferUsages::STORAGE,
325        mapped_at_creation: false,
326    });
327
328    let bg = device.create_bind_group(&BindGroupDescriptor {
329        label: None,
330        layout: &bgl,
331        entries: &[
332            BindGroupEntry {
333                binding: 0,
334                resource: BindingResource::TextureView(&view),
335            },
336            BindGroupEntry {
337                binding: 1,
338                resource: BindingResource::Buffer(BufferBinding {
339                    buffer: &output_buffer,
340                    offset: 0,
341                    size: None,
342                }),
343            },
344        ],
345    });
346
347    let pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
348        label: None,
349        bind_group_layouts: &[&bgl],
350        push_constant_ranges: &[],
351    });
352
353    let source = String::from(include_str!("copy_texture_to_buffer.wgsl"));
354
355    let processed_source = source.replace(
356        "{{type}}",
357        match aspect {
358            TextureAspect::DepthOnly => "f32",
359            TextureAspect::StencilOnly => "u32",
360            _ => unreachable!(),
361        },
362    );
363
364    let sm = device.create_shader_module(ShaderModuleDescriptor {
365        label: Some("shader copy_texture_to_buffer.wgsl"),
366        source: ShaderSource::Wgsl(Cow::Borrowed(&processed_source)),
367    });
368
369    let pipeline_copy = device.create_compute_pipeline(&ComputePipelineDescriptor {
370        label: Some("pipeline read"),
371        layout: Some(&pll),
372        module: &sm,
373        entry_point: Some("copy_texture_to_buffer"),
374        compilation_options: Default::default(),
375        cache: None,
376    });
377
378    {
379        let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
380
381        pass.set_pipeline(&pipeline_copy);
382        pass.set_bind_group(0, &bg, &[]);
383        pass.dispatch_workgroups(1, 1, 1);
384    }
385
386    encoder.copy_buffer_to_buffer(&output_buffer, 0, buffer, 0, buffer.size());
387}
388
389fn copy_texture_to_buffer_with_aspect(
390    encoder: &mut CommandEncoder,
391    texture: &Texture,
392    buffer: &Buffer,
393    buffer_stencil: &Option<Buffer>,
394    aspect: TextureAspect,
395) {
396    let (block_width, block_height) = texture.format().block_dimensions();
397    let block_size = texture.format().block_copy_size(Some(aspect)).unwrap();
398    let bytes_per_row = align_to(
399        (texture.width() / block_width) * block_size,
400        COPY_BYTES_PER_ROW_ALIGNMENT,
401    );
402    let mip_level = 0;
403    encoder.copy_texture_to_buffer(
404        TexelCopyTextureInfo {
405            texture,
406            mip_level,
407            origin: Origin3d::ZERO,
408            aspect,
409        },
410        TexelCopyBufferInfo {
411            buffer: match aspect {
412                TextureAspect::StencilOnly => buffer_stencil.as_ref().unwrap(),
413                _ => buffer,
414            },
415            layout: TexelCopyBufferLayout {
416                offset: 0,
417                bytes_per_row: Some(bytes_per_row),
418                rows_per_image: Some(texture.height() / block_height),
419            },
420        },
421        texture
422            .size()
423            .mip_level_size(mip_level, texture.dimension()),
424    );
425}
426
427fn copy_texture_to_buffer(
428    device: &Device,
429    encoder: &mut CommandEncoder,
430    texture: &Texture,
431    buffer: &Buffer,
432    buffer_stencil: &Option<Buffer>,
433) {
434    match texture.format() {
435        TextureFormat::Depth24Plus => {
436            copy_via_compute(device, encoder, texture, buffer, TextureAspect::DepthOnly);
437        }
438        TextureFormat::Depth24PlusStencil8 => {
439            copy_via_compute(device, encoder, texture, buffer, TextureAspect::DepthOnly);
440            copy_texture_to_buffer_with_aspect(
441                encoder,
442                texture,
443                buffer,
444                buffer_stencil,
445                TextureAspect::StencilOnly,
446            );
447        }
448        TextureFormat::Depth32FloatStencil8 => {
449            copy_texture_to_buffer_with_aspect(
450                encoder,
451                texture,
452                buffer,
453                buffer_stencil,
454                TextureAspect::DepthOnly,
455            );
456            copy_texture_to_buffer_with_aspect(
457                encoder,
458                texture,
459                buffer,
460                buffer_stencil,
461                TextureAspect::StencilOnly,
462            );
463        }
464        _ => {
465            copy_texture_to_buffer_with_aspect(
466                encoder,
467                texture,
468                buffer,
469                buffer_stencil,
470                TextureAspect::All,
471            );
472        }
473    }
474}
475
476pub struct ReadbackBuffers {
477    /// texture format
478    texture_format: TextureFormat,
479    /// texture width
480    texture_width: u32,
481    /// texture height
482    texture_height: u32,
483    /// texture depth or array layer count
484    texture_depth_or_array_layers: u32,
485    /// buffer for color or depth aspects
486    buffer: Buffer,
487    /// buffer for stencil aspect
488    buffer_stencil: Option<Buffer>,
489}
490
491impl ReadbackBuffers {
492    pub fn new(device: &Device, texture: &Texture) -> Self {
493        let (block_width, block_height) = texture.format().block_dimensions();
494        const SKIP_ALIGNMENT_FORMATS: [TextureFormat; 2] = [
495            TextureFormat::Depth24Plus,
496            TextureFormat::Depth24PlusStencil8,
497        ];
498        let should_align_buffer_size = !SKIP_ALIGNMENT_FORMATS.contains(&texture.format());
499        if texture.format().is_combined_depth_stencil_format() {
500            let mut buffer_depth_bytes_per_row = (texture.width() / block_width)
501                * texture
502                    .format()
503                    .block_copy_size(Some(TextureAspect::DepthOnly))
504                    .unwrap_or(4);
505            if should_align_buffer_size {
506                buffer_depth_bytes_per_row =
507                    align_to(buffer_depth_bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT);
508            }
509            let buffer_size = buffer_depth_bytes_per_row
510                * (texture.height() / block_height)
511                * texture.depth_or_array_layers();
512
513            let buffer_stencil_bytes_per_row = align_to(
514                (texture.width() / block_width)
515                    * texture
516                        .format()
517                        .block_copy_size(Some(TextureAspect::StencilOnly))
518                        .unwrap_or(4),
519                COPY_BYTES_PER_ROW_ALIGNMENT,
520            );
521            let buffer_stencil_size = buffer_stencil_bytes_per_row
522                * (texture.height() / block_height)
523                * texture.depth_or_array_layers();
524
525            let buffer = device.create_buffer_init(&util::BufferInitDescriptor {
526                label: Some("Texture Readback"),
527                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
528                contents: &vec![255; buffer_size as usize],
529            });
530            let buffer_stencil = device.create_buffer_init(&util::BufferInitDescriptor {
531                label: Some("Texture Stencil-Aspect Readback"),
532                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
533                contents: &vec![255; buffer_stencil_size as usize],
534            });
535            ReadbackBuffers {
536                texture_format: texture.format(),
537                texture_width: texture.width(),
538                texture_height: texture.height(),
539                texture_depth_or_array_layers: texture.depth_or_array_layers(),
540                buffer,
541                buffer_stencil: Some(buffer_stencil),
542            }
543        } else {
544            let mut bytes_per_row = (texture.width() / block_width)
545                * texture.format().block_copy_size(None).unwrap_or(4);
546            if should_align_buffer_size {
547                bytes_per_row = align_to(bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT);
548            }
549            let buffer_size =
550                bytes_per_row * (texture.height() / block_height) * texture.depth_or_array_layers();
551            let buffer = device.create_buffer_init(&util::BufferInitDescriptor {
552                label: Some("Texture Readback"),
553                usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
554                contents: &vec![255; buffer_size as usize],
555            });
556            ReadbackBuffers {
557                texture_format: texture.format(),
558                texture_width: texture.width(),
559                texture_height: texture.height(),
560                texture_depth_or_array_layers: texture.depth_or_array_layers(),
561                buffer,
562                buffer_stencil: None,
563            }
564        }
565    }
566
567    // TODO: also copy and check mips
568    pub fn copy_from(&self, device: &Device, encoder: &mut CommandEncoder, texture: &Texture) {
569        copy_texture_to_buffer(device, encoder, texture, &self.buffer, &self.buffer_stencil);
570    }
571
572    async fn retrieve_buffer(
573        &self,
574        ctx: &TestingContext,
575        buffer: &Buffer,
576        aspect: Option<TextureAspect>,
577    ) -> Vec<u8> {
578        let buffer_slice = buffer.slice(..);
579        buffer_slice.map_async(MapMode::Read, |_| ());
580        ctx.async_poll(PollType::wait()).await.unwrap();
581        let (block_width, block_height) = self.texture_format.block_dimensions();
582        let expected_bytes_per_row = (self.texture_width / block_width)
583            * self.texture_format.block_copy_size(aspect).unwrap_or(4);
584        let expected_buffer_size = expected_bytes_per_row
585            * (self.texture_height / block_height)
586            * self.texture_depth_or_array_layers;
587        let data: BufferView = buffer_slice.get_mapped_range();
588        if expected_buffer_size as usize == data.len() {
589            data.to_vec()
590        } else {
591            bytemuck::cast_slice(&data)
592                .chunks_exact(
593                    align_to(expected_bytes_per_row, COPY_BYTES_PER_ROW_ALIGNMENT) as usize,
594                )
595                .flat_map(|x| x.iter().take(expected_bytes_per_row as usize))
596                .copied()
597                .collect()
598        }
599    }
600
601    fn buffer_aspect(&self) -> Option<TextureAspect> {
602        if self.texture_format.is_combined_depth_stencil_format() {
603            Some(TextureAspect::DepthOnly)
604        } else {
605            None
606        }
607    }
608
609    async fn is_zero(
610        &self,
611        ctx: &TestingContext,
612        buffer: &Buffer,
613        aspect: Option<TextureAspect>,
614    ) -> bool {
615        let is_zero = self
616            .retrieve_buffer(ctx, buffer, aspect)
617            .await
618            .iter()
619            .all(|b| *b == 0);
620        buffer.unmap();
621        is_zero
622    }
623
624    pub async fn are_zero(&self, ctx: &TestingContext) -> bool {
625        let buffer_zero = self.is_zero(ctx, &self.buffer, self.buffer_aspect()).await;
626        let mut stencil_buffer_zero = true;
627        if let Some(buffer) = &self.buffer_stencil {
628            stencil_buffer_zero = self
629                .is_zero(ctx, buffer, Some(TextureAspect::StencilOnly))
630                .await;
631        };
632        buffer_zero && stencil_buffer_zero
633    }
634
635    pub async fn assert_buffer_contents(&self, ctx: &TestingContext, expected_data: &[u8]) {
636        let result_buffer = self
637            .retrieve_buffer(ctx, &self.buffer, self.buffer_aspect())
638            .await;
639        assert!(
640            result_buffer.len() >= expected_data.len(),
641            "Result buffer ({}) smaller than expected buffer ({})",
642            result_buffer.len(),
643            expected_data.len()
644        );
645        let result_buffer = &result_buffer[..expected_data.len()];
646        assert_eq!(result_buffer, expected_data);
647        self.buffer.unmap();
648    }
649}