Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add outputs_to_string to Tool and ComponentTool #9152

Merged
merged 4 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 61 additions & 46 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
import json
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.sockets import Sockets
Expand Down Expand Up @@ -154,29 +154,62 @@ def _handle_error(self, error: Exception) -> str:
raise error
return str(error)

def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
def _default_output_to_string_handler(self, result: Any) -> str:
"""
Default handler for converting a tool result to a string.

:param result: The tool result to convert to a string.
:returns: The converted tool result as a string.
"""
if self.convert_result_to_json_string:
# We disable ensure_ascii so special chars like emojis are not converted
tool_result_str = json.dumps(result, ensure_ascii=False)
else:
tool_result_str = str(result)
return tool_result_str

def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
"""
Prepares a ChatMessage with the result of a tool invocation.

:param result:
The tool result.
:param tool_call:
The ToolCall object containing the tool name and arguments.
:param tool_to_invoke:
The Tool object that was invoked.
:returns:
A ChatMessage object containing the tool result as a string.
:raises
StringConversionError: If the conversion of the tool result to a string fails
and `raise_on_failure` is True.
"""
source_key = None
output_to_string_handler = None
if tool_to_invoke.outputs_to_string is not None:
if tool_to_invoke.outputs_to_string.get("source"):
source_key = tool_to_invoke.outputs_to_string["source"]
if tool_to_invoke.outputs_to_string.get("handler"):
output_to_string_handler = tool_to_invoke.outputs_to_string["handler"]

# If a source key is provided, we extract the result from the source key
if source_key is not None:
result_to_convert = result.get(source_key)
else:
result_to_convert = result

# If no handler is provided, we use the default handler
if output_to_string_handler is None:
output_to_string_handler = self._default_output_to_string_handler

error = False
try:
if self.convert_result_to_json_string:
# We disable ensure_ascii so special chars like emojis are not converted
tool_result_str = json.dumps(result, ensure_ascii=False)
else:
tool_result_str = str(result)
tool_result_str = output_to_string_handler(result_to_convert)
except Exception as e:
conversion_method = "json.dumps" if self.convert_result_to_json_string else "str"
try:
tool_result_str = self._handle_error(StringConversionError(tool_call.tool_name, conversion_method, e))
tool_result_str = self._handle_error(
StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
)
error = True
except StringConversionError as conversion_error:
# If _handle_error re-raises, this properly preserves the chain
Expand Down Expand Up @@ -221,9 +254,10 @@ def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Di

return final_args

def _merge_tool_outputs(self, tool: Tool, result: Any, state: State) -> Any:
@staticmethod
def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
"""
Merges the tool result into the global state and determines the response message.
Merges the tool result into the State.

This method processes the output of a tool execution and integrates it into the global state.
It also determines what message, if any, should be returned for further processing in a conversation.
Expand All @@ -245,49 +279,25 @@ def _merge_tool_outputs(self, tool: Tool, result: Any, state: State) -> Any:
- The merged result dictionary
- Or the raw result if not a dictionary
"""
# If result is not a dictionary, return it as the output message.
# If result is not a dictionary we exit
if not isinstance(result, dict):
return result
return

# If there is no specific `outputs_to_state` mapping, we just return the full result
# If there is no specific `outputs_to_state` mapping, we exit
if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
return result
return

# Handle tool outputs with specific mapping for message and state updates
return self._handle_tool_outputs(tool.outputs_to_state, result, state)

@staticmethod
def _handle_tool_outputs(outputs_to_state: dict, result: dict, state: State) -> Union[dict, str]:
"""
Handles the `outputs_to_state` mapping from the tool and updates the state accordingly.

:param outputs_to_state: Mapping of outputs from the tool.
:param result: Result of the tool execution.
:param state: Global state to merge results into.
:returns: Final message for LLM or the entire result.
"""
message_content = None

for state_key, config in outputs_to_state.items():
# Update the state with the tool outputs
for state_key, config in tool.outputs_to_state.items():
# Get the source key from the output config, otherwise use the entire result
source_key = config.get("source", None)
output_value = result if source_key is None else result.get(source_key)

# Get the handler function, if any
handler = config.get("handler", None)

