diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7b6eca65f1e..85e7c20dd52 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -467,16 +467,11 @@ def is_on_meta(model_id, dtype): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) return all(value.device.type == "meta" for value in model.state_dict().values()) - result = {} - expected = {} model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing") dtypes = (None, "auto", torch.float16) for model_id, dtype in itertools.product(model_ids, dtypes): - result[(model_id, dtype)] = is_on_meta(model_id, dtype) - expected[(model_id, dtype)] = True - - assert result == expected + self.assertTrue(is_on_meta(model_id, dtype)) def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either