diff --git a/gen_anchors.py b/gen_anchors.py index bb3b45a8a..e1c172b13 100644 --- a/gen_anchors.py +++ b/gen_anchors.py @@ -1,8 +1,9 @@ +import sys import random import argparse import numpy as np -from voc import parse_voc_annotation +from voc import parse_voc_annotation, parse_boxpoints_annotation import json def IOU(ann, centroids): @@ -86,26 +87,56 @@ def run_kmeans(ann_dims, anchor_num): def _main_(argv): config_path = args.conf - num_anchors = args.anchors + num_anchors = int(args.anchors) + xywh = not args.xy with open(config_path) as config_buffer: config = json.loads(config_buffer.read()) - train_imgs, train_labels = parse_voc_annotation( - config['train']['train_annot_folder'], - config['train']['train_image_folder'], - config['train']['cache_name'], - config['model']['labels'] - ) + if xywh: + print('xywh format parser') + train_imgs, train_labels = parse_voc_annotation( + config['train']['train_annot_folder'], + config['train']['train_image_folder'], + config['train']['cache_name'], + config['model']['labels'] + ) + else: + print('boxpoints format parser') + train_imgs, train_labels = parse_boxpoints_annotation( + config['train']['train_annot_folder'], + config['train']['train_image_folder'], + config['train']['cache_name'], + config['model']['labels'] + ) # run k_mean to find the anchors annotation_dims = [] + print(len(train_imgs)) + if len(train_imgs) < 1: + print('empty train_imgs') + sys.exit() for image in train_imgs: print(image['filename']) - for obj in image['object']: - relative_w = (float(obj['xmax']) - float(obj['xmin']))/image['width'] - relatice_h = (float(obj["ymax"]) - float(obj['ymin']))/image['height'] - annotation_dims.append(tuple(map(float, (relative_w,relatice_h)))) + if xywh: + for image in train_imgs: + print(image['filename']) + for obj in image['object']: + relative_w = (float(obj['xmax']) - float(obj['xmin']))/image['width'] + relatice_h = (float(obj["ymax"]) - float(obj['ymin']))/image['height'] + annotation_dims.append(tuple(map(float, (relative_w,relatice_h)))) + else: + for image in train_imgs: + print(image['filename']) + for obj in image['object']: + # numpy.linalg.norm(a-b) + relative_w = float(np.linalg.norm( + np.array((float(obj["x1"]),float(obj["y1"])))-np.array((float(obj["x2"]),float(obj["y2"]))) + ))/image['width'] + relatice_h = float(np.linalg.norm( + np.array((float(obj["x1"]),float(obj["y1"])))-np.array((float(obj["x3"]),float(obj["y3"]))) + ))/image['height'] + annotation_dims.append(tuple(map(float, (relative_w,relatice_h)))) annotation_dims = np.array(annotation_dims) centroids = run_kmeans(annotation_dims, num_anchors) @@ -127,6 +158,8 @@ def _main_(argv): '--anchors', default=9, help='number of anchors to use') - + argparser.add_argument( + '-xy', + action="store_true") args = argparser.parse_args() _main_(args) diff --git a/voc.py b/voc.py index f51e5fd4a..0d0748140 100644 --- a/voc.py +++ b/voc.py @@ -64,4 +64,81 @@ def parse_voc_annotation(ann_dir, img_dir, cache_name, labels=[]): with open(cache_name, 'wb') as handle: pickle.dump(cache, handle, protocol=pickle.HIGHEST_PROTOCOL) + return all_insts, seen_labels + +def parse_boxpoints_annotation(ann_dir, img_dir, cache_name, labels=[]): + if os.path.exists(cache_name): + with open(cache_name, 'rb') as handle: + cache = pickle.load(handle) + all_insts, seen_labels = cache['all_insts'], cache['seen_labels'] + else: + all_insts = [] + seen_labels = {} + + for ann in sorted(os.listdir(ann_dir)): + img = {'object':[]} + + try: + tree = ET.parse(ann_dir + ann) + except Exception as e: + print(e) + print('Ignore this bad annotation: ' + ann_dir + ann) + continue + + for elem in tree.iter(): + if 'filename' in elem.tag: + img['filename'] = img_dir + elem.text + if 'width' in elem.tag: + img['width'] = int(elem.text) + if 'height' in elem.tag: + img['height'] = int(elem.text) + if 'object' in elem.tag or 'part' in elem.tag: + obj = {} + + for attr in list(elem): + if 'name' in attr.tag: + obj['name'] = attr.text + + if obj['name'] in seen_labels: + seen_labels[obj['name']] += 1 + else: + seen_labels[obj['name']] = 1 + + if len(labels) > 0 and obj['name'] not in labels: + break + else: + img['object'] += [obj] + + if 'bndbox' in attr.tag: + for dim in list(attr): + if 'x0' in dim.tag: + obj['x0'] = int(round(float(dim.text))) + else: + if 'x4' in dim.tag: + obj['x4'] = int(round(float(dim.text))) + if 'y0' in dim.tag: + obj['y0'] = int(round(float(dim.text))) + else: + if 'y4' in dim.tag: + obj['y4'] = int(round(float(dim.text))) + if 'x1' in dim.tag: + obj['x1'] = int(round(float(dim.text))) + if 'y1' in dim.tag: + obj['y1'] = int(round(float(dim.text))) + if 'x2' in dim.tag: + obj['x2'] = int(round(float(dim.text))) + if 'y2' in dim.tag: + obj['y2'] = int(round(float(dim.text))) + if 'x3' in dim.tag: + obj['x3'] = int(round(float(dim.text))) + if 'y3' in dim.tag: + obj['y3'] = int(round(float(dim.text))) + + if len(img['object']) > 0: + all_insts += [img] + + cache = {'all_insts': all_insts, 'seen_labels': seen_labels} + with open(cache_name, 'wb') as handle: + pickle.dump(cache, handle, protocol=pickle.HIGHEST_PROTOCOL) + return all_insts, seen_labels \ No newline at end of file