Skip to content

Commit 8ee83c2

Browse files
committed
fix a bug on postprocessing
1 parent 2f731c6 commit 8ee83c2

File tree

4 files changed

+5
-7
lines changed

4 files changed

+5
-7
lines changed

Diff for: base/base_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, config, model, criterion):
5353
self.device = torch.device("cpu")
5454
self.logger_info('train with device {} and pytorch {}'.format(self.device, torch.__version__))
5555
# metrics
56-
self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'), 'best_model': ''}
56+
self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf')}
5757

5858
self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())
5959

Diff for: eval.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path '' --img_folder '' --gt_folder '' --save_folder ''
1+
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path ''

Diff for: post_processing/seg_detector_representer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
6666
# _, sside = self.get_mini_boxes(contour)
6767
# if sside < self.min_size:
6868
# continue
69-
score = self.box_score_fast(pred, points.reshape(-1, 2))
69+
score = self.box_score_fast(pred, contour)
7070
if self.box_thresh > score:
7171
continue
7272

@@ -112,7 +112,7 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
112112
if sside < self.min_size:
113113
continue
114114
points = np.array(points)
115-
score = self.box_score_fast(pred, points.reshape(-1, 2))
115+
score = self.box_score_fast(pred, contour)
116116
if self.box_thresh > score:
117117
continue
118118

Diff for: trainer/trainer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def _on_epoch_finish(self):
149149
self.epoch_result['lr']))
150150
net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)
151151

152-
save_best = False
153152
if self.config['local_rank'] == 0:
153+
save_best = False
154154
if self.config['trainer']['metrics'] == 'hmean': # 使用f1作为最优模型指标
155155
recall, precision, hmean = self._eval(self.epoch_result['epoch'])
156156

@@ -166,12 +166,10 @@ def _on_epoch_finish(self):
166166
self.metrics['hmean'] = hmean
167167
self.metrics['precision'] = precision
168168
self.metrics['recall'] = recall
169-
self.metrics['best_model'] = net_save_path
170169
else:
171170
if self.epoch_result['train_loss'] < self.metrics['train_loss']:
172171
save_best = True
173172
self.metrics['train_loss'] = self.epoch_result['train_loss']
174-
self.metrics['best_model'] = net_save_path
175173
self._save_checkpoint(self.epoch_result['epoch'], net_save_path, save_best)
176174

177175
def _on_train_finish(self):

0 commit comments

Comments
 (0)