Skip to content

Commit 7bf1953

Browse files
XinyaoWapre-commit-ci[bot]chensuyueZePan110
authored
Embedding compatible with OpenAI API (#892)
* Embedding TEI Langchain compatible with OpenAI API Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * TextDoc support list Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support tei llama index openai compatible API Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support mosec langchain openai compatible API Signed-off-by: Xinyao Wang <[email protected]> * update UT for embedding tests Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ut bug Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support embedding predictionguard openai compatible API Signed-off-by: Xinyao Wang <[email protected]> * support embedding multimodal clip OpenAI compatible API Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * enable debug mode for embedding UT Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xinyao Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue <[email protected]> Co-authored-by: ZePan110 <[email protected]>
1 parent 4418824 commit 7bf1953

16 files changed

+429
-58
lines changed

comps/cores/proto/docarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TopologyInfo:
1717

1818

1919
class TextDoc(BaseDoc, TopologyInfo):
20-
text: str = None
20+
text: Union[str, List[str]] = None
2121

2222

2323
class Audio2text(BaseDoc, TopologyInfo):
@@ -93,15 +93,15 @@ class DocPath(BaseDoc):
9393

9494

9595
class EmbedDoc(BaseDoc):
96-
text: str
97-
embedding: conlist(float, min_length=0)
96+
text: Union[str, List[str]]
97+
embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]]
9898
search_type: str = "similarity"
9999
k: int = 4
100100
distance_threshold: Optional[float] = None
101101
fetch_k: int = 20
102102
lambda_mult: float = 0.5
103103
score_threshold: float = 0.2
104-
constraints: Optional[Union[Dict[str, Any], None]] = None
104+
constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None
105105

106106

107107
class EmbedMultimodalDoc(EmbedDoc):

comps/embeddings/mosec/langchain/README.md

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,34 @@ docker run -d --name="embedding-langchain-mosec-server" -e http_proxy=$http_prox
2525

2626
## run client test
2727

28-
```
29-
curl localhost:6000/v1/embeddings \
30-
-X POST \
31-
-d '{"text":"Hello, world!"}' \
32-
-H 'Content-Type: application/json'
28+
Use our basic API.
29+
30+
```bash
31+
## query with single text
32+
curl http://localhost:6000/v1/embeddings\
33+
-X POST \
34+
-d '{"text":"Hello, world!"}' \
35+
-H 'Content-Type: application/json'
36+
37+
## query with multiple texts
38+
curl http://localhost:6000/v1/embeddings\
39+
-X POST \
40+
-d '{"text":["Hello, world!","How are you?"]}' \
41+
-H 'Content-Type: application/json'
42+
```
43+
44+
We are also compatible with [OpenAI API](https://platform.openai.com/docs/api-reference/embeddings).
45+
46+
```bash
47+
## Input single text
48+
curl http://localhost:6000/v1/embeddings\
49+
-X POST \
50+
-d '{"input":"Hello, world!"}' \
51+
-H 'Content-Type: application/json'
52+
53+
## Input multiple texts with parameters
54+
curl http://localhost:6000/v1/embeddings\
55+
-X POST \
56+
-d '{"input":["Hello, world!","How are you?"], "dimensions":100}' \
57+
-H 'Content-Type: application/json'
3358
```

comps/embeddings/mosec/langchain/embedding_mosec.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
import os
66
import time
7-
from typing import List, Optional
7+
from typing import List, Optional, Union
88

99
from langchain_community.embeddings import OpenAIEmbeddings
1010

@@ -18,6 +18,12 @@
1818
register_statistics,
1919
statistics_dict,
2020
)
21+
from comps.cores.proto.api_protocol import (
22+
ChatCompletionRequest,
23+
EmbeddingRequest,
24+
EmbeddingResponse,
25+
EmbeddingResponseData,
26+
)
2127

2228
logger = CustomLogger("embedding_mosec")
2329
logflag = os.getenv("LOGFLAG", False)
@@ -62,18 +68,43 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
6268
output_datatype=EmbedDoc,
6369
)
6470
@register_statistics(names=["opea_service@embedding_mosec"])
65-
async def embedding(input: TextDoc) -> EmbedDoc:
71+
async def embedding(
72+
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
73+
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
6674
if logflag:
6775
logger.info(input)
6876
start = time.time()
69-
embed_vector = await embeddings.aembed_query(input.text)
70-
res = EmbedDoc(text=input.text, embedding=embed_vector)
77+
if isinstance(input, TextDoc):
78+
embed_vector = await get_embeddings(input.text)
79+
embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector
80+
res = EmbedDoc(text=input.text, embedding=embedding_res)
81+
else:
82+
embed_vector = await get_embeddings(input.input)
83+
if input.dimensions is not None:
84+
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]
85+
86+
# for standard openai embedding format
87+
res = EmbeddingResponse(
88+
data=[EmbeddingResponseData(index=i, embedding=embed_vector[i]) for i in range(len(embed_vector))]
89+
)
90+
91+
if isinstance(input, ChatCompletionRequest):
92+
input.embedding = res
93+
# keep
94+
res = input
95+
7196
statistics_dict["opea_service@embedding_mosec"].append_latency(time.time() - start, None)
7297
if logflag:
7398
logger.info(res)
7499
return res
75100

76101

102+
async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
103+
texts = [text] if isinstance(text, str) else text
104+
embed_vector = await embeddings.aembed_documents(texts)
105+
return embed_vector
106+
107+
77108
if __name__ == "__main__":
78109
MOSEC_EMBEDDING_ENDPOINT = os.environ.get("MOSEC_EMBEDDING_ENDPOINT", "http://127.0.0.1:8080")
79110
os.environ["OPENAI_API_BASE"] = MOSEC_EMBEDDING_ENDPOINT

comps/embeddings/multimodal_clip/README.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,34 @@ curl http://localhost:6000/v1/health_check\
4444

4545
### 2.2 Consume Embedding Service
4646

47+
Use our basic API.
48+
49+
```bash
50+
## query with single text
51+
curl http://localhost:6000/v1/embeddings\
52+
-X POST \
53+
-d '{"text":"Hello, world!"}' \
54+
-H 'Content-Type: application/json'
55+
56+
## query with multiple texts
57+
curl http://localhost:6000/v1/embeddings\
58+
-X POST \
59+
-d '{"text":["Hello, world!","How are you?"]}' \
60+
-H 'Content-Type: application/json'
61+
```
62+
63+
We are also compatible with [OpenAI API](https://platform.openai.com/docs/api-reference/embeddings).
64+
4765
```bash
48-
curl http://localhost:6000/v1/embeddings \
49-
-X POST -d '{"text":"Sample text"}' \
50-
-H 'Content-Type: application/json'
66+
## Input single text
67+
curl http://localhost:6000/v1/embeddings\
68+
-X POST \
69+
-d '{"input":"Hello, world!"}' \
70+
-H 'Content-Type: application/json'
5171

