From 4b732d1372edf9a90e161dd403dec9df319a844c Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 6 Nov 2024 20:24:42 +0100 Subject: [PATCH] refactor(nn): add QModuleWrapper Working for weights --- optimum/quanto/nn/qmodule.py | 186 +++++++++++++++++++++++++++++- optimum/quanto/tensor/__init__.py | 1 + optimum/quanto/tensor/qdynamic.py | 105 +++++++++++++++++ 3 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 optimum/quanto/tensor/qdynamic.py diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index 932d4ac4..c47b4bf3 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC +from contextlib import nullcontext from typing import Optional, Union import torch @@ -22,6 +23,7 @@ ActivationQBytesTensor, MaxOptimizer, Optimizer, + QDynamicTensor, QTensor, SymmetricOptimizer, WeightQBitsTensor, @@ -35,7 +37,7 @@ ) -__all__ = ["QModuleMixin", "register_qmodule", "quantize_module"] +__all__ = ["QModuleMixin", "QModuleWrapper", "register_qmodule", "quantize_module"] _QMODULE_TABLE = {} @@ -306,3 +308,185 @@ def freeze(self): @property def frozen(self): return isinstance(self.weight, QTensor) + + +class QModuleWrapper(torch.nn.Module): + + def __init__(self, + module: torch.nn.Module, + weights: Optional[qtype] = None, + activations: Optional[qtype] = None, + optimizer: Optional[Optimizer] = None, + ): + super().__init__() + self._wrapped = module + if weights is not None and not isinstance(weights, qtype): + weights = qtypes[weights] + if activations is not None and not isinstance(activations, qtype): + activations = qtypes[activations] + self.weight_qtype = weights + self.weight_group_size = None + if self.weight_qtype in (qint2, qint4): + out_features = self.weight.shape[0] + in_features = self.weight.numel() // out_features + group_size = 128 + if in_features > group_size: + while in_features % group_size != 0 and group_size > 32: + group_size -= 32 + if in_features % group_size == 0: + self.weight_group_size = group_size + self.activation_qtype = activations + self._quantize_hooks = {} + if activations is not None: + if weights is not None: + self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input) + self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output) + if optimizer is None and self.weight_qtype is not None: + optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer() + self.optimizer = optimizer + weight = getattr(module, "weight", None) + scale_dtype = torch.float32 if weight is None else weight.dtype + device = torch.device('cpu') if weight is None else weight.device + self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device)) + self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device)) + + def forward(self, *args, **kwargs): + if self.weight_qtype is None or self.frozen: + qcontext = nullcontext() + else: + # The wrapped module weight must be dynamically quantized + qcontext = QDynamicTensor(self._wrapped.weight, + qtype=self.weight_qtype, + axis=0, + optimizer=self.optimizer) + with qcontext: + return self._wrapped.forward(*args, **kwargs) + + def disable_output_quantization(self): + if "output" in self._quantize_hooks: + self._quantize_hooks["output"].remove() + + def _save_to_state_dict(self, destination, prefix, keep_vars): + """TODO: fix this + """ + if self.weight_qtype is None or not self.frozen: + # Save standard weight Tensor + destination[prefix + "weight"] = self.weight if keep_vars else self.weight.detach() + else: + # Save QTensor using dedicated method + self.weight.save_to_state_dict(destination, prefix + "weight.", keep_vars) + if self.bias is not None: + destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach() + destination[prefix + "input_scale"] = self.input_scale if keep_vars else self.input_scale.detach() + destination[prefix + "output_scale"] = self.output_scale if keep_vars else self.output_scale.detach() + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """TODO: fix this + """ + weight_name = prefix + "weight" + if self.weight_qtype is not None and weight_name not in state_dict: + # The weight Tensor is not present because it is a flattened QTensor + weight_prefix = weight_name + "." + # note: deserialized_weight can be None if a key is missing in the state_dict + if self.weight_qtype.bits == 8: + deserialized_weight = WeightQBytesTensor.load_from_state_dict( + state_dict, + weight_prefix, + qtype=self.weight_qtype, + axis=0, + size=self.weight.size(), + stride=self.weight.stride(), + activation_qtype=self.activation_qtype, + missing_keys=missing_keys, + ) + else: + deserialized_weight = WeightQBitsTensor.load_from_state_dict( + state_dict, + weight_prefix, + qtype=self.weight_qtype, + axis=0, + group_size=self.weight_group_size, + size=self.weight.size(), + stride=self.weight.stride(), + missing_keys=missing_keys, + ) + if deserialized_weight is not None: + deserialized_weight = deserialized_weight.optimize() + + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + if assign_to_params_buffers and (deserialized_weight is not None): + self.weight = torch.nn.Parameter(deserialized_weight) + elif deserialized_weight is not None: + if type(self.weight.data) is not type(deserialized_weight): + # Reloading frozen weights into unfrozen module: move to the correct device and force assignment + self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) + else: + # FIXME: here we should copy frozen weights into frozen module, but this leads to grad error + self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs + ) + + @property + def qweight(self): + """Return the module quantized weight + + When the module is frozen or does not quantize its weight parameter, it simply + returns the weight. + When the module is not frozen, this property is required to add the dynamic quantization + of the weight parameter to the graph and allow gradients to be propagated to the + underlying weight float values. + """ + if isinstance(self._wrapped.weight, QTensor): + # Frozen QModule + return self._wrapped.weight + # Quantize dynamically the weights per-axis + if isinstance(self.optimizer, SymmetricOptimizer): + scale = self.optimizer(self._wrapped.weight, qtype=self.weight_qtype, axis=0) + shift = None + else: + scale, shift = self.optimizer( + self.weight, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size + ) + return quantize_weight( + self._wrapped.weight, + qtype=self.weight_qtype, + axis=0, + scale=scale, + shift=shift, + group_size=self.weight_group_size, + activation_qtype=self.activation_qtype, + ) + + def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor: + input = input[0] + if isinstance(input, ActivationQBytesTensor): + if input.qtype != self.activation_qtype: + raise ValueError( + "Models with heterogeneous quantized activations are not supported:" + f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." + ) + else: + input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) + return input + + def quantize_output( + self, + module: torch.nn.Module, + input: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) + + def freeze(self): + qweight = self.qweight + if qweight is not None: + # Replace float weights by quantized weights + self._wrapped.weight = torch.nn.Parameter(qweight) + + @property + def frozen(self): + return isinstance(self._wrapped.weight, QTensor) diff --git a/optimum/quanto/tensor/__init__.py b/optimum/quanto/tensor/__init__.py index ee5079a9..075cc853 100644 --- a/optimum/quanto/tensor/__init__.py +++ b/optimum/quanto/tensor/__init__.py @@ -18,6 +18,7 @@ from .optimizers import * from .qbits import * from .qbytes import * +from .qdynamic import * from .qtensor import * from .qtype import * from .weights import * diff --git a/optimum/quanto/tensor/qdynamic.py b/optimum/quanto/tensor/qdynamic.py new file mode 100644 index 00000000..f3509751 --- /dev/null +++ b/optimum/quanto/tensor/qdynamic.py @@ -0,0 +1,105 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.overrides import TorchFunctionMode + +from .optimizers import AbsmaxOptimizer, MaxOptimizer, SymmetricOptimizer +from .qtensor import QTensor +from .qtype import qint2, qint4, qtype +from .weights import quantize_weight + + +__all__ = ["QDynamicTensor"] + + +class QDynamicTensor(TorchFunctionMode): + """A custom torch dispatch mode that uses dynamically quantized tensors. + + Args: + tensor (`torch.Tensor`): the torch.Tensor that will be dynamically quantized + qtype (`qtype`): the qtype to use to quantize the Tensor. + axis (`int`): the quantization axis. + optimizer (`Optimizer`): the optimizer to use to get the quantization parameters. + """ + + def __init__(self, tensor: torch.Tensor, qtype: qtype, axis: int, optimizer=None): + super().__init__() + assert not isinstance(tensor, QTensor) + self.tensor = tensor + self.qtype = qtype + self.axis = axis + self.group_size = None + if qtype in (qint2, qint4): + axis_dim = tensor.shape[axis] + other_dim = tensor.numel() // axis_dim + group_size = 128 + if other_dim > group_size: + while other_dim % group_size != 0 and group_size > 32: + group_size -= 32 + if other_dim % group_size == 0: + self.group_size = group_size + if optimizer is None: + optimizer = AbsmaxOptimizer() if qtype.bits == 8 else MaxOptimizer() + self.optimizer = optimizer + + def qtensor(self, other_qtype: qtype = None): + """Return the dynamically quantized QTensor + """ + # Quantize dynamically the tensor per-axis + if isinstance(self.optimizer, SymmetricOptimizer): + scale = self.optimizer(self.tensor, qtype=self.qtype, axis=self.axis) + shift = None + else: + scale, shift = self.optimizer( + self.tensor, qtype=self.qtype, axis=self.axis, group_size=self.group_size + ) + return quantize_weight( + self.tensor, + qtype=self.qtype, + axis=self.axis, + scale=scale, + shift=shift, + group_size=self.group_size, + activation_qtype=other_qtype, + ) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs is not None else {} + other_qtype = None + new_args = [] + for arg in args: + new_arg = arg + if isinstance(arg, QTensor): + if other_qtype is None: + other_qtype = arg.qtype + else: + assert arg.qtype == other_qtype + else: + qtag = getattr(arg, "qtag", None) + if qtag == self.tensor.qtag: + # Replace the tensor by its dynamically quantized version + new_arg = self.qtensor(other_qtype) + new_args.append(new_arg) + return func(*new_args, **kwargs) + + def __enter__(self): + super().__enter__() + # Tag the target Tensor to identify it when dispatching + self.tensor.qtag = id(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + # Untag the target Tensor + delattr(self.tensor, "qtag")