diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d68166d5268..6f2c6c194f2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -361,6 +361,9 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi Note: We fully disable this if we are using `deepspeed` """ + if model_to_load.device.type == "meta": + return False + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: return False @@ -375,7 +378,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi return False # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = list(model_to_load.state_dict().keys())[0] + first_key = next(iter(model_to_load.state_dict().keys())) if start_prefix + first_key in state_dict: return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 96a30df7e55..85e7c20dd52 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import glob +import itertools import json import os import os.path @@ -459,6 +460,19 @@ def test_model_from_config_torch_dtype_str(self): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + @require_torch + def test_model_from_pretrained_meta_device(self): + def is_on_meta(model_id, dtype): + with torch.device("meta"): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) + return all(value.device.type == "meta" for value in model.state_dict().values()) + + model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing") + dtypes = (None, "auto", torch.float16) + + for model_id, dtype in itertools.product(model_ids, dtypes): + self.assertTrue(is_on_meta(model_id, dtype)) + def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument