Skip to content

Conversation

YangKai0616
Copy link

The main contribution of this PR:

  1. Implement flash_attn.fwd and flash_attn.varlen_fwd on XPU.
  2. Modify test_flash_attn.py to support XPU testing.

This PR passed local pip install, nix compilation, and UT testing on the XPU. It also passed pip install and UT testing on the CUDA GPU.

Additional:
The test of test_flash_attn.py::test_flash_attn_kvcache on CUDA reports an error: RuntimeError: out must have shape (batch_size, seqlen_q, num_heads, head_size_og). I tested the original kernel flash-attn and got the same error, which is unrelated to this PR.

@YangKai0616
Copy link
Author

@danieldk please help review, thanks!

@danieldk
Copy link
Member

CUDA flash-attn2 is merged: https://github.com/huggingface/kernels-community/tree/main/flash-attn2 Could you rebase the PR on main?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants