Skip to content

Commit 4d54ca5

Browse files
Added groq provider
1 parent 1ce4e19 commit 4d54ca5

File tree

8 files changed

+257
-20
lines changed

8 files changed

+257
-20
lines changed
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from llama_cpp_agent import LlamaCppAgent
2+
from llama_cpp_agent import MessagesFormatterType
3+
from llama_cpp_agent.providers.groq import GroqProvider
4+
5+
provider = GroqProvider(base_url="https://api.groq.com/openai/v1", model="mixtral-8x7b-32768", huggingface_model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="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")
6+
7+
agent = LlamaCppAgent(
8+
provider,
9+
system_prompt="You are a helpful assistant.",
10+
predefined_messages_formatter_type=MessagesFormatterType.MISTRAL,
11+
)
12+
13+
settings = provider.get_provider_default_settings()
14+
settings.stream = True
15+
settings.max_tokens = 512
16+
settings.temperature = 0.65
17+
18+
while True:
19+
user_input = input(">")
20+
if user_input == "exit":
21+
break
22+
agent_output = agent.get_chat_response(user_input, llm_sampling_settings=settings)
23+
print(f"Agent: {agent_output.strip()}")

examples/07_Memory/MemoryAssistant/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from prompts import assistant_prompt, memory_prompt, wrap_function_response_in_xml_tags_json_mode, \
1111
generate_write_message, generate_write_message_with_examples, wrap_user_message_in_xml_tags_json_mode
1212

13-
provider = LlamaCppServerProvider("http://hades.hq.solidrust.net:8084")
13+
provider = LlamaCppServerProvider("http://localhost:8080")
1414

1515
agent = LlamaCppAgent(
1616
provider,

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ email = "[email protected]"
2727
agent_memory = ["chromadb", "SQLAlchemy", "numpy", "scipy"]
2828
rag = ["ragatouille"]
2929
vllm_provider = ["openai", "transformers", "sentencepiece", "protobuf"]
30+
groq_provider = ["groq"]
3031
mixtral_agent = ["mistral-common"]
3132
web_search_summarization = ["duckduckgo_search", "trafilatura", "lxml-html-clean", "lxml", "googlesearch-python" , "beautifulsoup4", "readability-lxml"]
3233

src/llama_cpp_agent/function_calling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def pydantic_model_to_openai_function_definition(pydantic_model: Type[BaseModel]
8383
function_definition = {
8484
"type": "function",
8585
"function": {
86-
"name": pydantic_model.__name__.lower(),
86+
"name": pydantic_model.__name__,
8787
"description": class_description,
8888
"parameters": {
8989
"type": "object",

src/llama_cpp_agent/llm_agent.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
function_calling_function_list_templater, structured_output_templater, \
2222
structured_output_thoughts_and_reasoning_templater
2323

24-
from .providers.provider_base import LlmProvider, LlmSamplingSettings
24+
from .providers.provider_base import LlmProvider, LlmSamplingSettings, LlmProviderId
25+
2526

2627
class SystemPromptModulePosition(Enum):
2728
after_system_instructions = 1
@@ -195,7 +196,7 @@ def stream_results():
195196
yield out_text
196197

197198
return structured_output_settings.handle_structured_output(
198-
full_response_stream
199+
full_response_stream, provider=self.provider
199200
)
200201

201202
if llm_sampling_settings.is_streaming():
@@ -219,7 +220,7 @@ def stream_results():
219220
print("")
220221
self.last_response = full_response
221222
return structured_output_settings.handle_structured_output(
222-
full_response
223+
full_response, provider=self.provider
223224
)
224225
else:
225226
full_response = ""
@@ -229,7 +230,7 @@ def stream_results():
229230
print(full_response)
230231
self.last_response = full_response
231232
return structured_output_settings.handle_structured_output(
232-
full_response
233+
full_response, provider=self.provider
233234
)
234235
return "Error: No model loaded!"
235236

@@ -322,7 +323,7 @@ def stream_results():
322323
}
323324
)
324325
return structured_output_settings.handle_structured_output(
325-
full_response_stream, prompt_suffix=prompt_suffix
326+
full_response_stream, prompt_suffix=prompt_suffix, provider=self.provider
326327
)
327328

328329
if self.provider:
@@ -358,7 +359,7 @@ def stream_results():
358359
)
359360

360361
return structured_output_settings.handle_structured_output(
361-
full_response, prompt_suffix=prompt_suffix
362+
full_response, prompt_suffix=prompt_suffix, provider=self.provider
362363
)
363364
else:
364365
text = completion["choices"][0]["text"]
@@ -377,7 +378,7 @@ def stream_results():
377378
}
378379
)
379380

380-
return structured_output_settings.handle_structured_output(text, prompt_suffix=prompt_suffix)
381+
return structured_output_settings.handle_structured_output(text, prompt_suffix=prompt_suffix, provider=self.provider)
381382
return "Error: No model loaded!"
382383

