Skip to content

Commit 0aa2a5c

Browse files
author
yue kun
committed
refine evaluation scripts
1 parent fcfb558 commit 0aa2a5c

File tree

2 files changed

+28
-46
lines changed

2 files changed

+28
-46
lines changed

DocumentUnderstanding/LORE-TSR/src/eval.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,25 @@
3333
table_dict.append({'file_name': file_name, 'pred_table': pred_table, 'gt_table': gt_table})
3434

3535
acs = []
36+
bbox_recalls = []
37+
bbox_precisions = []
3638
for i in tqdm(range(len(table_dict))):
3739
pair = pairTab(table_dict[i]['pred_table'], table_dict[i]['gt_table'])
40+
#Acc of Logical Locations
3841
ac = pair.evalAxis()
3942
if ac != 'null':
4043
acs.append(ac)
4144

42-
print(np.array(acs).mean())
45+
#Recall of Cell Detection
46+
# recall = pair.evalBbox('recall')
47+
# bbox_recalls.append(recall)
48+
49+
# #Precision of Cell Detection
50+
# precision = pair.evalBbox('precision')
51+
# bbox_precisions.append(precision)
52+
53+
# det_precision = np.array(bbox_precisions).mean()
54+
# det_recall = np.array(bbox_recalls).mean()
55+
# f = 2 * det_precision * det_recall / (det_precision + det_recall)
56+
57+
print('Evaluation Results | Accuracy of Logical Location: {:.2f}.'.format(np.array(acs).mean()))

DocumentUnderstanding/LORE-TSR/src/lib/utils/eval_utils.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def coco_into_labels(annot_path, label_path):
2929
#file_names.append(file_name)
3030

3131
# using this for your dataset
32-
# center_file = '{}/gt_center/'.format(label_path) + file_name +'.txt'
33-
# logi_file = '{}/gt_logi/'.format(label_path) + file_name +'.txt'
32+
center_file = '{}/gt_center/'.format(label_path) + file_name +'.txt'
33+
logi_file = '{}/gt_logi/'.format(label_path) + file_name +'.txt'
3434

3535
#TODO: revise the file names in the annotation of PubTabNet
36-
center_file = gt_center_dir + file_name.replace('.jpg', '.png') +'.txt'
37-
logi_file = gt_logi_dir + file_name.replace('.jpg', '.png') +'.txt'
36+
# center_file = gt_center_dir + file_name.replace('.jpg', '.png') +'.txt'
37+
# logi_file = gt_logi_dir + file_name.replace('.jpg', '.png') +'.txt'
3838

3939

4040
ann_ids = coco_data.getAnnIds(imgIds=[img_id])
@@ -67,7 +67,9 @@ def matching(self):
6767
for tunit in self.gt_list:
6868
if_find = 0
6969
for sunit in self.pred_list:
70-
if self.compute_IOU(tunit.bbox, sunit.bbox) >= 0.1:
70+
#TODO: Adding Parameters for IOU threshold
71+
#Using IOU=0.5 as Default
72+
if self.compute_IOU(tunit.bbox, sunit.bbox) >= 0.5:
7173
self.match_list.append(sunit)
7274
if_find = 1
7375
break
@@ -84,12 +86,12 @@ def evalBbox(self, eval_type):
8486
at = len(self.gt_list)
8587
if eval_type == 'recall':
8688
if at == 0:
87-
return 1
89+
return 'null'
8890
else:
8991
return tp/at
9092
elif eval_type == 'precision':
9193
if ap == 0:
92-
return 0
94+
return 'null'
9395
else:
9496
return tp/ap
9597

@@ -145,54 +147,19 @@ def evalAxis(self):
145147
#return 0
146148
return 'null'
147149
else:
148-
return truep/tp #len(self.gt_list)
149-
150-
# def evalAxis(self):
151-
152-
# tp = 0
153-
# for u in self.match_list:
154-
# if u != 'empty':
155-
# tp = tp + 1.0
156-
157-
# truep = 0
158-
# for i in range(len(self.gt_list)):
159-
# sunit = self.match_list[i]
160-
# if sunit != 'empty':
161-
# tunit = self.gt_list[i]
162-
163-
# saxis = sunit.axis
164-
# taxis= tunit.axis
165-
166-
# flag = 1
167-
# for j in range(4):
168-
# if saxis[j] != taxis[j]:
169-
# flag = 0
170-
# break
171-
# if flag == 1:
172-
# truep = truep + 1.0
173-
# if len(self.gt_list) == 0:
174-
# return 0
175-
# else:
176-
# if tp == 0:
177-
# return 0
178-
# else:
179-
# return truep/tp #len(self.gt_list)
150+
return truep/tp
180151

181-
182-
183152
class Table():
184153
def __init__(self, bbox_path, axis_path, file_name):
185-
self.bbox_dir = os.path.join(bbox_path, file_name) #'./det_bcb_tsfm_4ps/center/' + file_name
186-
self.axis_dir = os.path.join(axis_path, file_name) #'./det_bcb_tsfm_4ps/logi/' + file_name
154+
self.bbox_dir = os.path.join(bbox_path, file_name)
155+
self.axis_dir = os.path.join(axis_path, file_name)
187156

188157
self.ulist = []
189158
self.load_tabu(self.bbox_dir, self.axis_dir)
190159
self.ulist = self.bubble_sort(self.ulist)
191160

192161
def load_tabu(self, bbox_dir, axis_dir):
193162

194-
#f_b = open(self.bbox_dir.replace('jpg', 'png'))
195-
#f_a = open(self.axis_dir.replace('jpg', 'png'))
196163
f_b = open(self.bbox_dir)
197164
f_a = open(self.axis_dir)
198165
bboxs = f_b.readlines()

0 commit comments

Comments
 (0)