@@ -3897,7 +3897,20 @@ def get_memory_footprint(self, return_buffers=True):
3897
3897
@wraps (torch .nn .Module .cuda )
3898
3898
def cuda (self , * args , ** kwargs ):
3899
3899
if getattr (self , "quantization_method" , None ) == QuantizationMethod .HQQ :
3900
- raise ValueError ("`.cuda` is not supported for HQQ-quantized models." )
3900
+ from hqq .core .quantize import HQQLinear
3901
+
3902
+ # Since HQQLinear stores some tensors in the 'meta' attribute,
3903
+ # it's necessary to manually call the `cuda` method on HQQLinear layers.
3904
+ super ().cuda (* args , ** kwargs )
3905
+ for module in self .modules ():
3906
+ if isinstance (module , HQQLinear ):
3907
+ if len (args ) > 0 :
3908
+ device = args [0 ]
3909
+ else :
3910
+ device = kwargs .get ("device" , "cuda" )
3911
+ module .cuda (device )
3912
+ return self
3913
+
3901
3914
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
3902
3915
if getattr (self , "quantization_method" , None ) == QuantizationMethod .BITS_AND_BYTES :
3903
3916
if getattr (self , "is_loaded_in_8bit" , False ):
@@ -3910,8 +3923,7 @@ def cuda(self, *args, **kwargs):
3910
3923
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
3911
3924
f"The current device is `{ self .device } `. If you intended to move the model, please install bitsandbytes >= 0.43.2."
3912
3925
)
3913
- else :
3914
- return super ().cuda (* args , ** kwargs )
3926
+ return super ().cuda (* args , ** kwargs )
3915
3927
3916
3928
@wraps (torch .nn .Module .to )
3917
3929
def to (self , * args , ** kwargs ):
@@ -3926,7 +3938,30 @@ def to(self, *args, **kwargs):
3926
3938
break
3927
3939
3928
3940
if getattr (self , "quantization_method" , None ) == QuantizationMethod .HQQ :
3929
- raise ValueError ("`.to` is not supported for HQQ-quantized models." )
3941
+ from hqq .core .quantize import HQQLinear
3942
+
3943
+ # Since HQQLinear stores some tensors in the 'meta' attribute, we must
3944
+ # explicitly move the parameters to the target device for each HQQLinear layer after `to`.
3945
+ super ().to (* args , ** kwargs )
3946
+ for module in self .modules ():
3947
+ if isinstance (module , HQQLinear ):
3948
+ if "device" in kwargs :
3949
+ device = kwargs ["device" ]
3950
+ else :
3951
+ device = args [0 ]
3952
+ if "dtype" in kwargs :
3953
+ dtype = kwargs ["dtype" ]
3954
+ elif dtype_present_in_args :
3955
+ dtype = arg
3956
+ else :
3957
+ dtype = None
3958
+ # Due to the current messy implementation of HQQLinear, updating `compute_dtype`
3959
+ # followed by calling the `cuda` method achieves the intended behavior of `to`,
3960
+ # even when the target device is CPU.
3961
+ if dtype is not None :
3962
+ module .compute_dtype = dtype
3963
+ module .cuda (device )
3964
+ return self
3930
3965
3931
3966
if dtype_present_in_args and getattr (self , "quantization_method" , None ) == QuantizationMethod .QUARK :
3932
3967
raise ValueError ("Casting a Quark quantized model to a new `dtype` is not supported." )
0 commit comments