Skip to content

Unbreak optimum-executorch #38646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def __init__(
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.")

if not hasattr(model.config, "layer_types"):
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model)
else:
if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)

def forward(
self,
Expand Down Expand Up @@ -400,12 +402,6 @@ def __init__(
if not self.model.config.use_cache:
raise AssertionError("Model must have caching enabled")

if (
not hasattr(self.model.config, "cache_implementation")
or self.model.config.cache_implementation != "hybrid"
):
raise AssertionError("Model must use 'hybrid' cache implementation")

# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
Expand Down
6 changes: 4 additions & 2 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
Expand Down Expand Up @@ -436,7 +435,10 @@ def test_export_static_cache(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
6 changes: 4 additions & 2 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
Expand Down Expand Up @@ -363,7 +362,10 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
6 changes: 4 additions & 2 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

llama_models = {
Expand Down Expand Up @@ -352,7 +351,10 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
6 changes: 4 additions & 2 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

olmo_model = "allenai/OLMo-1B-hf"
Expand Down Expand Up @@ -382,7 +381,10 @@ def test_export_static_cache(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
6 changes: 4 additions & 2 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def test_export_static_cache(self):
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

model_id = "microsoft/Phi-4-mini-instruct"
Expand Down Expand Up @@ -399,7 +398,10 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
11 changes: 7 additions & 4 deletions tests/models/qwen2/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
slow,
torch_device,
)
from transformers.utils.import_utils import is_torch_greater_or_equal


if is_torch_available():
Expand Down Expand Up @@ -246,7 +245,6 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

qwen_model = "Qwen/Qwen2-0.5B"
Expand Down Expand Up @@ -287,8 +285,13 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = convert_and_export_with_cache(model, strict=strict)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
strict = version.parse(torch.__version__) != version.parse(
"2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
9 changes: 5 additions & 4 deletions tests/models/qwen3/test_modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
slow,
torch_device,
)
from transformers.utils.import_utils import is_torch_greater_or_equal


if is_torch_available():
Expand Down Expand Up @@ -240,13 +239,12 @@ def test_export_static_cache(self):

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

qwen_model = "Qwen/Qwen3-0.6B-Base"

tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
if is_torch_greater_or_equal("2.7.0"):
if version.parse(torch.__version__) == version.parse("2.7.0"):
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
else:
Expand Down Expand Up @@ -285,7 +283,10 @@ def test_export_static_cache(self):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model, strict=strict)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
Expand Down
33 changes: 24 additions & 9 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import unittest

from packaging import version
from parameterized import parameterized

from transformers import set_seed
Expand Down Expand Up @@ -680,15 +681,27 @@ def test_static_cache_exportability(self):
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)

# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
# Export with dynamic shapes
input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = convert_and_export_with_cache(
model,
example_input_ids=input_ids,
example_cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=strict,
)

from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=False,
strict=strict,
)

def test_hybrid_cache_exportability(self):
Expand Down Expand Up @@ -727,13 +740,15 @@ def test_hybrid_cache_exportability(self):
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)

# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=False,
strict=strict,
)


Expand Down