Skip to content

Commit 5942658

Browse files
committed
Fix tests and imports
1 parent a053780 commit 5942658

File tree

2 files changed

+96
-86
lines changed

2 files changed

+96
-86
lines changed

src/brevitas/graph/equalize.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
1111
import warnings
1212

13+
import packaging
14+
import packaging.version
1315
import torch
1416
from torch.fx import GraphModule as TorchGraphModule
1517
import torch.nn as nn
1618

19+
from brevitas import torch_version
1720
from brevitas.fx import GraphModule
1821
from brevitas.fx import Node
1922
from brevitas.graph.base import GraphTransform
23+
from brevitas.graph.base import InsertModuleCallAfter
2024
from brevitas.graph.base import ModuleInstanceToModuleInstance
2125
from brevitas.graph.hadamard import get_hadK
2226
from brevitas.graph.hadamard import matmul_hadU
@@ -29,13 +33,18 @@
2933
from brevitas.nn.quant_scale_bias import ScaleBias
3034
from brevitas.utils.torch_utils import KwargsForwardHook
3135

32-
from .base import InsertModuleCallAfter
33-
36+
# External optional dependency
3437
try:
3538
import fast_hadamard_transform
3639
except:
3740
fast_hadamard_transform = None
3841

42+
# RMSNorm was introduced with torch 2.4
43+
if torch_version >= packaging.version.parse('2.4'):
44+
RMSNorm = nn.RMSNorm
45+
else:
46+
RMSNorm = object
47+
3948
__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']
4049

4150
EPSILON = 1e-9
@@ -77,14 +86,13 @@
7786
operator.imul,
7887
operator.__mul__,
7988
operator.__imul__,
80-
torch.nn.functional.interpolate)
89+
nn.functional.interpolate)
8190

8291
_select_op = (operator.getitem, operator.__getitem__)
8392

8493
_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', 'to', torch.reshape, torch.flatten)
8594

86-
_scale_varying_activations = (
87-
torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU)
95+
_scale_varying_activations = (nn.Sigmoid, nn.Tanh, nn.ReLU6, nn.GELU, nn.SiLU)
8896

8997
_residual_methods = ('add', 'add_')
9098

@@ -95,31 +103,6 @@
95103
_ignore_ops = (getattr, 'size')
96104

97105

