Skip to content

Commit ac184bc

Browse files
committed
Refactor model module, add new net
1 parent e46a28b commit ac184bc

File tree

13 files changed

+210
-39
lines changed

13 files changed

+210
-39
lines changed

demos/BIN_CIFAR10.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
sys.path.insert(0, os.path.abspath('..'))
44

55
import torch
6-
from models.Alexnet_Bin import *
7-
from models.Alexnet_BinTest import *
6+
from models.Alexnet.Alexnet_Bin import *
7+
from models.Alexnet.Alexnet_BinTest import *
88
from torchvision import transforms
99
import torchvision
1010

1111

1212
from device import device
1313

1414

15-
BATCH_SIZE = 128
15+
BATCH_SIZE = 3
1616

1717
model = AlexNetBin().to(device)
1818

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import sys
33
sys.path.insert(0, os.path.abspath('..'))
44

5-
6-
from models.binMNIST import BinMNIST
5+
#.9833
6+
from models.ConvNet.binMNIST_conv import BinMNIST
77
import torch
88
import torch.utils.data
99
import numpy as np
1010
import torchvision
11+
from progress.bar import ChargingBar
1112

1213

1314

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import sys
33
sys.path.insert(0, os.path.abspath('..'))
44

5-
6-
from models.binMNIST_conv import BinMNIST
5+
# 0.9623
6+
from models.FullNet.binMNIST import BinMNIST
77
import torch
88
import torch.utils.data
99
import numpy as np

demos/LIN_CIFAR10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
sys.path.insert(0, os.path.abspath('..'))
44
# 82
55
import torch
6-
from models.VGG_LinQuant import VGGLinQuant
6+
from models.VGG.VGG_LinQuant import VGGLinQuant
77
from torchvision import transforms
88
import torchvision
99
from progress.bar import ChargingBar

demos/LIN_MNIST.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
sys.path.insert(0, os.path.abspath('..'))
44

55

6-
from models.linQuantMNIST import LinQuantMNIST
6+
from models.FullNet.linQuantMNIST import LinQuantMNIST
77
import torch
88
import torch.utils.data
99
import numpy as np
1010
import torchvision
1111

12-
1312
from device import device
1413

1514

