Skip to content

Commit 4e71e9b

Browse files
keetrapaymeric-roucher
authored andcommitted
Fix memory step model output in ToolCallingAgent (#1156)
1 parent c97c5c4 commit 4e71e9b

File tree

7 files changed

+308
-29
lines changed

7 files changed

+308
-29
lines changed

examples/smolagents_benchmark/run.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ def answer_single_question(example, model, answers_file, action_type):
141141
try:
142142
if action_type == "vanilla":
143143
answer = agent([{"role": "user", "content": augmented_question}]).content
144-
token_count = agent.last_output_token_count
144+
token_counts = agent.monitor.get_total_token_counts()
145145
intermediate_steps = answer
146146
else:
147147
# Run agent 🚀
148148
answer = str(agent.run(augmented_question))
149-
token_count = agent.monitor.get_total_token_counts()
149+
token_counts = agent.monitor.get_total_token_counts()
150150
# Remove memory from logs to make them more compact.
151151
for step in agent.memory.steps:
152152
if isinstance(step, ActionStep):
@@ -157,6 +157,8 @@ def answer_single_question(example, model, answers_file, action_type):
157157
except Exception as e:
158158
print("Error on ", augmented_question, e)
159159
intermediate_steps = []
160+
token_counts = {"input": 0, "output": 0}
161+
answer = str(e)
160162
end_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
161163
annotated_example = {
162164
"model_id": model.model_id,
@@ -169,7 +171,7 @@ def answer_single_question(example, model, answers_file, action_type):
169171
"intermediate_steps": intermediate_steps,
170172
"start_time": start_time,
171173
"end_time": end_time,
172-
"token_counts": token_count,
174+
"token_counts": token_counts,
173175
}
174176
append_answer(annotated_example, answers_file)
175177

@@ -233,7 +235,7 @@ def answer_questions(
233235
max_completion_tokens=8192,
234236
)
235237
else:
236-
model = HfApiModel(model_id=args.model_id, provider="together", max_tokens=8192)
238+
model = HfApiModel(model_id=args.model_id, max_tokens=8192)
237239

238240
answer_questions(
239241
eval_ds,

examples/smolagents_benchmark/score.ipynb

+283-4
Large diffs are not rendered by default.

src/smolagents/agents.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]:
994994
tool_call = model_message.tool_calls[0]
995995
tool_name, tool_call_id = tool_call.function.name, tool_call.id
996996
tool_arguments = tool_call.function.arguments
997-
997+
memory_step.model_output = str(f"Called Tool: '{tool_name}' with arguments: {tool_arguments}")
998998
memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
999999

10001000
# Execute

src/smolagents/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def parse_code_blobs(text: str) -> str:
186186
Raises:
187187
ValueError: If no valid code block is found in the text.
188188
"""
189-
pattern = r"```(?:py|python)?\n(.*?)\n```"
189+
pattern = r"```(?:py|python)?\s*\n(.*?)\n```"
190190
matches = re.findall(pattern, text, re.DOTALL)
191191
if matches:
192192
return "\n\n".join(match.strip() for match in matches)

tests/test_agents.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_fake_toolcalling_agent(self):
317317
assert "7.2904" in output
318318
assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
319319
assert "7.2904" in agent.memory.steps[1].observations
320-
assert agent.memory.steps[2].model_output is None
320+
assert agent.memory.steps[2].model_output == "Called Tool: 'final_answer' with arguments: {'answer': '7.2904'}"
321321

322322
def test_toolcalling_agent_handles_image_tool_outputs(self, shared_datadir):
323323
import PIL.Image
@@ -495,6 +495,15 @@ def test_replay_shows_logs(self):
495495
assert 'final_answer("got' in str_output
496496
assert "```<end_code>" in str_output
497497

498+
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel(), verbosity_level=0)
499+
agent.run("What is 2 multiplied by 3.6452?")
500+
with agent.logger.console.capture() as capture:
501+
agent.replay()
502+
str_output = capture.get().replace("\n", "")
503+
assert "Called" in str_output
504+
assert "Tool" in str_output
505+
assert "arguments" in str_output
506+
498507
def test_code_nontrivial_final_answer_works(self):
499508
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
500509
return ChatMessage(

tests/test_monitoring.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_streaming_agent_image_output(self):
160160
)
161161
)
162162

163-
self.assertEqual(len(outputs), 5)
163+
self.assertEqual(len(outputs), 6)
164164
final_message = outputs[-1]
165165
self.assertEqual(final_message.role, "assistant")
166166
self.assertIsInstance(final_message.content, dict)

tests/test_utils.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,14 @@ def test_parse_code_blobs(self):
113113
output = parse_code_blobs(code_blob)
114114
assert output == code_blob
115115

116-
def test_multiple_code_blobs(self):
117-
test_input = """Here's a function that adds numbers:
118-
```python
119-
def add(a, b):
120-
return a + b
121-
```
122-
And here's a function that multiplies them:
123-
```py
124-
def multiply(a, b):
125-
return a * b
126-
```"""
116+
# Allow whitespaces after header
117+
output = parse_code_blobs("```py \ncode_a\n````")
118+
assert output == "code_a"
127119

128-
expected_output = """def add(a, b):
129-
return a + b
130-
131-
def multiply(a, b):
132-
return a * b"""
120+
def test_multiple_code_blobs(self):
121+
test_input = "```\nFoo\n```\n\n```py\ncode_a\n````\n\n```python\ncode_b\n```"
133122
result = parse_code_blobs(test_input)
134-
assert result == expected_output
123+
assert result == "nFoo\n\ncode_a\n\ncode_b"
135124

136125

137126
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)