diff --git a/tests/test_loops.py b/tests/test_loops.py index 6b22df91..002ad431 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import threading import time from queue import Queue @@ -24,6 +25,7 @@ run_streaming_loop, run_batched_streaming_loop, inference_worker, + run_batched_loop, ) from litserve.utils import LitAPIStatus @@ -164,3 +166,90 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False) mock_single_loop.assert_called_once() + + +def test_run_single_loop(): + lit_api = MagicMock() + lit_api.request_timeout = 1 + lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) + lit_api.predict = MagicMock(side_effect=lambda x: f"response to {x}") + lit_api.encode_response = MagicMock(side_effect=lambda x: f"encoded {x}") + + request_queue = Queue() + request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) + response_queues = [Queue()] + + # Run the loop in a separate thread to allow it to be stopped + loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread.start() + + # Allow some time for the loop to process + time.sleep(1) + + # Stop the loop by putting a sentinel value in the queue + request_queue.put((None, None, None, None)) + loop_thread.join() + + response = response_queues[0].get() + assert response == ("UUID-001", ("encoded response to Hello", LitAPIStatus.OK)) + + +def test_run_batched_loop(): + lit_api = MagicMock() + lit_api.request_timeout = 1 + lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) + lit_api.batch = MagicMock(side_effect=lambda inputs: inputs) + lit_api.predict = MagicMock(side_effect=lambda x: [f"response to {i}" for i in x]) + lit_api.unbatch = MagicMock(side_effect=lambda x: x) + lit_api.encode_response = MagicMock(side_effect=lambda x: f"encoded {x}") + + request_queue = Queue() + request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) + request_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"})) + response_queues = [Queue()] + + # Run the loop in a separate thread to allow it to be stopped + loop_thread = threading.Thread(target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1)) + loop_thread.start() + + # Allow some time for the loop to process + time.sleep(1) + + # Stop the loop by putting a sentinel value in the queue + request_queue.put((None, None, None, None)) + loop_thread.join() + + response_1 = response_queues[0].get() + response_2 = response_queues[0].get() + assert response_1 == ("UUID-001", ("encoded response to Hello", LitAPIStatus.OK)) + assert response_2 == ("UUID-002", ("encoded response to World", LitAPIStatus.OK)) + + +def test_run_streaming_loop(): + lit_api = MagicMock() + lit_api.request_timeout = 1 + lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) + lit_api.predict = MagicMock(side_effect=lambda x: (f"response to {i}" for i in range(3))) + lit_api.encode_response = MagicMock(side_effect=lambda x: (f"encoded {i}" for i in x)) + lit_api.format_encoded_response = MagicMock(side_effect=lambda x: x) + + request_queue = Queue() + request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) + response_queues = [Queue()] + + # Run the loop in a separate thread to allow it to be stopped + loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread.start() + + # Allow some time for the loop to process + time.sleep(1) + + # Stop the loop by putting a sentinel value in the queue + request_queue.put((None, None, None, None)) + loop_thread.join() + + for i in range(3): + response = response_queues[0].get() + assert response == ("UUID-001", (f"encoded response to {i}", LitAPIStatus.OK)) + finish_response = response_queues[0].get() + assert finish_response == ("UUID-001", ("", LitAPIStatus.FINISH_STREAMING))