|
1 | 1 | import pytest
|
| 2 | +import torch.nn |
2 | 3 | from asgi_lifespan import LifespanManager
|
3 | 4 | from httpx import AsyncClient
|
| 5 | + |
| 6 | +from litserve.examples.openai_spec_example import ( |
| 7 | + OpenAIWithUsage, |
| 8 | + OpenAIWithUsageEncodeResponse, |
| 9 | + OpenAIBatchingWithUsage, |
| 10 | +) |
| 11 | +from litserve.examples.simple_example import SimpleStreamAPI |
4 | 12 | from litserve.utils import wrap_litserve_start
|
5 | 13 | import litserve as ls
|
6 | 14 |
|
@@ -33,3 +41,93 @@ async def test_simple_api():
|
33 | 41 | async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
|
34 | 42 | response = await ac.post("/predict", json={"input": 4.0})
|
35 | 43 | assert response.json() == {"output": 16.0}
|
| 44 | + |
| 45 | + |
| 46 | +@pytest.mark.asyncio() |
| 47 | +async def test_simple_api_without_server(): |
| 48 | + api = ls.examples.SimpleLitAPI() |
| 49 | + api.setup(None) |
| 50 | + assert api.model is not None, "Model should be loaded after setup" |
| 51 | + assert api.predict(4) == 16, "Model should be able to predict" |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.asyncio() |
| 55 | +async def test_simple_pytorch_api_without_server(): |
| 56 | + api = ls.examples.SimpleTorchAPI() |
| 57 | + api.setup("cpu") |
| 58 | + assert api.model is not None, "Model should be loaded after setup" |
| 59 | + assert isinstance(api.model, torch.nn.Module) |
| 60 | + assert api.decode_request({"input": 4}) == 4, "Request should be decoded" |
| 61 | + assert api.predict(torch.Tensor([4])).cpu() == 9, "Model should be able to predict" |
| 62 | + assert api.encode_response(9) == {"output": 9}, "Response should be encoded" |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.asyncio() |
| 66 | +async def test_simple_stream_api_without_server(): |
| 67 | + api = SimpleStreamAPI() |
| 68 | + api.setup(None) |
| 69 | + assert api.model is not None, "Model should be loaded after setup" |
| 70 | + assert api.decode_request({"input": 4}) == 4, "Request should be decoded" |
| 71 | + assert list(api.predict(4)) == ["0: 4", "1: 4", "2: 4"], "Model should be able to predict" |
| 72 | + assert list(api.encode_response(["0: 4", "1: 4", "2: 4"])) == [ |
| 73 | + {"output": "0: 4"}, |
| 74 | + {"output": "1: 4"}, |
| 75 | + {"output": "2: 4"}, |
| 76 | + ], "Response should be encoded" |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.asyncio() |
| 80 | +async def test_openai_with_usage(): |
| 81 | + api = OpenAIWithUsage() |
| 82 | + api.setup(None) |
| 83 | + response = list(api.predict("10 + 6")) |
| 84 | + assert response == [ |
| 85 | + { |
| 86 | + "role": "assistant", |
| 87 | + "content": "10 + 6 is equal to 16.", |
| 88 | + "prompt_tokens": 25, |
| 89 | + "completion_tokens": 10, |
| 90 | + "total_tokens": 35, |
| 91 | + } |
| 92 | + ], "Response should match expected output" |
| 93 | + |
| 94 | + |
| 95 | +@pytest.mark.asyncio() |
| 96 | +async def test_openai_with_usage_encode_response(): |
| 97 | + api = OpenAIWithUsageEncodeResponse() |
| 98 | + api.setup(None) |
| 99 | + response = list(api.predict("10 + 6")) |
| 100 | + encoded_response = list(api.encode_response(response)) |
| 101 | + assert encoded_response == [ |
| 102 | + {"role": "assistant", "content": "10"}, |
| 103 | + {"role": "assistant", "content": " +"}, |
| 104 | + {"role": "assistant", "content": " "}, |
| 105 | + {"role": "assistant", "content": "6"}, |
| 106 | + {"role": "assistant", "content": " is"}, |
| 107 | + {"role": "assistant", "content": " equal"}, |
| 108 | + {"role": "assistant", "content": " to"}, |
| 109 | + {"role": "assistant", "content": " "}, |
| 110 | + {"role": "assistant", "content": "16"}, |
| 111 | + {"role": "assistant", "content": "."}, |
| 112 | + {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}, |
| 113 | + ], "Encoded response should match expected output" |
| 114 | + |
| 115 | + |
| 116 | +@pytest.mark.asyncio() |
| 117 | +async def test_openai_batching_with_usage(): |
| 118 | + api = OpenAIBatchingWithUsage() |
| 119 | + api.setup(None) |
| 120 | + inputs = ["10 + 6", "10 + 6"] |
| 121 | + batched_response = list(api.predict(inputs)) |
| 122 | + assert batched_response == [["10 + 6 is equal to 16."] * 2], "Batched response should match expected output" |
| 123 | + encoded_response = list(api.encode_response(batched_response, [{"temperature": 1.0}, {"temperature": 1.0}])) |
| 124 | + assert encoded_response == [ |
| 125 | + [ |
| 126 | + {"role": "assistant", "content": "10 + 6 is equal to 16."}, |
| 127 | + {"role": "assistant", "content": "10 + 6 is equal to 16."}, |
| 128 | + ], |
| 129 | + [ |
| 130 | + {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}, |
| 131 | + {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}, |
| 132 | + ], |
| 133 | + ], "Encoded batched response should match expected output" |
0 commit comments