Skip to content

Commit

Permalink
Merge pull request #18 from msubedar/main
Browse files Browse the repository at this point in the history
Added support for arbitrary kernel sizes for Bayesian Conv layers
  • Loading branch information
ranganathkrishnan authored Dec 8, 2022
2 parents f6f516e + a8543ad commit 984fd10
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 89 deletions.
6 changes: 6 additions & 0 deletions bayesian_torch/layers/base_variational_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
import torch
import torch.nn as nn
import torch.distributions as distributions
from itertools import repeat
import collections

def get_kernel_size(x, n):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))

class BaseVariationalLayer_(nn.Module):
def __init__(self):
Expand Down
90 changes: 46 additions & 44 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base_variational_layer import BaseVariationalLayer_
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size

from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
Expand Down Expand Up @@ -263,28 +263,28 @@ def __init__(self,
self.bias = bias

self.kl = 0

kernel_size = get_kernel_size(kernel_size, 2)
self.mu_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))
self.rho_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))

self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -430,27 +430,29 @@ def __init__(self,
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init

kernel_size = get_kernel_size(kernel_size, 3)

self.mu_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.rho_kernel = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))

self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -760,28 +762,28 @@ def __init__(self,
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init

kernel_size = get_kernel_size(kernel_size, 2)
self.mu_kernel = nn.Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]))
self.rho_kernel = nn.Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]))

self.register_buffer(
'eps_kernel',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -928,28 +930,28 @@ def __init__(self,
self.bias = bias

self.kl = 0

kernel_size = get_kernel_size(kernel_size, 3)
self.mu_kernel = nn.Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.rho_kernel = nn.Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))

self.register_buffer(
'eps_kernel',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)

if self.bias:
Expand Down
90 changes: 46 additions & 44 deletions bayesian_torch/layers/variational_layers/conv_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..base_variational_layer import BaseVariationalLayer_
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size
import math

__all__ = [
Expand Down Expand Up @@ -255,26 +255,28 @@ def __init__(self,
self.posterior_rho_init = posterior_rho_init,
self.bias = bias

kernel_size = get_kernel_size(kernel_size, 2)

self.mu_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))
self.rho_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]))
self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -403,27 +405,27 @@ def __init__(self,
# variance of weight --> sigma = log (1 + exp(rho))
self.posterior_rho_init = posterior_rho_init,
self.bias = bias

kernel_size = get_kernel_size(kernel_size, 3)
self.mu_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.rho_kernel = Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.register_buffer(
'eps_kernel',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -698,27 +700,27 @@ def __init__(self,
# variance of weight --> sigma = log (1 + exp(rho))
self.posterior_rho_init = posterior_rho_init,
self.bias = bias

kernel_size = get_kernel_size(kernel_size, 2)
self.mu_kernel = Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]))
self.rho_kernel = Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]))
self.register_buffer(
'eps_kernel',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1]),
persistent=False)

if self.bias:
Expand Down Expand Up @@ -850,27 +852,27 @@ def __init__(self,
# variance of weight --> sigma = log (1 + exp(rho))
self.posterior_rho_init = posterior_rho_init,
self.bias = bias

kernel_size = get_kernel_size(kernel_size, 3)
self.mu_kernel = Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.rho_kernel = Parameter(
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size))
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]))
self.register_buffer(
'eps_kernel',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_mu',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)
self.register_buffer(
'prior_weight_sigma',
torch.Tensor(in_channels, out_channels // groups, kernel_size,
kernel_size, kernel_size),
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
kernel_size[1], kernel_size[2]),
persistent=False)

if self.bias:
Expand Down
2 changes: 1 addition & 1 deletion bayesian_torch/models/dnn_to_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def bnn_conv_layer(params, d):
bnn_layer = layer_fn(
in_channels=d.in_channels,
out_channels=d.out_channels,
kernel_size=d.kernel_size[0],
kernel_size=d.kernel_size,
stride=d.stride,
padding=d.padding,
dilation=d.dilation,
Expand Down

0 comments on commit 984fd10

Please sign in to comment.