|
35 | 35 | from fastapi.security import APIKeyHeader
|
36 | 36 | from starlette.formparsers import MultiPartParser
|
37 | 37 | from starlette.middleware.gzip import GZipMiddleware
|
| 38 | +from starlette.middleware import Middleware |
38 | 39 |
|
39 | 40 | from litserve import LitAPI
|
40 | 41 | from litserve.connector import _Connector
|
@@ -113,16 +114,25 @@ def __init__(
|
113 | 114 | stream: bool = False,
|
114 | 115 | spec: Optional[LitSpec] = None,
|
115 | 116 | max_payload_size=None,
|
116 |
| - middlewares: Optional[list[tuple[Callable, dict]]] = None, |
| 117 | + middlewares: Optional[list[Union[Middleware, tuple[Callable, dict]]]] = None, |
117 | 118 | ):
|
118 | 119 | if batch_timeout > timeout and timeout not in (False, -1):
|
119 | 120 | raise ValueError("batch_timeout must be less than timeout")
|
120 | 121 | if max_batch_size <= 0:
|
121 | 122 | raise ValueError("max_batch_size must be greater than 0")
|
122 | 123 | if isinstance(spec, OpenAISpec):
|
123 | 124 | stream = True
|
| 125 | + |
124 | 126 | if middlewares is None:
|
125 | 127 | middlewares = []
|
| 128 | + if not isinstance(middlewares, list): |
| 129 | + _msg = ( |
| 130 | + "middlewares must be a list of tuples" |
| 131 | + " where each tuple contains a middleware and its arguments. For example:\n" |
| 132 | + "server = ls.LitServer(ls.examples.SimpleLitAPI(), " |
| 133 | + 'middlewares=[(RequestIdMiddleware, {"length": 5})])' |
| 134 | + ) |
| 135 | + raise ValueError(_msg) |
126 | 136 |
|
127 | 137 | if not api_path.startswith("/"):
|
128 | 138 | raise ValueError(
|
@@ -364,8 +374,12 @@ async def stream_predict(request: self.request_type) -> self.response_type:
|
364 | 374 | path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
|
365 | 375 | )
|
366 | 376 |
|
367 |
| - for middleware, kwargs in self.middlewares: |
368 |
| - self.app.add_middleware(middleware, **kwargs) |
| 377 | + for middleware in self.middlewares: |
| 378 | + if isinstance(middleware, tuple): |
| 379 | + middleware, kwargs = middleware |
| 380 | + self.app.add_middleware(middleware, **kwargs) |
| 381 | + elif callable(middleware): |
| 382 | + self.app.add_middleware(middleware) |
369 | 383 |
|
370 | 384 | @staticmethod
|
371 | 385 | def generate_client_file():
|
|
0 commit comments