Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Aug 30, 2024
1 parent f2f70ea commit 8bae4bb
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import threading

import time
from queue import Queue
Expand All @@ -24,6 +25,7 @@
run_streaming_loop,
run_batched_streaming_loop,
inference_worker,
run_batched_loop,
)
from litserve.utils import LitAPIStatus

Expand Down Expand Up @@ -164,3 +166,90 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):

inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False)
mock_single_loop.assert_called_once()


def test_run_single_loop():
lit_api = MagicMock()
lit_api.request_timeout = 1
lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
lit_api.predict = MagicMock(side_effect=lambda x: f"response to {x}")
lit_api.encode_response = MagicMock(side_effect=lambda x: f"encoded {x}")

request_queue = Queue()
request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"}))
response_queues = [Queue()]

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues))
loop_thread.start()

# Allow some time for the loop to process
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
loop_thread.join()

response = response_queues[0].get()
assert response == ("UUID-001", ("encoded response to Hello", LitAPIStatus.OK))


def test_run_batched_loop():
lit_api = MagicMock()
lit_api.request_timeout = 1
lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
lit_api.batch = MagicMock(side_effect=lambda inputs: inputs)
lit_api.predict = MagicMock(side_effect=lambda x: [f"response to {i}" for i in x])
lit_api.unbatch = MagicMock(side_effect=lambda x: x)
lit_api.encode_response = MagicMock(side_effect=lambda x: f"encoded {x}")

request_queue = Queue()
request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"}))
request_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"}))
response_queues = [Queue()]

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1))
loop_thread.start()

# Allow some time for the loop to process
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
loop_thread.join()

response_1 = response_queues[0].get()
response_2 = response_queues[0].get()
assert response_1 == ("UUID-001", ("encoded response to Hello", LitAPIStatus.OK))
assert response_2 == ("UUID-002", ("encoded response to World", LitAPIStatus.OK))


def test_run_streaming_loop():
lit_api = MagicMock()
lit_api.request_timeout = 1
lit_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
lit_api.predict = MagicMock(side_effect=lambda x: (f"response to {i}" for i in range(3)))
lit_api.encode_response = MagicMock(side_effect=lambda x: (f"encoded {i}" for i in x))
lit_api.format_encoded_response = MagicMock(side_effect=lambda x: x)

request_queue = Queue()
request_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"}))
response_queues = [Queue()]

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues))
loop_thread.start()

# Allow some time for the loop to process
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
loop_thread.join()

for i in range(3):
response = response_queues[0].get()
assert response == ("UUID-001", (f"encoded response to {i}", LitAPIStatus.OK))
finish_response = response_queues[0].get()
assert finish_response == ("UUID-001", ("", LitAPIStatus.FINISH_STREAMING))

0 comments on commit 8bae4bb

Please sign in to comment.