Skip to content

Commit 8f11b60

Browse files
author
Guang Yang
committed
Unbreak optimum-executorch
1 parent 1094dd3 commit 8f11b60

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/transformers/integrations/executorch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,19 @@ def __init__(
5656
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
5757
raise ValueError("The model must have caching enabled to be performant.")
5858

59-
if not hasattr(model.config, "layer_types"):
60-
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
61-
# export will use `StaticCache` by default.
62-
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
59+
if not hasattr(model.config, "cache_implementation"):
60+
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
61+
# be used by default, so export will use `StaticCache` by default.
62+
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
6363
self.model = TorchExportableModuleWithStaticCache(model)
6464
else:
65-
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
65+
if model.config.cache_implementation == "hybrid":
66+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
67+
else:
68+
raise ValueError(
69+
f"Unsupported cache implementation: {model.config.cache_implementation}. "
70+
"Please use `hybrid` or `static`."
71+
)
6672

6773
def forward(
6874
self,

0 commit comments

Comments
 (0)