Skip to content

Commit 839878f

Browse files
Support x-litellm-api-key header param + allow key at max budget to call non-llm api endpoints (#10392)
* fix(user_api_key_auth.py): fix passing `x-litellm-api-key` to user api key auth Support using this when given, or bearer token when given Fixes issue with auth on vertex passthrough * test(test_user_api_key_auth.py): use new fastapi.security check * fix(user_api_key_auth.py): allow key at budget, to still call non-llm api endpoints Fixes issue where key at budget, couldn't call `/key/info`
1 parent 70accb7 commit 839878f

File tree

5 files changed

+115
-41
lines changed

5 files changed

+115
-41
lines changed

litellm/proxy/_experimental/out/onboarding.html

-1
This file was deleted.

litellm/proxy/_types.py

+35-33
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class LiteLLMRoutes(enum.Enum):
326326
"/v1/messages",
327327
]
328328

329+
llm_api_routes = openai_routes + anthropic_routes + mapped_pass_through_routes
329330
info_routes = [
330331
"/key/info",
331332
"/key/health",
@@ -654,9 +655,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
654655
allowed_cache_controls: Optional[list] = []
655656
config: Optional[dict] = {}
656657
permissions: Optional[dict] = {}
657-
model_max_budget: Optional[dict] = (
658-
{}
659-
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
658+
model_max_budget: Optional[
659+
dict
660+
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
660661

661662
model_config = ConfigDict(protected_namespaces=())
662663
model_rpm_limit: Optional[dict] = None
@@ -918,12 +919,12 @@ class NewCustomerRequest(BudgetNewRequest):
918919
alias: Optional[str] = None # human-friendly alias
919920
blocked: bool = False # allow/disallow requests for this end-user
920921
budget_id: Optional[str] = None # give either a budget_id or max_budget
921-
allowed_model_region: Optional[AllowedModelRegion] = (
922-
None # require all user requests to use models in this specific region
923-
)
924-
default_model: Optional[str] = (
925-
None # if no equivalent model in allowed region - default all requests to this model
926-
)
922+
allowed_model_region: Optional[
923+
AllowedModelRegion
924+
] = None # require all user requests to use models in this specific region
925+
default_model: Optional[
926+
str
927+
] = None # if no equivalent model in allowed region - default all requests to this model
927928

928929
@model_validator(mode="before")
929930
@classmethod
@@ -945,12 +946,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
945946
blocked: bool = False # allow/disallow requests for this end-user
946947
max_budget: Optional[float] = None
947948
budget_id: Optional[str] = None # give either a budget_id or max_budget
948-
allowed_model_region: Optional[AllowedModelRegion] = (
949-
None # require all user requests to use models in this specific region
950-
)
951-
default_model: Optional[str] = (
952-
None # if no equivalent model in allowed region - default all requests to this model
953-
)
949+
allowed_model_region: Optional[
950+
AllowedModelRegion
951+
] = None # require all user requests to use models in this specific region
952+
default_model: Optional[
953+
str
954+
] = None # if no equivalent model in allowed region - default all requests to this model
954955

955956

956957
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@@ -1086,9 +1087,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
10861087

10871088
class AddTeamCallback(LiteLLMPydanticObjectBase):
10881089
callback_name: str
1089-
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
1090-
"success_and_failure"
1091-
)
1090+
callback_type: Optional[
1091+
Literal["success", "failure", "success_and_failure"]
1092+
] = "success_and_failure"
10921093
callback_vars: Dict[str, str]
10931094

10941095
@model_validator(mode="before")
@@ -1346,9 +1347,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
13461347
stored_in_db: Optional[bool]
13471348
field_default_value: Any
13481349
premium_field: bool = False
1349-
nested_fields: Optional[List[FieldDetail]] = (
1350-
None # For nested dictionary or Pydantic fields
1351-
)
1350+
nested_fields: Optional[
1351+
List[FieldDetail]
1352+
] = None # For nested dictionary or Pydantic fields
13521353

13531354

13541355
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@@ -1616,9 +1617,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
16161617
budget_id: Optional[str] = None
16171618
created_at: datetime
16181619
updated_at: datetime
1619-
user: Optional[Any] = (
1620-
None # You might want to replace 'Any' with a more specific type if available
1621-
)
1620+
user: Optional[
1621+
Any
1622+
] = None # You might want to replace 'Any' with a more specific type if available
16221623
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
16231624

16241625
model_config = ConfigDict(protected_namespaces=())
@@ -2368,9 +2369,9 @@ class TeamModelDeleteRequest(BaseModel):
23682369
# Organization Member Requests
23692370
class OrganizationMemberAddRequest(OrgMemberAddRequest):
23702371
organization_id: str
2371-
max_budget_in_organization: Optional[float] = (
2372-
None # Users max budget within the organization
2373-
)
2372+
max_budget_in_organization: Optional[
2373+
float
2374+
] = None # Users max budget within the organization
23742375

23752376

23762377
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@@ -2451,6 +2452,7 @@ class SpecialHeaders(enum.Enum):
24512452
anthropic_authorization = "x-api-key"
24522453
google_ai_studio_authorization = "x-goog-api-key"
24532454
azure_apim_authorization = "Ocp-Apim-Subscription-Key"
2455+
custom_litellm_api_key = "x-litellm-api-key"
24542456

24552457

24562458
class LitellmDataForBackendLLMCall(TypedDict, total=False):
@@ -2559,9 +2561,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
25592561
Maps provider names to their budget configs.
25602562
"""
25612563

2562-
providers: Dict[str, ProviderBudgetResponseObject] = (
2563-
{}
2564-
) # Dictionary mapping provider names to their budget configurations
2564+
providers: Dict[
2565+
str, ProviderBudgetResponseObject
2566+
] = {} # Dictionary mapping provider names to their budget configurations
25652567

25662568

25672569
class ProxyStateVariables(TypedDict):
@@ -2689,9 +2691,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
26892691
enforce_rbac: bool = False
26902692
roles_jwt_field: Optional[str] = None # v2 on role mappings
26912693
role_mappings: Optional[List[RoleMapping]] = None
2692-
object_id_jwt_field: Optional[str] = (
2693-
None # can be either user / team, inferred from the role mapping
2694-
)
2694+
object_id_jwt_field: Optional[
2695+
str
2696+
] = None # can be either user / team, inferred from the role mapping
26952697
scope_mappings: Optional[List[ScopeMapping]] = None
26962698
enforce_scope_based_access: bool = False
26972699
enforce_team_based_model_access: bool = False

litellm/proxy/auth/user_api_key_auth.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@
5757

5858
user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
5959

60-
60+
custom_litellm_key_header = APIKeyHeader(
61+
name=SpecialHeaders.custom_litellm_api_key.value,
62+
auto_error=False,
63+
description="Bearer token",
64+
)
6165
api_key_header = APIKeyHeader(
6266
name=SpecialHeaders.openai_authorization.value,
6367
auto_error=False,
@@ -228,6 +232,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
228232
google_ai_studio_api_key_header: Optional[str],
229233
azure_apim_header: Optional[str],
230234
request_data: dict,
235+
custom_litellm_key_header: Optional[str] = None,
231236
) -> UserAPIKeyAuth:
232237
from litellm.proxy.proxy_server import (
233238
general_settings,
@@ -261,7 +266,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
261266
"pass_through_endpoints", None
262267
)
263268
passed_in_key: Optional[str] = None
264-
if isinstance(api_key, str):
269+
## CHECK IF X-LITELM-API-KEY IS PASSED IN - supercedes Authorization header
270+
if isinstance(custom_litellm_key_header, str):
271+
api_key = custom_litellm_key_header
272+
elif isinstance(api_key, str):
265273
passed_in_key = api_key
266274
api_key = _get_bearer_token(api_key=api_key)
267275
elif isinstance(azure_api_key_header, str):
@@ -867,11 +875,12 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
867875
)
868876

869877
# Check 4. Token Spend is under budget
870-
await _virtual_key_max_budget_check(
871-
valid_token=valid_token,
872-
proxy_logging_obj=proxy_logging_obj,
873-
user_obj=user_obj,
874-
)
878+
if route in LiteLLMRoutes.llm_api_routes.value:
879+
await _virtual_key_max_budget_check(
880+
valid_token=valid_token,
881+
proxy_logging_obj=proxy_logging_obj,
882+
user_obj=user_obj,
883+
)
875884

876885
# Check 5. Soft Budget Check
877886
await _virtual_key_soft_budget_check(
@@ -1025,6 +1034,9 @@ async def user_api_key_auth(
10251034
google_ai_studio_api_key_header
10261035
),
10271036
azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header),
1037+
custom_litellm_key_header: Optional[str] = fastapi.Security(
1038+
custom_litellm_key_header
1039+
),
10281040
) -> UserAPIKeyAuth:
10291041
"""
10301042
Parent function to authenticate user api key / jwt token.
@@ -1041,6 +1053,7 @@ async def user_api_key_auth(
10411053
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
10421054
azure_apim_header=azure_apim_header,
10431055
request_data=request_data,
1056+
custom_litellm_key_header=custom_litellm_key_header,
10441057
)
10451058

10461059
end_user_id = get_end_user_id_from_request_body(request_data)

tests/proxy_unit_tests/test_user_api_key_auth.py

+37
Original file line numberDiff line numberDiff line change
@@ -999,3 +999,40 @@ async def test_jwt_non_admin_team_route_access(monkeypatch):
999999
except ProxyException as e:
10001000
print("e", e)
10011001
assert "Only proxy admin can be used to generate" in str(e.message)
1002+
1003+
1004+
@pytest.mark.asyncio
1005+
async def test_x_litellm_api_key():
1006+
"""
1007+
Check if auth can pick up x-litellm-api-key header, even if Bearer token is provided
1008+
"""
1009+
from fastapi import Request
1010+
from starlette.datastructures import URL
1011+
1012+
from litellm.proxy._types import (
1013+
LiteLLM_TeamTable,
1014+
LiteLLM_TeamTableCachedObj,
1015+
UserAPIKeyAuth,
1016+
)
1017+
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
1018+
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
1019+
1020+
master_key = "sk-1234"
1021+
1022+
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
1023+
setattr(litellm.proxy.proxy_server, "master_key", master_key)
1024+
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
1025+
1026+
ignored_key = "aj12445"
1027+
1028+
# Create request with headers as bytes
1029+
request = Request(
1030+
scope={
1031+
"type": "http"
1032+
}
1033+
)
1034+
request._url = URL(url="/chat/completions")
1035+
1036+
valid_token = await user_api_key_auth(request=request, api_key="Bearer " + ignored_key, custom_litellm_key_header=master_key)
1037+
assert valid_token.token == hash_token(master_key)
1038+

tests/test_keys.py

+23
Original file line numberDiff line numberDiff line change
@@ -842,3 +842,26 @@ async def test_key_user_not_in_db():
842842
await chat_completion(session=session, key=key)
843843
except Exception as e:
844844
pytest.fail(f"Expected this call to work - {str(e)}")
845+
846+
847+
@pytest.mark.asyncio
848+
async def test_key_over_budget():
849+
"""
850+
Test if key over budget is handled as expected.
851+
"""
852+
async with aiohttp.ClientSession() as session:
853+
key_gen = await generate_key(session=session, i=0, budget=0.0000001)
854+
key = key_gen["key"]
855+
try:
856+
await chat_completion(session=session, key=key)
857+
except Exception as e:
858+
pytest.fail(f"Expected this call to work - {str(e)}")
859+
860+
## CALL `/models` - expect to work
861+
model_list = await get_key_info(session=session, get_key=key, call_key=key)
862+
## CALL `/chat/completions` - expect to fail
863+
try:
864+
await chat_completion(session=session, key=key)
865+
pytest.fail("Expected this call to fail")
866+
except Exception as e:
867+
assert "Budget has been exceeded!" in str(e)

0 commit comments

Comments
 (0)