Skip to content

Commit

Permalink
Fix tests and imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 21, 2024
1 parent a053780 commit 5942658
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 86 deletions.
94 changes: 51 additions & 43 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import warnings

import packaging
import packaging.version
import torch
from torch.fx import GraphModule as TorchGraphModule
import torch.nn as nn

from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.fx import Node
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
Expand All @@ -29,13 +33,18 @@
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.torch_utils import KwargsForwardHook

from .base import InsertModuleCallAfter

# External optional dependency
try:
import fast_hadamard_transform
except:
fast_hadamard_transform = None

# RMSNorm was introduced with torch 2.4
if torch_version >= packaging.version.parse('2.4'):
RMSNorm = nn.RMSNorm
else:
RMSNorm = object

__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']

EPSILON = 1e-9
Expand Down Expand Up @@ -77,14 +86,13 @@
operator.imul,
operator.__mul__,
operator.__imul__,
torch.nn.functional.interpolate)
nn.functional.interpolate)

_select_op = (operator.getitem, operator.__getitem__)

_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', 'to', torch.reshape, torch.flatten)

_scale_varying_activations = (
torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU)
_scale_varying_activations = (nn.Sigmoid, nn.Tanh, nn.ReLU6, nn.GELU, nn.SiLU)

_residual_methods = ('add', 'add_')

Expand All @@ -95,31 +103,6 @@
_ignore_ops = (getattr, 'size')


def _is_supported_module(
graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool:
if node.op == 'call_module':
module = get_module(graph_model, node.target)
if isinstance(module, supported_layers):
# We support only self-attention
if isinstance(module, nn.MultiheadAttention):
kwargs = dict(node.kwargs)
# When using hf/accelerate, we need to check the signature of the original forward
forward_to_check = module._old_forward if hasattr(
module, '_old_forward') else module.forward
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args))
return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name
return True
return False


def _is_scale_invariant_module(
graph_model: GraphModule,
node: Node,
scale_invariant_layers=_scale_invariant_layers) -> bool:
return node.op == 'call_module' and isinstance(
get_module(graph_model, node.target), scale_invariant_layers)


