@@ -29,9 +29,15 @@ def _normalize_cuda_device(dev: Optional[Union[str, int]]) -> str:
2929 - 'cpu' -> 'cpu'
3030 - 'cuda', 'cuda:0', 0, 1 -> canonicalized CUDA forms
3131 """
32- # Handle 'auto' or None
3332 if dev is None or dev == "auto" :
34- return "cuda" if torch .cuda .is_available () else "cpu"
33+ if torch .cuda .is_available ():
34+ # pick the current device if it's set, otherwise 0
35+ idx = torch .cuda .current_device () if torch .cuda .device_count () else 0
36+ return f"cuda:{ idx } "
37+ return "cpu"
38+
39+ if dev == "cuda" :
40+ return "cuda:0" # explicit index
3541
3642 # Explicit CPU
3743 if dev == "cpu" :
@@ -84,7 +90,6 @@ def _walk(v):
8490 _walk (dmap )
8591 return out
8692
87-
8893# ------------------------------------------------------------------------------------
8994# SINQ dynamic import
9095# ------------------------------------------------------------------------------------
@@ -169,9 +174,14 @@ def validate_environment(self, dtype=None, device_map=None, weights_only=None, *
169174
170175 device_str = _normalize_cuda_device (getattr (cfg , "device" , "auto" ))
171176 _validate_cuda_device_str (device_str )
177+
178+ print (f'Device string is: { device_str } ' )
172179
173180 self ._normalized_device_str = device_str
174181
182+ if device_str .startswith ("cuda" ):
183+ torch .cuda .set_device (torch .device (device_str ))
184+
175185 # Not supported: multi-GPU sharding via device_map
176186 devs = _flatten_device_map (device_map )
177187 if devs :
@@ -234,6 +244,7 @@ def _process_model_before_weight_loading(self, model: nn.Module, **kwargs) -> tu
234244 compute_dtype = self .update_dtype (None )
235245 to_skip = set (cfg .modules_to_not_convert or [])
236246 device_str = getattr (self , "_normalized_device_str" , _normalize_cuda_device (cfg .device ))
247+ print (f'Device string in process model before_weights: { device_str } ' )
237248
238249 def _convert (m : nn .Module , prefix : str = "" ):
239250 for child_name , child in list (m .named_children ()):
@@ -245,7 +256,7 @@ def _convert(m: nn.Module, prefix: str = ""):
245256 use_bias = (child .bias is not None ),
246257 sinq_quant_dict = sinq_quant_dict ,
247258 compute_dtype = compute_dtype ,
248- device_str = "cpu" ,
259+ device_str = device_str ,
249260 )
250261 setattr (m , child_name , ph )
251262 else :
@@ -268,8 +279,12 @@ def _process_model_after_weight_loading(
268279 method = str (getattr (cfg , "method" , "sinq" )).lower ()
269280 device_str = getattr (self , "_normalized_device_str" , _normalize_cuda_device (cfg .device ))
270281 device = torch .device (device_str )
282+ if device .type == "cuda" :
283+ torch .cuda .set_device (device )
271284 model = model
272285
286+ print (f'Device string in process model after weights: { device_str } ' )
287+
273288 placeholders : list [tuple [str , _SinqLoadTimeLinear , nn .Module ]] = []
274289
275290 def _gather (m : nn .Module , prefix : str = "" , parent : Optional [nn .Module ] = None ):
@@ -478,4 +493,4 @@ def _lookup_acts(name: str) -> Optional[torch.Tensor]:
478493 if total > 0 :
479494 logger .info (f"A-SINQ applied to { num_asinq } /{ total } Linear layers ({ 100.0 * num_asinq / total :.1f} %)." )
480495
481- return model
496+ return model
0 commit comments