Skip to content

Commit b1ffad0

Browse files
authored
Enforce max_output_length for shell tool outputs (#2299)
1 parent 813b035 commit b1ffad0

File tree

2 files changed

+309
-8
lines changed

2 files changed

+309
-8
lines changed

src/agents/_run_impl.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,6 +1739,9 @@ async def execute(
17391739
shell_output_payload: list[dict[str, Any]] | None = None
17401740
provider_meta: dict[str, Any] | None = None
17411741
max_output_length: int | None = None
1742+
requested_max_output_length = _normalize_max_output_length(
1743+
shell_call.action.max_output_length
1744+
)
17421745

17431746
try:
17441747
executor_result = call.shell_tool.executor(request)
@@ -1748,15 +1751,31 @@ async def execute(
17481751

17491752
if isinstance(result, ShellResult):
17501753
normalized = [_normalize_shell_output(entry) for entry in result.output]
1754+
result_max_output_length = _normalize_max_output_length(result.max_output_length)
1755+
if result_max_output_length is None:
1756+
max_output_length = requested_max_output_length
1757+
elif requested_max_output_length is None:
1758+
max_output_length = result_max_output_length
1759+
else:
1760+
max_output_length = min(result_max_output_length, requested_max_output_length)
1761+
if max_output_length is not None:
1762+
normalized = _truncate_shell_outputs(normalized, max_output_length)
17511763
output_text = _render_shell_outputs(normalized)
1764+
if max_output_length is not None:
1765+
output_text = output_text[:max_output_length]
17521766
shell_output_payload = [_serialize_shell_output(entry) for entry in normalized]
17531767
provider_meta = dict(result.provider_data or {})
1754-
max_output_length = result.max_output_length
17551768
else:
17561769
output_text = str(result)
1770+
if requested_max_output_length is not None:
1771+
max_output_length = requested_max_output_length
1772+
output_text = output_text[:max_output_length]
17571773
except Exception as exc:
17581774
status = "failed"
17591775
output_text = _format_shell_error(exc)
1776+
if requested_max_output_length is not None:
1777+
max_output_length = requested_max_output_length
1778+
output_text = output_text[:max_output_length]
17601779
logger.error("Shell executor failed: %s", exc, exc_info=True)
17611780

17621781
await asyncio.gather(
@@ -2029,6 +2048,51 @@ def _render_shell_outputs(outputs: Sequence[ShellCommandOutput]) -> str:
20292048
return "\n\n".join(rendered_chunks)
20302049

20312050

2051+
def _truncate_shell_outputs(
2052+
outputs: Sequence[ShellCommandOutput], max_length: int
2053+
) -> list[ShellCommandOutput]:
2054+
if max_length <= 0:
2055+
return [
2056+
ShellCommandOutput(
2057+
stdout="",
2058+
stderr="",
2059+
outcome=output.outcome,
2060+
command=output.command,
2061+
provider_data=output.provider_data,
2062+
)
2063+
for output in outputs
2064+
]
2065+
2066+
remaining = max_length
2067+
truncated: list[ShellCommandOutput] = []
2068+
for output in outputs:
2069+
stdout = ""
2070+
stderr = ""
2071+
if remaining > 0 and output.stdout:
2072+
stdout = output.stdout[:remaining]
2073+
remaining -= len(stdout)
2074+
if remaining > 0 and output.stderr:
2075+
stderr = output.stderr[:remaining]
2076+
remaining -= len(stderr)
2077+
truncated.append(
2078+
ShellCommandOutput(
2079+
stdout=stdout,
2080+
stderr=stderr,
2081+
outcome=output.outcome,
2082+
command=output.command,
2083+
provider_data=output.provider_data,
2084+
)
2085+
)
2086+
2087+
return truncated
2088+
2089+
2090+
def _normalize_max_output_length(value: int | None) -> int | None:
2091+
if value is None:
2092+
return None
2093+
return max(0, value)
2094+
2095+
20322096
def _format_shell_error(error: Exception | BaseException | Any) -> str:
20332097
if isinstance(error, Exception):
20342098
message = str(error)
@@ -2078,9 +2142,9 @@ def _coerce_shell_call(tool_call: Any) -> ShellCallData:
20782142
)
20792143
timeout_ms = int(timeout_value) if isinstance(timeout_value, (int, float)) else None
20802144

2081-
max_length_value = _get_mapping_or_attr(
2082-
action_payload, "max_output_length"
2083-
) or _get_mapping_or_attr(action_payload, "maxOutputLength")
2145+
max_length_value = _get_mapping_or_attr(action_payload, "max_output_length")
2146+
if max_length_value is None:
2147+
max_length_value = _get_mapping_or_attr(action_payload, "maxOutputLength")
20842148
max_output_length = (
20852149
int(max_length_value) if isinstance(max_length_value, (int, float)) else None
20862150
)

tests/test_shell_tool.py

Lines changed: 241 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,19 @@ async def test_shell_tool_structured_output_is_rendered() -> None:
9494
async def test_shell_tool_executor_failure_returns_error() -> None:
9595
class ExplodingExecutor:
9696
def __call__(self, request):
97-
raise RuntimeError("boom")
97+
raise RuntimeError("boom" * 10)
9898

9999
shell_tool = ShellTool(executor=ExplodingExecutor())
100100
tool_call = {
101101
"type": "shell_call",
102102
"id": "shell_call_fail",
103103
"call_id": "call_shell_fail",
104104
"status": "completed",
105-
"action": {"commands": ["echo boom"], "timeout_ms": 1000},
105+
"action": {
106+
"commands": ["echo boom"],
107+
"timeout_ms": 1000,
108+
"max_output_length": 6,
109+
},
106110
}
107111
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
108112
agent = Agent(name="shell-agent", tools=[shell_tool])
@@ -117,12 +121,13 @@ def __call__(self, request):
117121
)
118122

119123
assert isinstance(result, ToolCallOutputItem)
120-
assert "boom" in result.output
124+
assert result.output == "boombo"
121125
raw_item = cast(dict[str, Any], result.raw_item)
122126
assert raw_item["type"] == "shell_call_output"
123127
assert raw_item["status"] == "failed"
128+
assert raw_item["max_output_length"] == 6
124129
assert isinstance(raw_item["output"], list)
125-
assert "boom" in raw_item["output"][0]["stdout"]
130+
assert raw_item["output"][0]["stdout"] == "boombo"
126131
first_output = raw_item["output"][0]
127132
assert first_output["outcome"]["type"] == "exit"
128133
assert first_output["outcome"]["exit_code"] == 1
@@ -135,3 +140,235 @@ def __call__(self, request):
135140
assert "status" not in payload_dict
136141
assert "shell_output" not in payload_dict
137142
assert "provider_data" not in payload_dict
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_shell_tool_output_respects_max_output_length() -> None:
147+
shell_tool = ShellTool(
148+
executor=lambda request: ShellResult(
149+
output=[
150+
ShellCommandOutput(
151+
stdout="0123456789",
152+
stderr="abcdef",
153+
outcome=ShellCallOutcome(type="exit", exit_code=0),
154+
)
155+
],
156+
)
157+
)
158+
159+
tool_call = {
160+
"type": "shell_call",
161+
"id": "shell_call",
162+
"call_id": "call_shell",
163+
"status": "completed",
164+
"action": {
165+
"commands": ["echo hi"],
166+
"timeout_ms": 1000,
167+
"max_output_length": 6,
168+
},
169+
}
170+
171+
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
172+
agent = Agent(name="shell-agent", tools=[shell_tool])
173+
context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None)
174+
175+
result = await ShellAction.execute(
176+
agent=agent,
177+
call=tool_run,
178+
hooks=RunHooks[Any](),
179+
context_wrapper=context_wrapper,
180+
config=RunConfig(),
181+
)
182+
183+
assert isinstance(result, ToolCallOutputItem)
184+
assert result.output == "012345"
185+
raw_item = cast(dict[str, Any], result.raw_item)
186+
assert raw_item["max_output_length"] == 6
187+
assert raw_item["output"][0]["stdout"] == "012345"
188+
assert raw_item["output"][0]["stderr"] == ""
189+
190+
191+
@pytest.mark.asyncio
192+
async def test_shell_tool_uses_smaller_max_output_length() -> None:
193+
shell_tool = ShellTool(
194+
executor=lambda request: ShellResult(
195+
output=[
196+
ShellCommandOutput(
197+
stdout="0123456789",
198+
stderr="abcdef",
199+
outcome=ShellCallOutcome(type="exit", exit_code=0),
200+
)
201+
],
202+
max_output_length=8,
203+
)
204+
)
205+
206+
tool_call = {
207+
"type": "shell_call",
208+
"id": "shell_call",
209+
"call_id": "call_shell",
210+
"status": "completed",
211+
"action": {
212+
"commands": ["echo hi"],
213+
"timeout_ms": 1000,
214+
"max_output_length": 6,
215+
},
216+
}
217+
218+
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
219+
agent = Agent(name="shell-agent", tools=[shell_tool])
220+
context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None)
221+
222+
result = await ShellAction.execute(
223+
agent=agent,
224+
call=tool_run,
225+
hooks=RunHooks[Any](),
226+
context_wrapper=context_wrapper,
227+
config=RunConfig(),
228+
)
229+
230+
assert isinstance(result, ToolCallOutputItem)
231+
assert result.output == "012345"
232+
raw_item = cast(dict[str, Any], result.raw_item)
233+
assert raw_item["max_output_length"] == 6
234+
assert raw_item["output"][0]["stdout"] == "012345"
235+
assert raw_item["output"][0]["stderr"] == ""
236+
237+
238+
@pytest.mark.asyncio
239+
async def test_shell_tool_executor_can_override_max_output_length_to_zero() -> None:
240+
shell_tool = ShellTool(
241+
executor=lambda request: ShellResult(
242+
output=[
243+
ShellCommandOutput(
244+
stdout="0123456789",
245+
stderr="abcdef",
246+
outcome=ShellCallOutcome(type="exit", exit_code=0),
247+
)
248+
],
249+
max_output_length=0,
250+
)
251+
)
252+
253+
tool_call = {
254+
"type": "shell_call",
255+
"id": "shell_call",
256+
"call_id": "call_shell",
257+
"status": "completed",
258+
"action": {
259+
"commands": ["echo hi"],
260+
"timeout_ms": 1000,
261+
"max_output_length": 6,
262+
},
263+
}
264+
265+
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
266+
agent = Agent(name="shell-agent", tools=[shell_tool])
267+
context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None)
268+
269+
result = await ShellAction.execute(
270+
agent=agent,
271+
call=tool_run,
272+
hooks=RunHooks[Any](),
273+
context_wrapper=context_wrapper,
274+
config=RunConfig(),
275+
)
276+
277+
assert isinstance(result, ToolCallOutputItem)
278+
assert result.output == ""
279+
raw_item = cast(dict[str, Any], result.raw_item)
280+
assert raw_item["max_output_length"] == 0
281+
assert raw_item["output"][0]["stdout"] == ""
282+
assert raw_item["output"][0]["stderr"] == ""
283+
284+
285+
@pytest.mark.asyncio
286+
async def test_shell_tool_action_can_request_zero_max_output_length() -> None:
287+
shell_tool = ShellTool(
288+
executor=lambda request: ShellResult(
289+
output=[
290+
ShellCommandOutput(
291+
stdout="0123456789",
292+
stderr="abcdef",
293+
outcome=ShellCallOutcome(type="exit", exit_code=0),
294+
)
295+
],
296+
)
297+
)
298+
299+
tool_call = {
300+
"type": "shell_call",
301+
"id": "shell_call",
302+
"call_id": "call_shell",
303+
"status": "completed",
304+
"action": {
305+
"commands": ["echo hi"],
306+
"timeout_ms": 1000,
307+
"max_output_length": 0,
308+
},
309+
}
310+
311+
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
312+
agent = Agent(name="shell-agent", tools=[shell_tool])
313+
context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None)
314+
315+
result = await ShellAction.execute(
316+
agent=agent,
317+
call=tool_run,
318+
hooks=RunHooks[Any](),
319+
context_wrapper=context_wrapper,
320+
config=RunConfig(),
321+
)
322+
323+
assert isinstance(result, ToolCallOutputItem)
324+
assert result.output == ""
325+
raw_item = cast(dict[str, Any], result.raw_item)
326+
assert raw_item["max_output_length"] == 0
327+
assert raw_item["output"][0]["stdout"] == ""
328+
assert raw_item["output"][0]["stderr"] == ""
329+
330+
331+
@pytest.mark.asyncio
332+
async def test_shell_tool_action_negative_max_output_length_clamps_to_zero() -> None:
333+
shell_tool = ShellTool(
334+
executor=lambda request: ShellResult(
335+
output=[
336+
ShellCommandOutput(
337+
stdout="0123456789",
338+
stderr="abcdef",
339+
outcome=ShellCallOutcome(type="exit", exit_code=0),
340+
)
341+
],
342+
)
343+
)
344+
345+
tool_call = {
346+
"type": "shell_call",
347+
"id": "shell_call",
348+
"call_id": "call_shell",
349+
"status": "completed",
350+
"action": {
351+
"commands": ["echo hi"],
352+
"timeout_ms": 1000,
353+
"max_output_length": -5,
354+
},
355+
}
356+
357+
tool_run = ToolRunShellCall(tool_call=tool_call, shell_tool=shell_tool)
358+
agent = Agent(name="shell-agent", tools=[shell_tool])
359+
context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None)
360+
361+
result = await ShellAction.execute(
362+
agent=agent,
363+
call=tool_run,
364+
hooks=RunHooks[Any](),
365+
context_wrapper=context_wrapper,
366+
config=RunConfig(),
367+
)
368+
369+
assert isinstance(result, ToolCallOutputItem)
370+
assert result.output == ""
371+
raw_item = cast(dict[str, Any], result.raw_item)
372+
assert raw_item["max_output_length"] == 0
373+
assert raw_item["output"][0]["stdout"] == ""
374+
assert raw_item["output"][0]["stderr"] == ""

0 commit comments

Comments
 (0)