Skip to content

Commit 933241a

Browse files
committed
fix padding
1 parent cf75f22 commit 933241a

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

demo_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,21 @@ def test_model_fn(args, data, model, save_path, device):
103103
# pad image such that the resolution is a multiple of 32
104104
w_pad = (math.ceil(w/32)*32 - w) // 2
105105
h_pad = (math.ceil(h/32)*32 - h) // 2
106-
in_img = img_pad(in_img, w_r=w_pad, h_r=h_pad)
106+
w_odd_pad = w_pad
107+
h_odd_pad = h_pad
108+
if w % 2 == 1:
109+
w_odd_pad += 1
110+
if h % 2 == 1:
111+
h_odd_pad += 1
112+
113+
in_img = img_pad(in_img, w_pad=w_pad, h_pad=h_pad, w_odd_pad=w_odd_pad, h_odd_pad=h_odd_pad)
107114

108115
with torch.no_grad():
109116
out_1, out_2, out_3 = model(in_img)
110117
if h_pad != 0:
111-
out_1 = out_1[:, :, h_pad:-h_pad, :]
118+
out_1 = out_1[:, :, h_pad:-h_odd_pad, :]
112119
if w_pad != 0:
113-
out_1 = out_1[:, :, :, w_pad:-w_pad]
120+
out_1 = out_1[:, :, :, w_pad:-w_odd_pad]
114121

115122
# save images
116123
if args.SAVE_IMG:
@@ -126,7 +133,7 @@ def _list_image_files_recursively(data_dir):
126133
for home, dirs, files in os.walk(data_dir):
127134
for filename in files:
128135
ext = filename.split(".")[-1]
129-
if ext.lower() in ["jpg", "jpeg", "png", "gif", "webp"] and filename[-5]=='e':
136+
if ext.lower() in ["jpg", "jpeg", "png", "gif", "webp"]:
130137
file_list.append(os.path.join(home, filename))
131138
file_list.sort()
132139
return file_list

0 commit comments

Comments
 (0)