Skip to content

Commit d20aa4e

Browse files
author
yue kun
committed
rm some useless code
1 parent 8ca7e7e commit d20aa4e

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

OCR/MGP-STR/test_final.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):
239239
if out_pred == gt:
240240
out_n_correct += 1
241241

242-
# calculate confidence score (= multiply of pred_max_prob)
243-
try:
244-
confidence_score = char_preds_max_prob[index].cumprod(dim=0)[-1]
245-
except:
246-
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
247-
confidence_score_list.append(confidence_score)
242+
confidence_score_list.append(char_confidence_score)
248243

249244
elif opt.Transformer in ["char-str"]:
250245
attens, char_preds = model(image, is_eval=True) # final
@@ -393,8 +388,8 @@ def test(opt):
393388
_, accuracy_by_best_model, _, _, _, _, _, _ = validation(
394389
model, criterion, evaluation_loader, converter, opt)
395390
log.write(eval_data_log)
396-
print(f'{accuracy_by_best_model:0.3f}')
397-
log.write(f'{accuracy_by_best_model:0.3f}\n')
391+
print(f'{accuracy_by_best_model[0]:0.3f}')
392+
log.write(f'{accuracy_by_best_model[0]:0.3f}\n')
398393
log.close()
399394

400395
# https://github.com/clovaai/deep-text-recognition-benchmark/issues/125

OCR/MGP-STR/train_final_dist.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ def train(opt):
7979

8080
if opt.saved_model != '':
8181
print(f'loading pretrained model from {opt.saved_model}')
82-
if opt.FT:
83-
model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True)
84-
else:
85-
model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True)
82+
model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True)
8683

8784
""" setup loss """
8885
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0

0 commit comments

Comments
 (0)