1
+ from __future__ import print_function
2
+ import argparse
3
+ import torch
4
+ import torch .nn as nn
5
+ import torch .nn .functional as F
6
+ import torch .optim as optim
7
+ from torch .autograd import Variable
8
+ import models
9
+ import datasets
10
+ import data_transformations
11
+ from prettytable import PrettyTable
12
+ import datetime
13
+ import os
14
+ import time
15
+ import pdb
16
+
17
+ # sanity check for some arguments
18
+ model_names = sorted (name for name in models .__dict__
19
+ if name .islower () and not name .startswith ("__" )
20
+ and callable (models .__dict__ [name ]))
21
+
22
+ dataset_names = sorted (name for name in datasets .__dict__
23
+ if name .islower () and not name .startswith ("__" )
24
+ and callable (datasets .__dict__ [name ]))
25
+
26
+ transformations_names = sorted (name for name in data_transformations .__dict__
27
+ if name .islower () and not name .startswith ("__" )
28
+ and callable (data_transformations .__dict__ [name ]))
29
+
30
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
31
+
32
+ criterion = nn .NLLLoss ().to (device )
33
+
34
+ current_time = str (datetime .datetime .now ().strftime ("%d-%m-%Y_%H-%M-%S" ))
35
+ file = open ("runs/run-" + current_time , "w" )
36
+
37
+ def make_loader (args ):
38
+ data_transforms = data_transformations .__dict__ [args .data_transforms ]
39
+ train_dataset = datasets .__dict__ [args .dataset ](is_train = True , data_transforms = data_transforms )
40
+ val_dataset = datasets .__dict__ [args .dataset ](is_train = False , data_transforms = data_transforms )
41
+ train_loader = torch .utils .data .DataLoader (train_dataset , batch_size = args .batch_size , shuffle = True , num_workers = 1 )
42
+ val_loader = torch .utils .data .DataLoader (val_dataset , batch_size = args .batch_size , shuffle = False , num_workers = 1 )
43
+ return train_loader , val_loader
44
+
45
+ def train (model , epoch , train_loader , optimizer ):
46
+ model .train ()
47
+ training_loss = 0
48
+ for batch_idx , (data , target ) in enumerate (train_loader ):
49
+ data , target = Variable (data .to (device )), Variable (target .to (device ))
50
+ optimizer .zero_grad ()
51
+ output = model (data )
52
+ loss = criterion (output , target )
53
+ loss .backward ()
54
+ optimizer .step ()
55
+ training_loss += loss .data .item ()
56
+ if batch_idx == 10 :
57
+ break
58
+ training_loss /= len (train_loader .dataset )
59
+ return training_loss
60
+
61
+ def validation (model , val_loader ):
62
+ model .eval ()
63
+ validation_loss = 0
64
+ correct = 0
65
+ for batch_idx , (data , target ) in enumerate (val_loader ):
66
+ data , target = Variable (data .to (device ), volatile = True ), Variable (target .to (device ))
67
+ output = model (data )
68
+ validation_loss += criterion (output , target ).data .item () # sum up batch loss
69
+ pred = output .data .max (1 , keepdim = True )[1 ] # get the index of the max log-probability
70
+ correct += pred .eq (target .data .view_as (pred )).cpu ().sum ()
71
+ if batch_idx == 10 :
72
+ break
73
+
74
+ validation_loss /= len (val_loader .dataset )
75
+ return validation_loss , correct .item (), len (val_loader .dataset )
76
+
77
+
78
+ def main (args ):
79
+ torch .manual_seed (args .seed )
80
+ nclasses = datasets .__dict__ [args .dataset ].nclasses
81
+ model = models .__dict__ [args .arch ](nclasses = nclasses )
82
+ # model = torch.nn.DataParallel(model).to(device)
83
+ model .to (device )
84
+ optimizer = optim .SGD (model .parameters (), lr = args .lr , momentum = args .momentum )
85
+ train_loader , val_loader = make_loader (args )
86
+ report = PrettyTable (['Epoch No #' , 'Training loss' , 'Validation loss' , 'Accuracy' , 'Correct' , 'Total' , 'Time in secs' ])
87
+ for epoch in range (1 , args .epochs + 1 ):
88
+ print ("processing epoch {} ..." .format (epoch ))
89
+ start_time = time .time ()
90
+ training_loss = train (model , epoch , train_loader , optimizer )
91
+ validation_loss , correct , total = validation (model , val_loader )
92
+ end_time = time .time ()
93
+ report .add_row ([epoch , round (training_loss , 4 ), round (validation_loss , 4 ), "{}%" .format (round (correct / total , 3 )), correct , total , round (end_time - start_time , 2 )])
94
+ if args .save_model == 'y' :
95
+ val_folder = "saved_model/" + current_time
96
+ if not os .path .isdir (val_folder ):
97
+ os .mkdir (val_folder )
98
+ save_model_file = val_folder + '/model_' + str (epoch ) + '.pth'
99
+ torch .save (model .state_dict (), save_model_file )
100
+ # print('\nSaved model to ' + model_file + '. You can run `python evaluate.py --model' + model_file + '` to generate the Kaggle formatted csv file')
101
+ file .write (report .get_string ())
102
+
103
+
104
+ if __name__ == '__main__' :
105
+ parser = argparse .ArgumentParser (description = 'PyTorch GTSRB example' )
106
+ parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
107
+ help = 'input batch size for training (default: 64)' )
108
+ parser .add_argument ('--epochs' , type = int , default = 3 , metavar = 'N' ,
109
+ help = 'number of epochs to train (default: 10)' )
110
+ parser .add_argument ('--lr' , type = float , default = 0.01 , metavar = 'LR' ,
111
+ help = 'learning rate (default: 0.01)' )
112
+ parser .add_argument ('--momentum' , type = float , default = 0.5 , metavar = 'M' ,
113
+ help = 'SGD momentum (default: 0.5)' )
114
+ parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
115
+ help = 'random seed (default: 1)' )
116
+
117
+ parser .add_argument ('--log_interval' , type = int , default = 10 , metavar = 'N' ,
118
+ help = 'how many batches to wait before logging training status' )
119
+
120
+ parser .add_argument ('--save_model' , type = str , default = 'n' , metavar = 'D' ,
121
+ help = "Do you want to save models for this run or not. (y) for saving the model" )
122
+
123
+ # Model structure
124
+ parser .add_argument ('--arch' , '-a' , metavar = 'ARCH' , default = 'conv_net' ,
125
+ choices = model_names ,
126
+ help = 'model architecture: ' +
127
+ ' | ' .join (model_names ) +
128
+ ' (default: conv_net)' )
129
+ # Dataset setting
130
+ parser .add_argument ('--dataset' , metavar = 'DATASET' , default = 'ssl_data' ,
131
+ choices = dataset_names ,
132
+ help = 'Datasets: ' +
133
+ ' | ' .join (dataset_names ) +
134
+ ' (default: ssl_data)' )
135
+ # Data Transformation setting
136
+ parser .add_argument ('--data_transforms' , metavar = 'DATA_TRANFORMS' , default = 'tensor_transform' ,
137
+ choices = transformations_names ,
138
+ help = 'Datasets: ' +
139
+ ' | ' .join (transformations_names ) +
140
+ ' (default: tensor_transform)' )
141
+ # Printing Information
142
+ args = parser .parse_args ()
143
+
144
+ options = PrettyTable (['option' , 'Value' ])
145
+ for key , val in vars (args ).items ():
146
+ options .add_row ([key , val ])
147
+ options .add_row (["save-model-folder" , current_time ])
148
+ file .write (options .get_string ())
149
+ file .write ("\n " )
150
+
151
+ # creating folders
152
+ if not os .path .isdir ("runs" ):
153
+ os .mkdir ("runs" )
154
+
155
+ if not os .path .isdir ("saved_model" ):
156
+ os .mkdir ("saved_model" )
157
+
158
+ main (parser .parse_args ())
159
+ file .write ("\n " )
160
+ file .close ()
0 commit comments