Skip to content

✨ Support prompt logprobs with static batching #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 7, 2025
Merged

Conversation

joerunde
Copy link
Collaborator

@joerunde joerunde commented Jul 1, 2025

Description

This PR enables prompt logprobs with static batching, at batch size 1 only. This enables some experimentation and model evaluation tasks on spyre hardware.

For static batching, the requires us to warm the model up with only_last_token=False, which passes back the hidden state tensors for the entire (padded) prompt. This is a big performance penalty, and is also only supported on spyre cards with batch size 1 currently.

So, this PR introduces an environment flag VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS which must be set to 1 to enable prompt logprobs. At bootup, we check to ensure that we're running in static batching mode with max batch size 1, and fail to boot otherwise. All requests that ask for prompt_logprobs will be rejected unless prompt logprobs are enabled. This is different than the behavior today, where requests always return [None] for prompt logprobs.

Copy link

github-actions bot commented Jul 1, 2025

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

monkeypatch.setenv("VLLM_USE_V1", 1)
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", 1)
llm = LLM(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for all the other e2e test we specify the following args explicitly:

vllm_model = LLM(
        model=model,
        tokenizer=model,
        max_model_len=max_model_len,
        max_num_seqs=max_num_seqs,
        block_size=block_size,
        tensor_parallel_size=tensor_parallel_size,
    )

might be nice to do this here too, to be a) consistent with the other test, and b) safe if default values should change...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've seen tests fail in upstream vLLM because the default max_model len is larger than what the hardware used for the tests supports. So I think it's a good idea to set these parameters to the minimum required for the test to pass. But isn't it an exceptional situation where the tokenizer is different than what vLLM would load as default for a specific model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For max_model_len, max_num_seqs, and block_size I'd actually rather not set them here because right now this test is specifically only for static batching, so they're unused. Once we support continuous batching though, I agree we can and should set those explicitly.

I could go ahead and parameterize this test for multi-aiu if we want, that would be helpful for it to run when it detects multiple cards.

For tokenizer, I agree we shouldn't need to ever set it differently, and if we did it wouldn't work to set the same model name anyway. But... there's no harm in setting it for consistency with other tests I guess

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor parallel works 🎉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree!

@@ -0,0 +1,744 @@
{
"Hello darkness my old friend": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this prompt:)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with chicken soup prompts :(

But I can add it to the list of stock prompts if you want to keep it

offset = hidden_states.shape[0] - num_prompt_tokens

prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already compute the logits in the forward pass on line 407. Would it be possible to reuse this tensor instead of recomputing?
something like:

def _get_prompt_logprobs_dict(
        self,
        logits: torch.Tensor,
        model_inputs: ModelForwardInputs,
    ) -> dict[str, Optional[LogprobsTensors]]:
        ...
        for loop:
            ...
            # Get the logits corresponding to this req's prompt tokens.
            req_idx = self.get_req_id_to_index(model_inputs.is_prompt)[req_id]
            logits = logits[req_idx]
            # The offset needs to account for the left padding that static
            # batching applies.
            # TODO: To support continuous batching the offset needs to be
            # calculated differently.
            offset = logits.shape[0] - num_prompt_tokens

            logits = logits[offset:offset + num_logits]
            ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, this works!


def get_num_prompt_logprobs(self, is_prefill: bool) -> dict[str, int]:
return (self.prefill_batch.num_prompt_logprobs
if is_prefill else self.input_batch.num_prompt_logprobs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the two branches here needed? Do we ever return self.input_batch.num_prompt_logprobs (for decode)? differently asked, are we not exiting in line 295 if self.no_prompt_logprob(model_inputs.is_prompt) is True for decodes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True true, this is very likely to only ever be called after a no_prompt_logprob guard so it's probably fine to simplify

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this also doesn't work on cb now anyway)

model = "ibm-ai-platform/micro-g3.3-8b-instruct-1b"
num_prompt_logprobs = 5

json_path = Path(__file__).parent.parent / "expected_prompt_logprobs.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to get these log prompts from the huggingface model instead of reading hard coded values here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a small model, we could execute on CPU perhaps with transformers as reference implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just didn't know how to implement that without reimplementing most of the code here again. I checked the upstream vllm tests for prompt logprobs and they run lm_eval with a test suite that uses prompt logprobs, which wouldn't be feasible for us to do. I figured that testing against known good results from vllm was a simple enough solution.

I'm down to pair on a solution, maybe I'll see if granite-3.3-8b can generate code to get prompt logprobs out of a transformers model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with hf implementation, which I hope is correct 😬

@maxdebayser
Copy link
Collaborator

Nice, this is looking good. I mostly agree with Yannick's comments.

@joerunde
Copy link
Collaborator Author

joerunde commented Jul 3, 2025

@yannicks1 @maxdebayser This should be ready for another look

I can follow up with enabling this for CB, that will require touching bits of the model that I think Yannick is currently working on to replace only_last_token=False with the token indices that we want to get back from the model

Copy link
Collaborator

@yannicks1 yannicks1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -1,4 +1,4 @@
"""Verification of vLLM output by comparing with HF
"""Tests validating the correctness and configuration of prompt_logprobs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you replace the wrong comment here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah, yeah I clicked on the wrong file, sorry. I'll fix it

monkeypatch.setenv("VLLM_USE_V1", 1)
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", 1)
llm = LLM(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree!

@yannicks1
Copy link
Collaborator

I can follow up with enabling this for CB, that will require touching bits of the model that I think Yannick is currently working on to replace only_last_token=False with the token indices that we want to get back from the model

I don't think logprobs for CB has priority before it is fully working. Currently we have only_last_token=False for CB, so it could indeed be added without graph changes, but you are right we are looking into replacing that. Final decision on when to incorporate that will be made soon. Once the decision has been made, we can address logprobs for CB in another PR. Advocating to get this in for SB now.

@joerunde joerunde enabled auto-merge (squash) July 7, 2025 19:22
@github-actions github-actions bot added the ready label Jul 7, 2025
@joerunde joerunde merged commit 4190580 into main Jul 7, 2025
16 of 19 checks passed
@joerunde joerunde deleted the prompt-logprobs branch July 7, 2025 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants