11import re
22import traceback
33from types import TracebackType
4- from typing import Any , AsyncGenerator , Dict , Iterator , List , Optional , Type , Union
4+ from typing import Any , AsyncGenerator , Dict , Iterator , List , Optional , Type
55from uuid import UUID
66
77from codeboxapi import CodeBox # type: ignore
88from codeboxapi .schema import CodeBoxStatus # type: ignore
99from gui_agent_loop_core .schema .message .schema import BaseMessageContent
1010from langchain .callbacks .base import Callbacks
11-
12- from langchain_core .callbacks import BaseCallbackHandler
1311from langchain_community .chat_message_histories .in_memory import ChatMessageHistory
1412from langchain_community .chat_message_histories .postgres import PostgresChatMessageHistory
1513from langchain_community .chat_message_histories .redis import RedisChatMessageHistory
1614from langchain_core .agents import AgentAction , AgentFinish
15+ from langchain_core .callbacks import BaseCallbackHandler
1716from langchain_core .chat_history import BaseChatMessageHistory
1817from langchain_core .language_models import BaseLanguageModel
1918from langchain_core .messages .base import BaseMessage
@@ -44,18 +43,71 @@ class AgentCallbackHandler(BaseCallbackHandler):
4443 """Base callback handler that can be used to handle callbacks from langchain."""
4544
4645 def __init__ (self , agent_callback_func : callable ):
46+ print ("AgentCallbackHandler __init__ agent_callback_func=" , agent_callback_func )
4747 self .agent_callback_func = agent_callback_func
4848
49- def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
50- """Run when chain starts running."""
51- # print("AgentCallbackHandler on_chain_start")
52-
53- def on_chain_end (self , outputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
54- """Run when chain ends running."""
55- # print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs))
56- # print("AgentCallbackHandler on_chain_end type(outputs)=", outputs)
49+ ### on_chain callbacks ###
50+ def on_chain_start (
51+ self ,
52+ serialized : dict [str , Any ],
53+ inputs : dict [str , Any ],
54+ * ,
55+ run_id : UUID ,
56+ parent_run_id : Optional [UUID ] = None ,
57+ tags : Optional [list [str ]] = None ,
58+ metadata : Optional [dict [str , Any ]] = None ,
59+ ** kwargs : Any ,
60+ ) -> Any :
61+ """Run when a chain starts running.
62+
63+ Args:
64+ serialized (Dict[str, Any]): The serialized chain.
65+ inputs (Dict[str, Any]): The inputs.
66+ run_id (UUID): The run ID. This is the ID of the current run.
67+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
68+ tags (Optional[List[str]]): The tags.
69+ metadata (Optional[Dict[str, Any]]): The metadata.
70+ kwargs (Any): Additional keyword arguments.
71+ """
72+ print ("AgentCallbackHandler on_chain_start run_id=" , run_id )
73+
74+ def on_chain_end (
75+ self ,
76+ outputs : dict [str , Any ],
77+ * ,
78+ run_id : UUID ,
79+ parent_run_id : Optional [UUID ] = None ,
80+ ** kwargs : Any ,
81+ ) -> Any :
82+ """Run when chain ends running.
83+
84+ Args:
85+ outputs (Dict[str, Any]): The outputs of the chain.
86+ run_id (UUID): The run ID. This is the ID of the current run.
87+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
88+ kwargs (Any): Additional keyword arguments."""
89+ print ("AgentCallbackHandler on_chain_end run_id=" , run_id , ", type(outputs)=" , type (outputs ))
90+ print ("AgentCallbackHandler on_chain_end self.agent_callback_func=" , self .agent_callback_func )
5791 self .agent_callback_func (outputs )
5892
93+ def on_chain_error (
94+ self ,
95+ error : BaseException ,
96+ * ,
97+ run_id : UUID ,
98+ parent_run_id : Optional [UUID ] = None ,
99+ ** kwargs : Any ,
100+ ) -> Any :
101+ """Run when chain errors.
102+
103+ Args:
104+ error (BaseException): The error that occurred.
105+ run_id (UUID): The run ID. This is the ID of the current run.
106+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
107+ kwargs (Any): Additional keyword arguments."""
108+ print ("AgentCallbackHandler on_chain_error" )
109+
110+ ### on_chat callbacks ###
59111 def on_chat_model_start (
60112 self ,
61113 serialized : Dict [str , Any ],
@@ -67,20 +119,115 @@ def on_chat_model_start(
67119 metadata : Optional [Dict [str , Any ]] = None ,
68120 ** kwargs : Any ,
69121 ) -> Any :
70- """Run when chain starts running."""
71- # print("AgentCallbackHandler on_chat_model_start")
122+ """Run when a chain starts running.
123+
124+ Args:
125+ serialized (Dict[str, Any]): The serialized chain.
126+ inputs (Dict[str, Any]): The inputs.
127+ run_id (UUID): The run ID. This is the ID of the current run.
128+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
129+ tags (Optional[List[str]]): The tags.
130+ metadata (Optional[Dict[str, Any]]): The metadata.
131+ kwargs (Any): Additional keyword arguments.
132+ """
133+ print ("AgentCallbackHandler on_chat_model_start" )
134+
135+ ### on_agent callbacks ###
136+
137+ def on_agent_action (
138+ self ,
139+ action : AgentAction ,
140+ * ,
141+ run_id : UUID ,
142+ parent_run_id : Optional [UUID ] = None ,
143+ ** kwargs : Any ,
144+ ) -> Any :
145+ """Run on agent action.
146+
147+ Args:
148+ action (AgentAction): The agent action.
149+ run_id (UUID): The run ID. This is the ID of the current run.
150+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
151+ kwargs (Any): Additional keyword arguments."""
152+ print ("AgentCallbackHandler on_agent_action" )
72153
73- def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
74- """Run when chain errors."""
75- # print("AgentCallbackHandler on_chain_error")
154+ def on_agent_finish (
155+ self ,
156+ finish : AgentFinish ,
157+ * ,
158+ run_id : UUID ,
159+ parent_run_id : Optional [UUID ] = None ,
160+ ** kwargs : Any ,
161+ ) -> Any :
162+ """Run on the agent end.
76163
77- def on_agent_action (self , action : AgentAction , ** kwargs : Any ) -> Any :
78- """Run on agent action."""
79- # print("AgentCallbackHandler on_agent_action")
164+ Args:
165+ finish (AgentFinish): The agent finish.
166+ run_id (UUID): The run ID. This is the ID of the current run.
167+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
168+ kwargs (Any): Additional keyword arguments."""
169+ print ("AgentCallbackHandler on_agent_finish" )
170+
171+ ### on_tool callbacks ###
172+ def on_tool_start (
173+ self ,
174+ serialized : dict [str , Any ],
175+ input_str : str ,
176+ * ,
177+ run_id : UUID ,
178+ parent_run_id : Optional [UUID ] = None ,
179+ tags : Optional [list [str ]] = None ,
180+ metadata : Optional [dict [str , Any ]] = None ,
181+ inputs : Optional [dict [str , Any ]] = None ,
182+ ** kwargs : Any ,
183+ ) -> Any :
184+ """Run when the tool starts running.
185+
186+ Args:
187+ serialized (Dict[str, Any]): The serialized tool.
188+ input_str (str): The input string.
189+ run_id (UUID): The run ID. This is the ID of the current run.
190+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
191+ tags (Optional[List[str]]): The tags.
192+ metadata (Optional[Dict[str, Any]]): The metadata.
193+ inputs (Optional[Dict[str, Any]]): The inputs.
194+ kwargs (Any): Additional keyword arguments.
195+ """
196+ print ("AgentCallbackHandler on_tool_start" )
197+
198+ def on_tool_end (
199+ self ,
200+ output : Any ,
201+ * ,
202+ run_id : UUID ,
203+ parent_run_id : Optional [UUID ] = None ,
204+ ** kwargs : Any ,
205+ ) -> Any :
206+ """Run when the tool ends running.
207+
208+ Args:
209+ output (Any): The output of the tool.
210+ run_id (UUID): The run ID. This is the ID of the current run.
211+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
212+ kwargs (Any): Additional keyword arguments."""
213+ print ("AgentCallbackHandler on_tool_end" )
214+
215+ def on_tool_error (
216+ self ,
217+ error : BaseException ,
218+ * ,
219+ run_id : UUID ,
220+ parent_run_id : Optional [UUID ] = None ,
221+ ** kwargs : Any ,
222+ ) -> Any :
223+ """Run when tool errors.
80224
81- def on_agent_finish (self , finish : AgentFinish , ** kwargs : Any ) -> Any :
82- """Run on agent end."""
83- # print("AgentCallbackHandler on_agent_finish")
225+ Args:
226+ error (BaseException): The error that occurred.
227+ run_id (UUID): The run ID. This is the ID of the current run.
228+ parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
229+ kwargs (Any): Additional keyword arguments."""
230+ print ("AgentCallbackHandler on_tool_error" )
84231
85232
86233class CodeInterpreterSession :
@@ -313,12 +460,16 @@ def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse:
313460
314461 def _output_handler (self , response : Any ) -> CodeInterpreterResponse :
315462 """Embed images in the response"""
463+ print ("XXXX _output_handler in response=" , type (response ))
316464 final_response = self ._output_handler_pre (response )
465+ print ("XXXX _output_handler step1 " )
317466 response = self ._output_handler_post (final_response )
467+ print ("XXXX _output_handler out " )
318468 return response
319469
320470 async def _aoutput_handler (self , response : str ) -> CodeInterpreterResponse :
321471 """Embed images in the response"""
472+ print ("XXXX _aoutput_handler in response=" , type (response ))
322473 final_response = self ._output_handler_pre (response )
323474 for file in self .output_files :
324475 if str (file .name ) in final_response :
@@ -405,7 +556,7 @@ async def agenerate_response(
405556 traceback .print_exc ()
406557 if settings .DETAILED_ERROR :
407558 return CodeInterpreterResponse (
408- content = "Error in CodeInterpreterSession(agenerate_response): " f" { e .__class__ .__name__ } - { e } " ,
559+ content = f "Error in CodeInterpreterSession(agenerate_response): { e .__class__ .__name__ } - { e } " ,
409560 agent_name = self .brain .current_agent ,
410561 )
411562 else :
@@ -479,8 +630,7 @@ async def agenerate_response_stream(
479630 traceback .print_exc ()
480631 if settings .DETAILED_ERROR :
481632 yield CodeInterpreterResponse (
482- content = "Error in CodeInterpreterSession(agenerate_response_stream): "
483- f"{ e .__class__ .__name__ } - { e } "
633+ content = f"Error in CodeInterpreterSession(agenerate_response_stream): { e .__class__ .__name__ } - { e } "
484634 )
485635 else :
486636 yield CodeInterpreterResponse (
0 commit comments