Skip to content

Commit

Permalink
add validation for stream=False with yield usage (#402)
Browse files Browse the repository at this point in the history
* add validation for stream=False and yield usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update base.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update streaming_loops.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aniketmaurya and pre-commit-ci[bot] authored Jan 9, 2025
1 parent e9b1ee1 commit 747308a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 69 deletions.
22 changes: 22 additions & 0 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,28 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
return

original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
if not lit_api.stream and any([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
]):
raise ValueError(
"""When `stream=False`, `lit_api.predict`, `lit_api.encode_response` must not be
generator functions.
Correct usage:
def predict(self, inputs):
...
return {"output": output}
Incorrect usage:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
"""
)
if (
lit_api.stream
and lit_api.max_batch_size > 1
Expand Down
70 changes: 1 addition & 69 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
from queue import Empty, Queue
Expand All @@ -21,7 +20,7 @@

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import LitLoop, _inject_context, collate_requests
from litserve.loops.base import DefaultLoop, _inject_context, collate_requests
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, PickleableHTTPException

Expand Down Expand Up @@ -174,73 +173,6 @@ def run_batched_streaming_loop(
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


class DefaultLoop(LitLoop):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
lit_api._spec = spec
return

original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
if (
lit_api.stream
and lit_api.max_batch_size > 1
and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
(original or inspect.isgeneratorfunction(lit_api.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if lit_api.stream and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.
Example:
def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction
def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)


class StreamingLoop(DefaultLoop):
def __call__(
self,
Expand Down

0 comments on commit 747308a

Please sign in to comment.