Skip to content

Commit 57d0308

Browse files
committed
resolve conflict
2 parents 6a0817e + 15e5f3d commit 57d0308

File tree

5 files changed

+543
-459
lines changed

5 files changed

+543
-459
lines changed

src/litserve/loops.py

+372
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
import inspect
16+
import logging
17+
import multiprocessing as mp
18+
import os
19+
import pickle
20+
import sys
21+
import time
22+
from queue import Empty, Queue
23+
from typing import Dict, List, Optional, Tuple, Union
24+
25+
from fastapi import HTTPException
26+
from starlette.formparsers import MultiPartParser
27+
28+
from litserve import LitAPI
29+
from litserve.specs.base import LitSpec
30+
from litserve.utils import LitAPIStatus
31+
32+
mp.allow_connection_pickling()
33+
34+
try:
35+
import uvloop
36+
37+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
38+
39+
except ImportError:
40+
print(
41+
"uvloop is not installed. Falling back to the default asyncio event loop. "
42+
"Please install uvloop for better performance using `pip install uvloop`."
43+
)
44+
45+
logger = logging.getLogger(__name__)
46+
47+
# if defined, it will require clients to auth with X-API-Key in the header
48+
LIT_SERVER_API_KEY = os.environ.get("LIT_SERVER_API_KEY")
49+
50+
# timeout when we need to poll or wait indefinitely for a result in a loop.
51+
LONG_TIMEOUT = 100
52+
53+
# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing
54+
MultiPartParser.max_file_size = sys.maxsize
55+
56+
57+
def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
58+
sig = inspect.signature(func)
59+
if "context" in sig.parameters:
60+
return func(*args, **kwargs, context=context)
61+
return func(*args, **kwargs)
62+
63+
64+
def get_batch_from_uid(uids, lit_api, request_buffer):
65+
batches = []
66+
for uid in uids:
67+
try:
68+
x_enc, pipe_s = request_buffer.pop(uid)
69+
except KeyError:
70+
continue
71+
batches.append((x_enc, pipe_s))
72+
return batches
73+
74+
75+
def collate_requests(
76+
lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float
77+
) -> Tuple[List, List]:
78+
payloads = []
79+
timed_out_uids = []
80+
entered_at = time.monotonic()
81+
end_time = entered_at + batch_timeout
82+
apply_timeout = lit_api.request_timeout not in (-1, False)
83+
84+
if batch_timeout == 0:
85+
while len(payloads) < max_batch_size:
86+
try:
87+
response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait()
88+
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
89+
timed_out_uids.append((response_queue_id, uid))
90+
else:
91+
payloads.append((response_queue_id, uid, x_enc))
92+
except Empty:
93+
break
94+
return payloads, timed_out_uids
95+
96+
while time.monotonic() < end_time and len(payloads) < max_batch_size:
97+
remaining_time = end_time - time.monotonic()
98+
if remaining_time <= 0:
99+
break
100+
101+
try:
102+
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001))
103+
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
104+
timed_out_uids.append((response_queue_id, uid))
105+
else:
106+
payloads.append((response_queue_id, uid, x_enc))
107+
108+
except Empty:
109+
continue
110+
111+
return payloads, timed_out_uids
112+
113+
114+
def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
115+
while True:
116+
try:
117+
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
118+
except (Empty, ValueError):
119+
continue
120+
121+
if (lit_api.request_timeout and lit_api.request_timeout != -1) and (
122+
time.monotonic() - timestamp > lit_api.request_timeout
123+
):
124+
logger.error(
125+
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
126+
"has been timed out. "
127+
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
128+
)
129+
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
130+
continue
131+
try:
132+
context = {}
133+
if hasattr(lit_spec, "populate_context"):
134+
lit_spec.populate_context(context, x_enc)
135+
x = _inject_context(
136+
context,
137+
lit_api.decode_request,
138+
x_enc,
139+
)
140+
y = _inject_context(
141+
context,
142+
lit_api.predict,
143+
x,
144+
)
145+
y_enc = _inject_context(
146+
context,
147+
lit_api.encode_response,
148+
y,
149+
)
150+
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
151+
except Exception as e:
152+
logger.exception(
153+
"LitAPI ran into an error while processing the request uid=%s.\n"
154+
"Please check the error trace for more details.",
155+
uid,
156+
)
157+
err_pkl = pickle.dumps(e)
158+
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
159+
160+
161+
def run_batched_loop(
162+
lit_api: LitAPI,
163+
lit_spec: LitSpec,
164+
request_queue: Queue,
165+
response_queues: List[Queue],
166+
max_batch_size: int,
167+
batch_timeout: float,
168+
):
169+
while True:
170+
batches, timed_out_uids = collate_requests(
171+
lit_api,
172+
request_queue,
173+
max_batch_size,
174+
batch_timeout,
175+
)
176+
177+
for response_queue_id, uid in timed_out_uids:
178+
logger.error(
179+
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
180+
"has been timed out. "
181+
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
182+
)
183+
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
184+
185+
if not batches:
186+
continue
187+
logger.debug(f"{len(batches)} batched requests received")
188+
response_queue_ids, uids, inputs = zip(*batches)
189+
try:
190+
contexts = [{}] * len(inputs)
191+
if hasattr(lit_spec, "populate_context"):
192+
for input, context in zip(inputs, contexts):
193+
lit_spec.populate_context(context, input)
194+
195+
x = [
196+
_inject_context(
197+
context,
198+
lit_api.decode_request,
199+
input,
200+
)
201+
for input, context in zip(inputs, contexts)
202+
]
203+
x = lit_api.batch(x)
204+
y = _inject_context(contexts, lit_api.predict, x)
205+
outputs = lit_api.unbatch(y)
206+
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
207+
y_enc = _inject_context(context, lit_api.encode_response, y)
208+
209+
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
210+
211+
except Exception as e:
212+
logger.exception(
213+
"LitAPI ran into an error while processing the batched request.\n"
214+
"Please check the error trace for more details."
215+
)
216+
err_pkl = pickle.dumps(e)
217+
for response_queue_id, uid in zip(response_queue_ids, uids):
218+
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
219+
220+
221+
def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
222+
while True:
223+
try:
224+
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
225+
logger.debug("uid=%s", uid)
226+
except (Empty, ValueError):
227+
continue
228+
229+
if (lit_api.request_timeout and lit_api.request_timeout != -1) and (
230+
time.monotonic() - timestamp > lit_api.request_timeout
231+
):
232+
logger.error(
233+
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
234+
"has been timed out. "
235+
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
236+
)
237+
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
238+
continue
239+
240+
try:
241+
context = {}
242+
if hasattr(lit_spec, "populate_context"):
243+
lit_spec.populate_context(context, x_enc)
244+
x = _inject_context(
245+
context,
246+
lit_api.decode_request,
247+
x_enc,
248+
)
249+
y_gen = _inject_context(
250+
context,
251+
lit_api.predict,
252+
x,
253+
)
254+
y_enc_gen = _inject_context(
255+
context,
256+
lit_api.encode_response,
257+
y_gen,
258+
)
259+
for y_enc in y_enc_gen:
260+
y_enc = lit_api.format_encoded_response(y_enc)
261+
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
262+
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
263+
except Exception as e:
264+
logger.exception(
265+
"LitAPI ran into an error while processing the streaming request uid=%s.\n"
266+
"Please check the error trace for more details.",
267+
uid,
268+
)
269+
response_queues[response_queue_id].put((uid, (pickle.dumps(e), LitAPIStatus.ERROR)))
270+
271+
272+
def run_batched_streaming_loop(
273+
lit_api: LitAPI,
274+
lit_spec: LitSpec,
275+
request_queue: Queue,
276+
response_queues: List[Queue],
277+
max_batch_size: int,
278+
batch_timeout: float,
279+
):
280+
while True:
281+
batches, timed_out_uids = collate_requests(
282+
lit_api,
283+
request_queue,
284+
max_batch_size,
285+
batch_timeout,
286+
)
287+
for response_queue_id, uid in timed_out_uids:
288+
logger.error(
289+
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
290+
"has been timed out. "
291+
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
292+
)
293+
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
294+
295+
if not batches:
296+
continue
297+
response_queue_ids, uids, inputs = zip(*batches)
298+
try:
299+
contexts = [{}] * len(inputs)
300+
if hasattr(lit_spec, "populate_context"):
301+
for input, context in zip(inputs, contexts):
302+
lit_spec.populate_context(context, input)
303+
304+
x = [
305+
_inject_context(
306+
context,
307+
lit_api.decode_request,
308+
input,
309+
)
310+
for input, context in zip(inputs, contexts)
311+
]
312+
x = lit_api.batch(x)
313+
y_iter = _inject_context(contexts, lit_api.predict, x)
314+
unbatched_iter = lit_api.unbatch(y_iter)
315+
y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter)
316+
317+
# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
318+
for y_batch in y_enc_iter:
319+
for response_queue_id, y_enc, uid in zip(response_queue_ids, y_batch, uids):
320+
y_enc = lit_api.format_encoded_response(y_enc)
321+
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
322+
323+
for response_queue_id, uid in zip(response_queue_ids, uids):
324+
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
325+
326+
except Exception as e:
327+
logger.exception(
328+
"LitAPI ran into an error while processing the streaming batched request.\n"
329+
"Please check the error trace for more details."
330+
)
331+
err_pkl = pickle.dumps(e)
332+
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
333+
334+
335+
def inference_worker(
336+
lit_api: LitAPI,
337+
lit_spec: Optional[LitSpec],
338+
device: str,
339+
worker_id: int,
340+
request_queue: Queue,
341+
response_queues: List[Queue],
342+
max_batch_size: int,
343+
batch_timeout: float,
344+
stream: bool,
345+
workers_setup_status: Dict[str, bool] = None,
346+
):
347+
lit_api.setup(device)
348+
lit_api.device = device
349+
350+
print(f"Setup complete for worker {worker_id}.")
351+
352+
if workers_setup_status:
353+
workers_setup_status[worker_id] = True
354+
355+
if lit_spec:
356+
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
357+
if stream:
358+
if max_batch_size > 1:
359+
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
360+
else:
361+
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
362+
return
363+
364+
if max_batch_size > 1:
365+
run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
366+
else:
367+
run_single_loop(
368+
lit_api,
369+
lit_spec,
370+
request_queue,
371+
response_queues,
372+
)

0 commit comments

Comments
 (0)