-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Enable FlashInfer support encoder models and add head_dim padding workaround #6230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable FlashInfer support encoder models and add head_dim padding workaround #6230
Conversation
Hi, I noticed some tests failed. There seem to be a couple of issues:
Could you kindly re-run the CI workflow on the latest commit ( Thanks for your help! |
Thanks for your contribution! But I still have some confusions. In the benchmark result, the performance of flashinfer backend is similar to triton backend (even a little slower). But usually Flashinfer should be significantly faster than triton, so I guess this is probably due to the padding process which wastes a lot of computation resources. Maybe a better way should be raising an issue in the flashinfer repo, and pushing them to implement head_dim=32. Or I don't see any reason of using flashinfer backend for encoder model instead of triton, since triton is better in both flexibility and performance. The padding adds to the code complexity and makes it harder for us to maintain the codes. |
Thanks for your review, @Fridge003. You're right, the I'm happy to remove the padding. This PR will then focus on enabling non-causal attention for encoders (the The FlashInfer If this revised approach is acceptable, I'll update the PR. Thanks! |
Thanks for your update~ You can remove the padding logics first and add a comment that wait for update from flashinfer. On the CI part you can skip the models with head_dim lower than 64 for flashinfer backend. |
Hi, I tried removing the padding workaround and re-ran my test (this time testing 100,000 requests on BGE-m3 with async POST):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Hi @Fridge003 and team, It seems the CI checks ( After reviewing the logs, the failure occurred in the Given that my PR focuses on enabling FlashInfer for encoder models, and this test evaluates the general accuracy of a large decoder model ( Could you please help confirm this or suggest how to proceed? Perhaps the CI job could be re-run? Thanks for your help! |
Motivation
This PR aims to enhance the FlashInfer attention backend in SGLang to address two primary goals:
head_dim=32
as found inSupabase/gte-small
and potentially other BGE-like models) with FlashInfer's ragged prefill operations, an internal error is triggered, preventing these models from running.The original issue is: #6050
Modifications
This PR introduces the following key changes:
1. Encoder Model Support (Non-Causal Attention):
FlashInferAttnBackend.forward_extend
, thecausal
flag is now dynamically determined. For layers withlayer.attn_type == AttentionType.ENCODER_ONLY
,causal
is set toFalse
to enable bidirectional (non-causal) attention.save_kv_cache
is also appropriately set toFalse
as KV caching across layers is typically not used in the same way as in decoders.2. Workaround for FlashInfer
head_dim
Limitation (e.g., forhead_dim=32
):FlashInfer currently fails when using
BatchPrefillWithRaggedKVCacheWrapper
withhead_dim < 64
(e.g., 32). To work around this, we pad the head dimension up to 64 during prefill and forward steps:global_fake_head_dim
(default: 64) controls the padded size.head_dim
is less thanglobal_fake_head_dim
, we use the paddedfake_head_dim
for planning (begin_forward
), but keepsm_scale
based on the originalhead_dim
for correctness.sm_scale
remains based on the originalhead_dim
.This workaround is temporary until native support for
head_dim < 64
is available in FlashInfer.3. Verification and Results:
The effectiveness of these changes, particularly the padding workaround for
gte-small
(or a similar model withhead_dim=32
), was verified by comparing the FlashInfer backend's output (final embedding logits, e.g., shape(10000, 768)
) against Triton and a native PyTorch attention implementation (torch_native
).Numerical Similarity (vs
torch_native
forgte-small
like model):torch.allclose
(rtol=0.01, atol=0.001):torch.allclose
(rtol=0.001, atol=0.0001):1.89077000e-05
1.78243699e-05
9.76562500e-04
9.76562500e-04
These results show that the padded FlashInfer backend achieves MAE on the order of
~1.8e-5
compared to the native PyTorch version, similar to Triton. The slightly larger maximum error and failure for tighterallclose
tolerances are common for optimized kernels, especially withfloat16
/bfloat16
dtypes, and are considered within acceptable limits.Performance (seconds / 10,000 requests, for
gte-small
like model):The padded FlashInfer backend demonstrates performance comparable to Triton and significantly improves over the native PyTorch implementation.
I'm open to discussing whether the current solution is appropriate. It might be better to remove the temporary workaround and retain only the
causal
check, especially if full FlashInfer support is expected soon.That said, I'm so happy to keep the workaround in place while we wait for FlashInfer support to land.
Thank you for taking the time to review this -- I'm open to any suggestions.
Checklist