-
Notifications
You must be signed in to change notification settings - Fork 27
/
train_semi_supervised.py
125 lines (110 loc) · 4.25 KB
/
train_semi_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import pdb
import pickle
import logging
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from argparse import ArgumentParser
import itertools
from torch.utils.data import random_split
from src.modules.trainer import OCRTrainer
from src.utils.utils import EarlyStopping, gmkdir
from src.models.crnn import CRNN
from src.options.ss_opts import base_opts
from src.data.pickle_dataset import PickleDataset
from src.data.synth_dataset import SynthCollator
from src.criterions.ctc import CustomCTCLoss
from src.utils.top_sampler import SamplingTop
from main import Learner
class LearnerSemi(Learner):
def __init__(self, model, optimizer, savepath=None, resume=False):
self.model = model
self.optimizer = optimizer
self.savepath = os.path.join(savepath, 'finetuned.ckpt')
self.cuda = torch.cuda.is_available()
self.cuda_count = torch.cuda.device_count()
if self.cuda:
self.model = self.model.cuda()
self.epoch = 0
if self.cuda_count > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
self.model = nn.DataParallel(self.model)
self.best_score = None
def freeze(self, index, boolean=False):
layer = self.get_layer_groups()[index]
for params in layer.parameters():
params.requires_grad = boolean
def freeze_all_but(self, index):
n_layers = len(self.get_layer_groups())
for i in range(n_layers):
self.freeze(i)
self.freeze(index, boolean=True)
def unfreeze(self, index):
self.freeze(index, boolean=True)
def unfreeze_all(self):
n_layers = len(self.get_layer_groups())
for i in range(n_layers):
self.unfreeze(i)
def child(self, x):
return list(x.children())
def recursive_(self, child):
if hasattr(child, 'children'):
if len(self.child(child)) != 0:
child = self.child(child)
return self.recursive_(child)
return child
def get_layer_groups(self):
children = []
for child in self.child(self.model):
children.extend(self.recursive_(child))
children = [child for child in children if list(child.parameters())]
return children
if __name__ == '__main__':
parser = ArgumentParser()
base_opts(parser)
args = parser.parse_args()
# Loading souce data
args.imgdir = 'English_consortium'
args.source_data = SynthDataset(args)
args.collate_fn = SynthCollator()
# Loading target data an splitting
# into train and val
args.imgdir = 'English_unannotated'
target_data = SynthDataset(args)
train_split = int(0.8*len(target_data))
val_split = len(target_data) - train_split
args.data_train, args.data_val = random_split(target_data, (train_split, val_split))
args.alphabet = """Only thewigsofrcvdampbkuq.$A-210xT5'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%"""
args.nClasses = len(args.alphabet)
model = CRNN(args)
model = model.cuda()
args.criterion = CustomCTCLoss()
savepath = os.path.join(args.save_dir, args.name)
gmkdir(savepath)
gmkdir(args.log_dir)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Loading specific model to get top samples
resume_file = savepath + '/' + 'best.ckpt'
print('Loading model %s'%resume_file)
checkpoint = torch.load(resume_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['opt_state_dict'])
# Generating top samples
args.model = model
args.imgdir = 'target_top'
finetunepath = args.path + '/' + args.imgdir
gmkdir(finetunepath)
sampler = SamplingTop(args)
sampler.get_samples(train_on_pred=args.train_on_pred,
combine_scoring=args.combine_scoring)
# Joining source and top samples
args.top_samples = SynthDataset(args)
args.data_train = torch.utils.data.ConcatDataset([args.source_data, args.top_samples])
print('Traininig Data Size:{}\nVal Data Size:{}'.format(
len(args.data_train), len(args.data_val)))
learner = LearnerSemi(args.model, optimizer, savepath=savepath, resume=args.resume)
learner.fit(args)
shutil.rmtree(finetunepath)