Skip to content

Enable FlashInfer support encoder models and add head_dim padding workaround #6230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 20, 2025

Conversation

ccs96307
Copy link
Contributor

Motivation

This PR aims to enhance the FlashInfer attention backend in SGLang to address two primary goals:

  1. Enable support for encoder-only models: Currently, the FlashInfer backend needs adjustments to correctly handle non-causal attention required by encoder architectures.
  2. Resolve an "Invalid configuration" error for specific head dimensions: When using encoder models with certain head dimensions (e.g., head_dim=32 as found in Supabase/gte-small and potentially other BGE-like models) with FlashInfer's ragged prefill operations, an internal error is triggered, preventing these models from running.

The original issue is: #6050

Modifications

This PR introduces the following key changes:

1. Encoder Model Support (Non-Causal Attention):

  • In FlashInferAttnBackend.forward_extend, the causal flag is now dynamically determined. For layers with layer.attn_type == AttentionType.ENCODER_ONLY, causal is set to False to enable bidirectional (non-causal) attention.
  • For encoder self-attention, save_kv_cache is also appropriately set to False as KV caching across layers is typically not used in the same way as in decoders.

2. Workaround for FlashInfer head_dim Limitation (e.g., for head_dim=32):
FlashInfer currently fails when using BatchPrefillWithRaggedKVCacheWrapper with head_dim < 64 (e.g., 32). To work around this, we pad the head dimension up to 64 during prefill and forward steps:

  • A global variable global_fake_head_dim (default: 64) controls the padded size.
  • During prefill:
    • If the model’s head_dim is less than global_fake_head_dim, we use the padded fake_head_dim for planning (begin_forward), but keep sm_scale based on the original head_dim for correctness.
  • During forward:
    • Q, K, and V tensors are padded along the head dimension.
    • sm_scale remains based on the original head_dim.
    • FlashInfer returns output with the padded size, which we truncate back to the original shape.

This workaround is temporary until native support for head_dim < 64 is available in FlashInfer.

3. Verification and Results:
The effectiveness of these changes, particularly the padding workaround for gte-small (or a similar model with head_dim=32), was verified by comparing the FlashInfer backend's output (final embedding logits, e.g., shape (10000, 768)) against Triton and a native PyTorch attention implementation (torch_native).

Numerical Similarity (vs torch_native for gte-small like model):

  • torch.allclose (rtol=0.01, atol=0.001):
    • FlashInfer: True
    • Triton: True
  • torch.allclose (rtol=0.001, atol=0.0001):
    • FlashInfer: False
    • Triton: False
  • Mean Absolute Error (MAE):
    • FlashInfer: 1.89077000e-05
    • Triton: 1.78243699e-05
  • Maximum Absolute Error:
    • FlashInfer: 9.76562500e-04
    • Triton: 9.76562500e-04

These results show that the padded FlashInfer backend achieves MAE on the order of ~1.8e-5 compared to the native PyTorch version, similar to Triton. The slightly larger maximum error and failure for tighter allclose tolerances are common for optimized kernels, especially with float16/bfloat16 dtypes, and are considered within acceptable limits.

Performance (seconds / 10,000 requests, for gte-small like model):

  • FlashInfer (padded): 39.551 seconds
  • Triton: 39.144 seconds
  • Torch Native: 46.192 seconds

The padded FlashInfer backend demonstrates performance comparable to Triton and significantly improves over the native PyTorch implementation.


I'm open to discussing whether the current solution is appropriate. It might be better to remove the temporary workaround and retain only the causal check, especially if full FlashInfer support is expected soon.

That said, I'm so happy to keep the workaround in place while we wait for FlashInfer support to land.
Thank you for taking the time to review this -- I'm open to any suggestions.

Checklist

@Fridge003 Fridge003 self-assigned this May 12, 2025
@ccs96307
Copy link
Contributor Author

Hi,

I noticed some tests failed. There seem to be a couple of issues:

  1. One error is Error: fatal: remote error: upload-pack: not our ref d9e280a70f7be9b97bc7ba2fcd3bc17c2dbf23cc. I've recently updated this PR branch by merging the latest changes from main (the current head of this PR is d86966e), so this checkout error might have occurred if the test was running on a previous, now-stale reference.

  2. Another issue is ValueError: Unrecognized model in neuralmagic/Qwen2-7B-Instruct-FP8.... I suspect this might be unrelated to the changes in my PR.

