Skip to content

Commit

Permalink
Add Guardrails and Embedding microservices (#15)
Browse files Browse the repository at this point in the history
* Add Guardrails and Embedding microservices

Signed-off-by: lvliang-intel <[email protected]>
  • Loading branch information
lvliang-intel authored May 6, 2024
1 parent 3341e3f commit 1f6c1a5
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 61 deletions.
39 changes: 39 additions & 0 deletions comps/embeddings/langchain/embedding_tei_gaudi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from langchain_community.embeddings import HuggingFaceHubEmbeddings

from comps import EmbedDoc1024, TextDoc, opea_microservices, register_microservice


@register_microservice(
name="opea_service@embedding_tgi_gaudi",
expose_endpoint="/v1/embeddings",
port=8020,
input_datatype=TextDoc,
output_datatype=TextDoc,
)
def safety_guard(input: TextDoc) -> TextDoc:
embed_vector = embeddings.embed_query(input.text)
res = EmbedDoc1024(text=input.text, embedding=embed_vector)
return res


if __name__ == "__main__":
tei_embedding_endpoint = os.getenv("TEI_ENDPOINT", "http://localhost:8080")
embeddings = HuggingFaceHubEmbeddings(model=tei_embedding_endpoint)
print("TEI Gaudi Embedding initialized.")
opea_microservices["opea_service@embedding_tgi_gaudi"].start()
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@


@register_microservice(
name="opea_embedding_service",
name="opea_service@local_embedding",
expose_endpoint="/v1/embeddings",
port=9000,
port=8010,
input_datatype=TextDoc,
output_datatype=EmbedDoc1024,
)
Expand All @@ -31,4 +31,5 @@ def embedding(input: TextDoc) -> EmbedDoc1024:
return res


opea_microservices["opea_embedding_service"].start()
if __name__ == "__main__":
opea_microservices["opea_service@local_embedding"].start()
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
docarray[full]
fastapi
langchain
sentence_transformers
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import os

from fastapi import APIRouter, FastAPI, Request
from langchain_community.llms import HuggingFaceEndpoint
from starlette.middleware.cors import CORSMiddleware

from comps import TextDoc, opea_microservices, register_microservice

unsafe_categories = """O1: Violence and Hate.
Should not
Expand Down Expand Up @@ -103,55 +103,37 @@ def moderation_prompt_for_chat(chat):
return prompt


app = FastAPI()

app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
@register_microservice(
name="opea_service@guardrails_tgi_gaudi",
expose_endpoint="/v1/guardrails",
port=8020,
input_datatype=TextDoc,
output_datatype=TextDoc,
)


class GuardRailsRouter(APIRouter):

def __init__(self, safety_guard_endpoint) -> None:
super().__init__()
self.safety_guard_endpoint = safety_guard_endpoint

self.llm_guard = HuggingFaceEndpoint(
endpoint_url=safety_guard_endpoint,
max_new_tokens=100,
top_k=1,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
)
print("guardrails - router] LLM initialized.")


safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT")
router = GuardRailsRouter(safety_guard_endpoint)


@router.post("/v1/guardrails")
async def safety_guard(request: Request):
params = await request.json()
print(f"[guardrails - chat] POST request: /v1/guardrails, params:{params}")
query = params["query"]

def safety_guard(input: TextDoc) -> TextDoc:
# prompt guardrails
response_input_guard = router.llm_guard(moderation_prompt_for_chat("User", query))
response_input_guard = llm_guard(moderation_prompt_for_chat([{"role": "User", "content": input.text}]))
if "unsafe" in response_input_guard:
policy_violation_level = response_input_guard.split("\n")[1].strip()
policy_violations = unsafe_dict[policy_violation_level]
print(f"Violated policies: {policy_violations}")
return f"Violated policies: {policy_violations}, please check your input."
res = TextDoc(text=f"Violated policies: {policy_violations}, please check your input.")
else:
return "safe"
res = TextDoc(text="safe")

return res

app.include_router(router)

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=9000)
safety_guard_endpoint = os.getenv("SAFETY_GUARD_ENDPOINT", "http://localhost:8080")
llm_guard = HuggingFaceEndpoint(
endpoint_url=safety_guard_endpoint,
max_new_tokens=100,
top_k=1,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
)
print("guardrails - router] LLM initialized.")
opea_microservices["opea_service@guardrails_tgi_gaudi"].start()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
13 changes: 0 additions & 13 deletions tests/workflows/test_asr_comps.sh

This file was deleted.

0 comments on commit 1f6c1a5

Please sign in to comment.