# Start and End identify the starting and ending channels of the weight matrix that need to be
# equalized.
# Offset refers to the relative position of these channels with respect to
Expand Down Expand Up @@ -334,7 +317,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]:
return 0
elif module.groups == module.out_channels:
return 1
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
# We assume normalization happens only along the channel dimension
if len(module.weight.shape) == 1:
return 0
Expand Down Expand Up @@ -362,7 +345,7 @@ def _get_output_axis(module: nn.Module) -> Optional[int]:
elif isinstance(module,
(nn.Embedding, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
return 1
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
# We assume normalization happens only along the channel dimension
if len(module.weight.shape) == 1:
return 0
Expand Down Expand Up @@ -687,6 +670,31 @@ def _equalize(
return model


def _is_supported_module(
graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool:
if node.op == 'call_module':
module = get_module(graph_model, node.target)
if isinstance(module, supported_layers):
# We support only self-attention
if isinstance(module, nn.MultiheadAttention):
kwargs = dict(node.kwargs)
# When using hf/accelerate, we need to check the signature of the original forward
forward_to_check = module._old_forward if hasattr(
module, '_old_forward') else module.forward
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args))
return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name
return True
return False


def _is_scale_invariant_module(
graph_model: GraphModule,
node: Node,
scale_invariant_layers=_scale_invariant_layers) -> bool:
return node.op == 'call_module' and isinstance(
get_module(graph_model, node.target), scale_invariant_layers)


def _is_scale_varying_activation(graph_model, node):
return node.op == 'call_module' and isinstance(
get_module(graph_model, node.target), _scale_varying_activations)
Expand All @@ -696,7 +704,7 @@ def _is_scale_invariant_function(node: Node, scale_invariant_op: Set = _scale_in
out = node.op in (
'call_function',
'call_method') and node.target in scale_invariant_op + _select_op + _reshaping_op
if node.target == torch.nn.functional.interpolate:
if node.target == nn.functional.interpolate:
out &= node.kwargs.get('mode', None) == 'nearest'
return out

Expand Down Expand Up @@ -959,7 +967,7 @@ def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
# It is not possible to equalize through LayerNorm/BatchNorm as sink
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(
graph_model, state_impl_kwargs={'supported_sinks': supported_sinks})
if len(regions) > 0:
Expand Down Expand Up @@ -1135,7 +1143,7 @@ def __init__(

# It is not possible to equalize through LayerNorm/BatchNorm as sink
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
x for x in _supported_layers if x not in (nn.LayerNorm, *_batch_norm)])
self.regions = _extract_regions(
model,
add_mul_node=add_mul_node,
Expand Down Expand Up @@ -1305,9 +1313,9 @@ class GraphRotationEqualization(GraphTransform):
def __init__(self) -> None:
super(GraphRotationEqualization, self).__init__()

self.supported_srcs = (torch.nn.Linear, torch.nn.Embedding)
self.supported_sinks = (torch.nn.Linear)
self.scale_invariant_layers = (torch.nn.RMSNorm,)
self.supported_srcs = (nn.Linear, nn.Embedding)
self.supported_sinks = (nn.Linear)
self.scale_invariant_layers = (RMSNorm,)
self.scale_invariant_function = ()

def apply(self,
Expand All @@ -1332,7 +1340,7 @@ def _replace_bias(next_module, new_bias):
next_module.bias.data.copy_(new_bias)
else:
new_bias = new_bias.to(next_module.weight.device).to(next_module.weight.dtype)
next_module.register_parameter('bias', torch.nn.Parameter(new_bias))
next_module.register_parameter('bias', nn.Parameter(new_bias))


def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
Expand All @@ -1342,7 +1350,7 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
layer_norm.bias.data /= layer_norm.weight.data
# We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor
scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight)
next_module.weight = torch.nn.Parameter(next_module.weight.clone() * scale)
next_module.weight = nn.Parameter(next_module.weight.clone() * scale)

# Merge bias, new_bias includes the bias of next_module by going through its fwd
if hasattr(layer_norm, 'bias'):
Expand All @@ -1355,8 +1363,8 @@ class MergeLnAffine(GraphTransform):

def __init__(self) -> None:
super(MergeLnAffine, self).__init__()
self.supported_srcs = (torch.nn.RMSNorm, torch.nn.LayerNorm)
self.supported_sinks = (torch.nn.Linear)
self.supported_srcs = (RMSNorm, nn.LayerNorm)
self.supported_sinks = (nn.Linear)

def apply(self, graph_model: GraphModule) -> GraphModule:
regions = _extract_regions(
Expand Down Expand Up @@ -1388,7 +1396,7 @@ class LayerwiseActivationRotation(GraphTransform):
def __init__(self, blacklist_layer=None):
super(GraphTransform, self).__init__()

self.supported_sinks = (torch.nn.Linear)
self.supported_sinks = (nn.Linear)
self.blacklist_layers = blacklist_layer

def find_module(self, model, regions: List, prefix=''):
Expand Down
88 changes: 45 additions & 43 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,53 @@
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
from tests.marker import requires_pt_ge

from .equalization_fixtures import *

# def test_resnet18_equalization():
# model = models.resnet18(pretrained=True)

# torch.manual_seed(SEED)
# inp = torch.randn(IN_SIZE_CONV)
# model.eval()
# model = symbolic_trace(model)
# expected_out = model(inp)

# model_orig = copy.deepcopy(model)
# supported_sinks = list(_supported_layers)
# supported_sinks = tuple([
# x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
# regions = _extract_regions(
# model, state_impl_kwargs={'supported_sinks': supported_sinks})
# _ = equalize_test(
# regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
# out = model(inp)

# regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
# resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
# equalized_layers = set()
# for r in resnet_18_regions:
# equalized_layers.update(r[0])
# equalized_layers.update(r[1])

# # Check that we found all the expected regions
# for region, expected_region in zip(regions, resnet_18_regions):
# srcs = region.srcs_names
# sources_check = set(srcs) == set(expected_region[0])
# sinks = region.sinks_names
# sinks_check = set(sinks) == set(expected_region[1])
# assert sources_check
# assert sinks_check

# # Check that all layers were equalized and weights changed
# for layer in equalized_layers:
# eq_module = get_module(model, layer)
# orig_module = get_module(model_orig, layer)
# assert not torch.allclose(eq_module.weight, orig_module.weight)

# # Check that equalization is not introducing FP variations
# assert torch.allclose(expected_out, out, atol=ATOL)

def test_resnet18_equalization():
model = models.resnet18(pretrained=True)

torch.manual_seed(SEED)
inp = torch.randn(IN_SIZE_CONV)
model.eval()
model = symbolic_trace(model)
expected_out = model(inp)

model_orig = copy.deepcopy(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
_ = equalize_test(
regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
out = model(inp)

regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
equalized_layers = set()
for r in resnet_18_regions:
equalized_layers.update(r[0])
equalized_layers.update(r[1])

# Check that we found all the expected regions
for region, expected_region in zip(regions, resnet_18_regions):
srcs = region.srcs_names
sources_check = set(srcs) == set(expected_region[0])
sinks = region.sinks_names
sinks_check = set(sinks) == set(expected_region[1])
assert sources_check
assert sinks_check

# Check that all layers were equalized and weights changed
for layer in equalized_layers:
eq_module = get_module(model, layer)
orig_module = get_module(model_orig, layer)
assert not torch.allclose(eq_module.weight, orig_module.weight)

# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest_cases.parametrize("merge_bias", [True, False])
Expand Down Expand Up @@ -239,6 +240,7 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool):
assert any([shape != () for shape in shape_scale_regions])


@requires_pt_ge('2.4')
def test_models(rotation_fixtures):

in_shape = IN_SIZE_LINEAR
Expand Down

0 comments on commit 5942658

Please sign in to comment.