Skip to content

Commit b1726cf

Browse files
committed
Add streaming support and refactor client interface
- Implement streaming functionality for all providers - Update README with streaming examples and improved usage instructions - Refactor LlmInterface to support both generate and stream methods - Remove vLLM provider - Improve type hints and error handling - Update client initialization in examples
1 parent c390365 commit b1726cf

File tree

9 files changed

+239
-89
lines changed

9 files changed

+239
-89
lines changed

README.md

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,56 +15,127 @@
1515
pip install llmdk
1616
```
1717

18-
# Usage
18+
# Basic Usage
19+
20+
## Client
21+
22+
```python
23+
from llmdk import Llmdk, Providers
24+
25+
# You can also set OPENAI_API_KEY
26+
client = Llmdk(
27+
provider=Providers.OPENAI,
28+
model_name='gpt-4o-mini',
29+
# api_key='***',
30+
)
31+
```
32+
33+
## Generate
34+
35+
### Prompt
36+
37+
```python
38+
output = client.generate(
39+
'Who are you?',
40+
# system='Write in Portuguese.',
41+
)
42+
```
43+
44+
### List of messages
45+
46+
```python
47+
output = client.generate(
48+
messages=[
49+
# {'role': 'system', 'content': 'Write in Portuguese.'},
50+
{'role': 'user', 'content': 'Who are you?'},
51+
],
52+
)
53+
```
54+
55+
## Stream
56+
57+
### Prompt
58+
59+
```python
60+
for chunk in client.stream(
61+
'Who are you?',
62+
# system='Write in Portuguese.',
63+
):
64+
print(chunk, end='', flush=True)
65+
```
66+
67+
### List of messages
68+
69+
```python
70+
for chunk in client.stream([
71+
# {'role': 'system', 'content': 'Write in Portuguese.'},
72+
{'role': 'user', 'content': 'Who are you?'},
73+
]):
74+
print(chunk, end='', flush=True)
75+
```
76+
77+
# Supported Providers
1978

2079
## Anthropic
80+
2181
```python
2282
from llmdk import Llmdk, Providers
2383

2484
# You can also set ANTHROPIC_API_KEY
25-
client = Llmdk(Providers.ANTHROPIC, 'claude-3-5-sonnet-20240620', api_key='***')
26-
output = client.generate('Who are you?')
85+
client = Llmdk(
86+
provider=Providers.ANTHROPIC,
87+
model_name='claude-3-5-sonnet-20240620',
88+
# api_key='***',
89+
)
2790
```
2891

2992
## Groq
93+
3094
```python
3195
from llmdk import Llmdk, Providers
3296

3397
# You can also set GROQ_API_KEY
34-
client = Llmdk(Providers.GROQ, 'llama-3.1-70b-versatile', api_key='***')
35-
output = client.generate('Who are you?')
98+
client = Llmdk(
99+
provider=Providers.GROQ,
100+
model_name='llama-3.1-70b-versatile',
101+
# api_key='***',
102+
)
36103
```
37104

38105
## HuggingFace
106+
39107
```python
40108
from llmdk import Llmdk, Providers
41109

42110
# You can also set HF_TOKEN
43-
client = Llmdk(Providers.HUGGINGFACE, 'meta-llama/Meta-Llama-3.1-70B-Instruct', api_key='***')
44-
output = client.generate('Who are you?')
111+
client = Llmdk(
112+
provider=Providers.HUGGINGFACE,
113+
model_name='meta-llama/Meta-Llama-3.1-70B-Instruct',
114+
# api_key='***',
115+
)
45116
```
46117

47118
## Ollama
119+
48120
```python
49121
from llmdk import Llmdk, Providers
50122

51-
client = Llmdk(Providers.OLLAMA, 'llama3.1:8b', base_url='http://...')
52-
output = client.generate('Who are you?')
123+
client = Llmdk(
124+
provider=Providers.OLLAMA,
125+
model_name='llama3.2:1b',
126+
# base_url='http://localhost:11434',
127+
)
53128
```
54129

55130
## OpenAI
56-
```python
57-
from llmdk import Llmdk, Providers
58131

59-
# You can also set OPENAI_API_KEY
60-
client = Llmdk(Providers.OPENAI, 'gpt-4o-2024-08-06', api_key='***')
61-
output = client.generate('Who are you?')
62-
```
63-
64-
## vLLM
65132
```python
66133
from llmdk import Llmdk, Providers
67134

