Skip to content

Commit c5d9f34

Browse files
Guang Yangguangy10
Guang Yang
authored andcommitted
Unbreak optimum-executorch
1 parent aa798b7 commit c5d9f34

File tree

9 files changed

+67
-32
lines changed

9 files changed

+67
-32
lines changed

src/transformers/integrations/executorch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,19 @@ def __init__(
5656
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
5757
raise ValueError("The model must have caching enabled to be performant.")
5858

59-
if not hasattr(model.config, "layer_types"):
60-
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
61-
# export will use `StaticCache` by default.
62-
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
59+
if not hasattr(model.config, "cache_implementation"):
60+
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
61+
# be used by default, so export will use `StaticCache` by default.
62+
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
6363
self.model = TorchExportableModuleWithStaticCache(model)
6464
else:
65-
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
65+
if model.config.cache_implementation == "hybrid":
66+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
67+
else:
68+
raise ValueError(
69+
f"Unsupported cache implementation: {model.config.cache_implementation}. "
70+
"Please use `hybrid` or `static`."
71+
)
6672

6773
def forward(
6874
self,

tests/models/gemma/test_modeling_gemma.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ def test_export_static_cache(self):
390390

391391
from transformers.integrations.executorch import (
392392
TorchExportableModuleWithStaticCache,
393-
convert_and_export_with_cache,
394393
)
395394

396395
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
@@ -436,7 +435,10 @@ def test_export_static_cache(self):
436435
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
437436

438437
# Static Cache + export
439-
exported_program = convert_and_export_with_cache(model)
438+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
439+
440+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
441+
exported_program = exportable_module.export()
440442
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
441443
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
442444
)

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def test_export_static_cache(self):
313313

314314
from transformers.integrations.executorch import (
315315
TorchExportableModuleWithStaticCache,
316-
convert_and_export_with_cache,
317316
)
318317

319318
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
@@ -363,7 +362,10 @@ def test_export_static_cache(self):
363362
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
364363

365364
# Static Cache + export
366-
exported_program = convert_and_export_with_cache(model)
365+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
366+
367+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
368+
exported_program = exportable_module.export()
367369
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
368370
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
369371
)

tests/models/llama/test_modeling_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ def test_export_static_cache(self):
306306

307307
from transformers.integrations.executorch import (
308308
TorchExportableModuleWithStaticCache,
309-
convert_and_export_with_cache,
310309
)
311310

312311
llama_models = {
@@ -352,7 +351,10 @@ def test_export_static_cache(self):
352351
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
353352

354353
# Static Cache + export
355-
exported_program = convert_and_export_with_cache(model)
354+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
355+
356+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
357+
exported_program = exportable_module.export()
356358
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
357359
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
358360
)

tests/models/olmo/test_modeling_olmo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ def test_export_static_cache(self):
334334

335335
from transformers.integrations.executorch import (
336336
TorchExportableModuleWithStaticCache,
337-
convert_and_export_with_cache,
338337
)
339338

340339
olmo_model = "allenai/OLMo-1B-hf"
@@ -382,7 +381,10 @@ def test_export_static_cache(self):
382381
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
383382

384383
# Static Cache + export
385-
exported_program = convert_and_export_with_cache(model)
384+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
385+
386+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
387+
exported_program = exportable_module.export()
386388
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
387389
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
388390
)

tests/models/phi3/test_modeling_phi3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def test_export_static_cache(self):
347347
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
348348
from transformers.integrations.executorch import (
349349
TorchExportableModuleWithStaticCache,
350-
convert_and_export_with_cache,
351350
)
352351

353352
model_id = "microsoft/Phi-4-mini-instruct"
@@ -399,7 +398,10 @@ def test_export_static_cache(self):
399398
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
400399

401400
# Static Cache + export
402-
exported_program = convert_and_export_with_cache(model)
401+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
402+
403+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
404+
exported_program = exportable_module.export()
403405
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
404406
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
405407
)

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
slow,
3232
torch_device,
3333
)
34-
from transformers.utils.import_utils import is_torch_greater_or_equal
3534

