Skip to content

Commit ede63fc

Browse files
committed
Refactor LLM providers for consistent interface and enhanced functionality
- Update all provider classes to use a unified `generate` method signature - Add support for system prompts and pre-defined message lists - Implement flexible kwargs handling for provider-specific options - Replace `max_tokens` and `temperature` params with more generic approach - Update type hints and imports across all provider files - Improve error handling and default values in provider initializations - Standardize payload construction and API call patterns - Remove unused parameters and simplify client instantiations
1 parent b4deedc commit ede63fc

File tree

7 files changed

+152
-121
lines changed

7 files changed

+152
-121
lines changed

llmdk/providers/anthropic.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional
5+
from typing import Any, Optional, List, Dict
66

77
from anthropic import Anthropic
88

@@ -30,24 +30,29 @@ def __init__(
3030
def generate(
3131
self,
3232
prompt: str,
33-
temperature: Optional[float] = None,
34-
max_tokens: Optional[int] = None,
33+
system_prompt: Optional[str] = None,
34+
messages: Optional[List[Dict[str, str]]] = None,
35+
**kwargs: Any,
3536
) -> str:
36-
payload = {
37-
'model': self._model_name,
38-
'messages': [{
39-
"role": "user",
40-
"content": prompt,
41-
}],
42-
}
43-
44-
if temperature is not None:
45-
payload['temperature'] = temperature
46-
47-
# Required by Anthropic
48-
if max_tokens is None:
49-
max_tokens = 4096
50-
payload['max_tokens'] = max_tokens
51-
52-
message = self._client.messages.create(**payload).content[0].text
37+
payload = self._generate_kwargs.copy()
38+
payload.update(kwargs)
39+
payload['model'] = self._model_name
40+
payload['max_tokens'] = 4096 # Required
41+
42+
if messages is not None:
43+
payload['messages'] = messages
44+
else:
45+
payload['messages'] = []
46+
if system_prompt:
47+
payload['messages'].append({
48+
'role': 'system',
49+
'content': system_prompt,
50+
})
51+
payload['messages'].append({
52+
'role': 'user',
53+
'content': prompt,
54+
})
55+
56+
completion = self._client.messages.create(**payload)
57+
message = completion.content[0].text
5358
return message

llmdk/providers/groq.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional
5+
from typing import Any, Optional, List, Dict
66

77
from groq import Groq
88

@@ -14,7 +14,6 @@ def __init__(
1414
self,
1515
model_name: str,
1616
api_key: Optional[str] = None,
17-
base_url: Optional[str] = None,
1817
**kwargs: Any,
1918
):
2019
super().__init__(model_name=model_name, **kwargs)
@@ -24,31 +23,33 @@ def __init__(
2423

2524
self._client = Groq(
2625
api_key=api_key,
27-
base_url=base_url,
2826
)
2927

3028
def generate(
3129
self,
3230
prompt: str,
33-
temperature: Optional[float] = None,
34-
max_tokens: Optional[int] = None,
31+
system_prompt: Optional[str] = None,
32+
messages: Optional[List[Dict[str, str]]] = None,
33+
**kwargs: Any,
3534
) -> str:
36-
payload = {
37-
'model': self._model_name,
38-
'messages': [{
39-
"role": "user",
40-
"content": prompt,
41-
}],
42-
}
43-
44-
if temperature is not None:
45-
payload['temperature'] = temperature
46-
47-
if max_tokens is not None:
48-
payload['max_tokens'] = max_tokens
49-
50-
message = self._client.chat.completions.create(
51-
**payload
52-
).choices[0].message.content
53-
35+
payload = self._generate_kwargs.copy()
36+
payload.update(kwargs)
37+
payload['model'] = self._model_name
38+
39+
if messages is not None:
40+
payload['messages'] = messages
41+
else:
42+
payload['messages'] = []
43+
if system_prompt:
44+
payload['messages'].append({
45+
'role': 'system',
46+
'content': system_prompt,
47+
})
48+
payload['messages'].append({
49+
'role': 'user',
50+
'content': prompt,
51+
})
52+
53+
completion = self._client.chat.completions.create(**payload)
54+
message = completion.choices[0].message.content
5455
return message

llmdk/providers/huggingface.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional
5+
from typing import Any, Optional, List, Dict
66

77
from huggingface_hub import InferenceClient
88

@@ -22,29 +22,34 @@ def __init__(
2222
api_key = env.get('HF_TOKEN')
2323

2424
self._client = InferenceClient(
25-
model_name,
25+
model=model_name,
2626
token=api_key,
2727
)
2828

2929
def generate(
3030
self,
3131
prompt: str,
32-
temperature: Optional[float] = None,
33-
max_tokens: Optional[int] = None,
32+
system_prompt: Optional[str] = None,
33+
messages: Optional[List[Dict[str, str]]] = None,
34+
**kwargs: Any,
3435
) -> str:
35-
payload = {
36-
'messages': [{
37-
"role": "user",
38-
"content": prompt,
39-
}],
40-
}
41-
42-
if temperature is not None:
43-
payload['temperature'] = temperature
44-
45-
max_tokens = max_tokens or self._max_tokens
46-
if max_tokens is not None:
47-
payload['max_tokens'] = max_tokens
36+
payload = self._generate_kwargs.copy()
37+
payload.update(kwargs)
38+
payload['model'] = self._model_name
39+
40+
if messages is not None:
41+
payload['messages'] = messages
42+
else:
43+
payload['messages'] = []
44+
if system_prompt:
45+
payload['messages'].append({
46+
'role': 'system',
47+
'content': system_prompt,
48+
})
49+
payload['messages'].append({
50+
'role': 'user',
51+
'content': prompt,
52+
})
4853

4954
completion = self._client.chat_completion(**payload)
5055
message = completion.choices[0].message.content

llmdk/providers/interface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4-
from typing import Optional
4+
from typing import Any, Optional
55

66

77
class LlmInterface:
88
def __init__(
99
self,
1010
model_name: str,
11-
max_tokens: Optional[int] = None,
11+
**kwargs: Any,
1212
):
1313
self._model_name = model_name
14-
self._max_tokens = max_tokens
14+
self._generate_kwargs = kwargs
1515

1616
@property
1717
def model_name(self) -> str:
@@ -22,5 +22,6 @@ def generate(
2222
prompt: str,
2323
temperature: Optional[float] = None,
2424
max_tokens: Optional[int] = None,
25+
**kwargs: Any,
2526
) -> str:
2627
raise NotImplementedError

llmdk/providers/ollama.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional
5+
from typing import Any, Optional, List, Dict
66

77
from ollama import Client
88

@@ -12,48 +12,55 @@
1212
class OllamaClient(LlmInterface):
1313
def __init__(
1414
self,
15-
base_url: str,
1615
model_name: str,
16+
base_url: Optional[str] = None,
1717
headers: Optional[dict] = None,
1818
options: Optional[dict] = None,
1919
**kwargs: Any,
2020
):
2121
super().__init__(model_name=model_name, **kwargs)
2222

23+
if not base_url:
24+
base_url = env.get('OLLAMA_API_URL') or 'http://localhost:11434'
25+
2326
self._client = Client(
2427
host=base_url,
25-
headers=headers,
28+
headers=headers or {},
2629
)
2730

28-
self._options = options
29-
if self._options is None:
30-
self._options = {}
31+
self._options = options or {}
3132

3233
def generate(
3334
self,
3435
prompt: str,
35-
temperature: Optional[float] = None,
36-
max_tokens: Optional[int] = None,
36+
system_prompt: Optional[str] = None,
37+
messages: Optional[List[Dict[str, str]]] = None,
38+
options: Optional[dict] = None,
39+
**kwargs: Any,
3740
) -> str:
38-
payload = {
39-
'model': self._model_name,
40-
'messages': [{
41-
"role": "user",
42-
"content": prompt,
43-
}],
44-
}
45-
46-
options = dict(self._options)
47-
payload['options'] = options
48-
49-
if temperature is not None:
50-
options['temperature'] = temperature
41+
payload = self._generate_kwargs.copy()
42+
payload.update(kwargs)
43+
payload['model'] = self._model_name
5144

52-
if max_tokens is not None:
53-
options['num_predict'] = max_tokens
45+
if messages is not None:
46+
payload['messages'] = messages
47+
else:
48+
payload['messages'] = []
49+
if system_prompt:
50+
payload['messages'].append({
51+
'role': 'system',
52+
'content': system_prompt,
53+
})
54+
payload['messages'].append({
55+
'role': 'user',
56+
'content': prompt,
57+
})
5458

55-
message = self._client.chat(
56-
**payload
57-
)['message']['content']
59+
merged_options = self._options.copy()
60+
if options:
61+
merged_options.update(options)
62+
payload['options'] = merged_options
5863

64+
response = self._client.chat(**payload)
65+
message = response['message']['content']
5966
return message

llmdk/providers/openai.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional
5+
from typing import Any, Optional, List, Dict
66

77
from openai import OpenAI
88

@@ -33,23 +33,27 @@ def __init__(
3333
def generate(
3434
self,
3535
prompt: str,
36-
temperature: Optional[float] = None,
37-
max_tokens: Optional[int] = None,
36+
system_prompt: Optional[str] = None,
37+
messages: Optional[List[Dict[str, str]]] = None,
38+
**kwargs: Any,
3839
) -> str:
39-
payload = {
40-
'model': self._model_name,
41-
'messages': [{
42-
"role": "user",
43-
"content": prompt,
44-
}],
45-
}
46-
47-
if temperature is not None:
48-
payload['temperature'] = temperature
49-
50-
max_tokens = max_tokens or self._max_tokens
51-
if max_tokens is not None:
52-
payload['max_tokens'] = max_tokens
40+
payload = self._generate_kwargs.copy()
41+
payload.update(kwargs)
42+
payload['model'] = self._model_name
43+
44+
if messages is not None:
45+
payload['messages'] = messages
46+
else:
47+
payload['messages'] = []
48+
if system_prompt:
49+
payload['messages'].append({
50+
'role': 'system',
51+
'content': system_prompt,
52+
})
53+
payload['messages'].append({
54+
'role': 'user',
55+
'content': prompt,
56+
})
5357

5458
completion = self._client.chat.completions.create(**payload)
5559
message = completion.choices[0].message.content

0 commit comments

Comments
 (0)