File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -56,13 +56,19 @@ def __init__(
56
56
if not hasattr (model .config , "use_cache" ) or model .config .use_cache is False :
57
57
raise ValueError ("The model must have caching enabled to be performant." )
58
58
59
- if not hasattr (model .config , "layer_types " ):
60
- # If `layer_types ` is not specified explicitly in the config, there is only 1 type of layers, so
61
- # export will use `StaticCache` by default.
62
- logging .info ("Using `StaticCache` for export as `layer_types ` is not specified in the config." )
59
+ if not hasattr (model .config , "cache_implementation " ):
60
+ # If `cache_implementation ` is not specified explicitly in the config, `DynamicCache` will
61
+ # be used by default, so export will use `StaticCache` by default.
62
+ logging .info ("Using `StaticCache` for export as `cache_implementation ` is not specified in the config." )
63
63
self .model = TorchExportableModuleWithStaticCache (model )
64
64
else :
65
- self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
65
+ if model .config .cache_implementation == "hybrid" :
66
+ self .model = TorchExportableModuleWithHybridCache (model , max_batch_size , max_cache_len )
67
+ else :
68
+ raise ValueError (
69
+ f"Unsupported cache implementation: { model .config .cache_implementation } . "
70
+ "Please use `hybrid` or `static`."
71
+ )
66
72
67
73
def forward (
68
74
self ,
You can’t perform that action at this time.
0 commit comments