Skip to content

Commit 4ac66bd

Browse files
LiteLLM Minor Fixes and Improvements (09/07/2024) (BerriAI#5580)
* fix(litellm_logging.py): set completion_start_time_float to end_time_float if none Fixes BerriAI#5500 * feat(_init_.py): add new 'openai_text_completion_compatible_providers' list Fixes BerriAI#5558 Handles correctly routing fireworks ai calls when done via text completions * fix: fix linting errors * fix: fix linting errors * fix(openai.py): fix exception raised * fix(openai.py): fix error handling * fix(_redis.py): allow all supported arguments for redis cluster (BerriAI#5554) * Revert "fix(_redis.py): allow all supported arguments for redis cluster (BerriAI#5554)" (BerriAI#5583) This reverts commit f2191ef. * fix(router.py): return model alias w/ underlying deployment on router.get_model_list() Fixes BerriAI#5524 (comment) * test: handle flaky tests --------- Co-authored-by: Jonas Dittrich <[email protected]>
1 parent c86b333 commit 4ac66bd

14 files changed

+101
-34
lines changed

litellm/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,12 @@ def identify(event_details):
483483
"azure_ai",
484484
"github",
485485
]
486-
486+
openai_text_completion_compatible_providers: List = (
487+
[ # providers that support `/v1/completions`
488+
"together_ai",
489+
"fireworks_ai",
490+
]
491+
)
487492

