Skip to content

Commit fa0cb75

Browse files
authored
Merge pull request #78 from pauldotyu/main
feat: add support for local llms running in cluster
2 parents 845c893 + be9b671 commit fa0cb75

File tree

3 files changed

+105
-49
lines changed

3 files changed

+105
-49
lines changed

src/ai-service/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ pytest==7.3.1
55
httpx
66
pyyaml
77
semantic-kernel==0.3.1.dev0
8-
azure.identity==1.14.0
8+
azure.identity==1.14.0
9+
requests==2.31.0

src/ai-service/routers/description_generator.py

+102-47
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,73 @@
66
from dotenv import load_dotenv
77
from typing import Any, List, Dict
88
import os
9-
import dotenv
9+
import requests
10+
import json
1011

11-
# Load environment variables from .env file
12-
load_dotenv()
12+
# Set the useLocalLLM and useAzureOpenAI variables based on environment variables
13+
useLocalLLM: bool = False
14+
useAzureOpenAI: bool = False
1315

14-
# Initialize the semantic kernel
15-
kernel: sk.Kernel = sk.Kernel()
16+
if os.environ.get("USE_LOCAL_LLM"):
17+
useLocalLLM = os.environ.get("USE_LOCAL_LLM").lower() == "true"
1618

17-
kernel = sk.Kernel()
19+
if os.environ.get("USE_AZURE_OPENAI"):
20+
useAzureOpenAI = os.environ.get("USE_AZURE_OPENAI").lower() == "true"
1821

19-
# Get the Azure OpenAI deployment name, API key, and endpoint or OpenAI org id from environment variables
20-
useAzureOpenAI: str = os.environ.get("USE_AZURE_OPENAI")
21-
api_key: str = os.environ.get("OPENAI_API_KEY")
22-
useAzureAD: str = os.environ.get("USE_AZURE_AD")
22+
# if useLocalLLM and useAzureOpenAI are both set to true, raise an exception
23+
if useLocalLLM and useAzureOpenAI:
24+
raise Exception("USE_LOCAL_LLM and USE_AZURE_OPENAI environment variables cannot both be set to true")
2325

24-
if (isinstance(api_key, str) == False or api_key == "") and (isinstance(useAzureAD, str) == False or useAzureAD == ""):
25-
raise Exception("OPENAI_API_KEY environment variable must be set")
26-
if isinstance(useAzureOpenAI, str) == False or (useAzureOpenAI.lower() != "true" and useAzureOpenAI.lower() != "false"):
27-
raise Exception("USE_AZURE_OPENAI environment variable must be set to 'True' or 'False' string not boolean")
26+
# if useLocalLLM or useAzureOpenAI are set to true, get the endpoint from the environment variables
27+
if useLocalLLM or useAzureOpenAI:
28+
endpoint: str = os.environ.get("AI_ENDPOINT") or os.environ.get("AZURE_OPENAI_ENDPOINT")
29+
30+
if isinstance(endpoint, str) == False or endpoint == "":
31+
raise Exception("AI_ENDPOINT or AZURE_OPENAI_ENDPOINT environment variable must be set when USE_LOCAL_LLM or USE_AZURE_OPENAI is set to true")
2832

33+
# if not using local LLM, set up the semantic kernel
34+
if useLocalLLM:
35+
print("Using Local LLM")
36+
else:
37+
print("Using OpenAI and setting up Semantic Kernel")
38+
# Load environment variables from .env file
39+
load_dotenv()
2940

30-
if useAzureOpenAI.lower() == "false":
31-
org_id = os.environ.get("OPENAI_ORG_ID")
32-
if isinstance(org_id, str) == False or org_id == "":
33-
raise Exception("OPENAI_ORG_ID environment variable must be set when USE_AZURE_OPENAI is set to False")
34-
# Add the OpenAI text completion service to the kernel
35-
kernel.add_chat_service("dv", OpenAIChatCompletion("gpt-3.5-turbo", api_key, org_id))
41+
# Initialize the semantic kernel
42+
kernel: sk.Kernel = sk.Kernel()
43+
44+
kernel = sk.Kernel()
45+
46+
# Get the Azure OpenAI deployment name, API key, and endpoint or OpenAI org id from environment variables
47+
api_key: str = os.environ.get("OPENAI_API_KEY")
48+
useAzureAD: str = os.environ.get("USE_AZURE_AD")
49+
50+
if (isinstance(api_key, str) == False or api_key == "") and (isinstance(useAzureAD, str) == False or useAzureAD == ""):
51+
raise Exception("OPENAI_API_KEY environment variable must be set")
52+
53+
if not useAzureOpenAI:
54+
org_id = os.environ.get("OPENAI_ORG_ID")
55+
if isinstance(org_id, str) == False or org_id == "":
56+
raise Exception("OPENAI_ORG_ID environment variable must be set when USE_AZURE_OPENAI is set to False")
57+
# Add the OpenAI text completion service to the kernel
58+
kernel.add_chat_service("dv", OpenAIChatCompletion("gpt-3.5-turbo", api_key, org_id))
3659

