Skip to content

Commit

Permalink
Fix: take into account meta device (huggingface#34134)
Browse files Browse the repository at this point in the history
* Do not load for meta device

* Make some minor improvements

* Add test

* Update tests/utils/test_modeling_utils.py

Update test parameters

Co-authored-by: Marc Sun <[email protected]>

* Make the test simpler

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
2 people authored and BernardZach committed Dec 5, 2024
1 parent 2f564c7 commit 9ec16a6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
Note: We fully disable this if we are using `deepspeed`
"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

Expand All @@ -375,7 +378,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

Expand Down
14 changes: 14 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import copy
import glob
import itertools
import json
import os
import os.path
Expand Down Expand Up @@ -459,6 +460,19 @@ def test_model_from_config_torch_dtype_str(self):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")

@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
with torch.device("meta"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
return all(value.device.type == "meta" for value in model.state_dict().values())

model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
dtypes = (None, "auto", torch.float16)

for model_id, dtype in itertools.product(model_ids, dtypes):
self.assertTrue(is_on_meta(model_id, dtype))

def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument
Expand Down

0 comments on commit 9ec16a6

Please sign in to comment.