@@ -18,6 +17,8 @@
1817
MOMENTUM = 0.6
1918
EPOCH = 600
2019
DATASET_SIZE = 60000
20+
21+
2122
def adjust_learning_rate(optimizer, epoch):
2223
global LEARNING_RATE
2324
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ def __init__(self, num_classes=10):
3636
nn.MaxPool2d(kernel_size=3, stride=2),
3737
torch.nn.BatchNorm2d(256),
3838
nn.Hardtanh(inplace=True),
39-
40-
4139
)
4240
self.classifieur = nn.Sequential(
4341
BinaryConnect(stochastic=False),

models/Alexnet/Alexnet_BinTest.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
2+
3+
4+
5+
import torch.nn as nn
6+
import torchvision.transforms as transforms
7+
import torch
8+
import pdb
9+
import torch.nn as nn
10+
import math
11+
from torch.autograd import Variable
12+
from torch.autograd import Function
13+
import numpy as np
14+
15+
16+
def Binarize(tensor,quant_mode='det'):
17+
if quant_mode=='det':
18+
return tensor.sign()
19+
else:
20+
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
21+
22+
23+
24+
25+
class HingeLoss(nn.Module):
26+
def __init__(self):
27+
super(HingeLoss,self).__init__()
28+
self.margin=1.0
29+
30+
def hinge_loss(self,input,target):
31+
#import pdb; pdb.set_trace()
32+
output=self.margin-input.mul(target)
33+
output[output.le(0)]=0
34+
return output.mean()
35+
36+
def forward(self, input, target):
37+
return self.hinge_loss(input,target)
38+
39+
class SqrtHingeLossFunction(Function):
40+
def __init__(self):
41+
super(SqrtHingeLossFunction,self).__init__()
42+
self.margin=1.0
43+
44+
def forward(self, input, target):
45+
output=self.margin-input.mul(target)
46+
output[output.le(0)]=0
47+
self.save_for_backward(input, target)
48+
loss=output.mul(output).sum(0).sum(1).div(target.numel())
49+
return loss
50+
51+
def backward(self,grad_output):
52+
input, target = self.saved_tensors
53+
output=self.margin-input.mul(target)
54+
output[output.le(0)]=0
55+
import pdb; pdb.set_trace()
56+
grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
57+
grad_output.mul_(output.ne(0).float())
58+
grad_output.div_(input.numel())
59+
return grad_output,grad_output
60+
61+
def Quantize(tensor,quant_mode='det', params=None, numBits=8):
62+
tensor.clamp_(-2**(numBits-1),2**(numBits-1))
63+
if quant_mode=='det':
64+
tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
65+
else:
66+
tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
67+
quant_fixed(tensor, params)
68+
return tensor
69+
70+
import torch.nn._functions as tnnf
71+
72+
73+
class BinarizeLinear(nn.Linear):
74+
75+
def __init__(self, *kargs, **kwargs):
76+
super(BinarizeLinear, self).__init__(*kargs, **kwargs)
77+
78+
def forward(self, input):
79+
80+
if input.size(1) != 784:
81+
input.data=Binarize(input.data)
82+
if not hasattr(self.weight,'org'):
83+
self.weight.org=self.weight.data.clone()
84+
self.weight.data=Binarize(self.weight.org)
85+
out = nn.functional.linear(input, self.weight)
86+
if not self.bias is None:
87+
self.bias.org=self.bias.data.clone()
88+
out += self.bias.view(1, -1).expand_as(out)
89+
90+
return out
91+
92+
class BinarizeConv2d(nn.Conv2d):
93+
94+
def __init__(self, *kargs, **kwargs):
95+
super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
96+
97+
98+
def forward(self, input):
99+
if input.size(1) != 3:
100+
input.data = Binarize(input.data)
101+
if not hasattr(self.weight,'org'):
102+
self.weight.org=self.weight.data.clone()
103+
self.weight.data=Binarize(self.weight.org)
104+
105+
out = nn.functional.conv2d(input, self.weight, None, self.stride,
106+
self.padding, self.dilation, self.groups)
107+
108+
if not self.bias is None:
109+
self.bias.org=self.bias.data.clone()
110+
out += self.bias.view(1, -1, 1, 1).expand_as(out)
111+
112+
return out
113+
114+
115+
class AlexNetOWT_BN(nn.Module):
116+
117+
def __init__(self, num_classes=10):
118+
super(AlexNetOWT_BN, self).__init__()
119+
print("lol")
120+
self.ratioInfl=3
121+
self.features = nn.Sequential(
122+
BinarizeConv2d(3, int(64*self.ratioInfl), kernel_size=11, stride=4, padding=2),
123+
nn.MaxPool2d(kernel_size=3, stride=2),
124+
nn.BatchNorm2d(int(64*self.ratioInfl)),
125+
nn.Hardtanh(inplace=True),
126+
BinarizeConv2d(int(64*self.ratioInfl), int(192*self.ratioInfl), kernel_size=5, padding=2),
127+
nn.MaxPool2d(kernel_size=3, stride=2),
128+
nn.BatchNorm2d(int(192*self.ratioInfl)),
129+
nn.Hardtanh(inplace=True),
130+
131+
BinarizeConv2d(int(192*self.ratioInfl), int(384*self.ratioInfl), kernel_size=3, padding=1),
132+
nn.BatchNorm2d(int(384*self.ratioInfl)),
133+
nn.Hardtanh(inplace=True),
134+
135+
BinarizeConv2d(int(384*self.ratioInfl), int(256*self.ratioInfl), kernel_size=3, padding=1),
136+
nn.BatchNorm2d(int(256*self.ratioInfl)),
137+
nn.Hardtanh(inplace=True),
138+
139+
BinarizeConv2d(int(256*self.ratioInfl), 256, kernel_size=3, padding=1),
140+
nn.MaxPool2d(kernel_size=3, stride=2),
141+
nn.BatchNorm2d(256),
142+
nn.Hardtanh(inplace=True)
143+
144+
)
145+
self.classifier = nn.Sequential(
146+
BinarizeLinear(256 * 6 * 6, 4096),
147+
nn.BatchNorm1d(4096),
148+
nn.Hardtanh(inplace=True),
149+
#nn.Dropout(0.5),
150+
BinarizeLinear(4096, 4096),
151+
nn.BatchNorm1d(4096),
152+
nn.Hardtanh(inplace=True),
153+
#nn.Dropout(0.5),
154+
BinarizeLinear(4096, 10),
155+
nn.BatchNorm1d(10),
156+
nn.LogSoftmax()
157+
)
158+
159+
#self.regime = {
160+
# 0: {'optimizer': 'SGD', 'lr': 1e-2,
161+
# 'weight_decay': 5e-4, 'momentum': 0.9},
162+
# 10: {'lr': 5e-3},
163+
# 15: {'lr': 1e-3, 'weight_decay': 0},
164+
# 20: {'lr': 5e-4},
165+
# 25: {'lr': 1e-4}
166+
#}
167+
self.regime = {
168+
0: {'optimizer': 'Adam', 'lr': 5e-3},
169+
20: {'lr': 1e-3},
170+
30: {'lr': 5e-4},
171+
35: {'lr': 1e-4},
172+
40: {'lr': 1e-5}
173+
}
174+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
175+
std=[0.229, 0.224, 0.225])
176+
self.input_transform = {
177+
'train': transforms.Compose([
178+
transforms.Scale(256),
179+
transforms.RandomCrop(224),
180+
transforms.RandomHorizontalFlip(),
181+
transforms.ToTensor(),
182+
normalize
183+
]),
184+
'eval': transforms.Compose([
185+
transforms.Scale(256),
186+
transforms.CenterCrop(224),
187+
transforms.ToTensor(),
188+
normalize
189+
])
190+
}
191+
192+
def forward(self, x):
193+
x = self.features(x)
194+
x = x.view(-1, 256 * 6 * 6)
195+
x = self.classifier(x)
196+
return x
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def reset(self):
3232

3333

3434
def clamp(self):
35-
#self.linear1.clamp()
36-
#self.linear2.clamp()
35+
self.conv1.clamp()
36+
self.conv2.clamp()
3737
self.linear3.clamp()
3838
self.linear4.clamp()
3939

0 commit comments

Comments
 (0)