72+
## Input multiple texts with parameters
73+
curl http://localhost:6000/v1/embeddings\
74+
-X POST \
75+
-d '{"input":["Hello, world!","How are you?"], "dimensions":100}' \
76+
-H 'Content-Type: application/json'
5277
```

comps/embeddings/multimodal_clip/embedding_multimodal.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import datetime
5+
import os
56
import time
7+
from typing import List, Optional, Union
68

79
from dateparser.search import search_dates
810
from embeddings_clip import vCLIP
911

1012
from comps import (
13+
CustomLogger,
1114
EmbedDoc,
1215
ServiceType,
1316
TextDoc,
@@ -16,6 +19,15 @@
1619
register_statistics,
1720
statistics_dict,
1821
)
22+
from comps.cores.proto.api_protocol import (
23+
ChatCompletionRequest,
24+
EmbeddingRequest,
25+
EmbeddingResponse,
26+
EmbeddingResponseData,
27+
)
28+
29+
logger = CustomLogger("embedding_multimodal")
30+
logflag = os.getenv("LOGFLAG", False)
1931

2032

2133
def filtler_dates(prompt):
@@ -64,21 +76,49 @@ def filtler_dates(prompt):
6476
output_datatype=EmbedDoc,
6577
)
6678
@register_statistics(names=["opea_service@embedding_multimodal"])
67-
def embedding(input: TextDoc) -> EmbedDoc:
79+
async def embedding(
80+
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
81+
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
82+
if logflag:
83+
logger.info(input)
6884
start = time.time()
6985

7086
if isinstance(input, TextDoc):
71-
# Handle text input
72-
embed_vector = embeddings.embed_query(input.text).tolist()[0]
73-
res = EmbedDoc(text=input.text, embedding=embed_vector, constraints=filtler_dates(input.text))
74-
87+
embed_vector = await get_embeddings(input.text)
88+
if isinstance(input.text, str):
89+
embedding_res = embed_vector[0]
90+
constraints_res = filtler_dates(input.text)
91+
else:
92+
embedding_res = embed_vector
93+
constraints_res = [filtler_dates(input.text[i]) for i in range(len(input.text))]
94+
res = EmbedDoc(text=input.text, embedding=embedding_res, constraints=constraints_res)
7595
else:
76-
raise ValueError("Invalid input type")
96+
embed_vector = await get_embeddings(input.input)
97+
if input.dimensions is not None:
98+
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]
99+
100+
# for standard openai embedding format
101+
res = EmbeddingResponse(
102+
data=[EmbeddingResponseData(index=i, embedding=embed_vector[i]) for i in range(len(embed_vector))]
103+
)
104+
105+
if isinstance(input, ChatCompletionRequest):
106+
input.embedding = res
107+
# keep
108+
res = input
77109

78110
statistics_dict["opea_service@embedding_multimodal"].append_latency(time.time() - start, None)
111+
if logflag:
112+
logger.info(res)
79113
return res
80114

81115

116+
async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
117+
texts = [text] if isinstance(text, str) else text
118+
embed_vector = embeddings.embed_query(texts).tolist()
119+
return embed_vector
120+
121+
82122
if __name__ == "__main__":
83123
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 4})
84124
opea_microservices["opea_service@embedding_multimodal"].start()

comps/embeddings/predictionguard/README.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,34 @@ docker run -d --name="embedding-predictionguard" -p 6000:6000 -e PREDICTIONGUARD
3131

3232
## 🚀 Consume Embeddings Service
3333

34+
Use our basic API.
35+
3436
```bash
35-
curl localhost:6000/v1/embeddings \
36-
-X POST \
37-
-d '{"text":"Hello, world!"}' \
38-
-H 'Content-Type: application/json'
37+
## query with single text
38+
curl http://localhost:6000/v1/embeddings\
39+
-X POST \
40+
-d '{"text":"Hello, world!"}' \
41+
-H 'Content-Type: application/json'
42+
43+
## query with multiple texts
44+
curl http://localhost:6000/v1/embeddings\
45+
-X POST \
46+
-d '{"text":["Hello, world!","How are you?"]}' \
47+
-H 'Content-Type: application/json'
48+
```
49+
50+
We are also compatible with [OpenAI API](https://platform.openai.com/docs/api-reference/embeddings).
51+
52+
```bash
53+
## Input single text
54+
curl http://localhost:6000/v1/embeddings\
55+
-X POST \
56+
-d '{"input":"Hello, world!"}' \
57+
-H 'Content-Type: application/json'
58+
59+
## Input multiple texts with parameters
60+
curl http://localhost:6000/v1/embeddings\
61+
-X POST \
62+
-d '{"input":["Hello, world!","How are you?"], "dimensions":100}' \
63+
-H 'Content-Type: application/json'
3964
```

comps/embeddings/predictionguard/embedding_predictionguard.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import os
66
import time
7+
from typing import List, Optional, Union
78

89
from predictionguard import PredictionGuard
910

1011
from comps import (
12+
CustomLogger,
1113
EmbedDoc,
1214
ServiceType,
1315
TextDoc,
@@ -16,6 +18,15 @@
1618
register_statistics,
1719
statistics_dict,
1820
)
21+
from comps.cores.proto.api_protocol import (
22+
ChatCompletionRequest,
23+
EmbeddingRequest,
24+
EmbeddingResponse,
25+
EmbeddingResponseData,
26+
)
27+
28+
logger = CustomLogger("embedding_predictionguard")
29+
logflag = os.getenv("LOGFLAG", False)
1930

2031
# Initialize Prediction Guard client
2132
client = PredictionGuard()
@@ -31,16 +42,46 @@
3142
output_datatype=EmbedDoc,
3243
)
3344
@register_statistics(names=["opea_service@embedding_predictionguard"])
34-
def embedding(input: TextDoc) -> EmbedDoc:
45+
async def embedding(
46+
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
47+
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
48+
if logflag:
49+
logger.info(input)
3550
start = time.time()
36-
response = client.embeddings.create(model=pg_embedding_model_name, input=[{"text": input.text}])
37-
embed_vector = response["data"][0]["embedding"]
38-
embed_vector = embed_vector[:512] # Keep only the first 512 elements
39-
res = EmbedDoc(text=input.text, embedding=embed_vector)
51+
52+
if isinstance(input, TextDoc):
53+
embed_vector = await get_embeddings(input.text)
54+
embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector
55+
res = EmbedDoc(text=input.text, embedding=embedding_res)
56+
else:
57+
embed_vector = await get_embeddings(input.input)
58+
input.dimensions = input.dimensions if input.dimensions is not None else 512
59+
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]
60+
61+
# for standard openai embedding format
62+
res = EmbeddingResponse(
63+
data=[EmbeddingResponseData(index=i, embedding=embed_vector[i]) for i in range(len(embed_vector))]
64+
)
65+
66+
if isinstance(input, ChatCompletionRequest):
67+
input.embedding = res
68+
# keep
69+
res = input
70+
4071
statistics_dict["opea_service@embedding_predictionguard"].append_latency(time.time() - start, None)
72+
if logflag:
73+
logger.info(res)
4174
return res
4275

4376

77+
async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
78+
texts = [text] if isinstance(text, str) else text
79+
texts = [{"text": texts[i]} for i in range(len(texts))]
80+
response = client.embeddings.create(model=pg_embedding_model_name, input=texts)["data"]
81+
embed_vector = [response[i]["embedding"] for i in range(len(response))]
82+
return embed_vector
83+
84+
4485
if __name__ == "__main__":
4586
pg_embedding_model_name = os.getenv("PG_EMBEDDING_MODEL_NAME", "bridgetower-large-itm-mlm-itc")
4687
print("Prediction Guard Embedding initialized.")

0 commit comments

Comments
 (0)