Skip to content

Commit

Permalink
feat: middlewares and exception handlers in LitServer. (#241)
Browse files Browse the repository at this point in the history
* feat: middlewares and exception handlers in LitServer.
test: e2e for a server with custom middlewares and exception handlers.

* feat: exception handlers removed.

* refactor: gzip and maxsize appended to middlewares.

* refactor: middleware test moved from e2e to test_lit_server.

* chore: removed SimpleExceptionApi example.

* fix: Callable as type hint for middlewares param.
  • Loading branch information
lorenzomassimiani authored Aug 30, 2024
1 parent 15e5f3d commit 686db0c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from queue import Empty
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import uvicorn
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Request, Response
Expand Down Expand Up @@ -116,13 +116,16 @@ def __init__(
stream: bool = False,
spec: Optional[LitSpec] = None,
max_payload_size=None,
middlewares: Optional[list[tuple[Callable, dict]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if isinstance(spec, OpenAISpec):
stream = True
if middlewares is None:
middlewares = []

if not api_path.startswith("/"):
raise ValueError(
Expand Down Expand Up @@ -150,15 +153,17 @@ def __init__(
self.response_buffer = {}
# gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448
if not stream:
self.app.add_middleware(GZipMiddleware, minimum_size=1000)
middlewares.append((GZipMiddleware, {"minimum_size": 1000}))
if max_payload_size is not None:
self.app.add_middleware(MaxSizeMiddleware, max_size=max_payload_size)
middlewares.append((MaxSizeMiddleware, {"max_size": max_payload_size}))
self.middlewares = middlewares
self.lit_api = lit_api
self.lit_spec = spec
self.workers_per_device = workers_per_device
self.max_batch_size = max_batch_size
self.batch_timeout = batch_timeout
self.stream = stream
self.max_payload_size = max_payload_size
self._connector = _Connector(accelerator=accelerator, devices=devices)

specs = spec if spec is not None else []
Expand Down Expand Up @@ -362,7 +367,11 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
)

def generate_client_file(self):
for middleware, kwargs in self.middlewares:
self.app.add_middleware(middleware, **kwargs)

@staticmethod
def generate_client_file():
src_path = os.path.join(os.path.dirname(__file__), "python_client.py")
dest_path = os.path.join(os.getcwd(), "client.py")

Expand Down
23 changes: 23 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from litserve.server import LitServer
import litserve as ls
from fastapi.testclient import TestClient
from starlette.types import ASGIApp
from starlette.middleware.base import BaseHTTPMiddleware


def test_index(sync_testclient):
Expand Down Expand Up @@ -310,3 +312,24 @@ def test_http_exception():
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 501, "Server raises 501 error"
assert response.text == '{"detail":"decode request is bad"}', "decode request is bad"


class RequestIdMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, length: int) -> None:
self.app = app
self.length = length
super().__init__(app)

async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["X-Request-Id"] = "0" * self.length
return response


def test_custom_middleware():
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})])
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
assert response.json() == {"output": 16.0}, "server didn't return expected output"
assert response.headers["X-Request-Id"] == "00000"

0 comments on commit 686db0c

Please sign in to comment.