Skip to content

Commit e1247d1

Browse files
committed
correctly slice
1 parent c23a1c1 commit e1247d1

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

src/transformers/models/cohere2/modeling_cohere2.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ def forward(
258258
}
259259
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
260260

261+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
262+
if self.config._attn_implementation == "flash_attention_2":
263+
seq_len = attention_mask.shape[-1]
264+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
265+
261266
attention_interface: Callable = eager_attention_forward
262267
if self.config._attn_implementation != "eager":
263268
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -344,18 +349,13 @@ def forward(
344349
"""
345350

346351
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
347-
# Flash-attn is a 2D tensor
348-
if self.config._attn_implementation == "flash_attention_2":
349-
if past_key_value is not None: # when decoding
350-
attention_mask = attention_mask[:, -self.sliding_window :]
351-
else:
352+
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
353+
if self.config._attn_implementation != "flash_attention_2":
352354
min_dtype = torch.finfo(hidden_states.dtype).min
353355
sliding_window_mask = torch.tril(
354356
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
355357
)
356358
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
357-
if attention_mask.shape[-1] <= 1: # when decoding
358-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
359359

360360
residual = hidden_states
361361

src/transformers/models/cohere2/modular_cohere2.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,11 @@ def forward(
296296
}
297297
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
298298

299+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
300+
if self.config._attn_implementation == "flash_attention_2":
301+
seq_len = attention_mask.shape[-1]
302+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
303+
299304
attention_interface: Callable = eager_attention_forward
300305
if self.config._attn_implementation != "eager":
301306
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -363,18 +368,13 @@ def forward(
363368
"""
364369

365370
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
366-
# Flash-attn is a 2D tensor
367-
if self.config._attn_implementation == "flash_attention_2":
368-
if past_key_value is not None: # when decoding
369-
attention_mask = attention_mask[:, -self.sliding_window :]
370-
else:
371+
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
372+
if self.config._attn_implementation != "flash_attention_2":
371373
min_dtype = torch.finfo(hidden_states.dtype).min
372374
sliding_window_mask = torch.tril(
373375
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
374376
)
375377
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
376-
if attention_mask.shape[-1] <= 1: # when decoding
377-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
378378

379379
residual = hidden_states
380380

src/transformers/models/gemma2/modeling_gemma2.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ def forward(
223223
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
224224
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
225225

226+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
227+
if self.config._attn_implementation == "flash_attention_2":
228+
seq_len = attention_mask.shape[-1]
229+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
230+
226231
attention_interface: Callable = eager_attention_forward
227232
if self.config._attn_implementation != "eager":
228233
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -278,18 +283,13 @@ def forward(
278283
cache_position: Optional[torch.LongTensor] = None,
279284
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
280285
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
281-
# Flash-attn is a 2D tensor
282-
if self.config._attn_implementation == "flash_attention_2":
283-
if past_key_value is not None: # when decoding
284-
attention_mask = attention_mask[:, -self.sliding_window :]
285-
else:
286+
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
287+
if self.config._attn_implementation != "flash_attention_2":
286288
min_dtype = torch.finfo(hidden_states.dtype).min
287289
sliding_window_mask = torch.tril(
288290
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
289291
)
290292
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
291-
if attention_mask.shape[-1] <= 1: # when decoding
292-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
293293

294294
residual = hidden_states
295295

src/transformers/models/gemma2/modular_gemma2.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def forward(
259259
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
260260
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
261261

262+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
263+
if self.config._attn_implementation == "flash_attention_2":
264+
seq_len = attention_mask.shape[-1]
265+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
266+
262267
attention_interface: Callable = eager_attention_forward
263268
if self.config._attn_implementation != "eager":
264269
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -314,18 +319,13 @@ def forward(
314319
cache_position: Optional[torch.LongTensor] = None,
315320
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
316321
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
317-
# Flash-attn is a 2D tensor
318-
if self.config._attn_implementation == "flash_attention_2":
319-
if past_key_value is not None: # when decoding
320-
attention_mask = attention_mask[:, -self.sliding_window :]
321-
else:
322+
# For FA2, the mask is 2D and is of shape [bs, seq_len] (not [bs, fixed_cache_len])
323+
if self.config._attn_implementation != "flash_attention_2":
322324
min_dtype = torch.finfo(hidden_states.dtype).min
323325
sliding_window_mask = torch.tril(
324326
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
325327
)
326328
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
327-
if attention_mask.shape[-1] <= 1: # when decoding
328-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
329329

330330
residual = hidden_states
331331

0 commit comments

Comments
 (0)