1
+ import datetime
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import pytest
6
+ from pytest import fixture
7
+ from typing import Tuple , List
8
+
9
+ from cv2 import imread , cvtColor , COLOR_BGR2RGB
10
+ from skimage .metrics import structural_similarity as ssim
11
+
12
+
13
+ """
14
+ This test suite compares images in 2 directories by file name
15
+ The directories are specified by the command line arguments --baseline_dir and --test_dir
16
+
17
+ """
18
+ # ssim: Structural Similarity Index
19
+ # Returns a tuple of (ssim, diff_image)
20
+ def ssim_score (img0 : np .ndarray , img1 : np .ndarray ) -> Tuple [float , np .ndarray ]:
21
+ score , diff = ssim (img0 , img1 , channel_axis = - 1 , full = True )
22
+ # rescale the difference image to 0-255 range
23
+ diff = (diff * 255 ).astype ("uint8" )
24
+ return score , diff
25
+
26
+ # Metrics must return a tuple of (score, diff_image)
27
+ METRICS = {"ssim" : ssim_score }
28
+ METRICS_PASS_THRESHOLD = {"ssim" : 0.95 }
29
+
30
+
31
+ class TestCompareImageMetrics :
32
+ @fixture (scope = "class" )
33
+ def test_file_names (self , args_pytest ):
34
+ test_dir = args_pytest ['test_dir' ]
35
+ fnames = self .gather_file_basenames (test_dir )
36
+ yield fnames
37
+ del fnames
38
+
39
+ @fixture (scope = "class" , autouse = True )
40
+ def teardown (self , args_pytest ):
41
+ yield
42
+ # Runs after all tests are complete
43
+ # Aggregate output files into a grid of images
44
+ baseline_dir = args_pytest ['baseline_dir' ]
45
+ test_dir = args_pytest ['test_dir' ]
46
+ img_output_dir = args_pytest ['img_output_dir' ]
47
+ metrics_file = args_pytest ['metrics_file' ]
48
+
49
+ grid_dir = os .path .join (img_output_dir , "grid" )
50
+ os .makedirs (grid_dir , exist_ok = True )
51
+
52
+ for metric_dir in METRICS .keys ():
53
+ metric_path = os .path .join (img_output_dir , metric_dir )
54
+ for file in os .listdir (metric_path ):
55
+ if file .endswith (".png" ):
56
+ score = self .lookup_score_from_fname (file , metrics_file )
57
+ image_file_list = []
58
+ image_file_list .append ([
59
+ os .path .join (baseline_dir , file ),
60
+ os .path .join (test_dir , file ),
61
+ os .path .join (metric_path , file )
62
+ ])
63
+ # Create grid
64
+ image_list = [[Image .open (file ) for file in files ] for files in image_file_list ]
65
+ grid = self .image_grid (image_list )
66
+ grid .save (os .path .join (grid_dir , f"{ metric_dir } _{ score :.3f} _{ file } " ))
67
+
68
+ # Tests run for each baseline file name
69
+ @fixture ()
70
+ def fname (self , baseline_fname ):
71
+ yield baseline_fname
72
+ del baseline_fname
73
+
74
+ def test_directories_not_empty (self , args_pytest ):
75
+ baseline_dir = args_pytest ['baseline_dir' ]
76
+ test_dir = args_pytest ['test_dir' ]
77
+ assert len (os .listdir (baseline_dir )) != 0 , f"Baseline directory { baseline_dir } is empty"
78
+ assert len (os .listdir (test_dir )) != 0 , f"Test directory { test_dir } is empty"
79
+
80
+ def test_dir_has_all_matching_metadata (self , fname , test_file_names , args_pytest ):
81
+ # Check that all files in baseline_dir have a file in test_dir with matching metadata
82
+ baseline_file_path = os .path .join (args_pytest ['baseline_dir' ], fname )
83
+ file_paths = [os .path .join (args_pytest ['test_dir' ], f ) for f in test_file_names ]
84
+ file_match = self .find_file_match (baseline_file_path , file_paths )
85
+ assert file_match is not None , f"Could not find a file in { args_pytest ['test_dir' ]} with matching metadata to { baseline_file_path } "
86
+
87
+ # For a baseline image file, finds the corresponding file name in test_dir and
88
+ # compares the images using the metrics in METRICS
89
+ @pytest .mark .parametrize ("metric" , METRICS .keys ())
90
+ def test_pipeline_compare (
91
+ self ,
92
+ args_pytest ,
93
+ fname ,
94
+ test_file_names ,
95
+ metric ,
96
+ ):
97
+ baseline_dir = args_pytest ['baseline_dir' ]
98
+ test_dir = args_pytest ['test_dir' ]
99
+ metrics_output_file = args_pytest ['metrics_file' ]
100
+ img_output_dir = args_pytest ['img_output_dir' ]
101
+
102
+ baseline_file_path = os .path .join (baseline_dir , fname )
103
+
104
+ # Find file match
105
+ file_paths = [os .path .join (test_dir , f ) for f in test_file_names ]
106
+ test_file = self .find_file_match (baseline_file_path , file_paths )
107
+
108
+ # Run metrics
109
+ sample_baseline = self .read_img (baseline_file_path )
110
+ sample_secondary = self .read_img (test_file )
111
+
112
+ score , metric_img = METRICS [metric ](sample_baseline , sample_secondary )
113
+ metric_status = score > METRICS_PASS_THRESHOLD [metric ]
114
+
115
+ # Save metric values
116
+ with open (metrics_output_file , 'a' ) as f :
117
+ run_info = os .path .splitext (fname )[0 ]
118
+ metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
119
+ date_str = datetime .datetime .now ().strftime ("%Y-%m-%d %H:%M:%S" )
120
+ f .write (f"| { date_str } | { run_info } | { metric } | { metric_status_str } | { score } | \n " )
121
+
122
+ # Save metric image
123
+ metric_img_dir = os .path .join (img_output_dir , metric )
124
+ os .makedirs (metric_img_dir , exist_ok = True )
125
+ output_filename = f'{ fname } '
126
+ Image .fromarray (metric_img ).save (os .path .join (metric_img_dir , output_filename ))
127
+
128
+ assert score > METRICS_PASS_THRESHOLD [metric ]
129
+
130
+ def read_img (self , filename : str ) -> np .ndarray :
131
+ cvImg = imread (filename )
132
+ cvImg = cvtColor (cvImg , COLOR_BGR2RGB )
133
+ return cvImg
134
+
135
+ def image_grid (self , img_list : list [list [Image .Image ]]):
136
+ # imgs is a 2D list of images
137
+ # Assumes the input images are a rectangular grid of equal sized images
138
+ rows = len (img_list )
139
+ cols = len (img_list [0 ])
140
+
141
+ w , h = img_list [0 ][0 ].size
142
+ grid = Image .new ('RGB' , size = (cols * w , rows * h ))
143
+
144
+ for i , row in enumerate (img_list ):
145
+ for j , img in enumerate (row ):
146
+ grid .paste (img , box = (j * w , i * h ))
147
+ return grid
148
+
149
+ def lookup_score_from_fname (self ,
150
+ fname : str ,
151
+ metrics_output_file : str
152
+ ) -> float :
153
+ fname_basestr = os .path .splitext (fname )[0 ]
154
+ with open (metrics_output_file , 'r' ) as f :
155
+ for line in f :
156
+ if fname_basestr in line :
157
+ score = float (line .split ('|' )[5 ])
158
+ return score
159
+ raise ValueError (f"Could not find score for { fname } in { metrics_output_file } " )
160
+
161
+ def gather_file_basenames (self , directory : str ):
162
+ files = []
163
+ for file in os .listdir (directory ):
164
+ if file .endswith (".png" ):
165
+ files .append (file )
166
+ return files
167
+
168
+ def read_file_prompt (self , fname :str ) -> str :
169
+ # Read prompt from image file metadata
170
+ img = Image .open (fname )
171
+ img .load ()
172
+ return img .info ['prompt' ]
173
+
174
+ def find_file_match (self , baseline_file : str , file_paths : List [str ]):
175
+ # Find a file in file_paths with matching metadata to baseline_file
176
+ baseline_prompt = self .read_file_prompt (baseline_file )
177
+
178
+ # Do not match empty prompts
179
+ if baseline_prompt is None or baseline_prompt == "" :
180
+ return None
181
+
182
+ # Find file match
183
+ # Reorder test_file_names so that the file with matching name is first
184
+ # This is an optimization because matching file names are more likely
185
+ # to have matching metadata if they were generated with the same script
186
+ basename = os .path .basename (baseline_file )
187
+ file_path_basenames = [os .path .basename (f ) for f in file_paths ]
188
+ if basename in file_path_basenames :
189
+ match_index = file_path_basenames .index (basename )
190
+ file_paths .insert (0 , file_paths .pop (match_index ))
191
+
192
+ for f in file_paths :
193
+ test_file_prompt = self .read_file_prompt (f )
194
+ if baseline_prompt == test_file_prompt :
195
+ return f
0 commit comments