From a3f46963508b6de6272b2c9f234cfab954959dfb Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Sat, 10 Aug 2024 23:51:57 +0800 Subject: [PATCH] Support multiple image sources for LVM microservice (#451) Signed-off-by: lvliang-intel --- comps/cores/mega/gateway.py | 25 ++++++++++++++++++++++--- comps/lvms/lvm_tgi.py | 2 +- requirements.txt | 1 + 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 324f7081e..8ad31c841 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +import os +from io import BytesIO import requests from fastapi import Request from fastapi.responses import StreamingResponse +from PIL import Image from ..proto.api_protocol import ( AudioChatCompletionRequest, @@ -74,6 +77,7 @@ def list_parameter(self): pass def _handle_message(self, messages): + images = [] if isinstance(messages, str): prompt = messages else: @@ -104,7 +108,6 @@ def _handle_message(self, messages): raise ValueError(f"Unknown role: {msg_role}") if system_prompt: prompt = system_prompt + "\n" - images = [] for role, message in messages_dict.items(): if isinstance(message, tuple): text, image_list = message @@ -113,8 +116,24 @@ def _handle_message(self, messages): else: prompt += role + ":" for img in image_list: - response = requests.get(img) - images.append(base64.b64encode(response.content).decode("utf-8")) + # URL + if img.startswith("http://") or img.startswith("https://"): + response = requests.get(img) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Local Path + elif os.path.exists(img): + image = Image.open(img).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Bytes + else: + img_b64_str = img + + images.append(img_b64_str) else: if message: prompt += role + ": " + message + "\n" diff --git a/comps/lvms/lvm_tgi.py b/comps/lvms/lvm_tgi.py index b7383fa0c..b2eddf9f1 100644 --- a/comps/lvms/lvm_tgi.py +++ b/comps/lvms/lvm_tgi.py @@ -48,7 +48,7 @@ async def lvm(request: LVMDoc): async def stream_generator(): chat_response = "" text_generation = await lvm_client.text_generation( - prompt=prompt, + prompt=image_prompt, stream=streaming, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, diff --git a/requirements.txt b/requirements.txt index 53bfbf8d4..ef12b2fc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk +Pillow prometheus-fastapi-instrumentator pyyaml requests