Skip to content

Commit bd88263

Browse files
authored
[Feat - Cost Tracking improvement] Track prompt caching metrics in DailyUserSpendTransactions (#10029)
* stash changes * emit cache read/write tokens to daily spend update * emit cache read/write tokens on daily activity * update types.ts * docs prompt caching * undo ui change * fix activity metrics * fix prompt caching metrics * fix typed dict fields * fix get_aggregated_daily_spend_update_transactions * fix aggregating cache tokens * test_cache_token_fields_aggregation * daily_transaction * add cache_creation_input_tokens and cache_read_input_tokens to LiteLLM_DailyUserSpend * test_daily_spend_update_queue.py
1 parent d32d6fe commit bd88263

File tree

12 files changed

+197
-35
lines changed

12 files changed

+197
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- AlterTable
2+
ALTER TABLE "LiteLLM_DailyUserSpend" ADD COLUMN "cache_creation_input_tokens" INTEGER NOT NULL DEFAULT 0,
3+
ADD COLUMN "cache_read_input_tokens" INTEGER NOT NULL DEFAULT 0;
4+

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ model LiteLLM_DailyUserSpend {
326326
custom_llm_provider String?
327327
prompt_tokens Int @default(0)
328328
completion_tokens Int @default(0)
329+
cache_read_input_tokens Int @default(0)
330+
cache_creation_input_tokens Int @default(0)
329331
spend Float @default(0.0)
330332
api_requests Int @default(0)
331333
successful_requests Int @default(0)

litellm/proxy/_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,8 +2777,14 @@ class BaseDailySpendTransaction(TypedDict):
27772777
model: str
27782778
model_group: Optional[str]
27792779
custom_llm_provider: Optional[str]
2780+
2781+
# token count metrics
27802782
prompt_tokens: int
27812783
completion_tokens: int
2784+
cache_read_input_tokens: int
2785+
cache_creation_input_tokens: int
2786+
2787+
# request level metrics
27822788
spend: float
27832789
api_requests: int
27842790
successful_requests: int

litellm/proxy/db/db_spend_update_writer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import asyncio
9+
import json
910
import os
1011
import time
1112
import traceback
@@ -24,6 +25,7 @@
2425
DBSpendUpdateTransactions,
2526
Litellm_EntityType,
2627
LiteLLM_UserTable,
28+
SpendLogsMetadata,
2729
SpendLogsPayload,
2830
SpendUpdateQueueItem,
2931
)
@@ -806,6 +808,12 @@ async def update_daily_user_spend(
806808
"completion_tokens": transaction[
807809
"completion_tokens"
808810
],
811+
"cache_read_input_tokens": transaction.get(
812+
"cache_read_input_tokens", 0
813+
),
814+
"cache_creation_input_tokens": transaction.get(
815+
"cache_creation_input_tokens", 0
816+
),
809817
"spend": transaction["spend"],
810818
"api_requests": transaction["api_requests"],
811819
"successful_requests": transaction[
@@ -824,6 +832,16 @@ async def update_daily_user_spend(
824832
"completion_tokens"
825833
]
826834
},
835+
"cache_read_input_tokens": {
836+
"increment": transaction.get(
837+
"cache_read_input_tokens", 0
838+
)
839+
},
840+
"cache_creation_input_tokens": {
841+
"increment": transaction.get(
842+
"cache_creation_input_tokens", 0
843+
)
844+
},
827845
"spend": {"increment": transaction["spend"]},
828846
"api_requests": {
829847
"increment": transaction["api_requests"]
@@ -1024,6 +1042,8 @@ async def _common_add_spend_log_transaction_to_daily_transaction(
10241042

10251043
request_status = prisma_client.get_request_status(payload)
10261044
verbose_proxy_logger.info(f"Logged request status: {request_status}")
1045+
_metadata: SpendLogsMetadata = json.loads(payload["metadata"])
1046+
usage_obj = _metadata.get("usage_object", {}) or {}
10271047
if isinstance(payload["startTime"], datetime):
10281048
start_time = payload["startTime"].isoformat()
10291049
date = start_time.split("T")[0]
@@ -1047,6 +1067,12 @@ async def _common_add_spend_log_transaction_to_daily_transaction(
10471067
api_requests=1,
10481068
successful_requests=1 if request_status == "success" else 0,
10491069
failed_requests=1 if request_status != "success" else 0,
1070+
cache_read_input_tokens=usage_obj.get("cache_read_input_tokens", 0)
1071+
or 0,
1072+
cache_creation_input_tokens=usage_obj.get(
1073+
"cache_creation_input_tokens", 0
1074+
)
1075+
or 0,
10501076
)
10511077
return daily_transaction
10521078
except Exception as e:

litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ class DailySpendUpdateQueue(BaseUpdateQueue):
5353

5454
def __init__(self):
5555
super().__init__()
56-
self.update_queue: asyncio.Queue[
57-
Dict[str, BaseDailySpendTransaction]
58-
] = asyncio.Queue()
56+
self.update_queue: asyncio.Queue[Dict[str, BaseDailySpendTransaction]] = (
57+
asyncio.Queue()
58+
)
5959

6060
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
6161
"""Enqueue an update."""
@@ -72,9 +72,9 @@ async def aggregate_queue_updates(self):
7272
Combine all updates in the queue into a single update.
7373
This is used to reduce the size of the in-memory queue.
7474
"""
75-
updates: List[
76-
Dict[str, BaseDailySpendTransaction]
77-
] = await self.flush_all_updates_from_in_memory_queue()
75+
updates: List[Dict[str, BaseDailySpendTransaction]] = (
76+
await self.flush_all_updates_from_in_memory_queue()
77+
)
7878
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
7979
updates
8080
)
@@ -98,7 +98,7 @@ async def flush_and_get_aggregated_daily_spend_update_transactions(
9898

9999
@staticmethod
100100
def get_aggregated_daily_spend_update_transactions(
101-
updates: List[Dict[str, BaseDailySpendTransaction]]
101+
updates: List[Dict[str, BaseDailySpendTransaction]],
102102
) -> Dict[str, BaseDailySpendTransaction]:
103103
"""Aggregate updates by daily_transaction_key."""
104104
aggregated_daily_spend_update_transactions: Dict[
@@ -118,6 +118,16 @@ def get_aggregated_daily_spend_update_transactions(
118118
"successful_requests"
119119
]
120120
daily_transaction["failed_requests"] += payload["failed_requests"]
121+
122+
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
123+
daily_transaction["cache_read_input_tokens"] = (
124+
payload.get("cache_read_input_tokens", 0) or 0
125+
) + daily_transaction.get("cache_read_input_tokens", 0)
126+
127+
daily_transaction["cache_creation_input_tokens"] = (
128+
payload.get("cache_creation_input_tokens", 0) or 0
129+
) + daily_transaction.get("cache_creation_input_tokens", 0)
130+
121131
else:
122132
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
123133
return aggregated_daily_spend_update_transactions

litellm/proxy/management_endpoints/internal_user_endpoints.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> d
8282
data_json["user_id"] = str(uuid.uuid4())
8383
auto_create_key = data_json.pop("auto_create_key", True)
8484
if auto_create_key is False:
85-
data_json[
86-
"table_name"
87-
] = "user" # only create a user, don't create key if 'auto_create_key' set to False
85+
data_json["table_name"] = (
86+
"user" # only create a user, don't create key if 'auto_create_key' set to False
87+
)
8888

8989
is_internal_user = False
9090
if data.user_role and data.user_role.is_internal_user_role:
@@ -651,9 +651,9 @@ def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> di
651651
"budget_duration" not in non_default_values
652652
): # applies internal user limits, if user role updated
653653
if is_internal_user and litellm.internal_user_budget_duration is not None:
654-
non_default_values[
655-
"budget_duration"
656-
] = litellm.internal_user_budget_duration
654+
non_default_values["budget_duration"] = (
655+
litellm.internal_user_budget_duration
656+
)
657657
duration_s = duration_in_seconds(
658658
duration=non_default_values["budget_duration"]
659659
)
@@ -964,13 +964,13 @@ async def get_users(
964964
"in": user_id_list, # Now passing a list of strings as required by Prisma
965965
}
966966

967-
users: Optional[
968-
List[LiteLLM_UserTable]
969-
] = await prisma_client.db.litellm_usertable.find_many(
970-
where=where_conditions,
971-
skip=skip,
972-
take=page_size,
973-
order={"created_at": "desc"},
967+
users: Optional[List[LiteLLM_UserTable]] = (
968+
await prisma_client.db.litellm_usertable.find_many(
969+
where=where_conditions,
970+
skip=skip,
971+
take=page_size,
972+
order={"created_at": "desc"},
973+
)
974974
)
975975

976976
# Get total count of user rows
@@ -1225,13 +1225,13 @@ async def ui_view_users(
12251225
}
12261226

12271227
# Query users with pagination and filters
1228-
users: Optional[
1229-
List[BaseModel]
1230-
] = await prisma_client.db.litellm_usertable.find_many(
1231-
where=where_conditions,
1232-
skip=skip,
1233-
take=page_size,
1234-
order={"created_at": "desc"},
1228+
users: Optional[List[BaseModel]] = (
1229+
await prisma_client.db.litellm_usertable.find_many(
1230+
where=where_conditions,
1231+
skip=skip,
1232+
take=page_size,
1233+
order={"created_at": "desc"},
1234+
)
12351235
)
12361236

12371237
if not users:
@@ -1258,6 +1258,8 @@ class SpendMetrics(BaseModel):
12581258
spend: float = Field(default=0.0)
12591259
prompt_tokens: int = Field(default=0)
12601260
completion_tokens: int = Field(default=0)
1261+
cache_read_input_tokens: int = Field(default=0)
1262+
cache_creation_input_tokens: int = Field(default=0)
12611263
total_tokens: int = Field(default=0)
12621264
successful_requests: int = Field(default=0)
12631265
failed_requests: int = Field(default=0)
@@ -1312,6 +1314,8 @@ class DailySpendMetadata(BaseModel):
13121314
total_api_requests: int = Field(default=0)
13131315
total_successful_requests: int = Field(default=0)
13141316
total_failed_requests: int = Field(default=0)
1317+
total_cache_read_input_tokens: int = Field(default=0)
1318+
total_cache_creation_input_tokens: int = Field(default=0)
13151319
page: int = Field(default=1)
13161320
total_pages: int = Field(default=1)
13171321
has_more: bool = Field(default=False)
@@ -1332,6 +1336,8 @@ class LiteLLM_DailyUserSpend(BaseModel):
13321336
custom_llm_provider: Optional[str] = None
13331337
prompt_tokens: int = 0
13341338
completion_tokens: int = 0
1339+
cache_read_input_tokens: int = 0
1340+
cache_creation_input_tokens: int = 0
13351341
spend: float = 0.0
13361342
api_requests: int = 0
13371343
successful_requests: int = 0
@@ -1349,6 +1355,8 @@ def update_metrics(
13491355
group_metrics.spend += record.spend
13501356
group_metrics.prompt_tokens += record.prompt_tokens
13511357
group_metrics.completion_tokens += record.completion_tokens
1358+
group_metrics.cache_read_input_tokens += record.cache_read_input_tokens
1359+
group_metrics.cache_creation_input_tokens += record.cache_creation_input_tokens
13521360
group_metrics.total_tokens += record.prompt_tokens + record.completion_tokens
13531361
group_metrics.api_requests += record.api_requests
13541362
group_metrics.successful_requests += record.successful_requests
@@ -1448,6 +1456,8 @@ async def get_user_daily_activity(
14481456
- spend
14491457
- prompt_tokens
14501458
- completion_tokens
1459+
- cache_read_input_tokens
1460+
- cache_creation_input_tokens
14511461
- total_tokens
14521462
- api_requests
14531463
- breakdown by model, api_key, provider
@@ -1484,9 +1494,9 @@ async def get_user_daily_activity(
14841494
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
14851495
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
14861496
):
1487-
where_conditions[
1488-
"user_id"
1489-
] = user_api_key_dict.user_id # only allow access to own data
1497+
where_conditions["user_id"] = (
1498+
user_api_key_dict.user_id
1499+
) # only allow access to own data
14901500

14911501
# Get total count for pagination
14921502
total_count = await prisma_client.db.litellm_dailyuserspend.count(
@@ -1560,6 +1570,10 @@ async def get_user_daily_activity(
15601570
total_metrics.total_tokens += (
15611571
record.prompt_tokens + record.completion_tokens
15621572
)
1573+
total_metrics.cache_read_input_tokens += record.cache_read_input_tokens
1574+
total_metrics.cache_creation_input_tokens += (
1575+
record.cache_creation_input_tokens
1576+
)
15631577
total_metrics.api_requests += record.api_requests
15641578
total_metrics.successful_requests += record.successful_requests
15651579
total_metrics.failed_requests += record.failed_requests
@@ -1587,6 +1601,8 @@ async def get_user_daily_activity(
15871601
total_api_requests=total_metrics.api_requests,
15881602
total_successful_requests=total_metrics.successful_requests,
15891603
total_failed_requests=total_metrics.failed_requests,
1604+
total_cache_read_input_tokens=total_metrics.cache_read_input_tokens,
1605+
total_cache_creation_input_tokens=total_metrics.cache_creation_input_tokens,
15901606
page=page,
15911607
total_pages=-(-total_count // page_size), # Ceiling division
15921608
has_more=(page * page_size) < total_count,

litellm/proxy/schema.prisma

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ model LiteLLM_DailyUserSpend {
326326
custom_llm_provider String?
327327
prompt_tokens Int @default(0)
328328
completion_tokens Int @default(0)
329+
cache_read_input_tokens Int @default(0)
330+
cache_creation_input_tokens Int @default(0)
329331
spend Float @default(0.0)
330332
api_requests Int @default(0)
331333
successful_requests Int @default(0)

schema.prisma

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ model LiteLLM_DailyUserSpend {
326326
custom_llm_provider String?
327327
prompt_tokens Int @default(0)
328328
completion_tokens Int @default(0)
329+
cache_read_input_tokens Int @default(0)
330+
cache_creation_input_tokens Int @default(0)
329331
spend Float @default(0.0)
330332
api_requests Int @default(0)
331333
successful_requests Int @default(0)

tests/litellm/proxy/db/db_transaction_queue/test_daily_spend_update_queue.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ async def test_get_aggregated_daily_spend_update_transactions_same_key():
204204
"api_requests": 2, # 1 + 1
205205
"successful_requests": 2, # 1 + 1
206206
"failed_requests": 0, # 0 + 0
207+
"cache_creation_input_tokens": 0,
208+
"cache_read_input_tokens": 0,
207209
}
208210

209211
updates = [{test_key: test_transaction1}, {test_key: test_transaction2}]
@@ -249,6 +251,8 @@ async def test_flush_and_get_aggregated_daily_spend_update_transactions(
249251
"api_requests": 2, # 1 + 1
250252
"successful_requests": 2, # 1 + 1
251253
"failed_requests": 0, # 0 + 0
254+
"cache_creation_input_tokens": 0,
255+
"cache_read_input_tokens": 0,
252256
}
253257

254258
# Add updates to queue
@@ -368,6 +372,48 @@ async def test_aggregate_queue_updates_accuracy(daily_spend_update_queue):
368372
assert daily_spend_update_transactions[test_key3]["failed_requests"] == 0
369373

370374

375+
@pytest.mark.asyncio
376+
async def test_cache_token_fields_aggregation(daily_spend_update_queue):
377+
"""Test that cache_read_input_tokens and cache_creation_input_tokens are handled and aggregated correctly."""
378+
test_key = "user1_2023-01-01_key123_gpt-4_openai"
379+
transaction1 = {
380+
"spend": 1.0,
381+
"prompt_tokens": 10,
382+
"completion_tokens": 5,
383+
"api_requests": 1,
384+
"successful_requests": 1,
385+
"failed_requests": 0,
386+
"cache_read_input_tokens": 7,
387+
"cache_creation_input_tokens": 3,
388+
}
389+
transaction2 = {
390+
"spend": 2.0,
391+
"prompt_tokens": 20,
392+
"completion_tokens": 10,
393+
"api_requests": 1,
394+
"successful_requests": 1,
395+
"failed_requests": 0,
396+
"cache_read_input_tokens": 5,
397+
"cache_creation_input_tokens": 4,
398+
}
399+
# Add both updates
400+
await daily_spend_update_queue.add_update({test_key: transaction1})
401+
await daily_spend_update_queue.add_update({test_key: transaction2})
402+
# Aggregate
403+
await daily_spend_update_queue.aggregate_queue_updates()
404+
updates = await daily_spend_update_queue.flush_all_updates_from_in_memory_queue()
405+
assert len(updates) == 1
406+
agg = updates[0][test_key]
407+
assert agg["cache_read_input_tokens"] == 12 # 7 + 5
408+
assert agg["cache_creation_input_tokens"] == 7 # 3 + 4
409+
assert agg["spend"] == 3.0
410+
assert agg["prompt_tokens"] == 30
411+
assert agg["completion_tokens"] == 15
412+
assert agg["api_requests"] == 2
413+
assert agg["successful_requests"] == 2
414+
assert agg["failed_requests"] == 0
415+
416+
371417
@pytest.mark.asyncio
372418
async def test_queue_size_reduction_with_large_volume(
373419
monkeypatch, daily_spend_update_queue

0 commit comments

Comments
 (0)