11from fastapi import FastAPI , HTTPException
22from fastapi .responses import StreamingResponse
3+ from contextlib import asynccontextmanager
34from pydantic import BaseModel , Field
4- from typing import Optional , List , Union , AsyncGenerator
5- import asyncio
5+ from typing import Optional , List , Union , AsyncGenerator , Literal
66import json
7+ import logging
78import os
89import uvicorn
910from vllm import AsyncLLMEngine
1011from vllm .engine .arg_utils import AsyncEngineArgs
1112from vllm .sampling_params import SamplingParams
1213from vllm .utils import random_uuid
14+ from utils import format_chat_prompt , create_error_response
1315
14- app = FastAPI (title = "vLLM Load Balancing Server" , version = "1.0.0" )
16+ # Configure logging
17+ logging .basicConfig (
18+ level = logging .INFO ,
19+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ,
20+ handlers = [
21+ logging .StreamHandler (),
22+ ]
23+ )
24+ logger = logging .getLogger (__name__ )
25+
26+ @asynccontextmanager
27+ async def lifespan (_ : FastAPI ):
28+ """Initialize the vLLM engine on startup and cleanup on shutdown"""
29+ # Startup
30+ await create_engine ()
31+ yield
32+ # Shutdown cleanup
33+ global engine , engine_ready
34+ if engine :
35+ logger .info ("Shutting down vLLM engine..." )
36+ # vLLM AsyncLLMEngine doesn't have an explicit shutdown method,
37+ # but we can clean up our references
38+ engine = None
39+ engine_ready = False
40+ logger .info ("vLLM engine shutdown complete" )
41+
42+ app = FastAPI (title = "vLLM Load Balancing Server" , version = "1.0.0" , lifespan = lifespan )
1543
1644# Global variables
1745engine : Optional [AsyncLLMEngine ] = None
@@ -35,6 +63,19 @@ class GenerationResponse(BaseModel):
3563 completion_tokens : int
3664 total_tokens : int
3765
66+ class ChatMessage (BaseModel ):
67+ role : Literal ["system" , "user" , "assistant" ]
68+ content : str
69+
70+ class ChatCompletionRequest (BaseModel ):
71+ messages : List [ChatMessage ]
72+ max_tokens : int = Field (default = 512 , ge = 1 , le = 4096 )
73+ temperature : float = Field (default = 0.7 , ge = 0.0 , le = 2.0 )
74+ top_p : float = Field (default = 0.9 , ge = 0.0 , le = 1.0 )
75+ stop : Optional [Union [str , List [str ]]] = None
76+ stream : bool = Field (default = False )
77+
78+
3879async def create_engine ():
3980 """Initialize the vLLM engine"""
4081 global engine , engine_ready
@@ -46,36 +87,34 @@ async def create_engine():
4687 # Configure engine arguments
4788 engine_args = AsyncEngineArgs (
4889 model = model_name ,
49- tensor_parallel_size = 1 , # Adjust based on your GPU setup
50- dtype = " auto" ,
51- trust_remote_code = True ,
52- max_model_len = None , # Let vLLM decide based on model
53- gpu_memory_utilization = 0.9 ,
54- enforce_eager = False ,
90+ tensor_parallel_size = int ( os . getenv ( "TENSOR_PARALLEL_SIZE" , "1" )),
91+ dtype = os . getenv ( "DTYPE" , " auto") ,
92+ trust_remote_code = os . getenv ( "TRUST_REMOTE_CODE" , "true" ). lower () == "true" ,
93+ max_model_len = int ( os . getenv ( "MAX_MODEL_LEN" )) if os . getenv ( "MAX_MODEL_LEN" ) else None ,
94+ gpu_memory_utilization = float ( os . getenv ( "GPU_MEMORY_UTILIZATION" , " 0.9" )) ,
95+ enforce_eager = os . getenv ( "ENFORCE_EAGER" , "false" ). lower () == "true" ,
5596 )
5697
5798 # Create the engine
5899 engine = AsyncLLMEngine .from_engine_args (engine_args )
59100 engine_ready = True
60- print (f"vLLM engine initialized successfully with model: { model_name } " )
101+ logger . info (f"vLLM engine initialized successfully with model: { model_name } " )
61102
62103 except Exception as e :
63- print (f"Failed to initialize vLLM engine: { str (e )} " )
104+ logger . error (f"Failed to initialize vLLM engine: { str (e )} " )
64105 engine_ready = False
65106 raise
66107
67- @app .on_event ("startup" )
68- async def startup_event ():
69- """Initialize the vLLM engine on startup"""
70- await create_engine ()
71108
72109@app .get ("/ping" )
73110async def health_check ():
74111 """Health check endpoint required by RunPod load balancer"""
75112 if not engine_ready :
113+ logger .debug ("Health check: Engine initializing" )
76114 # Return 204 when initializing
77115 return {"status" : "initializing" }, 204
78116
117+ logger .debug ("Health check: Engine healthy" )
79118 # Return 200 when healthy
80119 return {"status" : "healthy" }
81120
@@ -95,8 +134,12 @@ async def root():
95134@app .post ("/v1/completions" , response_model = GenerationResponse )
96135async def generate_completion (request : GenerationRequest ):
97136 """Generate text completion"""
137+ logger .info (f"Received completion request: max_tokens={ request .max_tokens } , temperature={ request .temperature } , stream={ request .stream } " )
138+
98139 if not engine_ready or engine is None :
99- raise HTTPException (status_code = 503 , detail = "Engine not ready" )
140+ logger .warning ("Completion request rejected: Engine not ready" )
141+ error_response = create_error_response ("ServiceUnavailable" , "Engine not ready" )
142+ raise HTTPException (status_code = 503 , detail = error_response .model_dump ())
100143
101144 try :
102145 # Create sampling parameters
@@ -126,15 +169,23 @@ async def generate_completion(request: GenerationRequest):
126169 final_output = output
127170
128171 if final_output is None :
129- raise HTTPException (status_code = 500 , detail = "No output generated" )
172+ request_id = random_uuid ()
173+ error_response = create_error_response ("GenerationError" , "No output generated" , request_id )
174+ raise HTTPException (status_code = 500 , detail = error_response .model_dump ())
130175
131176 generated_text = final_output .outputs [0 ].text
132177 finish_reason = final_output .outputs [0 ].finish_reason
133178
134- # Calculate token counts (approximate)
135- prompt_tokens = len (request .prompt .split ())
136- completion_tokens = len (generated_text .split ())
179+ # Calculate token counts using actual token IDs when available
180+ if hasattr (final_output , 'prompt_token_ids' ) and final_output .prompt_token_ids is not None :
181+ prompt_tokens = len (final_output .prompt_token_ids )
182+ else :
183+ # Fallback to approximate word count
184+ prompt_tokens = len (request .prompt .split ())
137185
186+ completion_tokens = len (final_output .outputs [0 ].token_ids )
187+
188+ logger .info (f"Completion generated: { completion_tokens } tokens, finish_reason={ finish_reason } " )
138189 return GenerationResponse (
139190 text = generated_text ,
140191 finish_reason = finish_reason ,
@@ -144,7 +195,10 @@ async def generate_completion(request: GenerationRequest):
144195 )
145196
146197 except Exception as e :
147- raise HTTPException (status_code = 500 , detail = f"Generation failed: { str (e )} " )
198+ request_id = random_uuid ()
199+ logger .error (f"Generation failed (request_id={ request_id } ): { str (e )} " , exc_info = True )
200+ error_response = create_error_response ("GenerationError" , f"Generation failed: { str (e )} " , request_id )
201+ raise HTTPException (status_code = 500 , detail = error_response .model_dump ())
148202
149203async def stream_completion (prompt : str , sampling_params : SamplingParams , request_id : str ) -> AsyncGenerator [str , None ]:
150204 """Stream completion generator"""
@@ -160,31 +214,32 @@ async def stream_completion(prompt: str, sampling_params: SamplingParams, reques
160214 yield f"data: { json .dumps ({'error' : str (e )})} \n \n "
161215
162216@app .post ("/v1/chat/completions" )
163- async def chat_completions (request : dict ):
217+ async def chat_completions (request : ChatCompletionRequest ):
164218 """OpenAI-compatible chat completions endpoint"""
219+ logger .info (f"Received chat completion request: { len (request .messages )} messages, max_tokens={ request .max_tokens } , temperature={ request .temperature } " )
220+
165221 if not engine_ready or engine is None :
166- raise HTTPException (status_code = 503 , detail = "Engine not ready" )
222+ logger .warning ("Chat completion request rejected: Engine not ready" )
223+ error_response = create_error_response ("ServiceUnavailable" , "Engine not ready" )
224+ raise HTTPException (status_code = 503 , detail = error_response .model_dump ())
167225
168226 try :
169227 # Extract messages and convert to prompt
170- messages = request .get ( " messages" , [])
228+ messages = request .messages
171229 if not messages :
172- raise HTTPException (status_code = 400 , detail = "No messages provided" )
230+ error_response = create_error_response ("ValidationError" , "No messages provided" )
231+ raise HTTPException (status_code = 400 , detail = error_response .model_dump ())
173232
174- # Simple conversion of messages to prompt (you may want to improve this)
175- prompt = ""
176- for message in messages :
177- role = message .get ("role" , "user" )
178- content = message .get ("content" , "" )
179- prompt += f"{ role } : { content } \n "
180- prompt += "assistant: "
233+ # Use proper chat template formatting
234+ model_name = os .getenv ("MODEL_NAME" , "microsoft/DialoGPT-medium" )
235+ prompt = format_chat_prompt (messages , model_name )
181236
182237 # Create sampling parameters from request
183238 sampling_params = SamplingParams (
184- max_tokens = request .get ( " max_tokens" , 512 ) ,
185- temperature = request .get ( " temperature" , 0.7 ) ,
186- top_p = request .get ( " top_p" , 0.9 ) ,
187- stop = request .get ( " stop" ) ,
239+ max_tokens = request .max_tokens ,
240+ temperature = request .temperature ,
241+ top_p = request .top_p ,
242+ stop = request .stop ,
188243 )
189244
190245 # Generate
@@ -195,9 +250,12 @@ async def chat_completions(request: dict):
195250 final_output = output
196251
197252 if final_output is None :
198- raise HTTPException (status_code = 500 , detail = "No output generated" )
253+ error_response = create_error_response ("GenerationError" , "No output generated" , request_id )
254+ raise HTTPException (status_code = 500 , detail = error_response .model_dump ())
199255
200256 generated_text = final_output .outputs [0 ].text
257+ completion_tokens = len (final_output .outputs [0 ].token_ids )
258+ logger .info (f"Chat completion generated: { completion_tokens } tokens, finish_reason={ final_output .outputs [0 ].finish_reason } " )
201259
202260 # Return OpenAI-compatible response
203261 return {
@@ -213,19 +271,22 @@ async def chat_completions(request: dict):
213271 "finish_reason" : final_output .outputs [0 ].finish_reason
214272 }],
215273 "usage" : {
216- "prompt_tokens" : len (prompt .split ()),
217- "completion_tokens" : len (generated_text . split () ),
218- "total_tokens" : len (prompt .split ()) + len (generated_text . split () )
274+ "prompt_tokens" : len (final_output . prompt_token_ids ) if hasattr ( final_output , 'prompt_token_ids' ) and final_output . prompt_token_ids is not None else len ( prompt .split ()),
275+ "completion_tokens" : len (final_output . outputs [ 0 ]. token_ids ),
276+ "total_tokens" : ( len (final_output . prompt_token_ids ) if hasattr ( final_output , 'prompt_token_ids' ) and final_output . prompt_token_ids is not None else len ( prompt .split ())) + len (final_output . outputs [ 0 ]. token_ids )
219277 }
220278 }
221279
222280 except Exception as e :
223- raise HTTPException (status_code = 500 , detail = f"Chat completion failed: { str (e )} " )
281+ request_id = random_uuid ()
282+ logger .error (f"Chat completion failed (request_id={ request_id } ): { str (e )} " , exc_info = True )
283+ error_response = create_error_response ("ChatCompletionError" , f"Chat completion failed: { str (e )} " , request_id )
284+ raise HTTPException (status_code = 500 , detail = error_response .model_dump ())
224285
225286if __name__ == "__main__" :
226287 # Get ports from environment variables
227288 port = int (os .getenv ("PORT" , 8000 ))
228- print (f"Starting vLLM server on port { port } " )
289+ logger . info (f"Starting vLLM server on port { port } " )
229290
230291 # If health port is different, you'd need to run a separate health server
231292 # For simplicity, we're using the same port here
@@ -235,4 +296,4 @@ async def chat_completions(request: dict):
235296 host = "0.0.0.0" ,
236297 port = port ,
237298 log_level = "info"
238- )
299+ )
0 commit comments