Skip to content

Commit 5d9cf58

Browse files
michaelbenayounBernardZach
authored andcommitted
Fix torch.fx issue related to the new loss_kwargs keyword argument (huggingface#34380)
* Fix FX * Unskip tests
1 parent 2bfe54b commit 5d9cf58

File tree

6 files changed

+1
-6
lines changed

6 files changed

+1
-6
lines changed

src/transformers/utils/fx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,7 @@ def keys(self, obj: "Proxy") -> Any:
14161416
your custom tracer.
14171417
"""
14181418
attribute = HFAttribute(obj, "keys")()
1419-
if obj.node.target == "**kwargs":
1419+
if obj.node.target.startswith("**"):
14201420
return attribute._metadata
14211421
return attribute
14221422

tests/models/cohere/test_modeling_cohere.py

-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ def test_model_various_embeddings(self):
304304
config_and_inputs[0].position_embedding_type = type
305305
self.model_tester.create_and_check_model(*config_and_inputs)
306306

307-
@unittest.skip(reason="PR #34283 made changes to the forward function.")
308307
def test_torch_fx_output_loss(self):
309308
super().test_torch_fx_output_loss()
310309

tests/models/mistral/test_modeling_mistral.py

-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def test_model_various_embeddings(self):
356356
config_and_inputs[0].position_embedding_type = type
357357
self.model_tester.create_and_check_model(*config_and_inputs)
358358

359-
@unittest.skip(reason="PR #34283 made changes to the forward function.")
360359
def test_torch_fx_output_loss(self):
361360
super().test_torch_fx_output_loss()
362361

tests/models/mixtral/test_modeling_mixtral.py

-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def test_model_various_embeddings(self):
356356
config_and_inputs[0].position_embedding_type = type
357357
self.model_tester.create_and_check_model(*config_and_inputs)
358358

359-
@unittest.skip(reason="PR #34283 made changes to the forward function.")
360359
def test_torch_fx_output_loss(self):
361360
super().test_torch_fx_output_loss()
362361

tests/models/qwen2/test_modeling_qwen2.py

-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ def test_model_various_embeddings(self):
368368
config_and_inputs[0].position_embedding_type = type
369369
self.model_tester.create_and_check_model(*config_and_inputs)
370370

371-
@unittest.skip(reason="PR #34283 made changes to the forward function.")
372371
def test_torch_fx_output_loss(self):
373372
super().test_torch_fx_output_loss()
374373

tests/models/qwen2_moe/test_modeling_qwen2_moe.py

-1
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def test_model_various_embeddings(self):
391391
config_and_inputs[0].position_embedding_type = type
392392
self.model_tester.create_and_check_model(*config_and_inputs)
393393

394-
@unittest.skip(reason="PR #34283 made changes to the forward function.")
395394
def test_torch_fx_output_loss(self):
396395
super().test_torch_fx_output_loss()
397396

0 commit comments

Comments
 (0)