-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Open
Labels
dataRay Data-related issuesRay Data-related issuesenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)
Description
Description
I'm trying to use ray.data.llm to do DP inference with multi-gpu following this wiki: https://docs.ray.io/en/latest/data/working-with-llms.html
I would like to get the logprobs of the output tokens, but the response does not contain that, even if I enabled log_probs in sampling_params (I think this works in native vllm):
guided_decoding_params = dict(choice=["true", "false"])
self.sampling_params = dict(
max_tokens=16,
temperature=0,
stop_token_ids=stop_token_ids,
guided_decoding=guided_decoding_params,
logprobs=1,
)
processor = build_llm_processor(
self.vllm_engine_processor_config,
preprocess=lambda row: dict(
messages=[
{"role": "user", "content": self._record_to_prompt(row)},
],
sampling_params=self.sampling_params,
),
postprocess=lambda row: self._get_llm_prediction(row),
)
I think it'd make sense to pass the logprobs to the ray data response if vllm already supports it. If there's a way to access it that I missed please point out as well, thanks
Use case
Being able to access logprobs is important in ray data offline inference
Metadata
Metadata
Assignees
Labels
dataRay Data-related issuesRay Data-related issuesenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)