@@ -56,19 +56,13 @@ def __init__(
56
56
if not hasattr (model .config , "use_cache" ) or model .config .use_cache is False :
57
57
raise ValueError ("The model must have caching enabled to be performant." )
58
58
59
- if not hasattr (model .config , "cache_implementation" ) :
60
- # If `cache_implementation ` is not specified explicitly in the config, `DynamicCache` will
61
- # be used by default, so export will use `StaticCache` by default.
62
- logging .info ("Using `StaticCache` for export as `cache_implementation ` is not specified in the config." )
59
+ if not hasattr (model .config , "layer_types" ) and model . config . get ( "sliding_window" , None ) is not None :
60
+ # If `layer_types ` is not specified explicitly in the config, there is only 1 type of layers, so
61
+ # export will use `StaticCache` by default.
62
+ logging .info ("Using `StaticCache` for export as `layer_types ` is not specified in the config." )
63
63
self .model = TorchExportableModuleWithStaticCache (model )
64
64
else :
65
- if model .config .cache_implementation == "hybrid" :
66
- self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
67
- else :
68
- raise ValueError (
69
- f"Unsupported cache implementation: { model .config .cache_implementation } . "
70
- "Please use `hybrid` or `static`."
71
- )
65
+ self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
72
66
73
67
def forward (
74
68
self ,
@@ -406,12 +400,6 @@ def __init__(
406
400
if not self .model .config .use_cache :
407
401
raise AssertionError ("Model must have caching enabled" )
408
402
409
- if (
410
- not hasattr (self .model .config , "cache_implementation" )
411
- or self .model .config .cache_implementation != "hybrid"
412
- ):
413
- raise AssertionError ("Model must use 'hybrid' cache implementation" )
414
-
415
403
# Initialize the HybridCache
416
404
self .cache = HybridCache (
417
405
config = self .model .config ,
0 commit comments