Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Dec 18, 2024
1 parent 06b484a commit af39d42
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,49 @@ async def test_openai_token_usage(api, batch_size, openai_request_data, openai_r
assert result["usage"] == openai_response_data["usage"]


class OpenAIWithUsagePerToken(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
for i in range(1, 6):
yield {
"role": "assistant",
"content": f"{i}",
"prompt_tokens": 0,
"completion_tokens": 1,
"total_tokens": 1,
}


# OpenAIWithUsagePerToken
@pytest.mark.asyncio
@pytest.mark.parametrize(
("api", "batch_size"),
[
(OpenAIWithUsagePerToken(), 1),
],
)
async def test_openai_per_token_usage(api, batch_size, openai_request_data, openai_response_data):
server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
result = resp.json()
content = result["choices"][0]["message"]["content"]
assert content == "12345", "LitAPI predict response should match with the generated output"
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"

# with streaming
openai_request_data["stream"] = True
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"


@pytest.mark.asyncio
async def test_openai_spec_with_image(openai_request_data_with_image):
server = ls.LitServer(TestAPI(), spec=OpenAISpec())
Expand Down

0 comments on commit af39d42

Please sign in to comment.