-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
EAGLE speculative decoding support for Llama4 #20403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zixi-qi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces comprehensive support for EAGLE speculative decoding with Llama4 models. It defines a new draft model architecture tailored for this purpose, including its own forward pass, weight loading, and configuration handling. The changes enable vLLM to leverage speculative decoding for Llama4, potentially improving inference performance by generating speculative tokens with a smaller, faster draft model.
Highlights
- New Model Support: I've added a new model definition,
EagleLlama4ForCausalLM
, to support EAGLE speculative decoding specifically for Llama4 models. This new class inherits fromLlama4ForCausalLM
and integrates the draft model logic. - Draft Model Implementation: A new
LlamaModel
class has been introduced withinllama4_eagle.py
to serve as the dense-only draft model for EAGLE. This model includes its own embedding layer, Llama4 decoder layers, and a uniquefc
layer for combining input embeddings and hidden states. It also handles specific weight loading and configuration validation for its role as a draft model. - Configuration and Weight Loading for Draft Model: The draft model's configuration is validated and updated to account for its position relative to the target model (using
start_layer_id
) and its specific quantization requirements (currently supporting TorchAO). Weight loading for the draft model is carefully managed, including handling stacked parameters and ensuring thelm_head
is shared with the target model. - Model Registry Integration: The new
EagleLlama4ForCausalLM
has been registered invllm/model_executor/models/registry.py
, making it discoverable and loadable by the vLLM system.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
86d83f6
to
fb51dfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for EAGLE speculative decoding for Llama4 models. The implementation introduces a new EagleLlama4ForCausalLM
model, which correctly follows the EAGLE architecture by combining embeddings with hidden states.
The changes are well-structured. I've identified a few areas for improvement:
- A potential side effect due to in-place modification of a shared configuration object.
- A robustness improvement in the model's forward pass.
- A memory efficiency improvement in the weight loading process.
Once these points are addressed, the code should be in great shape. The addition to the model registry is correct.
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying vllm_config
in-place can lead to unexpected side effects, as this configuration object might be shared and used in other parts of the application. Specifically, self.config.no_rope_layers
is modified in validate_and_update_config
, which can affect other components that rely on the original configuration.
To avoid this, it's safer to work with a deep copy of the configuration object. Please add import copy
at the top of the file and use copy.deepcopy
here.
self.config = vllm_config. \ | |
speculative_config.draft_model_config.hf_config | |
self.config = copy.deepcopy( | |
vllm_config.speculative_config.draft_model_config.hf_config) |
if inputs_embeds is None: | ||
inputs_embeds = self.get_input_embeddings(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward
method's signature allows input_ids
to be Optional[torch.Tensor]
. If inputs_embeds
is None
, the code proceeds to call self.get_input_embeddings(input_ids)
, which expects a torch.Tensor
, not None
. This could lead to a TypeError
if LlamaModel.forward
is called with both inputs_embeds=None
and input_ids=None
.
While the current call chain from EagleLlama4ForCausalLM
ensures input_ids
is not None
, adding an explicit check would make LlamaModel
more robust and prevent potential issues if it's used in other contexts in the future.
if inputs_embeds is None: | |
inputs_embeds = self.get_input_embeddings(input_ids) | |
if inputs_embeds is None: | |
assert input_ids is not None, ( | |
"input_ids must be provided if inputs_embeds is not.") | |
inputs_embeds = self.get_input_embeddings(input_ids) |
weights = [ | ||
self.permute_qk_weight_for_rotary(name, loaded_weight) | ||
for name, loaded_weight in weights | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a list comprehension here [...]
will create a new list in memory containing all the processed weights. For large models, this can consume a significant amount of memory unnecessarily.
Using a generator expression (...)
instead will process the weights one by one, avoiding the creation of the intermediate list and making the process more memory-efficient.
weights = [ | |
self.permute_qk_weight_for_rotary(name, loaded_weight) | |
for name, loaded_weight in weights | |
] | |
weights = ( | |
self.permute_qk_weight_for_rotary(name, loaded_weight) | |
for name, loaded_weight in weights | |
) |
fb51dfe
to
15638a9
Compare
Signed-off-by: qizixi <[email protected]>
15638a9
to
baaafc3
Compare
Purpose
Support EAGLE speculative decoding with dense-only draft model for Llama4
Test Plan
WIP, needs to find oss draft model
Test Result