-
Notifications
You must be signed in to change notification settings - Fork 18
✨ Support prompt logprobs with static batching #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
monkeypatch.setenv("VLLM_USE_V1", 1) | ||
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) | ||
monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", 1) | ||
llm = LLM(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for all the other e2e test we specify the following args explicitly:
vllm_model = LLM(
model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
block_size=block_size,
tensor_parallel_size=tensor_parallel_size,
)
might be nice to do this here too, to be a) consistent with the other test, and b) safe if default values should change...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've seen tests fail in upstream vLLM because the default max_model len is larger than what the hardware used for the tests supports. So I think it's a good idea to set these parameters to the minimum required for the test to pass. But isn't it an exceptional situation where the tokenizer is different than what vLLM would load as default for a specific model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For max_model_len
, max_num_seqs
, and block_size
I'd actually rather not set them here because right now this test is specifically only for static batching, so they're unused. Once we support continuous batching though, I agree we can and should set those explicitly.
I could go ahead and parameterize this test for multi-aiu if we want, that would be helpful for it to run when it detects multiple cards.
For tokenizer
, I agree we shouldn't need to ever set it differently, and if we did it wouldn't work to set the same model name anyway. But... there's no harm in setting it for consistency with other tests I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor parallel works 🎉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree!
tests/expected_prompt_logprobs.json
Outdated
@@ -0,0 +1,744 @@ | |||
{ | |||
"Hello darkness my old friend": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this prompt:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with chicken soup prompts :(
But I can add it to the list of stock prompts if you want to keep it
offset = hidden_states.shape[0] - num_prompt_tokens | ||
|
||
prompt_hidden_states = hidden_states[offset:offset + num_logits] | ||
logits = self.model.compute_logits(prompt_hidden_states, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already compute the logits in the forward pass on line 407. Would it be possible to reuse this tensor instead of recomputing?
something like:
def _get_prompt_logprobs_dict(
self,
logits: torch.Tensor,
model_inputs: ModelForwardInputs,
) -> dict[str, Optional[LogprobsTensors]]:
...
for loop:
...
# Get the logits corresponding to this req's prompt tokens.
req_idx = self.get_req_id_to_index(model_inputs.is_prompt)[req_id]
logits = logits[req_idx]
# The offset needs to account for the left padding that static
# batching applies.
# TODO: To support continuous batching the offset needs to be
# calculated differently.
offset = logits.shape[0] - num_prompt_tokens
logits = logits[offset:offset + num_logits]
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, this works!
|
||
def get_num_prompt_logprobs(self, is_prefill: bool) -> dict[str, int]: | ||
return (self.prefill_batch.num_prompt_logprobs | ||
if is_prefill else self.input_batch.num_prompt_logprobs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the two branches here needed? Do we ever return self.input_batch.num_prompt_logprobs
(for decode)? differently asked, are we not exiting in line 295 if self.no_prompt_logprob(model_inputs.is_prompt)
is True for decodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True true, this is very likely to only ever be called after a no_prompt_logprob
guard so it's probably fine to simplify
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this also doesn't work on cb now anyway)
model = "ibm-ai-platform/micro-g3.3-8b-instruct-1b" | ||
num_prompt_logprobs = 5 | ||
|
||
json_path = Path(__file__).parent.parent / "expected_prompt_logprobs.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to get these log prompts from the huggingface model instead of reading hard coded values here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a small model, we could execute on CPU perhaps with transformers as reference implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I just didn't know how to implement that without reimplementing most of the code here again. I checked the upstream vllm tests for prompt logprobs and they run lm_eval
with a test suite that uses prompt logprobs, which wouldn't be feasible for us to do. I figured that testing against known good results from vllm was a simple enough solution.
I'm down to pair on a solution, maybe I'll see if granite-3.3-8b can generate code to get prompt logprobs out of a transformers model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced with hf implementation, which I hope is correct 😬
Nice, this is looking good. I mostly agree with Yannick's comments. |
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
@yannicks1 @maxdebayser This should be ready for another look I can follow up with enabling this for CB, that will require touching bits of the model that I think Yannick is currently working on to replace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -1,4 +1,4 @@ | |||
"""Verification of vLLM output by comparing with HF | |||
"""Tests validating the correctness and configuration of prompt_logprobs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you replace the wrong comment here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hah, yeah I clicked on the wrong file, sorry. I'll fix it
monkeypatch.setenv("VLLM_USE_V1", 1) | ||
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) | ||
monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", 1) | ||
llm = LLM(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree!
I don't think logprobs for CB has priority before it is fully working. Currently we have |
Signed-off-by: Joe Runde <[email protected]>
Description
This PR enables prompt logprobs with static batching, at batch size 1 only. This enables some experimentation and model evaluation tasks on spyre hardware.
For static batching, the requires us to warm the model up with
only_last_token=False
, which passes back the hidden state tensors for the entire (padded) prompt. This is a big performance penalty, and is also only supported on spyre cards with batch size 1 currently.So, this PR introduces an environment flag
VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS
which must be set to 1 to enable prompt logprobs. At bootup, we check to ensure that we're running in static batching mode with max batch size 1, and fail to boot otherwise. All requests that ask forprompt_logprobs
will be rejected unless prompt logprobs are enabled. This is different than the behavior today, where requests always return[None]
for prompt logprobs.