Skip to content

Commit 29086f5

Browse files
authored
Merge pull request #272 from rhesis-ai/feature/add-llm-providers
Extend LLMService: - LiteLLM support - HuggingFace support
2 parents 2b8a32f + f63eb40 commit 29086f5

File tree

12 files changed

+950
-56
lines changed

12 files changed

+950
-56
lines changed

sdk/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ dependencies = [
3939
"tomli>=2.2.1",
4040
"tomli-w>=1.2.0",
4141
"litellm>=1.76.0",
42+
"torch>=2.8.0",
43+
"transformers>=4.56.0",
4244
]
4345

4446
[project.license]

sdk/src/rhesis/sdk/errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Error messages used in the SDK. This file makes it easier to find and edit the used Error messages
2+
3+
# For this file too long lines are fine for better readability
4+
# flake8: noqa: E501
5+
6+
7+
# LLM Errors
8+
NO_MODEL_NAME_PROVIDED = "The model name is not valid. Please provide a non-empty string."
9+
HUGGINGFACE_MODEL_NOT_LOADED = "Hugging Face model is not loaded. Set auto_loading=True to load it manually using `load_model()`."
10+
MODEL_RELOAD_WARNING = "WARNING: The model {} is already loaded. It will be reloaded."
11+
WARNING_TOKENIZER_ALREADY_LOADED_RELOAD = (
12+
"WARNING: The tokenizer for model {} is already loaded. It will be reloaded."
13+
)

sdk/src/rhesis/sdk/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from rhesis.sdk.models.providers.gemini import GeminiLLM
2+
from rhesis.sdk.models.providers.huggingface import HuggingFaceLLM
3+
from rhesis.sdk.models.providers.litellm import LiteLLM
24
from rhesis.sdk.models.providers.native import RhesisLLM
35
from rhesis.sdk.models.providers.openai import OpenAILLM
46

5-
__all__ = ["RhesisLLM", "GeminiLLM", "OpenAILLM"]
7+
__all__ = ["RhesisLLM", "HuggingFaceLLM", "LiteLLM", "GeminiLLM", "OpenAILLM"]

