Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 401 additions & 80 deletions src/brevitas/graph/equalize.py

Large diffs are not rendered by default.

96,198 changes: 96,198 additions & 0 deletions src/brevitas/graph/hadamard.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import torch

from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.nn.quant_mha import QuantMultiheadAttention

try:
import fast_hadamard_transform
except:
fast_hadamard_transform = None

INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states']


Expand Down Expand Up @@ -41,3 +49,45 @@ def forward(self, *args, **kwargs):
# We convert everything to args so that hooks can work correctly
out = self.layer(*kwargs.values())
return out


class RotatedModule(torch.nn.Module):

def __init__(self, layer, had_mat=None, k=None) -> None:
super().__init__()
if had_mat is not None:
self.had_mat = torch.nn.Parameter(had_mat).cpu()
else:
self.had_mat = None
self.layer = layer
self.k = k

def forward(self, inp, **kwargs):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if is_cuda and fast_hadamard_transform is not None:
if self.had_mat is None or self.k is None:
had_K, K = get_hadK(inp.shape[-1])
else:
had_K = self.had_mat
K = self.k
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)
o = self.layer(inp)

return o


def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
return inp
13 changes: 12 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren
```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -53,7 +54,10 @@ options:
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on the chosen Dataset.
--dataset {wikitext2,c4}
Dataset to use for quantization (default: wikitext2)
Dataset to use for quantization (default: c4)
--gpxq-block-name GPXQ_BLOCK_NAME
Block name for faster GPxQ optimization. It works only
if FX is not needed (default: None)
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse,hqo}
Expand Down Expand Up @@ -121,6 +125,7 @@ options:
--act-calibration Apply activation calibration.
--bias-corr Apply bias correction.
--ln-affine-merge Merge LN affine params.
--replace-rmsnorm Replace HF RMSNorms with Torch one.
--no-quantize Disable quantization.
--no-float16 Disable float16 as base datatype and switch to
float32.
Expand All @@ -129,6 +134,12 @@ options:
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation Apply graph rotation equalization
--graph-rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
--layerwise-rotation Apply layerwise rotation equalization
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
28 changes: 20 additions & 8 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
import torch
from torch import nn

from brevitas.graph.equalize import _is_reshaping_op
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.utils import get_module
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


def replace_rmsnorm_with_torch(model, config):
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)
rewriters = [
ModuleToModuleByClass(
rms_cls, torch.nn.RMSNorm, normalized_shape=config.hidden_size, eps=config.rms_norm_eps)
for rms_cls in set_of_layers]
dtype = next(iter(model.parameters())).dtype
for r in rewriters:
model = r.apply(model)
model = model.to(dtype)
return model


def replace_bias(next_module, new_bias):
Expand Down Expand Up @@ -49,7 +62,7 @@ def merge_layernorm_affine_params(graph_model):
module = get_module(graph_model, node.target)
if isinstance(module, nn.LayerNorm):
for next in node.users:
while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)):
while (_is_scale_invariant_module(graph_model, next)):
next = node.next
if next.op == 'call_module':
next_module = get_module(graph_model, next.target)
Expand Down Expand Up @@ -83,8 +96,7 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(graph_model, dtype):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(graph_model, dtype):
merge_layernorm_affine_params(graph_model)
def apply_layernorm_affine_merge(graph_model):
eq = MergeLnAffine()
graph_model = eq.apply(graph_model)
return graph_model
36 changes: 34 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from brevitas.export import export_torch_qcdq
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
Expand All @@ -30,6 +32,7 @@
from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq
from brevitas_examples.llm.llm_quant.gpxq import apply_gptq
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge
from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch
from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear
from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
Expand Down Expand Up @@ -196,18 +199,34 @@ def main(args):
remove_hooks(model)
print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}")

if args.replace_rmsnorm:
model = replace_rmsnorm_with_torch(model, model.config)

if require_fx:
model = get_fx(model)
with torch.no_grad():
model, guards = torch._dynamo.export(model)(**calibration_loader[0])
# Blockwise optimization does not work with FX at the moment
args.gpxq_block_name = None

# Apply LN affine merging before inserting MHA layers
# since currently there is support only for merging into Linear
if args.ln_affine_merge:
print("Apply LN affine merge...")
apply_layernorm_affine_merge(model, dtype)
apply_layernorm_affine_merge(model)
print("LN affine merge applied.")

if args.graph_rotation:
assert args.ln_affine_merge
assert args.replace_rmsnorm
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=True, full_rotation_method=args.graph_rotation_mode)
model = eq.apply(model)
remove_hooks(model)
elif args.layerwise_rotation:
eq = LayerwiseActivationRotation()
model = eq.apply(model)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.replace_mha:
Expand Down Expand Up @@ -497,6 +516,8 @@ def parse_args(args):
'--act-calibration', action='store_true', help='Apply activation calibration.')
parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')
parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.')
parser.add_argument(
'--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.')
parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.')
parser.add_argument(
'--no-float16',
Expand All @@ -510,6 +531,17 @@ def parse_args(args):
'--weight-equalization',
action='store_true',
help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).')
parser.add_argument(
'--graph-rotation', action='store_true', help='Apply graph rotation equalization')
parser.add_argument(
'--graph-rotation-mode',
default='had',
choices=['had', 'ort'],
help=
'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard'
)
parser.add_argument(
'--layerwise-rotation', action='store_true', help='Apply layerwise rotation equalization')
parser.add_argument(
'--act-equalization',
default=None,
Expand Down
37 changes: 37 additions & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,40 @@ def forward(self, x):

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)

## List of Rotation fixtures


@pytest_cases.fixture
def linear_rms():

class LinearRMSModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(3, 4, bias=True)
self.linear.weight.data.fill_(2.)
self.linear.bias.data.fill_(1.)
self.rms = nn.RMSNorm(4)
self.rms.weight.data = torch.randn_like(
self.rms.weight.data) # Change learned parameters
self.linear_1 = nn.Linear(4, 8, bias=False)
self.linear_1.weight.data.fill_(2.)
self.linear_2 = nn.Linear(8, 8, bias=False)

def forward(self, x):
x = self.linear(x)
x = self.rms(x)
x = self.linear_1(x)
x = self.linear_2(x) * x
x = torch.matmul(x.flatten(1), x.flatten(1).t())

return x

return LinearRMSModel


list_of_rotation_mixtures = ['linear_rms']

rotation_fixtures = fixture_union(
'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures)
63 changes: 57 additions & 6 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _is_supported_module
from brevitas.graph.equalize import _supported_layers
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
from tests.marker import requires_pt_ge

from .equalization_fixtures import *

Expand All @@ -28,14 +32,14 @@ def test_resnet18_equalization():
expected_out = model(inp)

model_orig = copy.deepcopy(model)
regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
_ = equalize_test(
regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
out = model(inp)

# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)

regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
equalized_layers = set()
Expand All @@ -58,6 +62,9 @@ def test_resnet18_equalization():
orig_module = get_module(model_orig, layer)
assert not torch.allclose(eq_module.weight, orig_module.weight)

# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest_cases.parametrize("merge_bias", [True, False])
def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool):
Expand All @@ -73,7 +80,10 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool

expected_out = model(inp)

regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
scale_factor_regions = equalize_test(
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]
Expand Down Expand Up @@ -126,7 +136,10 @@ def test_models(toy_model, merge_bias, request):
expected_out = model(inp)

model = symbolic_trace(model)
regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
scale_factor_regions = equalize_test(
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]
Expand Down Expand Up @@ -225,3 +238,41 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool):
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert any([shape != () for shape in shape_scale_regions])


@requires_pt_ge('2.4')
@pytest_cases.parametrize('partial_had', [True, False])
def test_models(rotation_fixtures, partial_had):

in_shape = IN_SIZE_LINEAR

model_class = rotation_fixtures
model = model_class()
inp = torch.ones(in_shape)

model.eval()
penultimate_weight = model.linear_1.weight.data
last_weight = model.linear_2.weight.data
with torch.no_grad():
expected_out = model(inp)

model = symbolic_trace(model)
merge = MergeLnAffine()
model = merge.apply(model)
eq = GraphRotationEqualization(orphan_sink=partial_had)
model = eq.apply(model)

with torch.no_grad():
out = model(inp)

penultimate_weight_new = model.linear_1.weight.data

# Invariance of the output
assert torch.allclose(out, expected_out, atol=ATOL)
# Rotate weights must be different
assert not torch.allclose(penultimate_weight, penultimate_weight_new)
# Merging affine parameters of RMS
assert torch.allclose(model.rms.weight.data, torch.ones_like(model.rms.weight.data))
if partial_had:
last_weight_new = model.linear_2.layer.weight.data
assert not torch.allclose(last_weight, last_weight_new)
Loading