|
1 | | -from fastapi import FastAPI, HTTPException |
2 | | -from fastapi.responses import StreamingResponse |
| 1 | +from fastapi import FastAPI, HTTPException, status |
| 2 | +from fastapi.responses import StreamingResponse, JSONResponse |
3 | 3 | from contextlib import asynccontextmanager |
4 | | -from pydantic import BaseModel, Field |
5 | | -from typing import Optional, List, Union, AsyncGenerator, Literal |
| 4 | +from typing import Optional, AsyncGenerator |
6 | 5 | import json |
7 | 6 | import logging |
8 | 7 | import os |
|
12 | 11 | from vllm.sampling_params import SamplingParams |
13 | 12 | from vllm.utils import random_uuid |
14 | 13 | from utils import format_chat_prompt, create_error_response |
| 14 | +from .models import GenerationRequest, GenerationResponse, ChatCompletionRequest |
15 | 15 |
|
16 | 16 | # Configure logging |
17 | 17 | logging.basicConfig( |
@@ -39,42 +39,14 @@ async def lifespan(_: FastAPI): |
39 | 39 | engine_ready = False |
40 | 40 | logger.info("vLLM engine shutdown complete") |
41 | 41 |
|
| 42 | + |
42 | 43 | app = FastAPI(title="vLLM Load Balancing Server", version="1.0.0", lifespan=lifespan) |
43 | 44 |
|
| 45 | + |
44 | 46 | # Global variables |
45 | 47 | engine: Optional[AsyncLLMEngine] = None |
46 | 48 | engine_ready = False |
47 | 49 |
|
48 | | -class GenerationRequest(BaseModel): |
49 | | - prompt: str |
50 | | - max_tokens: int = Field(default=512, ge=1, le=4096) |
51 | | - temperature: float = Field(default=0.7, ge=0.0, le=2.0) |
52 | | - top_p: float = Field(default=0.9, ge=0.0, le=1.0) |
53 | | - top_k: int = Field(default=-1, ge=-1) |
54 | | - frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) |
55 | | - presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) |
56 | | - stop: Optional[Union[str, List[str]]] = None |
57 | | - stream: bool = Field(default=False) |
58 | | - |
59 | | -class GenerationResponse(BaseModel): |
60 | | - text: str |
61 | | - finish_reason: str |
62 | | - prompt_tokens: int |
63 | | - completion_tokens: int |
64 | | - total_tokens: int |
65 | | - |
66 | | -class ChatMessage(BaseModel): |
67 | | - role: Literal["system", "user", "assistant"] |
68 | | - content: str |
69 | | - |
70 | | -class ChatCompletionRequest(BaseModel): |
71 | | - messages: List[ChatMessage] |
72 | | - max_tokens: int = Field(default=512, ge=1, le=4096) |
73 | | - temperature: float = Field(default=0.7, ge=0.0, le=2.0) |
74 | | - top_p: float = Field(default=0.9, ge=0.0, le=1.0) |
75 | | - stop: Optional[Union[str, List[str]]] = None |
76 | | - stream: bool = Field(default=False) |
77 | | - |
78 | 50 |
|
79 | 51 | async def create_engine(): |
80 | 52 | """Initialize the vLLM engine""" |
@@ -111,8 +83,11 @@ async def health_check(): |
111 | 83 | """Health check endpoint required by RunPod load balancer""" |
112 | 84 | if not engine_ready: |
113 | 85 | logger.debug("Health check: Engine initializing") |
114 | | - # Return 204 when initializing |
115 | | - return {"status": "initializing"}, 204 |
| 86 | + # Return 503 when initializing |
| 87 | + return JSONResponse( |
| 88 | + content={"status": "initializing"}, |
| 89 | + status_code=status.HTTP_503_SERVICE_UNAVAILABLE |
| 90 | + ) |
116 | 91 |
|
117 | 92 | logger.debug("Health check: Engine healthy") |
118 | 93 | # Return 200 when healthy |
|
0 commit comments