diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 3a086f6e..93162842 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -919,6 +919,7 @@ def run( for uid, response_queue_id in self.response_queue_ids.items(): self.put_error_response(response_queues, response_queue_id, uid, e) self.response_queue_ids.clear() + self.active_sequences.clear() def inference_worker( diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 90a63618..bf783ad5 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -443,6 +443,7 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat logger.debug(encoded_response) chat_msg = ChatMessage(**encoded_response) usage = UsageInfo(**encoded_response) + usage_infos.append(usage) # Aggregate usage info across all choices msgs.append(chat_msg.content) if chat_msg.tool_calls: tool_calls = chat_msg.tool_calls @@ -451,6 +452,5 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat msg = {"role": "assistant", "content": content, "tool_calls": tool_calls} choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop") choices.append(choice) - usage_infos.append(usage) # Only use the last item from encode_response return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos)) diff --git a/tests/test_specs.py b/tests/test_specs.py index 78b71980..ff4a77fc 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -89,6 +89,49 @@ async def test_openai_token_usage(api, batch_size, openai_request_data, openai_r assert result["usage"] == openai_response_data["usage"] +class OpenAIWithUsagePerToken(ls.LitAPI): + def setup(self, device): + self.model = None + + def predict(self, x): + for i in range(1, 6): + yield { + "role": "assistant", + "content": f"{i}", + "prompt_tokens": 0, + "completion_tokens": 1, + "total_tokens": 1, + } + + +# OpenAIWithUsagePerToken +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("api", "batch_size"), + [ + (OpenAIWithUsagePerToken(), 1), + ], +) +async def test_openai_per_token_usage(api, batch_size, openai_request_data, openai_response_data): + server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient( + transport=ASGITransport(app=manager.app), base_url="http://test" + ) as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + result = resp.json() + content = result["choices"][0]["message"]["content"] + assert content == "12345", "LitAPI predict response should match with the generated output" + assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens" + + # with streaming + openai_request_data["stream"] = True + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens" + + @pytest.mark.asyncio async def test_openai_spec_with_image(openai_request_data_with_image): server = ls.LitServer(TestAPI(), spec=OpenAISpec())