488493
# well supported replicate llms
489494
replicate_models: List = [

litellm/litellm_core_utils/litellm_logging.py

+2
Original file line numberDiff line numberDiff line change
@@ -2329,6 +2329,8 @@ def get_standard_logging_object_payload(
23292329
completion_start_time_float = completion_start_time.timestamp()
23302330
elif isinstance(completion_start_time, float):
23312331
completion_start_time_float = completion_start_time
2332+
else:
2333+
completion_start_time_float = end_time_float
23322334
# clean up litellm hidden params
23332335
clean_hidden_params = StandardLoggingHiddenParams(
23342336
model_id=None,

litellm/llms/OpenAI/openai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ async def async_streaming(
12631263

12641264
error_headers = getattr(e, "headers", None)
12651265
if response is not None and hasattr(response, "text"):
1266+
error_headers = getattr(e, "headers", None)
12661267
raise OpenAIError(
12671268
status_code=500,
12681269
message=f"{str(e)}\n\nOriginal Response: {response.text}",
@@ -1800,12 +1801,11 @@ def completion(
18001801
headers: Optional[dict] = None,
18011802
):
18021803
super().completion()
1803-
exception_mapping_worked = False
18041804
try:
18051805
if headers is None:
18061806
headers = self.validate_environment(api_key=api_key)
18071807
if model is None or messages is None:
1808-
raise OpenAIError(status_code=422, message=f"Missing model or messages")
1808+
raise OpenAIError(status_code=422, message="Missing model or messages")
18091809

18101810
if (
18111811
len(messages) > 0

litellm/llms/azure_text.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,10 @@ def completion(
162162
client=None,
163163
):
164164
super().completion()
165-
exception_mapping_worked = False
166165
try:
167166
if model is None or messages is None:
168167
raise AzureOpenAIError(
169-
status_code=422, message=f"Missing model or messages"
168+
status_code=422, message="Missing model or messages"
170169
)
171170

172171
max_retries = optional_params.pop("max_retries", 2)
@@ -293,7 +292,10 @@ def completion(
293292
"api-version", api_version
294293
)
295294

296-
response = azure_client.completions.create(**data, timeout=timeout) # type: ignore
295+
raw_response = azure_client.completions.with_raw_response.create(
296+
**data, timeout=timeout
297+
)
298+
response = raw_response.parse()
297299
stringified_response = response.model_dump()
298300
## LOGGING
299301
logging_obj.post_call(
@@ -380,13 +382,15 @@ async def acompletion(
380382
"complete_input_dict": data,
381383
},
382384
)
383-
response = await azure_client.completions.create(**data, timeout=timeout)
385+
raw_response = await azure_client.completions.with_raw_response.create(
386+
**data, timeout=timeout
387+
)
388+
response = raw_response.parse()
384389
return openai_text_completion_config.convert_to_chat_model_response_object(
385390
response_object=response.model_dump(),
386391
model_response_object=model_response,
387392
)
388393
except AzureOpenAIError as e:
389-
exception_mapping_worked = True
390394
raise e
391395
except Exception as e:
392396
status_code = getattr(e, "status_code", 500)

litellm/main.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,9 @@ def completion(
12091209
custom_llm_provider == "text-completion-openai"
12101210
or "ft:babbage-002" in model
12111211
or "ft:davinci-002" in model # support for finetuned completion models
1212+
or custom_llm_provider
1213+
in litellm.openai_text_completion_compatible_providers
1214+
and kwargs.get("text_completion") is True
12121215
):
12131216
openai.api_type = "openai"
12141217

@@ -4099,8 +4102,8 @@ def process_prompt(i, individual_prompt):
40994102

41004103
kwargs.pop("prompt", None)
41014104

4102-
if (
4103-
_model is not None and custom_llm_provider == "openai"
4105+
if _model is not None and (
4106+
custom_llm_provider == "openai"
41044107
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
41054108
if _model not in litellm.open_ai_chat_completion_models:
41064109
model = "text-completion-openai/" + _model

litellm/proxy/_new_secret_config.yaml

+6-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
model_list:
2-
- model_name: "anthropic/claude-3-5-sonnet-20240620"
2+
- model_name: "gpt-turbo"
33
litellm_params:
4-
model: anthropic/claude-3-5-sonnet-20240620
5-
# api_base: http://0.0.0.0:9000
6-
- model_name: gpt-3.5-turbo
7-
litellm_params:
8-
model: openai/*
4+
model: azure/chatgpt-v-2
5+
api_key: os.environ/AZURE_API_KEY
6+
api_base: os.environ/AZURE_API_BASE
97

10-
litellm_settings:
11-
success_callback: ["s3"]
12-
s3_callback_params:
13-
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
14-
s3_region_name: us-west-2 # AWS Region Name for S3
15-
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
16-
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
8+
router_settings:
9+
model_group_alias: {"gpt-4": "gpt-turbo"}

litellm/proxy/health_check.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import logging
55
import random
6-
from typing import Optional
6+
from typing import List, Optional
77

88
import litellm
99
from litellm._logging import print_verbose
@@ -36,6 +36,25 @@ def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
3636
)
3737

3838

39+
def filter_deployments_by_id(
40+
model_list: List,
41+
) -> List:
42+
seen_ids = set()
43+
filtered_deployments = []
44+
45+
for deployment in model_list:
46+
_model_info = deployment.get("model_info") or {}
47+
_id = _model_info.get("id") or None
48+
if _id is None:
49+
continue
50+
51+
if _id not in seen_ids:
52+
seen_ids.add(_id)
53+
filtered_deployments.append(deployment)
54+
55+
return filtered_deployments
56+
57+
3958
async def _perform_health_check(model_list: list, details: Optional[bool] = True):
4059
"""
4160
Perform a health check for each model in the list.
@@ -105,6 +124,9 @@ async def perform_health_check(
105124
_new_model_list = [x for x in model_list if x["model_name"] == model]
106125
model_list = _new_model_list
107126

127+
model_list = filter_deployments_by_id(
128+
model_list=model_list
129+
) # filter duplicate deployments (e.g. when model alias'es are used)
108130
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
109131
model_list, details
110132
)

litellm/proxy/management_helpers/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ async def add_new_member(
109109
where={"user_id": user_info.user_id}, # type: ignore
110110
data={"teams": {"push": [team_id]}},
111111
)
112-
113-
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
112+
if _returned_user is not None:
113+
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
114114
elif len(existing_user_row) > 1:
115115
raise HTTPException(
116116
status_code=400,

litellm/router.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -4556,6 +4556,27 @@ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
45564556
ids.append(id)
45574557
return ids
45584558

4559+
def _get_all_deployments(
4560+
self, model_name: str, model_alias: Optional[str] = None
4561+
) -> List[DeploymentTypedDict]:
4562+
"""
4563+
Return all deployments of a model name
4564+
4565+
Used for accurate 'get_model_list'.
4566+
"""
4567+
4568+
returned_models: List[DeploymentTypedDict] = []
4569+
for model in self.model_list:
4570+
if model["model_name"] == model_name:
4571+
if model_alias is not None:
4572+
alias_model = copy.deepcopy(model)
4573+
alias_model["model_name"] = model_name
4574+
returned_models.append(alias_model)
4575+
else:
4576+
returned_models.append(model)
4577+
4578+
return returned_models
4579+
45594580
def get_model_names(self) -> List[str]:
45604581
"""
45614582
Returns all possible model names for router.
@@ -4567,24 +4588,26 @@ def get_model_names(self) -> List[str]:
45674588
def get_model_list(
45684589
self, model_name: Optional[str] = None
45694590
) -> Optional[List[DeploymentTypedDict]]:
4591+
"""
4592+
Includes router model_group_alias'es as well
4593+
"""
45704594
if hasattr(self, "model_list"):
45714595
returned_models: List[DeploymentTypedDict] = []
45724596

45734597
for model_alias, model_value in self.model_group_alias.items():
4574-
model_alias_item = DeploymentTypedDict(
4575-
model_name=model_alias,
4576-
litellm_params=LiteLLMParamsTypedDict(model=model_value),
4598+
returned_models.extend(
4599+
self._get_all_deployments(
4600+
model_name=model_value, model_alias=model_alias
4601+
)
45774602
)
4578-
returned_models.append(model_alias_item)
45794603

45804604
if model_name is None:
45814605
returned_models += self.model_list
45824606

45834607
return returned_models
45844608

45854609
for model in self.model_list:
4586-
if model["model_name"] == model_name:
4587-
returned_models.append(model)
4610+
returned_models.extend(self._get_all_deployments(model_name=model_name))
45884611

45894612
return returned_models
45904613
return None

litellm/tests/test_completion.py

+2
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,8 @@ async def test_model_function_invoke(model, sync_mode, api_key, api_base):
626626
response = await litellm.acompletion(**data)
627627

628628
print(f"response: {response}")
629+
except litellm.InternalServerError:
630+
pass
629631
except litellm.RateLimitError as e:
630632
pass
631633
except Exception as e:

litellm/tests/test_exceptions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def _pre_call_utils(
864864
data["messages"] = [{"role": "user", "content": "Hello world"}]
865865
if streaming is True:
866866
data["stream"] = True
867-
mapped_target = client.chat.completions.with_raw_response
867+
mapped_target = client.chat.completions.with_raw_response # type: ignore
868868
if sync_mode:
869869
original_function = litellm.completion
870870
else:
@@ -873,7 +873,7 @@ def _pre_call_utils(
873873
data["prompt"] = "Hello world"
874874
if streaming is True:
875875
data["stream"] = True
876-
mapped_target = client.completions.with_raw_response
876+
mapped_target = client.completions.with_raw_response # type: ignore
877877
if sync_mode:
878878
original_function = litellm.text_completion
879879
else:

litellm/tests/test_function_calling.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_current_weather(location, unit="fahrenheit"):
5252
# "anthropic.claude-3-sonnet-20240229-v1:0",
5353
],
5454
)
55+
@pytest.mark.flaky(retries=3, delay=1)
5556
def test_aaparallel_function_call(model):
5657
try:
5758
litellm.set_verbose = True

litellm/tests/test_text_completion.py

+11
Original file line numberDiff line numberDiff line change
@@ -4239,3 +4239,14 @@ def test_completion_vllm():
42394239
mock_call.assert_called_once()
42404240

42414241
assert "hello" in mock_call.call_args.kwargs["extra_body"]
4242+
4243+
4244+
def test_completion_fireworks_ai_multiple_choices():
4245+
litellm.set_verbose = True
4246+
response = litellm.text_completion(
4247+
model="fireworks_ai/llama-v3p1-8b-instruct",
4248+
prompt=["halo", "hi", "halo", "hi"],
4249+
)
4250+
print(response.choices)
4251+
4252+
assert len(response.choices) == 4

proxy_server_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ router_settings:
148148
redis_password: os.environ/REDIS_PASSWORD
149149
redis_port: os.environ/REDIS_PORT
150150
enable_pre_call_checks: true
151+
model_group_alias: {"my-special-fake-model-alias-name": "fake-openai-endpoint-3"}
151152

152153
general_settings:
153154
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

0 commit comments

Comments
 (0)