From a8543ad20672c7f4890741ff2fd9d8b0197d5570 Mon Sep 17 00:00:00 2001 From: msubedar Date: Wed, 7 Dec 2022 13:42:09 -0800 Subject: [PATCH] Added support for arbitrary kernel sizes for Bayesian Conv layers --- .../layers/base_variational_layer.py | 6 ++ .../layers/flipout_layers/conv_flipout.py | 90 ++++++++++--------- .../variational_layers/conv_variational.py | 90 ++++++++++--------- bayesian_torch/models/dnn_to_bnn.py | 2 +- 4 files changed, 99 insertions(+), 89 deletions(-) diff --git a/bayesian_torch/layers/base_variational_layer.py b/bayesian_torch/layers/base_variational_layer.py index 4d63cc9..8263e82 100644 --- a/bayesian_torch/layers/base_variational_layer.py +++ b/bayesian_torch/layers/base_variational_layer.py @@ -29,7 +29,13 @@ import torch import torch.nn as nn import torch.distributions as distributions +from itertools import repeat +import collections +def get_kernel_size(x, n): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) class BaseVariationalLayer_(nn.Module): def __init__(self): diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py index ce13897..311462a 100644 --- a/bayesian_torch/layers/flipout_layers/conv_flipout.py +++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py @@ -36,7 +36,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..base_variational_layer import BaseVariationalLayer_ +from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size from torch.distributions.normal import Normal from torch.distributions.uniform import Uniform @@ -263,28 +263,28 @@ def __init__(self, self.bias = bias self.kl = 0 - + kernel_size = get_kernel_size(kernel_size, 2) self.mu_kernel = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1])) self.rho_kernel = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1])) self.register_buffer( 'eps_kernel', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) if self.bias: @@ -430,27 +430,29 @@ def __init__(self, self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init + kernel_size = get_kernel_size(kernel_size, 3) + self.mu_kernel = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.rho_kernel = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.register_buffer( 'eps_kernel', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) if self.bias: @@ -760,28 +762,28 @@ def __init__(self, self.prior_variance = prior_variance self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init - + kernel_size = get_kernel_size(kernel_size, 2) self.mu_kernel = nn.Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1])) self.rho_kernel = nn.Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1])) self.register_buffer( 'eps_kernel', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) if self.bias: @@ -928,28 +930,28 @@ def __init__(self, self.bias = bias self.kl = 0 - + kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = nn.Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.rho_kernel = nn.Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.register_buffer( 'eps_kernel', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) if self.bias: diff --git a/bayesian_torch/layers/variational_layers/conv_variational.py b/bayesian_torch/layers/variational_layers/conv_variational.py index 7855ad8..0d2ebfd 100644 --- a/bayesian_torch/layers/variational_layers/conv_variational.py +++ b/bayesian_torch/layers/variational_layers/conv_variational.py @@ -46,7 +46,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter -from ..base_variational_layer import BaseVariationalLayer_ +from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size import math __all__ = [ @@ -255,26 +255,28 @@ def __init__(self, self.posterior_rho_init = posterior_rho_init, self.bias = bias + kernel_size = get_kernel_size(kernel_size, 2) + self.mu_kernel = Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1])) self.rho_kernel = Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1])) self.register_buffer( 'eps_kernel', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) if self.bias: @@ -403,27 +405,27 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias - + kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.rho_kernel = Parameter( - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.register_buffer( 'eps_kernel', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(out_channels, in_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) if self.bias: @@ -698,27 +700,27 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias - + kernel_size = get_kernel_size(kernel_size, 2) self.mu_kernel = Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1])) self.rho_kernel = Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1])) self.register_buffer( 'eps_kernel', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1]), persistent=False) if self.bias: @@ -850,27 +852,27 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias - + kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.rho_kernel = Parameter( - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size)) + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2])) self.register_buffer( 'eps_kernel', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_mu', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) self.register_buffer( 'prior_weight_sigma', - torch.Tensor(in_channels, out_channels // groups, kernel_size, - kernel_size, kernel_size), + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], + kernel_size[1], kernel_size[2]), persistent=False) if self.bias: diff --git a/bayesian_torch/models/dnn_to_bnn.py b/bayesian_torch/models/dnn_to_bnn.py index 18b9b51..92e18b4 100644 --- a/bayesian_torch/models/dnn_to_bnn.py +++ b/bayesian_torch/models/dnn_to_bnn.py @@ -79,7 +79,7 @@ def bnn_conv_layer(params, d): bnn_layer = layer_fn( in_channels=d.in_channels, out_channels=d.out_channels, - kernel_size=d.kernel_size[0], + kernel_size=d.kernel_size, stride=d.stride, padding=d.padding, dilation=d.dilation,