98-
def _is_supported_module(
99-
graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool:
100-
if node.op == 'call_module':
101-
module = get_module(graph_model, node.target)
102-
if isinstance(module, supported_layers):
103-
# We support only self-attention
104-
if isinstance(module, nn.MultiheadAttention):
105-
kwargs = dict(node.kwargs)
106-
# When using hf/accelerate, we need to check the signature of the original forward
107-
forward_to_check = module._old_forward if hasattr(
108-
module, '_old_forward') else module.forward
109-
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args))
110-
return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name
111-
return True
112-
return False
113-
114-
115-
def _is_scale_invariant_module(
116-
graph_model: GraphModule,
117-
node: Node,
118-
scale_invariant_layers=_scale_invariant_layers) -> bool:
119-
return node.op == 'call_module' and isinstance(
120-
get_module(graph_model, node.target), scale_invariant_layers)
121-
122-
123106
# Start and End identify the starting and ending channels of the weight matrix that need to be
124107
# equalized.
125108
# Offset refers to the relative position of these channels with respect to
@@ -334,7 +317,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]:
334317
return 0
335318
elif module.groups == module.out_channels:
336319
return 1
337-
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
320+
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
338321
# We assume normalization happens only along the channel dimension
339322
if len(module.weight.shape) == 1:
340323
return 0
@@ -362,7 +345,7 @@ def _get_output_axis(module: nn.Module) -> Optional[int]:
362345
elif isinstance(module,
363346
(nn.Embedding, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
364347
return 1
365-
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
348+
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
366349
# We assume normalization happens only along the channel dimension
367350
if len(module.weight.shape) == 1:
368351
return 0
@@ -687,6 +670,31 @@ def _equalize(
687670
return model
688671

689672

673+
def _is_supported_module(
674+
graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool:
675+
if node.op == 'call_module':
676+
module = get_module(graph_model, node.target)
677+
if isinstance(module, supported_layers):
678+
# We support only self-attention
679+
if isinstance(module, nn.MultiheadAttention):
680+
kwargs = dict(node.kwargs)
681+
# When using hf/accelerate, we need to check the signature of the original forward
682+
forward_to_check = module._old_forward if hasattr(
683+
module, '_old_forward') else module.forward
684+
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args))
685+
return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name
686+
return True
687+
return False
688+
689+
690+
def _is_scale_invariant_module(
691+
graph_model: GraphModule,
692+
node: Node,
693+
scale_invariant_layers=_scale_invariant_layers) -> bool:
694+
return node.op == 'call_module' and isinstance(
695+
get_module(graph_model, node.target), scale_invariant_layers)
696+
697+
690698
def _is_scale_varying_activation(graph_model, node):
691699
return node.op == 'call_module' and isinstance(
692700
get_module(graph_model, node.target), _scale_varying_activations)
@@ -696,7 +704,7 @@ def _is_scale_invariant_function(node: Node, scale_invariant_op: Set = _scale_in
696704
out = node.op in (
697705
'call_function',
698706
'call_method') and node.target in scale_invariant_op + _select_op + _reshaping_op
699-
if node.target == torch.nn.functional.interpolate:
707+
if node.target == nn.functional.interpolate:
700708
out &= node.kwargs.get('mode', None) == 'nearest'
701709
return out
702710

@@ -959,7 +967,7 @@ def apply(self,
959967
graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
960968
# It is not possible to equalize through LayerNorm/BatchNorm as sink
961969
supported_sinks = tuple([
962-
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
970+
x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)])
963971
regions = _extract_regions(
964972
graph_model, state_impl_kwargs={'supported_sinks': supported_sinks})
965973
if len(regions) > 0:
@@ -1135,7 +1143,7 @@ def __init__(
11351143

11361144
# It is not possible to equalize through LayerNorm/BatchNorm as sink
11371145
supported_sinks = tuple([
1138-
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
1146+
x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)])
11391147
self.regions = _extract_regions(
11401148
model,
11411149
add_mul_node=add_mul_node,
@@ -1305,9 +1313,9 @@ class GraphRotationEqualization(GraphTransform):
13051313
def __init__(self) -> None:
13061314
super(GraphRotationEqualization, self).__init__()
13071315

1308-
self.supported_srcs = (torch.nn.Linear, torch.nn.Embedding)
1309-
self.supported_sinks = (torch.nn.Linear)
1310-
self.scale_invariant_layers = (torch.nn.RMSNorm,)
1316+
self.supported_srcs = (nn.Linear, nn.Embedding)
1317+
self.supported_sinks = (nn.Linear)
1318+
self.scale_invariant_layers = (RMSNorm,)
13111319
self.scale_invariant_function = ()
13121320

13131321
def apply(self,
@@ -1332,7 +1340,7 @@ def _replace_bias(next_module, new_bias):
13321340
next_module.bias.data.copy_(new_bias)
13331341
else:
13341342
new_bias = new_bias.to(next_module.weight.device).to(next_module.weight.dtype)
1335-
next_module.register_parameter('bias', torch.nn.Parameter(new_bias))
1343+
next_module.register_parameter('bias', nn.Parameter(new_bias))
13361344

13371345

13381346
def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
@@ -1342,7 +1350,7 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
13421350
layer_norm.bias.data /= layer_norm.weight.data
13431351
# We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor
13441352
scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight)
1345-
next_module.weight = torch.nn.Parameter(next_module.weight.clone() * scale)
1353+
next_module.weight = nn.Parameter(next_module.weight.clone() * scale)
13461354

13471355
# Merge bias, new_bias includes the bias of next_module by going through its fwd
13481356
if hasattr(layer_norm, 'bias'):
@@ -1355,8 +1363,8 @@ class MergeLnAffine(GraphTransform):
13551363

13561364
def __init__(self) -> None:
13571365
super(MergeLnAffine, self).__init__()
1358-
self.supported_srcs = (torch.nn.RMSNorm, torch.nn.LayerNorm)
1359-
self.supported_sinks = (torch.nn.Linear)
1366+
self.supported_srcs = (RMSNorm, nn.LayerNorm)
1367+
self.supported_sinks = (nn.Linear)
13601368

13611369
def apply(self, graph_model: GraphModule) -> GraphModule:
13621370
regions = _extract_regions(
@@ -1388,7 +1396,7 @@ class LayerwiseActivationRotation(GraphTransform):
13881396
def __init__(self, blacklist_layer=None):
13891397
super(GraphTransform, self).__init__()
13901398

1391-
self.supported_sinks = (torch.nn.Linear)
1399+
self.supported_sinks = (nn.Linear)
13921400
self.blacklist_layers = blacklist_layer
13931401

13941402
def find_module(self, model, regions: List, prefix=''):

tests/brevitas/graph/test_equalization.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,53 @@
1717
from brevitas.graph.standardize import DuplicateSharedStatelessModule
1818
from brevitas.graph.standardize import TorchFunctionalToModule
1919
from brevitas.graph.utils import get_module
20+
from tests.marker import requires_pt_ge
2021

2122
from .equalization_fixtures import *
2223

23-
# def test_resnet18_equalization():
24-
# model = models.resnet18(pretrained=True)
25-
26-
# torch.manual_seed(SEED)
27-
# inp = torch.randn(IN_SIZE_CONV)
28-
# model.eval()
29-
# model = symbolic_trace(model)
30-
# expected_out = model(inp)
31-
32-
# model_orig = copy.deepcopy(model)
33-
# supported_sinks = list(_supported_layers)
34-
# supported_sinks = tuple([
35-
# x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
36-
# regions = _extract_regions(
37-
# model, state_impl_kwargs={'supported_sinks': supported_sinks})
38-
# _ = equalize_test(
39-
# regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
40-
# out = model(inp)
41-
42-
# regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
43-
# resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
44-
# equalized_layers = set()
45-
# for r in resnet_18_regions:
46-
# equalized_layers.update(r[0])
47-
# equalized_layers.update(r[1])
48-
49-
# # Check that we found all the expected regions
50-
# for region, expected_region in zip(regions, resnet_18_regions):
51-
# srcs = region.srcs_names
52-
# sources_check = set(srcs) == set(expected_region[0])
53-
# sinks = region.sinks_names
54-
# sinks_check = set(sinks) == set(expected_region[1])
55-
# assert sources_check
56-
# assert sinks_check
57-
58-
# # Check that all layers were equalized and weights changed
59-
# for layer in equalized_layers:
60-
# eq_module = get_module(model, layer)
61-
# orig_module = get_module(model_orig, layer)
62-
# assert not torch.allclose(eq_module.weight, orig_module.weight)
63-
64-
# # Check that equalization is not introducing FP variations
65-
# assert torch.allclose(expected_out, out, atol=ATOL)
24+
25+
def test_resnet18_equalization():
26+
model = models.resnet18(pretrained=True)
27+
28+
torch.manual_seed(SEED)
29+
inp = torch.randn(IN_SIZE_CONV)
30+
model.eval()
31+
model = symbolic_trace(model)
32+
expected_out = model(inp)
33+
34+
model_orig = copy.deepcopy(model)
35+
supported_sinks = list(_supported_layers)
36+
supported_sinks = tuple([
37+
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
38+
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
39+
_ = equalize_test(
40+
regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
41+
out = model(inp)
42+
43+
regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
44+
resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
45+
equalized_layers = set()
46+
for r in resnet_18_regions:
47+
equalized_layers.update(r[0])
48+
equalized_layers.update(r[1])
49+
50+
# Check that we found all the expected regions
51+
for region, expected_region in zip(regions, resnet_18_regions):
52+
srcs = region.srcs_names
53+
sources_check = set(srcs) == set(expected_region[0])
54+
sinks = region.sinks_names
55+
sinks_check = set(sinks) == set(expected_region[1])
56+
assert sources_check
57+
assert sinks_check
58+
59+
# Check that all layers were equalized and weights changed
60+
for layer in equalized_layers:
61+
eq_module = get_module(model, layer)
62+
orig_module = get_module(model_orig, layer)
63+
assert not torch.allclose(eq_module.weight, orig_module.weight)
64+
65+
# Check that equalization is not introducing FP variations
66+
assert torch.allclose(expected_out, out, atol=ATOL)
6667

6768

6869
@pytest_cases.parametrize("merge_bias", [True, False])
@@ -239,6 +240,7 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool):
239240
assert any([shape != () for shape in shape_scale_regions])
240241

241242

243+
@requires_pt_ge('2.4')
242244
def test_models(rotation_fixtures):
243245

244246
in_shape = IN_SIZE_LINEAR

0 commit comments

Comments
 (0)