Skip to content

Commit

Permalink
Support multiple image sources for LVM microservice (#451)
Browse files Browse the repository at this point in the history
Signed-off-by: lvliang-intel <[email protected]>
  • Loading branch information
lvliang-intel authored Aug 10, 2024
1 parent 2098b91 commit ed776ac
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
25 changes: 22 additions & 3 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import os
from io import BytesIO

import requests
from fastapi import Request
from fastapi.responses import StreamingResponse
from PIL import Image

from ..proto.api_protocol import (
AudioChatCompletionRequest,
Expand Down Expand Up @@ -74,6 +77,7 @@ def list_parameter(self):
pass

def _handle_message(self, messages):
images = []
if isinstance(messages, str):
prompt = messages
else:
Expand Down Expand Up @@ -104,7 +108,6 @@ def _handle_message(self, messages):
raise ValueError(f"Unknown role: {msg_role}")
if system_prompt:
prompt = system_prompt + "\n"
images = []
for role, message in messages_dict.items():
if isinstance(message, tuple):
text, image_list = message
Expand All @@ -113,8 +116,24 @@ def _handle_message(self, messages):
else:
prompt += role + ":"
for img in image_list:
response = requests.get(img)
images.append(base64.b64encode(response.content).decode("utf-8"))
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img

images.append(img_b64_str)
else:
if message:
prompt += role + ": " + message + "\n"
Expand Down
2 changes: 1 addition & 1 deletion comps/lvms/lvm_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def lvm(request: LVMDoc):
async def stream_generator():
chat_response = ""
text_generation = await lvm_client.text_generation(
prompt=prompt,
prompt=image_prompt,
stream=streaming,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ httpx
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Pillow
prometheus-fastapi-instrumentator
pyyaml
requests
Expand Down

0 comments on commit ed776ac

Please sign in to comment.