if state_key == "message":
# Handle the message output separately
if handler is not None:
message_content = handler(output_value)
else:
message_content = str(output_value)
else:
# Merge other outputs into the state
state.set(state_key, output_value, handler_override=handler)

# If no "message" key was found, return the result or message content
return message_content if message_content is not None else result
# Merge other outputs into the state
state.set(state_key, output_value, handler_override=handler)

@component.output_types(tool_messages=List[ChatMessage], state=State)
def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -343,9 +353,9 @@ def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dic
tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
continue

# 3) Merge outputs into state & create a single ChatMessage for the LLM
# 3) Merge outputs into state
try:
tool_text = self._merge_tool_outputs(tool_to_invoke, tool_result, state)
self._merge_tool_outputs(tool_to_invoke, tool_result, state)
except Exception as e:
try:
error_message = self._handle_error(
Expand All @@ -359,7 +369,12 @@ def run(self, messages: List[ChatMessage], state: Optional[State] = None) -> Dic
# Re-raise with proper error chain
raise propagated_e from e

tool_messages.append(self._prepare_tool_result_message(result=tool_text, tool_call=tool_call))
# 4) Prepare the tool result ChatMessage message
tool_messages.append(
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)

return {"tool_messages": tool_messages, "state": state}

Expand Down
35 changes: 31 additions & 4 deletions haystack/tools/component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def __init__(
description: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
*,
inputs_from_state: Optional[Dict[str, Any]] = None,
outputs_to_state: Optional[Dict[str, Any]] = None,
outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
inputs_from_state: Optional[Dict[str, str]] = None,
outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
):
"""
Create a Tool instance from a Haystack component.
Expand All @@ -108,14 +109,25 @@ def __init__(
:param parameters:
A JSON schema defining the parameters expected by the Tool.
Will fall back to the parameters defined in the component's run method signature if not provided.
:param outputs_to_string:
Optional dictionary defining how a tool outputs should be converted into a string.
If the source is provided only the specified output key is sent to the handler.
If the source is omitted the whole tool result is sent to the handler.
Example: {
"source": "docs", "handler": format_documents
}
:param inputs_from_state:
Optional dictionary mapping state keys to tool parameter names.
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
:param outputs_to_state:
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
If the source is provided only the specified output key is sent to the handler.
Example: {
"documents": {"source": "docs", "handler": custom_handler}
}
If the source is omitted the whole tool result is sent to the handler.
Example: {
"documents": {"source": "docs", "handler": custom_handler},
"message": {"source": "summary", "handler": format_summary}
"documents": {"handler": custom_handler}
}
:raises ValueError: If the component is invalid or schema generation fails.
"""
Expand Down Expand Up @@ -186,6 +198,7 @@ def component_invoker(**kwargs):
function=component_invoker,
inputs_from_state=inputs_from_state,
outputs_to_state=outputs_to_state,
outputs_to_string=outputs_to_string,
)
self._component = component

Expand All @@ -200,7 +213,9 @@ def to_dict(self) -> Dict[str, Any]:
"name": self.name,
"description": self.description,
"parameters": self._unresolved_parameters,
"outputs_to_string": self.outputs_to_string,
"inputs_from_state": self.inputs_from_state,
"outputs_to_state": self.outputs_to_state,
}

if self.outputs_to_state is not None:
Expand All @@ -212,6 +227,9 @@ def to_dict(self) -> Dict[str, Any]:
serialized_outputs[key] = serialized_config
serialized["outputs_to_state"] = serialized_outputs

if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])

return {"type": generate_qualified_class_name(type(self)), "data": serialized}

@classmethod
Expand All @@ -232,11 +250,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
deserialized_outputs[key] = deserialized_config
inner_data["outputs_to_state"] = deserialized_outputs

if (
inner_data.get("outputs_to_string") is not None
and inner_data["outputs_to_string"].get("handler") is not None
):
inner_data["outputs_to_string"]["handler"] = deserialize_callable(
inner_data["outputs_to_string"]["handler"]
)

