Skip to content

Commit

Permalink
use standard nms as last step
Browse files Browse the repository at this point in the history
  • Loading branch information
zxytim committed Aug 2, 2017
1 parent e71a8e5 commit b82c728
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 80 deletions.
25 changes: 22 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
tf.app.flags.DEFINE_string('test_data_path', '/tmp/ch4_test_images/images/', '')
tf.app.flags.DEFINE_string('gpu_list', '0', '')
tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/east_icdar2015_resnet_v1_50_rbox/', '')
tf.app.flags.DEFINE_string('output_path', '/tmp/ch4_test_images/images/', '')
tf.app.flags.DEFINE_string('output_dir', '/tmp/ch4_test_images/images/', '')
tf.app.flags.DEFINE_bool('no_write_images', False, 'do not write images')

import model
from icdar import restore_rectangle
Expand Down Expand Up @@ -123,6 +124,13 @@ def main(argv=None):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list


try:
os.makedirs(FLAGS.output_dir)
except OSError as e:
if e.errno != 17:
raise

with tf.get_default_graph().as_default():
input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
Expand All @@ -141,6 +149,7 @@ def main(argv=None):
im_fn_list = get_images()
for im_fn in im_fn_list:
im = cv2.imread(im_fn)[:, :, ::-1]
start_time = time.time()
im_resized, (ratio_h, ratio_w) = resize_image(im)

timer = {'net': 0, 'restore': 0, 'nms': 0}
Expand All @@ -157,9 +166,17 @@ def main(argv=None):
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h

duration = time.time() - start_time
print('[timing] {}'.format(duration))

# save to file
if boxes is not None:
with open(FLAGS.output_path + 'res_{}.txt'.format(os.path.basename(im_fn).split('.')[0]), 'w') as f:
res_file = os.path.join(
FLAGS.output_dir,
'{}.txt'.format(
os.path.basename(im_fn).split('.')[0]))

with open(res_file, 'w') as f:
for box in boxes:
# to avoid submitting errors
box = sort_poly(box.astype(np.int32))
Expand All @@ -169,7 +186,9 @@ def main(argv=None):
box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1],
))
cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
cv2.imwrite(os.path.join(FLAGS.output_path, os.path.basename(im_fn)), im[:, :, ::-1])
if not FLAGS.no_write_images:
img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn))
cv2.imwrite(img_path, im[:, :, ::-1])

if __name__ == '__main__':
tf.app.run()
2 changes: 1 addition & 1 deletion lanms/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
LDFLAGS = $(shell python3-config --ldflags)

DEPS = $(shell find include -xtype f)
DEPS = lanms.h $(shell find include -xtype f)
CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp

LIB_SO = adaptor.so
Expand Down
3 changes: 1 addition & 2 deletions lanms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR))


def merge_quadrangle_n9(polys, thres=0.3, precision=1000):
def merge_quadrangle_n9(polys, thres=0.3, precision=10000):
from .adaptor import merge_quadrangle_n9 as nms_impl
if len(polys) == 0:
return np.array([], dtype='float32')
Expand All @@ -18,4 +18,3 @@ def merge_quadrangle_n9(polys, thres=0.3, precision=1000):
ret[:,:8] /= precision
return ret


107 changes: 33 additions & 74 deletions lanms/lanms.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,25 @@ namespace lanms {
return poly_iou(a, b) > iou_threshold;
}

/**
* Incrementally merge polygons
*/
class PolyMerger {
public:
PolyMerger(): score(0), nr_polys(0) {
memset(data, 0, sizeof(data));
}

/**
* Add a new polygon to be merged.
*/
void add(const Polygon &p_given) {
Polygon p;
if (nr_polys > 0) {
// vertices of two polygons to merge may not in the same order;
// we match their vertices by choosing the ordering that
// minimizes the total squared distance.
// see function normalize_poly for details.
p = normalize_poly(get(), p_given);
} else {
p = p_given;
Expand Down Expand Up @@ -156,86 +166,35 @@ namespace lanms {
};


class DisjointSet {
public:
DisjointSet(size_t size): m_parent(size) {
std::iota(std::begin(m_parent), std::end(m_parent), 0);
}

bool test(size_t a, size_t b) {
assert(a < size() && b < size());
return get_root(a) == get_root(b);
}

void merge(size_t a, size_t b) {
assert(a < size() && b < size());
m_parent[get_root(a)] = get_root(b);
}

size_t get_root(size_t x) {
assert(x < size());
return x == m_parent[x] ? x : (m_parent[x] = get_root(m_parent[x]));
}

std::vector<std::vector<size_t>> get_groups() {
std::vector<std::pair<size_t, size_t>> root2id;
for (size_t i = 0; i < size(); i ++) {
root2id.emplace_back(std::make_pair(get_root(i), i));
}
std::sort(std::begin(root2id), std::end(root2id));

std::vector<std::vector<size_t>> groups;
size_t last_root = std::numeric_limits<size_t>::max();
for (auto &&p: root2id) {
if (last_root != p.first) {
groups.emplace_back();
}
auto &g = groups.back();
g.emplace_back(p.second);
last_root = p.first;
}
return groups;
}

inline size_t size() const {
return m_parent.size();
}


private:
std::vector<size_t> m_parent;
};


std::vector<Polygon> naive_merge(std::vector<Polygon> &polys, float iou_threshold) {
std::sort(std::begin(polys), std::end(polys),
[](const Polygon &a, const Polygon &b) {
return a.score > b.score;
});
auto n = polys.size();
DisjointSet ds(n);
for (size_t i = 0; i < n; i ++) {
for (size_t j = i + 1; j < n; j ++) {
if (ds.test(i, j))
continue;
if (should_merge(polys[i], polys[j], iou_threshold)) {
ds.merge(i, j);
/**
* The standard NMS algorithm.
*/
std::vector<Polygon> standard_nms(std::vector<Polygon> &polys, float iou_threshold) {
size_t n = polys.size();
if (n == 0)
return {};
std::vector<size_t> indices(n);
std::iota(std::begin(indices), std::end(indices), 0);
std::sort(std::begin(indices), std::end(indices), [&](size_t i, size_t j) { return polys[i].score > polys[j].score; });

std::vector<size_t> keep;
while (indices.size()) {
size_t p = 0, cur = indices[0];
keep.emplace_back(cur);
for (size_t i = 1; i < indices.size(); i ++) {
if (!should_merge(polys[cur], polys[indices[i]], iou_threshold)) {
indices[p ++] = indices[i];
}
}
indices.resize(p);
}

auto groups = ds.get_groups();
std::vector<Polygon> ret;
for (auto &&g: groups) {
PolyMerger merger;
for (auto &&i: g) {
merger.add(polys[i]);
}
ret.emplace_back(merger.get());
for (auto &&i: keep) {
ret.emplace_back(polys[i]);
}

return ret;
}
}

std::vector<Polygon>
merge_quadrangle_n9(const float *data, size_t n, float iou_threshold) {
Expand Down Expand Up @@ -270,6 +229,6 @@ namespace lanms {
polys.emplace_back(poly);
}
}
return naive_merge(polys, iou_threshold);
return standard_nms(polys, iou_threshold);
}
}

0 comments on commit b82c728

Please sign in to comment.