Skip to content
This repository was archived by the owner on Feb 10, 2026. It is now read-only.

Commit 46bc7df

Browse files
authored
Merge pull request #3 from AdvancedPhotonSource/new_models
Add new models to bridge.
2 parents 2bc0460 + dfdb786 commit 46bc7df

File tree

2 files changed

+189
-9
lines changed

2 files changed

+189
-9
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ metrics
44
.env
55
prometheus.yml
66
myserver.crt
7-
myserver.key
7+
myserver.key
8+
sandbox.py

argo_bridge.py

Lines changed: 187 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ def after_request(response):
6565

6666
'gpto1': 'gpto1',
6767
'o1': 'gpto1',
68+
69+
'gpto3': 'gpto3',
70+
'gpto4mini': 'gpto4mini',
71+
'gpt41': 'gpt41',
72+
'gpt41mini' : 'gpt41mini',
73+
'gpt41nano' : 'gpt41nano',
74+
75+
76+
'gemini25pro': 'gemini25pro',
77+
'gemini25flash': 'gemini25flash',
78+
'claudeopus4': 'claudeopus4',
79+
'claudesonnet4': 'claudesonnet4',
80+
'claudesonnet37': 'claudesonnet37',
81+
'claudesonnet35v2': 'claudesonnet35v2',
6882
}
6983

7084

@@ -108,9 +122,25 @@ def after_request(response):
108122
# Models using development environment
109123
'gpto3mini': 'dev',
110124
'gpto1mini': 'dev',
111-
'gpto1': 'dev'
125+
'gpto1': 'dev',
126+
'gemini25pro': 'dev',
127+
'gemini25flash': 'dev',
128+
'claudeopus4': 'dev',
129+
'claudesonnet4': 'dev',
130+
'claudesonnet37': 'dev',
131+
'claudesonnet35v2': 'dev',
132+
'gpto3': 'dev',
133+
'gpto4mini': 'dev',
134+
'gpt41': 'dev',
135+
'gpt41mini' : 'dev',
136+
'gpt41nano' : 'dev',
112137
}
113138

