Skip to content

Commit 5a557f4

Browse files
committed
adding require_hqq
1 parent 9f90e44 commit 5a557f4

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

src/transformers/testing_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
is_gguf_available,
8888
is_grokadamw_available,
8989
is_hadamard_available,
90+
is_hqq_available,
9091
is_ipex_available,
9192
is_jieba_available,
9293
is_jinja_available,
@@ -1213,6 +1214,13 @@ def require_auto_gptq(test_case):
12131214
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
12141215

12151216

1217+
def require_hqq(test_case):
1218+
"""
1219+
Decorator for hqq dependency
1220+
"""
1221+
return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
1222+
1223+
12161224
def require_auto_awq(test_case):
12171225
"""
12181226
Decorator for auto_awq dependency

src/transformers/utils/quantization_config.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,9 @@ def __init__(
225225
if is_hqq_available():
226226
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
227227
else:
228-
229-
class HQQBaseQuantizeConfig:
230-
def __init__(self, *args, **kwargs):
231-
pass
228+
raise ImportError(
229+
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
230+
)
232231

233232
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
234233
if deprecated_key in kwargs:

tests/quantization/hqq/test_hqq.py

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
require_torch_multi_gpu,
2424
slow,
2525
torch_device,
26+
require_hqq,
2627
)
2728
from transformers.utils import is_hqq_available, is_torch_available
2829

@@ -86,6 +87,7 @@ def check_forward(test_module, model, batch_size=1, context_size=1024):
8687

8788

8889
@require_torch_gpu
90+
@require_hqq
8991
class HqqConfigTest(unittest.TestCase):
9092
def test_to_dict(self):
9193
"""
@@ -100,6 +102,7 @@ def test_to_dict(self):
100102
@slow
101103
@require_torch_gpu
102104
@require_accelerate
105+
@require_hqq
103106
class HQQTest(unittest.TestCase):
104107
def tearDown(self):
105108
cleanup()
@@ -122,6 +125,7 @@ def test_fp16_quantized_model(self):
122125
@require_torch_gpu
123126
@require_torch_multi_gpu
124127
@require_accelerate
128+
@require_hqq
125129
class HQQTestMultiGPU(unittest.TestCase):
126130
def tearDown(self):
127131
cleanup()
@@ -144,6 +148,7 @@ def test_fp16_quantized_model_multipgpu(self):
144148
@slow
145149
@require_torch_gpu
146150
@require_accelerate
151+
@require_hqq
147152
class HQQSerializationTest(unittest.TestCase):
148153
def tearDown(self):
149154
cleanup()

0 commit comments

Comments
 (0)