Skip to content

Commit

Permalink
refactor(nn): add QModuleWrapper
Browse files Browse the repository at this point in the history
Working for weights
  • Loading branch information
dacorvo committed Nov 6, 2024
1 parent 7aaf99e commit 4b732d1
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 1 deletion.
186 changes: 185 additions & 1 deletion optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from abc import ABC
from contextlib import nullcontext
from typing import Optional, Union

import torch
Expand All @@ -22,6 +23,7 @@
ActivationQBytesTensor,
MaxOptimizer,
Optimizer,
QDynamicTensor,
QTensor,
SymmetricOptimizer,
WeightQBitsTensor,
Expand All @@ -35,7 +37,7 @@
)


__all__ = ["QModuleMixin", "register_qmodule", "quantize_module"]
__all__ = ["QModuleMixin", "QModuleWrapper", "register_qmodule", "quantize_module"]


_QMODULE_TABLE = {}
Expand Down Expand Up @@ -306,3 +308,185 @@ def freeze(self):
@property
def frozen(self):
return isinstance(self.weight, QTensor)


class QModuleWrapper(torch.nn.Module):

def __init__(self,
module: torch.nn.Module,
weights: Optional[qtype] = None,
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
super().__init__()
self._wrapped = module
if weights is not None and not isinstance(weights, qtype):
weights = qtypes[weights]
if activations is not None and not isinstance(activations, qtype):
activations = qtypes[activations]
self.weight_qtype = weights
self.weight_group_size = None
if self.weight_qtype in (qint2, qint4):
out_features = self.weight.shape[0]
in_features = self.weight.numel() // out_features
group_size = 128
if in_features > group_size:
while in_features % group_size != 0 and group_size > 32:
group_size -= 32
if in_features % group_size == 0:
self.weight_group_size = group_size
self.activation_qtype = activations
self._quantize_hooks = {}
if activations is not None:
if weights is not None:
self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input)
self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output)
if optimizer is None and self.weight_qtype is not None:
optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer()
self.optimizer = optimizer
weight = getattr(module, "weight", None)
scale_dtype = torch.float32 if weight is None else weight.dtype
device = torch.device('cpu') if weight is None else weight.device
self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device))
self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device))

def forward(self, *args, **kwargs):
if self.weight_qtype is None or self.frozen:
qcontext = nullcontext()
else:
# The wrapped module weight must be dynamically quantized
qcontext = QDynamicTensor(self._wrapped.weight,
qtype=self.weight_qtype,
axis=0,
optimizer=self.optimizer)
with qcontext:
return self._wrapped.forward(*args, **kwargs)

def disable_output_quantization(self):
if "output" in self._quantize_hooks:
self._quantize_hooks["output"].remove()

