Skip to content

Commit e576a83

Browse files
committed
fix: address PR feedback - add env config, fix imports,improve docs
1 parent 771ea50 commit e576a83

File tree

4 files changed

+191
-52
lines changed

4 files changed

+191
-52
lines changed

Dockerfile

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \
1111
python3 -m pip install --upgrade pip && \
1212
python3 -m pip install --upgrade -r /requirements.txt
1313

14-
# Install vLLM (switching back to pip installs since issues that required building fork are fixed and space optimization is not as important since caching) and FlashInfer
15-
RUN python3 -m pip install vllm==0.9.1 && \
16-
python3 -m pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3
17-
18-
# Setup for Option 2: Building the Image with the Model included
14+
# Pin vLLM version for stability - 0.9.1 is latest stable as of 2024-07
15+
# FlashInfer provides optimized attention for better performance
16+
ARG VLLM_VERSION=0.9.1
17+
ARG CUDA_VERSION=cu121
18+
ARG TORCH_VERSION=torch2.3
1919

20+
RUN python3 -m pip install vllm==${VLLM_VERSION} && \
21+
python3 -m pip install flashinfer -i https://flashinfer.ai/whl/${CUDA_VERSION}/${TORCH_VERSION}
2022

2123
ENV PYTHONPATH="/:/vllm-workspace"
2224

23-
2425
COPY src /src
2526

26-
# Start the handler
27-
CMD ["python3", "/src/handler.py"]
27+
WORKDIR /src
28+
29+
CMD ["python3", "handler.py"]

README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,39 @@
22

33
A FastAPI-based load balancer for serving vLLM models with RunPod integration. Provides OpenAI-compatible APIs with streaming and non-streaming text generation.
44