return cls(
component=component,
name=inner_data["name"],
description=inner_data["description"],
parameters=inner_data.get("parameters", None),
outputs_to_string=inner_data.get("outputs_to_string", None),
inputs_from_state=inner_data.get("inputs_from_state", None),
outputs_to_state=inner_data.get("outputs_to_state", None),
)
Expand Down
39 changes: 34 additions & 5 deletions haystack/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,33 @@ class Tool:
A JSON schema defining the parameters expected by the Tool.
:param function:
The function that will be invoked when the Tool is called.
:param outputs_to_string:
Optional dictionary defining how a tool outputs should be converted into a string.
If the source is provided only the specified output key is sent to the handler.
If the source is omitted the whole tool result is sent to the handler.
Example: {
"source": "docs", "handler": format_documents
}
:param inputs_from_state:
Optional dictionary mapping state keys to tool parameter names.
Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
:param outputs_to_state:
Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
If the source is provided only the specified output key is sent to the handler.
Example: {
"documents": {"source": "docs", "handler": custom_handler}
}
If the source is omitted the whole tool result is sent to the handler.
Example: {
"documents": {"source": "docs", "handler": custom_handler},
"message": {"source": "summary", "handler": format_summary}
"documents": {"handler": custom_handler}
}
"""

name: str
description: str
parameters: Dict[str, Any]
function: Callable
outputs_to_string: Optional[Dict[str, Any]] = None
inputs_from_state: Optional[Dict[str, str]] = None
outputs_to_state: Optional[Dict[str, Dict[str, Any]]] = None

Expand All @@ -58,11 +70,17 @@ def __post_init__(self):
if self.outputs_to_state is not None:
for key, config in self.outputs_to_state.items():
if not isinstance(config, dict):
raise ValueError(f"Output configuration for key '{key}' must be a dictionary")
raise ValueError(f"outputs_to_state configuration for key '{key}' must be a dictionary")
if "source" in config and not isinstance(config["source"], str):
raise ValueError(f"Output source for key '{key}' must be a string.")
raise ValueError(f"outputs_to_state source for key '{key}' must be a string.")
if "handler" in config and not callable(config["handler"]):
raise ValueError(f"Output handler for key '{key}' must be callable")
raise ValueError(f"outputs_to_state handler for key '{key}' must be callable")

if self.outputs_to_string is not None:
if "source" in self.outputs_to_string and not isinstance(self.outputs_to_string["source"], str):
raise ValueError("outputs_to_string source must be a string.")
if "handler" in self.outputs_to_string and not callable(self.outputs_to_string["handler"]):
raise ValueError("outputs_to_string handler must be callable")

@property
def tool_spec(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -103,6 +121,9 @@ def to_dict(self) -> Dict[str, Any]:
serialized_outputs[key] = serialized_config
data["outputs_to_state"] = serialized_outputs

if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
data["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])

return {"type": generate_qualified_class_name(type(self)), "data": data}

@classmethod
Expand All @@ -128,6 +149,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
deserialized_outputs[key] = deserialized_config
init_parameters["outputs_to_state"] = deserialized_outputs

if (
init_parameters.get("outputs_to_string") is not None
and init_parameters["outputs_to_string"].get("handler") is not None
):
init_parameters["outputs_to_string"]["handler"] = deserialize_callable(
init_parameters["outputs_to_string"]["handler"]
)

return cls(**init_parameters)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Adds outputs_to_string to Tool and ComponentTool to allow users to customize how the output of a Tool should be converted into a string so that it can be provided back to the ChatGenerator in a ChatMessage.
If outputs_to_string is not provided a default converter is used within ToolInvoker. The default handler uses the current default behavior.
2 changes: 2 additions & 0 deletions test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def test_to_dict(self, mock_check_valid_model):
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"outputs_to_string": None,
"parameters": {"x": {"type": "string"}},
},
}
Expand Down Expand Up @@ -324,6 +325,7 @@ def test_serde_in_pipeline(self, mock_check_valid_model):
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"outputs_to_string": None,
"description": "description",
"parameters": {"x": {"type": "string"}},
"function": "builtins.print",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def test_to_dict(self, model_info_mock, tools):
"inputs_from_state": None,
"name": "weather",
"outputs_to_state": None,
"outputs_to_string": None,
"description": "useful to determine the weather in a given location",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
"function": "generators.chat.test_hugging_face_local.get_weather",
Expand Down
1 change: 1 addition & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"inputs_from_state": None,
"name": "name",
"outputs_to_state": None,
"outputs_to_string": None,
"parameters": {"x": {"type": "string"}},
},
}
Expand Down
Loading
Loading