Skip to content

Commit 9c59b35

Browse files
committed
check mask
1 parent e1247d1 commit 9c59b35

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

src/transformers/models/cohere2/modeling_cohere2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def forward(
259259
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
260260

261261
# Here we need to slice as we use a static cache by default, but FA2 does not support it
262-
if self.config._attn_implementation == "flash_attention_2":
262+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
263263
seq_len = attention_mask.shape[-1]
264264
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
265265

src/transformers/models/cohere2/modular_cohere2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def forward(
297297
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
298298

299299
# Here we need to slice as we use a static cache by default, but FA2 does not support it
300-
if self.config._attn_implementation == "flash_attention_2":
300+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
301301
seq_len = attention_mask.shape[-1]
302302
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
303303

src/transformers/models/gemma2/modeling_gemma2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def forward(
224224
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
225225

226226
# Here we need to slice as we use a static cache by default, but FA2 does not support it
227-
if self.config._attn_implementation == "flash_attention_2":
227+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
228228
seq_len = attention_mask.shape[-1]
229229
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
230230

src/transformers/models/gemma2/modular_gemma2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def forward(
260260
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
261261

262262
# Here we need to slice as we use a static cache by default, but FA2 does not support it
263-
if self.config._attn_implementation == "flash_attention_2":
263+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
264264
seq_len = attention_mask.shape[-1]
265265
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
266266

0 commit comments

Comments
 (0)