|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import asyncio
|
15 |
| -import inspect |
16 | 15 | import pickle
|
17 | 16 | import re
|
18 | 17 | from asgi_lifespan import LifespanManager
|
19 | 18 | from litserve import LitAPI
|
20 | 19 | from fastapi import Request, Response, HTTPException
|
21 |
| -import time |
22 | 20 | import torch
|
23 | 21 | import torch.nn as nn
|
24 |
| -from queue import Queue |
25 | 22 | from httpx import AsyncClient
|
26 | 23 | from litserve.utils import wrap_litserve_start
|
27 | 24 |
|
|
31 | 28 | from litserve.connector import _Connector
|
32 | 29 | from litserve.server import (
|
33 | 30 | inference_worker,
|
34 |
| - run_single_loop, |
35 |
| - run_streaming_loop, |
36 |
| - LitAPIStatus, |
37 |
| - run_batched_streaming_loop, |
38 | 31 | )
|
39 | 32 | from litserve.server import LitServer
|
40 | 33 | import litserve as ls
|
@@ -73,32 +66,6 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
|
73 | 66 | mock_single_loop.assert_called_once()
|
74 | 67 |
|
75 | 68 |
|
76 |
| -@pytest.fixture() |
77 |
| -def loop_args(): |
78 |
| - requests_queue = Queue() |
79 |
| - requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc |
80 |
| - requests_queue.put((1, "uuid-234", time.monotonic(), 2)) |
81 |
| - |
82 |
| - lit_api_mock = MagicMock() |
83 |
| - lit_api_mock.request_timeout = 1 |
84 |
| - lit_api_mock.decode_request = MagicMock(side_effect=lambda x: x["input"]) |
85 |
| - return lit_api_mock, requests_queue |
86 |
| - |
87 |
| - |
88 |
| -class FakeResponseQueue: |
89 |
| - def put(self, item): |
90 |
| - raise StopIteration("exit loop") |
91 |
| - |
92 |
| - |
93 |
| -def test_single_loop(loop_args): |
94 |
| - lit_api_mock, requests_queue = loop_args |
95 |
| - lit_api_mock.unbatch.side_effect = None |
96 |
| - response_queues = [FakeResponseQueue()] |
97 |
| - |
98 |
| - with pytest.raises(StopIteration, match="exit loop"): |
99 |
| - run_single_loop(lit_api_mock, None, requests_queue, response_queues) |
100 |
| - |
101 |
| - |
102 | 69 | @pytest.mark.asyncio()
|
103 | 70 | async def test_stream(simple_stream_api):
|
104 | 71 | server = LitServer(simple_stream_api, stream=True, timeout=10)
|
@@ -141,108 +108,6 @@ async def test_batched_stream_server(simple_batched_stream_api):
|
141 | 108 | ), "Server returns input prompt and generated output which didn't match."
|
142 | 109 |
|
143 | 110 |
|
144 |
| -class FakeStreamResponseQueue: |
145 |
| - def __init__(self, num_streamed_outputs): |
146 |
| - self.num_streamed_outputs = num_streamed_outputs |
147 |
| - self.count = 0 |
148 |
| - |
149 |
| - def put(self, item): |
150 |
| - uid, args = item |
151 |
| - response, status = args |
152 |
| - if self.count >= self.num_streamed_outputs: |
153 |
| - raise StopIteration("exit loop") |
154 |
| - assert response == f"{self.count}", "This streaming loop generates number from 0 to 9 which is sent via Queue" |
155 |
| - self.count += 1 |
156 |
| - |
157 |
| - |
158 |
| -def test_streaming_loop(): |
159 |
| - num_streamed_outputs = 10 |
160 |
| - |
161 |
| - def fake_predict(inputs: str): |
162 |
| - for i in range(num_streamed_outputs): |
163 |
| - yield {"output": f"{i}"} |
164 |
| - |
165 |
| - def fake_encode(output): |
166 |
| - assert inspect.isgenerator(output), "predict function must be a generator when `stream=True`" |
167 |
| - for out in output: |
168 |
| - yield out["output"] |
169 |
| - |
170 |
| - fake_stream_api = MagicMock() |
171 |
| - fake_stream_api.request_timeout = 1 |
172 |
| - fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) |
173 |
| - fake_stream_api.predict = MagicMock(side_effect=fake_predict) |
174 |
| - fake_stream_api.encode_response = MagicMock(side_effect=fake_encode) |
175 |
| - fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) |
176 |
| - |
177 |
| - requests_queue = Queue() |
178 |
| - requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"})) |
179 |
| - response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] |
180 |
| - |
181 |
| - with pytest.raises(StopIteration, match="exit loop"): |
182 |
| - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) |
183 |
| - |
184 |
| - fake_stream_api.predict.assert_called_once_with("Hello") |
185 |
| - fake_stream_api.encode_response.assert_called_once() |
186 |
| - |
187 |
| - |
188 |
| -class FakeBatchStreamResponseQueue: |
189 |
| - def __init__(self, num_streamed_outputs): |
190 |
| - self.num_streamed_outputs = num_streamed_outputs |
191 |
| - self.count = 0 |
192 |
| - |
193 |
| - def put(self, item): |
194 |
| - uid, args = item |
195 |
| - response, status = args |
196 |
| - if status == LitAPIStatus.FINISH_STREAMING: |
197 |
| - raise StopIteration("interrupt iteration") |
198 |
| - if status == LitAPIStatus.ERROR and b"interrupt iteration" in response: |
199 |
| - assert self.count // 2 == self.num_streamed_outputs, ( |
200 |
| - f"Loop count must have incremented for " f"{self.num_streamed_outputs} times." |
201 |
| - ) |
202 |
| - raise StopIteration("finish streaming") |
203 |
| - |
204 |
| - assert ( |
205 |
| - response == f"{self.count // 2}" |
206 |
| - ), f"streaming loop generates number from 0 to 9 which is sent via Queue. {args}, count:{self.count}" |
207 |
| - self.count += 1 |
208 |
| - |
209 |
| - |
210 |
| -def test_batched_streaming_loop(): |
211 |
| - num_streamed_outputs = 10 |
212 |
| - |
213 |
| - def fake_predict(inputs: list): |
214 |
| - n = len(inputs) |
215 |
| - assert n == 2, "Two requests has been simulated to batched." |
216 |
| - for i in range(num_streamed_outputs): |
217 |
| - yield [{"output": f"{i}"}] * n |
218 |
| - |
219 |
| - def fake_encode(output_iter): |
220 |
| - assert inspect.isgenerator(output_iter), "predict function must be a generator when `stream=True`" |
221 |
| - for outputs in output_iter: |
222 |
| - yield [output["output"] for output in outputs] |
223 |
| - |
224 |
| - fake_stream_api = MagicMock() |
225 |
| - fake_stream_api.request_timeout = 1 |
226 |
| - fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"]) |
227 |
| - fake_stream_api.batch = MagicMock(side_effect=lambda inputs: inputs) |
228 |
| - fake_stream_api.predict = MagicMock(side_effect=fake_predict) |
229 |
| - fake_stream_api.encode_response = MagicMock(side_effect=fake_encode) |
230 |
| - fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs) |
231 |
| - fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) |
232 |
| - |
233 |
| - requests_queue = Queue() |
234 |
| - requests_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) |
235 |
| - requests_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"})) |
236 |
| - response_queues = [FakeBatchStreamResponseQueue(num_streamed_outputs)] |
237 |
| - |
238 |
| - with pytest.raises(StopIteration, match="finish streaming"): |
239 |
| - run_batched_streaming_loop( |
240 |
| - fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2 |
241 |
| - ) |
242 |
| - fake_stream_api.predict.assert_called_once_with(["Hello", "World"]) |
243 |
| - fake_stream_api.encode_response.assert_called_once() |
244 |
| - |
245 |
| - |
246 | 111 | def test_litapi_with_stream(simple_litapi):
|
247 | 112 | with pytest.raises(
|
248 | 113 | ValueError,
|
|
0 commit comments