forked from moskomule/l0.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathl0module.py
82 lines (67 loc) · 3.26 KB
/
l0module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import math
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
def hard_sigmoid(x):
return torch.min(torch.max(x, torch.zeros_like(x)), torch.ones_like(x))
class _L0Norm(nn.Module):
def __init__(self, origin, loc_mean=0, loc_sdev=0.01, beta=2 / 3, gamma=-0.1,
zeta=1.1, fix_temp=True):
"""
Base class of layers using L0 Norm
:param origin: original layer such as nn.Linear(..), nn.Conv2d(..)
:param loc_mean: mean of the normal distribution which generates initial location parameters
:param loc_sdev: standard deviation of the normal distribution which generates initial location parameters
:param beta: initial temperature parameter
:param gamma: lower bound of "stretched" s
:param zeta: upper bound of "stretched" s
:param fix_temp: True if temperature is fixed
"""
super(_L0Norm, self).__init__()
self._origin = origin
self._size = self._origin.weight.size()
self.loc = nn.Parameter(torch.zeros(self._size).normal_(loc_mean, loc_sdev))
self.temp = beta if fix_temp else nn.Parameter(torch.zeros(1).fill_(beta))
self.register_buffer("uniform", torch.zeros(self._size))
self.gamma = gamma
self.zeta = zeta
self.gamma_zeta_ratio = math.log(-gamma / zeta)
def _get_mask(self):
if self.training:
self.uniform.uniform_()
u = Variable(self.uniform)
s = F.sigmoid((torch.log(u) - torch.log(1 - u) + self.loc) / self.temp)
s = s * (self.zeta - self.gamma) + self.gamma
penalty = F.sigmoid(self.loc - self.temp * self.gamma_zeta_ratio).sum()
else:
s = F.sigmoid(self.loc) * (self.zeta - self.gamma) + self.gamma
penalty = 0
return hard_sigmoid(s), penalty
class L0Linear(_L0Norm):
def __init__(self, in_features, out_features, bias=True, **kwargs):
super(L0Linear, self).__init__(nn.Linear(in_features, out_features, bias=bias), **kwargs)
def forward(self, input):
mask, penalty = self._get_mask()
return F.linear(input, self._origin.weight * mask, self._origin.bias), penalty
class L0Conv2d(_L0Norm):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
**kwargs):
super(L0Conv2d, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias), **kwargs)
def forward(self, input):
mask, penalty = self._get_mask()
conv = F.conv2d(input, self._origin.weight * mask, self._origin.bias, stride=self._origin.stride,
padding=self._origin.padding, dilation=self._origin.dilation, groups=self._origin.groups)
return conv, penalty
class L0Sequential(nn.Sequential):
def forward(self, input):
penalty = 0
for module in self._modules.values():
output = module(input)
if isinstance(output, tuple):
input = output[0]
penalty += output[1]
else:
input = output
return input, penalty