From c62b2e5e90401655bc85e40d80b59cf1ff300a80 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 9 Jan 2025 15:22:59 +0000 Subject: [PATCH] update tests --- tests/parity_fastapi/ls-server.py | 1 + tests/perf_test/bert/server.py | 1 + tests/perf_test/stream/stream_speed/server.py | 1 + tests/test_lit_server.py | 12 ++++++++---- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/parity_fastapi/ls-server.py b/tests/parity_fastapi/ls-server.py index 2a653889..c98fda40 100644 --- a/tests/parity_fastapi/ls-server.py +++ b/tests/parity_fastapi/ls-server.py @@ -64,6 +64,7 @@ def main(batch_size: int, workers_per_device: int): batch_timeout=0.01, timeout=10, workers_per_device=workers_per_device, + use_zmq=True, ) server.run(port=8000, log_level="warning") diff --git a/tests/perf_test/bert/server.py b/tests/perf_test/bert/server.py index 3a7c825d..abbcce2f 100644 --- a/tests/perf_test/bert/server.py +++ b/tests/perf_test/bert/server.py @@ -59,6 +59,7 @@ def main( devices=devices, batch_timeout=batch_timeout, timeout=200, + use_zmq=True, ) server.run(log_level="warning", num_api_servers=4, generate_client_file=False) diff --git a/tests/perf_test/stream/stream_speed/server.py b/tests/perf_test/stream/stream_speed/server.py index 03e5dc80..11ddfaf5 100644 --- a/tests/perf_test/stream/stream_speed/server.py +++ b/tests/perf_test/stream/stream_speed/server.py @@ -21,5 +21,6 @@ def encode_response(self, output_stream): server = ls.LitServer( api, stream=True, + use_zmq=True, ) server.run(port=8000, generate_client_file=False) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index c3fc6637..47e1af19 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -61,9 +61,10 @@ def test_device_identifiers_error(simple_litapi, devices): LitServer(simple_litapi, accelerator="cuda", devices=devices, timeout=10) +@pytest.mark.parametrize("use_zmq", [True, False]) @pytest.mark.asyncio -async def test_stream(simple_stream_api): - server = LitServer(simple_stream_api, stream=True, timeout=10) +async def test_stream(simple_stream_api, use_zmq): + server = LitServer(simple_stream_api, stream=True, timeout=10, use_zmq=use_zmq) expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "") @@ -84,9 +85,12 @@ async def test_stream(simple_stream_api): ) +@pytest.mark.parametrize("use_zmq", [True, False]) @pytest.mark.asyncio -async def test_batched_stream_server(simple_batched_stream_api): - server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) +async def test_batched_stream_server(simple_batched_stream_api, use_zmq): + server = LitServer( + simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30, use_zmq=use_zmq + ) expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") expected_output2 = "World LitServe is streaming output".lower().replace(" ", "")