Skip to content

Commit 331e784

Browse files
authored
[Feat] Responses API - Add session management support for non-openai models (#10321)
* add session id in spendLogs * fix log proxy server request as independant field * use trace id for SpendLogs * add _ENTERPRISE_ResponsesSessionHandler * use _ENTERPRISE_ResponsesSessionHandler * working session_ids * working session management * working session_ids * test_async_gcs_pub_sub_v1 * test_spend_logs_payload_e2e * working session_ids * test_get_standard_logging_payload_trace_id * test_get_standard_logging_payload_trace_id * test_gcs_pub_sub.py * fix all linting errors * test_spend_logs_payload_with_prompts_enabled * _ENTERPRISE_ResponsesSessionHandler * _ENTERPRISE_ResponsesSessionHandler * expose session id on ui * get spend logs by session * add sessionSpendLogsCall * add session handling * session logs * ui session details * fix on rowExpandDetails * ui working sessions
1 parent c66c821 commit 331e784

File tree

22 files changed

+878
-469
lines changed

22 files changed

+878
-469
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from litellm.proxy._types import SpendLogsPayload
2+
from litellm.integrations.custom_logger import CustomLogger
3+
from litellm._logging import verbose_proxy_logger
4+
from typing import Optional, List, Union
5+
import json
6+
from litellm.types.utils import ModelResponse, Message
7+
from litellm.types.llms.openai import (
8+
AllMessageValues,
9+
ChatCompletionResponseMessage,
10+
GenericChatCompletionMessage,
11+
ResponseInputParam,
12+
)
13+
from litellm.types.utils import ChatCompletionMessageToolCall
14+
15+
from litellm.responses.utils import ResponsesAPIRequestUtils
16+
from typing import TypedDict
17+
18+
class ChatCompletionSession(TypedDict, total=False):
19+
messages: List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionResponseMessage, Message]]
20+
litellm_session_id: Optional[str]
21+
22+
class _ENTERPRISE_ResponsesSessionHandler:
23+
@staticmethod
24+
async def get_chat_completion_message_history_for_previous_response_id(
25+
previous_response_id: str,
26+
) -> ChatCompletionSession:
27+
"""
28+
Return the chat completion message history for a previous response id
29+
"""
30+
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
31+
all_spend_logs: List[SpendLogsPayload] = await _ENTERPRISE_ResponsesSessionHandler.get_all_spend_logs_for_previous_response_id(previous_response_id)
32+
33+
litellm_session_id: Optional[str] = None
34+
if len(all_spend_logs) > 0:
35+
litellm_session_id = all_spend_logs[0].get("session_id")
36+
37+
chat_completion_message_history: List[
38+
Union[
39+
AllMessageValues,
40+
GenericChatCompletionMessage,
41+
ChatCompletionMessageToolCall,
42+
ChatCompletionResponseMessage,
43+
Message,
44+
]
45+
] = []
46+
for spend_log in all_spend_logs:
47+
proxy_server_request: Union[str, dict] = spend_log.get("proxy_server_request") or "{}"
48+
proxy_server_request_dict: Optional[dict] = None
49+
response_input_param: Optional[Union[str, ResponseInputParam]] = None
50+
if isinstance(proxy_server_request, dict):
51+
proxy_server_request_dict = proxy_server_request
52+
else:
53+
proxy_server_request_dict = json.loads(proxy_server_request)
54+
55+
############################################################
56+
# Add Input messages for this Spend Log
57+
############################################################
58+
if proxy_server_request_dict:
59+
_response_input_param = proxy_server_request_dict.get("input", None)
60+
if isinstance(_response_input_param, str):
61+
response_input_param = _response_input_param
62+
elif isinstance(_response_input_param, dict):
63+
response_input_param = ResponseInputParam(**_response_input_param)
64+
65+
if response_input_param:
66+
chat_completion_messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
67+
input=response_input_param,
68+
responses_api_request=proxy_server_request_dict or {}
69+
)
70+
chat_completion_message_history.extend(chat_completion_messages)
71+
72+
############################################################
73+
# Add Output messages for this Spend Log
74+
############################################################
75+
_response_output = spend_log.get("response", "{}")
76+
if isinstance(_response_output, dict):
77+
# transform `ChatCompletion Response` to `ResponsesAPIResponse`
78+
model_response = ModelResponse(**_response_output)
79+
for choice in model_response.choices:
80+
if hasattr(choice, "message"):
81+
chat_completion_message_history.append(choice.message)
82+
83+
verbose_proxy_logger.debug("chat_completion_message_history %s", json.dumps(chat_completion_message_history, indent=4, default=str))
84+
return ChatCompletionSession(
85+
messages=chat_completion_message_history,
86+
litellm_session_id=litellm_session_id
87+
)
88+
89+
@staticmethod
90+
async def get_all_spend_logs_for_previous_response_id(
91+
previous_response_id: str
92+
) -> List[SpendLogsPayload]:
93+
"""
94+
Get all spend logs for a previous response id
95+
96+
97+
SQL query
98+
99+
SELECT session_id FROM spend_logs WHERE response_id = previous_response_id, SELECT * FROM spend_logs WHERE session_id = session_id
100+
"""
101+
from litellm.proxy.proxy_server import prisma_client
102+
decoded_response_id = ResponsesAPIRequestUtils._decode_responses_api_response_id(previous_response_id)
103+
previous_response_id = decoded_response_id.get("response_id", previous_response_id)
104+
if prisma_client is None:
105+
return []
106+
107+
query = """
108+
WITH matching_session AS (
109+
SELECT session_id
110+
FROM "LiteLLM_SpendLogs"
111+
WHERE request_id = $1
112+
)
113+
SELECT *
114+
FROM "LiteLLM_SpendLogs"
115+
WHERE session_id IN (SELECT session_id FROM matching_session)
116+
ORDER BY "endTime" ASC;
117+
"""
118+
119+
spend_logs = await prisma_client.db.query_raw(
120+
query,
121+
previous_response_id
122+
)
123+
124+
verbose_proxy_logger.debug(
125+
"Found the following spend logs for previous response id %s: %s",
126+
previous_response_id,
127+
json.dumps(spend_logs, indent=4, default=str)
128+
)
129+
130+
131+
return spend_logs
132+
133+
134+
135+
136+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- AlterTable
2+
ALTER TABLE "LiteLLM_SpendLogs" ADD COLUMN "proxy_server_request" JSONB DEFAULT '{}',
3+
ADD COLUMN "session_id" TEXT;
4+

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ model LiteLLM_SpendLogs {
226226
requester_ip_address String?
227227
messages Json? @default("{}")
228228
response Json? @default("{}")
229+
session_id String?
230+
proxy_server_request Json? @default("{}")
229231
@@index([startTime])
230232
@@index([end_user])
231233
}