37-
else:
38-
deployment: str = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
39-
endpoint: str = os.environ.get("AZURE_OPENAI_ENDPOINT")
40-
if isinstance(deployment, str) == False or isinstance(endpoint, str) == False or deployment == "" or endpoint == "":
41-
raise Exception("AZURE_OPENAI_DEPLOYMENT_NAME and AZURE_OPENAI_ENDPOINT environment variables must be set when USE_AZURE_OPENAI is set to true")
42-
# Add the Azure OpenAI text completion service to the kernel
43-
if isinstance(useAzureAD, str) == True and useAzureAD.lower() == "true":
44-
print("Authenticating to Azure OpenAI with Azure AD Workload Identity")
45-
credential = DefaultAzureCredential()
46-
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
47-
kernel.add_chat_service("dv", AzureChatCompletion(deployment_name=deployment, endpoint=endpoint, api_key=access_token.token, ad_auth=True))
4860
else:
49-
print("Authenticating to Azure OpenAI with OpenAI API key")
50-
kernel.add_chat_service("dv", AzureChatCompletion(deployment, endpoint, api_key))
61+
deployment: str = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
62+
# Add the Azure OpenAI text completion service to the kernel
63+
if isinstance(useAzureAD, str) == True and useAzureAD.lower() == "true":
64+
print("Authenticating to Azure OpenAI with Azure AD Workload Identity")
65+
credential = DefaultAzureCredential()
66+
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
67+
kernel.add_chat_service("dv", AzureChatCompletion(deployment_name=deployment, endpoint=endpoint, api_key=access_token.token, ad_auth=True))
68+
else:
69+
print("Authenticating to Azure OpenAI with OpenAI API key")
70+
kernel.add_chat_service("dv", AzureChatCompletion(deployment, endpoint, api_key))
5171

52-
# Import semantic skills from the "skills" directory
53-
skills_directory: str = "skills"
54-
productFunctions: dict = kernel.import_semantic_skill_from_directory(skills_directory, "ProductSkill")
55-
descriptionFunction: Any = productFunctions["Description"]
72+
# Import semantic skills from the "skills" directory
73+
skills_directory: str = "skills"
74+
productFunctions: dict = kernel.import_semantic_skill_from_directory(skills_directory, "ProductSkill")
75+
descriptionFunction: Any = productFunctions["Description"]
5676

5777
# Define the description API router
5878
description: APIRouter = APIRouter(prefix="/generate", tags=["generate"])
@@ -62,7 +82,7 @@ class Product:
6282
def __init__(self, product: Dict[str, List]) -> None:
6383
self.name: str = product["name"]
6484
self.tags: List[str] = product["tags"]
65-
85+
6686
# Define the post_description endpoint
6787
@description.post("/description", summary="Get description for a product", operation_id="getDescription")
6888
async def post_description(request: Request) -> JSONResponse:
@@ -73,15 +93,50 @@ async def post_description(request: Request) -> JSONResponse:
7393
name: str = product.name
7494
tags: List = ",".join(product.tags)
7595

76-
# Create a new context and invoke the description function
77-
context: Any = kernel.create_new_context()
78-
context["name"] = name
79-
context["tags"] = tags
80-
result: str = await descriptionFunction.invoke_async(context=context)
81-
if "error" in str(result).lower():
82-
return Response(content=str(result), status_code=status.HTTP_401_UNAUTHORIZED)
83-
print(result)
84-
result = str(result).replace("\n", "")
96+
if useLocalLLM:
97+
print("Calling local LLM")
98+
99+
prompt = f"Describe this pet store product using joyful, playful, and enticing language.\nProduct name: {name}\ntags: {tags}\ndescription:\""
100+
temperature = 0.5
101+
top_p = 0.0
102+
103+
url = endpoint
104+
payload = {
105+
"prompt": prompt,
106+
"temperature": temperature,
107+
"top_p": top_p
108+
}
109+
headers = {"Content-Type": "application/json"}
110+
response = requests.request("POST", url, headers=headers, json=payload)
111+
112+
# convert response.text to json
113+
result = json.loads(response.text)
114+
result = result["Result"]
115+
result = result.split("description:")[1]
116+
117+
# remove all double quotes
118+
if "\"" in result:
119+
result = result.replace("\"", "")
120+
121+
# # if first character is a double quote, remove it
122+
# if result[0] == "\"":
123+
# result = result[1:]
124+
# # if last character is a double quote, remove it
125+
# if result[-1] == "\"":
126+
# result = result[:-1]
127+
128+
print(result)
129+
else:
130+
print("Calling OpenAI")
131+
# Create a new context and invoke the description function
132+
context: Any = kernel.create_new_context()
133+
context["name"] = name
134+
context["tags"] = tags
135+
result: str = await descriptionFunction.invoke_async(context=context)
136+
if "error" in str(result).lower():
137+
return Response(content=str(result), status_code=status.HTTP_401_UNAUTHORIZED)
138+
print(result)
139+
result = str(result).replace("\n", "")
85140

86141
# Return the description as a JSON response
87142
return JSONResponse(content={"description": result}, status_code=status.HTTP_200_OK)

src/store-admin/src/components/ProductForm.vue

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
<div class="form-row">
2929
<label for="product-description">Description</label>
3030
<textarea id="product-description" placeholder="Product Description" v-model="product.description" />
31-
<button @click="generateDescription" class="ai-button">Ask OpenAI</button>
31+
<button @click="generateDescription" class="ai-button">Ask AI Assistant</button>
3232
<input type="hidden" id="product-id" placeholder="Product ID" v-model="product.id" />
3333
</div>
3434

0 commit comments

Comments
 (0)