3635

3736
if is_torch_available():
@@ -246,7 +245,6 @@ def test_export_static_cache(self):
246245

247246
from transformers.integrations.executorch import (
248247
TorchExportableModuleWithStaticCache,
249-
convert_and_export_with_cache,
250248
)
251249

252250
qwen_model = "Qwen/Qwen2-0.5B"
@@ -287,8 +285,13 @@ def test_export_static_cache(self):
287285
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
288286

289287
# Static Cache + export
290-
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
291-
exported_program = convert_and_export_with_cache(model, strict=strict)
288+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
289+
290+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
291+
strict = version.parse(torch.__version__) != version.parse(
292+
"2.7.0"
293+
) # Due to https://github.com/pytorch/pytorch/issues/150994
294+
exported_program = exportable_module.export(strict=strict)
292295
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
293296
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
294297
)

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
slow,
3232
torch_device,
3333
)
34-
from transformers.utils.import_utils import is_torch_greater_or_equal
3534

3635

3736
if is_torch_available():
@@ -240,13 +239,12 @@ def test_export_static_cache(self):
240239

241240
from transformers.integrations.executorch import (
242241
TorchExportableModuleWithStaticCache,
243-
convert_and_export_with_cache,
244242
)
245243

246244
qwen_model = "Qwen/Qwen3-0.6B-Base"
247245

248246
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
249-
if is_torch_greater_or_equal("2.7.0"):
247+
if version.parse(torch.__version__) == version.parse("2.7.0"):
250248
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
251249
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
252250
else:
@@ -285,7 +283,10 @@ def test_export_static_cache(self):
285283
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
286284

287285
# Static Cache + export
288-
exported_program = convert_and_export_with_cache(model, strict=strict)
286+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
287+
288+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
289+
exported_program = exportable_module.export(strict=strict)
289290
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
290291
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
291292
)

tests/utils/test_cache_utils.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import unittest
1717

18+
from packaging import version
1819
from parameterized import parameterized
1920

2021
from transformers import set_seed
@@ -680,15 +681,27 @@ def test_static_cache_exportability(self):
680681
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
681682
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
682683

683-
# Export with dynamic shapes using Dim.AUTO
684-
tokenizer = AutoTokenizer.from_pretrained(model_id)
685-
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
686-
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
684+
# Export with dynamic shapes
685+
input_ids = torch.zeros((1, 3), dtype=torch.long)
686+
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
687+
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
688+
strict = version.parse(torch.__version__) != version.parse("2.7.0")
687689
exported_program = convert_and_export_with_cache(
688690
model,
689691
example_input_ids=input_ids,
692+
example_cache_position=cache_position,
693+
dynamic_shapes=dynamic_shapes,
694+
strict=strict,
695+
)
696+
697+
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
698+
699+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
700+
exported_program = exportable_module.export(
701+
input_ids=input_ids,
702+
cache_position=cache_position,
690703
dynamic_shapes=dynamic_shapes,
691-
strict=False,
704+
strict=strict,
692705
)
693706

694707
def test_hybrid_cache_exportability(self):
@@ -727,13 +740,15 @@ def test_hybrid_cache_exportability(self):
727740
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
728741

729742
# Export with dynamic shapes using Dim.AUTO
730-
tokenizer = AutoTokenizer.from_pretrained(model_id)
731-
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
732-
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
743+
input_ids = torch.zeros((1, 3), dtype=torch.long)
744+
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
745+
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
746+
strict = version.parse(torch.__version__) != version.parse("2.7.0")
733747
exported_program = exportable_module.export(
734748
input_ids=input_ids,
749+
cache_position=cache_position,
735750
dynamic_shapes=dynamic_shapes,
736-
strict=False,
751+
strict=strict,
737752
)
738753

739754

0 commit comments

Comments
 (0)