Skip to content

Commit

Permalink
Merge pull request #9 from piEsposito/main
Browse files Browse the repository at this point in the history
Let the models return prediction only, saving KL Divergence as an attribute
  • Loading branch information
ranganathkrishnan authored Dec 2, 2021
2 parents 7abcfe7 + f1fc4e5 commit daa7292
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 32 deletions.
54 changes: 42 additions & 12 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(self,
self.posterior_rho_init = posterior_rho_init
self.bias = bias

self.kl = 0

self.mu_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size))
self.rho_kernel = nn.Parameter(
Expand Down Expand Up @@ -150,7 +152,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv1d(x,
Expand Down Expand Up @@ -191,8 +193,11 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs


class Conv2dFlipout(BaseVariationalLayer_):
Expand Down Expand Up @@ -244,6 +249,8 @@ def __init__(self,
self.posterior_rho_init = posterior_rho_init
self.bias = bias

self.kl = 0

self.mu_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
Expand Down Expand Up @@ -299,7 +306,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv2d(x,
Expand Down Expand Up @@ -340,8 +347,11 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs


class Conv3dFlipout(BaseVariationalLayer_):
Expand Down Expand Up @@ -388,6 +398,8 @@ def __init__(self,
self.groups = groups
self.bias = bias

self.kl = 0

self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init
Expand Down Expand Up @@ -448,7 +460,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv3d(x,
Expand Down Expand Up @@ -489,8 +501,11 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs


class ConvTranspose1dFlipout(BaseVariationalLayer_):
Expand Down Expand Up @@ -537,6 +552,8 @@ def __init__(self,
self.groups = groups
self.bias = bias

self.kl = 0

self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init
Expand Down Expand Up @@ -593,7 +610,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv_transpose1d(x,
Expand Down Expand Up @@ -635,8 +652,11 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs


class ConvTranspose2dFlipout(BaseVariationalLayer_):
Expand Down Expand Up @@ -683,6 +703,8 @@ def __init__(self,
self.groups = groups
self.bias = bias

self.kl = 0

self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init
Expand Down Expand Up @@ -743,7 +765,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv_transpose2d(x,
Expand Down Expand Up @@ -785,8 +807,11 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs


class ConvTranspose3dFlipout(BaseVariationalLayer_):
Expand Down Expand Up @@ -838,6 +863,8 @@ def __init__(self,
self.posterior_rho_init = posterior_rho_init
self.bias = bias

self.kl = 0

self.mu_kernel = nn.Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size))
Expand Down Expand Up @@ -893,7 +920,7 @@ def init_parameters(self):
self.prior_bias_mu.data.fill_(self.prior_mean)
self.prior_bias_sigma.data.fill_(self.prior_variance)

def forward(self, x):
def forward(self, x, return_kl=True):

# linear outputs
outputs = F.conv_transpose3d(x,
Expand Down Expand Up @@ -935,5 +962,8 @@ def forward(self, x):
dilation=self.dilation,
groups=self.groups) * sign_output

self.kl = kl
# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs
10 changes: 8 additions & 2 deletions bayesian_torch/layers/flipout_layers/linear_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(self,
torch.Tensor(out_features, in_features),
persistent=False)

self.kl = 0

if bias:
self.mu_bias = nn.Parameter(torch.Tensor(out_features))
self.rho_bias = nn.Parameter(torch.Tensor(out_features))
Expand Down Expand Up @@ -123,7 +125,7 @@ def init_parameters(self):
self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1)
self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1)

def forward(self, x):
def forward(self, x, return_kl=True):
# sampling delta_W
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
delta_weight = (sigma_weight * self.eps_weight.data.normal_())
Expand All @@ -148,5 +150,9 @@ def forward(self, x):
perturbed_outputs = F.linear(x * sign_input, delta_weight,
bias) * sign_output

self.kl = kl

# returning outputs + perturbations
return outputs + perturbed_outputs, kl
if return_kl:
return outputs + perturbed_outputs, kl
return outputs + perturbed_outputs
9 changes: 7 additions & 2 deletions bayesian_torch/layers/flipout_layers/rnn_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __init__(self,
self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
self.bias = bias

self.kl = 0

self.ih = LinearFlipout(prior_mean=prior_mean,
prior_variance=prior_variance,
posterior_mu_init=posterior_mu_init,
Expand All @@ -92,7 +94,7 @@ def __init__(self,
out_features=out_features * 4,
bias=bias)

def forward(self, X, hidden_states=None):
def forward(self, X, hidden_states=None, return_kl=True):

batch_size, seq_size, _ = X.size()

Expand Down Expand Up @@ -137,4 +139,7 @@ def forward(self, X, hidden_states=None):
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
c_ts = c_ts.transpose(0, 1).contiguous()

return hidden_seq, (hidden_seq, c_ts), kl
self.kl = kl
if return_kl:
return hidden_seq, (hidden_seq, c_ts), kl
return hidden_seq, (hidden_seq, c_ts)
Loading

0 comments on commit daa7292

Please sign in to comment.