|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
| 15 | +import json |
| 16 | +import threading |
15 | 17 |
|
16 | 18 | import time |
17 | 19 | from queue import Queue |
18 | 20 |
|
19 | 21 | from unittest.mock import MagicMock, patch |
20 | 22 | import pytest |
| 23 | +from fastapi import HTTPException |
21 | 24 |
|
22 | 25 | from litserve.loops import ( |
23 | 26 | run_single_loop, |
24 | 27 | run_streaming_loop, |
25 | 28 | run_batched_streaming_loop, |
26 | 29 | inference_worker, |
| 30 | + run_batched_loop, |
27 | 31 | ) |
| 32 | +from litserve.test_examples.openai_spec_example import OpenAIBatchingWithUsage |
28 | 33 | from litserve.utils import LitAPIStatus |
| 34 | +import litserve as ls |
29 | 35 |
|
30 | 36 |
|
31 | 37 | @pytest.fixture |
@@ -164,3 +170,198 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): |
164 | 170 |
|
165 | 171 | inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False) |
166 | 172 | mock_single_loop.assert_called_once() |
| 173 | + |
| 174 | + |
| 175 | +def test_run_single_loop(): |
| 176 | + lit_api = ls.test_examples.SimpleLitAPI() |
| 177 | + lit_api.setup(None) |
| 178 | + lit_api.request_timeout = 1 |
| 179 | + |
| 180 | + request_queue = Queue() |
| 181 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) |
| 182 | + response_queues = [Queue()] |
| 183 | + |
| 184 | + # Run the loop in a separate thread to allow it to be stopped |
| 185 | + loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues)) |
| 186 | + loop_thread.start() |
| 187 | + |
| 188 | + # Allow some time for the loop to process |
| 189 | + time.sleep(1) |
| 190 | + |
| 191 | + # Stop the loop by putting a sentinel value in the queue |
| 192 | + request_queue.put((None, None, None, None)) |
| 193 | + loop_thread.join() |
| 194 | + |
| 195 | + response = response_queues[0].get() |
| 196 | + assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)) |
| 197 | + |
| 198 | + |
| 199 | +def test_run_single_loop_timeout(caplog): |
| 200 | + lit_api = ls.test_examples.SimpleLitAPI() |
| 201 | + lit_api.setup(None) |
| 202 | + lit_api.request_timeout = 0.0001 |
| 203 | + |
| 204 | + request_queue = Queue() |
| 205 | + request = (0, "UUID-001", time.monotonic(), {"input": 4.0}) |
| 206 | + time.sleep(0.1) |
| 207 | + request_queue.put(request) |
| 208 | + response_queues = [Queue()] |
| 209 | + |
| 210 | + # Run the loop in a separate thread to allow it to be stopped |
| 211 | + loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues)) |
| 212 | + loop_thread.start() |
| 213 | + |
| 214 | + request_queue.put((None, None, None, None)) |
| 215 | + loop_thread.join() |
| 216 | + assert "Request UUID-001 was waiting in the queue for too long" in caplog.text |
| 217 | + assert isinstance(response_queues[0].get()[1][0], HTTPException), "Timeout should return an HTTPException" |
| 218 | + |
| 219 | + |
| 220 | +def test_run_batched_loop(): |
| 221 | + lit_api = ls.test_examples.SimpleBatchedAPI() |
| 222 | + lit_api.setup(None) |
| 223 | + lit_api._sanitize(2, None) |
| 224 | + assert lit_api.model is not None, "Setup must initialize the model" |
| 225 | + lit_api.request_timeout = 1 |
| 226 | + |
| 227 | + request_queue = Queue() |
| 228 | + # response_queue_id, uid, timestamp, x_enc |
| 229 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) |
| 230 | + request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0})) |
| 231 | + response_queues = [Queue()] |
| 232 | + |
| 233 | + # Run the loop in a separate thread to allow it to be stopped |
| 234 | + loop_thread = threading.Thread(target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1)) |
| 235 | + loop_thread.start() |
| 236 | + |
| 237 | + # Allow some time for the loop to process |
| 238 | + time.sleep(1) |
| 239 | + |
| 240 | + # Stop the loop by putting a sentinel value in the queue |
| 241 | + request_queue.put((None, None, None, None)) |
| 242 | + loop_thread.join() |
| 243 | + |
| 244 | + response_1 = response_queues[0].get(timeout=10) |
| 245 | + response_2 = response_queues[0].get(timeout=10) |
| 246 | + assert response_1 == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)) |
| 247 | + assert response_2 == ("UUID-002", ({"output": 25.0}, LitAPIStatus.OK)) |
| 248 | + |
| 249 | + |
| 250 | +def test_run_batched_loop_timeout(caplog): |
| 251 | + lit_api = ls.test_examples.SimpleBatchedAPI() |
| 252 | + lit_api.setup(None) |
| 253 | + lit_api._sanitize(2, None) |
| 254 | + assert lit_api.model is not None, "Setup must initialize the model" |
| 255 | + lit_api.request_timeout = 0.1 |
| 256 | + |
| 257 | + request_queue = Queue() |
| 258 | + # response_queue_id, uid, timestamp, x_enc |
| 259 | + r1 = (0, "UUID-001", time.monotonic(), {"input": 4.0}) |
| 260 | + time.sleep(0.1) |
| 261 | + request_queue.put(r1) |
| 262 | + r2 = (0, "UUID-002", time.monotonic(), {"input": 5.0}) |
| 263 | + request_queue.put(r2) |
| 264 | + response_queues = [Queue()] |
| 265 | + |
| 266 | + # Run the loop in a separate thread to allow it to be stopped |
| 267 | + loop_thread = threading.Thread( |
| 268 | + target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 0.001) |
| 269 | + ) |
| 270 | + loop_thread.start() |
| 271 | + |
| 272 | + # Allow some time for the loop to process |
| 273 | + time.sleep(1) |
| 274 | + |
| 275 | + assert "Request UUID-001 was waiting in the queue for too long" in caplog.text |
| 276 | + resp1 = response_queues[0].get(timeout=10)[1] |
| 277 | + resp2 = response_queues[0].get(timeout=10)[1] |
| 278 | + assert isinstance(resp1[0], HTTPException), "First request was timed out" |
| 279 | + assert resp2[0] == {"output": 25.0}, "Second request wasn't timed out" |
| 280 | + |
| 281 | + # Stop the loop by putting a sentinel value in the queue |
| 282 | + request_queue.put((None, None, None, None)) |
| 283 | + loop_thread.join() |
| 284 | + |
| 285 | + |
| 286 | +def test_run_streaming_loop(): |
| 287 | + lit_api = ls.test_examples.SimpleStreamAPI() |
| 288 | + lit_api.setup(None) |
| 289 | + lit_api.request_timeout = 1 |
| 290 | + |
| 291 | + request_queue = Queue() |
| 292 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"})) |
| 293 | + response_queues = [Queue()] |
| 294 | + |
| 295 | + # Run the loop in a separate thread to allow it to be stopped |
| 296 | + loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues)) |
| 297 | + loop_thread.start() |
| 298 | + |
| 299 | + # Allow some time for the loop to process |
| 300 | + time.sleep(1) |
| 301 | + |
| 302 | + # Stop the loop by putting a sentinel value in the queue |
| 303 | + request_queue.put((None, None, None, None)) |
| 304 | + loop_thread.join() |
| 305 | + |
| 306 | + for i in range(3): |
| 307 | + response = response_queues[0].get(timeout=10) |
| 308 | + response = json.loads(response[1][0]) |
| 309 | + assert response == {"output": f"{i}: Hello"} |
| 310 | + |
| 311 | + |
| 312 | +def test_run_streaming_loop_timeout(caplog): |
| 313 | + lit_api = ls.test_examples.SimpleStreamAPI() |
| 314 | + lit_api.setup(None) |
| 315 | + lit_api.request_timeout = 0.1 |
| 316 | + |
| 317 | + request_queue = Queue() |
| 318 | + request_queue.put((0, "UUID-001", time.monotonic() - 5, {"input": "Hello"})) |
| 319 | + response_queues = [Queue()] |
| 320 | + |
| 321 | + # Run the loop in a separate thread to allow it to be stopped |
| 322 | + loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues)) |
| 323 | + loop_thread.start() |
| 324 | + |
| 325 | + # Allow some time for the loop to process |
| 326 | + time.sleep(1) |
| 327 | + |
| 328 | + # Stop the loop by putting a sentinel value in the queue |
| 329 | + request_queue.put((None, None, None, None)) |
| 330 | + loop_thread.join() |
| 331 | + |
| 332 | + assert "Request UUID-001 was waiting in the queue for too long" in caplog.text |
| 333 | + response = response_queues[0].get(timeout=10)[1] |
| 334 | + assert isinstance(response[0], HTTPException), "request was timed out" |
| 335 | + |
| 336 | + |
| 337 | +def off_test_run_batched_streaming_loop(openai_request_data): |
| 338 | + lit_api = OpenAIBatchingWithUsage() |
| 339 | + lit_api.setup(None) |
| 340 | + lit_api.request_timeout = 1 |
| 341 | + lit_api.stream = True |
| 342 | + spec = ls.OpenAISpec() |
| 343 | + lit_api._sanitize(2, spec) |
| 344 | + |
| 345 | + request_queue = Queue() |
| 346 | + # response_queue_id, uid, timestamp, x_enc |
| 347 | + r1 = (0, "UUID-001", time.monotonic(), openai_request_data) |
| 348 | + r2 = (0, "UUID-002", time.monotonic(), openai_request_data) |
| 349 | + request_queue.put(r1) |
| 350 | + request_queue.put(r2) |
| 351 | + response_queues = [Queue()] |
| 352 | + |
| 353 | + # Run the loop in a separate thread to allow it to be stopped |
| 354 | + loop_thread = threading.Thread( |
| 355 | + target=run_batched_streaming_loop, args=(lit_api, spec, request_queue, response_queues, 2, 0.1) |
| 356 | + ) |
| 357 | + loop_thread.start() |
| 358 | + |
| 359 | + # Allow some time for the loop to process |
| 360 | + time.sleep(1) |
| 361 | + |
| 362 | + # Stop the loop by putting a sentinel value in the queue |
| 363 | + request_queue.put((None, None, None, None)) |
| 364 | + loop_thread.join() |
| 365 | + |
| 366 | + response = response_queues[0].get(timeout=5)[1] |
| 367 | + assert response[0] == {"role": "assistant", "content": "10 + 6 is equal to 16."} |
0 commit comments