139+
140+
NON_STREAMING_MODELS = ['gemini25pro', 'gemini25flash',
141+
'claudeopus4', 'claudesonnet4', 'claudesonnet37', 'claudesonnet35v2',
142+
'gpto3', 'gpto4mini', 'gpt41', 'gpt41mini', 'gpt41nano',]
143+
114144
# For models endpoint
115145
MODELS = {
116146
"object": "list",
@@ -131,10 +161,30 @@ def after_request(response):
131161
EMBED_ENV = 'prod'
132162

133163
DEFAULT_MODEL = "gpt4o"
134-
ANL_USER = "APS"
164+
BRIDGE_USER = "ARGO_BRIDGE"
135165
ANL_STREAM_URL = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/streamchat/"
136166
ANL_DEBUG_FP = 'log_bridge.log'
137167

168+
169+
def get_user_from_auth_header():
170+
"""
171+
Extracts the user from the Authorization header.
172+
If the header is present and valid, the bearer token is returned.
173+
Otherwise, the default user is returned.
174+
"""
175+
auth_header = request.headers.get("Authorization")
176+
if auth_header and auth_header.startswith("Bearer "):
177+
# Return the token part of the header
178+
token = auth_header.split(" ")[1]
179+
logging.debug(f"Authorization header found: {auth_header}")
180+
if token == 'noop':
181+
return BRIDGE_USER
182+
183+
return auth_header.split(" ")[1]
184+
# Return the default user if no valid header is found
185+
return BRIDGE_USER
186+
187+
138188
def get_api_url(model, endpoint_type):
139189
"""
140190
Determine the correct API URL based on model and endpoint type
@@ -169,11 +219,18 @@ def chat_completions():
169219
logging.info("Received chat completions request")
170220

171221
data = request.get_json()
222+
logging.info(f"Request Data: {data}")
172223
model_base = data.get("model", DEFAULT_MODEL)
173224
is_streaming = data.get("stream", False)
174225
temperature = data.get("temperature", 0.1)
175226
stop = data.get("stop", [])
176227

228+
# Force non-streaming for specific models. Remove once Argo supports streaming for all models.
229+
# TODO: TEMP Fake streaming for the new models until Argo supports it
230+
is_fake_stream = False
231+
if model_base in NON_STREAMING_MODELS and is_streaming:
232+
is_fake_stream = True
233+
177234
if model_base not in MODEL_MAPPING:
178235
return jsonify({"error": {
179236
"message": f"Model '{model_base}' not supported."
@@ -183,8 +240,19 @@ def chat_completions():
183240

184241
logging.debug(f"Received request: {data}")
185242

243+
# Process multimodal content for Gemini models
244+
if model_base.startswith('gemini'):
245+
try:
246+
data['messages'] = convert_multimodal_to_text(data['messages'], model_base)
247+
except ValueError as e:
248+
return jsonify({"error": {
249+
"message": str(e)
250+
}}), 400
251+
252+
user = get_user_from_auth_header()
253+
186254
req_obj = {
187-
"user": ANL_USER,
255+
"user": user,
188256
"model": model,
189257
"messages": data['messages'],
190258
"system": "",
@@ -194,7 +262,22 @@ def chat_completions():
194262

195263
logging.debug(f"Argo Request {req_obj}")
196264

197-
if is_streaming:
265+
if is_fake_stream:
266+
logging.info(req_obj)
267+
response = requests.post(get_api_url(model, 'chat'), json=req_obj)
268+
269+
if not response.ok:
270+
logging.error(f"Internal API error: {response.status_code} {response.reason}")
271+
return jsonify({"error": {
272+
"message": f"Internal API error: {response.status_code} {response.reason}"
273+
}}), 500
274+
275+
json_response = response.json()
276+
text = json_response.get("response", "")
277+
logging.debug(f"Response Text {text}")
278+
return Response(_fake_stream_response(text, model), mimetype='text/event-stream')
279+
280+
elif is_streaming:
198281
return Response(_stream_chat_response(model, req_obj), mimetype='text/event-stream')
199282
else:
200283
response = requests.post(get_api_url(model, 'chat'), json=req_obj)
@@ -280,6 +363,98 @@ def _static_chat_response(text, model):
280363
}]
281364
}
282365

366+
def _fake_stream_response(text, model):
367+
begin_chunk = {
368+
"id": 'abc',
369+
"object": "chat.completion.chunk",
370+
"created": int(datetime.datetime.now().timestamp()),
371+
"model": model,
372+
"choices": [{
373+
"index": 0,
374+
"delta": {'role': 'assistant', 'content':''},
375+
"logprobs": None,
376+
"finish_reason": None
377+
}]
378+
}
379+
yield f"data: {json.dumps(begin_chunk)}\n\n"
380+
chunk = {
381+
"id": 'abc',
382+
"object": "chat.completion.chunk",
383+
"created": int(datetime.datetime.now().timestamp()),
384+
"model": model,
385+
"choices": [{
386+
"index": 0,
387+
"delta": {'content': text},
388+
"logprobs": None,
389+
"finish_reason": None
390+
}]
391+
}
392+
yield f"data: {json.dumps(chunk)}\n\n"
393+
end_chunk = {
394+
"id": 'argo',
395+
"object": "chat.completion.chunk",
396+
"created": int(datetime.datetime.now().timestamp()),
397+
"model": model,
398+
"system_fingerprint": "fp_44709d6fcb",
399+
"choices": [{
400+
"index": 0,
401+
"delta": {},
402+
"logprobs": None,
403+
"finish_reason": "stop"
404+
}]
405+
}
406+
yield f"data: {json.dumps(end_chunk)}\n\n"
407+
yield "data: [DONE]\n\n"
408+
409+
def convert_multimodal_to_text(messages, model_base):
410+
"""
411+
Convert multimodal content format to plain text for Gemini models.
412+
413+
Args:
414+
messages (list): List of message objects
415+
model_base (str): The model being used
416+
417+
Returns:
418+
list: Processed messages with text-only content
419+
420+
Raises:
421+
ValueError: If non-text content is found in multimodal format
422+
"""
423+
# Only process for Gemini models
424+
gemini_models = ['gemini25pro', 'gemini25flash']
425+
if model_base not in gemini_models:
426+
return messages
427+
428+
processed_messages = []
429+
430+
for message in messages:
431+
processed_message = message.copy()
432+
content = message.get("content")
433+
434+
# Check if content is in multimodal format (list of content objects)
435+
if isinstance(content, list):
436+
text_parts = []
437+
438+
for content_item in content:
439+
if isinstance(content_item, dict):
440+
content_type = content_item.get("type")
441+
442+
if content_type == "text":
443+
text_parts.append(content_item.get("text", ""))
444+
else:
445+
# Error if non-text content is found
446+
raise ValueError(f"Gemini models only support text content. Found unsupported content type: '{content_type}'")
447+
else:
448+
# If content item is not a dict, treat as plain text
449+
text_parts.append(str(content_item))
450+
451+
# Join all text parts and set as the content
452+
processed_message["content"] = " ".join(text_parts)
453+
454+
processed_messages.append(processed_message)
455+
456+
return processed_messages
457+
283458

284459
"""
285460
=================================
@@ -308,8 +483,10 @@ def completions():
308483

309484
logging.debug(f"Received request: {data}")
310485

486+
user = get_user_from_auth_header()
487+
311488
req_obj = {
312-
"user": ANL_USER,
489+
"user": user,
313490
"model": model,
314491
"prompt": [data['prompt']],
315492
"system": "",
@@ -389,7 +566,8 @@ def embeddings():
389566
if isinstance(input_data, str):
390567
input_data = [input_data]
391568

392-
embedding_vectors = _get_embeddings_from_argo(input_data, model)
569+
user = get_user_from_auth_header()
570+
embedding_vectors = _get_embeddings_from_argo(input_data, model, user)
393571

394572
response_data = {
395573
"object": "list",
@@ -411,15 +589,15 @@ def embeddings():
411589
return jsonify(response_data)
412590

413591

414-
def _get_embeddings_from_argo(texts, model):
592+
def _get_embeddings_from_argo(texts, model, user=BRIDGE_USER):
415593
BATCH_SIZE = 16
416594
all_embeddings = []
417595

418596
for i in range(0, len(texts), BATCH_SIZE):
419597
batch_texts = texts[i:i + BATCH_SIZE]
420598

421599
payload = {
422-
"user": ANL_USER,
600+
"user": user,
423601
"model": model,
424602
"prompt": batch_texts
425603
}
@@ -509,6 +687,7 @@ def parse_args():
509687
level=logging.DEBUG if debug_enabled else logging.INFO,
510688
format='%(asctime)s - %(levelname)s - %(message)s'
511689
)
690+
logging.getLogger('watchdog').setLevel(logging.CRITICAL+10)
512691

513692
logging.info(f'Starting server with debug mode: {debug_enabled}')
514693
print(f'Starting server... | Port {args.port} | User {args.username} | Debug: {debug_enabled}')

0 commit comments

Comments
 (0)