sdk/src/rhesis/sdk/models/providers/gemini.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,15 @@
99
1010
"""
1111

12-
import json
1312
import os
14-
from typing import Optional, Union
1513

16-
from litellm import completion
17-
from pydantic import BaseModel
18-
19-
from rhesis.sdk.models.base import BaseLLM
20-
from rhesis.sdk.models.utils import validate_llm_response
14+
from rhesis.sdk.models.providers.litellm import LiteLLM
2115

2216
PROVIDER = "gemini"
2317
DEFAULT_MODEL_NAME = "gemini-2.0-flash"
2418

2519

26-
class GeminiLLM(BaseLLM):
20+
class GeminiLLM(LiteLLM):
2721
def __init__(self, model_name: str = DEFAULT_MODEL_NAME, api_key=None, **kwargs):
2822
"""
2923
GeminiLLM: Google Gemini LLM Provider
@@ -47,30 +41,7 @@ def __init__(self, model_name: str = DEFAULT_MODEL_NAME, api_key=None, **kwargs)
4741
Raises:
4842
ValueError: If the API key is not set.
4943
"""
50-
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
51-
if self.api_key is None:
44+
api_key = api_key or os.getenv("GEMINI_API_KEY")
45+
if api_key is None:
5246
raise ValueError("GEMINI_API_KEY is not set")
53-
super().__init__(model_name)
54-
55-
def load_model(self, *args, **kwargs):
56-
return None # LiteLLM handles model loading internally
57-
58-
def generate(
59-
self, prompt: str, schema: Optional[BaseModel] = None, *args, **kwargs
60-
) -> Union[str, dict]:
61-
messages = [{"role": "user", "content": prompt}]
62-
response = completion(
63-
model=f"{PROVIDER}/{self.model_name}",
64-
messages=messages,
65-
response_format=schema,
66-
api_key=self.api_key,
67-
*args,
68-
**kwargs,
69-
)
70-
response_content = response.choices[0].message.content
71-
if schema:
72-
response_content = json.loads(response_content)
73-
validate_llm_response(response_content, schema)
74-
return response_content
75-
else:
76-
return response_content
47+
super().__init__(PROVIDER + "/" + model_name, api_key=api_key)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import gc
2+
from typing import Optional
3+
4+
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
from rhesis.sdk.errors import (
8+
HUGGINGFACE_MODEL_NOT_LOADED,
9+
MODEL_RELOAD_WARNING,
10+
NO_MODEL_NAME_PROVIDED,
11+
WARNING_TOKENIZER_ALREADY_LOADED_RELOAD,
12+
)
13+
from rhesis.sdk.models.base import BaseLLM
14+
15+
16+
class HuggingFaceLLM(BaseLLM):
17+
"""
18+
A standard implementation of a model available on Hugging Face's model hub.
19+
This class provides a basic structure for loading and using models from Hugging Face.
20+
It can be extended to include specific models or configurations as needed.
21+
A complete implementation may be needed for unusual models or configurations.
22+
Example usage:
23+
>>> llm = HugginFaceLLM("crumb/nano-mistral")
24+
>>> result = llm.generate("Tell me a joke.")
25+
>>> print(result)
26+
"""
27+
28+
def __init__(
29+
self, model_name: str, auto_loading: bool = True, default_kwargs: Optional[dict] = None
30+
):
31+
"""
32+
Initialize the model with the given name and location.
33+
Args:
34+
model_name: The location to pull the model from
35+
auto_loading: Whether to automatically load the model on initialization.
36+
If turned off, manual loading is needed. Allows lazy loading.
37+
"""
38+
if not model_name or not isinstance(model_name, str) or model_name.strip() == "":
39+
raise ValueError(NO_MODEL_NAME_PROVIDED)
40+
41+
self.model_name = model_name
42+
self.default_kwargs = default_kwargs
43+
44+
self.model = None
45+
self.tokenizer = None
46+
self.device = None
47+
48+
if auto_loading:
49+
(self.model, self.tokenizer, self.device) = self.load_model()
50+
51+
def __del__(self):
52+
"""
53+
If the model or tokenizer is loaded, unload them to free up resources.
54+
Unloading manually is cleaner, but this is a fallback.
55+
"""
56+
if self.model is not None or self.tokenizer is not None:
57+
self.unload_model()
58+
59+
def load_model(self):
60+
"""
61+
Load the model and tokenizer from the specified location.
62+
"""
63+
if self.model is not None:
64+
print(MODEL_RELOAD_WARNING.format(self.model_name))
65+
if self.tokenizer is not None:
66+
print(WARNING_TOKENIZER_ALREADY_LOADED_RELOAD.format(self.model_name))
67+
68+
model = AutoModelForCausalLM.from_pretrained(
69+
self.model_name,
70+
)
71+
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
model.to(device)
74+
75+
tokenizer = AutoTokenizer.from_pretrained(
76+
self.model_name,
77+
)
78+
79+
return model, tokenizer, device
80+
81+
def unload_model(self):
82+
"""
83+
Aggressively unload the model and tokenizer to free up GPU/CPU memory.
84+
This handles edge cases such as partial allocations and hanging references.
85+
"""
86+
# Unload model
87+
try:
88+
if self.model is not None:
89+
try:
90+
self.model.cpu()
91+
except Exception:
92+
pass
93+
try:
94+
# Clear state_dict if available
95+
if hasattr(self.model, "state_dict"):
96+
sd = self.model.state_dict()
97+
for k in list(sd.keys()):
98+
sd.pop(k)
99+
del sd
100+
except Exception:
101+
pass
102+
del self.model
103+
self.model = None
104+
except Exception:
105+
pass
106+
107+
# Unload tokenizer
108+
try:
109+
if self.tokenizer is not None:
110+
try:
111+
if hasattr(self.tokenizer, "backend_tokenizer"):
112+
self.tokenizer.backend_tokenizer = None
113+
except Exception:
114+
pass
115+
del self.tokenizer
116+
self.tokenizer = None
117+
except Exception:
118+
pass
119+
120+
# Force cleanup
121+
torch.cuda.empty_cache()
122+
gc.collect()
123+
torch.cuda.empty_cache()
124+
125+
def generate(
126+
self,
127+
prompt: str,
128+
system_prompt: Optional[str] = None,
129+
**kwargs,
130+
) -> str:
131+
"""
132+
Generate a response from the model
133+
"""
134+
135+
# check model and tokenizer
136+
if self.model is None or self.tokenizer is None:
137+
raise RuntimeError(HUGGINGFACE_MODEL_NOT_LOADED)
138+
139+
# format arguments
140+
if self.default_kwargs:
141+
kwargs = {**self.default_kwargs, **kwargs}
142+
143+
if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None:
144+
messages = (
145+
[
146+
{"role": "system", "content": system_prompt},
147+
{"role": "user", "content": prompt},
148+
]
149+
if system_prompt
150+
else [
151+
{"role": "user", "content": prompt},
152+
]
153+
)
154+
inputs = self.tokenizer.apply_chat_template(
155+
messages, add_generation_prompt=True, return_dict=True, return_tensors="pt"
156+
).to(self.device)
157+
else:
158+
messages = f"{system_prompt}\n\n{prompt}" if system_prompt else prompt
159+
inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
160+
161+
# generate response
162+
output_ids = self.model.generate(
163+
**inputs,
164+
pad_token_id=self.tokenizer.eos_token_id,
165+
eos_token_id=self.tokenizer.eos_token_id,
166+
**kwargs,
167+
)
168+
169+
completion = self.tokenizer.decode(
170+
output_ids[0][inputs["input_ids"].shape[1] :], # only take the newly generated content
171+
skip_special_tokens=True,
172+
clean_up_tokenization_spaces=True,
173+
).strip()
174+
175+
return completion
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
from typing import Optional
3+
4+
from litellm import completion
5+
from pydantic import BaseModel
6+
7+
from rhesis.sdk.errors import NO_MODEL_NAME_PROVIDED
8+
from rhesis.sdk.models.base import BaseLLM
9+
from rhesis.sdk.models.utils import validate_llm_response
10+
11+
12+
class LiteLLM(BaseLLM):
13+
def __init__(self, model_name: str, api_key: Optional[str] = None):
14+
"""
15+
LiteLLM: LiteLLM Provider for Model inference
16+
17+
This class provides an interface for interacting with all models accessible through LiteLLM.
18+
19+
Args:
20+
model_name (str): The name of the model to use including the provider.
21+
api_key (Optional[str]): The API key for authentication.
22+
If not provided, LiteLLM will handle it internally.
23+
24+
Usage:
25+
>>> llm = LiteLLM(model_name="provider/model", api_key="your_api_key")
26+
>>> result = llm.generate(prompt="Tell me a joke.", system_prompt="You are funny")
27+
>>> print(result)
28+
29+
If a Pydantic schema is provided to `generate`, the response will be validated and returned
30+
as a dict.
31+
"""
32+
self.api_key = api_key # LiteLLM will handle Environment Retrieval
33+
if not model_name or not isinstance(model_name, str) or model_name.strip() == "":
34+
raise ValueError(NO_MODEL_NAME_PROVIDED)
35+
super().__init__(model_name)
36+
37+
def load_model(self):
38+
"""
39+
LiteLLM handles model loading internally, so no loading is needed
40+
"""
41+
pass
42+
43+
def generate(
44+
self,
45+
prompt: str,
46+
system_prompt: Optional[str] = None,
47+
schema: Optional[BaseModel] = None,
48+
*args,
49+
**kwargs,
50+
):
51+
"""
52+
Run a chat completion using LiteLLM, returning the response.
53+
The schema will be used to validate the response if provided.
54+
"""
55+
# handle system prompt
56+
messages = (
57+
[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
58+
if system_prompt
59+
else [{"role": "user", "content": prompt}]
60+
)
61+
62+
# Call the completion function passing given arguments
63+
response = completion(
64+
model=self.model_name,
65+
messages=messages,
66+
response_format=schema,
67+
api_key=self.api_key,
68+
*args,
69+
**kwargs,
70+
)
71+
72+
response_content = response.choices[0].message.content
73+
if schema:
74+
response_content = json.loads(response_content)
75+
validate_llm_response(response_content, schema)
76+
return response_content
77+
else:
78+
return response_content
Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from litellm import completion
1+
"""
2+
Only partially supported. No testing done yet
3+
"""
24

3-
from rhesis.sdk.models.base import BaseLLM
5+
import os
46

7+
from rhesis.sdk.models.providers.litellm import LiteLLM
58

6-
class OpenAILLM(BaseLLM):
7-
def load_model(self, *args, **kwargs):
8-
return None # LiteLLM handles model loading internally
9+
DEFAULT_MODEL_NAME = "gpt-4"
910

10-
def generate(self, prompt: str, *args, **kwargs) -> str:
11-
messages = [{"role": "user", "content": prompt}]
12-
response = completion(model=self.model_name, messages=messages, *args, **kwargs)
13-
return response.choices[0].message.content
1411

15-
16-
if __name__ == "__main__":
17-
openai = OpenAILLM(model_name="gpt-4")
18-
print(openai.generate("Hello, how are you?"))
12+
class OpenAILLM(LiteLLM):
13+
def __init__(self, model_name=DEFAULT_MODEL_NAME, api_key=None):
14+
api_key = api_key or os.getenv("OPENAI_API_KEY")
15+
if api_key is None:
16+
raise ValueError("OPENAI_API_KEY is not set")
17+
super().__init__(model_name, api_key=api_key)

0 commit comments

Comments
 (0)