Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 31, 2024
1 parent 5ee1b08 commit 2463ab3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 29 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ MNIST

src/litserve/server.log
src/client.py
src/start_server.py
src/litserve/start_server.py
src/start_server.py
src/litserve/start_server.py
src/litserve/client.py


1 change: 1 addition & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def predict(self, x, **kwargs):
def process_on_gpu(self, batch):
"""Process the batch on the GPU."""
pass

def process_on_cpu(self, batch):
"""Process the batch on the CPU."""
pass
Expand Down
23 changes: 13 additions & 10 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,18 @@
import time
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union
from datetime import datetime
from fastapi import HTTPException
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus

import logging

logging.basicConfig(filename='server.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(filename="server.log", level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)



mp.allow_connection_pickling()

try:
Expand Down Expand Up @@ -100,8 +97,15 @@ def collate_requests(
return payloads, timed_out_uids


async def run_heter_pipeline(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue],
max_batch_size: int, batch_timeout: float, heter_pipeline: Queue):
async def run_heter_pipeline(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
heter_pipeline: Queue,
):
cpu_batch = []
gpu_batch = []
cpu_to_gpu_queue = Queue()
Expand Down Expand Up @@ -175,9 +179,6 @@ async def move_cpu_results_to_gpu_batch():
print("All batches processed, exiting...")





def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
while True:
try:
Expand Down Expand Up @@ -427,7 +428,9 @@ async def inference_worker(

if heter_pipeline is not None:
print(f"Worker {worker_id} using heter pipeline")
await run_heter_pipeline(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, heter_pipeline)
await run_heter_pipeline(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, heter_pipeline
)
elif stream:
print(f"Worker {worker_id} using streaming")
if max_batch_size > 1:
Expand Down
29 changes: 14 additions & 15 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from litserve import LitAPI
from litserve.connector import _Connector
from litserve.loops import inference_worker,run_heter_pipeline,run_inference_worker
from litserve.loops import run_inference_worker
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, MaxSizeMiddleware, load_and_raise
Expand Down Expand Up @@ -198,9 +198,9 @@ def __init__(
self.devices = [self.device_identifiers(accelerator, device) for device in device_list]

self.workers = self.devices * self.workers_per_device

self.use_heter_pipeline = use_heter_pipeline
print("use_heter_pipeline ",self.use_heter_pipeline)
print("use_heter_pipeline ", self.use_heter_pipeline)
self.heter_pipeline = None # Will be initialized in launch_inference_worker if needed

self.setup_server()
Expand All @@ -219,7 +219,6 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.heter_pipeline = manager.Queue()
print("Heterogeneous pipeline queue created")


for spec in self._specs:
# Objects of Server class are referenced (not copied)
logging.debug(f"shallow copy for Server is created for for spec {spec}")
Expand Down Expand Up @@ -323,12 +322,11 @@ async def health(request: Request) -> Response:
return Response(content="ok", status_code=200)

return Response(content="not ready", status_code=503)

async def predict(request: self.request_type) -> self.response_type:
try:
response_queue_id = self.app.response_queue_id


uid = uuid.uuid4()
event = asyncio.Event()
self.response_buffer[uid] = event
Expand All @@ -338,20 +336,20 @@ async def predict(request: self.request_type) -> self.response_type:
content_type = request.headers.get("Content-Type", "").lower()

if self.request_type == Request:
if content_type == "application/x-www-form-urlencoded":
payload = await request.form()
elif content_type.startswith("multipart/form-data"):
if content_type == "application/x-www-form-urlencoded" or content_type.startswith("multipart/form-data"):
payload = await request.form()
elif content_type == "application/json":
try:
payload = await request.json()
except Exception as e:
logger.exception(f"Failed to parse JSON for request uid={uid} with content_type={content_type}")
except Exception:
logger.exception(
f"Failed to parse JSON for request uid={uid} with content_type={content_type}"
)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
else:
logger.error(f"Unsupported Content-Type {content_type} for request uid={uid}")
raise HTTPException(status_code=415, detail="Unsupported Content-Type")
print("Request Queue ",self.request_queue)
print("Request Queue ", self.request_queue)
self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload))
# print(f"Request uid={uid} added to request queue")
await event.wait()
Expand Down Expand Up @@ -390,7 +388,6 @@ async def stream_predict(request: self.request_type) -> self.response_type:
logger.exception(f"An error occurred while processing request uid={uid}: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")


async def stream_predict(request: self.request_type) -> self.response_type:
response_queue_id = self.app.response_queue_id
uid = uuid.uuid4()
Expand Down Expand Up @@ -458,7 +455,7 @@ def run(
api_server_worker_type: Optional[str] = None,
**kwargs,
):
try:
try:
if generate_client_file:
self.generate_client_file()

Expand Down Expand Up @@ -489,7 +486,9 @@ def run(
manager, litserve_workers = self.launch_inference_worker(num_api_servers)

try:
servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs)
servers = self._start_server(
port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs
)
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
for s in servers:
s.join()
Expand Down

0 comments on commit 2463ab3

Please sign in to comment.