Skip to content

Densenet canonizations #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 255 additions & 2 deletions src/zennit/canonizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Canonizer(metaclass=ABCMeta):
'''Canonizer Base class.
Canonizers modify modules temporarily such that certain attribution rules can properly be applied.
'''

@abstractmethod
def apply(self, root_module):
'''Apply this canonizer recursively on all applicable modules.
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(self):
super().__init__()
self.linears = None
self.batch_norm = None

self.batch_norm_eps = None
self.linear_params = None
self.batch_norm_params = None

Expand All @@ -94,6 +95,7 @@ def register(self, linears, batch_norm):
key: getattr(self.batch_norm, key).data for key in ('weight', 'bias', 'running_mean', 'running_var')
}

self.batch_norm_eps = batch_norm.eps
self.merge_batch_norm(self.linears, self.batch_norm)

def remove(self):
Expand All @@ -110,6 +112,8 @@ def remove(self):
for key, value in self.batch_norm_params.items():
getattr(self.batch_norm, key).data = value

self.batch_norm.eps = self.batch_norm_eps

@staticmethod
def merge_batch_norm(modules, batch_norm):
'''Update parameters of a linear layer to additionally include a Batch Normalization operation and update the
Expand Down Expand Up @@ -148,6 +152,7 @@ def merge_batch_norm(modules, batch_norm):
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data)
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data)
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data)
batch_norm.eps = 0.


class SequentialMergeBatchNorm(MergeBatchNorm):
Expand All @@ -162,6 +167,7 @@ class SequentialMergeBatchNorm(MergeBatchNorm):
to properly detect when there is an activation function between linear and batch-norm modules.

'''

def apply(self, root_module):
'''Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge
them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity.
Expand All @@ -181,13 +187,256 @@ def apply(self, root_module):
last_leaf = None
for leaf in collect_leaves(root_module):
if isinstance(last_leaf, self.linear_type) and isinstance(leaf, self.batch_norm_type):
if last_leaf.weight.shape[0] == leaf.weight.shape[0]:
instance = self.copy()
instance.register((last_leaf,), leaf)
instances.append(instance)
last_leaf = leaf

return instances


class MergeBatchNormtoRight(MergeBatchNorm):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convhook strictly expects a 2d Convolution. We can either make convhook more general, or call this MergeBatchNormConv2dRight

'''Canonizer to merge the parameters of all batch norms that appear sequentially right before a linear module.

Note
----
MergeBatchNormtoRight traverses the tree of children of the provided module depth-first and in-order.
This means that child-modules must be assigned to their parent module in the order they are visited in the forward
pass to correctly identify adjacent modules.
This also means that activation functions must be assigned in their module-form as a child to their parent-module
to properly detect when there is an activation function between linear and batch-norm modules.

'''

@staticmethod
def convhook(module, x, y):
Comment on lines +212 to +213
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggested to make this a wrapper instead (see respective line where the hook is registered). Also, try have descriptive names (i.e., avoid single character names like x and y)

x = x[0]
bias_kernel = module.canonization_params["bias_kernel"]
pad1, pad2 = module.padding
# ASSUMING module.kernel_size IS ALWAYS STRICTLY GREATER THAN module.padding
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's raise Exceptions when assumptions fail

if pad1 > 0:
left_margin = bias_kernel[:, :, 0:pad1, :]
right_margin = bias_kernel[:, :, pad1 + 1:, :]
middle = bias_kernel[:, :, pad1:pad1 + 1, :].expand(
1,
bias_kernel.shape[1],
x.shape[2] - module.weight.shape[2] + 1,
bias_kernel.shape[-1]
)
bias_kernel = torch.cat((left_margin, middle, right_margin), dim=2)

if pad2 > 0:
left_margin = bias_kernel[:, :, :, 0:pad2]
right_margin = bias_kernel[:, :, :, pad2 + 1:]
middle = bias_kernel[:, :, :, pad2:pad2 + 1].expand(
1,
bias_kernel.shape[1],
bias_kernel.shape[-2],
x.shape[3] - module.weight.shape[3] + 1
)
bias_kernel = torch.cat((left_margin, middle, right_margin), dim=3)
Comment on lines +218 to +238
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are the same except for the dim. Let's write this only once without copying too much code.


if module.stride[0] > 1 or module.stride[1] > 1:
indices1 = [i for i in range(0, bias_kernel.shape[2]) if i % module.stride[0] == 0]
indices2 = [i for i in range(0, bias_kernel.shape[3]) if i % module.stride[1] == 0]
bias_kernel = bias_kernel[:, :, indices1, :]
bias_kernel = bias_kernel[:, :, :, indices2]
ynew = y + bias_kernel
return ynew

def __init__(self):
super().__init__()
self.handles = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make handles a RemovableHandleList. We can add super(A, self) to the RemovableHandleList, then we only need to do self.handles.remove() in def remove


