-
Notifications
You must be signed in to change notification settings - Fork 344
feat: add varlen attention on cpu #777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
kozistr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aside from one thing, everything else looks good to me!
| let causal_mask = create_causal_mask_batch(seq_len_q, seq_len_k, num_heads, device)?; | ||
| attention_scores = attention_scores.add(&causal_mask)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was just wondering that it looks like causual_mask and window_mask below are always fp32 type while attention_scores could be fp16. I'm not sure if I'm right, it might fail due to a type mismatch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, you are right. i figured that out too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case its always fp32, but for apple metal backend it could be fp16 afaik.
|
I am going to mark this PR as draft. I implemented a pretty fast attention primitive here: huggingface/candle#3250 once that is merged (which i am eagerly waiting for) we can do a simple copy of the function here (without the tests). |
What does this PR do?
This PR brings varlen-flash-attention to CPU/Metal. (Its not softmax-fused / Flash), but at least its not-padded, so we don't do OOM.
Fixes # (issue)
Before submitting
instasnapshots?Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.