def _save_to_state_dict(self, destination, prefix, keep_vars):
"""TODO: fix this
"""
if self.weight_qtype is None or not self.frozen:
# Save standard weight Tensor
destination[prefix + "weight"] = self.weight if keep_vars else self.weight.detach()
else:
# Save QTensor using dedicated method
self.weight.save_to_state_dict(destination, prefix + "weight.", keep_vars)
if self.bias is not None:
destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach()
destination[prefix + "input_scale"] = self.input_scale if keep_vars else self.input_scale.detach()
destination[prefix + "output_scale"] = self.output_scale if keep_vars else self.output_scale.detach()

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""TODO: fix this
"""
weight_name = prefix + "weight"
if self.weight_qtype is not None and weight_name not in state_dict:
# The weight Tensor is not present because it is a flattened QTensor
weight_prefix = weight_name + "."
# note: deserialized_weight can be None if a key is missing in the state_dict
if self.weight_qtype.bits == 8:
deserialized_weight = WeightQBytesTensor.load_from_state_dict(
state_dict,
weight_prefix,
qtype=self.weight_qtype,
axis=0,
size=self.weight.size(),
stride=self.weight.stride(),
activation_qtype=self.activation_qtype,
missing_keys=missing_keys,
)
else:
deserialized_weight = WeightQBitsTensor.load_from_state_dict(
state_dict,
weight_prefix,
qtype=self.weight_qtype,
axis=0,
group_size=self.weight_group_size,
size=self.weight.size(),
stride=self.weight.stride(),
missing_keys=missing_keys,
)
if deserialized_weight is not None:
deserialized_weight = deserialized_weight.optimize()

assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
if assign_to_params_buffers and (deserialized_weight is not None):
self.weight = torch.nn.Parameter(deserialized_weight)
elif deserialized_weight is not None:
if type(self.weight.data) is not type(deserialized_weight):
# Reloading frozen weights into unfrozen module: move to the correct device and force assignment
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))
else:
# FIXME: here we should copy frozen weights into frozen module, but this leads to grad error
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))

super()._load_from_state_dict(
state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs
)

@property
def qweight(self):
"""Return the module quantized weight
When the module is frozen or does not quantize its weight parameter, it simply
returns the weight.
When the module is not frozen, this property is required to add the dynamic quantization
of the weight parameter to the graph and allow gradients to be propagated to the
underlying weight float values.
"""
if isinstance(self._wrapped.weight, QTensor):
# Frozen QModule
return self._wrapped.weight
# Quantize dynamically the weights per-axis
if isinstance(self.optimizer, SymmetricOptimizer):
scale = self.optimizer(self._wrapped.weight, qtype=self.weight_qtype, axis=0)
shift = None
else:
scale, shift = self.optimizer(
self.weight, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size
)
return quantize_weight(
self._wrapped.weight,
qtype=self.weight_qtype,
axis=0,
scale=scale,
shift=shift,
group_size=self.weight_group_size,
activation_qtype=self.activation_qtype,
)

def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor:
input = input[0]
if isinstance(input, ActivationQBytesTensor):
if input.qtype != self.activation_qtype:
raise ValueError(
"Models with heterogeneous quantized activations are not supported:"
f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead."
)
else:
input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale)
return input

def quantize_output(
self,
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale)

def freeze(self):
qweight = self.qweight
if qweight is not None:
# Replace float weights by quantized weights
self._wrapped.weight = torch.nn.Parameter(qweight)

@property
def frozen(self):
return isinstance(self._wrapped.weight, QTensor)
1 change: 1 addition & 0 deletions optimum/quanto/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .optimizers import *
from .qbits import *
from .qbytes import *
from .qdynamic import *
from .qtensor import *
from .qtype import *
from .weights import *
105 changes: 105 additions & 0 deletions optimum/quanto/tensor/qdynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.overrides import TorchFunctionMode

from .optimizers import AbsmaxOptimizer, MaxOptimizer, SymmetricOptimizer
from .qtensor import QTensor
from .qtype import qint2, qint4, qtype
from .weights import quantize_weight


__all__ = ["QDynamicTensor"]


class QDynamicTensor(TorchFunctionMode):
"""A custom torch dispatch mode that uses dynamically quantized tensors.
Args:
tensor (`torch.Tensor`): the torch.Tensor that will be dynamically quantized
qtype (`qtype`): the qtype to use to quantize the Tensor.
axis (`int`): the quantization axis.
optimizer (`Optimizer`): the optimizer to use to get the quantization parameters.
"""

def __init__(self, tensor: torch.Tensor, qtype: qtype, axis: int, optimizer=None):
super().__init__()
assert not isinstance(tensor, QTensor)
self.tensor = tensor
self.qtype = qtype
self.axis = axis
self.group_size = None
if qtype in (qint2, qint4):
axis_dim = tensor.shape[axis]
other_dim = tensor.numel() // axis_dim
group_size = 128
if other_dim > group_size:
while other_dim % group_size != 0 and group_size > 32:
group_size -= 32
if other_dim % group_size == 0:
self.group_size = group_size
if optimizer is None:
optimizer = AbsmaxOptimizer() if qtype.bits == 8 else MaxOptimizer()
self.optimizer = optimizer

def qtensor(self, other_qtype: qtype = None):
"""Return the dynamically quantized QTensor
"""
# Quantize dynamically the tensor per-axis
if isinstance(self.optimizer, SymmetricOptimizer):
scale = self.optimizer(self.tensor, qtype=self.qtype, axis=self.axis)
shift = None
else:
scale, shift = self.optimizer(
self.tensor, qtype=self.qtype, axis=self.axis, group_size=self.group_size
)
return quantize_weight(
self.tensor,
qtype=self.qtype,
axis=self.axis,
scale=scale,
shift=shift,
group_size=self.group_size,
activation_qtype=other_qtype,
)

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs is not None else {}
other_qtype = None
new_args = []
for arg in args:
new_arg = arg
if isinstance(arg, QTensor):
if other_qtype is None:
other_qtype = arg.qtype
else:
assert arg.qtype == other_qtype
else:
qtag = getattr(arg, "qtag", None)
if qtag == self.tensor.qtag:
# Replace the tensor by its dynamically quantized version
new_arg = self.qtensor(other_qtype)
new_args.append(new_arg)
return func(*new_args, **kwargs)

def __enter__(self):
super().__enter__()
# Tag the target Tensor to identify it when dispatching
self.tensor.qtag = id(self)

def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
# Untag the target Tensor
delattr(self.tensor, "qtag")

0 comments on commit 4b732d1

Please sign in to comment.