5+
## Prerequisites
6+
7+
Before you begin, make sure you have:
8+
9+
- A RunPod account (sign up at [runpod.io](https://runpod.io))
10+
- RunPod API key (available in your RunPod dashboard)
11+
- Basic understanding of REST APIs and HTTP requests
12+
- `curl` or a similar tool for testing API endpoints
13+
514
## Docker Image
615

716
Use the pre-built Docker image: `runpod/vllm-loadbalancer:dev`
817

18+
## Environment Variables
19+
20+
Configure these environment variables in your RunPod endpoint:
21+
22+
| Variable | Required | Description | Default | Example |
23+
|----------|----------|-------------|---------|---------|
24+
| `MODEL_NAME` | **Yes** | HuggingFace model identifier | None | `microsoft/DialoGPT-medium` |
25+
| `TENSOR_PARALLEL_SIZE` | No | Number of GPUs for model parallelism | `1` | `2` |
26+
| `DTYPE` | No | Model precision type | `auto` | `float16` |
27+
| `TRUST_REMOTE_CODE` | No | Allow remote code execution | `true` | `false` |
28+
| `MAX_MODEL_LEN` | No | Maximum sequence length | None (auto) | `2048` |
29+
| `GPU_MEMORY_UTILIZATION` | No | GPU memory usage ratio | `0.9` | `0.8` |
30+
| `ENFORCE_EAGER` | No | Disable CUDA graphs | `false` | `true` |
31+
932
## Deployment on RunPod
1033

1134
1. Create a new serverless endpoint
1235
2. Use Docker image: `runpod/vllm-loadbalancer:dev`
13-
3. Set environment variable: `MODEL_NAME` (e.g., "microsoft/DialoGPT-medium")
36+
3. Set required environment variable: `MODEL_NAME` (e.g., "microsoft/DialoGPT-medium")
37+
4. Optional: Configure additional environment variables as needed
1438

1539
## API Usage with curl
1640

src/handler.py

Lines changed: 104 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,45 @@
11
from fastapi import FastAPI, HTTPException
22
from fastapi.responses import StreamingResponse
3+
from contextlib import asynccontextmanager
34
from pydantic import BaseModel, Field
4-
from typing import Optional, List, Union, AsyncGenerator
5-
import asyncio
5+
from typing import Optional, List, Union, AsyncGenerator, Literal
66
import json
7+
import logging
78
import os
89
import uvicorn
910
from vllm import AsyncLLMEngine
1011
from vllm.engine.arg_utils import AsyncEngineArgs
1112
from vllm.sampling_params import SamplingParams
1213
from vllm.utils import random_uuid
14+
from utils import format_chat_prompt, create_error_response
1315

14-
app = FastAPI(title="vLLM Load Balancing Server", version="1.0.0")
16+
# Configure logging
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20+
handlers=[
21+
logging.StreamHandler(),
22+
]
23+
)
24+
logger = logging.getLogger(__name__)
25+
26+
@asynccontextmanager
27+
async def lifespan(_: FastAPI):
28+
"""Initialize the vLLM engine on startup and cleanup on shutdown"""
29+
# Startup
30+
await create_engine()
31+
yield
32+
# Shutdown cleanup
33+
global engine, engine_ready
34+
if engine:
35+
logger.info("Shutting down vLLM engine...")
36+
# vLLM AsyncLLMEngine doesn't have an explicit shutdown method,
37+
# but we can clean up our references
38+
engine = None
39+
engine_ready = False
40+
logger.info("vLLM engine shutdown complete")
41+
42+
app = FastAPI(title="vLLM Load Balancing Server", version="1.0.0", lifespan=lifespan)
1543

1644
# Global variables
1745
engine: Optional[AsyncLLMEngine] = None
@@ -35,6 +63,19 @@ class GenerationResponse(BaseModel):
3563
completion_tokens: int
3664
total_tokens: int
3765

66+
class ChatMessage(BaseModel):
67+
role: Literal["system", "user", "assistant"]
68+
content: str
69+
70+
class ChatCompletionRequest(BaseModel):
71+
messages: List[ChatMessage]
72+
max_tokens: int = Field(default=512, ge=1, le=4096)
73+
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
74+
top_p: float = Field(default=0.9, ge=0.0, le=1.0)
75+
stop: Optional[Union[str, List[str]]] = None
76+
stream: bool = Field(default=False)
77+
78+
3879
async def create_engine():
3980
"""Initialize the vLLM engine"""
4081
global engine, engine_ready
@@ -46,36 +87,34 @@ async def create_engine():
4687
# Configure engine arguments
4788
engine_args = AsyncEngineArgs(
4889
model=model_name,
49-
tensor_parallel_size=1, # Adjust based on your GPU setup
50-
dtype="auto",
51-
trust_remote_code=True,
52-
max_model_len=None, # Let vLLM decide based on model
53-
gpu_memory_utilization=0.9,
54-
enforce_eager=False,
90+
tensor_parallel_size=int(os.getenv("TENSOR_PARALLEL_SIZE", "1")),
91+
dtype=os.getenv("DTYPE", "auto"),
92+
trust_remote_code=os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true",
93+
max_model_len=int(os.getenv("MAX_MODEL_LEN")) if os.getenv("MAX_MODEL_LEN") else None,
94+
gpu_memory_utilization=float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")),
95+
enforce_eager=os.getenv("ENFORCE_EAGER", "false").lower() == "true",
5596
)
5697

5798
# Create the engine
5899
engine = AsyncLLMEngine.from_engine_args(engine_args)
59100
engine_ready = True
60-
print(f"vLLM engine initialized successfully with model: {model_name}")
101+
logger.info(f"vLLM engine initialized successfully with model: {model_name}")
61102

62103
except Exception as e:
63-
print(f"Failed to initialize vLLM engine: {str(e)}")
104+
logger.error(f"Failed to initialize vLLM engine: {str(e)}")
64105
engine_ready = False
65106
raise
66107

67-
@app.on_event("startup")
68-
async def startup_event():
69-
"""Initialize the vLLM engine on startup"""
70-
await create_engine()
71108

72109
@app.get("/ping")
73110
async def health_check():
74111
"""Health check endpoint required by RunPod load balancer"""
75112
if not engine_ready:
113+
logger.debug("Health check: Engine initializing")
76114
# Return 204 when initializing
77115
return {"status": "initializing"}, 204
78116

117+
logger.debug("Health check: Engine healthy")
79118
# Return 200 when healthy
80119
return {"status": "healthy"}
81120

@@ -95,8 +134,12 @@ async def root():
95134
@app.post("/v1/completions", response_model=GenerationResponse)
96135
async def generate_completion(request: GenerationRequest):
97136
"""Generate text completion"""
137+
logger.info(f"Received completion request: max_tokens={request.max_tokens}, temperature={request.temperature}, stream={request.stream}")
138+
98139
if not engine_ready or engine is None:
99-
raise HTTPException(status_code=503, detail="Engine not ready")
140+
logger.warning("Completion request rejected: Engine not ready")
141+
error_response = create_error_response("ServiceUnavailable", "Engine not ready")
142+
raise HTTPException(status_code=503, detail=error_response.model_dump())
100143

101144
try:
102145
# Create sampling parameters
@@ -126,15 +169,23 @@ async def generate_completion(request: GenerationRequest):
126169
final_output = output
127170

128171
if final_output is None:
129-
raise HTTPException(status_code=500, detail="No output generated")
172+
request_id = random_uuid()
173+
error_response = create_error_response("GenerationError", "No output generated", request_id)
174+
raise HTTPException(status_code=500, detail=error_response.model_dump())
130175

131176
generated_text = final_output.outputs[0].text
132177
finish_reason = final_output.outputs[0].finish_reason
133178

134-
# Calculate token counts (approximate)
135-
prompt_tokens = len(request.prompt.split())
136-
completion_tokens = len(generated_text.split())
179+
# Calculate token counts using actual token IDs when available
180+
if hasattr(final_output, 'prompt_token_ids') and final_output.prompt_token_ids is not None:
181+
prompt_tokens = len(final_output.prompt_token_ids)
182+
else:
183+
# Fallback to approximate word count
184+
prompt_tokens = len(request.prompt.split())
137185

186+
completion_tokens = len(final_output.outputs[0].token_ids)
187+
188+
logger.info(f"Completion generated: {completion_tokens} tokens, finish_reason={finish_reason}")
138189
return GenerationResponse(
139190
text=generated_text,
140191
finish_reason=finish_reason,
@@ -144,7 +195,10 @@ async def generate_completion(request: GenerationRequest):
144195
)
145196

146197
except Exception as e:
147-
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
198+
request_id = random_uuid()
199+
logger.error(f"Generation failed (request_id={request_id}): {str(e)}", exc_info=True)
200+
error_response = create_error_response("GenerationError", f"Generation failed: {str(e)}", request_id)
201+
raise HTTPException(status_code=500, detail=error_response.model_dump())
148202

149203
async def stream_completion(prompt: str, sampling_params: SamplingParams, request_id: str) -> AsyncGenerator[str, None]:
150204
"""Stream completion generator"""
@@ -160,31 +214,32 @@ async def stream_completion(prompt: str, sampling_params: SamplingParams, reques
160214
yield f"data: {json.dumps({'error': str(e)})}\n\n"
161215

162216
@app.post("/v1/chat/completions")
163-
async def chat_completions(request: dict):
217+
async def chat_completions(request: ChatCompletionRequest):
164218
"""OpenAI-compatible chat completions endpoint"""
219+
logger.info(f"Received chat completion request: {len(request.messages)} messages, max_tokens={request.max_tokens}, temperature={request.temperature}")
220+
165221
if not engine_ready or engine is None:
166-
raise HTTPException(status_code=503, detail="Engine not ready")
222+
logger.warning("Chat completion request rejected: Engine not ready")
223+
error_response = create_error_response("ServiceUnavailable", "Engine not ready")
224+
raise HTTPException(status_code=503, detail=error_response.model_dump())
167225

168226
try:
169227
# Extract messages and convert to prompt
170-
messages = request.get("messages", [])
228+
messages = request.messages
171229
if not messages:
172-
raise HTTPException(status_code=400, detail="No messages provided")
230+
error_response = create_error_response("ValidationError", "No messages provided")
231+
raise HTTPException(status_code=400, detail=error_response.model_dump())
173232

174-
# Simple conversion of messages to prompt (you may want to improve this)
175-
prompt = ""
176-
for message in messages:
177-
role = message.get("role", "user")
178-
content = message.get("content", "")
179-
prompt += f"{role}: {content}\n"
180-
prompt += "assistant: "
233+
# Use proper chat template formatting
234+
model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
235+
prompt = format_chat_prompt(messages, model_name)
181236

182237
# Create sampling parameters from request
183238
sampling_params = SamplingParams(
184-
max_tokens=request.get("max_tokens", 512),
185-
temperature=request.get("temperature", 0.7),
186-
top_p=request.get("top_p", 0.9),
187-
stop=request.get("stop"),
239+
max_tokens=request.max_tokens,
240+
temperature=request.temperature,
241+
top_p=request.top_p,
242+
stop=request.stop,
188243
)
189244

190245
# Generate
@@ -195,9 +250,12 @@ async def chat_completions(request: dict):
195250
final_output = output
196251

197252
if final_output is None:
198-
raise HTTPException(status_code=500, detail="No output generated")
253+
error_response = create_error_response("GenerationError", "No output generated", request_id)
254+
raise HTTPException(status_code=500, detail=error_response.model_dump())
199255

200256
generated_text = final_output.outputs[0].text
257+
completion_tokens = len(final_output.outputs[0].token_ids)
258+
logger.info(f"Chat completion generated: {completion_tokens} tokens, finish_reason={final_output.outputs[0].finish_reason}")
201259

202260
# Return OpenAI-compatible response
203261
return {
@@ -213,19 +271,22 @@ async def chat_completions(request: dict):
213271
"finish_reason": final_output.outputs[0].finish_reason
214272
}],
215273
"usage": {
216-
"prompt_tokens": len(prompt.split()),
217-
"completion_tokens": len(generated_text.split()),
218-
"total_tokens": len(prompt.split()) + len(generated_text.split())
274+
"prompt_tokens": len(final_output.prompt_token_ids) if hasattr(final_output, 'prompt_token_ids') and final_output.prompt_token_ids is not None else len(prompt.split()),
275+
"completion_tokens": len(final_output.outputs[0].token_ids),
276+
"total_tokens": (len(final_output.prompt_token_ids) if hasattr(final_output, 'prompt_token_ids') and final_output.prompt_token_ids is not None else len(prompt.split())) + len(final_output.outputs[0].token_ids)
219277
}
220278
}
221279

