Skip to content

Commit 1fcd9ba

Browse files
committed
fix: avoid circular imports with models to house BaseModel classes
1 parent e576a83 commit 1fcd9ba

File tree

3 files changed

+55
-51
lines changed

3 files changed

+55
-51
lines changed

src/handler.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from fastapi import FastAPI, HTTPException
2-
from fastapi.responses import StreamingResponse
1+
from fastapi import FastAPI, HTTPException, status
2+
from fastapi.responses import StreamingResponse, JSONResponse
33
from contextlib import asynccontextmanager
4-
from pydantic import BaseModel, Field
5-
from typing import Optional, List, Union, AsyncGenerator, Literal
4+
from typing import Optional, AsyncGenerator
65
import json
76
import logging
87
import os
@@ -12,6 +11,7 @@
1211
from vllm.sampling_params import SamplingParams
1312
from vllm.utils import random_uuid
1413
from utils import format_chat_prompt, create_error_response
14+
from .models import GenerationRequest, GenerationResponse, ChatCompletionRequest
1515

1616
# Configure logging
1717
logging.basicConfig(
@@ -39,42 +39,14 @@ async def lifespan(_: FastAPI):
3939
engine_ready = False
4040
logger.info("vLLM engine shutdown complete")
4141

42+
4243
app = FastAPI(title="vLLM Load Balancing Server", version="1.0.0", lifespan=lifespan)
4344

45+
4446
# Global variables
4547
engine: Optional[AsyncLLMEngine] = None
4648
engine_ready = False
4749

48-
class GenerationRequest(BaseModel):
49-
prompt: str
50-
max_tokens: int = Field(default=512, ge=1, le=4096)
51-
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
52-
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
53-
top_k: int = Field(default=-1, ge=-1)
54-
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
55-
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
56-
stop: Optional[Union[str, List[str]]] = None
57-
stream: bool = Field(default=False)
58-
59-
class GenerationResponse(BaseModel):
60-
text: str
61-
finish_reason: str
62-
prompt_tokens: int
63-
completion_tokens: int
64-
total_tokens: int
65-
66-
class ChatMessage(BaseModel):
67-
role: Literal["system", "user", "assistant"]
68-
content: str
69-
70-
class ChatCompletionRequest(BaseModel):
71-
messages: List[ChatMessage]
72-
max_tokens: int = Field(default=512, ge=1, le=4096)
73-
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
74-
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
75-
stop: Optional[Union[str, List[str]]] = None
76-
stream: bool = Field(default=False)
77-
7850

7951
async def create_engine():
8052
"""Initialize the vLLM engine"""
@@ -111,8 +83,11 @@ async def health_check():
11183
"""Health check endpoint required by RunPod load balancer"""
11284
if not engine_ready:
11385
logger.debug("Health check: Engine initializing")
114-
# Return 204 when initializing
115-
return {"status": "initializing"}, 204
86+
# Return 503 when initializing
87+
return JSONResponse(
88+
content={"status": "initializing"},
89+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
90+
)
11691

11792
logger.debug("Health check: Engine healthy")
11893
# Return 200 when healthy

src/models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Optional, List, Union, Literal
2+
from pydantic import BaseModel, Field
3+
4+
5+
class ChatMessage(BaseModel):
6+
role: Literal["system", "user", "assistant"]
7+
content: str
8+
9+
10+
class GenerationRequest(BaseModel):
11+
prompt: str
12+
max_tokens: int = Field(default=512, ge=1, le=4096)
13+
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
14+
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
15+
top_k: int = Field(default=-1, ge=-1)
16+
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
17+
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
18+
stop: Optional[Union[str, List[str]]] = None
19+
stream: bool = Field(default=False)
20+
21+
22+
class GenerationResponse(BaseModel):
23+
text: str
24+
finish_reason: str
25+
prompt_tokens: int
26+
completion_tokens: int
27+
total_tokens: int
28+
29+
30+
class ChatCompletionRequest(BaseModel):
31+
messages: List[ChatMessage]
32+
max_tokens: int = Field(default=512, ge=1, le=4096)
33+
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
34+
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
35+
stop: Optional[Union[str, List[str]]] = None
36+
stream: bool = Field(default=False)
37+
38+
39+
class ErrorResponse(BaseModel):
40+
error: str
41+
detail: str
42+
request_id: Optional[str] = None

src/utils.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,6 @@
1-
from typing import List, Optional
1+
from typing import List
22
from transformers import AutoTokenizer
3-
from pydantic import BaseModel
4-
5-
6-
class ChatMessage:
7-
"""Chat message for type hints in utils functions"""
8-
def __init__(self, role: str, content: str):
9-
self.role = role
10-
self.content = content
11-
12-
13-
class ErrorResponse(BaseModel):
14-
error: str
15-
detail: str
16-
request_id: Optional[str] = None
3+
from .models import ChatMessage, ErrorResponse
174

185

196
def get_tokenizer(model_name: str):

0 commit comments

Comments
 (0)