Skip to content

Commit 559e1d9

Browse files
committed
minor changes to sinq integration
1 parent 7cc0c19 commit 559e1d9

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

src/transformers/quantizers/quantizer_sinq.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ def _normalize_cuda_device(dev: Optional[Union[str, int]]) -> str:
2929
- 'cpu' -> 'cpu'
3030
- 'cuda', 'cuda:0', 0, 1 -> canonicalized CUDA forms
3131
"""
32-
# Handle 'auto' or None
3332
if dev is None or dev == "auto":
34-
return "cuda" if torch.cuda.is_available() else "cpu"
33+
if torch.cuda.is_available():
34+
# pick the current device if it's set, otherwise 0
35+
idx = torch.cuda.current_device() if torch.cuda.device_count() else 0
36+
return f"cuda:{idx}"
37+
return "cpu"
38+
39+
if dev == "cuda":
40+
return "cuda:0" # explicit index
3541

3642
# Explicit CPU
3743
if dev == "cpu":
@@ -84,7 +90,6 @@ def _walk(v):
8490
_walk(dmap)
8591
return out
8692

87-
8893
# ------------------------------------------------------------------------------------
8994
# SINQ dynamic import
9095
# ------------------------------------------------------------------------------------
@@ -169,9 +174,14 @@ def validate_environment(self, dtype=None, device_map=None, weights_only=None, *
169174

170175
device_str = _normalize_cuda_device(getattr(cfg, "device", "auto"))
171176
_validate_cuda_device_str(device_str)
177+
178+
print(f'Device string is: {device_str}')
172179

173180
self._normalized_device_str = device_str
174181

182+
if device_str.startswith("cuda"):
183+
torch.cuda.set_device(torch.device(device_str))
184+
175185
# Not supported: multi-GPU sharding via device_map
176186
devs = _flatten_device_map(device_map)
177187
if devs:
@@ -234,6 +244,7 @@ def _process_model_before_weight_loading(self, model: nn.Module, **kwargs) -> tu
234244
compute_dtype = self.update_dtype(None)
235245
to_skip = set(cfg.modules_to_not_convert or [])
236246
device_str = getattr(self, "_normalized_device_str", _normalize_cuda_device(cfg.device))
247+
print(f'Device string in process model before_weights: {device_str}')
237248

238249
def _convert(m: nn.Module, prefix: str = ""):
239250
for child_name, child in list(m.named_children()):
@@ -245,7 +256,7 @@ def _convert(m: nn.Module, prefix: str = ""):
245256
use_bias=(child.bias is not None),
246257
sinq_quant_dict=sinq_quant_dict,
247258
compute_dtype=compute_dtype,
248-
device_str="cpu",
259+
device_str=device_str,
249260
)
250261
setattr(m, child_name, ph)
251262
else:
@@ -268,8 +279,12 @@ def _process_model_after_weight_loading(
268279
method = str(getattr(cfg, "method", "sinq")).lower()
269280
device_str = getattr(self, "_normalized_device_str", _normalize_cuda_device(cfg.device))
270281
device = torch.device(device_str)
282+
if device.type == "cuda":
283+
torch.cuda.set_device(device)
271284
model = model
272285

286+
print(f'Device string in process model after weights: {device_str}')
287+
273288
placeholders: list[tuple[str, _SinqLoadTimeLinear, nn.Module]] = []
274289

275290
def _gather(m: nn.Module, prefix: str = "", parent: Optional[nn.Module] = None):
@@ -478,4 +493,4 @@ def _lookup_acts(name: str) -> Optional[torch.Tensor]:
478493
if total > 0:
479494
logger.info(f"A-SINQ applied to {num_asinq}/{total} Linear layers ({100.0 * num_asinq / total:.1f}%).")
480495

481-
return model
496+
return model

0 commit comments

Comments
 (0)