222280
except Exception as e:
223-
raise HTTPException(status_code=500, detail=f"Chat completion failed: {str(e)}")
281+
request_id = random_uuid()
282+
logger.error(f"Chat completion failed (request_id={request_id}): {str(e)}", exc_info=True)
283+
error_response = create_error_response("ChatCompletionError", f"Chat completion failed: {str(e)}", request_id)
284+
raise HTTPException(status_code=500, detail=error_response.model_dump())
224285

225286
if __name__ == "__main__":
226287
# Get ports from environment variables
227288
port = int(os.getenv("PORT", 8000))
228-
print(f"Starting vLLM server on port {port}")
289+
logger.info(f"Starting vLLM server on port {port}")
229290

230291
# If health port is different, you'd need to run a separate health server
231292
# For simplicity, we're using the same port here
@@ -235,4 +296,4 @@ async def chat_completions(request: dict):
235296
host="0.0.0.0",
236297
port=port,
237298
log_level="info"
238-
)
299+
)

src/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import List, Optional
2+
from transformers import AutoTokenizer
3+
from pydantic import BaseModel
4+
5+
6+
class ChatMessage:
7+
"""Chat message for type hints in utils functions"""
8+
def __init__(self, role: str, content: str):
9+
self.role = role
10+
self.content = content
11+
12+
13+
class ErrorResponse(BaseModel):
14+
error: str
15+
detail: str
16+
request_id: Optional[str] = None
17+
18+
19+
def get_tokenizer(model_name: str):
20+
"""Get tokenizer for the given model"""
21+
return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22+
23+
24+
def format_chat_prompt(messages: List[ChatMessage], model_name: str) -> str:
25+
"""Format messages using the model's chat template"""
26+
tokenizer = get_tokenizer(model_name)
27+
28+
# Use model's built-in chat template if available
29+
if hasattr(tokenizer, 'apply_chat_template'):
30+
message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages]
31+
return tokenizer.apply_chat_template(
32+
message_dicts,
33+
tokenize=False,
34+
add_generation_prompt=True
35+
)
36+
37+
# Fallback to common format
38+
formatted_prompt = ""
39+
for message in messages:
40+
if message.role == "system":
41+
formatted_prompt += f"System: {message.content}\n\n"
42+
elif message.role == "user":
43+
formatted_prompt += f"Human: {message.content}\n\n"
44+
elif message.role == "assistant":
45+
formatted_prompt += f"Assistant: {message.content}\n\n"
46+
47+
formatted_prompt += "Assistant: "
48+
return formatted_prompt
49+
50+
51+
def create_error_response(error: str, detail: str, request_id: str = None) -> ErrorResponse:
52+
return ErrorResponse(error=error, detail=detail, request_id=request_id)

0 commit comments

Comments
 (0)