Skip to content

Commit 9e34faa

Browse files
committed
add tests
1 parent 07d60df commit 9e34faa

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

src/litserve/server.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from fastapi.security import APIKeyHeader
3636
from starlette.formparsers import MultiPartParser
3737
from starlette.middleware.gzip import GZipMiddleware
38+
from starlette.middleware import Middleware
3839

3940
from litserve import LitAPI
4041
from litserve.connector import _Connector
@@ -113,16 +114,25 @@ def __init__(
113114
stream: bool = False,
114115
spec: Optional[LitSpec] = None,
115116
max_payload_size=None,
116-
middlewares: Optional[list[tuple[Callable, dict]]] = None,
117+
middlewares: Optional[list[Union[Middleware, tuple[Callable, dict]]]] = None,
117118
):
118119
if batch_timeout > timeout and timeout not in (False, -1):
119120
raise ValueError("batch_timeout must be less than timeout")
120121
if max_batch_size <= 0:
121122
raise ValueError("max_batch_size must be greater than 0")
122123
if isinstance(spec, OpenAISpec):
123124
stream = True
125+
124126
if middlewares is None:
125127
middlewares = []
128+
if not isinstance(middlewares, list):
129+
_msg = (
130+
"middlewares must be a list of tuples"
131+
" where each tuple contains a middleware and its arguments. For example:\n"
132+
"server = ls.LitServer(ls.examples.SimpleLitAPI(), "
133+
'middlewares=[(RequestIdMiddleware, {"length": 5})])'
134+
)
135+
raise ValueError(_msg)
126136

127137
if not api_path.startswith("/"):
128138
raise ValueError(
@@ -364,8 +374,12 @@ async def stream_predict(request: self.request_type) -> self.response_type:
364374
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
365375
)
366376

367-
for middleware, kwargs in self.middlewares:
368-
self.app.add_middleware(middleware, **kwargs)
377+
for middleware in self.middlewares:
378+
if isinstance(middleware, tuple):
379+
middleware, kwargs = middleware
380+
self.app.add_middleware(middleware, **kwargs)
381+
elif callable(middleware):
382+
self.app.add_middleware(middleware)
369383

370384
@staticmethod
371385
def generate_client_file():

tests/test_lit_server.py

+22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import torch.nn as nn
2424
from httpx import AsyncClient
2525
from litserve.utils import wrap_litserve_start
26+
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
27+
from starlette.middleware.trustedhost import TrustedHostMiddleware
2628

2729
from unittest.mock import patch, MagicMock
2830
import pytest
@@ -393,3 +395,23 @@ def test_custom_middleware():
393395
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
394396
assert response.json() == {"output": 16.0}, "server didn't return expected output"
395397
assert response.headers["X-Request-Id"] == "00000"
398+
399+
400+
def test_starlette_middlewares():
401+
middlewares = [
402+
(
403+
TrustedHostMiddleware,
404+
{
405+
"allowed_hosts": ["localhost", "127.0.0.1"],
406+
},
407+
),
408+
HTTPSRedirectMiddleware,
409+
]
410+
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=middlewares)
411+
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
412+
response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"})
413+
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
414+
assert response.json() == {"output": 16.0}, "server didn't return expected output"
415+
416+
response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"})
417+
assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}"

0 commit comments

Comments
 (0)