def apply(self, root_module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

instances = []
last_leaf = None
for leaf in collect_leaves(root_module):
if isinstance(last_leaf, self.batch_norm_type) and isinstance(leaf, self.linear_type):
instance = self.copy()
instance.register((last_leaf,), leaf)
instance.register((leaf,), last_leaf)
instances.append(instance)
last_leaf = leaf

return instances

def register(self, linears, batch_norm):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is almost identical to the super class. Can we maybe write this in a way that does not result in so much copied code? I'm okay with returning the result of self.merge_batch_norm inside MergeBatchNorm.register, which will be None anyways

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with making the function signature mismatch its super class for the sake of not copying apply

'''Store the parameters of the linear modules and the batch norm module and apply the merge.

Parameters
----------
linear: list of obj:`torch.nn.Module`
List of linear layer with mandatory attributes `weight` and `bias`.
batch_norm: obj:`torch.nn.Module`
Batch Normalization module with mandatory attributes
`running_mean`, `running_var`, `weight`, `bias` and `eps`
'''
self.linears = linears
self.batch_norm = batch_norm
self.batch_norm_eps = self.batch_norm.eps

self.linear_params = [(linear.weight.data, getattr(linear.bias, 'data', None)) for linear in linears]

self.batch_norm_params = {
key: getattr(self.batch_norm, key).data for key in ('weight', 'bias', 'running_mean', 'running_var')
}
returned_handles = self.merge_batch_norm(self.linears, self.batch_norm)
if returned_handles != []:
self.handles = self.handles + returned_handles
Comment on lines +284 to +286
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the other changes, this should become only self.handles = self.merge_batch_norm(self.linears, self.batch_norm).


def remove(self):
'''Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before
the merge.
'''
super().remove()
for h in self.handles:
h.remove()
Comment on lines +292 to +294
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when self.handles is a RemovableHandleList, we only need self.handles.remove()

for module in self.linears:
if isinstance(module, torch.nn.Conv2d):
if module.padding != (0, 0):
delattr(module, "canonization_params")
Comment on lines +296 to +298
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we store canonization_params (maybe give it a different name) inside the canonizer, we do not need to do anything (maybe we can clear it, but we do not do it for the linear and batchnorm params either)


def merge_batch_norm(self, modules, batch_norm):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring

return_handles = []
denominator = (batch_norm.running_var + batch_norm.eps) ** .5

# Weight of the batch norm layer when seen as an affine transformation
scale = (batch_norm.weight / denominator)

# bias of the batch norm layer when seen as an affine transformation
shift = batch_norm.bias - batch_norm.running_mean * scale

for module in modules:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I would prefer less code copying. You can change the parent class to make this work

original_weight = module.weight.data
if module.bias is None:
module.bias = torch.nn.Parameter(
torch.zeros(module.out_channels, device=original_weight.device, dtype=original_weight.dtype)
)
original_bias = module.bias.data

if isinstance(module, ConvolutionTranspose):
index = (slice(None), *((None,) * (original_weight.ndim - 1)))
else:
index = (None, slice(None), *((None,) * (original_weight.ndim - 2)))

# merge batch_norm into linear layer to the right
module.weight.data = (original_weight * scale[index])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module.weight.data = (original_weight * scale[index])
object.__setattr__(module, 'weight', original_weight * scale[index])


if isinstance(module, torch.nn.Conv2d):
if module.padding == (0, 0):
module.bias.data = (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be adapted to object.__setattr__(module, 'bias', (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias)

else:
# We calculate a bias kernel, which is the output of the conv layer, without the bias, and with maximum padding,
# applied to feature maps of the same size as the convolution kernel, with values given by the batch norm biases.
# This produces a mostly constant feature map, which is not constant near the edges due to padding.
# We then attach a forward hook to the conv layer to compute from this bias_kernel the feature map to be added
# after the convolution due to the batch norm bias, depending on the given input's shape
bias_kernel = shift[index].expand(*(shift[index].shape[0:-2] + original_weight.shape[-2:]))
temp_module = torch.nn.Conv2d(in_channels=module.in_channels, out_channels=module.out_channels,
kernel_size=module.kernel_size, padding=module.padding,
padding_mode=module.padding_mode, bias=False)
Comment on lines +336 to +338
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's indent one line per kwarg

temp_module.weight.data = original_weight
bias_kernel = temp_module(bias_kernel).detach()

module.canonization_params = {}
module.canonization_params["bias_kernel"] = bias_kernel
Comment on lines +342 to +343
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's store these in the canonizer itself, similar to MergeBatchNorm.linear_params and .batch_norm_params

return_handles.append(module.register_forward_hook(MergeBatchNormtoRight.convhook))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of not using Hooks, maybe we can wrap and overwrite the forward function (similar to the ResNet Canonizer)?

elif isinstance(module, torch.nn.Linear):
module.bias.data = (original_weight * shift).sum(dim=1) + original_bias

# change batch_norm parameters to produce identity
batch_norm.running_mean.data = torch.zeros_like(batch_norm.running_mean.data)
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data)
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data)
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data)
batch_norm.eps = 0.
Comment on lines +346 to +353
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these need to be adapted to the new approach (see current version of MergeBatchNorm)

