From fcce97384b1d74905fa4b42a36b7c9a0e429e276 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Mon, 16 Dec 2024 18:16:06 +0000 Subject: [PATCH] add tests for continuous batching and Default loops (#396) * add test * update * fix * addt test * update * bump version * Update src/litserve/loops.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/litserve/__about__.py | 2 +- src/litserve/loops.py | 11 +-- tests/test_loops.py | 156 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/src/litserve/__about__.py b/src/litserve/__about__.py index cef9982b..cdc5141a 100644 --- a/src/litserve/__about__.py +++ b/src/litserve/__about__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.6.dev1" +__version__ = "0.2.6.dev2" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 2f6828be..3a086f6e 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -495,9 +495,6 @@ def __init__(self): self._context = {} def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float): - if max_batch_size <= 1: - raise ValueError("max_batch_size must be greater than 1") - batches, timed_out_uids = collate_requests( lit_api, request_queue, @@ -507,8 +504,10 @@ def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_si return batches, timed_out_uids def get_request(self, request_queue: Queue, timeout: float = 1.0): - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout) - return response_queue_id, uid, timestamp, x_enc + try: + return request_queue.get(timeout=timeout) + except Empty: + return None def populate_context(self, lit_spec: LitSpec, request: Any): if lit_spec and hasattr(lit_spec, "populate_context"): @@ -751,6 +750,8 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool: def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None: """Add a new sequence to active sequences and perform any action before prediction such as filling the cache.""" + if hasattr(lit_api, "add_request"): + lit_api.add_request(uid, request) decoded_request = lit_api.decode_request(request) self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []} diff --git a/tests/test_loops.py b/tests/test_loops.py index 96581ef3..76532e5d 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -14,6 +14,7 @@ import inspect import io import json +import re import threading import time from queue import Queue @@ -28,8 +29,13 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner from litserve.loops import ( + ContinuousBatchingLoop, + DefaultLoop, + LitLoop, + Output, _BaseLoop, inference_worker, + notify_timed_out_requests, run_batched_loop, run_batched_streaming_loop, run_single_loop, @@ -495,3 +501,153 @@ def test_get_default_loop(): assert isinstance(loop, ls.loops.BatchedStreamingLoop), ( "BatchedStreamingLoop must be returned when stream=True and max_batch_size>1" ) + + +@pytest.fixture +def lit_loop_setup(): + lit_loop = LitLoop() + lit_api = MagicMock(request_timeout=0.1) + request_queue = Queue() + return lit_loop, lit_api, request_queue + + +def test_lit_loop_get_batch_requests(lit_loop_setup): + lit_loop, lit_api, request_queue = lit_loop_setup + request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) + request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0})) + batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001) + assert len(batches) == 2 + assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})] + assert timed_out_uids == [] + + +def test_lit_loop_get_request(lit_loop_setup): + lit_loop, _, request_queue = lit_loop_setup + t = time.monotonic() + request_queue.put((0, "UUID-001", t, {"input": 4.0})) + response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1) + assert uid == "UUID-001" + assert response_queue_id == 0 + assert timestamp == t + assert x_enc == {"input": 4.0} + assert lit_loop.get_request(request_queue, timeout=0.001) is None + + +def test_lit_loop_put_response(lit_loop_setup): + lit_loop, _, request_queue = lit_loop_setup + response_queues = [Queue()] + lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK) + response = response_queues[0].get() + assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)) + + +def test_notify_timed_out_requests(): + response_queues = [Queue()] + + # Simulate timed out requests + timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")] + + # Call the function to notify timed out requests + notify_timed_out_requests(response_queues, timed_out_uids) + + # Check the responses in the response queue + response_1 = response_queues[0].get() + response_2 = response_queues[0].get() + + assert response_1[0] == "UUID-001" + assert response_1[1][1] == LitAPIStatus.ERROR + assert isinstance(response_1[1][0], HTTPException) + assert response_2[0] == "UUID-002" + assert isinstance(response_2[1][0], HTTPException) + assert response_2[1][1] == LitAPIStatus.ERROR + + +class ContinuousBatchingAPI(ls.LitAPI): + def setup(self, spec: Optional[LitSpec]): + self.model = {} + + def add_request(self, uid: str, request): + self.model[uid] = {"outputs": list(range(5))} + + def decode_request(self, input: str): + return input + + def encode_response(self, output: str): + return {"output": output} + + def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]: + outputs = [] + for k in self.model: + v = self.model[k] + if v["outputs"]: + o = v["outputs"].pop(0) + outputs.append(Output(k, o, LitAPIStatus.OK)) + keys = list(self.model.keys()) + for k in keys: + if k not in [o.uid for o in outputs]: + outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING)) + del self.model[k] + return outputs + + +@pytest.mark.parametrize( + ("stream", "max_batch_size", "error_msg"), + [ + (True, 4, "`lit_api.unbatch` must generate values using `yield`."), + (True, 1, "`lit_api.encode_response` must generate values using `yield`."), + ], +) +def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg): + lit_api = ls.test_examples.SimpleLitAPI() + lit_api.stream = stream + lit_api.max_batch_size = max_batch_size + loop = DefaultLoop() + with pytest.raises(ValueError, match=error_msg): + loop.pre_setup(lit_api, None) + + +@pytest.fixture +def continuous_batching_setup(): + lit_api = ContinuousBatchingAPI() + lit_api.stream = True + lit_api.request_timeout = 0.1 + lit_api.pre_setup(2, None) + lit_api.setup(None) + request_queue = Queue() + response_queues = [Queue()] + loop = ContinuousBatchingLoop() + return lit_api, loop, request_queue, response_queues + + +def test_continuous_batching_pre_setup(continuous_batching_setup): + lit_api, loop, request_queue, response_queues = continuous_batching_setup + lit_api.stream = False + with pytest.raises( + ValueError, + match=re.escape( + "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" + ), + ): + loop.pre_setup(lit_api, None) + + +def test_continuous_batching_run(continuous_batching_setup): + lit_api, loop, request_queue, response_queues = continuous_batching_setup + request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"})) + loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER) + + results = [] + for i in range(5): + response = response_queues[0].get() + uid, (response_data, status) = response + o = json.loads(response_data)["output"] + assert o == i + assert status == LitAPIStatus.OK + assert uid == "UUID-001" + results.append(o) + assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4" + response = response_queues[0].get() + uid, (response_data, status) = response + o = json.loads(response_data)["output"] + assert o == "" + assert status == LitAPIStatus.FINISH_STREAMING