From d597528951c1db5c98f62a3c3e62e74d323dff3d Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 28 Aug 2024 17:00:08 +0100 Subject: [PATCH] test `collate_request` w batch_timeout and batch_size (#238) * add batch timeout test * fix batch timeout issue * test extreme cases * update * fix --- src/litserve/server.py | 12 ++++++++++++ tests/test_batch.py | 26 +++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index e0935c4b..506d5873 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -92,6 +92,18 @@ def collate_requests( end_time = entered_at + batch_timeout apply_timeout = lit_api.request_timeout not in (-1, False) + if batch_timeout == 0: + while len(payloads) < max_batch_size: + try: + response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait() + if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: + timed_out_uids.append((response_queue_id, uid)) + else: + payloads.append((response_queue_id, uid, x_enc)) + except Empty: + break + return payloads, timed_out_uids + while time.monotonic() < end_time and len(payloads) < max_batch_size: remaining_time = end_time - time.monotonic() if remaining_time <= 0: diff --git a/tests/test_batch.py b/tests/test_batch.py index f027b899..048bc3c7 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -24,8 +24,9 @@ from httpx import AsyncClient from litserve import LitAPI, LitServer -from litserve.server import run_batched_loop +from litserve.server import run_batched_loop, collate_requests from litserve.utils import wrap_litserve_start +import litserve as ls class Linear(nn.Module): @@ -171,3 +172,26 @@ def test_batched_loop(): lit_api_mock.batch.assert_called_once() lit_api_mock.batch.assert_called_once_with([4.0, 5.0]) lit_api_mock.unbatch.assert_called_once() + + +@pytest.mark.parametrize( + ("batch_timeout", "batch_size"), + [ + pytest.param(0, 2), + pytest.param(0, 1000), + pytest.param(0.01, 2), + pytest.param(1000, 2), + pytest.param(0.01, 1000), + ], +) +def test_collate_requests(batch_timeout, batch_size): + api = ls.examples.SimpleBatchedAPI() + api.request_timeout = 5 + request_queue = Queue() + for i in range(batch_size): + request_queue.put((i, f"uuid-abc-{i}", time.monotonic(), i)) # response_queue_id, uid, timestamp, x_enc + payloads, timed_out_uids = collate_requests( + api, request_queue, max_batch_size=batch_size, batch_timeout=batch_timeout + ) + assert len(payloads) == batch_size, f"Should have {batch_size} payloads, got {len(payloads)}" + assert len(timed_out_uids) == 0, "No timed out uids"