Skip to content

Commit 6bdd8f3

Browse files
authored
test LitServer.run (#248)
* remove dead code * add tests * update test * update * update * merge master * update * add start_server test * update msg * remove dead code * skip windows * revert deadcode removal
1 parent 686db0c commit 6bdd8f3

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

src/litserve/server.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,17 @@ def run(
412412
if num_api_servers is None:
413413
num_api_servers = len(self.workers)
414414

415-
manager, litserve_workers = self.launch_inference_worker(num_api_servers)
415+
if num_api_servers < 1:
416+
raise ValueError("num_api_servers must be greater than 0")
416417

417418
if sys.platform == "win32":
419+
print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'")
418420
api_server_worker_type = "thread"
419421
elif api_server_worker_type is None:
420422
api_server_worker_type = "process"
421423

424+
manager, litserve_workers = self.launch_inference_worker(num_api_servers)
425+
422426
try:
423427
servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs)
424428
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
@@ -447,7 +451,7 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
447451
elif uvicorn_worker_type == "thread":
448452
w = threading.Thread(target=server.run, args=(sockets,))
449453
else:
450-
raise ValueError("Invalid value for uvicorn_worker_type. Must be 'process' or 'thread'")
454+
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
451455
w.start()
452456
servers.append(w)
453457
return servers

tests/test_lit_server.py

+60
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import asyncio
1515
import pickle
1616
import re
17+
import sys
18+
1719
from asgi_lifespan import LifespanManager
1820
from litserve import LitAPI
1921
from fastapi import Request, Response, HTTPException
@@ -196,6 +198,64 @@ def test_server_run(mock_uvicorn):
196198
mock_uvicorn.Config.assert_called()
197199

198200

201+
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
202+
@patch("litserve.server.uvicorn")
203+
def test_start_server(mock_uvicon):
204+
server = LitServer(ls.examples.TestAPI(), spec=ls.OpenAISpec())
205+
sockets = MagicMock()
206+
server._start_server(8000, 1, "info", sockets, "process")
207+
mock_uvicon.Server.assert_called()
208+
assert server.lit_spec.response_queue_id is not None, "response_queue_id must be generated"
209+
210+
211+
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
212+
@patch("litserve.server.uvicorn")
213+
def test_server_run_with_api_server_worker_type(mock_uvicorn):
214+
api = ls.examples.SimpleLitAPI()
215+
server = ls.LitServer(api, devices=1)
216+
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
217+
server.run(api_server_worker_type="invalid")
218+
219+
with pytest.raises(ValueError, match=r"must be greater than 0"):
220+
server.run(num_api_servers=0)
221+
222+
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
223+
server._start_server = MagicMock()
224+
225+
# Running the method to test
226+
server.run(api_server_worker_type=None)
227+
server.launch_inference_worker.assert_called_with(1)
228+
actual = server._start_server.call_args
229+
assert actual[0][4] == "process", "Server should run in process mode"
230+
231+
server.run(api_server_worker_type="thread")
232+
server.launch_inference_worker.assert_called_with(1)
233+
actual = server._start_server.call_args
234+
assert actual[0][4] == "thread", "Server should run in thread mode"
235+
236+
server.run(api_server_worker_type="process")
237+
server.launch_inference_worker.assert_called_with(1)
238+
actual = server._start_server.call_args
239+
assert actual[0][4] == "process", "Server should run in process mode"
240+
241+
server.run(api_server_worker_type="process", num_api_servers=10)
242+
server.launch_inference_worker.assert_called_with(10)
243+
244+
245+
@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows")
246+
@patch("litserve.server.uvicorn")
247+
def test_server_run_windows(mock_uvicorn):
248+
api = ls.examples.SimpleLitAPI()
249+
server = ls.LitServer(api)
250+
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
251+
server._start_server = MagicMock()
252+
253+
# Running the method to test
254+
server.run(api_server_worker_type=None)
255+
actual = server._start_server.call_args
256+
assert actual[0][4] == "thread", "Windows only supports thread mode"
257+
258+
199259
def test_server_terminate():
200260
server = LitServer(SimpleLitAPI())
201261
mock_manager = MagicMock()

0 commit comments

Comments
 (0)