From 1ff5a07aa5df4a4c199dd5e534d41d1e0765e00d Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Sun, 13 Oct 2024 11:51:28 +0200 Subject: [PATCH 1/5] Do not load for meta device --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d68166d5268..7f5a498342f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -361,6 +361,9 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi Note: We fully disable this if we are using `deepspeed` """ + if model_to_load.device.type == "meta": + return False + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: return False From 936087cfe05885b1553be8d9d276e0d0848a1069 Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Sun, 13 Oct 2024 12:00:20 +0200 Subject: [PATCH 2/5] Make some minor improvements --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7f5a498342f..6f2c6c194f2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -378,7 +378,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi return False # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = list(model_to_load.state_dict().keys())[0] + first_key = next(iter(model_to_load.state_dict().keys())) if start_prefix + first_key in state_dict: return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype From 984892590558becd783ba00bc203e101fafb31ec Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Tue, 15 Oct 2024 20:59:38 +0200 Subject: [PATCH 3/5] Add test --- tests/utils/test_modeling_utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 96a30df7e55..1f725d29658 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import glob +import itertools import json import os import os.path @@ -459,6 +460,24 @@ def test_model_from_config_torch_dtype_str(self): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + @slow + def test_model_from_pretrained_meta_device(self): + def is_on_meta(model_id, dtype): + with torch.device("meta"): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) + return all(value.device.type == "meta" for value in model.state_dict().values()) + + result = {} + expected = {} + model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing") + dtypes = (None, "auto", torch.float16) + + for model_id, dtype in itertools.product(model_ids, dtypes): + result[(model_id, dtype)] = is_on_meta(model_id, dtype) + expected[(model_id, dtype)] = True + + assert result == expected + def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument From a83e2346ada3d051cc126d332463ed49cd64f47a Mon Sep 17 00:00:00 2001 From: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com> Date: Wed, 16 Oct 2024 19:34:09 +0200 Subject: [PATCH 4/5] Update tests/utils/test_modeling_utils.py Update test parameters Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- tests/utils/test_modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 1f725d29658..7b6eca65f1e 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -460,7 +460,7 @@ def test_model_from_config_torch_dtype_str(self): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") - @slow + @require_torch def test_model_from_pretrained_meta_device(self): def is_on_meta(model_id, dtype): with torch.device("meta"): From b782bfbae64b88f0a40abb155f96bb22a17d972c Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Wed, 16 Oct 2024 19:41:48 +0200 Subject: [PATCH 5/5] Make the test simpler --- tests/utils/test_modeling_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7b6eca65f1e..85e7c20dd52 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -467,16 +467,11 @@ def is_on_meta(model_id, dtype): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) return all(value.device.type == "meta" for value in model.state_dict().values()) - result = {} - expected = {} model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing") dtypes = (None, "auto", torch.float16) for model_id, dtype in itertools.product(model_ids, dtypes): - result[(model_id, dtype)] = is_on_meta(model_id, dtype) - expected[(model_id, dtype)] = True - - assert result == expected + self.assertTrue(is_on_meta(model_id, dtype)) def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either