Skip to content

Commit 15b2a4e

Browse files
committed
modify script of check dataset
1 parent 8587c1f commit 15b2a4e

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ I recommand you to check the information of your dataset with the script:
113113
```
114114
$ python tools/check_dataset_info.py --im_root /path/to/your/data_root --im_anns /path/to/your/anno_file
115115
```
116+
This will print some of the information of your dataset.
116117
Then you need to change the field of `im_root` and `train/val_im_anns` in the config file. I prepared a demo config file for you named [`bisenet_customer.py`](./configs/bisenet_customer.py). You can start from this conig file.
117118

118119

tools/check_dataset_info.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
parse = argparse.ArgumentParser()
1212
parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',)
1313
parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',)
14+
parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255)
1415
args = parse.parse_args()
1516

17+
lb_ignore = args.lb_ignore
18+
1619

1720
with open(args.im_anns, 'r') as fr:
1821
lines = fr.read().splitlines()
@@ -54,15 +57,20 @@
5457
if shape[1] < min_shape_width[1]:
5558
min_shape_width = shape
5659

57-
max_lb_val = max(max_lb_val, np.max(lb.ravel()))
58-
min_lb_val = min(min_lb_val, np.min(lb.ravel()))
59-
60+
lb = lb[lb != lb_ignore]
61+
if lb.size > 0:
62+
max_lb_val = max(max_lb_val, np.max(lb))
63+
min_lb_val = min(min_lb_val, np.min(lb))
6064

65+
min_lb_val = 0
66+
max_lb_val = 181
67+
lb_minlength = 182
6168
## label info
6269
lb_minlength = max_lb_val+1-min_lb_val
6370
lb_hist = np.zeros(lb_minlength)
64-
for impth in tqdm(impaths):
65-
lb = cv2.imread(lbpth, 0).ravel() + min_lb_val
71+
for lbpth in tqdm(lbpaths):
72+
lb = cv2.imread(lbpth, 0)
73+
lb = lb[lb != lb_ignore] + min_lb_val
6674
lb_hist += np.bincount(lb, minlength=lb_minlength)
6775

6876
lb_missing_vals = [ind + min_lb_val
@@ -75,38 +83,39 @@
7583
n_pixels = 0
7684
for impth in tqdm(impaths):
7785
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
78-
im = im.reshape(-1, 3)
86+
im = im.reshape(-1, 3) / 255.
7987
n_pixels += im.shape[0]
8088
rgb_mean += im.sum(axis=0)
81-
rgb_mean = rgb_mean / n_pixels
89+
rgb_mean = (rgb_mean / n_pixels)
8290

8391
rgb_std = np.zeros(3).astype(np.float32)
8492
for impth in tqdm(impaths):
8593
im = cv2.imread(impth)[:, :, ::-1].astype(np.float32)
86-
im = im.reshape(-1, 3)
94+
im = im.reshape(-1, 3) / 255.
8795

8896
a = (im - rgb_mean.reshape(1, 3)) ** 2
8997
rgb_std += a.sum(axis=0)
90-
rgb_std = (rgb_std / n_pixels) ** (0.5)
98+
rgb_std = (rgb_std / n_pixels) ** 0.5
99+
100+
rgb_mean = rgb_mean.tolist()
101+
rgb_std = rgb_std.tolist()
91102

92103

104+
print('\n')
93105
print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs')
94106
print('\n')
95107

96-
print('max and min image shapes by area are: ')
97-
print(f'\t{max_shape_area}, {min_shape_area}')
98-
print('max and min image shapes by height are: ')
99-
print(f'\t{max_shape_height}, {min_shape_height}')
100-
print('max and min image shapes by width are: ')
101-
print(f'\t{max_shape_width}, {min_shape_width}')
108+
print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}')
109+
print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}')
110+
print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}')
102111
print('\n')
103112

104-
print(f'label values are within range of ({min_lb_val}, {max_lb_val})')
105-
print('label values that are missing: ')
106-
print('\t', lb_missing_vals)
113+
print(f'we ignore label value of {args.lb_ignore} in label images')
114+
print(f'label values are within range of [{min_lb_val}, {max_lb_val}]')
115+
print(f'label values that are missing: {lb_missing_vals}')
107116
print('ratios of each label value: ')
108117
print('\t', lb_ratios)
109118
print('\n')
110119

111-
print('pixel mean rgb: ', mean)
112-
print('pixel std rgb: ', std)
120+
print('pixel mean rgb: ', rgb_mean)
121+
print('pixel std rgb: ', rgb_std)

0 commit comments

Comments
 (0)