Skip to content

Commit be4970c

Browse files
committed
Fix wrap_error()
1 parent 1182a8d commit be4970c

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

coagent/agents/chat_agent.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def tool(func):
156156
def wrap_error(func):
157157
"""Decorator to capture and return the possible error when running the given tool."""
158158

159+
async def __wrap_aiter(
160+
aiter_: AsyncIterator[ChatMessage | str],
161+
) -> AsyncIterator[ChatMessage | str]:
162+
try:
163+
async for chunk in aiter_:
164+
yield chunk
165+
except Exception as exc:
166+
logger.exception(exc)
167+
yield f"Error: {exc}"
168+
159169
@functools.wraps(func)
160170
async def run(
161171
*args: Any, **kwargs: Any
@@ -180,7 +190,7 @@ async def run(
180190

181191
result = func(*args, **kwargs)
182192
if is_async_iterator(result):
183-
return result
193+
return __wrap_aiter(result)
184194
else:
185195
return await result
186196

tests/agents/test_chat_agent.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from typing import AsyncIterator
2+
13
from pydantic import Field
24
import pytest
35

46
from coagent.agents.chat_agent import wrap_error
57

68

79
@pytest.mark.asyncio
8-
async def test_wrap_error():
10+
async def test_wrap_error_normal():
911
@wrap_error
1012
async def func(
1113
a: int = Field(..., description="Argument a"),
@@ -16,3 +18,22 @@ async def func(
1618
assert await func() == "Error: Missing required argument 'a'"
1719
assert await func(a=1) == 1
1820
assert await func(a=1, b=0) == "Error: division by zero"
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_wrap_error_aiter():
25+
@wrap_error
26+
async def func(
27+
a: int = Field(..., description="Argument a"),
28+
b: int = Field(1, description="Argument b"),
29+
) -> AsyncIterator[float]:
30+
yield a / b
31+
32+
result = await func()
33+
assert result == "Error: Missing required argument 'a'"
34+
35+
result = await func(a=1)
36+
assert await anext(result) == 1
37+
38+
result = await func(a=1, b=0)
39+
assert await anext(result) == "Error: division by zero"

0 commit comments

Comments
 (0)