Skip to content

Commit

Permalink
add tests for continuous batching and Default loops (#396)
Browse files Browse the repository at this point in the history
* add test

* update

* fix

* addt test

* update

* bump version

* Update src/litserve/loops.py

Co-authored-by: Jirka Borovec <[email protected]>

---------

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
aniketmaurya and Borda authored Dec 16, 2024
1 parent e71f38b commit fcce973
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev1"
__version__ = "0.2.6.dev2"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
11 changes: 6 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,6 @@ def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
Expand All @@ -507,8 +504,10 @@ def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_si
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc
try:
return request_queue.get(timeout=timeout)
except Empty:
return None

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
Expand Down Expand Up @@ -751,6 +750,8 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
if hasattr(lit_api, "add_request"):
lit_api.add_request(uid, request)
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

Expand Down
156 changes: 156 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import io
import json
import re
import threading
import time
from queue import Queue
Expand All @@ -28,8 +29,13 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import (
ContinuousBatchingLoop,
DefaultLoop,
LitLoop,
Output,
_BaseLoop,
inference_worker,
notify_timed_out_requests,
run_batched_loop,
run_batched_streaming_loop,
run_single_loop,
Expand Down Expand Up @@ -495,3 +501,153 @@ def test_get_default_loop():
assert isinstance(loop, ls.loops.BatchedStreamingLoop), (
"BatchedStreamingLoop must be returned when stream=True and max_batch_size>1"
)


@pytest.fixture
def lit_loop_setup():
lit_loop = LitLoop()
lit_api = MagicMock(request_timeout=0.1)
request_queue = Queue()
return lit_loop, lit_api, request_queue


def test_lit_loop_get_batch_requests(lit_loop_setup):
lit_loop, lit_api, request_queue = lit_loop_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))
request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0}))
batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001)
assert len(batches) == 2
assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})]
assert timed_out_uids == []


def test_lit_loop_get_request(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
t = time.monotonic()
request_queue.put((0, "UUID-001", t, {"input": 4.0}))
response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1)
assert uid == "UUID-001"
assert response_queue_id == 0
assert timestamp == t
assert x_enc == {"input": 4.0}
assert lit_loop.get_request(request_queue, timeout=0.001) is None


def test_lit_loop_put_response(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
response_queues = [Queue()]
lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK)
response = response_queues[0].get()
assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))


def test_notify_timed_out_requests():
response_queues = [Queue()]

# Simulate timed out requests
timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")]

# Call the function to notify timed out requests
notify_timed_out_requests(response_queues, timed_out_uids)

# Check the responses in the response queue
response_1 = response_queues[0].get()
response_2 = response_queues[0].get()

assert response_1[0] == "UUID-001"
assert response_1[1][1] == LitAPIStatus.ERROR
assert isinstance(response_1[1][0], HTTPException)
assert response_2[0] == "UUID-002"
assert isinstance(response_2[1][0], HTTPException)
assert response_2[1][1] == LitAPIStatus.ERROR


class ContinuousBatchingAPI(ls.LitAPI):
def setup(self, spec: Optional[LitSpec]):
self.model = {}

def add_request(self, uid: str, request):
self.model[uid] = {"outputs": list(range(5))}

def decode_request(self, input: str):
return input

def encode_response(self, output: str):
return {"output": output}

def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]:
outputs = []
for k in self.model:
v = self.model[k]
if v["outputs"]:
o = v["outputs"].pop(0)
outputs.append(Output(k, o, LitAPIStatus.OK))
keys = list(self.model.keys())
for k in keys:
if k not in [o.uid for o in outputs]:
outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING))
del self.model[k]
return outputs


@pytest.mark.parametrize(
("stream", "max_batch_size", "error_msg"),
[
(True, 4, "`lit_api.unbatch` must generate values using `yield`."),
(True, 1, "`lit_api.encode_response` must generate values using `yield`."),
],
)
def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg):
lit_api = ls.test_examples.SimpleLitAPI()
lit_api.stream = stream
lit_api.max_batch_size = max_batch_size
loop = DefaultLoop()
with pytest.raises(ValueError, match=error_msg):
loop.pre_setup(lit_api, None)


@pytest.fixture
def continuous_batching_setup():
lit_api = ContinuousBatchingAPI()
lit_api.stream = True
lit_api.request_timeout = 0.1
lit_api.pre_setup(2, None)
lit_api.setup(None)
request_queue = Queue()
response_queues = [Queue()]
loop = ContinuousBatchingLoop()
return lit_api, loop, request_queue, response_queues


def test_continuous_batching_pre_setup(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
lit_api.stream = False
with pytest.raises(
ValueError,
match=re.escape(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
),
):
loop.pre_setup(lit_api, None)


def test_continuous_batching_run(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER)

results = []
for i in range(5):
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == i
assert status == LitAPIStatus.OK
assert uid == "UUID-001"
results.append(o)
assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4"
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == ""
assert status == LitAPIStatus.FINISH_STREAMING

0 comments on commit fcce973

Please sign in to comment.