Skip to content

Commit ddba62c

Browse files
niklubnikmatt-bernstein
authored
feat: DIA-1360: Add token cost KPI to the Prompt aggregate-level subset metrics (#201)
Co-authored-by: nik <[email protected]> Co-authored-by: Matt Bernstein <[email protected]> Co-authored-by: niklub <[email protected]>
1 parent a0e7f6e commit ddba62c

13 files changed

+623
-125
lines changed

adala/runtimes/_litellm.py

+133-49
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import asyncio
22
import logging
3-
from typing import Any, Dict, List, Optional, Union, Type
3+
from typing import Any, Dict, List, Optional, Type
44

55
import litellm
6-
from litellm.exceptions import AuthenticationError
6+
from litellm.exceptions import (
7+
AuthenticationError,
8+
ContentPolicyViolationError,
9+
BadRequestError,
10+
NotFoundError,
11+
)
12+
from litellm.types.utils import Usage
713
import instructor
8-
from instructor.exceptions import InstructorRetryException
14+
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
915
import traceback
1016
from adala.utils.exceptions import ConstrainedGenerationError
1117
from adala.utils.internal_data import InternalDataFrame
@@ -14,14 +20,14 @@
1420
parse_template,
1521
partial_str_format,
1622
)
17-
from openai import NotFoundError
1823
from pydantic import ConfigDict, field_validator, BaseModel
1924
from rich import print
2025
from tenacity import (
2126
AsyncRetrying,
2227
Retrying,
2328
retry_if_not_exception_type,
2429
stop_after_attempt,
30+
wait_random_exponential,
2531
)
2632
from pydantic_core._pydantic_core import ValidationError
2733

@@ -33,6 +39,25 @@
3339
logger = logging.getLogger(__name__)
3440

3541

