Skip to content

Commit e6ed971

Browse files
committed
fix: src/codeinterpreterapi/session.py for CallbackHandler
1 parent cf28909 commit e6ed971

File tree

1 file changed

+175
-25
lines changed

1 file changed

+175
-25
lines changed

src/codeinterpreterapi/session.py

Lines changed: 175 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
import re
22
import traceback
33
from types import TracebackType
4-
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type, Union
4+
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type
55
from uuid import UUID
66

77
from codeboxapi import CodeBox # type: ignore
88
from codeboxapi.schema import CodeBoxStatus # type: ignore
99
from gui_agent_loop_core.schema.message.schema import BaseMessageContent
1010
from langchain.callbacks.base import Callbacks
11-
12-
from langchain_core.callbacks import BaseCallbackHandler
1311
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
1412
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
1513
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
1614
from langchain_core.agents import AgentAction, AgentFinish
15+
from langchain_core.callbacks import BaseCallbackHandler
1716
from langchain_core.chat_history import BaseChatMessageHistory
1817
from langchain_core.language_models import BaseLanguageModel
1918
from langchain_core.messages.base import BaseMessage
@@ -44,18 +43,71 @@ class AgentCallbackHandler(BaseCallbackHandler):
4443
"""Base callback handler that can be used to handle callbacks from langchain."""
4544

4645
def __init__(self, agent_callback_func: callable):
46+
print("AgentCallbackHandler __init__ agent_callback_func=", agent_callback_func)
4747
self.agent_callback_func = agent_callback_func
4848

49-
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
50-
"""Run when chain starts running."""
51-
# print("AgentCallbackHandler on_chain_start")
52-
53-
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
54-
"""Run when chain ends running."""
55-
# print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs))
56-
# print("AgentCallbackHandler on_chain_end type(outputs)=", outputs)
49+
### on_chain callbacks ###
50+
def on_chain_start(
51+
self,
52+
serialized: dict[str, Any],
53+
inputs: dict[str, Any],
54+
*,
55+
run_id: UUID,
56+
parent_run_id: Optional[UUID] = None,
57+
tags: Optional[list[str]] = None,
58+
metadata: Optional[dict[str, Any]] = None,
59+
**kwargs: Any,
60+
) -> Any:
61+
"""Run when a chain starts running.
62+
63+
Args:
64+
serialized (Dict[str, Any]): The serialized chain.
65+
inputs (Dict[str, Any]): The inputs.
66+
run_id (UUID): The run ID. This is the ID of the current run.
67+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
68+
tags (Optional[List[str]]): The tags.
69+
metadata (Optional[Dict[str, Any]]): The metadata.
70+
kwargs (Any): Additional keyword arguments.
71+
"""
72+
print("AgentCallbackHandler on_chain_start run_id=", run_id)
73+
74+
def on_chain_end(
75+
self,
76+
outputs: dict[str, Any],
77+
*,
78+
run_id: UUID,
79+
parent_run_id: Optional[UUID] = None,
80+
**kwargs: Any,
81+
) -> Any:
82+
"""Run when chain ends running.
83+
84+
Args:
85+
outputs (Dict[str, Any]): The outputs of the chain.
86+
run_id (UUID): The run ID. This is the ID of the current run.
87+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
88+
kwargs (Any): Additional keyword arguments."""
89+
print("AgentCallbackHandler on_chain_end run_id=", run_id, ", type(outputs)=", type(outputs))
90+
print("AgentCallbackHandler on_chain_end self.agent_callback_func=", self.agent_callback_func)
5791
self.agent_callback_func(outputs)
5892

