forked from moskomule/l0.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.py
109 lines (87 loc) · 3.65 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import tensorboardX
from trainer import Trainer
from l0module import L0Linear, L0Conv2d
def get_loader(batch_size):
data_root = os.path.expanduser('~/.torch/data/mnist')
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_root, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), shuffle=True, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_root, train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size)
return train_loader, test_loader
class L0Net(nn.Module):
"""
based on Caffe LeNet (https://github.com/BVLC/caffe/blob/master/examples/mnist/lenet.prototxt)
"""
def __init__(self, mean):
super(L0Net, self).__init__()
self.conv1 = L0Conv2d(1, 20, kernel_size=5, stride=1, loc_mean=mean)
self.conv2 = L0Conv2d(20, 50, kernel_size=5, stride=1, loc_mean=mean)
self.dense1 = L0Linear(800, 500, loc_mean=mean)
self.dense2 = L0Linear(500, 10, loc_mean=mean)
def forward(self, x):
x, z1 = self.conv1(x)
x = F.max_pool2d(x, 2, stride=2)
x, z2 = self.conv2(x)
x = F.max_pool2d(x, 2, stride=2)
x = x.view(x.shape[0], -1)
x, z3 = self.dense1(x)
x = F.relu(x)
x, z4 = self.dense2(x)
penalty = z1 + z2 + z3 + z4
return x, penalty
class LeNet(nn.Module):
"""
based on Caffe LeNet (https://github.com/BVLC/caffe/blob/master/examples/mnist/lenet.prototxt)
"""
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
self.dense1 = nn.Linear(800, 500)
self.dense2 = nn.Linear(500, 10)
def forward(self, x):
x = self.conv1(x)
x = F.max_pool2d(x, 2, stride=2)
x = self.conv2(x)
x = F.max_pool2d(x, 2, stride=2)
x = x.view(x.shape[0], -1)
x = self.dense1(x)
x = F.relu(x)
x = self.dense2(x)
return x
def main(epochs, batch_size, coef, mean, baseline):
train_loader, test_loader = get_loader(batch_size)
logger = tensorboardX.SummaryWriter()
model = LeNet() if baseline else L0Net(mean)
model.cuda()
optimizer = optim.Adam(params=model.parameters(), lr=1e-3)
l0_loss = lambda output, target: F.cross_entropy(output[0], target) + coef / len(train_loader.dataset) * output[1]
loss_f = F.cross_entropy if baseline else l0_loss
trainer = Trainer(model, optimizer, loss_f, logger)
if not baseline:
correct = lambda output, target: (output[0].max(dim=1)[1] == target).sum().data[0]
trainer.correct = correct
trainer.start(epochs, train_loader, test_loader)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--coef", type=float, default=1e-1)
parser.add_argument("--mean", type=float, default=1)
parser.add_argument("--baseline", action="store_true")
main(**vars(parser.parse_args()))