Skip to content

Commit 8a235e7

Browse files
authored
(Refactor / QA) - Use LoggingCallbackManager to append callbacks and ensure no duplicate callbacks are added (#8112)
* LoggingCallbackManager * add logging_callback_manager * use logging_callback_manager * add add_litellm_failure_callback * use add_litellm_callback * use add_litellm_async_success_callback * add_litellm_async_failure_callback * linting fix * fix logging callback manager * test_duplicate_multiple_loggers_test * use _reset_all_callbacks * fix testing with dup callbacks * test_basic_image_generation * reset callbacks for tests * fix check for _add_custom_logger_to_list * fix test_amazing_sync_embedding * fix _get_custom_logger_key * fix batches testing * fix _reset_all_callbacks * fix _check_callback_list_size * add callback_manager_test * fix test gemini-2.0-flash-thinking-exp-01-21
1 parent 3eac163 commit 8a235e7

File tree

19 files changed

+607
-59
lines changed

19 files changed

+607
-59
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ jobs:
994994
- run: ruff check ./litellm
995995
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
996996
- run: python ./tests/code_coverage_tests/router_code_coverage.py
997+
- run: python ./tests/code_coverage_tests/callback_manager_test.py
997998
- run: python ./tests/code_coverage_tests/recursive_detector.py
998999
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
9991000
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py

litellm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
4040
from litellm.integrations.custom_logger import CustomLogger
41+
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
4142
import httpx
4243
import dotenv
4344
from enum import Enum
@@ -50,6 +51,7 @@
5051
_turn_on_debug()
5152
###############################################
5253
### Callbacks /Logging / Success / Failure Handlers #####
54+
logging_callback_manager = LoggingCallbackManager()
5355
input_callback: List[Union[str, Callable, CustomLogger]] = []
5456
success_callback: List[Union[str, Callable, CustomLogger]] = []
5557
failure_callback: List[Union[str, Callable, CustomLogger]] = []

litellm/caching/caching.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def __init__(
207207
if "cache" not in litellm.input_callback:
208208
litellm.input_callback.append("cache")
209209
if "cache" not in litellm.success_callback:
210-
litellm.success_callback.append("cache")
210+
litellm.logging_callback_manager.add_litellm_success_callback("cache")
211211
if "cache" not in litellm._async_success_callback:
212-
litellm._async_success_callback.append("cache")
212+
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
213213
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
214214
self.type = type
215215
self.namespace = namespace
@@ -774,9 +774,9 @@ def enable_cache(
774774
if "cache" not in litellm.input_callback:
775775
litellm.input_callback.append("cache")
776776
if "cache" not in litellm.success_callback:
777-
litellm.success_callback.append("cache")
777+
litellm.logging_callback_manager.add_litellm_success_callback("cache")
778778
if "cache" not in litellm._async_success_callback:
779-
litellm._async_success_callback.append("cache")
779+
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
780780

781781
if litellm.cache is None:
782782
litellm.cache = Cache(
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from typing import Callable, List, Union
2+
3+
import litellm
4+
from litellm._logging import verbose_logger
5+
from litellm.integrations.custom_logger import CustomLogger
6+
7+
8+
class LoggingCallbackManager:
9+
"""
10+
A centralized class that allows easy add / remove callbacks for litellm.
11+
12+
Goals of this class:
13+
- Prevent adding duplicate callbacks / success_callback / failure_callback
14+
- Keep a reasonable MAX_CALLBACKS limit (this ensures callbacks don't exponentially grow and consume CPU Resources)
15+
"""
16+
17+
# healthy maximum number of callbacks - unlikely someone needs more than 20
18+
MAX_CALLBACKS = 30
19+
20+
def add_litellm_input_callback(self, callback: Union[CustomLogger, str]):
21+
"""
22+
Add a input callback to litellm.input_callback
23+
"""
24+
self._safe_add_callback_to_list(
25+
callback=callback, parent_list=litellm.input_callback
26+
)
27+
28+
def add_litellm_service_callback(
29+
self, callback: Union[CustomLogger, str, Callable]
30+
):
31+
"""
32+
Add a service callback to litellm.service_callback
33+
"""
34+
self._safe_add_callback_to_list(
35+
callback=callback, parent_list=litellm.service_callback
36+
)
37+
38+
def add_litellm_callback(self, callback: Union[CustomLogger, str, Callable]):
39+
"""
40+
Add a callback to litellm.callbacks
41+
42+
Ensures no duplicates are added.
43+
"""
44+
self._safe_add_callback_to_list(
45+
callback=callback, parent_list=litellm.callbacks # type: ignore
46+
)
47+
48+
def add_litellm_success_callback(
49+
self, callback: Union[CustomLogger, str, Callable]
50+
):
51+
"""
52+
Add a success callback to `litellm.success_callback`
53+
"""
54+
self._safe_add_callback_to_list(
55+
callback=callback, parent_list=litellm.success_callback
56+
)
57+
58+
def add_litellm_failure_callback(
59+
self, callback: Union[CustomLogger, str, Callable]
60+
):
61+
"""
62+
Add a failure callback to `litellm.failure_callback`
63+
"""
64+
self._safe_add_callback_to_list(
65+
callback=callback, parent_list=litellm.failure_callback
66+
)
67+
68+
def add_litellm_async_success_callback(
69+
self, callback: Union[CustomLogger, Callable, str]
70+
):
71+
"""
72+
Add a success callback to litellm._async_success_callback
73+
"""
74+
self._safe_add_callback_to_list(
75+
callback=callback, parent_list=litellm._async_success_callback
76+
)
77+
78+
def add_litellm_async_failure_callback(
79+
self, callback: Union[CustomLogger, Callable, str]
80+
):
81+
"""
82+
Add a failure callback to litellm._async_failure_callback
83+
"""
84+
self._safe_add_callback_to_list(
85+
callback=callback, parent_list=litellm._async_failure_callback
86+
)
87+
88+
def _add_string_callback_to_list(
89+
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
90+
):
91+
"""
92+
Add a string callback to a list, if the callback is already in the list, do not add it again.
93+
"""
94+
if callback not in parent_list:
95+
parent_list.append(callback)
96+
else:
97+
verbose_logger.debug(
98+
f"Callback {callback} already exists in {parent_list}, not adding again.."
99+
)
100+
101+
def _check_callback_list_size(
102+
self, parent_list: List[Union[CustomLogger, Callable, str]]
103+
) -> bool:
104+
"""
105+
Check if adding another callback would exceed MAX_CALLBACKS
106+
Returns True if safe to add, False if would exceed limit
107+
"""
108+
if len(parent_list) >= self.MAX_CALLBACKS:
109+
verbose_logger.warning(
110+
f"Cannot add callback - would exceed MAX_CALLBACKS limit of {self.MAX_CALLBACKS}. Current callbacks: {len(parent_list)}"
111+
)
112+
return False
113+
return True
114+
115+
def _safe_add_callback_to_list(
116+
self,
117+
callback: Union[CustomLogger, Callable, str],
118+
parent_list: List[Union[CustomLogger, Callable, str]],
119+
):
120+
"""
121+
Safe add a callback to a list, if the callback is already in the list, do not add it again.
122+
123+
Ensures no duplicates are added for `str`, `Callable`, and `CustomLogger` callbacks.
124+
"""
125+
# Check max callbacks limit first
126+
if not self._check_callback_list_size(parent_list):
127+
return
128+
129+
if isinstance(callback, str):
130+
self._add_string_callback_to_list(
131+
callback=callback, parent_list=parent_list
132+
)
133+
elif isinstance(callback, CustomLogger):
134+
self._add_custom_logger_to_list(
135+
custom_logger=callback,
136+
parent_list=parent_list,
137+
)
138+
elif callable(callback):
139+
self._add_callback_function_to_list(
140+
callback=callback, parent_list=parent_list
141+
)
142+
143+
def _add_callback_function_to_list(
144+
self, callback: Callable, parent_list: List[Union[CustomLogger, Callable, str]]
145+
):
146+
"""
147+
Add a callback function to a list, if the callback is already in the list, do not add it again.
148+
"""
149+
# Check if the function already exists in the list by comparing function objects
150+
if callback not in parent_list:
151+
parent_list.append(callback)
152+
else:
153+
verbose_logger.debug(
154+
f"Callback function {callback.__name__} already exists in {parent_list}, not adding again.."
155+
)
156+
157+
def _add_custom_logger_to_list(
158+
self,
159+
custom_logger: CustomLogger,
160+
parent_list: List[Union[CustomLogger, Callable, str]],
161+
):
162+
"""
163+
Add a custom logger to a list, if another instance of the same custom logger exists in the list, do not add it again.
164+
"""
165+
# Check if an instance of the same class already exists in the list
166+
custom_logger_key = self._get_custom_logger_key(custom_logger)
167+
custom_logger_type_name = type(custom_logger).__name__
168+
for existing_logger in parent_list:
169+
if (
170+
isinstance(existing_logger, CustomLogger)
171+
and self._get_custom_logger_key(existing_logger) == custom_logger_key
172+
):
173+
verbose_logger.debug(
174+
f"Custom logger of type {custom_logger_type_name}, key: {custom_logger_key} already exists in {parent_list}, not adding again.."
175+
)
176+
return
177+
parent_list.append(custom_logger)
178+
179+
def _get_custom_logger_key(self, custom_logger: CustomLogger):
180+
"""
181+
Get a unique key for a custom logger that considers only fundamental instance variables
182+
183+
Returns:
184+
str: A unique key combining the class name and fundamental instance variables (str, bool, int)
185+
"""
186+
key_parts = [type(custom_logger).__name__]
187+
188+
# Add only fundamental type instance variables to the key
189+
for attr_name, attr_value in vars(custom_logger).items():
190+
if not attr_name.startswith("_"): # Skip private attributes
191+
if isinstance(attr_value, (str, bool, int)):
192+
key_parts.append(f"{attr_name}={attr_value}")
193+
194+
return "-".join(key_parts)
195+
196+
def _reset_all_callbacks(self):
197+
"""
198+
Reset all callbacks to an empty list
199+
200+
Note: this is an internal function and should be used sparingly.
201+
"""
202+
litellm.input_callback = []
203+
litellm.success_callback = []
204+
litellm.failure_callback = []
205+
litellm._async_success_callback = []
206+
litellm._async_failure_callback = []
207+
litellm.callbacks = []

litellm/proxy/guardrails/guardrail_initializers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def initialize_aporia(litellm_params, guardrail):
1313
event_hook=litellm_params["mode"],
1414
default_on=litellm_params["default_on"],
1515
)
16-
litellm.callbacks.append(_aporia_callback)
16+
litellm.logging_callback_manager.add_litellm_callback(_aporia_callback)
1717

1818

1919
def initialize_bedrock(litellm_params, guardrail):
@@ -28,7 +28,7 @@ def initialize_bedrock(litellm_params, guardrail):
2828
guardrailVersion=litellm_params["guardrailVersion"],
2929
default_on=litellm_params["default_on"],
3030
)
31-
litellm.callbacks.append(_bedrock_callback)
31+
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
3232

3333

3434
def initialize_lakera(litellm_params, guardrail):
@@ -42,7 +42,7 @@ def initialize_lakera(litellm_params, guardrail):
4242
category_thresholds=litellm_params.get("category_thresholds"),
4343
default_on=litellm_params["default_on"],
4444
)
45-
litellm.callbacks.append(_lakera_callback)
45+
litellm.logging_callback_manager.add_litellm_callback(_lakera_callback)
4646

4747

4848
def initialize_aim(litellm_params, guardrail):
@@ -55,7 +55,7 @@ def initialize_aim(litellm_params, guardrail):
5555
event_hook=litellm_params["mode"],
5656
default_on=litellm_params["default_on"],
5757
)
58-
litellm.callbacks.append(_aim_callback)
58+
litellm.logging_callback_manager.add_litellm_callback(_aim_callback)
5959

