@@ -196,6 +196,18 @@ def seen_tokens(self):
196
196
else :
197
197
return None
198
198
199
+ def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
200
+ """
201
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
202
+ the given layer at `layer_idx`.
203
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
204
+ for each layer.
205
+ """
206
+ query_length = cache_position .shape [0 ]
207
+ past_seen_tokens = self .get_seq_length ()
208
+ kv_length = query_length + past_seen_tokens
209
+ return kv_length , 0
210
+
199
211
200
212
@dataclass
201
213
class CacheConfig :
@@ -1084,8 +1096,6 @@ class SinkCache(Cache):
1084
1096
```
1085
1097
"""
1086
1098
1087
- is_sliding = True
1088
-
1089
1099
def __init__ (self , window_length : int , num_sink_tokens : int ) -> None :
1090
1100
super ().__init__ ()
1091
1101
self .key_cache : List [torch .Tensor ] = []
@@ -1390,6 +1400,16 @@ def reset(self):
1390
1400
self .key_cache [layer_idx ].zero_ ()
1391
1401
self .value_cache [layer_idx ].zero_ ()
1392
1402
1403
+ def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
1404
+ """
1405
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1406
+ the given layer at `layer_idx`.
1407
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1408
+ for each layer.
1409
+ """
1410
+ kv_length = self .get_max_cache_shape ()
1411
+ return kv_length , 0
1412
+
1393
1413
1394
1414
class SlidingWindowCache (StaticCache ):
1395
1415
"""
@@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
1446
1466
```
1447
1467
"""
1448
1468
1449
- is_sliding = True
1450
1469
is_compileable = True
1451
1470
1452
1471
def __init__ (
@@ -1465,6 +1484,7 @@ def __init__(
1465
1484
"config and it's not set to None."
1466
1485
)
1467
1486
max_cache_len = min (config .sliding_window , max_cache_len )
1487
+ self .sliding_window = config .sliding_window
1468
1488
super ().__init__ (
1469
1489
config = config ,
1470
1490
max_batch_size = max_batch_size ,
@@ -1509,6 +1529,21 @@ def reset(self):
1509
1529
self .key_cache [layer_idx ].zero_ ()
1510
1530
self .value_cache [layer_idx ].zero_ ()
1511
1531
1532
+ def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
1533
+ """
1534
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1535
+ the given layer at `layer_idx`.
1536
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1537
+ for each layer.
1538
+ """
1539
+ query_length = cache_position .shape [0 ]
1540
+ first_cache_position = cache_position [0 ]
1541
+ # torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
1542
+ kv_offset = torch .clamp (first_cache_position - self .sliding_window + 1 , min = 0 )
1543
+ # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
1544
+ kv_length = max (query_length , self .get_max_cache_shape ())
1545
+ return kv_length , kv_offset
1546
+
1512
1547
1513
1548
class EncoderDecoderCache (Cache ):
1514
1549
"""
@@ -1761,12 +1796,17 @@ def __init__(
1761
1796
else config .num_key_value_heads
1762
1797
)
1763
1798
1764
- layer_switch = config .sliding_window_pattern if hasattr (config , "sliding_window_pattern" ) else 2 # 2 is for BC
1765
- self .is_sliding_list = [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )]
1799
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
1800
+ if hasattr (config , "layer_types" ):
1801
+ self .is_sliding = [layer_type != "full_attention" for layer_type in config .layer_types ]
1802
+ else :
1803
+ self .is_sliding = [False ] * config .num_hidden_layers
1804
+
1766
1805
self .key_cache : List [torch .Tensor ] = []
1767
1806
self .value_cache : List [torch .Tensor ] = []
1768
1807
global_cache_shape = (self .max_batch_size , self .num_key_value_heads , self .max_cache_len , self .head_dim )
1769
1808
sliding_cache_shape = (self .max_batch_size , self .num_key_value_heads , self .sliding_window_len , self .head_dim )
1809
+ self .sliding_window = min (config .sliding_window , max_cache_len )
1770
1810
device = torch .device (device ) if device is not None else None
1771
1811
for i in range (config .num_hidden_layers ):
1772
1812
if layer_device_map is not None :
@@ -1775,7 +1815,7 @@ def __init__(
1775
1815
layer_device = device
1776
1816
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1777
1817
# breaks when updating the cache.
1778
- cache_shape = sliding_cache_shape if self .is_sliding_list [i ] else global_cache_shape
1818
+ cache_shape = sliding_cache_shape if self .is_sliding [i ] else global_cache_shape
1779
1819
new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1780
1820
new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1781
1821
torch ._dynamo .mark_static_address (new_layer_key_cache )
@@ -1796,7 +1836,7 @@ def update(
1796
1836
if cache_position is None :
1797
1837
raise ValueError ("`cache_position` must be provided for HybridCache." )
1798
1838
1799
- is_sliding_layer = self .is_sliding_list [layer_idx ]
1839
+ is_sliding_layer = self .is_sliding [layer_idx ]
1800
1840
1801
1841
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
1802
1842
# when the cache is initialized in the forward pass (e.g. Gemma2)
@@ -1843,6 +1883,26 @@ def reset(self):
1843
1883
self .key_cache [layer_idx ].zero_ ()
1844
1884
self .value_cache [layer_idx ].zero_ ()
1845
1885
1886
+ def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
1887
+ """
1888
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1889
+ the given layer at `layer_idx`.
1890
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1891
+ for each layer.
1892
+ """
1893
+ if self .is_sliding [layer_idx ]:
1894
+ query_length = cache_position .shape [0 ]
1895
+ first_cache_position = cache_position [0 ]
1896
+
1897
+ local_mask_kv_offset = torch .clamp (first_cache_position - self .sliding_window + 1 , min = 0 )
1898
+ # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
1899
+ local_mask_kv_length = max (query_length , self .sliding_window )
1900
+ return local_mask_kv_length , local_mask_kv_offset
1901
+
1902
+ full_mask_kv_offset = 0
1903
+ full_mask_kv_length = self .get_max_cache_shape ()
1904
+ return full_mask_kv_length , full_mask_kv_offset
1905
+
1846
1906
1847
1907
class HybridChunkedCache (Cache ):
1848
1908
"""
@@ -1912,11 +1972,11 @@ def __init__(
1912
1972
self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
1913
1973
self ._dtype = dtype
1914
1974
1915
- if hasattr (config .get_text_config (), "no_rope_layers" ):
1916
- self .is_sliding = config .no_rope_layers
1975
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
1976
+ if hasattr (config , "layer_types" ):
1977
+ self .is_sliding = [layer_type != "full_attention" for layer_type in config .layer_types ]
1917
1978
else :
1918
- layer_switch = getattr (config , "sliding_window_pattern" , 2 )
1919
- self .is_sliding = [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )]
1979
+ self .is_sliding = [False ] * config .num_hidden_layers
1920
1980
1921
1981
self .key_cache : List [torch .Tensor ] = []
1922
1982
self .value_cache : List [torch .Tensor ] = []
@@ -1999,11 +2059,7 @@ def update(
1999
2059
key_states = key_states .to (k_out .dtype )
2000
2060
value_states = value_states .to (v_out .dtype )
2001
2061
2002
- if self .is_sliding [layer_idx ]:
2003
- update_fn = self ._sliding_update
2004
- else :
2005
- update_fn = self ._static_update
2006
-
2062
+ update_fn = self ._sliding_update if self .is_sliding [layer_idx ] else self ._static_update
2007
2063
return update_fn (
2008
2064
cache_position ,
2009
2065
layer_idx ,
@@ -2038,6 +2094,37 @@ def reset(self):
2038
2094
self .value_cache [layer_idx ].zero_ ()
2039
2095
self .cumulative_length = [0 for _ in range (len (self .cumulative_length ))]
2040
2096
2097
+ def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
2098
+ """
2099
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
2100
+ the given layer at `layer_idx`.
2101
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
2102
+ for each layer.
2103
+ """
2104
+ if self .is_sliding [layer_idx ]:
2105
+ query_length = cache_position .shape [0 ]
2106
+ first_cache_position = cache_position [0 ]
2107
+
2108
+ local_mask_kv_offset = torch .clamp (first_cache_position - self .sliding_window + 1 , min = 0 )
2109
+ # This is the true general case for any Cache using local attention (sliding or chunked)
2110
+ if first_cache_position >= self .sliding_window :
2111
+ # Here the Cache is already full
2112
+ local_mask_kv_length = self .sliding_window + query_length - 1
2113
+ elif (
2114
+ first_cache_position < self .sliding_window
2115
+ and first_cache_position + query_length > self .sliding_window
2116
+ ):
2117
+ # Here the Cache becomes full with the new input
2118
+ local_mask_kv_length = first_cache_position + query_length
2119
+ else :
2120
+ # Here the Cache is still smaller than the local size, but we return the local size as it's static
2121
+ local_mask_kv_length = self .sliding_window
2122
+ return local_mask_kv_length , local_mask_kv_offset
2123
+
2124
+ full_mask_kv_offset = 0
2125
+ full_mask_kv_length = self .get_max_cache_shape ()
2126
+ return full_mask_kv_length , full_mask_kv_offset
2127
+
2041
2128
2042
2129
class OffloadedHybridCache (HybridChunkedCache ):
2043
2130
def __init__ (
0 commit comments