93+
def on_chain_error(
94+
self,
95+
error: BaseException,
96+
*,
97+
run_id: UUID,
98+
parent_run_id: Optional[UUID] = None,
99+
**kwargs: Any,
100+
) -> Any:
101+
"""Run when chain errors.
102+
103+
Args:
104+
error (BaseException): The error that occurred.
105+
run_id (UUID): The run ID. This is the ID of the current run.
106+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
107+
kwargs (Any): Additional keyword arguments."""
108+
print("AgentCallbackHandler on_chain_error")
109+
110+
### on_chat callbacks ###
59111
def on_chat_model_start(
60112
self,
61113
serialized: Dict[str, Any],
@@ -67,20 +119,115 @@ def on_chat_model_start(
67119
metadata: Optional[Dict[str, Any]] = None,
68120
**kwargs: Any,
69121
) -> Any:
70-
"""Run when chain starts running."""
71-
# print("AgentCallbackHandler on_chat_model_start")
122+
"""Run when a chain starts running.
123+
124+
Args:
125+
serialized (Dict[str, Any]): The serialized chain.
126+
inputs (Dict[str, Any]): The inputs.
127+
run_id (UUID): The run ID. This is the ID of the current run.
128+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
129+
tags (Optional[List[str]]): The tags.
130+
metadata (Optional[Dict[str, Any]]): The metadata.
131+
kwargs (Any): Additional keyword arguments.
132+
"""
133+
print("AgentCallbackHandler on_chat_model_start")
134+
135+
### on_agent callbacks ###
136+
137+
def on_agent_action(
138+
self,
139+
action: AgentAction,
140+
*,
141+
run_id: UUID,
142+
parent_run_id: Optional[UUID] = None,
143+
**kwargs: Any,
144+
) -> Any:
145+
"""Run on agent action.
146+
147+
Args:
148+
action (AgentAction): The agent action.
149+
run_id (UUID): The run ID. This is the ID of the current run.
150+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
151+
kwargs (Any): Additional keyword arguments."""
152+
print("AgentCallbackHandler on_agent_action")
72153

73-
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
74-
"""Run when chain errors."""
75-
# print("AgentCallbackHandler on_chain_error")
154+
def on_agent_finish(
155+
self,
156+
finish: AgentFinish,
157+
*,
158+
run_id: UUID,
159+
parent_run_id: Optional[UUID] = None,
160+
**kwargs: Any,
161+
) -> Any:
162+
"""Run on the agent end.
76163
77-
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
78-
"""Run on agent action."""
79-
# print("AgentCallbackHandler on_agent_action")
164+
Args:
165+
finish (AgentFinish): The agent finish.
166+
run_id (UUID): The run ID. This is the ID of the current run.
167+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
168+
kwargs (Any): Additional keyword arguments."""
169+
print("AgentCallbackHandler on_agent_finish")
170+
171+
### on_tool callbacks ###
172+
def on_tool_start(
173+
self,
174+
serialized: dict[str, Any],
175+
input_str: str,
176+
*,
177+
run_id: UUID,
178+
parent_run_id: Optional[UUID] = None,
179+
tags: Optional[list[str]] = None,
180+
metadata: Optional[dict[str, Any]] = None,
181+
inputs: Optional[dict[str, Any]] = None,
182+
**kwargs: Any,
183+
) -> Any:
184+
"""Run when the tool starts running.
185+
186+
Args:
187+
serialized (Dict[str, Any]): The serialized tool.
188+
input_str (str): The input string.
189+
run_id (UUID): The run ID. This is the ID of the current run.
190+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
191+
tags (Optional[List[str]]): The tags.
192+
metadata (Optional[Dict[str, Any]]): The metadata.
193+
inputs (Optional[Dict[str, Any]]): The inputs.
194+
kwargs (Any): Additional keyword arguments.
195+
"""
196+
print("AgentCallbackHandler on_tool_start")
197+
198+
def on_tool_end(
199+
self,
200+
output: Any,
201+
*,
202+
run_id: UUID,
203+
parent_run_id: Optional[UUID] = None,
204+
**kwargs: Any,
205+
) -> Any:
206+
"""Run when the tool ends running.
207+
208+
Args:
209+
output (Any): The output of the tool.
210+
run_id (UUID): The run ID. This is the ID of the current run.
211+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
212+
kwargs (Any): Additional keyword arguments."""
213+
print("AgentCallbackHandler on_tool_end")
214+
215+
def on_tool_error(
216+
self,
217+
error: BaseException,
218+
*,
219+
run_id: UUID,
220+
parent_run_id: Optional[UUID] = None,
221+
**kwargs: Any,
222+
) -> Any:
223+
"""Run when tool errors.
80224
81-
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
82-
"""Run on agent end."""
83-
# print("AgentCallbackHandler on_agent_finish")
225+
Args:
226+
error (BaseException): The error that occurred.
227+
run_id (UUID): The run ID. This is the ID of the current run.
228+
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
229+
kwargs (Any): Additional keyword arguments."""
230+
print("AgentCallbackHandler on_tool_error")
84231

85232

86233
class CodeInterpreterSession:
@@ -313,12 +460,16 @@ def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse:
313460

314461
def _output_handler(self, response: Any) -> CodeInterpreterResponse:
315462
"""Embed images in the response"""
463+
print("XXXX _output_handler in response=", type(response))
316464
final_response = self._output_handler_pre(response)
465+
print("XXXX _output_handler step1 ")
317466
response = self._output_handler_post(final_response)
467+
print("XXXX _output_handler out ")
318468
return response
319469

320470
async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse:
321471
"""Embed images in the response"""
472+
print("XXXX _aoutput_handler in response=", type(response))
322473
final_response = self._output_handler_pre(response)
323474
for file in self.output_files:
324475
if str(file.name) in final_response:
@@ -405,7 +556,7 @@ async def agenerate_response(
405556
traceback.print_exc()
406557
if settings.DETAILED_ERROR:
407558
return CodeInterpreterResponse(
408-
content="Error in CodeInterpreterSession(agenerate_response): " f"{e.__class__.__name__} - {e}",
559+
content=f"Error in CodeInterpreterSession(agenerate_response): {e.__class__.__name__} - {e}",
409560
agent_name=self.brain.current_agent,
410561
)
411562
else:
@@ -479,8 +630,7 @@ async def agenerate_response_stream(
479630
traceback.print_exc()
480631
if settings.DETAILED_ERROR:
481632
yield CodeInterpreterResponse(
482-
content="Error in CodeInterpreterSession(agenerate_response_stream): "
483-
f"{e.__class__.__name__} - {e}"
633+
content=f"Error in CodeInterpreterSession(agenerate_response_stream): {e.__class__.__name__} - {e}"
484634
)
485635
else:
486636
yield CodeInterpreterResponse(

0 commit comments

Comments
 (0)