Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple image sources for LVM microservice #451

Merged
merged 6 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import os
from io import BytesIO

import requests
from fastapi import Request
from fastapi.responses import StreamingResponse
from PIL import Image

from ..proto.api_protocol import (
AudioChatCompletionRequest,
Expand Down Expand Up @@ -74,6 +77,7 @@
pass

def _handle_message(self, messages):
images = []

Check warning on line 80 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L80

Added line #L80 was not covered by tests
if isinstance(messages, str):
prompt = messages
else:
Expand Down Expand Up @@ -104,7 +108,6 @@
raise ValueError(f"Unknown role: {msg_role}")
if system_prompt:
prompt = system_prompt + "\n"
images = []
for role, message in messages_dict.items():
if isinstance(message, tuple):
text, image_list = message
Expand All @@ -113,8 +116,24 @@
else:
prompt += role + ":"
for img in image_list:
response = requests.get(img)
images.append(base64.b64encode(response.content).decode("utf-8"))
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()

Check warning on line 125 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L120-L125

Added lines #L120 - L125 were not covered by tests
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()

Check warning on line 131 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L127-L131

Added lines #L127 - L131 were not covered by tests
# Bytes
else:
img_b64_str = img

Check warning on line 134 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L134

Added line #L134 was not covered by tests

images.append(img_b64_str)

Check warning on line 136 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L136

Added line #L136 was not covered by tests
else:
if message:
prompt += role + ": " + message + "\n"
Expand Down
2 changes: 1 addition & 1 deletion comps/lvms/lvm_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def lvm(request: LVMDoc):
async def stream_generator():
chat_response = ""
text_generation = await lvm_client.text_generation(
prompt=prompt,
prompt=image_prompt,
stream=streaming,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ httpx
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Pillow
prometheus-fastapi-instrumentator
pyyaml
requests
Expand Down
Loading