return return_handles


class ThreshReLUMergeBatchNorm(MergeBatchNormtoRight):
'''Canonizer to canonize BatchNorm -> ReLU -> Linear chains, modifying the ReLU as explained in
https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf
'''

@staticmethod
def prehook(module, x):
module.canonization_params["original_x"] = x[0].clone()

@staticmethod
def fwdhook(module, x, y):
x = module.canonization_params["original_x"]
index = (None, slice(None), *((None,) * (module.canonization_params['weights'].ndim + 1)))
y = module.canonization_params['weights'][index] * x + module.canonization_params['biases'][index]
baseline_vals = -1. * (module.canonization_params['biases'] / module.canonization_params['weights'])[index]
return torch.where(y > 0, x, baseline_vals)
Comment on lines +363 to +372
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstrings and signatures are non-descriptive


def __init__(self):
super().__init__()
self.relu = None

def apply(self, root_module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring

instances = []
oldest_leaf = None
old_leaf = None
mid_leaf = None
for leaf in collect_leaves(root_module):
if (
isinstance(old_leaf, self.batch_norm_type)
and isinstance(mid_leaf, ReLU)
and isinstance(leaf, self.linear_type)
Comment on lines +385 to +387
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like one indentation level too deep

):
instance = self.copy()
instance.register((leaf,), old_leaf, mid_leaf)
instances.append(instance)
elif (
isinstance(oldest_leaf, self.batch_norm_type)
and isinstance(old_leaf, ReLU)
and isinstance(mid_leaf, AdaptiveAvgPool2d)
and isinstance(leaf, self.linear_type)
Comment on lines +393 to +396
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like one indentation level too deep

):
instance = self.copy()
instance.register((leaf,), oldest_leaf, old_leaf)
instances.append(instance)
oldest_leaf = old_leaf
old_leaf = mid_leaf
mid_leaf = leaf

return instances

def register(self, linears, batch_norm, relu):
'''Store the parameters of the linear modules and the batch norm module and apply the merge.

Parameters
----------
linear: list of obj:`torch.nn.Module`
List of linear layer with mandatory attributes `weight` and `bias`.
batch_norm: obj:`torch.nn.Module`
Batch Normalization module with mandatory attributes
`running_mean`, `running_var`, `weight`, `bias` and `eps`
relu: obj:`torch.nn.Module`
The activation unit between the Batch Normalization and Linear modules.
'''
self.relu = relu

denominator = (batch_norm.running_var + batch_norm.eps) ** .5
scale = (batch_norm.weight / denominator) # Weight of the batch norm layer when seen as a linear layer
shift = batch_norm.bias - batch_norm.running_mean * scale # bias of the batch norm layer when seen as a linear layer
self.relu.canonization_params = {}
self.relu.canonization_params['weights'] = scale
self.relu.canonization_params['biases'] = shift
Comment on lines +425 to +427
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's store these inside the canonizer


super().register(linears, batch_norm)
self.handles.append(self.relu.register_forward_pre_hook(ThreshReLUMergeBatchNorm.prehook))
self.handles.append(self.relu.register_forward_hook(ThreshReLUMergeBatchNorm.fwdhook))
Comment on lines +430 to +431
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can do this with a wrapped .forward similar to how I suggested for the MergeBatchNormtoRight?


def remove(self):
'''Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before
the merge.
'''
super().remove()
delattr(self.relu, "canonization_params")
Comment on lines +433 to +438
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we store the params inside the canonizer, we do not need to delete them, thus we can simply remove this to use the parent class' remove



class NamedMergeBatchNorm(MergeBatchNorm):
'''Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names.
Expand All @@ -197,6 +446,7 @@ class NamedMergeBatchNorm(MergeBatchNorm):
name_map: list[tuple[string], string]
List of which linear layer names belong to which batch norm name.
'''

def __init__(self, name_map):
super().__init__()
self.name_map = name_map
Expand Down Expand Up @@ -239,6 +489,7 @@ class AttributeCanonizer(Canonizer):
overload for a module. The function signature is (name: string, module: type) -> None or
dict.
'''

def __init__(self, attribute_map):
self.attribute_map = attribute_map
self.attribute_keys = None
Expand Down Expand Up @@ -308,6 +559,7 @@ class CompositeCanonizer(Canonizer):
canonizers : list of obj:`Canonizer`
Canonizers of which to build a Composite of.
'''

def __init__(self, canonizers):
self.canonizers = canonizers

Expand All @@ -327,6 +579,7 @@ def apply(self, root_module):
instances = []
for canonizer in self.canonizers:
instances += canonizer.apply(root_module)
instances.reverse()
return instances

def register(self):
Expand Down
Loading