6060

6161
def initialize_presidio(litellm_params, guardrail):
@@ -71,7 +71,7 @@ def initialize_presidio(litellm_params, guardrail):
7171
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
7272
default_on=litellm_params["default_on"],
7373
)
74-
litellm.callbacks.append(_presidio_callback)
74+
litellm.logging_callback_manager.add_litellm_callback(_presidio_callback)
7575

7676
if litellm_params["output_parse_pii"]:
7777
_success_callback = _OPTIONAL_PresidioPIIMasking(
@@ -81,7 +81,7 @@ def initialize_presidio(litellm_params, guardrail):
8181
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
8282
default_on=litellm_params["default_on"],
8383
)
84-
litellm.callbacks.append(_success_callback)
84+
litellm.logging_callback_manager.add_litellm_callback(_success_callback)
8585

8686

8787
def initialize_hide_secrets(litellm_params, guardrail):
@@ -93,7 +93,7 @@ def initialize_hide_secrets(litellm_params, guardrail):
9393
guardrail_name=guardrail["guardrail_name"],
9494
default_on=litellm_params["default_on"],
9595
)
96-
litellm.callbacks.append(_secret_detection_object)
96+
litellm.logging_callback_manager.add_litellm_callback(_secret_detection_object)
9797

9898

9999
def initialize_guardrails_ai(litellm_params, guardrail):
@@ -111,4 +111,4 @@ def initialize_guardrails_ai(litellm_params, guardrail):
111111
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
112112
default_on=litellm_params["default_on"],
113113
)
114-
litellm.callbacks.append(_guardrails_ai_callback)
114+
litellm.logging_callback_manager.add_litellm_callback(_guardrails_ai_callback)

litellm/proxy/guardrails/init_guardrails.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def init_guardrails_v2(
157157
event_hook=litellm_params["mode"],
158158
default_on=litellm_params["default_on"],
159159
)
160-
litellm.callbacks.append(_guardrail_callback) # type: ignore
160+
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
161161
else:
162162
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
163163

0 commit comments

Comments
 (0)