Skip to content

Commit b44299c

Browse files
update model split tests with ui
1 parent 66bfd70 commit b44299c

8 files changed

+80
-3
lines changed

litellm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def identify(event_details):
286286
Logging,
287287
acreate,
288288
get_model_list,
289+
completion_with_split_tests
289290
)
290291
from .main import * # type: ignore
291292
from .integrations import *
48 Bytes
Binary file not shown.
40 Bytes
Binary file not shown.
3.31 KB
Binary file not shown.

litellm/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def completion(
9494
custom_api_base=None,
9595
litellm_call_id=None,
9696
litellm_logging_obj=None,
97-
completion_call_id=None, # this is an optional param to tag individual completion calls
97+
id=None, # this is an optional param to tag individual completion calls
9898
# model specific optional params
9999
# used by text-bison only
100100
top_k=40,
@@ -154,7 +154,7 @@ def completion(
154154
custom_api_base=custom_api_base,
155155
litellm_call_id=litellm_call_id,
156156
model_alias_map=litellm.model_alias_map,
157-
completion_call_id=completion_call_id
157+
completion_call_id=id
158158
)
159159
logging.update_environment_variables(optional_params=optional_params, litellm_params=litellm_params)
160160
if custom_llm_provider == "azure":

litellm/tests/test_split_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#### What this tests ####
2+
# This tests the 'completion_with_split_tests' function to enable a/b testing between llm models
3+
4+
import sys, os
5+
import traceback
6+
7+
sys.path.insert(
8+
0, os.path.abspath("../..")
9+
) # Adds the parent directory to the system path
10+
import litellm
11+
from litellm import completion_with_split_tests
12+
litellm.set_verbose = True
13+
split_per_model = {
14+
"gpt-4": 0.7,
15+
"claude-instant-1.2": 0.3
16+
}
17+
18+
messages = [{ "content": "Hello, how are you?","role": "user"}]
19+
20+
# print(completion_with_split_tests(models=split_per_model, messages=messages))
21+
22+
# test with client
23+
24+
print(completion_with_split_tests(models=split_per_model, messages=messages, use_client=True, id=1234))

litellm/utils.py

+52
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,58 @@ async def stream_to_string(generator):
18981898
return response
18991899

19001900

1901+
########## experimental completion variants ############################
1902+
1903+
def get_model_split_test(models, completion_call_id):
1904+
global last_fetched_at
1905+
try:
1906+
# make the api call
1907+
last_fetched_at = time.time()
1908+
print(f"last_fetched_at: {last_fetched_at}")
1909+
response = requests.post(
1910+
#http://api.litellm.ai
1911+
url="http://api.litellm.ai/get_model_split_test", # get the updated dict from table or update the table with the dict
1912+
headers={"content-type": "application/json"},
1913+
data=json.dumps({"completion_call_id": completion_call_id, "models": models}),
1914+
)
1915+
print_verbose(f"get_model_list response: {response.text}")
1916+
data = response.json()
1917+
# update model list
1918+
split_test_models = data["split_test_models"]
1919+
# update environment - if required
1920+
threading.Thread(target=get_all_keys, args=()).start()
1921+
return split_test_models
1922+
except:
1923+
print_verbose(
1924+
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
1925+
)
1926+
1927+
1928+
def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs):
1929+
"""
1930+
Example Usage:
1931+
1932+
models = {
1933+
"gpt-4": 0.7,
1934+
"huggingface/wizard-coder": 0.3
1935+
}
1936+
messages = [{ "content": "Hello, how are you?","role": "user"}]
1937+
completion_with_split_tests(models=models, messages=messages)
1938+
"""
1939+
import random
1940+
if use_client:
1941+
if "id" not in kwargs or kwargs["id"] is None:
1942+
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
1943+
# get the most recent model split list from server
1944+
models = get_model_split_test(models=models, completion_call_id=kwargs["id"])
1945+
1946+
try:
1947+
selected_llm = random.choices(list(models.keys()), weights=list(models.values()))[0]
1948+
except:
1949+
traceback.print_exc()
1950+
raise ValueError("""models does not follow the required format - {'model_name': 'split_percentage'}, e.g. {'gpt-4': 0.7, 'huggingface/wizard-coder': 0.3}""")
1951+
return litellm.completion(model=selected_llm, messages=messages, **kwargs)
1952+
19011953
def completion_with_fallbacks(**kwargs):
19021954
response = None
19031955
rate_limited_models = set()

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "litellm"
3-
version = "0.1.511"
3+
version = "0.1.512"
44
description = "Library to easily interface with LLM API providers"
55
authors = ["BerriAI"]
66
license = "MIT License"

0 commit comments

Comments
 (0)