Skip to content

Commit

Permalink
add Starlette middleware support (#253)
Browse files Browse the repository at this point in the history
* add tests

* test update tests

* update type
  • Loading branch information
aniketmaurya authored Aug 30, 2024
1 parent 66109e1 commit b1aaee4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,25 @@ def __init__(
stream: bool = False,
spec: Optional[LitSpec] = None,
max_payload_size=None,
middlewares: Optional[list[tuple[Callable, dict]]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if isinstance(spec, OpenAISpec):
stream = True

if middlewares is None:
middlewares = []
if not isinstance(middlewares, list):
_msg = (
"middlewares must be a list of tuples"
" where each tuple contains a middleware and its arguments. For example:\n"
"server = ls.LitServer(ls.examples.SimpleLitAPI(), "
'middlewares=[(RequestIdMiddleware, {"length": 5})])'
)
raise ValueError(_msg)

if not api_path.startswith("/"):
raise ValueError(
Expand Down Expand Up @@ -364,8 +373,12 @@ async def stream_predict(request: self.request_type) -> self.response_type:
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
)

for middleware, kwargs in self.middlewares:
self.app.add_middleware(middleware, **kwargs)
for middleware in self.middlewares:
if isinstance(middleware, tuple):
middleware, kwargs = middleware
self.app.add_middleware(middleware, **kwargs)
elif callable(middleware):
self.app.add_middleware(middleware)

@staticmethod
def generate_client_file():
Expand Down
36 changes: 36 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch.nn as nn
from httpx import AsyncClient
from litserve.utils import wrap_litserve_start
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware

from unittest.mock import patch, MagicMock
import pytest
Expand Down Expand Up @@ -393,3 +395,37 @@ def test_custom_middleware():
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
assert response.json() == {"output": 16.0}, "server didn't return expected output"
assert response.headers["X-Request-Id"] == "00000"


def test_starlette_middlewares():
middlewares = [
(
TrustedHostMiddleware,
{
"allowed_hosts": ["localhost", "127.0.0.1"],
},
),
HTTPSRedirectMiddleware,
]
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=middlewares)
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"})
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
assert response.json() == {"output": 16.0}, "server didn't return expected output"

response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"})
assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}"


def test_middlewares_inputs():
server = ls.LitServer(SimpleLitAPI(), middlewares=[])
assert len(server.middlewares) == 1, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000)
assert len(server.middlewares) == 2, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=None)
assert len(server.middlewares) == 1, "Default middleware should be present"

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))

0 comments on commit b1aaee4

Please sign in to comment.