@@ -29,12 +29,12 @@ def coco_into_labels(annot_path, label_path):
2929 #file_names.append(file_name)
3030
3131 # using this for your dataset
32- # center_file = '{}/gt_center/'.format(label_path) + file_name +'.txt'
33- # logi_file = '{}/gt_logi/'.format(label_path) + file_name +'.txt'
32+ center_file = '{}/gt_center/' .format (label_path ) + file_name + '.txt'
33+ logi_file = '{}/gt_logi/' .format (label_path ) + file_name + '.txt'
3434
3535 #TODO: revise the file names in the annotation of PubTabNet
36- center_file = gt_center_dir + file_name .replace ('.jpg' , '.png' ) + '.txt'
37- logi_file = gt_logi_dir + file_name .replace ('.jpg' , '.png' ) + '.txt'
36+ # center_file = gt_center_dir + file_name.replace('.jpg', '.png') +'.txt'
37+ # logi_file = gt_logi_dir + file_name.replace('.jpg', '.png') +'.txt'
3838
3939
4040 ann_ids = coco_data .getAnnIds (imgIds = [img_id ])
@@ -67,7 +67,9 @@ def matching(self):
6767 for tunit in self .gt_list :
6868 if_find = 0
6969 for sunit in self .pred_list :
70- if self .compute_IOU (tunit .bbox , sunit .bbox ) >= 0.1 :
70+ #TODO: Adding Parameters for IOU threshold
71+ #Using IOU=0.5 as Default
72+ if self .compute_IOU (tunit .bbox , sunit .bbox ) >= 0.5 :
7173 self .match_list .append (sunit )
7274 if_find = 1
7375 break
@@ -84,12 +86,12 @@ def evalBbox(self, eval_type):
8486 at = len (self .gt_list )
8587 if eval_type == 'recall' :
8688 if at == 0 :
87- return 1
89+ return 'null'
8890 else :
8991 return tp / at
9092 elif eval_type == 'precision' :
9193 if ap == 0 :
92- return 0
94+ return 'null'
9395 else :
9496 return tp / ap
9597
@@ -145,54 +147,19 @@ def evalAxis(self):
145147 #return 0
146148 return 'null'
147149 else :
148- return truep / tp #len(self.gt_list)
149-
150- # def evalAxis(self):
151-
152- # tp = 0
153- # for u in self.match_list:
154- # if u != 'empty':
155- # tp = tp + 1.0
156-
157- # truep = 0
158- # for i in range(len(self.gt_list)):
159- # sunit = self.match_list[i]
160- # if sunit != 'empty':
161- # tunit = self.gt_list[i]
162-
163- # saxis = sunit.axis
164- # taxis= tunit.axis
165-
166- # flag = 1
167- # for j in range(4):
168- # if saxis[j] != taxis[j]:
169- # flag = 0
170- # break
171- # if flag == 1:
172- # truep = truep + 1.0
173- # if len(self.gt_list) == 0:
174- # return 0
175- # else:
176- # if tp == 0:
177- # return 0
178- # else:
179- # return truep/tp #len(self.gt_list)
150+ return truep / tp
180151
181-
182-
183152class Table ():
184153 def __init__ (self , bbox_path , axis_path , file_name ):
185- self .bbox_dir = os .path .join (bbox_path , file_name ) #'./det_bcb_tsfm_4ps/center/' + file_name
186- self .axis_dir = os .path .join (axis_path , file_name ) #'./det_bcb_tsfm_4ps/logi/' + file_name
154+ self .bbox_dir = os .path .join (bbox_path , file_name )
155+ self .axis_dir = os .path .join (axis_path , file_name )
187156
188157 self .ulist = []
189158 self .load_tabu (self .bbox_dir , self .axis_dir )
190159 self .ulist = self .bubble_sort (self .ulist )
191160
192161 def load_tabu (self , bbox_dir , axis_dir ):
193162
194- #f_b = open(self.bbox_dir.replace('jpg', 'png'))
195- #f_a = open(self.axis_dir.replace('jpg', 'png'))
196163 f_b = open (self .bbox_dir )
197164 f_a = open (self .axis_dir )
198165 bboxs = f_b .readlines ()
0 commit comments