-
Notifications
You must be signed in to change notification settings - Fork 27
/
main.py
105 lines (91 loc) · 3.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import pdb
import pickle
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from argparse import ArgumentParser
from src.modules.trainer import OCRTrainer
from src.utils.utils import EarlyStopping, gmkdir
from src.models.crnn import CRNN
from src.options.opts import base_opts
from src.data.synth_dataset import SynthDataset, SynthCollator
from src.data.pickle_dataset import PickleDataset
from src.criterions.ctc import CustomCTCLoss
from src.utils.top_sampler import SamplingTop
class Learner(object):
def __init__(self, model, optimizer, savepath=None, resume=False):
self.model = model
self.optimizer = optimizer
self.savepath = os.path.join(savepath, 'best.ckpt')
self.cuda = torch.cuda.is_available()
self.cuda_count = torch.cuda.device_count()
if self.cuda:
self.model = self.model.cuda()
self.epoch = 0
if self.cuda_count > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
self.model = nn.DataParallel(self.model)
self.best_score = None
if resume and os.path.exists(self.savepath):
self.checkpoint = torch.load(self.savepath)
self.epoch = self.checkpoint['epoch']
self.best_score=self.checkpoint['best']
self.load()
else:
print('checkpoint does not exist')
def fit(self, opt):
opt.cuda = self.cuda
opt.model = self.model
opt.optimizer = self.optimizer
logging.basicConfig(filename="%s/%s.csv" %(opt.log_dir, opt.name), level=logging.INFO)
self.saver = EarlyStopping(self.savepath, patience=15, verbose=True, best_score=self.best_score)
opt.epoch = self.epoch
trainer = OCRTrainer(opt)
for epoch in range(opt.epoch, opt.epochs):
train_result = trainer.run_epoch()
val_result = trainer.run_epoch(validation=True)
trainer.count = epoch
info = '%d, %.6f, %.6f, %.6f, %.6f, %.6f, %.6f'%(epoch, train_result['train_loss'],
val_result['val_loss'], train_result['train_ca'], val_result['val_ca'],
train_result['train_wa'], val_result['val_wa'])
logging.info(info)
self.val_loss = val_result['val_loss']
print(self.val_loss)
if self.savepath:
self.save(epoch)
if self.saver.early_stop:
print("Early stopping")
break
def load(self):
print('Loading checkpoint at {} trained for {} epochs'.format(self.savepath, self.checkpoint['epoch']))
self.model.load_state_dict(self.checkpoint['state_dict'])
if 'opt_state_dict' in self.checkpoint.keys():
print('Loading optimizer')
self.optimizer.load_state_dict(self.checkpoint['opt_state_dict'])
def save(self, epoch):
self.saver(self.val_loss, epoch, self.model, self.optimizer)
if __name__ == '__main__':
parser = ArgumentParser()
base_opts(parser)
args = parser.parse_args()
data = PickleDataset(args)
args.collate_fn = SynthCollator()
train_split = int(0.8*len(data))
val_split = len(data) - train_split
args.data_train, args.data_val = random_split(data, (train_split, val_split))
print('Traininig Data Size:{}\nVal Data Size:{}'.format(
len(args.data_train), len(args.data_val)))
args.alphabet = """Only thewigsofrcvdampbkuq.$A-210xT5'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%"""
args.nClasses = len(args.alphabet)
model = CRNN(args)
args.criterion = CustomCTCLoss()
savepath = os.path.join(args.save_dir, args.name)
gmkdir(savepath)
gmkdir(args.log_dir)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
learner = Learner(model, optimizer, savepath=savepath, resume=args.resume)
learner.fit(args)