Skip to content

Commit 0ee8185

Browse files
authored
Merge pull request #685 from euxx/feat/xinference
feat: add Xinference LLM support
2 parents ac1a841 + 2bd2717 commit 0ee8185

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
3333
snowflake = ["snowflake-connector-python"]
3434
duckdb = ["duckdb"]
3535
google = ["google-generativeai", "google-cloud-aiplatform"]
36-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres"]
36+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
3737
test = ["tox"]
3838
chromadb = ["chromadb"]
3939
openai = ["openai"]
@@ -56,3 +56,4 @@ azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fast
5656
pgvector = ["langchain-postgres>=0.0.12"]
5757
faiss-cpu = ["faiss-cpu"]
5858
faiss-gpu = ["faiss-gpu"]
59+
xinference-client = ["xinference-client"]

src/vanna/xinference/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .xinference import Xinference

src/vanna/xinference/xinference.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from xinference_client.client.restful.restful_client import (
2+
Client,
3+
RESTfulChatModelHandle,
4+
)
5+
6+
from ..base import VannaBase
7+
8+
9+
class Xinference(VannaBase):
10+
def __init__(self, config=None):
11+
VannaBase.__init__(self, config=config)
12+
13+
if not config or "base_url" not in config:
14+
raise ValueError("config must contain at least Xinference base_url")
15+
16+
base_url = config["base_url"]
17+
api_key = config.get("api_key", "not empty")
18+
self.xinference_client = Client(base_url=base_url, api_key=api_key)
19+
20+
def system_message(self, message: str) -> any:
21+
return {"role": "system", "content": message}
22+
23+
def user_message(self, message: str) -> any:
24+
return {"role": "user", "content": message}
25+
26+
def assistant_message(self, message: str) -> any:
27+
return {"role": "assistant", "content": message}
28+
29+
def submit_prompt(self, prompt, **kwargs) -> str:
30+
if prompt is None:
31+
raise Exception("Prompt is None")
32+
33+
if len(prompt) == 0:
34+
raise Exception("Prompt is empty")
35+
36+
num_tokens = 0
37+
for message in prompt:
38+
num_tokens += len(message["content"]) / 4
39+
40+
model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None)
41+
if model_uid is None:
42+
raise ValueError("model_uid is required")
43+
44+
xinference_model = self.xinference_client.get_model(model_uid)
45+
if isinstance(xinference_model, RESTfulChatModelHandle):
46+
print(
47+
f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
48+
)
49+
50+
response = xinference_model.chat(prompt)
51+
return response["choices"][0]["message"]["content"]
52+
else:
53+
raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")

tests/test_imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_regular_imports():
2828
from vanna.remote import VannaDefault
2929
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
3030
from vanna.weaviate.weaviate_vector import WeaviateDatabase
31+
from vanna.xinference.xinference import Xinference
3132
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
3233
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings
3334

@@ -52,4 +53,5 @@ def test_shortcut_imports():
5253
from vanna.vannadb import VannaDB_VectorStore
5354
from vanna.vllm import Vllm
5455
from vanna.weaviate import WeaviateDatabase
56+
from vanna.xinference import Xinference
5557
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings

0 commit comments

Comments
 (0)