Skip to content

Commit 8b9b498

Browse files
committed
test OpenAISpec
1 parent 07d60df commit 8b9b498

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed

src/litserve/examples/simple_example.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
import litserve as ls
24

35

@@ -54,8 +56,6 @@ def forward(self, x):
5456
self.model = Linear().to(device)
5557

5658
def decode_request(self, request):
57-
import torch
58-
5959
# get the input and create a 1D tensor on the correct device
6060
content = request["input"]
6161
return torch.tensor([content], device=self.device)
@@ -95,7 +95,7 @@ def decode_request(self, request):
9595

9696
def predict(self, x):
9797
for i in range(3):
98-
yield self.model(i, x.encode("utf-8").decode())
98+
yield self.model(i, x)
9999

100100
def encode_response(self, output_stream):
101101
for output in output_stream:

tests/test_examples.py

+98
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import pytest
2+
import torch.nn
23
from asgi_lifespan import LifespanManager
34
from httpx import AsyncClient
5+
6+
from litserve.examples.openai_spec_example import (
7+
OpenAIWithUsage,
8+
OpenAIWithUsageEncodeResponse,
9+
OpenAIBatchingWithUsage,
10+
)
11+
from litserve.examples.simple_example import SimpleStreamAPI
412
from litserve.utils import wrap_litserve_start
513
import litserve as ls
614

@@ -33,3 +41,93 @@ async def test_simple_api():
3341
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
3442
response = await ac.post("/predict", json={"input": 4.0})
3543
assert response.json() == {"output": 16.0}
44+
45+
46+
@pytest.mark.asyncio()
47+
async def test_simple_api_without_server():
48+
api = ls.examples.SimpleLitAPI()
49+
api.setup(None)
50+
assert api.model is not None, "Model should be loaded after setup"
51+
assert api.predict(4) == 16, "Model should be able to predict"
52+
53+
54+
@pytest.mark.asyncio()
55+
async def test_simple_pytorch_api_without_server():
56+
api = ls.examples.SimpleTorchAPI()
57+
api.setup("cpu")
58+
assert api.model is not None, "Model should be loaded after setup"
59+
assert isinstance(api.model, torch.nn.Module)
60+
assert api.decode_request({"input": 4}) == 4, "Request should be decoded"
61+
assert api.predict(torch.Tensor([4])).cpu() == 9, "Model should be able to predict"
62+
assert api.encode_response(9) == {"output": 9}, "Response should be encoded"
63+
64+
65+
@pytest.mark.asyncio()
66+
async def test_simple_stream_api_without_server():
67+
api = SimpleStreamAPI()
68+
api.setup(None)
69+
assert api.model is not None, "Model should be loaded after setup"
70+
assert api.decode_request({"input": 4}) == 4, "Request should be decoded"
71+
assert list(api.predict(4)) == ["0: 4", "1: 4", "2: 4"], "Model should be able to predict"
72+
assert list(api.encode_response(["0: 4", "1: 4", "2: 4"])) == [
73+
{"output": "0: 4"},
74+
{"output": "1: 4"},
75+
{"output": "2: 4"},
76+
], "Response should be encoded"
77+
78+
79+
@pytest.mark.asyncio()
80+
async def test_openai_with_usage():
81+
api = OpenAIWithUsage()
82+
api.setup(None)
83+
response = list(api.predict("10 + 6"))
84+
assert response == [
85+
{
86+
"role": "assistant",
87+
"content": "10 + 6 is equal to 16.",
88+
"prompt_tokens": 25,
89+
"completion_tokens": 10,
90+
"total_tokens": 35,
91+
}
92+
], "Response should match expected output"
93+
94+
95+
@pytest.mark.asyncio()
96+
async def test_openai_with_usage_encode_response():
97+
api = OpenAIWithUsageEncodeResponse()
98+
api.setup(None)
99+
response = list(api.predict("10 + 6"))
100+
encoded_response = list(api.encode_response(response))
101+
assert encoded_response == [
102+
{"role": "assistant", "content": "10"},
103+
{"role": "assistant", "content": " +"},
104+
{"role": "assistant", "content": " "},
105+
{"role": "assistant", "content": "6"},
106+
{"role": "assistant", "content": " is"},
107+
{"role": "assistant", "content": " equal"},
108+
{"role": "assistant", "content": " to"},
109+
{"role": "assistant", "content": " "},
110+
{"role": "assistant", "content": "16"},
111+
{"role": "assistant", "content": "."},
112+
{"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35},
113+
], "Encoded response should match expected output"
114+
115+
116+
@pytest.mark.asyncio()
117+
async def test_openai_batching_with_usage():
118+
api = OpenAIBatchingWithUsage()
119+
api.setup(None)
120+
inputs = ["10 + 6", "10 + 6"]
121+
batched_response = list(api.predict(inputs))
122+
assert batched_response == [["10 + 6 is equal to 16."] * 2], "Batched response should match expected output"
123+
encoded_response = list(api.encode_response(batched_response, [{"temperature": 1.0}, {"temperature": 1.0}]))
124+
assert encoded_response == [
125+
[
126+
{"role": "assistant", "content": "10 + 6 is equal to 16."},
127+
{"role": "assistant", "content": "10 + 6 is equal to 16."},
128+
],
129+
[
130+
{"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35},
131+
{"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35},
132+
],
133+
], "Encoded batched response should match expected output"

0 commit comments

Comments
 (0)