Skip to content

Commit 76fd034

Browse files
committed
use static cache if has layer_types but no sliding_window
1 parent c5d9f34 commit 76fd034

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

src/transformers/integrations/executorch.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,15 @@ def __init__(
5656
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
5757
raise ValueError("The model must have caching enabled to be performant.")
5858

59-
if not hasattr(model.config, "cache_implementation"):
60-
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
61-
# be used by default, so export will use `StaticCache` by default.
62-
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
63-
self.model = TorchExportableModuleWithStaticCache(model)
59+
if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None:
60+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
6461
else:
65-
if model.config.cache_implementation == "hybrid":
66-
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
67-
else:
68-
raise ValueError(
69-
f"Unsupported cache implementation: {model.config.cache_implementation}. "
70-
"Please use `hybrid` or `static`."
71-
)
62+
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
63+
# there is only 1 type of layers, so export will use `StaticCache` by default.
64+
logging.info(
65+
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
66+
)
67+
self.model = TorchExportableModuleWithStaticCache(model)
7268

7369
def forward(
7470
self,
@@ -406,12 +402,6 @@ def __init__(
406402
if not self.model.config.use_cache:
407403
raise AssertionError("Model must have caching enabled")
408404

409-
if (
410-
not hasattr(self.model.config, "cache_implementation")
411-
or self.model.config.cache_implementation != "hybrid"
412-
):
413-
raise AssertionError("Model must use 'hybrid' cache implementation")
414-
415405
# Initialize the HybridCache
416406
self.cache = HybridCache(
417407
config=self.model.config,

0 commit comments

Comments
 (0)