@@ -103,14 +103,21 @@ def test_model_fn(args, data, model, save_path, device):
103103 # pad image such that the resolution is a multiple of 32
104104 w_pad = (math .ceil (w / 32 )* 32 - w ) // 2
105105 h_pad = (math .ceil (h / 32 )* 32 - h ) // 2
106- in_img = img_pad (in_img , w_r = w_pad , h_r = h_pad )
106+ w_odd_pad = w_pad
107+ h_odd_pad = h_pad
108+ if w % 2 == 1 :
109+ w_odd_pad += 1
110+ if h % 2 == 1 :
111+ h_odd_pad += 1
112+
113+ in_img = img_pad (in_img , w_pad = w_pad , h_pad = h_pad , w_odd_pad = w_odd_pad , h_odd_pad = h_odd_pad )
107114
108115 with torch .no_grad ():
109116 out_1 , out_2 , out_3 = model (in_img )
110117 if h_pad != 0 :
111- out_1 = out_1 [:, :, h_pad :- h_pad , :]
118+ out_1 = out_1 [:, :, h_pad :- h_odd_pad , :]
112119 if w_pad != 0 :
113- out_1 = out_1 [:, :, :, w_pad :- w_pad ]
120+ out_1 = out_1 [:, :, :, w_pad :- w_odd_pad ]
114121
115122 # save images
116123 if args .SAVE_IMG :
@@ -126,7 +133,7 @@ def _list_image_files_recursively(data_dir):
126133 for home , dirs , files in os .walk (data_dir ):
127134 for filename in files :
128135 ext = filename .split ("." )[- 1 ]
129- if ext .lower () in ["jpg" , "jpeg" , "png" , "gif" , "webp" ] and filename [ - 5 ] == 'e' :
136+ if ext .lower () in ["jpg" , "jpeg" , "png" , "gif" , "webp" ]:
130137 file_list .append (os .path .join (home , filename ))
131138 file_list .sort ()
132139 return file_list
0 commit comments