Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple models support for LLM TGI #835

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
model=chat_request.model if chat_request.model else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
Expand Down
41 changes: 41 additions & 0 deletions comps/cores/mega/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import ipaddress
import json
import multiprocessing
import os
import random
from socket import AF_INET, SOCK_STREAM, socket
from typing import List, Optional, Union

from .logger import CustomLogger


def is_port_free(host: str, port: int) -> bool:
"""Check if a given port on a host is free.
Expand Down Expand Up @@ -183,6 +186,44 @@
return _random_port()


class ConfigError(Exception):
"""Custom exception for configuration errors."""

pass


def load_model_configs(model_configs: str) -> dict:
"""Load and validate the model configurations .

If valid, return the configuration for the specified model name.
"""
logger = CustomLogger("models_loader")
try:
configs = json.loads(model_configs)
if not isinstance(configs, list) or not configs:
raise ConfigError("MODEL_CONFIGS must be a non-empty JSON array.")
required_keys = {"model_name", "displayName", "endpoint", "minToken", "maxToken"}
configs_map = {}
for config in configs:
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ConfigError(f"Missing required configuration fields: {missing_keys}")
empty_keys = [key for key in required_keys if not config.get(key)]
if empty_keys:
raise ConfigError(f"Empty values found for configuration fields: {empty_keys}")
model_name = config["model_name"]
configs_map[model_name] = config
if not configs_map:
raise ConfigError("No valid configurations found.")
return configs_map
except json.JSONDecodeError:
logger.error("Error parsing MODEL_CONFIGS environment variable as JSON.")
raise ConfigError("MODEL_CONFIGS is not valid JSON.")
except ConfigError as e:
logger.error(str(e))
raise

Check warning on line 224 in comps/cores/mega/utils.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/utils.py#L200-L224

Added lines #L200 - L224 were not covered by tests


class SafeContextManager:
"""This context manager ensures that the `__exit__` method of the
sub context is called, even when there is an Exception in the
Expand Down
1 change: 1 addition & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def chat_template_must_contain_variables(cls, v):


class LLMParams(BaseDoc):
model: Optional[str] = None
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
Expand Down
38 changes: 30 additions & 8 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,27 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import ConfigError, load_model_configs
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm_tgi")
logflag = os.getenv("LOGFLAG", False)

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = AsyncInferenceClient(
model=llm_endpoint,
timeout=600,
)
# Environment variables
MODEL_CONFIGS = os.getenv("MODEL_CONFIGS")
DEFAULT_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")

# Extract the model endpoint
llm_endpoint = ""
configs_map = {}
letonghan marked this conversation as resolved.
Show resolved Hide resolved
if not MODEL_CONFIGS:
llm_endpoint = DEFAULT_ENDPOINT
else:
try:
configs_map = load_model_configs(MODEL_CONFIGS)
except ConfigError as e:
logger.error(f"Failed to load model configurations: {e}")
raise ConfigError(f"Failed to load model configurations: {e}")


@register_microservice(
Expand All @@ -45,6 +56,17 @@
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)

if input.model and MODEL_CONFIGS and configs_map:
letonghan marked this conversation as resolved.
Show resolved Hide resolved
if configs_map.get(input.model):
config = configs_map.get(input.model)
llm_endpoint = config.get("endpoint")
else:
logger.error(f"Input model {input.model} not present in model_configs")
raise ConfigError(f"Input model {input.model} not present in model_configs")

llm = AsyncInferenceClient(model=llm_endpoint, timeout=600)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the CI test error, looks like llm_endpoint is empty string for this case, please fix it.
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. Fixed it.


prompt_template = None
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
Expand All @@ -61,7 +83,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche
docs = [doc.text for doc in input.retrieved_docs]
if logflag:
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs)
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs, input.model)
# use default llm parameters for inferencing
new_input = LLMParamsDoc(query=prompt)
if logflag:
Expand Down Expand Up @@ -114,7 +136,7 @@ async def stream_generator():
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents, input.model)

text_generation = await llm.text_generation(
prompt=prompt,
Expand Down Expand Up @@ -170,7 +192,7 @@ async def stream_generator():
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents)
prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents, input.model)

chat_completion = client.completions.create(
model="tgi",
Expand Down
41 changes: 24 additions & 17 deletions comps/llms/text-generation/tgi/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@

class ChatTemplate:
@staticmethod
def generate_rag_prompt(question, documents):
def generate_rag_prompt(question, documents, model):
context_str = "\n".join(documents)
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
if model == "meta-llama/Meta-Llama-3.1-70B-Instruct" or model == "meta-llama/Meta-Llama-3.1-8B-Instruct":
template = """
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
### 搜索结果:{context}
### 问题:{question}
### 回答:
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
else:
template = """
### You are a helpful, respectful and honest assistant to help the user with questions. \
Please refer to the search results obtained from the local knowledge base. \
But be careful to not incorporate the information that you think is not relevant to the question. \
If you don't know the answer to a question, please don't share false information. \n
### Search results: {context} \n
### Question: {question} \n
### Answer:
"""
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
template = """
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
### 搜索结果:{context}
### 问题:{question}
### 回答:
"""
else:
template = """
### You are a helpful, respectful and honest assistant to help the user with questions. \
Please refer to the search results obtained from the local knowledge base. \
But be careful to not incorporate the information that you think is not relevant to the question. \
If you don't know the answer to a question, please don't share false information. \n
### Search results: {context} \n
### Question: {question} \n
### Answer:
"""
return template.format(context=context_str, question=question)
Loading