Skip to content

Commit 1501526

Browse files
committed
use static cache if has layer_types but no sliding_window
1 parent c5d9f34 commit 1501526

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

src/transformers/integrations/executorch.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,13 @@ def __init__(
5656
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
5757
raise ValueError("The model must have caching enabled to be performant.")
5858

59-
if not hasattr(model.config, "cache_implementation"):
60-
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
61-
# be used by default, so export will use `StaticCache` by default.
62-
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
59+
if not hasattr(model.config, "layer_types") and model.config.get("sliding_window", None) is not None:
60+
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
61+
# export will use `StaticCache` by default.
62+
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
6363
self.model = TorchExportableModuleWithStaticCache(model)
6464
else:
65-
if model.config.cache_implementation == "hybrid":
66-
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
67-
else:
68-
raise ValueError(
69-
f"Unsupported cache implementation: {model.config.cache_implementation}. "
70-
"Please use `hybrid` or `static`."
71-
)
65+
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
7266

7367
def forward(
7468
self,
@@ -406,12 +400,6 @@ def __init__(
406400
if not self.model.config.use_cache:
407401
raise AssertionError("Model must have caching enabled")
408402

409-
if (
410-
not hasattr(self.model.config, "cache_implementation")
411-
or self.model.config.cache_implementation != "hybrid"
412-
):
413-
raise AssertionError("Model must use 'hybrid' cache implementation")
414-
415403
# Initialize the HybridCache
416404
self.cache = HybridCache(
417405
config=self.model.config,

0 commit comments

Comments
 (0)