diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index 87ac535d..d465da22 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -264,6 +264,28 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): return original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ + if not lit_api.stream and any([ + inspect.isgeneratorfunction(lit_api.predict), + inspect.isgeneratorfunction(lit_api.encode_response), + ]): + raise ValueError( + """When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be + generator functions. + + Correct usage: + + def predict(self, inputs): + ... + return {"output": output} + + Incorrect usage: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + """ + ) if ( lit_api.stream and lit_api.max_batch_size > 1 diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 12ddb479..2ed57491 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time from queue import Empty, Queue @@ -21,7 +20,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import LitLoop, _inject_context, collate_requests +from litserve.loops.base import DefaultLoop, _inject_context, collate_requests from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, PickleableHTTPException @@ -174,73 +173,6 @@ def run_batched_streaming_loop( response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) -class DefaultLoop(LitLoop): - def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): - # we will sanitize regularly if no spec - # in case, we have spec then: - # case 1: spec implements a streaming API - # Case 2: spec implements a non-streaming API - if spec: - # TODO: Implement sanitization - lit_api._spec = spec - return - - original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ - if ( - lit_api.stream - and lit_api.max_batch_size > 1 - and not all([ - inspect.isgeneratorfunction(lit_api.predict), - inspect.isgeneratorfunction(lit_api.encode_response), - (original or inspect.isgeneratorfunction(lit_api.unbatch)), - ]) - ): - raise ValueError( - """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and - `lit_api.unbatch` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - - def unbatch(self, outputs): - for output in outputs: - unbatched_output = ... - yield unbatched_output - """ - ) - - if lit_api.stream and not all([ - inspect.isgeneratorfunction(lit_api.predict), - inspect.isgeneratorfunction(lit_api.encode_response), - ]): - raise ValueError( - """When `stream=True` both `lit_api.predict` and - `lit_api.encode_response` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - """ - ) - - class StreamingLoop(DefaultLoop): def __call__( self,