@@ -56,19 +56,15 @@ def __init__(
56
56
if not hasattr (model .config , "use_cache" ) or model .config .use_cache is False :
57
57
raise ValueError ("The model must have caching enabled to be performant." )
58
58
59
- if not hasattr (model .config , "cache_implementation" ):
60
- # If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
61
- # be used by default, so export will use `StaticCache` by default.
62
- logging .info ("Using `StaticCache` for export as `cache_implementation` is not specified in the config." )
63
- self .model = TorchExportableModuleWithStaticCache (model )
59
+ if hasattr (model .config , "layer_types" ) and getattr (model .config , "sliding_window" , None ) is not None :
60
+ self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
64
61
else :
65
- if model .config .cache_implementation == "hybrid" :
66
- self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
67
- else :
68
- raise ValueError (
69
- f"Unsupported cache implementation: { model .config .cache_implementation } . "
70
- "Please use `hybrid` or `static`."
71
- )
62
+ # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
63
+ # there is only 1 type of layers, so export will use `StaticCache` by default.
64
+ logging .info (
65
+ "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
66
+ )
67
+ self .model = TorchExportableModuleWithStaticCache (model )
72
68
73
69
def forward (
74
70
self ,
@@ -406,12 +402,6 @@ def __init__(
406
402
if not self .model .config .use_cache :
407
403
raise AssertionError ("Model must have caching enabled" )
408
404
409
- if (
410
- not hasattr (self .model .config , "cache_implementation" )
411
- or self .model .config .cache_implementation != "hybrid"
412
- ):
413
- raise AssertionError ("Model must use 'hybrid' cache implementation" )
414
-
415
405
# Initialize the HybridCache
416
406
self .cache = HybridCache (
417
407
config = self .model .config ,
0 commit comments