Skip to content

Commit cb0f604

Browse files
HighCWuSunMarc
andauthored
Fix HQQ model param device transfer issue (#38466)
* Fix HQQ model param device transfer issue * modify a comment * clear the code and add test for hqq device/dtype * fix test hqq code quality of imports --------- Co-authored-by: Marc Sun <[email protected]>
1 parent c77bcd8 commit cb0f604

File tree

3 files changed

+85
-4
lines changed

3 files changed

+85
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3897,7 +3897,20 @@ def get_memory_footprint(self, return_buffers=True):
38973897
@wraps(torch.nn.Module.cuda)
38983898
def cuda(self, *args, **kwargs):
38993899
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
3900-
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
3900+
from hqq.core.quantize import HQQLinear
3901+
3902+
# Since HQQLinear stores some tensors in the 'meta' attribute,
3903+
# it's necessary to manually call the `cuda` method on HQQLinear layers.
3904+
super().cuda(*args, **kwargs)
3905+
for module in self.modules():
3906+
if isinstance(module, HQQLinear):
3907+
if len(args) > 0:
3908+
device = args[0]
3909+
else:
3910+
device = kwargs.get("device", "cuda")
3911+
module.cuda(device)
3912+
return self
3913+
39013914
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
39023915
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
39033916
if getattr(self, "is_loaded_in_8bit", False):
@@ -3910,8 +3923,7 @@ def cuda(self, *args, **kwargs):
39103923
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
39113924
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
39123925
)
3913-
else:
3914-
return super().cuda(*args, **kwargs)
3926+
return super().cuda(*args, **kwargs)
39153927

39163928
@wraps(torch.nn.Module.to)
39173929
def to(self, *args, **kwargs):
@@ -3926,7 +3938,30 @@ def to(self, *args, **kwargs):
39263938
break
39273939

39283940
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
3929-
raise ValueError("`.to` is not supported for HQQ-quantized models.")
3941+
from hqq.core.quantize import HQQLinear
3942+
3943+
# Since HQQLinear stores some tensors in the 'meta' attribute, we must
3944+
# explicitly move the parameters to the target device for each HQQLinear layer after `to`.
3945+
super().to(*args, **kwargs)
3946+
for module in self.modules():
3947+
if isinstance(module, HQQLinear):
3948+
if "device" in kwargs:
3949+
device = kwargs["device"]
3950+
else:
3951+
device = args[0]
3952+
if "dtype" in kwargs:
3953+
dtype = kwargs["dtype"]
3954+
elif dtype_present_in_args:
3955+
dtype = arg
3956+
else:
3957+
dtype = None
3958+
# Due to the current messy implementation of HQQLinear, updating `compute_dtype`
3959+
# followed by calling the `cuda` method achieves the intended behavior of `to`,
3960+
# even when the target device is CPU.
3961+
if dtype is not None:
3962+
module.compute_dtype = dtype
3963+
module.cuda(device)
3964+
return self
39303965

39313966
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
39323967
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")

src/transformers/quantizers/quantizer_hqq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ def create_quantized_param(
202202
if is_hqq_available():
203203
from hqq.core.quantize import HQQLinear
204204

205+
# TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
206+
# but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
207+
# we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
208+
@property
209+
def weight(_self: HQQLinear):
210+
return torch.empty(0, dtype=_self.compute_dtype, device=_self.device)
211+
212+
HQQLinear.weight = weight
213+
205214
module, tensor_name = get_module_from_name(model, param_name)
206215
layer_name = ".".join(param_name.split(".")[:-1])
207216
parent_module = find_parent(model, layer_name)

tests/quantization/hqq/test_hqq.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import gc
1616
import unittest
1717

18+
import accelerate
19+
1820
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
1921
from transformers.testing_utils import (
2022
backend_empty_cache,
@@ -119,6 +121,41 @@ def test_fp16_quantized_model(self):
119121
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
120122
check_forward(self, hqq_runner.model)
121123

124+
def test_quantized_model_to_new_device_and_new_dtype(self):
125+
"""
126+
Simple LLM model testing different devices and dtypes
127+
"""
128+
quant_config = HqqConfig(nbits=8, group_size=64)
129+
130+
hqq_runner = HQQLLMRunner(
131+
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
132+
)
133+
134+
original_device = hqq_runner.model.model.layers[0].self_attn.v_proj.device
135+
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
136+
check_forward(self, hqq_runner.model)
137+
138+
# Remove `accelerate` hooks to enable move the model to a new device
139+
accelerate.hooks.remove_hook_from_module(hqq_runner.model, recurse=True)
140+
141+
hqq_runner.model.to("cpu", torch.bfloat16)
142+
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
143+
check_forward(self, hqq_runner.model)
144+
145+
hqq_runner.model.cuda(original_device)
146+
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
147+
check_forward(self, hqq_runner.model)
148+
149+
def test_quantized_model_fake_weight_dtype(self):
150+
quant_config = HqqConfig(nbits=8, group_size=64)
151+
152+
hqq_runner = HQQLLMRunner(
153+
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
154+
)
155+
156+
# We use a hack to inject a fake weight to HQQLinear. Check that it works
157+
self.assertEqual(hqq_runner.model.model.layers[0].self_attn.v_proj.weight.dtype, torch.float16)
158+
122159

123160
@slow
124161
@require_torch_gpu

0 commit comments

Comments
 (0)