42+
# basically only retrying on timeout, incomplete output, or rate limit
43+
# https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list
44+
# NOTE: token usage is only correctly calculated if we only use instructor retries, not litellm retries
45+
# https://github.com/jxnl/instructor/pull/763
46+
RETRY_POLICY = dict(
47+
retry=retry_if_not_exception_type(
48+
(
49+
ValidationError,
50+
ContentPolicyViolationError,
51+
AuthenticationError,
52+
BadRequestError,
53+
)
54+
),
55+
# should stop earlier on ValidationError and later on other errors, but couldn't figure out how to do that cleanly
56+
stop=stop_after_attempt(3),
57+
wait=wait_random_exponential(multiplier=1, max=60),
58+
)
59+
60+
3661
def get_messages(
3762
user_prompt: str,
3863
system_prompt: Optional[str] = None,
@@ -59,6 +84,37 @@ def _format_error_dict(e: Exception) -> dict:
5984
return error_dct
6085

6186

87+
def _log_llm_exception(e) -> dict:
88+
dct = _format_error_dict(e)
89+
base_error = f"Inference error {dct['_adala_message']}"
90+
tb = traceback.format_exc()
91+
logger.error(f"{base_error}\nTraceback:\n{tb}")
92+
return dct
93+
94+
95+
def _get_usage_dict(usage: Usage, model: str) -> Dict:
96+
data = dict()
97+
data["_prompt_tokens"] = usage.prompt_tokens
98+
# will not exist if there is no completion
99+
data["_completion_tokens"] = usage.get("completion_tokens", 0)
100+
# can't use litellm.completion_cost bc it only takes the most recent completion, and .usage is summed over retries
101+
# TODO make sure this is calculated correctly after we turn on caching
102+
# litellm will register the cost of an azure model on first successful completion. If there hasn't been a successful completion, the model will not be registered
103+
try:
104+
prompt_cost, completion_cost = litellm.cost_per_token(
105+
model, usage.prompt_tokens, usage.get("completion_tokens", 0)
106+
)
107+
data["_prompt_cost_usd"] = prompt_cost
108+
data["_completion_cost_usd"] = completion_cost
109+
data["_total_cost_usd"] = prompt_cost + completion_cost
110+
except NotFoundError:
111+
logger.error(f"Failed to get cost for model {model}")
112+
data["_prompt_cost_usd"] = None
113+
data["_completion_cost_usd"] = None
114+
data["_total_cost_usd"] = None
115+
return data
116+
117+
62118
class LiteLLMChatRuntime(Runtime):
63119
"""
64120
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
@@ -173,45 +229,59 @@ def record_to_record(
173229
instructions_first,
174230
)
175231

176-
retries = Retrying(
177-
retry=retry_if_not_exception_type((ValidationError)),
178-
stop=stop_after_attempt(3),
179-
)
232+
retries = Retrying(**RETRY_POLICY)
180233

181234
try:
182235
# returns a pydantic model named Output
183-
response = instructor_client.chat.completions.create(
184-
messages=messages,
185-
response_model=response_model,
186-
model=self.model,
187-
max_tokens=self.max_tokens,
188-
temperature=self.temperature,
189-
seed=self.seed,
190-
max_retries=retries,
191-
# extra inference params passed to this runtime
192-
**self.model_extra,
236+
response, completion = (
237+
instructor_client.chat.completions.create_with_completion(
238+
messages=messages,
239+
response_model=response_model,
240+
model=self.model,
241+
max_tokens=self.max_tokens,
242+
temperature=self.temperature,
243+
seed=self.seed,
244+
max_retries=retries,
245+
# extra inference params passed to this runtime
246+
**self.model_extra,
247+
)
193248
)
249+
usage = completion.usage
250+
dct = response.dict()
251+
except IncompleteOutputException as e:
252+
usage = e.total_usage
253+
dct = _log_llm_exception(e)
194254
except InstructorRetryException as e:
255+
usage = e.total_usage
195256
# get root cause error from retries
196257
n_attempts = e.n_attempts
197258
e = e.__cause__.last_attempt.exception()
198-
dct = _format_error_dict(e)
199-
print_error(f"Inference error {dct['_adala_message']} after {n_attempts=}")
200-
tb = traceback.format_exc()
201-
logger.debug(tb)
202-
return dct
259+
dct = _log_llm_exception(e)
203260
except Exception as e:
261+
# usage = e.total_usage
262+
# not available here, so have to approximate by hand, assuming the same error occurred each time
263+
n_attempts = retries.stop.max_attempt_number
264+
prompt_tokens = n_attempts * litellm.token_counter(
265+
model=self.model, messages=messages[:-1]
266+
) # response is appended as the last message
267+
# TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
268+
completion_tokens = 0
269+
usage = Usage(
270+
prompt_tokens=prompt_tokens,
271+
completion_tokens=completion_tokens,
272+
total_tokens=(prompt_tokens + completion_tokens),
273+
)
274+
204275
# Catch case where the model does not return a properly formatted output
205276
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
206277
e = ConstrainedGenerationError()
207-
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
208-
dct = _format_error_dict(e)
209-
print_error(f"Inference error {dct['_adala_message']}")
210-
tb = traceback.format_exc()
211-
logger.debug(tb)
212-
return dct
278+
# there are no other known errors to catch
279+
dct = _log_llm_exception(e)
213280

214-
return response.dict()
281+
# Add usage data to the response (e.g. token counts, cost)
282+
dct.update(_get_usage_dict(usage, model=self.model))
283+
284+
return dct
215285

216286

217287
class AsyncLiteLLMChatRuntime(AsyncRuntime):
@@ -304,14 +374,11 @@ async def batch_to_batch(
304374
axis=1,
305375
).tolist()
306376

307-
retries = AsyncRetrying(
308-
retry=retry_if_not_exception_type((ValidationError)),
309-
stop=stop_after_attempt(3),
310-
)
377+
retries = AsyncRetrying(**RETRY_POLICY)
311378

312379
tasks = [
313380
asyncio.ensure_future(
314-
async_instructor_client.chat.completions.create(
381+
async_instructor_client.chat.completions.create_with_completion(
315382
messages=get_messages(
316383
user_prompt,
317384
instructions_template,
@@ -334,31 +401,48 @@ async def batch_to_batch(
334401
# convert list of LLMResponse objects to the dataframe records
335402
df_data = []
336403
for response in responses:
337-
if isinstance(response, InstructorRetryException):
404+
if isinstance(response, IncompleteOutputException):
338405
e = response
406+
usage = e.total_usage
407+
dct = _log_llm_exception(e)
408+
elif isinstance(response, InstructorRetryException):
409+
e = response
410+
usage = e.total_usage
339411
# get root cause error from retries
340412
n_attempts = e.n_attempts
341413
e = e.__cause__.last_attempt.exception()
342-
dct = _format_error_dict(e)
343-
print_error(
344-
f"Inference error {dct['_adala_message']} after {n_attempts=}"
345-
)
346-
tb = traceback.format_exc()
347-
logger.debug(tb)
348-
df_data.append(dct)
414+
dct = _log_llm_exception(e)
349415
elif isinstance(response, Exception):
350416
e = response
417+
# usage = e.total_usage
418+
# not available here, so have to approximate by hand, assuming the same error occurred each time
419+
n_attempts = retries.stop.max_attempt_number
420+
messages = [] # TODO how to get these?
421+
prompt_tokens = n_attempts * litellm.token_counter(
422+
model=self.model, messages=messages[:-1]
423+
) # response is appended as the last message
424+
# TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
425+
completion_tokens = 0
426+
usage = Usage(
427+
prompt_tokens,
428+
completion_tokens,
429+
total_tokens=(prompt_tokens + completion_tokens),
430+
)
431+
351432
# Catch case where the model does not return a properly formatted output
352433
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
353434
e = ConstrainedGenerationError()
354435
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
355-
dct = _format_error_dict(e)
356-
print_error(f"Inference error {dct['_adala_message']}")
357-
tb = traceback.format_exc()
358-
logger.debug(tb)
359-
df_data.append(dct)
436+
dct = _log_llm_exception(e)
360437
else:
361-
df_data.append(response.dict())
438+
resp, completion = response
439+
usage = completion.usage
440+
dct = resp.dict()
441+
442+
# Add usage data to the response (e.g. token counts, cost)
443+
dct.update(_get_usage_dict(usage, model=self.model))
444+
445+
df_data.append(dct)
362446

363447
output_df = InternalDataFrame(df_data)
364448
return output_df.set_index(batch.index)

docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ services:
4545
condition: service_healthy
4646
environment:
4747
- REDIS_URL=redis://redis:6379/0
48-
- MODULE_NAME=process_file.app
48+
- MODULE_NAME=stream_inference.app
4949
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
5050
- LOG_LEVEL=DEBUG
5151
- C_FORCE_ROOT=true # needed when using pickle serializer in celery + running as root - remove when we dont run as root

poetry.lock

+24-24
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ instructor = "^1.3.7"
4747
[tool.poetry.dev-dependencies]
4848
pytest = "^7.4.3"
4949
pytest-cov = "^4.1.0"
50-
black = "^24.3.0"
50+
black = "^24.8.0"
5151
pytest-black = "^0.3.12"
5252
mkdocs = "^1.5.3"
5353
mkdocs-jupyter = "^0.24.3"

server/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ poetry run uvicorn app:app --host 0.0.0.0 --port 30001
2525

2626
```bash
2727
cd tasks/
28-
poetry run celery -A process_file worker --loglevel=info
28+
poetry run celery -A stream_inference worker --loglevel=info
2929
```
3030

3131
# run in Docker

server/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from server.handlers.result_handlers import ResultHandler
2121
from server.log_middleware import LogMiddleware
22-
from server.tasks.process_file import streaming_parent_task
22+
from server.tasks.stream_inference import streaming_parent_task
2323
from server.utils import (
2424
Settings,
2525
delete_topic,

0 commit comments

Comments
 (0)