Skip to content

add support for cudnn sdpa checkpoint #1805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def cudnn_jax_flash_attention(
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
lengths = jnp.sum(decoder_segment_ids, axis=-1)

return dot_product_attention(
output, lse = dot_product_attention(
query,
key,
value,
Expand All @@ -901,7 +901,7 @@ def cudnn_jax_flash_attention(
return_residual=True
)
else:
return dot_product_attention(
output, lse = dot_product_attention(
query,
key,
value,
Expand All @@ -911,6 +911,9 @@ def cudnn_jax_flash_attention(
qkv_layout="BTNH",
return_residual=True
)
output = checkpoint_name(output, "context")
lse = checkpoint_name(lse, "context")
return output, lse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this match the return type hint of Array?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya this is good point, its failing our internal lint, probably need to change type hint


def compute_local_attention(
self, attn_weights: Array, value: Array | KVTensor, q_seq_len: int, model_mode: str
Expand Down
Loading