From 747308adde767d0e96cf10a252d2d8f419a9f5aa Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 9 Jan 2025 14:51:22 +0000 Subject: [PATCH] add validation for `stream=False` with `yield` usage (#402) * add validation for stream=False and yield usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update streaming_loops.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/litserve/loops/base.py | 22 +++++++++ src/litserve/loops/streaming_loops.py | 70 +-------------------------- 2 files changed, 23 insertions(+), 69 deletions(-) diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index 87ac535d..d465da22 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -264,6 +264,28 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): return original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ + if not lit_api.stream and any([ + inspect.isgeneratorfunction(lit_api.predict), + inspect.isgeneratorfunction(lit_api.encode_response), + ]): + raise ValueError( + """When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be + generator functions. + + Correct usage: + + def predict(self, inputs): + ... + return {"output": output} + + Incorrect usage: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + """ + ) if ( lit_api.stream and lit_api.max_batch_size > 1 diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 12ddb479..2ed57491 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time from queue import Empty, Queue @@ -21,7 +20,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import LitLoop, _inject_context, collate_requests +from litserve.loops.base import DefaultLoop, _inject_context, collate_requests from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, PickleableHTTPException @@ -174,73 +173,6 @@ def run_batched_streaming_loop( response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) -class DefaultLoop(LitLoop): - def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): - # we will sanitize regularly if no spec - # in case, we have spec then: - # case 1: spec implements a streaming API - # Case 2: spec implements a non-streaming API - if spec: - # TODO: Implement sanitization - lit_api._spec = spec - return - - original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ - if ( - lit_api.stream - and lit_api.max_batch_size > 1 - and not all([ - inspect.isgeneratorfunction(lit_api.predict), - inspect.isgeneratorfunction(lit_api.encode_response), - (original or inspect.isgeneratorfunction(lit_api.unbatch)), - ]) - ): - raise ValueError( - """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and - `lit_api.unbatch` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - - def unbatch(self, outputs): - for output in outputs: - unbatched_output = ... - yield unbatched_output - """ - ) - - if lit_api.stream and not all([ - inspect.isgeneratorfunction(lit_api.predict), - inspect.isgeneratorfunction(lit_api.encode_response), - ]): - raise ValueError( - """When `stream=True` both `lit_api.predict` and - `lit_api.encode_response` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - """ - ) - - class StreamingLoop(DefaultLoop): def __call__( self,