From 4d7bd369c3ff7d6a8f0588af4ccffd7fe2e0df1c Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 18 Sep 2024 17:04:03 +0000 Subject: [PATCH] feat(qtensor): add MarlinQBitsTensor Adding more tests revealed a bug in the Marlin int4 kernel when the weights and inputs are large enough. Failing configurations are marked as xfail. --- .../tensor/weights/marlin/int4/__init__.py | 1 + .../tensor/weights/marlin/int4/qbits.py | 168 ++++++++++++++++++ .../test_marlin_int4_weight_qbits_tensor.py | 150 ++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 optimum/quanto/tensor/weights/marlin/int4/qbits.py create mode 100644 test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py diff --git a/optimum/quanto/tensor/weights/marlin/int4/__init__.py b/optimum/quanto/tensor/weights/marlin/int4/__init__.py index 0fdfc89c..2d213c0a 100644 --- a/optimum/quanto/tensor/weights/marlin/int4/__init__.py +++ b/optimum/quanto/tensor/weights/marlin/int4/__init__.py @@ -1 +1,2 @@ from .packed import * +from .qbits import * diff --git a/optimum/quanto/tensor/weights/marlin/int4/qbits.py b/optimum/quanto/tensor/weights/marlin/int4/qbits.py new file mode 100644 index 00000000..4a3aaf1a --- /dev/null +++ b/optimum/quanto/tensor/weights/marlin/int4/qbits.py @@ -0,0 +1,168 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast + +import torch +from torch.autograd import Function + +from ....function import QuantizedLinearFunction +from ....grouped import group, ungroup +from ....qtype import qtypes +from ...qbits import WeightQBitsTensor +from ..permutations import marlin_permute +from .packed import MarlinInt4PackedTensor + + +__all__ = ["MarlinInt4WeightQBitsTensor"] + + +class MarlinQBitsDequantizer(Function): + @staticmethod + def forward(ctx, t): + unpacked = t._data.unpack() + scale = t._scale + shift = t._shift + unpacked = group(unpacked, axis=0, group_size=t._group_size) + # Apply inverted permutations + scale = marlin_permute(scale, reverse=True) + shift = marlin_permute(shift, reverse=True) + n_scales = scale.numel() + scale = scale.t().reshape((n_scales, 1)) + shift = shift.t().reshape((n_scales, 1)) + # Shift is already scaled and negated + dqt = scale * unpacked + shift + return ungroup(dqt, axis=t.axis, orig_shape=t.shape) + + @staticmethod + def backward(ctx, gO): + return gO + + +class MarlinQBitsLinearFunction(QuantizedLinearFunction): + @staticmethod + def forward(ctx, input, other, bias): + ctx.save_for_backward(input, other) + if type(input) is not torch.Tensor: + input = input.dequantize() + out_features, in_features = other.shape + output = torch.ops.quanto.gemm_f16i4_marlin( + input, + other._data._data, + other._scale, + other._shift, + other._workspace, + ) + if bias is not None: + output = output + bias + return output + + +class MarlinInt4WeightQBitsTensor(WeightQBitsTensor): + @staticmethod + def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): + assert data.device.type == "cuda" + assert data.device == scale.device + assert data.device == shift.device + return torch.Tensor._make_wrapper_subclass( + cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad + ) + + def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): + assert axis == 0 + out_features, in_features = size + if not isinstance(data, MarlinInt4PackedTensor): + assert type(data) is torch.Tensor + # Format data, scale and shift for optimized CUDA gemm + ungrouped = ungroup(data, axis=0, orig_shape=size) + data = MarlinInt4PackedTensor.pack(ungrouped) + scale = scale.reshape(out_features, in_features // group_size).t().contiguous() + shift = shift.reshape(out_features, in_features // group_size).t() + if not shift.dtype.is_floating_point: + # Integer shift must be scaled + shift = scale * shift + # Shift must be negated + shift = -shift.contiguous() + # Finally, apply scale and shift permutations + scale = marlin_permute(scale) + shift = marlin_permute(shift) + super().__init__(qtype, axis, group_size, size, stride, data, scale, shift) + self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device) + + def dequantize(self): + return MarlinQBitsDequantizer.apply(self) + + def weight_qbits_tensor(self): + """Convert back to a WeightQBitsTensor + + This is required to make sure only standard packing is used when serializing. + """ + data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size) + scale = marlin_permute(self._scale, reverse=True) + shift = marlin_permute(self._shift, reverse=True) + n_scales = scale.numel() + scale = scale.t().reshape((n_scales, 1)) + shift = -shift.t().reshape((n_scales, 1)) + return WeightQBitsTensor( + self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift + ) + + def __tensor_flatten__(self): + inner_tensors = ["_data", "_scale", "_shift"] + # Since meta can be used for serialization, use only strings + meta = { + "qtype": self._qtype.name, + "axis": str(self._axis), + "group_size": str(self._group_size), + "size": str(list(self.size())), + "stride": str(list(self.stride())), + } + return inner_tensors, meta + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert len(inner_tensors) == 3 + assert len(meta) == 5 + data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"] + # Meta should only contain strings, AST compatible except qtype + qtype = qtypes[meta["qtype"]] + axis = ast.literal_eval(meta["axis"]) + group_size = ast.literal_eval(meta["group_size"]) + size = ast.literal_eval(meta["size"]) + stride = ast.literal_eval(meta["stride"]) + return MarlinInt4WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Dispatch torch functions applied on this subtensor + + This method is called whenever a torch function (such as `torch.nn.functional.linear`) + is called with at least one parameter coresponding to this subtensor: + + - if a quantized implementation exists for the selected function, it is called, + - otherwise, the original implementation is called, deactivating further functional dispatch. + + During the execution of the standard torch function, a second-level of dispatch will + happen, but this time directly on individual torch Tensor operations (mainly ATEN). + """ + kwargs = kwargs or {} + if func is torch.nn.functional.linear: + + def qlinear(input, other, bias=None): + return MarlinQBitsLinearFunction.apply(input, other, bias) + + return qlinear(*args, **kwargs) + # Defer to operations dispatcher + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) diff --git a/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py new file mode 100644 index 00000000..a44db5b2 --- /dev/null +++ b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py @@ -0,0 +1,150 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from helpers import device_eq, random_qweight +from tensor.weights.weight_helpers import check_weight_qtensor_linear + +from optimum.quanto import qint4 +from optimum.quanto.library.extensions import is_extension_available +from optimum.quanto.tensor.weights import WeightQBitsTensor +from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available" +) +@pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) +@pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) +def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features): + qtype = qint4 + group_size = 128 + dtype = torch.float16 + shape = (out_features, in_features) + device = torch.device("cuda") + qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) + # Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members + marlinqbt = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + assert marlinqbt.dtype == dtype + assert marlinqbt.qtype == qtype + assert marlinqbt.shape == shape + assert device_eq(marlinqbt.device, device) + # Verify the dequantized tensors are identical + assert torch.equal(marlinqbt.dequantize(), qbt.dequantize()) + # Now verify that we can reconstruct the WeightQBitsTensor + new_qbt = marlinqbt.weight_qbits_tensor() + assert type(new_qbt) is WeightQBitsTensor + assert new_qbt.dtype == dtype + assert new_qbt.qtype == qtype + assert new_qbt.shape == shape + assert torch.equal(new_qbt._data, qbt._data) + assert torch.equal(new_qbt._scale, qbt._scale) + assert torch.equal(new_qbt._shift, qbt._shift) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_marlin_int4_weight_qbits_tensor_move(device): + qtype = qint4 + group_size = 128 + dtype = torch.float16 + shape = (1024, 1024) + device = torch.device("cuda") + # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA + qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda")) + marlinqbt = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + # Move to device, dequantize and compare + moved_qbt = marlinqbt.to(device) + assert isinstance(moved_qbt, WeightQBitsTensor) + if device.type != "cuda": + assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor + assert marlinqbt.dtype == moved_qbt.dtype + assert marlinqbt.qtype == moved_qbt.qtype + assert marlinqbt.shape == moved_qbt.shape + assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize()) + + +def _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias +): + # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA + qbt = random_qweight( + (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda") + ) + marlin_qweight = MarlinInt4WeightQBitsTensor( + qtype=qbt.qtype, + axis=qbt.axis, + group_size=qbt._group_size, + size=qbt.size(), + stride=qbt.stride(), + data=qbt._data.unpack(), + scale=qbt._scale, + shift=qbt._shift, + ) + check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias) + + +@pytest.mark.skipif( + not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, + reason="CUDA >= sm80 not available", +) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("tokens", [16, 32]) +@pytest.mark.parametrize("in_features", [1024]) +@pytest.mark.parametrize("out_features", [1024, 2048, 4096]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) +def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias): + dtype = torch.float16 + weight_qtype = qint4 + group_size = 128 + _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias + ) + + +@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False) +@pytest.mark.skipif( + not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, + reason="CUDA >= sm80 not available", +) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("tokens", [48, 64]) +# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) +@pytest.mark.parametrize("in_features", [4096, 16384]) +@pytest.mark.parametrize("out_features", [2048, 4096]) +def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): + dtype = torch.float16 + weight_qtype = qint4 + group_size = 128 + _test_marlin_int4_weight_qbits_tensor_linear( + dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False + )