68-
client = Llmdk(Providers.VLLM, base_url='http://...')
69-
output = client.generate('Who are you?')
135+
# You can also set OPENAI_API_KEY
136+
client = Llmdk(
137+
provider=Providers.OPENAI,
138+
model_name='gpt-4o-mini',
139+
# api_key='***',
140+
)
70141
```

llmdk/llmdk.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4+
from collections.abc import Iterator
45
from enum import Enum
5-
6+
from typing import Any, Dict, List, Optional
67
from llmdk.providers.anthropic import AnthropicClient
78
from llmdk.providers.groq import GroqClient
89
from llmdk.providers.huggingface import HuggingFaceClient
910
from llmdk.providers.ollama import OllamaClient
1011
from llmdk.providers.openai import OpenAiClient
11-
from llmdk.providers.vllm import VllmClient
1212

1313

1414
class Providers(Enum):
@@ -17,7 +17,6 @@ class Providers(Enum):
1717
HUGGINGFACE = 'huggingface'
1818
OLLAMA = 'ollama'
1919
OPENAI = 'openai'
20-
VLLM = 'vllm'
2120

2221

2322
class Llmdk:
@@ -80,17 +79,22 @@ def __init__(
8079
)
8180
return
8281

83-
if (
84-
provider == Providers.VLLM
85-
or provider == Providers.VLLM.value
86-
):
87-
self._client = VllmClient(
88-
base_url=base_url,
89-
)
90-
return
91-
9282
raise ValueError(f"Provider {provider} is not supported")
9383

9484
# Fallback to the original client
9585
def __getattr__(self, name):
9686
return getattr(self._client, name)
87+
88+
def stream(
89+
self,
90+
prompt: str,
91+
system: Optional[str] = None,
92+
messages: Optional[List[Dict[str, str]]] = None,
93+
**kwargs: Any,
94+
) -> Iterator[str]:
95+
return self._client.stream(
96+
prompt,
97+
system=system,
98+
messages=messages,
99+
**kwargs,
100+
)

llmdk/providers/anthropic.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional, Dict
5+
from typing import Any, Iterator, Optional, Dict
66
from anthropic import Anthropic
77
from llmdk.providers.interface import LlmInterface
88

@@ -25,7 +25,30 @@ def __init__(
2525
base_url=base_url,
2626
)
2727

28+
def _prepare_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
29+
if 'max_tokens' not in payload:
30+
payload['max_tokens'] = 4096 # Required by Anthropic
31+
32+
# Check for system message and move it to the system property
33+
messages = payload.get('messages', [])
34+
if messages and messages[0]['role'] == 'system':
35+
payload['system'] = messages[0]['content']
36+
payload['messages'] = messages[1:]
37+
38+
return payload
39+
2840
def _execute_request(self, payload: Dict[str, Any]) -> str:
29-
payload['max_tokens'] = 4096 # Required
30-
completion = self._client.messages.create(**payload)
41+
prepared_payload = self._prepare_payload(payload)
42+
completion = self._client.messages.create(**prepared_payload)
3143
return completion.content[0].text
44+
45+
def _execute_stream_request(
46+
self,
47+
payload: Dict[str, Any],
48+
) -> Iterator[str]:
49+
prepared_payload = self._prepare_payload(payload)
50+
51+
with self._client.messages.stream(**prepared_payload) as stream:
52+
for message in stream:
53+
if message.type == 'content_block_delta':
54+
yield message.delta.text

llmdk/providers/groq.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional, Dict
5+
from typing import Any, Iterator, Optional, Dict
66
from groq import Groq
77
from llmdk.providers.interface import LlmInterface
88

@@ -26,3 +26,12 @@ def __init__(
2626
def _execute_request(self, payload: Dict[str, Any]) -> str:
2727
completion = self._client.chat.completions.create(**payload)
2828
return completion.choices[0].message.content
29+
30+
def _execute_stream_request(
31+
self,
32+
payload: Dict[str, Any],
33+
) -> Iterator[str]:
34+
payload['stream'] = True
35+
for chunk in self._client.chat.completions.create(**payload):
36+
if chunk.choices[0].delta.content is not None:
37+
yield chunk.choices[0].delta.content

llmdk/providers/huggingface.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from os import environ as env
5-
from typing import Any, Optional, Dict
5+
from typing import Any, Iterator, Optional, Dict
66
from huggingface_hub import InferenceClient
77
from llmdk.providers.interface import LlmInterface
88

@@ -27,3 +27,12 @@ def __init__(
2727
def _execute_request(self, payload: Dict[str, Any]) -> str:
2828
completion = self._client.chat_completion(**payload)
2929
return completion.choices[0].message.content
30+
31+
def _execute_stream_request(
32+
self,
33+
payload: Dict[str, Any],
34+
) -> Iterator[str]:
35+
payload['stream'] = True
36+
output = self._client.chat.completions.create(**payload)
37+
for chunk in output:
38+
yield chunk.choices[0].delta.content

0 commit comments

Comments
 (0)