@@ -65,6 +65,20 @@ def after_request(response):
6565
6666 'gpto1' : 'gpto1' ,
6767 'o1' : 'gpto1' ,
68+
69+ 'gpto3' : 'gpto3' ,
70+ 'gpto4mini' : 'gpto4mini' ,
71+ 'gpt41' : 'gpt41' ,
72+ 'gpt41mini' : 'gpt41mini' ,
73+ 'gpt41nano' : 'gpt41nano' ,
74+
75+
76+ 'gemini25pro' : 'gemini25pro' ,
77+ 'gemini25flash' : 'gemini25flash' ,
78+ 'claudeopus4' : 'claudeopus4' ,
79+ 'claudesonnet4' : 'claudesonnet4' ,
80+ 'claudesonnet37' : 'claudesonnet37' ,
81+ 'claudesonnet35v2' : 'claudesonnet35v2' ,
6882}
6983
7084
@@ -108,9 +122,25 @@ def after_request(response):
108122 # Models using development environment
109123 'gpto3mini' : 'dev' ,
110124 'gpto1mini' : 'dev' ,
111- 'gpto1' : 'dev'
125+ 'gpto1' : 'dev' ,
126+ 'gemini25pro' : 'dev' ,
127+ 'gemini25flash' : 'dev' ,
128+ 'claudeopus4' : 'dev' ,
129+ 'claudesonnet4' : 'dev' ,
130+ 'claudesonnet37' : 'dev' ,
131+ 'claudesonnet35v2' : 'dev' ,
132+ 'gpto3' : 'dev' ,
133+ 'gpto4mini' : 'dev' ,
134+ 'gpt41' : 'dev' ,
135+ 'gpt41mini' : 'dev' ,
136+ 'gpt41nano' : 'dev' ,
112137}
113138
139+
140+ NON_STREAMING_MODELS = ['gemini25pro' , 'gemini25flash' ,
141+ 'claudeopus4' , 'claudesonnet4' , 'claudesonnet37' , 'claudesonnet35v2' ,
142+ 'gpto3' , 'gpto4mini' , 'gpt41' , 'gpt41mini' , 'gpt41nano' ,]
143+
114144# For models endpoint
115145MODELS = {
116146 "object" : "list" ,
@@ -131,10 +161,30 @@ def after_request(response):
131161EMBED_ENV = 'prod'
132162
133163DEFAULT_MODEL = "gpt4o"
134- ANL_USER = "APS "
164+ BRIDGE_USER = "ARGO_BRIDGE "
135165ANL_STREAM_URL = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/streamchat/"
136166ANL_DEBUG_FP = 'log_bridge.log'
137167
168+
169+ def get_user_from_auth_header ():
170+ """
171+ Extracts the user from the Authorization header.
172+ If the header is present and valid, the bearer token is returned.
173+ Otherwise, the default user is returned.
174+ """
175+ auth_header = request .headers .get ("Authorization" )
176+ if auth_header and auth_header .startswith ("Bearer " ):
177+ # Return the token part of the header
178+ token = auth_header .split (" " )[1 ]
179+ logging .debug (f"Authorization header found: { auth_header } " )
180+ if token == 'noop' :
181+ return BRIDGE_USER
182+
183+ return auth_header .split (" " )[1 ]
184+ # Return the default user if no valid header is found
185+ return BRIDGE_USER
186+
187+
138188def get_api_url (model , endpoint_type ):
139189 """
140190 Determine the correct API URL based on model and endpoint type
@@ -169,11 +219,18 @@ def chat_completions():
169219 logging .info ("Received chat completions request" )
170220
171221 data = request .get_json ()
222+ logging .info (f"Request Data: { data } " )
172223 model_base = data .get ("model" , DEFAULT_MODEL )
173224 is_streaming = data .get ("stream" , False )
174225 temperature = data .get ("temperature" , 0.1 )
175226 stop = data .get ("stop" , [])
176227
228+ # Force non-streaming for specific models. Remove once Argo supports streaming for all models.
229+ # TODO: TEMP Fake streaming for the new models until Argo supports it
230+ is_fake_stream = False
231+ if model_base in NON_STREAMING_MODELS and is_streaming :
232+ is_fake_stream = True
233+
177234 if model_base not in MODEL_MAPPING :
178235 return jsonify ({"error" : {
179236 "message" : f"Model '{ model_base } ' not supported."
@@ -183,8 +240,19 @@ def chat_completions():
183240
184241 logging .debug (f"Received request: { data } " )
185242
243+ # Process multimodal content for Gemini models
244+ if model_base .startswith ('gemini' ):
245+ try :
246+ data ['messages' ] = convert_multimodal_to_text (data ['messages' ], model_base )
247+ except ValueError as e :
248+ return jsonify ({"error" : {
249+ "message" : str (e )
250+ }}), 400
251+
252+ user = get_user_from_auth_header ()
253+
186254 req_obj = {
187- "user" : ANL_USER ,
255+ "user" : user ,
188256 "model" : model ,
189257 "messages" : data ['messages' ],
190258 "system" : "" ,
@@ -194,7 +262,22 @@ def chat_completions():
194262
195263 logging .debug (f"Argo Request { req_obj } " )
196264
197- if is_streaming :
265+ if is_fake_stream :
266+ logging .info (req_obj )
267+ response = requests .post (get_api_url (model , 'chat' ), json = req_obj )
268+
269+ if not response .ok :
270+ logging .error (f"Internal API error: { response .status_code } { response .reason } " )
271+ return jsonify ({"error" : {
272+ "message" : f"Internal API error: { response .status_code } { response .reason } "
273+ }}), 500
274+
275+ json_response = response .json ()
276+ text = json_response .get ("response" , "" )
277+ logging .debug (f"Response Text { text } " )
278+ return Response (_fake_stream_response (text , model ), mimetype = 'text/event-stream' )
279+
280+ elif is_streaming :
198281 return Response (_stream_chat_response (model , req_obj ), mimetype = 'text/event-stream' )
199282 else :
200283 response = requests .post (get_api_url (model , 'chat' ), json = req_obj )
@@ -280,6 +363,98 @@ def _static_chat_response(text, model):
280363 }]
281364 }
282365
366+ def _fake_stream_response (text , model ):
367+ begin_chunk = {
368+ "id" : 'abc' ,
369+ "object" : "chat.completion.chunk" ,
370+ "created" : int (datetime .datetime .now ().timestamp ()),
371+ "model" : model ,
372+ "choices" : [{
373+ "index" : 0 ,
374+ "delta" : {'role' : 'assistant' , 'content' :'' },
375+ "logprobs" : None ,
376+ "finish_reason" : None
377+ }]
378+ }
379+ yield f"data: { json .dumps (begin_chunk )} \n \n "
380+ chunk = {
381+ "id" : 'abc' ,
382+ "object" : "chat.completion.chunk" ,
383+ "created" : int (datetime .datetime .now ().timestamp ()),
384+ "model" : model ,
385+ "choices" : [{
386+ "index" : 0 ,
387+ "delta" : {'content' : text },
388+ "logprobs" : None ,
389+ "finish_reason" : None
390+ }]
391+ }
392+ yield f"data: { json .dumps (chunk )} \n \n "
393+ end_chunk = {
394+ "id" : 'argo' ,
395+ "object" : "chat.completion.chunk" ,
396+ "created" : int (datetime .datetime .now ().timestamp ()),
397+ "model" : model ,
398+ "system_fingerprint" : "fp_44709d6fcb" ,
399+ "choices" : [{
400+ "index" : 0 ,
401+ "delta" : {},
402+ "logprobs" : None ,
403+ "finish_reason" : "stop"
404+ }]
405+ }
406+ yield f"data: { json .dumps (end_chunk )} \n \n "
407+ yield "data: [DONE]\n \n "
408+
409+ def convert_multimodal_to_text (messages , model_base ):
410+ """
411+ Convert multimodal content format to plain text for Gemini models.
412+
413+ Args:
414+ messages (list): List of message objects
415+ model_base (str): The model being used
416+
417+ Returns:
418+ list: Processed messages with text-only content
419+
420+ Raises:
421+ ValueError: If non-text content is found in multimodal format
422+ """
423+ # Only process for Gemini models
424+ gemini_models = ['gemini25pro' , 'gemini25flash' ]
425+ if model_base not in gemini_models :
426+ return messages
427+
428+ processed_messages = []
429+
430+ for message in messages :
431+ processed_message = message .copy ()
432+ content = message .get ("content" )
433+
434+ # Check if content is in multimodal format (list of content objects)
435+ if isinstance (content , list ):
436+ text_parts = []
437+
438+ for content_item in content :
439+ if isinstance (content_item , dict ):
440+ content_type = content_item .get ("type" )
441+
442+ if content_type == "text" :
443+ text_parts .append (content_item .get ("text" , "" ))
444+ else :
445+ # Error if non-text content is found
446+ raise ValueError (f"Gemini models only support text content. Found unsupported content type: '{ content_type } '" )
447+ else :
448+ # If content item is not a dict, treat as plain text
449+ text_parts .append (str (content_item ))
450+
451+ # Join all text parts and set as the content
452+ processed_message ["content" ] = " " .join (text_parts )
453+
454+ processed_messages .append (processed_message )
455+
456+ return processed_messages
457+
283458
284459"""
285460=================================
@@ -308,8 +483,10 @@ def completions():
308483
309484 logging .debug (f"Received request: { data } " )
310485
486+ user = get_user_from_auth_header ()
487+
311488 req_obj = {
312- "user" : ANL_USER ,
489+ "user" : user ,
313490 "model" : model ,
314491 "prompt" : [data ['prompt' ]],
315492 "system" : "" ,
@@ -389,7 +566,8 @@ def embeddings():
389566 if isinstance (input_data , str ):
390567 input_data = [input_data ]
391568
392- embedding_vectors = _get_embeddings_from_argo (input_data , model )
569+ user = get_user_from_auth_header ()
570+ embedding_vectors = _get_embeddings_from_argo (input_data , model , user )
393571
394572 response_data = {
395573 "object" : "list" ,
@@ -411,15 +589,15 @@ def embeddings():
411589 return jsonify (response_data )
412590
413591
414- def _get_embeddings_from_argo (texts , model ):
592+ def _get_embeddings_from_argo (texts , model , user = BRIDGE_USER ):
415593 BATCH_SIZE = 16
416594 all_embeddings = []
417595
418596 for i in range (0 , len (texts ), BATCH_SIZE ):
419597 batch_texts = texts [i :i + BATCH_SIZE ]
420598
421599 payload = {
422- "user" : ANL_USER ,
600+ "user" : user ,
423601 "model" : model ,
424602 "prompt" : batch_texts
425603 }
@@ -509,6 +687,7 @@ def parse_args():
509687 level = logging .DEBUG if debug_enabled else logging .INFO ,
510688 format = '%(asctime)s - %(levelname)s - %(message)s'
511689 )
690+ logging .getLogger ('watchdog' ).setLevel (logging .CRITICAL + 10 )
512691
513692 logging .info (f'Starting server with debug mode: { debug_enabled } ' )
514693 print (f'Starting server... | Port { args .port } | User { args .username } | Debug: { debug_enabled } ' )
0 commit comments