Skip to content

Commit

Permalink
test collate_request w batch_timeout and batch_size (#238)
Browse files Browse the repository at this point in the history
* add batch timeout test

* fix batch timeout issue

* test extreme cases

* update

* fix
  • Loading branch information
aniketmaurya authored Aug 28, 2024
1 parent ebccad8 commit d597528
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ def collate_requests(
end_time = entered_at + batch_timeout
apply_timeout = lit_api.request_timeout not in (-1, False)

if batch_timeout == 0:
while len(payloads) < max_batch_size:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait()
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append((response_queue_id, uid))
else:
payloads.append((response_queue_id, uid, x_enc))
except Empty:
break
return payloads, timed_out_uids

while time.monotonic() < end_time and len(payloads) < max_batch_size:
remaining_time = end_time - time.monotonic()
if remaining_time <= 0:
Expand Down
26 changes: 25 additions & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from httpx import AsyncClient

from litserve import LitAPI, LitServer
from litserve.server import run_batched_loop
from litserve.server import run_batched_loop, collate_requests
from litserve.utils import wrap_litserve_start
import litserve as ls


class Linear(nn.Module):
Expand Down Expand Up @@ -171,3 +172,26 @@ def test_batched_loop():
lit_api_mock.batch.assert_called_once()
lit_api_mock.batch.assert_called_once_with([4.0, 5.0])
lit_api_mock.unbatch.assert_called_once()


@pytest.mark.parametrize(
("batch_timeout", "batch_size"),
[
pytest.param(0, 2),
pytest.param(0, 1000),
pytest.param(0.01, 2),
pytest.param(1000, 2),
pytest.param(0.01, 1000),
],
)
def test_collate_requests(batch_timeout, batch_size):
api = ls.examples.SimpleBatchedAPI()
api.request_timeout = 5
request_queue = Queue()
for i in range(batch_size):
request_queue.put((i, f"uuid-abc-{i}", time.monotonic(), i)) # response_queue_id, uid, timestamp, x_enc
payloads, timed_out_uids = collate_requests(
api, request_queue, max_batch_size=batch_size, batch_timeout=batch_timeout
)
assert len(payloads) == batch_size, f"Should have {batch_size} payloads, got {len(payloads)}"
assert len(timed_out_uids) == 0, "No timed out uids"

0 comments on commit d597528

Please sign in to comment.