Skip to content

Commit

Permalink
Update reranking_tei to support authentication for tgi endpoints. Use…
Browse files Browse the repository at this point in the history
…s get_access_token from utils
  • Loading branch information
sgurunat committed Oct 29, 2024
1 parent 9a8e582 commit 5ef283f
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions comps/reranks/tei/reranking_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@
RerankingResponse,
RerankingResponseData,
)
from comps.cores.mega.utils import get_access_token


logger = CustomLogger("reranking_tei")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

@register_microservice(
name="opea_service@reranking_tei",
Expand All @@ -46,6 +52,7 @@ async def reranking(
logger.info(input)
start = time.time()
reranking_results = []
access_token = get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
url = tei_reranking_endpoint + "/rerank"
Expand All @@ -56,6 +63,8 @@ async def reranking(
query = input.input
data = {"query": query, "texts": docs}
headers = {"Content-Type": "application/json"}
if access_token:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(data), headers=headers) as response:
response_data = await response.json()
Expand Down

0 comments on commit 5ef283f

Please sign in to comment.