Skip to content

Commit 35129f7

Browse files
authored
Add loop.pre_setup to allow fine-grained LitAPI validation based on inference loop (#393)
* pre_setup loop * add test * fix tests * apply feedback
1 parent 75c6d0e commit 35129f7

File tree

6 files changed

+161
-130
lines changed

6 files changed

+161
-130
lines changed

src/litserve/api.py

+1-63
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import inspect
1514
import json
1615
import warnings
1716
from abc import ABC, abstractmethod
@@ -113,76 +112,15 @@ def device(self):
113112
def device(self, value):
114113
self._device = value
115114

116-
def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]):
115+
def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]):
117116
self.max_batch_size = max_batch_size
118117
if self.stream:
119118
self._default_unbatch = self._unbatch_stream
120119
else:
121120
self._default_unbatch = self._unbatch_no_stream
122121

123-
# we will sanitize regularly if no spec
124-
# in case, we have spec then:
125-
# case 1: spec implements a streaming API
126-
# Case 2: spec implements a non-streaming API
127122
if spec:
128-
# TODO: Implement sanitization
129123
self._spec = spec
130-
return
131-
132-
original = self.unbatch.__code__ is LitAPI.unbatch.__code__
133-
if (
134-
self.stream
135-
and max_batch_size > 1
136-
and not all([
137-
inspect.isgeneratorfunction(self.predict),
138-
inspect.isgeneratorfunction(self.encode_response),
139-
(original or inspect.isgeneratorfunction(self.unbatch)),
140-
])
141-
):
142-
raise ValueError(
143-
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
144-
`lit_api.unbatch` must generate values using `yield`.
145-
146-
Example:
147-
148-
def predict(self, inputs):
149-
...
150-
for i in range(max_token_length):
151-
yield prediction
152-
153-
def encode_response(self, outputs):
154-
for output in outputs:
155-
encoded_output = ...
156-
yield encoded_output
157-
158-
def unbatch(self, outputs):
159-
for output in outputs:
160-
unbatched_output = ...
161-
yield unbatched_output
162-
"""
163-
)
164-
165-
if self.stream and not all([
166-
inspect.isgeneratorfunction(self.predict),
167-
inspect.isgeneratorfunction(self.encode_response),
168-
]):
169-
raise ValueError(
170-
"""When `stream=True` both `lit_api.predict` and
171-
`lit_api.encode_response` must generate values using `yield`.
172-
173-
Example:
174-
175-
def predict(self, inputs):
176-
...
177-
for i in range(max_token_length):
178-
yield prediction
179-
180-
def encode_response(self, outputs):
181-
for output in outputs:
182-
encoded_output = ...
183-
yield encoded_output
184-
"""
185-
)
186124

187125
def set_logger_queue(self, queue: Queue):
188126
"""Set the queue for logging events."""

src/litserve/loops.py

+122-48
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def run(
441441
442442
"""
443443

444+
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
445+
pass
446+
444447
def __call__(
445448
self,
446449
lit_api: LitAPI,
@@ -487,7 +490,109 @@ def run(
487490
raise NotImplementedError
488491

489492

490-
class SingleLoop(_BaseLoop):
493+
class LitLoop(_BaseLoop):
494+
def __init__(self):
495+
self._context = {}
496+
497+
def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
498+
if max_batch_size <= 1:
499+
raise ValueError("max_batch_size must be greater than 1")
500+
501+
batches, timed_out_uids = collate_requests(
502+
lit_api,
503+
request_queue,
504+
max_batch_size,
505+
batch_timeout,
506+
)
507+
return batches, timed_out_uids
508+
509+
def get_request(self, request_queue: Queue, timeout: float = 1.0):
510+
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
511+
return response_queue_id, uid, timestamp, x_enc
512+
513+
def populate_context(self, lit_spec: LitSpec, request: Any):
514+
if lit_spec and hasattr(lit_spec, "populate_context"):
515+
lit_spec.populate_context(self._context, request)
516+
517+
def put_response(
518+
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
519+
) -> None:
520+
response_queues[response_queue_id].put((uid, (response_data, status)))
521+
522+
def put_error_response(
523+
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
524+
) -> None:
525+
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))
526+
527+
528+
class DefaultLoop(LitLoop):
529+
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
530+
# we will sanitize regularly if no spec
531+
# in case, we have spec then:
532+
# case 1: spec implements a streaming API
533+
# Case 2: spec implements a non-streaming API
534+
if spec:
535+
# TODO: Implement sanitization
536+
lit_api._spec = spec
537+
return
538+
539+
original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
540+
if (
541+
lit_api.stream
542+
and lit_api.max_batch_size > 1
543+
and not all([
544+
inspect.isgeneratorfunction(lit_api.predict),
545+
inspect.isgeneratorfunction(lit_api.encode_response),
546+
(original or inspect.isgeneratorfunction(lit_api.unbatch)),
547+
])
548+
):
549+
raise ValueError(
550+
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
551+
`lit_api.unbatch` must generate values using `yield`.
552+
553+
Example:
554+
555+
def predict(self, inputs):
556+
...
557+
for i in range(max_token_length):
558+
yield prediction
559+
560+
def encode_response(self, outputs):
561+
for output in outputs:
562+
encoded_output = ...
563+
yield encoded_output
564+
565+
def unbatch(self, outputs):
566+
for output in outputs:
567+
unbatched_output = ...
568+
yield unbatched_output
569+
"""
570+
)
571+
572+
if lit_api.stream and not all([
573+
inspect.isgeneratorfunction(lit_api.predict),
574+
inspect.isgeneratorfunction(lit_api.encode_response),
575+
]):
576+
raise ValueError(
577+
"""When `stream=True` both `lit_api.predict` and
578+
`lit_api.encode_response` must generate values using `yield`.
579+
580+
Example:
581+
582+
def predict(self, inputs):
583+
...
584+
for i in range(max_token_length):
585+
yield prediction
586+
587+
def encode_response(self, outputs):
588+
for output in outputs:
589+
encoded_output = ...
590+
yield encoded_output
591+
"""
592+
)
593+
594+
595+
class SingleLoop(DefaultLoop):
491596
def __call__(
492597
self,
493598
lit_api: LitAPI,
@@ -505,7 +610,7 @@ def __call__(
505610
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
506611

507612

508-
class BatchedLoop(_BaseLoop):
613+
class BatchedLoop(DefaultLoop):
509614
def __call__(
510615
self,
511616
lit_api: LitAPI,
@@ -531,7 +636,7 @@ def __call__(
531636
)
532637

533638

534-
class StreamingLoop(_BaseLoop):
639+
class StreamingLoop(DefaultLoop):
535640
def __call__(
536641
self,
537642
lit_api: LitAPI,
@@ -549,7 +654,7 @@ def __call__(
549654
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
550655

551656

552-
class BatchedStreamingLoop(_BaseLoop):
657+
class BatchedStreamingLoop(DefaultLoop):
553658
def __call__(
554659
self,
555660
lit_api: LitAPI,
@@ -593,41 +698,6 @@ class Output:
593698
status: LitAPIStatus
594699

595700

596-
class LitLoop(_BaseLoop):
597-
def __init__(self):
598-
self._context = {}
599-
600-
def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
601-
if max_batch_size <= 1:
602-
raise ValueError("max_batch_size must be greater than 1")
603-
604-
batches, timed_out_uids = collate_requests(
605-
lit_api,
606-
request_queue,
607-
max_batch_size,
608-
batch_timeout,
609-
)
610-
return batches, timed_out_uids
611-
612-
def get_request(self, request_queue: Queue, timeout: float = 1.0):
613-
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
614-
return response_queue_id, uid, timestamp, x_enc
615-
616-
def populate_context(self, lit_spec: LitSpec, request: Any):
617-
if lit_spec and hasattr(lit_spec, "populate_context"):
618-
lit_spec.populate_context(self._context, request)
619-
620-
def put_response(
621-
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
622-
) -> None:
623-
response_queues[response_queue_id].put((uid, (response_data, status)))
624-
625-
def put_error_response(
626-
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
627-
) -> None:
628-
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))
629-
630-
631701
class ContinuousBatchingLoop(LitLoop):
632702
def __init__(self, max_sequence_length: int = 2048):
633703
super().__init__()
@@ -840,15 +910,7 @@ def inference_worker(
840910
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
841911

842912
if loop == "auto":
843-
loop = (
844-
BatchedStreamingLoop()
845-
if stream and max_batch_size > 1
846-
else StreamingLoop()
847-
if stream
848-
else BatchedLoop()
849-
if max_batch_size > 1
850-
else SingleLoop()
851-
)
913+
loop = get_default_loop(stream, max_batch_size)
852914

853915
loop(
854916
lit_api,
@@ -863,3 +925,15 @@ def inference_worker(
863925
workers_setup_status,
864926
callback_runner,
865927
)
928+
929+
930+
def get_default_loop(stream: bool, max_batch_size: int) -> _BaseLoop:
931+
return (
932+
BatchedStreamingLoop()
933+
if stream and max_batch_size > 1
934+
else StreamingLoop()
935+
if stream
936+
else BatchedLoop()
937+
if max_batch_size > 1
938+
else SingleLoop()
939+
)

src/litserve/server.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from litserve.callbacks.base import Callback, CallbackRunner, EventTypes
4141
from litserve.connector import _Connector
4242
from litserve.loggers import Logger, _LoggerConnector
43-
from litserve.loops import _BaseLoop, inference_worker
43+
from litserve.loops import LitLoop, get_default_loop, inference_worker
4444
from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware
4545
from litserve.python_client import client_template
4646
from litserve.specs import OpenAISpec
@@ -113,7 +113,7 @@ def __init__(
113113
spec: Optional[LitSpec] = None,
114114
max_payload_size=None,
115115
track_requests: bool = False,
116-
loop: Optional[Union[str, _BaseLoop]] = "auto",
116+
loop: Optional[Union[str, LitLoop]] = "auto",
117117
callbacks: Optional[Union[List[Callback], Callback]] = None,
118118
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
119119
loggers: Optional[Union[Logger, List[Logger]]] = None,
@@ -154,6 +154,8 @@ def __init__(
154154

155155
if isinstance(loop, str) and loop != "auto":
156156
raise ValueError("loop must be an instance of _BaseLoop or 'auto'")
157+
if loop == "auto":
158+
loop = get_default_loop(stream, max_batch_size)
157159

158160
if middlewares is None:
159161
middlewares = []
@@ -198,15 +200,16 @@ def __init__(
198200
"but the max_batch_size parameter was not set."
199201
)
200202

201-
self._loop = loop
203+
self._loop: LitLoop = loop
202204
self.api_path = api_path
203205
self.healthcheck_path = healthcheck_path
204206
self.info_path = info_path
205207
self.track_requests = track_requests
206208
self.timeout = timeout
207209
lit_api.stream = stream
208210
lit_api.request_timeout = self.timeout
209-
lit_api._sanitize(max_batch_size, spec=spec)
211+
lit_api.pre_setup(max_batch_size, spec=spec)
212+
self._loop.pre_setup(lit_api, spec=spec)
210213
self.app = FastAPI(lifespan=self.lifespan)
211214
self.app.response_queue_id = None
212215
self.response_queue_id = None

tests/test_batch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_max_batch_size_warning():
154154

155155
def test_batch_predict_string_warning():
156156
api = ls.test_examples.SimpleBatchedAPI()
157-
api._sanitize(2, None)
157+
api.pre_setup(2, None)
158158
api.predict = MagicMock(return_value="This is a string")
159159

160160
mock_input = torch.tensor([[1.0], [2.0]])

0 commit comments

Comments
 (0)