Skip to content

Commit 33ab644

Browse files
committed
init repository
0 parents  commit 33ab644

7 files changed

+219
-0
lines changed

.gitignore

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
*~
2+
data
3+
runs
4+
saved_model
5+
*.pyc
6+
*.pth
7+
*.csv
8+
.DS_Store

data_transformations.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torchvision.transforms as transforms
2+
3+
__all__ = ['tensor_transform']
4+
5+
# custom class for data transformation can also be written
6+
tensor_transform = transforms.Compose([
7+
transforms.ToTensor()
8+
])

datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ssl_data import *

datasets/ssl_data.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torchvision import datasets
2+
3+
def ssl_data(**kwargs):
4+
if kwargs['is_train']:
5+
dataset = datasets.ImageFolder('data/ssl_data_96/supervised/train', transform = kwargs['data_transforms'])
6+
else:
7+
dataset = datasets.ImageFolder('data/ssl_data_96/supervised/val', transform = kwargs['data_transforms'])
8+
return dataset
9+
10+
ssl_data.nclasses = 1000 # ugly but works

main.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import print_function
2+
import argparse
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torch.optim as optim
7+
from torch.autograd import Variable
8+
import models
9+
import datasets
10+
import data_transformations
11+
from prettytable import PrettyTable
12+
import datetime
13+
import os
14+
import time
15+
import pdb
16+
17+
# sanity check for some arguments
18+
model_names = sorted(name for name in models.__dict__
19+
if name.islower() and not name.startswith("__")
20+
and callable(models.__dict__[name]))
21+
22+
dataset_names = sorted(name for name in datasets.__dict__
23+
if name.islower() and not name.startswith("__")
24+
and callable(datasets.__dict__[name]))
25+
26+
transformations_names = sorted(name for name in data_transformations.__dict__
27+
if name.islower() and not name.startswith("__")
28+
and callable(data_transformations.__dict__[name]))
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
32+
criterion = nn.NLLLoss().to(device)
33+
34+
current_time = str(datetime.datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
35+
file = open("runs/run-" + current_time, "w")
36+
37+
def make_loader(args):
38+
data_transforms = data_transformations.__dict__[args.data_transforms]
39+
train_dataset = datasets.__dict__[args.dataset](is_train = True, data_transforms = data_transforms)
40+
val_dataset = datasets.__dict__[args.dataset](is_train = False, data_transforms = data_transforms)
41+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)
42+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1)
43+
return train_loader, val_loader
44+
45+
def train(model, epoch, train_loader, optimizer):
46+
model.train()
47+
training_loss = 0
48+
for batch_idx, (data, target) in enumerate(train_loader):
49+
data, target = Variable(data.to(device)), Variable(target.to(device))
50+
optimizer.zero_grad()
51+
output = model(data)
52+
loss = criterion(output, target)
53+
loss.backward()
54+
optimizer.step()
55+
training_loss += loss.data.item()
56+
if batch_idx == 10:
57+
break
58+
training_loss /= len(train_loader.dataset)
59+
return training_loss
60+
61+
def validation(model, val_loader):
62+
model.eval()
63+
validation_loss = 0
64+
correct = 0
65+
for batch_idx, (data, target) in enumerate(val_loader):
66+
data, target = Variable(data.to(device), volatile=True), Variable(target.to(device))
67+
output = model(data)
68+
validation_loss += criterion(output, target).data.item() # sum up batch loss
69+
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
70+
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
71+
if batch_idx == 10:
72+
break
73+
74+
validation_loss /= len(val_loader.dataset)
75+
return validation_loss, correct.item(), len(val_loader.dataset)
76+
77+
78+
def main(args):
79+
torch.manual_seed(args.seed)
80+
nclasses = datasets.__dict__[args.dataset].nclasses
81+
model = models.__dict__[args.arch](nclasses = nclasses)
82+
# model = torch.nn.DataParallel(model).to(device)
83+
model.to(device)
84+
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
85+
train_loader, val_loader = make_loader(args)
86+
report = PrettyTable(['Epoch No #', 'Training loss', 'Validation loss', 'Accuracy', 'Correct', 'Total', 'Time in secs'])
87+
for epoch in range(1, args.epochs + 1):
88+
print("processing epoch {} ...".format(epoch))
89+
start_time = time.time()
90+
training_loss = train(model, epoch, train_loader, optimizer)
91+
validation_loss, correct, total = validation(model, val_loader)
92+
end_time = time.time()
93+
report.add_row([epoch, round(training_loss, 4), round(validation_loss, 4), "{}%".format(round(correct/total, 3)), correct, total, round(end_time - start_time, 2)])
94+
if args.save_model == 'y':
95+
val_folder = "saved_model/" + current_time
96+
if not os.path.isdir(val_folder):
97+
os.mkdir(val_folder)
98+
save_model_file = val_folder + '/model_' + str(epoch) +'.pth'
99+
torch.save(model.state_dict(), save_model_file)
100+
# print('\nSaved model to ' + model_file + '. You can run `python evaluate.py --model' + model_file + '` to generate the Kaggle formatted csv file')
101+
file.write(report.get_string())
102+
103+
104+
if __name__ == '__main__':
105+
parser = argparse.ArgumentParser(description='PyTorch GTSRB example')
106+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
107+
help='input batch size for training (default: 64)')
108+
parser.add_argument('--epochs', type=int, default=3, metavar='N',
109+
help='number of epochs to train (default: 10)')
110+
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
111+
help='learning rate (default: 0.01)')
112+
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
113+
help='SGD momentum (default: 0.5)')
114+
parser.add_argument('--seed', type=int, default=1, metavar='S',
115+
help='random seed (default: 1)')
116+
117+
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
118+
help='how many batches to wait before logging training status')
119+
120+
parser.add_argument('--save_model', type=str, default='n', metavar='D',
121+
help="Do you want to save models for this run or not. (y) for saving the model")
122+
123+
# Model structure
124+
parser.add_argument('--arch', '-a', metavar='ARCH', default='conv_net',
125+
choices=model_names,
126+
help='model architecture: ' +
127+
' | '.join(model_names) +
128+
' (default: conv_net)')
129+
# Dataset setting
130+
parser.add_argument('--dataset', metavar='DATASET', default='ssl_data',
131+
choices=dataset_names,
132+
help='Datasets: ' +
133+
' | '.join(dataset_names) +
134+
' (default: ssl_data)')
135+
# Data Transformation setting
136+
parser.add_argument('--data_transforms', metavar='DATA_TRANFORMS', default='tensor_transform',
137+
choices=transformations_names,
138+
help='Datasets: ' +
139+
' | '.join(transformations_names) +
140+
' (default: tensor_transform)')
141+
# Printing Information
142+
args = parser.parse_args()
143+
144+
options = PrettyTable(['option', 'Value'])
145+
for key, val in vars(args).items():
146+
options.add_row([key, val])
147+
options.add_row(["save-model-folder", current_time])
148+
file.write(options.get_string())
149+
file.write("\n")
150+
151+
# creating folders
152+
if not os.path.isdir("runs"):
153+
os.mkdir("runs")
154+
155+
if not os.path.isdir("saved_model"):
156+
os.mkdir("saved_model")
157+
158+
main(parser.parse_args())
159+
file.write("\n")
160+
file.close()

models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .simple_convnet import *

models/simple_convnet.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import pdb
5+
6+
__all__ = ['conv_net']
7+
8+
class ConvNet(nn.Module):
9+
def __init__(self, nclasses):
10+
super(ConvNet, self).__init__()
11+
self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
12+
self.conv2 = nn.Conv2d(32, 128, kernel_size=3)
13+
self.conv3 = nn.Conv2d(128, 256, kernel_size=3)
14+
self.conv2_drop = nn.Dropout2d()
15+
self.fc1 = nn.Linear(25600, 1024)
16+
self.fc2 = nn.Linear(1024, nclasses)
17+
18+
def forward(self, x):
19+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv1(x)), 2))
20+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
21+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv3(x)), 2))
22+
x = x.view(-1, 25600)
23+
x = F.relu(self.fc1(x))
24+
x = F.dropout(x, training=self.training)
25+
x = self.fc2(x)
26+
return F.log_softmax(x, dim=1)
27+
28+
# model-name - should be same as value provided in the argument for model
29+
def conv_net(**kwargs):
30+
model = ConvNet(kwargs['nclasses'])
31+
return model

0 commit comments

Comments
 (0)