You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to figure out how to replace the interface from flash attention to flashinfer
In flashattention, the q,k,v in flash_attn_func has the batchsize dimension: q: (batch_size, seqlen, nheads, headdim). However, I found that in flashinfer, it seems that it only has [kv_len, num_kv_heads, head_dim] for both prefill and decode api.
So how can I replace the flash_attn_func to FlashInfer Python API?
Jason
The text was updated successfully, but these errors were encountered:
Hi,
I am trying to figure out how to replace the interface from flash attention to flashinfer
In flashattention, the q,k,v in flash_attn_func has the batchsize dimension: q: (batch_size, seqlen, nheads, headdim). However, I found that in flashinfer, it seems that it only has [kv_len, num_kv_heads, head_dim] for both prefill and decode api.
So how can I replace the flash_attn_func to FlashInfer Python API?
Jason
The text was updated successfully, but these errors were encountered: