File tree Expand file tree Collapse file tree 2 files changed +33
-2
lines changed Expand file tree Collapse file tree 2 files changed +33
-2
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,16 @@ def tool(func):
156
156
def wrap_error (func ):
157
157
"""Decorator to capture and return the possible error when running the given tool."""
158
158
159
+ async def __wrap_aiter (
160
+ aiter_ : AsyncIterator [ChatMessage | str ],
161
+ ) -> AsyncIterator [ChatMessage | str ]:
162
+ try :
163
+ async for chunk in aiter_ :
164
+ yield chunk
165
+ except Exception as exc :
166
+ logger .exception (exc )
167
+ yield f"Error: { exc } "
168
+
159
169
@functools .wraps (func )
160
170
async def run (
161
171
* args : Any , ** kwargs : Any
@@ -180,7 +190,7 @@ async def run(
180
190
181
191
result = func (* args , ** kwargs )
182
192
if is_async_iterator (result ):
183
- return result
193
+ return __wrap_aiter ( result )
184
194
else :
185
195
return await result
186
196
Original file line number Diff line number Diff line change
1
+ from typing import AsyncIterator
2
+
1
3
from pydantic import Field
2
4
import pytest
3
5
4
6
from coagent .agents .chat_agent import wrap_error
5
7
6
8
7
9
@pytest .mark .asyncio
8
- async def test_wrap_error ():
10
+ async def test_wrap_error_normal ():
9
11
@wrap_error
10
12
async def func (
11
13
a : int = Field (..., description = "Argument a" ),
@@ -16,3 +18,22 @@ async def func(
16
18
assert await func () == "Error: Missing required argument 'a'"
17
19
assert await func (a = 1 ) == 1
18
20
assert await func (a = 1 , b = 0 ) == "Error: division by zero"
21
+
22
+
23
+ @pytest .mark .asyncio
24
+ async def test_wrap_error_aiter ():
25
+ @wrap_error
26
+ async def func (
27
+ a : int = Field (..., description = "Argument a" ),
28
+ b : int = Field (1 , description = "Argument b" ),
29
+ ) -> AsyncIterator [float ]:
30
+ yield a / b
31
+
32
+ result = await func ()
33
+ assert result == "Error: Missing required argument 'a'"
34
+
35
+ result = await func (a = 1 )
36
+ assert await anext (result ) == 1
37
+
38
+ result = await func (a = 1 , b = 0 )
39
+ assert await anext (result ) == "Error: division by zero"
You can’t perform that action at this time.
0 commit comments