Skip to content

Commit f28c816

Browse files
authored
fix openai usage info for non-streaming response (#399)
* fix openai usage * add test
1 parent 636c9fd commit f28c816

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

src/litserve/loops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,7 @@ def run(
919919
for uid, response_queue_id in self.response_queue_ids.items():
920920
self.put_error_response(response_queues, response_queue_id, uid, e)
921921
self.response_queue_ids.clear()
922+
self.active_sequences.clear()
922923

923924

924925
def inference_worker(

src/litserve/specs/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
443443
logger.debug(encoded_response)
444444
chat_msg = ChatMessage(**encoded_response)
445445
usage = UsageInfo(**encoded_response)
446+
usage_infos.append(usage) # Aggregate usage info across all choices
446447
msgs.append(chat_msg.content)
447448
if chat_msg.tool_calls:
448449
tool_calls = chat_msg.tool_calls
@@ -451,6 +452,5 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
451452
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
452453
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
453454
choices.append(choice)
454-
usage_infos.append(usage) # Only use the last item from encode_response
455455

456456
return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))

tests/test_specs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,49 @@ async def test_openai_token_usage(api, batch_size, openai_request_data, openai_r
8989
assert result["usage"] == openai_response_data["usage"]
9090

9191

92+
class OpenAIWithUsagePerToken(ls.LitAPI):
93+
def setup(self, device):
94+
self.model = None
95+
96+
def predict(self, x):
97+
for i in range(1, 6):
98+
yield {
99+
"role": "assistant",
100+
"content": f"{i}",
101+
"prompt_tokens": 0,
102+
"completion_tokens": 1,
103+
"total_tokens": 1,
104+
}
105+
106+
107+
# OpenAIWithUsagePerToken
108+
@pytest.mark.asyncio
109+
@pytest.mark.parametrize(
110+
("api", "batch_size"),
111+
[
112+
(OpenAIWithUsagePerToken(), 1),
113+
],
114+
)
115+
async def test_openai_per_token_usage(api, batch_size, openai_request_data, openai_response_data):
116+
server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01)
117+
with wrap_litserve_start(server) as server:
118+
async with LifespanManager(server.app) as manager, AsyncClient(
119+
transport=ASGITransport(app=manager.app), base_url="http://test"
120+
) as ac:
121+
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
122+
assert resp.status_code == 200, "Status code should be 200"
123+
result = resp.json()
124+
content = result["choices"][0]["message"]["content"]
125+
assert content == "12345", "LitAPI predict response should match with the generated output"
126+
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"
127+
128+
# with streaming
129+
openai_request_data["stream"] = True
130+
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
131+
assert resp.status_code == 200, "Status code should be 200"
132+
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"
133+
134+
92135
@pytest.mark.asyncio
93136
async def test_openai_spec_with_image(openai_request_data_with_image):
94137
server = ls.LitServer(TestAPI(), spec=OpenAISpec())

0 commit comments

Comments
 (0)