litellm/litellm_core_utils/litellm_logging.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from litellm.batches.batch_utils import _handle_completed_batch
2929
from litellm.caching.caching import DualCache, InMemoryCache
3030
from litellm.caching.caching_handler import LLMCachingHandler
31-
3231
from litellm.constants import (
3332
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
3433
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
@@ -249,7 +248,7 @@ def __init__(
249248
self.start_time = start_time # log the call start time
250249
self.call_type = call_type
251250
self.litellm_call_id = litellm_call_id
252-
self.litellm_trace_id = litellm_trace_id
251+
self.litellm_trace_id: str = litellm_trace_id or str(uuid.uuid4())
253252
self.function_id = function_id
254253
self.streaming_chunks: List[Any] = [] # for generating complete stream response
255254
self.sync_streaming_chunks: List[Any] = (
@@ -3500,6 +3499,21 @@ def get_response_time(
35003499
else:
35013500
return end_time_float - start_time_float
35023501

3502+
@staticmethod
3503+
def _get_standard_logging_payload_trace_id(
3504+
logging_obj: Logging,
3505+
litellm_params: dict,
3506+
) -> str:
3507+
"""
3508+
Returns the `litellm_trace_id` for this request
3509+
3510+
This helps link sessions when multiple requests are made in a single session
3511+
"""
3512+
dynamic_trace_id = litellm_params.get("litellm_trace_id")
3513+
if dynamic_trace_id:
3514+
return str(dynamic_trace_id)
3515+
return logging_obj.litellm_trace_id
3516+
35033517

35043518
def get_standard_logging_object_payload(
35053519
kwargs: Optional[dict],
@@ -3652,7 +3666,10 @@ def get_standard_logging_object_payload(
36523666

36533667
payload: StandardLoggingPayload = StandardLoggingPayload(
36543668
id=str(id),
3655-
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
3669+
trace_id=StandardLoggingPayloadSetup._get_standard_logging_payload_trace_id(
3670+
logging_obj=logging_obj,
3671+
litellm_params=litellm_params,
3672+
),
36563673
call_type=call_type or "",
36573674
cache_hit=cache_hit,
36583675
stream=stream,

litellm/proxy/_types.py

+35-33
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
654654
allowed_cache_controls: Optional[list] = []
655655
config: Optional[dict] = {}
656656
permissions: Optional[dict] = {}
657-
model_max_budget: Optional[
658-
dict
659-
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
657+
model_max_budget: Optional[dict] = (
658+
{}
659+
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
660660

661661
model_config = ConfigDict(protected_namespaces=())
662662
model_rpm_limit: Optional[dict] = None
@@ -918,12 +918,12 @@ class NewCustomerRequest(BudgetNewRequest):
918918
alias: Optional[str] = None # human-friendly alias
919919
blocked: bool = False # allow/disallow requests for this end-user
920920
budget_id: Optional[str] = None # give either a budget_id or max_budget
921-
allowed_model_region: Optional[
922-
AllowedModelRegion
923-
] = None # require all user requests to use models in this specific region
924-
default_model: Optional[
925-
str
926-
] = None # if no equivalent model in allowed region - default all requests to this model
921+
allowed_model_region: Optional[AllowedModelRegion] = (
922+
None # require all user requests to use models in this specific region
923+
)
924+
default_model: Optional[str] = (
925+
None # if no equivalent model in allowed region - default all requests to this model
926+
)
927927

928928
@model_validator(mode="before")
929929
@classmethod
@@ -945,12 +945,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
945945
blocked: bool = False # allow/disallow requests for this end-user
946946
max_budget: Optional[float] = None
947947
budget_id: Optional[str] = None # give either a budget_id or max_budget
948-
allowed_model_region: Optional[
949-
AllowedModelRegion
950-
] = None # require all user requests to use models in this specific region
951-
default_model: Optional[
952-
str
953-
] = None # if no equivalent model in allowed region - default all requests to this model
948+
allowed_model_region: Optional[AllowedModelRegion] = (
949+
None # require all user requests to use models in this specific region
950+
)
951+
default_model: Optional[str] = (
952+
None # if no equivalent model in allowed region - default all requests to this model
953+
)
954954

955955

956956
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@@ -1086,9 +1086,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
10861086

10871087
class AddTeamCallback(LiteLLMPydanticObjectBase):
10881088
callback_name: str
1089-
callback_type: Optional[
1090-
Literal["success", "failure", "success_and_failure"]
1091-
] = "success_and_failure"
1089+
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
1090+
"success_and_failure"
1091+
)
10921092
callback_vars: Dict[str, str]
10931093

10941094
@model_validator(mode="before")
@@ -1346,9 +1346,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
13461346
stored_in_db: Optional[bool]
13471347
field_default_value: Any
13481348
premium_field: bool = False
1349-
nested_fields: Optional[
1350-
List[FieldDetail]
1351-
] = None # For nested dictionary or Pydantic fields
1349+
nested_fields: Optional[List[FieldDetail]] = (
1350+
None # For nested dictionary or Pydantic fields
1351+
)
13521352

13531353

13541354
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@@ -1616,9 +1616,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
16161616
budget_id: Optional[str] = None
16171617
created_at: datetime
16181618
updated_at: datetime
1619-
user: Optional[
1620-
Any
1621-
] = None # You might want to replace 'Any' with a more specific type if available
1619+
user: Optional[Any] = (
1620+
None # You might want to replace 'Any' with a more specific type if available
1621+
)
16221622
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
16231623

16241624
model_config = ConfigDict(protected_namespaces=())
@@ -2015,6 +2015,8 @@ class SpendLogsPayload(TypedDict):
20152015
custom_llm_provider: Optional[str]
20162016
messages: Optional[Union[str, list, dict]]
20172017
response: Optional[Union[str, list, dict]]
2018+
proxy_server_request: Optional[str]
2019+
session_id: Optional[str]
20182020

20192021

20202022
class SpanAttributes(str, enum.Enum):
@@ -2366,9 +2368,9 @@ class TeamModelDeleteRequest(BaseModel):
23662368
# Organization Member Requests
23672369
class OrganizationMemberAddRequest(OrgMemberAddRequest):
23682370
organization_id: str
2369-
max_budget_in_organization: Optional[
2370-
float
2371-
] = None # Users max budget within the organization
2371+
max_budget_in_organization: Optional[float] = (
2372+
None # Users max budget within the organization
2373+
)
23722374

23732375

23742376
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@@ -2557,9 +2559,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
25572559
Maps provider names to their budget configs.
25582560
"""
25592561

2560-
providers: Dict[
2561-
str, ProviderBudgetResponseObject
2562-
] = {} # Dictionary mapping provider names to their budget configurations
2562+
providers: Dict[str, ProviderBudgetResponseObject] = (
2563+
{}
2564+
) # Dictionary mapping provider names to their budget configurations
25632565

25642566

25652567
class ProxyStateVariables(TypedDict):
@@ -2687,9 +2689,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
26872689
enforce_rbac: bool = False
26882690
roles_jwt_field: Optional[str] = None # v2 on role mappings
26892691
role_mappings: Optional[List[RoleMapping]] = None
2690-
object_id_jwt_field: Optional[
2691-
str
2692-
] = None # can be either user / team, inferred from the role mapping
2692+
object_id_jwt_field: Optional[str] = (
2693+
None # can be either user / team, inferred from the role mapping
2694+
)
26932695
scope_mappings: Optional[List[ScopeMapping]] = None
26942696
enforce_scope_based_access: bool = False
26952697
enforce_team_based_model_access: bool = False

litellm/proxy/proxy_config.yaml

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
model_list:
2-
- model_name: openai/*
2+
- model_name: anthropic/*
33
litellm_params:
4-
model: openai/*
5-
api_key: os.environ/OPENAI_API_KEY
6-
7-
router_settings:
8-
optional_pre_call_checks: ["responses_api_deployment_check"]
4+
model: anthropic/*
5+
api_key: os.environ/ANTHROPIC_API_KEY
6+
general_settings:
7+
store_prompts_in_spend_logs: true

litellm/proxy/schema.prisma

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ model LiteLLM_SpendLogs {
226226
requester_ip_address String?
227227
messages Json? @default("{}")
228228
response Json? @default("{}")
229+
session_id String?
230+
proxy_server_request Json? @default("{}")
229231
@@index([startTime])
230232
@@index([end_user])
231233
}

0 commit comments

Comments
 (0)