10
10
from typing import Callable , Dict , List , Optional , Set , Tuple , Union
11
11
import warnings
12
12
13
+ import packaging
14
+ import packaging .version
13
15
import torch
14
16
from torch .fx import GraphModule as TorchGraphModule
15
17
import torch .nn as nn
16
18
19
+ from brevitas import torch_version
17
20
from brevitas .fx import GraphModule
18
21
from brevitas .fx import Node
19
22
from brevitas .graph .base import GraphTransform
23
+ from brevitas .graph .base import InsertModuleCallAfter
20
24
from brevitas .graph .base import ModuleInstanceToModuleInstance
21
25
from brevitas .graph .hadamard import get_hadK
22
26
from brevitas .graph .hadamard import matmul_hadU
29
33
from brevitas .nn .quant_scale_bias import ScaleBias
30
34
from brevitas .utils .torch_utils import KwargsForwardHook
31
35
32
- from .base import InsertModuleCallAfter
33
-
36
+ # External optional dependency
34
37
try :
35
38
import fast_hadamard_transform
36
39
except :
37
40
fast_hadamard_transform = None
38
41
42
+ # RMSNorm was introduced with torch 2.4
43
+ if torch_version >= packaging .version .parse ('2.4' ):
44
+ RMSNorm = nn .RMSNorm
45
+ else :
46
+ RMSNorm = object
47
+
39
48
__all__ = ['GraphActivationEqualization' , 'LayerwiseActivationEqualization' , 'EqualizeGraph' ]
40
49
41
50
EPSILON = 1e-9
77
86
operator .imul ,
78
87
operator .__mul__ ,
79
88
operator .__imul__ ,
80
- torch . nn .functional .interpolate )
89
+ nn .functional .interpolate )
81
90
82
91
_select_op = (operator .getitem , operator .__getitem__ )
83
92
84
93
_reshaping_op = ('view' , 'reshape' , 'flatten' , 'contiguous' , 'to' , torch .reshape , torch .flatten )
85
94
86
- _scale_varying_activations = (
87
- torch .nn .Sigmoid , torch .nn .Tanh , torch .nn .ReLU6 , torch .nn .GELU , torch .nn .SiLU )
95
+ _scale_varying_activations = (nn .Sigmoid , nn .Tanh , nn .ReLU6 , nn .GELU , nn .SiLU )
88
96
89
97
_residual_methods = ('add' , 'add_' )
90
98
95
103
_ignore_ops = (getattr , 'size' )
96
104
97
105
98
- def _is_supported_module (
99
- graph_model : GraphModule , node : Node , supported_layers : Set = _supported_layers ) -> bool :
100
- if node .op == 'call_module' :
101
- module = get_module (graph_model , node .target )
102
- if isinstance (module , supported_layers ):
103
- # We support only self-attention
104
- if isinstance (module , nn .MultiheadAttention ):
105
- kwargs = dict (node .kwargs )
106
- # When using hf/accelerate, we need to check the signature of the original forward
107
- forward_to_check = module ._old_forward if hasattr (
108
- module , '_old_forward' ) else module .forward
109
- kwargs .update (zip (forward_to_check .__code__ .co_varnames [1 :], node .args ))
110
- return kwargs ['query' ].name == kwargs ['key' ].name == kwargs ['value' ].name
111
- return True
112
- return False
113
-
114
-
115
- def _is_scale_invariant_module (
116
- graph_model : GraphModule ,
117
- node : Node ,
118
- scale_invariant_layers = _scale_invariant_layers ) -> bool :
119
- return node .op == 'call_module' and isinstance (
120
- get_module (graph_model , node .target ), scale_invariant_layers )
121
-
122
-
123
106
# Start and End identify the starting and ending channels of the weight matrix that need to be
124
107
# equalized.
125
108
# Offset refers to the relative position of these channels with respect to
@@ -334,7 +317,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]:
334
317
return 0
335
318
elif module .groups == module .out_channels :
336
319
return 1
337
- elif isinstance (module , (nn .LayerNorm , nn . RMSNorm )):
320
+ elif isinstance (module , (nn .LayerNorm , RMSNorm )):
338
321
# We assume normalization happens only along the channel dimension
339
322
if len (module .weight .shape ) == 1 :
340
323
return 0
@@ -362,7 +345,7 @@ def _get_output_axis(module: nn.Module) -> Optional[int]:
362
345
elif isinstance (module ,
363
346
(nn .Embedding , nn .ConvTranspose1d , nn .ConvTranspose2d , nn .ConvTranspose3d )):
364
347
return 1
365
- elif isinstance (module , (nn .LayerNorm , nn . RMSNorm )):
348
+ elif isinstance (module , (nn .LayerNorm , RMSNorm )):
366
349
# We assume normalization happens only along the channel dimension
367
350
if len (module .weight .shape ) == 1 :
368
351
return 0
@@ -687,6 +670,31 @@ def _equalize(
687
670
return model
688
671
689
672
673
+ def _is_supported_module (
674
+ graph_model : GraphModule , node : Node , supported_layers : Set = _supported_layers ) -> bool :
675
+ if node .op == 'call_module' :
676
+ module = get_module (graph_model , node .target )
677
+ if isinstance (module , supported_layers ):
678
+ # We support only self-attention
679
+ if isinstance (module , nn .MultiheadAttention ):
680
+ kwargs = dict (node .kwargs )
681
+ # When using hf/accelerate, we need to check the signature of the original forward
682
+ forward_to_check = module ._old_forward if hasattr (
683
+ module , '_old_forward' ) else module .forward
684
+ kwargs .update (zip (forward_to_check .__code__ .co_varnames [1 :], node .args ))
685
+ return kwargs ['query' ].name == kwargs ['key' ].name == kwargs ['value' ].name
686
+ return True
687
+ return False
688
+
689
+
690
+ def _is_scale_invariant_module (
691
+ graph_model : GraphModule ,
692
+ node : Node ,
693
+ scale_invariant_layers = _scale_invariant_layers ) -> bool :
694
+ return node .op == 'call_module' and isinstance (
695
+ get_module (graph_model , node .target ), scale_invariant_layers )
696
+
697
+
690
698
def _is_scale_varying_activation (graph_model , node ):
691
699
return node .op == 'call_module' and isinstance (
692
700
get_module (graph_model , node .target ), _scale_varying_activations )
@@ -696,7 +704,7 @@ def _is_scale_invariant_function(node: Node, scale_invariant_op: Set = _scale_in
696
704
out = node .op in (
697
705
'call_function' ,
698
706
'call_method' ) and node .target in scale_invariant_op + _select_op + _reshaping_op
699
- if node .target == torch . nn .functional .interpolate :
707
+ if node .target == nn .functional .interpolate :
700
708
out &= node .kwargs .get ('mode' , None ) == 'nearest'
701
709
return out
702
710
@@ -959,7 +967,7 @@ def apply(self,
959
967
graph_model : GraphModule ) -> Union [Tuple [GraphModule , Set [Tuple [str ]]], GraphModule ]:
960
968
# It is not possible to equalize through LayerNorm/BatchNorm as sink
961
969
supported_sinks = tuple ([
962
- x for x in _supported_layers if x not in (torch . nn .LayerNorm , * _batch_norm )])
970
+ x for x in _supported_layers if x not in (nn .LayerNorm , * _batch_norm )])
963
971
regions = _extract_regions (
964
972
graph_model , state_impl_kwargs = {'supported_sinks' : supported_sinks })
965
973
if len (regions ) > 0 :
@@ -1135,7 +1143,7 @@ def __init__(
1135
1143
1136
1144
# It is not possible to equalize through LayerNorm/BatchNorm as sink
1137
1145
supported_sinks = tuple ([
1138
- x for x in _supported_layers if x not in (torch . nn .LayerNorm , * _batch_norm )])
1146
+ x for x in _supported_layers if x not in (nn .LayerNorm , * _batch_norm )])
1139
1147
self .regions = _extract_regions (
1140
1148
model ,
1141
1149
add_mul_node = add_mul_node ,
@@ -1305,9 +1313,9 @@ class GraphRotationEqualization(GraphTransform):
1305
1313
def __init__ (self ) -> None :
1306
1314
super (GraphRotationEqualization , self ).__init__ ()
1307
1315
1308
- self .supported_srcs = (torch . nn .Linear , torch . nn .Embedding )
1309
- self .supported_sinks = (torch . nn .Linear )
1310
- self .scale_invariant_layers = (torch . nn . RMSNorm ,)
1316
+ self .supported_srcs = (nn .Linear , nn .Embedding )
1317
+ self .supported_sinks = (nn .Linear )
1318
+ self .scale_invariant_layers = (RMSNorm ,)
1311
1319
self .scale_invariant_function = ()
1312
1320
1313
1321
def apply (self ,
@@ -1332,7 +1340,7 @@ def _replace_bias(next_module, new_bias):
1332
1340
next_module .bias .data .copy_ (new_bias )
1333
1341
else :
1334
1342
new_bias = new_bias .to (next_module .weight .device ).to (next_module .weight .dtype )
1335
- next_module .register_parameter ('bias' , torch . nn .Parameter (new_bias ))
1343
+ next_module .register_parameter ('bias' , nn .Parameter (new_bias ))
1336
1344
1337
1345
1338
1346
def _merge_ln (layer_norm , next_module , scale_bias_by_weight ):
@@ -1342,7 +1350,7 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
1342
1350
layer_norm .bias .data /= layer_norm .weight .data
1343
1351
# We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor
1344
1352
scale = layer_norm .weight .data .view (view_shape ).expand_as (next_module .weight )
1345
- next_module .weight = torch . nn .Parameter (next_module .weight .clone () * scale )
1353
+ next_module .weight = nn .Parameter (next_module .weight .clone () * scale )
1346
1354
1347
1355
# Merge bias, new_bias includes the bias of next_module by going through its fwd
1348
1356
if hasattr (layer_norm , 'bias' ):
@@ -1355,8 +1363,8 @@ class MergeLnAffine(GraphTransform):
1355
1363
1356
1364
def __init__ (self ) -> None :
1357
1365
super (MergeLnAffine , self ).__init__ ()
1358
- self .supported_srcs = (torch . nn . RMSNorm , torch . nn .LayerNorm )
1359
- self .supported_sinks = (torch . nn .Linear )
1366
+ self .supported_srcs = (RMSNorm , nn .LayerNorm )
1367
+ self .supported_sinks = (nn .Linear )
1360
1368
1361
1369
def apply (self , graph_model : GraphModule ) -> GraphModule :
1362
1370
regions = _extract_regions (
@@ -1388,7 +1396,7 @@ class LayerwiseActivationRotation(GraphTransform):
1388
1396
def __init__ (self , blacklist_layer = None ):
1389
1397
super (GraphTransform , self ).__init__ ()
1390
1398
1391
- self .supported_sinks = (torch . nn .Linear )
1399
+ self .supported_sinks = (nn .Linear )
1392
1400
self .blacklist_layers = blacklist_layer
1393
1401
1394
1402
def find_module (self , model , regions : List , prefix = '' ):
0 commit comments