Skip to content

Commit

Permalink
Add loop.pre_setup to allow fine-grained LitAPI validation based on…
Browse files Browse the repository at this point in the history
… inference loop (#393)

* pre_setup loop

* add test

* fix tests

* apply feedback
  • Loading branch information
aniketmaurya authored Dec 11, 2024
1 parent 75c6d0e commit 35129f7
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 130 deletions.
64 changes: 1 addition & 63 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -113,76 +112,15 @@ def device(self):
def device(self, value):
self._device = value

def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]):
def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]):
self.max_batch_size = max_batch_size
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
self._default_unbatch = self._unbatch_no_stream

# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
self._spec = spec
return

original = self.unbatch.__code__ is LitAPI.unbatch.__code__
if (
self.stream
and max_batch_size > 1
and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
(original or inspect.isgeneratorfunction(self.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if self.stream and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)

def set_logger_queue(self, queue: Queue):
"""Set the queue for logging events."""
Expand Down
170 changes: 122 additions & 48 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def run(
"""

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
pass

def __call__(
self,
lit_api: LitAPI,
Expand Down Expand Up @@ -487,7 +490,109 @@ def run(
raise NotImplementedError


class SingleLoop(_BaseLoop):
class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class DefaultLoop(LitLoop):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
lit_api._spec = spec
return

original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
if (
lit_api.stream
and lit_api.max_batch_size > 1
and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
(original or inspect.isgeneratorfunction(lit_api.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if lit_api.stream and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)


class SingleLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -505,7 +610,7 @@ def __call__(
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)


class BatchedLoop(_BaseLoop):
class BatchedLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -531,7 +636,7 @@ def __call__(
)


class StreamingLoop(_BaseLoop):
class StreamingLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -549,7 +654,7 @@ def __call__(
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)


class BatchedStreamingLoop(_BaseLoop):
class BatchedStreamingLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand Down Expand Up @@ -593,41 +698,6 @@ class Output:
status: LitAPIStatus


class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
super().__init__()
Expand Down Expand Up @@ -840,15 +910,7 @@ def inference_worker(
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")

if loop == "auto":
loop = (
BatchedStreamingLoop()
if stream and max_batch_size > 1
else StreamingLoop()
if stream
else BatchedLoop()
if max_batch_size > 1
else SingleLoop()
)
loop = get_default_loop(stream, max_batch_size)

loop(
lit_api,
Expand All @@ -863,3 +925,15 @@ def inference_worker(
workers_setup_status,
callback_runner,
)


def get_default_loop(stream: bool, max_batch_size: int) -> _BaseLoop:
return (
BatchedStreamingLoop()
if stream and max_batch_size > 1
else StreamingLoop()
if stream
else BatchedLoop()
if max_batch_size > 1
else SingleLoop()
)
11 changes: 7 additions & 4 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from litserve.callbacks.base import Callback, CallbackRunner, EventTypes
from litserve.connector import _Connector
from litserve.loggers import Logger, _LoggerConnector
from litserve.loops import _BaseLoop, inference_worker
from litserve.loops import LitLoop, get_default_loop, inference_worker
from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
spec: Optional[LitSpec] = None,
max_payload_size=None,
track_requests: bool = False,
loop: Optional[Union[str, _BaseLoop]] = "auto",
loop: Optional[Union[str, LitLoop]] = "auto",
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
Expand Down Expand Up @@ -154,6 +154,8 @@ def __init__(

if isinstance(loop, str) and loop != "auto":
raise ValueError("loop must be an instance of _BaseLoop or 'auto'")
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

if middlewares is None:
middlewares = []
Expand Down Expand Up @@ -198,15 +200,16 @@ def __init__(
"but the max_batch_size parameter was not set."
)

self._loop = loop
self._loop: LitLoop = loop
self.api_path = api_path
self.healthcheck_path = healthcheck_path
self.info_path = info_path
self.track_requests = track_requests
self.timeout = timeout
lit_api.stream = stream
lit_api.request_timeout = self.timeout
lit_api._sanitize(max_batch_size, spec=spec)
lit_api.pre_setup(max_batch_size, spec=spec)
self._loop.pre_setup(lit_api, spec=spec)
self.app = FastAPI(lifespan=self.lifespan)
self.app.response_queue_id = None
self.response_queue_id = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_max_batch_size_warning():

def test_batch_predict_string_warning():
api = ls.test_examples.SimpleBatchedAPI()
api._sanitize(2, None)
api.pre_setup(2, None)
api.predict = MagicMock(return_value="This is a string")

mock_input = torch.tensor([[1.0], [2.0]])
Expand Down
Loading

0 comments on commit 35129f7

Please sign in to comment.