|
14 | 14 | import asyncio
|
15 | 15 | import pickle
|
16 | 16 | import re
|
| 17 | +import sys |
| 18 | + |
17 | 19 | from asgi_lifespan import LifespanManager
|
18 | 20 | from litserve import LitAPI
|
19 | 21 | from fastapi import Request, Response, HTTPException
|
@@ -196,6 +198,64 @@ def test_server_run(mock_uvicorn):
|
196 | 198 | mock_uvicorn.Config.assert_called()
|
197 | 199 |
|
198 | 200 |
|
| 201 | +@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") |
| 202 | +@patch("litserve.server.uvicorn") |
| 203 | +def test_start_server(mock_uvicon): |
| 204 | + server = LitServer(ls.examples.TestAPI(), spec=ls.OpenAISpec()) |
| 205 | + sockets = MagicMock() |
| 206 | + server._start_server(8000, 1, "info", sockets, "process") |
| 207 | + mock_uvicon.Server.assert_called() |
| 208 | + assert server.lit_spec.response_queue_id is not None, "response_queue_id must be generated" |
| 209 | + |
| 210 | + |
| 211 | +@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") |
| 212 | +@patch("litserve.server.uvicorn") |
| 213 | +def test_server_run_with_api_server_worker_type(mock_uvicorn): |
| 214 | + api = ls.examples.SimpleLitAPI() |
| 215 | + server = ls.LitServer(api, devices=1) |
| 216 | + with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"): |
| 217 | + server.run(api_server_worker_type="invalid") |
| 218 | + |
| 219 | + with pytest.raises(ValueError, match=r"must be greater than 0"): |
| 220 | + server.run(num_api_servers=0) |
| 221 | + |
| 222 | + server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]]) |
| 223 | + server._start_server = MagicMock() |
| 224 | + |
| 225 | + # Running the method to test |
| 226 | + server.run(api_server_worker_type=None) |
| 227 | + server.launch_inference_worker.assert_called_with(1) |
| 228 | + actual = server._start_server.call_args |
| 229 | + assert actual[0][4] == "process", "Server should run in process mode" |
| 230 | + |
| 231 | + server.run(api_server_worker_type="thread") |
| 232 | + server.launch_inference_worker.assert_called_with(1) |
| 233 | + actual = server._start_server.call_args |
| 234 | + assert actual[0][4] == "thread", "Server should run in thread mode" |
| 235 | + |
| 236 | + server.run(api_server_worker_type="process") |
| 237 | + server.launch_inference_worker.assert_called_with(1) |
| 238 | + actual = server._start_server.call_args |
| 239 | + assert actual[0][4] == "process", "Server should run in process mode" |
| 240 | + |
| 241 | + server.run(api_server_worker_type="process", num_api_servers=10) |
| 242 | + server.launch_inference_worker.assert_called_with(10) |
| 243 | + |
| 244 | + |
| 245 | +@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows") |
| 246 | +@patch("litserve.server.uvicorn") |
| 247 | +def test_server_run_windows(mock_uvicorn): |
| 248 | + api = ls.examples.SimpleLitAPI() |
| 249 | + server = ls.LitServer(api) |
| 250 | + server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]]) |
| 251 | + server._start_server = MagicMock() |
| 252 | + |
| 253 | + # Running the method to test |
| 254 | + server.run(api_server_worker_type=None) |
| 255 | + actual = server._start_server.call_args |
| 256 | + assert actual[0][4] == "thread", "Windows only supports thread mode" |
| 257 | + |
| 258 | + |
199 | 259 | def test_server_terminate():
|
200 | 260 | server = LitServer(SimpleLitAPI())
|
201 | 261 | mock_manager = MagicMock()
|
|
0 commit comments