Skip to content

Conversation

@michaelfeil
Copy link
Contributor

@michaelfeil michaelfeil commented Dec 17, 2025

What does this PR do?

This PR brings varlen-flash-attention to CPU/Metal. (Its not softmax-fused / Flash), but at least its not-padded, so we don't do OOM.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests? If applicable, did you include or update the insta snapshots?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Contributor

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aside from one thing, everything else looks good to me!

Comment on lines +114 to +115
let causal_mask = create_causal_mask_batch(seq_len_q, seq_len_k, num_heads, device)?;
attention_scores = attention_scores.add(&causal_mask)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was just wondering that it looks like causual_mask and window_mask below are always fp32 type while attention_scores could be fp16. I'm not sure if I'm right, it might fail due to a type mismatch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, you are right. i figured that out too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case its always fp32, but for apple metal backend it could be fp16 afaik.

@michaelfeil
Copy link
Contributor Author

I am going to mark this PR as draft. I implemented a pretty fast attention primitive here: huggingface/candle#3250 once that is merged (which i am eagerly waiting for) we can do a simple copy of the function here (without the tests).

@michaelfeil michaelfeil marked this pull request as draft December 19, 2025 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants