Skip to content

Commit 9bd56c0

Browse files
committed
separate loops
1 parent 1104920 commit 9bd56c0

File tree

2 files changed

+155
-135
lines changed

2 files changed

+155
-135
lines changed

tests/test_lit_server.py

-135
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import asyncio
15-
import inspect
1615
import pickle
1716
import re
1817
from asgi_lifespan import LifespanManager
1918
from litserve import LitAPI
2019
from fastapi import Request, Response, HTTPException
21-
import time
2220
import torch
2321
import torch.nn as nn
24-
from queue import Queue
2522
from httpx import AsyncClient
2623
from litserve.utils import wrap_litserve_start
2724

@@ -31,10 +28,6 @@
3128
from litserve.connector import _Connector
3229
from litserve.server import (
3330
inference_worker,
34-
run_single_loop,
35-
run_streaming_loop,
36-
LitAPIStatus,
37-
run_batched_streaming_loop,
3831
)
3932
from litserve.server import LitServer
4033
import litserve as ls
@@ -73,32 +66,6 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
7366
mock_single_loop.assert_called_once()
7467

7568

76-
@pytest.fixture()
77-
def loop_args():
78-
requests_queue = Queue()
79-
requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc
80-
requests_queue.put((1, "uuid-234", time.monotonic(), 2))
81-
82-
lit_api_mock = MagicMock()
83-
lit_api_mock.request_timeout = 1
84-
lit_api_mock.decode_request = MagicMock(side_effect=lambda x: x["input"])
85-
return lit_api_mock, requests_queue
86-
87-
88-
class FakeResponseQueue:
89-
def put(self, item):
90-
raise StopIteration("exit loop")
91-
92-
93-
def test_single_loop(loop_args):
94-
lit_api_mock, requests_queue = loop_args
95-
lit_api_mock.unbatch.side_effect = None
96-
response_queues = [FakeResponseQueue()]
97-
98-
with pytest.raises(StopIteration, match="exit loop"):
99-
run_single_loop(lit_api_mock, None, requests_queue, response_queues)
100-
101-
10269
@pytest.mark.asyncio()
10370
async def test_stream(simple_stream_api):
10471
server = LitServer(simple_stream_api, stream=True, timeout=10)
@@ -141,108 +108,6 @@ async def test_batched_stream_server(simple_batched_stream_api):
141108
), "Server returns input prompt and generated output which didn't match."
142109

143110

144-
class FakeStreamResponseQueue:
145-
def __init__(self, num_streamed_outputs):
146-
self.num_streamed_outputs = num_streamed_outputs
147-
self.count = 0
148-
149-
def put(self, item):
150-
uid, args = item
151-
response, status = args
152-
if self.count >= self.num_streamed_outputs:
153-
raise StopIteration("exit loop")
154-
assert response == f"{self.count}", "This streaming loop generates number from 0 to 9 which is sent via Queue"
155-
self.count += 1
156-
157-
158-
def test_streaming_loop():
159-
num_streamed_outputs = 10
160-
161-
def fake_predict(inputs: str):
162-
for i in range(num_streamed_outputs):
163-
yield {"output": f"{i}"}
164-
165-
def fake_encode(output):
166-
assert inspect.isgenerator(output), "predict function must be a generator when `stream=True`"
167-
for out in output:
168-
yield out["output"]
169-
170-
fake_stream_api = MagicMock()
171-
fake_stream_api.request_timeout = 1
172-
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
173-
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
174-
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
175-
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)
176-
177-
requests_queue = Queue()
178-
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
179-
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]
180-
181-
with pytest.raises(StopIteration, match="exit loop"):
182-
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues)
183-
184-
fake_stream_api.predict.assert_called_once_with("Hello")
185-
fake_stream_api.encode_response.assert_called_once()
186-
187-
188-
class FakeBatchStreamResponseQueue:
189-
def __init__(self, num_streamed_outputs):
190-
self.num_streamed_outputs = num_streamed_outputs
191-
self.count = 0
192-
193-
def put(self, item):
194-
uid, args = item
195-
response, status = args
196-
if status == LitAPIStatus.FINISH_STREAMING:
197-
raise StopIteration("interrupt iteration")
198-
if status == LitAPIStatus.ERROR and b"interrupt iteration" in response:
199-
assert self.count // 2 == self.num_streamed_outputs, (
200-
f"Loop count must have incremented for " f"{self.num_streamed_outputs} times."
201-
)
202-
raise StopIteration("finish streaming")
203-
204-
assert (
205-
response == f"{self.count // 2}"
206-
), f"streaming loop generates number from 0 to 9 which is sent via Queue. {args}, count:{self.count}"
207-
self.count += 1
208-
209-
210-
def test_batched_streaming_loop():
211-
num_streamed_outputs = 10
212-
213-
def fake_predict(inputs: list):
214-
n = len(inputs)
215-
assert n == 2, "Two requests has been simulated to batched."
216-
for i in range(num_streamed_outputs):
217-
yield [{"output": f"{i}"}] * n
218-
219-
def fake_encode(output_iter):
220-
assert inspect.isgenerator(output_iter), "predict function must be a generator when `stream=True`"
221-
for outputs in output_iter:
222-
yield [output["output"] for output in outputs]
223-
224-
fake_stream_api = MagicMock()
225-
fake_stream_api.request_timeout = 1
226-
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
227-
fake_stream_api.batch = MagicMock(side_effect=lambda inputs: inputs)
228-
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
229-
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
230-
fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs)
231-
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)
232-
233-
requests_queue = Queue()
234-
requests_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"}))
235-
requests_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"}))
236-
response_queues = [FakeBatchStreamResponseQueue(num_streamed_outputs)]
237-
238-
with pytest.raises(StopIteration, match="finish streaming"):
239-
run_batched_streaming_loop(
240-
fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2
241-
)
242-
fake_stream_api.predict.assert_called_once_with(["Hello", "World"])
243-
fake_stream_api.encode_response.assert_called_once()
244-
245-
246111
def test_litapi_with_stream(simple_litapi):
247112
with pytest.raises(
248113
ValueError,

tests/test_loops.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import inspect
15+
16+
import time
17+
from queue import Queue
18+
19+
from unittest.mock import MagicMock
20+
import pytest
21+
22+
from litserve.loops import (
23+
run_single_loop,
24+
run_streaming_loop,
25+
run_batched_streaming_loop,
26+
)
27+
from litserve.utils import LitAPIStatus
28+
29+
30+
@pytest.fixture()
31+
def loop_args():
32+
requests_queue = Queue()
33+
requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc
34+
requests_queue.put((1, "uuid-234", time.monotonic(), 2))
35+
36+
lit_api_mock = MagicMock()
37+
lit_api_mock.request_timeout = 1
38+
lit_api_mock.decode_request = MagicMock(side_effect=lambda x: x["input"])
39+
return lit_api_mock, requests_queue
40+
41+
42+
class FakeResponseQueue:
43+
def put(self, item):
44+
raise StopIteration("exit loop")
45+
46+
47+
def test_single_loop(loop_args):
48+
lit_api_mock, requests_queue = loop_args
49+
lit_api_mock.unbatch.side_effect = None
50+
response_queues = [FakeResponseQueue()]
51+
52+
with pytest.raises(StopIteration, match="exit loop"):
53+
run_single_loop(lit_api_mock, None, requests_queue, response_queues)
54+
55+
56+
class FakeStreamResponseQueue:
57+
def __init__(self, num_streamed_outputs):
58+
self.num_streamed_outputs = num_streamed_outputs
59+
self.count = 0
60+
61+
def put(self, item):
62+
uid, args = item
63+
response, status = args
64+
if self.count >= self.num_streamed_outputs:
65+
raise StopIteration("exit loop")
66+
assert response == f"{self.count}", "This streaming loop generates number from 0 to 9 which is sent via Queue"
67+
self.count += 1
68+
69+
70+
def test_streaming_loop():
71+
num_streamed_outputs = 10
72+
73+
def fake_predict(inputs: str):
74+
for i in range(num_streamed_outputs):
75+
yield {"output": f"{i}"}
76+
77+
def fake_encode(output):
78+
assert inspect.isgenerator(output), "predict function must be a generator when `stream=True`"
79+
for out in output:
80+
yield out["output"]
81+
82+
fake_stream_api = MagicMock()
83+
fake_stream_api.request_timeout = 1
84+
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
85+
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
86+
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
87+
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)
88+
89+
requests_queue = Queue()
90+
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
91+
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]
92+
93+
with pytest.raises(StopIteration, match="exit loop"):
94+
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues)
95+
96+
fake_stream_api.predict.assert_called_once_with("Hello")
97+
fake_stream_api.encode_response.assert_called_once()
98+
99+
100+
class FakeBatchStreamResponseQueue:
101+
def __init__(self, num_streamed_outputs):
102+
self.num_streamed_outputs = num_streamed_outputs
103+
self.count = 0
104+
105+
def put(self, item):
106+
uid, args = item
107+
response, status = args
108+
if status == LitAPIStatus.FINISH_STREAMING:
109+
raise StopIteration("interrupt iteration")
110+
if status == LitAPIStatus.ERROR and b"interrupt iteration" in response:
111+
assert self.count // 2 == self.num_streamed_outputs, (
112+
f"Loop count must have incremented for " f"{self.num_streamed_outputs} times."
113+
)
114+
raise StopIteration("finish streaming")
115+
116+
assert (
117+
response == f"{self.count // 2}"
118+
), f"streaming loop generates number from 0 to 9 which is sent via Queue. {args}, count:{self.count}"
119+
self.count += 1
120+
121+
122+
def test_batched_streaming_loop():
123+
num_streamed_outputs = 10
124+
125+
def fake_predict(inputs: list):
126+
n = len(inputs)
127+
assert n == 2, "Two requests has been simulated to batched."
128+
for i in range(num_streamed_outputs):
129+
yield [{"output": f"{i}"}] * n
130+
131+
def fake_encode(output_iter):
132+
assert inspect.isgenerator(output_iter), "predict function must be a generator when `stream=True`"
133+
for outputs in output_iter:
134+
yield [output["output"] for output in outputs]
135+
136+
fake_stream_api = MagicMock()
137+
fake_stream_api.request_timeout = 1
138+
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
139+
fake_stream_api.batch = MagicMock(side_effect=lambda inputs: inputs)
140+
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
141+
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
142+
fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs)
143+
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)
144+
145+
requests_queue = Queue()
146+
requests_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"}))
147+
requests_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"}))
148+
response_queues = [FakeBatchStreamResponseQueue(num_streamed_outputs)]
149+
150+
with pytest.raises(StopIteration, match="finish streaming"):
151+
run_batched_streaming_loop(
152+
fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2
153+
)
154+
fake_stream_api.predict.assert_called_once_with(["Hello", "World"])
155+
fake_stream_api.encode_response.assert_called_once()

0 commit comments

Comments
 (0)