1
1
import asyncio
2
2
import logging
3
- from typing import Any , Dict , List , Optional , Union , Type
3
+ from typing import Any , Dict , List , Optional , Type
4
4
5
5
import litellm
6
- from litellm .exceptions import AuthenticationError
6
+ from litellm .exceptions import (
7
+ AuthenticationError ,
8
+ ContentPolicyViolationError ,
9
+ BadRequestError ,
10
+ NotFoundError ,
11
+ )
12
+ from litellm .types .utils import Usage
7
13
import instructor
8
- from instructor .exceptions import InstructorRetryException
14
+ from instructor .exceptions import InstructorRetryException , IncompleteOutputException
9
15
import traceback
10
16
from adala .utils .exceptions import ConstrainedGenerationError
11
17
from adala .utils .internal_data import InternalDataFrame
14
20
parse_template ,
15
21
partial_str_format ,
16
22
)
17
- from openai import NotFoundError
18
23
from pydantic import ConfigDict , field_validator , BaseModel
19
24
from rich import print
20
25
from tenacity import (
21
26
AsyncRetrying ,
22
27
Retrying ,
23
28
retry_if_not_exception_type ,
24
29
stop_after_attempt ,
30
+ wait_random_exponential ,
25
31
)
26
32
from pydantic_core ._pydantic_core import ValidationError
27
33
33
39
logger = logging .getLogger (__name__ )
34
40
35
41
42
+ # basically only retrying on timeout, incomplete output, or rate limit
43
+ # https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list
44
+ # NOTE: token usage is only correctly calculated if we only use instructor retries, not litellm retries
45
+ # https://github.com/jxnl/instructor/pull/763
46
+ RETRY_POLICY = dict (
47
+ retry = retry_if_not_exception_type (
48
+ (
49
+ ValidationError ,
50
+ ContentPolicyViolationError ,
51
+ AuthenticationError ,
52
+ BadRequestError ,
53
+ )
54
+ ),
55
+ # should stop earlier on ValidationError and later on other errors, but couldn't figure out how to do that cleanly
56
+ stop = stop_after_attempt (3 ),
57
+ wait = wait_random_exponential (multiplier = 1 , max = 60 ),
58
+ )
59
+
60
+
36
61
def get_messages (
37
62
user_prompt : str ,
38
63
system_prompt : Optional [str ] = None ,
@@ -59,6 +84,37 @@ def _format_error_dict(e: Exception) -> dict:
59
84
return error_dct
60
85
61
86
87
+ def _log_llm_exception (e ) -> dict :
88
+ dct = _format_error_dict (e )
89
+ base_error = f"Inference error { dct ['_adala_message' ]} "
90
+ tb = traceback .format_exc ()
91
+ logger .error (f"{ base_error } \n Traceback:\n { tb } " )
92
+ return dct
93
+
94
+
95
+ def _get_usage_dict (usage : Usage , model : str ) -> Dict :
96
+ data = dict ()
97
+ data ["_prompt_tokens" ] = usage .prompt_tokens
98
+ # will not exist if there is no completion
99
+ data ["_completion_tokens" ] = usage .get ("completion_tokens" , 0 )
100
+ # can't use litellm.completion_cost bc it only takes the most recent completion, and .usage is summed over retries
101
+ # TODO make sure this is calculated correctly after we turn on caching
102
+ # litellm will register the cost of an azure model on first successful completion. If there hasn't been a successful completion, the model will not be registered
103
+ try :
104
+ prompt_cost , completion_cost = litellm .cost_per_token (
105
+ model , usage .prompt_tokens , usage .get ("completion_tokens" , 0 )
106
+ )
107
+ data ["_prompt_cost_usd" ] = prompt_cost
108
+ data ["_completion_cost_usd" ] = completion_cost
109
+ data ["_total_cost_usd" ] = prompt_cost + completion_cost
110
+ except NotFoundError :
111
+ logger .error (f"Failed to get cost for model { model } " )
112
+ data ["_prompt_cost_usd" ] = None
113
+ data ["_completion_cost_usd" ] = None
114
+ data ["_total_cost_usd" ] = None
115
+ return data
116
+
117
+
62
118
class LiteLLMChatRuntime (Runtime ):
63
119
"""
64
120
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
@@ -173,45 +229,59 @@ def record_to_record(
173
229
instructions_first ,
174
230
)
175
231
176
- retries = Retrying (
177
- retry = retry_if_not_exception_type ((ValidationError )),
178
- stop = stop_after_attempt (3 ),
179
- )
232
+ retries = Retrying (** RETRY_POLICY )
180
233
181
234
try :
182
235
# returns a pydantic model named Output
183
- response = instructor_client .chat .completions .create (
184
- messages = messages ,
185
- response_model = response_model ,
186
- model = self .model ,
187
- max_tokens = self .max_tokens ,
188
- temperature = self .temperature ,
189
- seed = self .seed ,
190
- max_retries = retries ,
191
- # extra inference params passed to this runtime
192
- ** self .model_extra ,
236
+ response , completion = (
237
+ instructor_client .chat .completions .create_with_completion (
238
+ messages = messages ,
239
+ response_model = response_model ,
240
+ model = self .model ,
241
+ max_tokens = self .max_tokens ,
242
+ temperature = self .temperature ,
243
+ seed = self .seed ,
244
+ max_retries = retries ,
245
+ # extra inference params passed to this runtime
246
+ ** self .model_extra ,
247
+ )
193
248
)
249
+ usage = completion .usage
250
+ dct = response .dict ()
251
+ except IncompleteOutputException as e :
252
+ usage = e .total_usage
253
+ dct = _log_llm_exception (e )
194
254
except InstructorRetryException as e :
255
+ usage = e .total_usage
195
256
# get root cause error from retries
196
257
n_attempts = e .n_attempts
197
258
e = e .__cause__ .last_attempt .exception ()
198
- dct = _format_error_dict (e )
199
- print_error (f"Inference error { dct ['_adala_message' ]} after { n_attempts = } " )
200
- tb = traceback .format_exc ()
201
- logger .debug (tb )
202
- return dct
259
+ dct = _log_llm_exception (e )
203
260
except Exception as e :
261
+ # usage = e.total_usage
262
+ # not available here, so have to approximate by hand, assuming the same error occurred each time
263
+ n_attempts = retries .stop .max_attempt_number
264
+ prompt_tokens = n_attempts * litellm .token_counter (
265
+ model = self .model , messages = messages [:- 1 ]
266
+ ) # response is appended as the last message
267
+ # TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
268
+ completion_tokens = 0
269
+ usage = Usage (
270
+ prompt_tokens = prompt_tokens ,
271
+ completion_tokens = completion_tokens ,
272
+ total_tokens = (prompt_tokens + completion_tokens ),
273
+ )
274
+
204
275
# Catch case where the model does not return a properly formatted output
205
276
if type (e ).__name__ == "ValidationError" and "Invalid JSON" in str (e ):
206
277
e = ConstrainedGenerationError ()
207
- # the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
208
- dct = _format_error_dict (e )
209
- print_error (f"Inference error { dct ['_adala_message' ]} " )
210
- tb = traceback .format_exc ()
211
- logger .debug (tb )
212
- return dct
278
+ # there are no other known errors to catch
279
+ dct = _log_llm_exception (e )
213
280
214
- return response .dict ()
281
+ # Add usage data to the response (e.g. token counts, cost)
282
+ dct .update (_get_usage_dict (usage , model = self .model ))
283
+
284
+ return dct
215
285
216
286
217
287
class AsyncLiteLLMChatRuntime (AsyncRuntime ):
@@ -304,14 +374,11 @@ async def batch_to_batch(
304
374
axis = 1 ,
305
375
).tolist ()
306
376
307
- retries = AsyncRetrying (
308
- retry = retry_if_not_exception_type ((ValidationError )),
309
- stop = stop_after_attempt (3 ),
310
- )
377
+ retries = AsyncRetrying (** RETRY_POLICY )
311
378
312
379
tasks = [
313
380
asyncio .ensure_future (
314
- async_instructor_client .chat .completions .create (
381
+ async_instructor_client .chat .completions .create_with_completion (
315
382
messages = get_messages (
316
383
user_prompt ,
317
384
instructions_template ,
@@ -334,31 +401,48 @@ async def batch_to_batch(
334
401
# convert list of LLMResponse objects to the dataframe records
335
402
df_data = []
336
403
for response in responses :
337
- if isinstance (response , InstructorRetryException ):
404
+ if isinstance (response , IncompleteOutputException ):
338
405
e = response
406
+ usage = e .total_usage
407
+ dct = _log_llm_exception (e )
408
+ elif isinstance (response , InstructorRetryException ):
409
+ e = response
410
+ usage = e .total_usage
339
411
# get root cause error from retries
340
412
n_attempts = e .n_attempts
341
413
e = e .__cause__ .last_attempt .exception ()
342
- dct = _format_error_dict (e )
343
- print_error (
344
- f"Inference error { dct ['_adala_message' ]} after { n_attempts = } "
345
- )
346
- tb = traceback .format_exc ()
347
- logger .debug (tb )
348
- df_data .append (dct )
414
+ dct = _log_llm_exception (e )
349
415
elif isinstance (response , Exception ):
350
416
e = response
417
+ # usage = e.total_usage
418
+ # not available here, so have to approximate by hand, assuming the same error occurred each time
419
+ n_attempts = retries .stop .max_attempt_number
420
+ messages = [] # TODO how to get these?
421
+ prompt_tokens = n_attempts * litellm .token_counter (
422
+ model = self .model , messages = messages [:- 1 ]
423
+ ) # response is appended as the last message
424
+ # TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
425
+ completion_tokens = 0
426
+ usage = Usage (
427
+ prompt_tokens ,
428
+ completion_tokens ,
429
+ total_tokens = (prompt_tokens + completion_tokens ),
430
+ )
431
+
351
432
# Catch case where the model does not return a properly formatted output
352
433
if type (e ).__name__ == "ValidationError" and "Invalid JSON" in str (e ):
353
434
e = ConstrainedGenerationError ()
354
435
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
355
- dct = _format_error_dict (e )
356
- print_error (f"Inference error { dct ['_adala_message' ]} " )
357
- tb = traceback .format_exc ()
358
- logger .debug (tb )
359
- df_data .append (dct )
436
+ dct = _log_llm_exception (e )
360
437
else :
361
- df_data .append (response .dict ())
438
+ resp , completion = response
439
+ usage = completion .usage
440
+ dct = resp .dict ()
441
+
442
+ # Add usage data to the response (e.g. token counts, cost)
443
+ dct .update (_get_usage_dict (usage , model = self .model ))
444
+
445
+ df_data .append (dct )
362
446
363
447
output_df = InternalDataFrame (df_data )
364
448
return output_df .set_index (batch .index )
0 commit comments