-
Notifications
You must be signed in to change notification settings - Fork 35
Densenet canonizations #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
28e40df
38625d3
1eb17ca
93b2961
95826dc
670b63f
5688f16
10cf96d
4b76c71
7904301
de7414c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,6 +28,7 @@ class Canonizer(metaclass=ABCMeta): | |||||
'''Canonizer Base class. | ||||||
Canonizers modify modules temporarily such that certain attribution rules can properly be applied. | ||||||
''' | ||||||
|
||||||
@abstractmethod | ||||||
def apply(self, root_module): | ||||||
'''Apply this canonizer recursively on all applicable modules. | ||||||
|
@@ -70,7 +71,7 @@ def __init__(self): | |||||
super().__init__() | ||||||
self.linears = None | ||||||
self.batch_norm = None | ||||||
|
||||||
self.batch_norm_eps = None | ||||||
self.linear_params = None | ||||||
self.batch_norm_params = None | ||||||
|
||||||
|
@@ -94,6 +95,7 @@ def register(self, linears, batch_norm): | |||||
key: getattr(self.batch_norm, key).data for key in ('weight', 'bias', 'running_mean', 'running_var') | ||||||
} | ||||||
|
||||||
self.batch_norm_eps = batch_norm.eps | ||||||
self.merge_batch_norm(self.linears, self.batch_norm) | ||||||
|
||||||
def remove(self): | ||||||
|
@@ -110,6 +112,8 @@ def remove(self): | |||||
for key, value in self.batch_norm_params.items(): | ||||||
getattr(self.batch_norm, key).data = value | ||||||
|
||||||
self.batch_norm.eps = self.batch_norm_eps | ||||||
|
||||||
@staticmethod | ||||||
def merge_batch_norm(modules, batch_norm): | ||||||
'''Update parameters of a linear layer to additionally include a Batch Normalization operation and update the | ||||||
|
@@ -148,6 +152,7 @@ def merge_batch_norm(modules, batch_norm): | |||||
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data) | ||||||
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data) | ||||||
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data) | ||||||
batch_norm.eps = 0. | ||||||
|
||||||
|
||||||
class SequentialMergeBatchNorm(MergeBatchNorm): | ||||||
|
@@ -162,6 +167,7 @@ class SequentialMergeBatchNorm(MergeBatchNorm): | |||||
to properly detect when there is an activation function between linear and batch-norm modules. | ||||||
|
||||||
''' | ||||||
|
||||||
def apply(self, root_module): | ||||||
'''Finds a batch norm following right after a linear layer, and creates a copy of this instance to merge | ||||||
them by fusing the batch norm parameters into the linear layer and reducing the batch norm to the identity. | ||||||
|
@@ -181,13 +187,256 @@ def apply(self, root_module): | |||||
last_leaf = None | ||||||
for leaf in collect_leaves(root_module): | ||||||
if isinstance(last_leaf, self.linear_type) and isinstance(leaf, self.batch_norm_type): | ||||||
if last_leaf.weight.shape[0] == leaf.weight.shape[0]: | ||||||
instance = self.copy() | ||||||
instance.register((last_leaf,), leaf) | ||||||
instances.append(instance) | ||||||
last_leaf = leaf | ||||||
|
||||||
return instances | ||||||
|
||||||
|
||||||
class MergeBatchNormtoRight(MergeBatchNorm): | ||||||
'''Canonizer to merge the parameters of all batch norms that appear sequentially right before a linear module. | ||||||
|
||||||
Note | ||||||
---- | ||||||
MergeBatchNormtoRight traverses the tree of children of the provided module depth-first and in-order. | ||||||
This means that child-modules must be assigned to their parent module in the order they are visited in the forward | ||||||
pass to correctly identify adjacent modules. | ||||||
This also means that activation functions must be assigned in their module-form as a child to their parent-module | ||||||
to properly detect when there is an activation function between linear and batch-norm modules. | ||||||
|
||||||
''' | ||||||
|
||||||
@staticmethod | ||||||
def convhook(module, x, y): | ||||||
Comment on lines
+212
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggested to make this a wrapper instead (see respective line where the hook is registered). Also, try have descriptive names (i.e., avoid single character names like |
||||||
x = x[0] | ||||||
bias_kernel = module.canonization_params["bias_kernel"] | ||||||
pad1, pad2 = module.padding | ||||||
# ASSUMING module.kernel_size IS ALWAYS STRICTLY GREATER THAN module.padding | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's raise Exceptions when assumptions fail |
||||||
if pad1 > 0: | ||||||
left_margin = bias_kernel[:, :, 0:pad1, :] | ||||||
right_margin = bias_kernel[:, :, pad1 + 1:, :] | ||||||
middle = bias_kernel[:, :, pad1:pad1 + 1, :].expand( | ||||||
1, | ||||||
bias_kernel.shape[1], | ||||||
x.shape[2] - module.weight.shape[2] + 1, | ||||||
bias_kernel.shape[-1] | ||||||
) | ||||||
bias_kernel = torch.cat((left_margin, middle, right_margin), dim=2) | ||||||
|
||||||
if pad2 > 0: | ||||||
left_margin = bias_kernel[:, :, :, 0:pad2] | ||||||
right_margin = bias_kernel[:, :, :, pad2 + 1:] | ||||||
middle = bias_kernel[:, :, :, pad2:pad2 + 1].expand( | ||||||
1, | ||||||
bias_kernel.shape[1], | ||||||
bias_kernel.shape[-2], | ||||||
x.shape[3] - module.weight.shape[3] + 1 | ||||||
) | ||||||
bias_kernel = torch.cat((left_margin, middle, right_margin), dim=3) | ||||||
Comment on lines
+218
to
+238
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two are the same except for the dim. Let's write this only once without copying too much code. |
||||||
|
||||||
if module.stride[0] > 1 or module.stride[1] > 1: | ||||||
indices1 = [i for i in range(0, bias_kernel.shape[2]) if i % module.stride[0] == 0] | ||||||
indices2 = [i for i in range(0, bias_kernel.shape[3]) if i % module.stride[1] == 0] | ||||||
bias_kernel = bias_kernel[:, :, indices1, :] | ||||||
bias_kernel = bias_kernel[:, :, :, indices2] | ||||||
ynew = y + bias_kernel | ||||||
return ynew | ||||||
|
||||||
def __init__(self): | ||||||
super().__init__() | ||||||
self.handles = [] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make handles a |
||||||
|
||||||
def apply(self, root_module): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing docstring |
||||||
instances = [] | ||||||
last_leaf = None | ||||||
for leaf in collect_leaves(root_module): | ||||||
if isinstance(last_leaf, self.batch_norm_type) and isinstance(leaf, self.linear_type): | ||||||
instance = self.copy() | ||||||
instance.register((last_leaf,), leaf) | ||||||
instance.register((leaf,), last_leaf) | ||||||
instances.append(instance) | ||||||
last_leaf = leaf | ||||||
|
||||||
return instances | ||||||
|
||||||
def register(self, linears, batch_norm): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is almost identical to the super class. Can we maybe write this in a way that does not result in so much copied code? I'm okay with returning the result of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm okay with making the function signature mismatch its super class for the sake of not copying |
||||||
'''Store the parameters of the linear modules and the batch norm module and apply the merge. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
linear: list of obj:`torch.nn.Module` | ||||||
List of linear layer with mandatory attributes `weight` and `bias`. | ||||||
batch_norm: obj:`torch.nn.Module` | ||||||
Batch Normalization module with mandatory attributes | ||||||
`running_mean`, `running_var`, `weight`, `bias` and `eps` | ||||||
''' | ||||||
self.linears = linears | ||||||
self.batch_norm = batch_norm | ||||||
self.batch_norm_eps = self.batch_norm.eps | ||||||
|
||||||
self.linear_params = [(linear.weight.data, getattr(linear.bias, 'data', None)) for linear in linears] | ||||||
|
||||||
self.batch_norm_params = { | ||||||
key: getattr(self.batch_norm, key).data for key in ('weight', 'bias', 'running_mean', 'running_var') | ||||||
} | ||||||
returned_handles = self.merge_batch_norm(self.linears, self.batch_norm) | ||||||
if returned_handles != []: | ||||||
self.handles = self.handles + returned_handles | ||||||
Comment on lines
+284
to
+286
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the other changes, this should become only |
||||||
|
||||||
def remove(self): | ||||||
'''Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before | ||||||
the merge. | ||||||
''' | ||||||
super().remove() | ||||||
for h in self.handles: | ||||||
h.remove() | ||||||
Comment on lines
+292
to
+294
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when |
||||||
for module in self.linears: | ||||||
if isinstance(module, torch.nn.Conv2d): | ||||||
if module.padding != (0, 0): | ||||||
delattr(module, "canonization_params") | ||||||
Comment on lines
+296
to
+298
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when we store |
||||||
|
||||||
def merge_batch_norm(self, modules, batch_norm): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing docstring |
||||||
return_handles = [] | ||||||
denominator = (batch_norm.running_var + batch_norm.eps) ** .5 | ||||||
|
||||||
# Weight of the batch norm layer when seen as an affine transformation | ||||||
scale = (batch_norm.weight / denominator) | ||||||
|
||||||
# bias of the batch norm layer when seen as an affine transformation | ||||||
shift = batch_norm.bias - batch_norm.running_mean * scale | ||||||
|
||||||
for module in modules: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I would prefer less code copying. You can change the parent class to make this work |
||||||
original_weight = module.weight.data | ||||||
if module.bias is None: | ||||||
module.bias = torch.nn.Parameter( | ||||||
torch.zeros(module.out_channels, device=original_weight.device, dtype=original_weight.dtype) | ||||||
) | ||||||
original_bias = module.bias.data | ||||||
|
||||||
if isinstance(module, ConvolutionTranspose): | ||||||
index = (slice(None), *((None,) * (original_weight.ndim - 1))) | ||||||
else: | ||||||
index = (None, slice(None), *((None,) * (original_weight.ndim - 2))) | ||||||
|
||||||
# merge batch_norm into linear layer to the right | ||||||
module.weight.data = (original_weight * scale[index]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if isinstance(module, torch.nn.Conv2d): | ||||||
if module.padding == (0, 0): | ||||||
module.bias.data = (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs to be adapted to |
||||||
else: | ||||||
# We calculate a bias kernel, which is the output of the conv layer, without the bias, and with maximum padding, | ||||||
# applied to feature maps of the same size as the convolution kernel, with values given by the batch norm biases. | ||||||
# This produces a mostly constant feature map, which is not constant near the edges due to padding. | ||||||
# We then attach a forward hook to the conv layer to compute from this bias_kernel the feature map to be added | ||||||
# after the convolution due to the batch norm bias, depending on the given input's shape | ||||||
bias_kernel = shift[index].expand(*(shift[index].shape[0:-2] + original_weight.shape[-2:])) | ||||||
temp_module = torch.nn.Conv2d(in_channels=module.in_channels, out_channels=module.out_channels, | ||||||
kernel_size=module.kernel_size, padding=module.padding, | ||||||
padding_mode=module.padding_mode, bias=False) | ||||||
Comment on lines
+336
to
+338
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's indent one line per kwarg |
||||||
temp_module.weight.data = original_weight | ||||||
bias_kernel = temp_module(bias_kernel).detach() | ||||||
|
||||||
module.canonization_params = {} | ||||||
module.canonization_params["bias_kernel"] = bias_kernel | ||||||
Comment on lines
+342
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's store these in the canonizer itself, similar to |
||||||
return_handles.append(module.register_forward_hook(MergeBatchNormtoRight.convhook)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the sake of not using Hooks, maybe we can wrap and overwrite the forward function (similar to the ResNet Canonizer)? |
||||||
elif isinstance(module, torch.nn.Linear): | ||||||
module.bias.data = (original_weight * shift).sum(dim=1) + original_bias | ||||||
|
||||||
# change batch_norm parameters to produce identity | ||||||
batch_norm.running_mean.data = torch.zeros_like(batch_norm.running_mean.data) | ||||||
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data) | ||||||
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data) | ||||||
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data) | ||||||
batch_norm.eps = 0. | ||||||
Comment on lines
+346
to
+353
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these need to be adapted to the new approach (see current version of |
||||||
return return_handles | ||||||
|
||||||
|
||||||
class ThreshReLUMergeBatchNorm(MergeBatchNormtoRight): | ||||||
'''Canonizer to canonize BatchNorm -> ReLU -> Linear chains, modifying the ReLU as explained in | ||||||
https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf | ||||||
''' | ||||||
|
||||||
@staticmethod | ||||||
def prehook(module, x): | ||||||
module.canonization_params["original_x"] = x[0].clone() | ||||||
|
||||||
@staticmethod | ||||||
def fwdhook(module, x, y): | ||||||
x = module.canonization_params["original_x"] | ||||||
index = (None, slice(None), *((None,) * (module.canonization_params['weights'].ndim + 1))) | ||||||
y = module.canonization_params['weights'][index] * x + module.canonization_params['biases'][index] | ||||||
baseline_vals = -1. * (module.canonization_params['biases'] / module.canonization_params['weights'])[index] | ||||||
return torch.where(y > 0, x, baseline_vals) | ||||||
Comment on lines
+363
to
+372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing docstrings and signatures are non-descriptive |
||||||
|
||||||
def __init__(self): | ||||||
super().__init__() | ||||||
self.relu = None | ||||||
|
||||||
def apply(self, root_module): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing docstring |
||||||
instances = [] | ||||||
oldest_leaf = None | ||||||
old_leaf = None | ||||||
mid_leaf = None | ||||||
for leaf in collect_leaves(root_module): | ||||||
if ( | ||||||
isinstance(old_leaf, self.batch_norm_type) | ||||||
and isinstance(mid_leaf, ReLU) | ||||||
and isinstance(leaf, self.linear_type) | ||||||
Comment on lines
+385
to
+387
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like one indentation level too deep |
||||||
): | ||||||
instance = self.copy() | ||||||
instance.register((leaf,), old_leaf, mid_leaf) | ||||||
instances.append(instance) | ||||||
elif ( | ||||||
isinstance(oldest_leaf, self.batch_norm_type) | ||||||
and isinstance(old_leaf, ReLU) | ||||||
and isinstance(mid_leaf, AdaptiveAvgPool2d) | ||||||
and isinstance(leaf, self.linear_type) | ||||||
Comment on lines
+393
to
+396
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like one indentation level too deep |
||||||
): | ||||||
instance = self.copy() | ||||||
instance.register((leaf,), oldest_leaf, old_leaf) | ||||||
instances.append(instance) | ||||||
oldest_leaf = old_leaf | ||||||
old_leaf = mid_leaf | ||||||
mid_leaf = leaf | ||||||
|
||||||
return instances | ||||||
|
||||||
def register(self, linears, batch_norm, relu): | ||||||
'''Store the parameters of the linear modules and the batch norm module and apply the merge. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
linear: list of obj:`torch.nn.Module` | ||||||
List of linear layer with mandatory attributes `weight` and `bias`. | ||||||
batch_norm: obj:`torch.nn.Module` | ||||||
Batch Normalization module with mandatory attributes | ||||||
`running_mean`, `running_var`, `weight`, `bias` and `eps` | ||||||
relu: obj:`torch.nn.Module` | ||||||
The activation unit between the Batch Normalization and Linear modules. | ||||||
''' | ||||||
self.relu = relu | ||||||
|
||||||
denominator = (batch_norm.running_var + batch_norm.eps) ** .5 | ||||||
scale = (batch_norm.weight / denominator) # Weight of the batch norm layer when seen as a linear layer | ||||||
shift = batch_norm.bias - batch_norm.running_mean * scale # bias of the batch norm layer when seen as a linear layer | ||||||
self.relu.canonization_params = {} | ||||||
self.relu.canonization_params['weights'] = scale | ||||||
self.relu.canonization_params['biases'] = shift | ||||||
Comment on lines
+425
to
+427
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's store these inside the canonizer |
||||||
|
||||||
super().register(linears, batch_norm) | ||||||
self.handles.append(self.relu.register_forward_pre_hook(ThreshReLUMergeBatchNorm.prehook)) | ||||||
self.handles.append(self.relu.register_forward_hook(ThreshReLUMergeBatchNorm.fwdhook)) | ||||||
Comment on lines
+430
to
+431
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can do this with a wrapped .forward similar to how I suggested for the |
||||||
|
||||||
def remove(self): | ||||||
'''Undo the merge by reverting the parameters of both the linear and the batch norm modules to the state before | ||||||
the merge. | ||||||
''' | ||||||
super().remove() | ||||||
delattr(self.relu, "canonization_params") | ||||||
Comment on lines
+433
to
+438
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we store the params inside the canonizer, we do not need to delete them, thus we can simply remove this to use the parent class' remove |
||||||
|
||||||
|
||||||
class NamedMergeBatchNorm(MergeBatchNorm): | ||||||
'''Canonizer to merge the parameters of all batch norms into linear modules, specified by their respective names. | ||||||
|
@@ -197,6 +446,7 @@ class NamedMergeBatchNorm(MergeBatchNorm): | |||||
name_map: list[tuple[string], string] | ||||||
List of which linear layer names belong to which batch norm name. | ||||||
''' | ||||||
|
||||||
def __init__(self, name_map): | ||||||
super().__init__() | ||||||
self.name_map = name_map | ||||||
|
@@ -239,6 +489,7 @@ class AttributeCanonizer(Canonizer): | |||||
overload for a module. The function signature is (name: string, module: type) -> None or | ||||||
dict. | ||||||
''' | ||||||
|
||||||
def __init__(self, attribute_map): | ||||||
self.attribute_map = attribute_map | ||||||
self.attribute_keys = None | ||||||
|
@@ -308,6 +559,7 @@ class CompositeCanonizer(Canonizer): | |||||
canonizers : list of obj:`Canonizer` | ||||||
Canonizers of which to build a Composite of. | ||||||
''' | ||||||
|
||||||
def __init__(self, canonizers): | ||||||
self.canonizers = canonizers | ||||||
|
||||||
|
@@ -327,6 +579,7 @@ def apply(self, root_module): | |||||
instances = [] | ||||||
for canonizer in self.canonizers: | ||||||
instances += canonizer.apply(root_module) | ||||||
instances.reverse() | ||||||
return instances | ||||||
|
||||||
def register(self): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
convhook
strictly expects a 2d Convolution. We can either makeconvhook
more general, or call thisMergeBatchNormConv2dRight