diff --git a/src/litserve/server.py b/src/litserve/server.py index 2413371c..bb2d551c 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -27,7 +27,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from queue import Empty -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import uvicorn from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Request, Response @@ -116,6 +116,7 @@ def __init__( stream: bool = False, spec: Optional[LitSpec] = None, max_payload_size=None, + middlewares: Optional[list[tuple[Callable, dict]]] = None, ): if batch_timeout > timeout and timeout not in (False, -1): raise ValueError("batch_timeout must be less than timeout") @@ -123,6 +124,8 @@ def __init__( raise ValueError("max_batch_size must be greater than 0") if isinstance(spec, OpenAISpec): stream = True + if middlewares is None: + middlewares = [] if not api_path.startswith("/"): raise ValueError( @@ -150,15 +153,17 @@ def __init__( self.response_buffer = {} # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 if not stream: - self.app.add_middleware(GZipMiddleware, minimum_size=1000) + middlewares.append((GZipMiddleware, {"minimum_size": 1000})) if max_payload_size is not None: - self.app.add_middleware(MaxSizeMiddleware, max_size=max_payload_size) + middlewares.append((MaxSizeMiddleware, {"max_size": max_payload_size})) + self.middlewares = middlewares self.lit_api = lit_api self.lit_spec = spec self.workers_per_device = workers_per_device self.max_batch_size = max_batch_size self.batch_timeout = batch_timeout self.stream = stream + self.max_payload_size = max_payload_size self._connector = _Connector(accelerator=accelerator, devices=devices) specs = spec if spec is not None else [] @@ -362,7 +367,11 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())] ) - def generate_client_file(self): + for middleware, kwargs in self.middlewares: + self.app.add_middleware(middleware, **kwargs) + + @staticmethod + def generate_client_file(): src_path = os.path.join(os.path.dirname(__file__), "python_client.py") dest_path = os.path.join(os.getcwd(), "client.py") diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 21f81a7e..6731bec6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -30,6 +30,8 @@ from litserve.server import LitServer import litserve as ls from fastapi.testclient import TestClient +from starlette.types import ASGIApp +from starlette.middleware.base import BaseHTTPMiddleware def test_index(sync_testclient): @@ -310,3 +312,24 @@ def test_http_exception(): response = client.post("/predict", json={"input": 4.0}) assert response.status_code == 501, "Server raises 501 error" assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" + + +class RequestIdMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, length: int) -> None: + self.app = app + self.length = length + super().__init__(app) + + async def dispatch(self, request, call_next): + response = await call_next(request) + response.headers["X-Request-Id"] = "0" * self.length + return response + + +def test_custom_middleware(): + server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})]) + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}" + assert response.json() == {"output": 16.0}, "server didn't return expected output" + assert response.headers["X-Request-Id"] == "00000"