383384
def get_text_completion(
@@ -645,7 +646,7 @@ def get_response_role_and_completion(
645646

646647
return (
647648
self.provider.create_completion(
648-
prompt,
649+
prompt if self.provider.get_provider_identifier() is not LlmProviderId.groq else messages,
649650
structured_output_settings,
650651
llm_sampling_settings,
651652
self.messages_formatter.bos_token,

src/llama_cpp_agent/llm_output_settings/settings.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class LlmStructuredOutputSettings(BaseModel):
115115
False,
116116
description="If the output should be a tuple of the output and the generated JSON string by the LLM",
117117
)
118+
118119
class Config:
119120
arbitrary_types_allowed = True
120121

@@ -616,7 +617,8 @@ def add_all_current_functions_to_heartbeat_list(self, excluded: list[str] = None
616617
[tool.model.__name__ for tool in self.function_tools if tool.model.__name__ not in excluded]
617618
)
618619

619-
def handle_structured_output(self, llm_output: str, prompt_suffix: str = None):
620+
def handle_structured_output(self, llm_output: str, prompt_suffix: str = None, provider=None):
621+
620622
if self.output_raw_json_string:
621623
return llm_output
622624

src/llama_cpp_agent/providers/groq.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import json
2+
from copy import copy, deepcopy
3+
from dataclasses import dataclass, field
4+
from typing import List, Dict, Optional, Union
5+
6+
from llama_cpp_agent.llm_output_settings import (
7+
LlmStructuredOutputSettings,
8+
LlmStructuredOutputType,
9+
)
10+
from llama_cpp_agent.providers.provider_base import (
11+
LlmProvider,
12+
LlmProviderId,
13+
LlmSamplingSettings,
14+
)
15+
16+
17+
@dataclass
18+
class GroqSamplingSettings(LlmSamplingSettings):
19+
"""
20+
GroqSamplingSettings dataclass
21+
"""
22+
23+
top_p: float = 1
24+
temperature: float = 0.7
25+
max_tokens: int = 16
26+
stream: bool = False
27+
28+
def get_provider_identifier(self) -> LlmProviderId:
29+
return LlmProviderId.groq
30+
31+
def get_additional_stop_sequences(self) -> Union[List[str], None]:
32+
return None
33+
34+
def add_additional_stop_sequences(self, sequences: List[str]):
35+
pass
36+
37+
def is_streaming(self):
38+
return self.stream
39+
40+
@staticmethod
41+
def load_from_dict(settings: dict) -> "GroqSamplingSettings":
42+
"""
43+
Load the settings from a dictionary.
44+
45+
Args:
46+
settings (dict): The dictionary containing the settings.
47+
48+
Returns:
49+
LlamaCppSamplingSettings: The loaded settings.
50+
"""
51+
return GroqSamplingSettings(**settings)
52+
53+
def as_dict(self) -> dict:
54+
"""
55+
Convert the settings to a dictionary.
56+
57+
Returns:
58+
dict: The dictionary representation of the settings.
59+
"""
60+
return self.__dict__
61+
62+
63+
class GroqProvider(LlmProvider):
64+
def __init__(self, base_url: str, model: str, huggingface_model: str, api_key: str = None):
65+
from openai import OpenAI
66+
from transformers import AutoTokenizer
67+
self.tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
68+
self.client = OpenAI(
69+
base_url=base_url,
70+
api_key=api_key if api_key else "xxx-xxxxxxxx",
71+
)
72+
self.model = model
73+
74+
def is_using_json_schema_constraints(self):
75+
return True
76+
77+
def get_provider_identifier(self) -> LlmProviderId:
78+
return LlmProviderId.groq
79+
80+
def get_provider_default_settings(self) -> GroqSamplingSettings:
81+
return GroqSamplingSettings()
82+
83+
def create_completion(
84+
self,
85+
prompt: str | list[dict],
86+
structured_output_settings: LlmStructuredOutputSettings,
87+
settings: GroqSamplingSettings,
88+
bos_token: str,
89+
):
90+
tools = None
91+
if (
92+
structured_output_settings.output_type
93+
== LlmStructuredOutputType.function_calling
94+
or structured_output_settings.output_type == LlmStructuredOutputType.parallel_function_calling
95+
):
96+
tools = [tool.to_openai_tool() for tool in structured_output_settings.function_tools]
97+
top_p = settings.top_p
98+
stream = settings.stream
99+
temperature = settings.temperature
100+
max_tokens = settings.max_tokens
101+
102+
settings_dict = deepcopy(settings.as_dict())
103+
settings_dict.pop("top_p")
104+
settings_dict.pop("stream")
105+
settings_dict.pop("temperature")
106+
settings_dict.pop("max_tokens")
107+
108+
if settings.stream:
109+
result = self.client.chat.completions.create(
110+
messages=prompt,
111+
model=self.model,
112+
extra_body=settings_dict,
113+
tools=tools,
114+
top_p=top_p,
115+
stream=stream,
116+
temperature=temperature,
117+
max_tokens=max_tokens,
118+
)
119+
120+
def generate_chunks():
121+
for chunk in result:
122+
if chunk.choices[0].delta.tool_calls is not None:
123+
if tools is not None:
124+
args = chunk.choices[0].delta.tool_calls[0].function.arguments
125+
args_loaded = json.loads(args)
126+
function_name = chunk.choices[0].delta.tool_calls[0].function.name
127+
function_dict = {structured_output_settings.function_calling_name_field_name: function_name, structured_output_settings.function_calling_content: args_loaded}
128+
yield {"choices": [{"text": json.dumps(function_dict)}]}
129+
if chunk.choices[0].delta.content is not None:
130+
yield {"choices": [{"text": chunk.choices[0].delta.content}]}
131+
132+
return generate_chunks()
133+
else:
134+
result = self.client.chat.completions.create(
135+
messages=prompt,
136+
model=self.model,
137+
extra_body=settings_dict,
138+
tools=tools,
139+
top_p=top_p,
140+
stream=stream,
141+
temperature=temperature,
142+
max_tokens=max_tokens,
143+
)
144+
if tools is not None:
145+
args = result.choices[0].message.tool_calls[0].function.arguments
146+
args_loaded = json.loads(args)
147+
function_name = result.choices[0].message.tool_calls[0].function.name
148+
function_dict = {structured_output_settings.function_calling_name_field_name: function_name, structured_output_settings.function_calling_content: args_loaded}
149+
return {"choices": [{"text": json.dumps(function_dict)}]}
150+
return {"choices": [{"text": result.choices[0].message.content}]}
151+
152+
def create_chat_completion(
153+
self,
154+
messages: List[Dict[str, str]],
155+
structured_output_settings: LlmStructuredOutputSettings,
156+
settings: GroqSamplingSettings
157+
):
158+
grammar = None
159+
if (
160+
structured_output_settings.output_type
161+
!= LlmStructuredOutputType.no_structured_output
162+
):
163+
grammar = structured_output_settings.get_json_schema()
164+
165+
top_p = settings.top_p
166+
stream = settings.stream
167+
temperature = settings.temperature
168+
max_tokens = settings.max_tokens
169+
170+
settings_dict = copy(settings.as_dict())
171+
settings_dict.pop("top_p")
172+
settings_dict.pop("stream")
173+
settings_dict.pop("temperature")
174+
settings_dict.pop("max_tokens")
175+
if grammar is not None:
176+
settings_dict["guided_json"] = grammar
177+
178+
if settings.stream:
179+
result = self.client.chat.completions.create(
180+
messages=messages,
181+
model=self.model,
182+
extra_body=settings_dict,
183+
top_p=top_p,
184+
stream=stream,
185+
temperature=temperature,
186+
max_tokens=max_tokens,
187+
)
188+
189+
def generate_chunks():
190+
for chunk in result:
191+
if chunk.choices[0].delta.content is not None:
192+
yield {"choices": [{"text": chunk.choices[0].delta.content}]}
193+
194+
return generate_chunks()
195+
else:
196+
result = self.client.chat.completions.create(
197+
messages=messages,
198+
model=self.model,
199+
extra_body=settings_dict,
200+
top_p=top_p,
201+
stream=stream,
202+
temperature=temperature,
203+
max_tokens=max_tokens,
204+
)
205+
return {"choices": [{"text": result.choices[0].message.content}]}
206+
207+
def tokenize(self, prompt: str) -> list[int]:
208+
result = self.tokenizer.encode(text=prompt)
209+
return result

src/llama_cpp_agent/providers/provider_base.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class LlmProviderId(Enum):
2222
llama_cpp_python = "llama_cpp_python"
2323
tgi_server = "text_generation_inference"
2424
vllm_server = "vllm"
25+
groq = "groq"
2526

2627

2728
class LlmSamplingSettings(ABC):
@@ -146,11 +147,11 @@ def get_provider_default_settings(self) -> LlmSamplingSettings:
146147

147148
@abstractmethod
148149
def create_completion(
149-
self,
150-
prompt: str,
151-
structured_output_settings: LlmStructuredOutputSettings,
152-
settings: LlmSamplingSettings,
153-
bos_token: str,
150+
self,
151+
prompt: str | list[dict],
152+
structured_output_settings: LlmStructuredOutputSettings,
153+
settings: LlmSamplingSettings,
154+
bos_token: str,
154155
):
155156
"""
156157
Create a completion request with the LLM provider and returns the result.
@@ -168,10 +169,10 @@ def create_completion(
168169

169170
@abstractmethod
170171
def create_chat_completion(
171-
self,
172-
messages: List[Dict[str, str]],
173-
structured_output_settings: LlmStructuredOutputSettings,
174-
settings: LlmSamplingSettings
172+
self,
173+
messages: List[Dict[str, str]],
174+
structured_output_settings: LlmStructuredOutputSettings,
175+
settings: LlmSamplingSettings
175176
):
176177
"""
177178
Create a chat completion request with the LLM provider and returns the result.

0 commit comments

Comments
 (0)