@@ -1649,16 +1649,37 @@ def test_causal_lm_training_multi_gpu(self):
16491649 Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
16501650 correctly.
16511651 """
1652+ device_map = {
1653+ "model.decoder.embed_tokens" : 0 ,
1654+ "lm_head" : 0 ,
1655+ "model.decoder.embed_positions" : 0 ,
1656+ "model.decoder.project_out" : 0 ,
1657+ "model.decoder.project_in" : 0 ,
1658+ "model.decoder.layers.0" : 0 ,
1659+ "model.decoder.layers.1" : 0 ,
1660+ "model.decoder.layers.2" : 0 ,
1661+ "model.decoder.layers.3" : 0 ,
1662+ "model.decoder.layers.4" : 0 ,
1663+ "model.decoder.layers.5" : 0 ,
1664+ "model.decoder.layers.6" : 1 ,
1665+ "model.decoder.layers.7" : 1 ,
1666+ "model.decoder.layers.8" : 1 ,
1667+ "model.decoder.layers.9" : 1 ,
1668+ "model.decoder.layers.10" : 1 ,
1669+ "model.decoder.layers.11" : 1 ,
1670+ "model.decoder.final_layer_norm" : 1 ,
1671+ }
16521672
16531673 with tempfile .TemporaryDirectory () as tmp_dir :
16541674 model = AutoModelForCausalLM .from_pretrained (
16551675 self .causal_lm_model_id ,
16561676 torch_dtype = torch .float16 ,
1657- device_map = "auto" ,
1677+ device_map = device_map ,
16581678 quantization_config = self .quantization_config ,
16591679 )
16601680
16611681 assert set (model .hf_device_map .values ()) == set (range (device_count ))
1682+ assert {p .device .index for p in model .parameters ()} == set (range (device_count ))
16621683
16631684 model = prepare_model_for_kbit_training (model )
16641685
@@ -3182,14 +3203,35 @@ def test_causal_lm_training_multi_gpu(self):
31823203 Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
31833204 correctly.
31843205 """
3206+ device_map = {
3207+ "model.decoder.embed_tokens" : 0 ,
3208+ "lm_head" : 0 ,
3209+ "model.decoder.embed_positions" : 0 ,
3210+ "model.decoder.project_out" : 0 ,
3211+ "model.decoder.project_in" : 0 ,
3212+ "model.decoder.layers.0" : 0 ,
3213+ "model.decoder.layers.1" : 0 ,
3214+ "model.decoder.layers.2" : 0 ,
3215+ "model.decoder.layers.3" : 0 ,
3216+ "model.decoder.layers.4" : 0 ,
3217+ "model.decoder.layers.5" : 0 ,
3218+ "model.decoder.layers.6" : 1 ,
3219+ "model.decoder.layers.7" : 1 ,
3220+ "model.decoder.layers.8" : 1 ,
3221+ "model.decoder.layers.9" : 1 ,
3222+ "model.decoder.layers.10" : 1 ,
3223+ "model.decoder.layers.11" : 1 ,
3224+ "model.decoder.final_layer_norm" : 1 ,
3225+ }
31853226
31863227 with tempfile .TemporaryDirectory () as tmp_dir :
31873228 model = AutoModelForCausalLM .from_pretrained (
31883229 self .causal_lm_model_id ,
3189- device_map = "auto" ,
3230+ device_map = device_map ,
31903231 )
31913232
31923233 assert set (model .hf_device_map .values ()) == set (range (device_count ))
3234+ assert {p .device .index for p in model .parameters ()} == set (range (device_count ))
31933235
31943236 model = prepare_model_for_kbit_training (model )
31953237
@@ -3579,16 +3621,38 @@ def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
35793621 def test_causal_lm_training_multi_gpu_torchao (self , quant_type ):
35803622 from transformers import TorchAoConfig
35813623
3624+ device_map = {
3625+ "model.decoder.embed_tokens" : 0 ,
3626+ "lm_head" : 0 ,
3627+ "model.decoder.embed_positions" : 0 ,
3628+ "model.decoder.project_out" : 0 ,
3629+ "model.decoder.project_in" : 0 ,
3630+ "model.decoder.layers.0" : 0 ,
3631+ "model.decoder.layers.1" : 0 ,
3632+ "model.decoder.layers.2" : 0 ,
3633+ "model.decoder.layers.3" : 0 ,
3634+ "model.decoder.layers.4" : 0 ,
3635+ "model.decoder.layers.5" : 0 ,
3636+ "model.decoder.layers.6" : 1 ,
3637+ "model.decoder.layers.7" : 1 ,
3638+ "model.decoder.layers.8" : 1 ,
3639+ "model.decoder.layers.9" : 1 ,
3640+ "model.decoder.layers.10" : 1 ,
3641+ "model.decoder.layers.11" : 1 ,
3642+ "model.decoder.final_layer_norm" : 1 ,
3643+ }
3644+
35823645 with tempfile .TemporaryDirectory () as tmp_dir :
35833646 quantization_config = TorchAoConfig (quant_type = quant_type )
35843647 model = AutoModelForCausalLM .from_pretrained (
35853648 self .causal_lm_model_id ,
3586- device_map = "auto" ,
3649+ device_map = device_map ,
35873650 quantization_config = quantization_config ,
35883651 torch_dtype = torch .bfloat16 ,
35893652 )
35903653
35913654 assert set (model .hf_device_map .values ()) == set (range (device_count ))
3655+ assert {p .device .index for p in model .parameters ()} == set (range (device_count ))
35923656
35933657 model = prepare_model_for_kbit_training (model )
35943658 model .model_parallel = True
@@ -3640,15 +3704,36 @@ def test_causal_lm_training_multi_gpu_torchao_int4_raises(self):
36403704 # TODO: Once proper torchao support for int4 is added, remove this test and add int4 to supported_quant_types
36413705 from transformers import TorchAoConfig
36423706
3707+ device_map = {
3708+ "model.decoder.embed_tokens" : 0 ,
3709+ "lm_head" : 0 ,
3710+ "model.decoder.embed_positions" : 0 ,
3711+ "model.decoder.project_out" : 0 ,
3712+ "model.decoder.project_in" : 0 ,
3713+ "model.decoder.layers.0" : 0 ,
3714+ "model.decoder.layers.1" : 0 ,
3715+ "model.decoder.layers.2" : 0 ,
3716+ "model.decoder.layers.3" : 0 ,
3717+ "model.decoder.layers.4" : 0 ,
3718+ "model.decoder.layers.5" : 0 ,
3719+ "model.decoder.layers.6" : 1 ,
3720+ "model.decoder.layers.7" : 1 ,
3721+ "model.decoder.layers.8" : 1 ,
3722+ "model.decoder.layers.9" : 1 ,
3723+ "model.decoder.layers.10" : 1 ,
3724+ "model.decoder.layers.11" : 1 ,
3725+ "model.decoder.final_layer_norm" : 1 ,
3726+ }
36433727 quantization_config = TorchAoConfig (quant_type = "int4_weight_only" )
36443728 model = AutoModelForCausalLM .from_pretrained (
36453729 self .causal_lm_model_id ,
3646- device_map = "auto" ,
3730+ device_map = device_map ,
36473731 quantization_config = quantization_config ,
36483732 torch_dtype = torch .bfloat16 ,
36493733 )
36503734
36513735 assert set (model .hf_device_map .values ()) == set (range (device_count ))
3736+ assert {p .device .index for p in model .parameters ()} == set (range (device_count ))
36523737
36533738 model = prepare_model_for_kbit_training (model )
36543739 model .model_parallel = True
0 commit comments