|
28 | 28 | from litserve import LitAPI
|
29 | 29 | from litserve.callbacks import CallbackRunner
|
30 | 30 | from litserve.loops import (
|
| 31 | + LitLoop, |
31 | 32 | _BaseLoop,
|
32 | 33 | inference_worker,
|
33 | 34 | run_batched_loop,
|
@@ -495,3 +496,41 @@ def test_get_default_loop():
|
495 | 496 | assert isinstance(loop, ls.loops.BatchedStreamingLoop), (
|
496 | 497 | "BatchedStreamingLoop must be returned when stream=True and max_batch_size>1"
|
497 | 498 | )
|
| 499 | + |
| 500 | + |
| 501 | +@pytest.fixture |
| 502 | +def lit_loop_setup(): |
| 503 | + lit_loop = LitLoop() |
| 504 | + lit_api = MagicMock(request_timeout=0.1) |
| 505 | + request_queue = Queue() |
| 506 | + return lit_loop, lit_api, request_queue |
| 507 | + |
| 508 | + |
| 509 | +def test_lit_loop_get_batch_requests(lit_loop_setup): |
| 510 | + lit_loop, lit_api, request_queue = lit_loop_setup |
| 511 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) |
| 512 | + request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0})) |
| 513 | + batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001) |
| 514 | + assert len(batches) == 2 |
| 515 | + assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})] |
| 516 | + assert timed_out_uids == [] |
| 517 | + |
| 518 | + |
| 519 | +def test_lit_loop_get_request(lit_loop_setup): |
| 520 | + lit_loop, _, request_queue = lit_loop_setup |
| 521 | + t = time.monotonic() |
| 522 | + request_queue.put((0, "UUID-001", t, {"input": 4.0})) |
| 523 | + response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1) |
| 524 | + assert uid == "UUID-001" |
| 525 | + assert response_queue_id == 0 |
| 526 | + assert timestamp == t |
| 527 | + assert x_enc == {"input": 4.0} |
| 528 | + assert lit_loop.get_request(request_queue, timeout=0.001) is None |
| 529 | + |
| 530 | + |
| 531 | +def test_lit_loop_put_response(lit_loop_setup): |
| 532 | + lit_loop, _, request_queue = lit_loop_setup |
| 533 | + response_queues = [Queue()] |
| 534 | + lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK) |
| 535 | + response = response_queues[0].get() |
| 536 | + assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)) |
0 commit comments