Could you kindly re-run the CI workflow on the latest commit (d86966e) of this PR when you get a chance?

Thanks for your help!

@Fridge003
Copy link
Collaborator

Fridge003 commented May 14, 2025

Thanks for your contribution! But I still have some confusions.

In the benchmark result, the performance of flashinfer backend is similar to triton backend (even a little slower). But usually Flashinfer should be significantly faster than triton, so I guess this is probably due to the padding process which wastes a lot of computation resources.

Maybe a better way should be raising an issue in the flashinfer repo, and pushing them to implement head_dim=32. Or I don't see any reason of using flashinfer backend for encoder model instead of triton, since triton is better in both flexibility and performance. The padding adds to the code complexity and makes it harder for us to maintain the codes.

@ccs96307
Copy link
Contributor Author

Thanks for your review, @Fridge003. You're right, the head_dim padding workaround adds complexity without a clear performance win over Triton in this scenario.

I'm happy to remove the padding. This PR will then focus on enabling non-causal attention for encoders (the causal flag logic), allowing FlashInfer to be used with encoder models that have natively supported head dimensions (in my test, at lease BGE-m3 is ok).

The FlashInfer head_dim limitation itself is tracked here: flashinfer-ai/flashinfer#1048.

If this revised approach is acceptable, I'll update the PR. Thanks!

@Fridge003
Copy link
Collaborator

Thanks for your review, @Fridge003. You're right, the head_dim padding workaround adds complexity without a clear performance win over Triton in this scenario.

I'm happy to remove the padding. This PR will then focus on enabling non-causal attention for encoders (the causal flag logic), allowing FlashInfer to be used with encoder models that have natively supported head dimensions (in my test, at lease BGE-m3 is ok).

The FlashInfer head_dim limitation itself is tracked here: flashinfer-ai/flashinfer#1048.

If this revised approach is acceptable, I'll update the PR. Thanks!

Thanks for your update~ You can remove the padding logics first and add a comment that wait for update from flashinfer. On the CI part you can skip the models with head_dim lower than 64 for flashinfer backend.

@ccs96307
Copy link
Contributor Author

ccs96307 commented May 15, 2025

Hi, I tried removing the padding workaround and re-ran my test (this time testing 100,000 requests on BGE-m3 with async POST):

  • flashinfer: 169 seconds / 100,000 requests (0.00169 seconds per request)
  • triton: 185 seconds / 100,000 requests (0.00185 seconds per request)
  • torch_native: 405 seconds / 100,000 requests (0.00405 seconds per request)

Copy link
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Fridge003 Fridge003 requested a review from BBuf as a code owner May 17, 2025 19:16
@Fridge003 Fridge003 added the ready-to-merge The PR is ready to merge after the CI is green. label May 19, 2025
@Fridge003 Fridge003 removed the ready-to-merge The PR is ready to merge after the CI is green. label May 20, 2025
@ccs96307
Copy link
Contributor Author

Hi @Fridge003 and team,

It seems the CI checks (amd_ci_exec.sh python3 test_eval_accuracy_large.py) failed.

After reviewing the logs, the failure occurred in the test_human_eval benchmark. The model scored 0.639, which is just slightly below the required threshold of 0.64. The test was retried once but failed again with a similar score.

Given that my PR focuses on enabling FlashInfer for encoder models, and this test evaluates the general accuracy of a large decoder model (Llama-3.1-8B-Instruct) on an AMD/ROCm platform, I suspect this might be a flaky test and likely unrelated to my changes.

Could you please help confirm this or suggest how to proceed? Perhaps the CI job could be re-run?

Thanks for your help!

@Fridge003 Fridge003 added the ready-to-merge The PR is ready to merge after the CI is green. label Jul 12, 2025
@zhyncs zhyncs merged commit cbdfb77 into sgl-project:main Jul 20, 2025
22 of 60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready-to-merge The PR is ready to merge after the CI is green.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants