Skip to content

Commit 4c82bff

Browse files
FIX Multi GPU tests: explicit device map (huggingface#2484)
Some multi GPU tests had device_map="auto" but some recent changes in accelerate resulted in parameters being moved to a single device. Now set the device map explicitly to avoid that. Add a more rigorous check to ensure that the parameters are really on multiple devices.
1 parent 87cffd5 commit 4c82bff

File tree

1 file changed

+89
-4
lines changed

1 file changed

+89
-4
lines changed

tests/test_gpu_examples.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,16 +1649,37 @@ def test_causal_lm_training_multi_gpu(self):
16491649
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
16501650
correctly.
16511651
"""
1652+
device_map = {
1653+
"model.decoder.embed_tokens": 0,
1654+
"lm_head": 0,
1655+
"model.decoder.embed_positions": 0,
1656+
"model.decoder.project_out": 0,
1657+
"model.decoder.project_in": 0,
1658+
"model.decoder.layers.0": 0,
1659+
"model.decoder.layers.1": 0,
1660+
"model.decoder.layers.2": 0,
1661+
"model.decoder.layers.3": 0,
1662+
"model.decoder.layers.4": 0,
1663+
"model.decoder.layers.5": 0,
1664+
"model.decoder.layers.6": 1,
1665+
"model.decoder.layers.7": 1,
1666+
"model.decoder.layers.8": 1,
1667+
"model.decoder.layers.9": 1,
1668+
"model.decoder.layers.10": 1,
1669+
"model.decoder.layers.11": 1,
1670+
"model.decoder.final_layer_norm": 1,
1671+
}
16521672

16531673
with tempfile.TemporaryDirectory() as tmp_dir:
16541674
model = AutoModelForCausalLM.from_pretrained(
16551675
self.causal_lm_model_id,
16561676
torch_dtype=torch.float16,
1657-
device_map="auto",
1677+
device_map=device_map,
16581678
quantization_config=self.quantization_config,
16591679
)
16601680

16611681
assert set(model.hf_device_map.values()) == set(range(device_count))
1682+
assert {p.device.index for p in model.parameters()} == set(range(device_count))
16621683

16631684
model = prepare_model_for_kbit_training(model)
16641685

@@ -3182,14 +3203,35 @@ def test_causal_lm_training_multi_gpu(self):
31823203
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
31833204
correctly.
31843205
"""
3206+
device_map = {
3207+
"model.decoder.embed_tokens": 0,
3208+
"lm_head": 0,
3209+
"model.decoder.embed_positions": 0,
3210+
"model.decoder.project_out": 0,
3211+
"model.decoder.project_in": 0,
3212+
"model.decoder.layers.0": 0,
3213+
"model.decoder.layers.1": 0,
3214+
"model.decoder.layers.2": 0,
3215+
"model.decoder.layers.3": 0,
3216+
"model.decoder.layers.4": 0,
3217+
"model.decoder.layers.5": 0,
3218+
"model.decoder.layers.6": 1,
3219+
"model.decoder.layers.7": 1,
3220+
"model.decoder.layers.8": 1,
3221+
"model.decoder.layers.9": 1,
3222+
"model.decoder.layers.10": 1,
3223+
"model.decoder.layers.11": 1,
3224+
"model.decoder.final_layer_norm": 1,
3225+
}
31853226

31863227
with tempfile.TemporaryDirectory() as tmp_dir:
31873228
model = AutoModelForCausalLM.from_pretrained(
31883229
self.causal_lm_model_id,
3189-
device_map="auto",
3230+
device_map=device_map,
31903231
)
31913232

31923233
assert set(model.hf_device_map.values()) == set(range(device_count))
3234+
assert {p.device.index for p in model.parameters()} == set(range(device_count))
31933235

31943236
model = prepare_model_for_kbit_training(model)
31953237

@@ -3579,16 +3621,38 @@ def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
35793621
def test_causal_lm_training_multi_gpu_torchao(self, quant_type):
35803622
from transformers import TorchAoConfig
35813623

3624+
device_map = {
3625+
"model.decoder.embed_tokens": 0,
3626+
"lm_head": 0,
3627+
"model.decoder.embed_positions": 0,
3628+
"model.decoder.project_out": 0,
3629+
"model.decoder.project_in": 0,
3630+
"model.decoder.layers.0": 0,
3631+
"model.decoder.layers.1": 0,
3632+
"model.decoder.layers.2": 0,
3633+
"model.decoder.layers.3": 0,
3634+
"model.decoder.layers.4": 0,
3635+
"model.decoder.layers.5": 0,
3636+
"model.decoder.layers.6": 1,
3637+
"model.decoder.layers.7": 1,
3638+
"model.decoder.layers.8": 1,
3639+
"model.decoder.layers.9": 1,
3640+
"model.decoder.layers.10": 1,
3641+
"model.decoder.layers.11": 1,
3642+
"model.decoder.final_layer_norm": 1,
3643+
}
3644+
35823645
with tempfile.TemporaryDirectory() as tmp_dir:
35833646
quantization_config = TorchAoConfig(quant_type=quant_type)
35843647
model = AutoModelForCausalLM.from_pretrained(
35853648
self.causal_lm_model_id,
3586-
device_map="auto",
3649+
device_map=device_map,
35873650
quantization_config=quantization_config,
35883651
torch_dtype=torch.bfloat16,
35893652
)
35903653

35913654
assert set(model.hf_device_map.values()) == set(range(device_count))
3655+
assert {p.device.index for p in model.parameters()} == set(range(device_count))
35923656

35933657
model = prepare_model_for_kbit_training(model)
35943658
model.model_parallel = True
@@ -3640,15 +3704,36 @@ def test_causal_lm_training_multi_gpu_torchao_int4_raises(self):
36403704
# TODO: Once proper torchao support for int4 is added, remove this test and add int4 to supported_quant_types
36413705
from transformers import TorchAoConfig
36423706

3707+
device_map = {
3708+
"model.decoder.embed_tokens": 0,
3709+
"lm_head": 0,
3710+
"model.decoder.embed_positions": 0,
3711+
"model.decoder.project_out": 0,
3712+
"model.decoder.project_in": 0,
3713+
"model.decoder.layers.0": 0,
3714+
"model.decoder.layers.1": 0,
3715+
"model.decoder.layers.2": 0,
3716+
"model.decoder.layers.3": 0,
3717+
"model.decoder.layers.4": 0,
3718+
"model.decoder.layers.5": 0,
3719+
"model.decoder.layers.6": 1,
3720+
"model.decoder.layers.7": 1,
3721+
"model.decoder.layers.8": 1,
3722+
"model.decoder.layers.9": 1,
3723+
"model.decoder.layers.10": 1,
3724+
"model.decoder.layers.11": 1,
3725+
"model.decoder.final_layer_norm": 1,
3726+
}
36433727
quantization_config = TorchAoConfig(quant_type="int4_weight_only")
36443728
model = AutoModelForCausalLM.from_pretrained(
36453729
self.causal_lm_model_id,
3646-
device_map="auto",
3730+
device_map=device_map,
36473731
quantization_config=quantization_config,
36483732
torch_dtype=torch.bfloat16,
36493733
)
36503734

36513735
assert set(model.hf_device_map.values()) == set(range(device_count))
3736+
assert {p.device.index for p in model.parameters()} == set(range(device_count))
36523737

36533738
model = prepare_model_for_kbit_training(model)
36543739
model.model_parallel = True

0 commit comments

Comments
 (0)