diff --git a/docs/guardrails.md b/docs/guardrails.md index 2f0be0f2..69f8d9cd 100644 --- a/docs/guardrails.md +++ b/docs/guardrails.md @@ -2,10 +2,11 @@ Guardrails run _in parallel_ to your agents, enabling you to do checks and validations of user input. For example, imagine you have an agent that uses a very smart (and hence slow/expensive) model to help with customer requests. You wouldn't want malicious users to ask the model to help them with their math homework. So, you can run a guardrail with a fast/cheap model. If the guardrail detects malicious usage, it can immediately raise an error, which stops the expensive model from running and saves you time/money. -There are two kinds of guardrails: +There are three kinds of guardrails: 1. Input guardrails run on the initial user input 2. Output guardrails run on the final agent output +3. Fact checking guardrails run on the initial user input and the final user output ## Input guardrails @@ -23,7 +24,7 @@ Input guardrails run in 3 steps: Output guardrails run in 3 steps: -1. First, the guardrail receives the same input passed to the agent. +1. First, the guardrail receives the output of the last agent. 2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult] 3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. @@ -31,9 +32,22 @@ Output guardrails run in 3 steps: Output guardrails are intended to run on the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability. +## Fact Checking guardrails + +Fact Checking guardrails run in 3 steps: + +1. First, the guardrail receives the same input passed to the first agent and the output of the last agent. +2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`FactCheckingGuardrailResult`][agents.guardrail.FactCheckingGuardrailResult] +3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`FactCheckingGuardrailTripwireTriggered`][agents.exceptions.FactCheckingGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception. + +!!! Note + + Fact checking guardrails are intended to run on user input and the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the output guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability. + + ## Tripwires -If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output}GuardrailTripwireTriggered` exception and halt the Agent execution. +If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output,FactChecking}GuardrailTripwireTriggered` exception and halt the Agent execution. ## Implementing a guardrail @@ -152,3 +166,113 @@ async def main(): 2. This is the guardrail's output type. 3. This is the guardrail function that receives the agent's output, and returns the result. 4. This is the actual agent that defines the workflow. + +Fact checking guardrails are similar. + +```python +import json + +from pydantic import BaseModel, Field + +from agents import ( + Agent, + GuardrailFunctionOutput, + FactCheckingGuardrailTripwireTriggered, + RunContextWrapper, + Runner, + fact_checking_guardrail, +) + + +""" +This example shows how to use fact checking guardrails. + +Fact checking guardrails are checks that run on both the original input and the final output of an agent. +Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that +the output aligns with known facts or the provided input data. They can be used to: +- Validate that the agent's output correctly reflects the information given in the input. +- Ensure that any factual details in the response match expected values. +- Detect discrepancies or potential misinformation. + +In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input. +""" + + +class MessageOutput(BaseModel): # (1)! + reasoning: str = Field(description="Thoughts on how to respond to the user's message") + response: str = Field(description="The response to the user's message") + age: int | None = Field(description="Age of the person") + + +class FactCheckingOutput(BaseModel): # (2)! + reasoning: str + is_age_correct: bool + + +guardrail_agent = Agent( + name="Guardrail Check", + instructions=( + "You are given a task to determine if the hypothesis is grounded in the provided evidence. " + "Rely solely on the contents of the evidence without using external knowledge." + ), + output_type=FactCheckingOutput, +) + + +@fact_checking_guardrail +async def self_check_facts( # (3)! + context: RunContextWrapper, + agent: Agent, + output: MessageOutput, + evidence: str) \ + -> GuardrailFunctionOutput: + """This is a facts checking guardrail function, which happens to call an agent to check if the output + is coherent with the input. + """ + message = ( + f"Input: {evidence}\n" + f"Age: {output.age}" + ) + + print(f"message: {message}") + + # Run the fact-checking agent using the constructed message. + result = await Runner.run(guardrail_agent, message, context=context.context) + final_output = result.final_output_as(FactCheckingOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=not final_output.is_age_correct, + ) + + +async def main(): + agent = Agent( # (4)! + name="Entities Extraction Agent", + instructions=""" + Always respond age = 28. + """, + fact_checking_guardrails=[self_check_facts], + output_type=MessageOutput, + ) + + await Runner.run(agent, "My name is Alex and I'm 28 years old.") + print("First message passed") + + # This should trip the guardrail + try: + result = await Runner.run( + agent, "My name is Alex and I'm 38." + ) + print( + f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" + ) + + except FactCheckingGuardrailTripwireTriggered as e: + print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") +``` + +1. This is the actual agent's output type. +2. This is the guardrail's output type. +3. This is the guardrail function that receives the user input and agent's output, and returns the result. +4. This is the actual agent that defines the workflow. diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920..b39d5f6f 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -51,4 +51,4 @@ You can definitely do this without any special Agents SDK features by using para This is really useful for latency: for example, you might have a very fast model that runs the guardrail and a slow model that runs the actual agent. You wouldn't want to wait for the slow model to finish, so guardrails let you quickly reject invalid inputs. -See the [`input_guardrails.py`](./input_guardrails.py) and [`output_guardrails.py`](./output_guardrails.py) files for examples. +See the [`input_guardrails.py`](./input_guardrails.py), [`output_guardrails.py`](./output_guardrails.py) and [`fact_checking_guardrails.py`](./fact_checking_guardrails.py) files for examples. diff --git a/examples/agent_patterns/fact_checking_guardrails.py b/examples/agent_patterns/fact_checking_guardrails.py new file mode 100644 index 00000000..084a71ad --- /dev/null +++ b/examples/agent_patterns/fact_checking_guardrails.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import asyncio +import json + +from pydantic import BaseModel, Field + +from agents import ( + Agent, + FactCheckingGuardrailTripwireTriggered, + GuardrailFunctionOutput, + RunContextWrapper, + Runner, + fact_checking_guardrail, +) + +""" +This example shows how to use fact checking guardrails. + +Fact checking guardrails are checks that run on both the original input and the final output of an agent. +Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that +the output aligns with known facts or the provided input data. They can be used to: +- Validate that the agent's output correctly reflects the information given in the input. +- Ensure that any factual details in the response match expected values. +- Detect discrepancies or potential misinformation. + +In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input. +""" + + +class MessageOutput(BaseModel): + reasoning: str = Field(description="Thoughts on how to respond to the user's message") + response: str = Field(description="The response to the user's message") + age: int | None = Field(description="Age of the person") + + +class FactCheckingOutput(BaseModel): + reasoning: str + is_age_correct: bool + + +guardrail_agent = Agent( + name="Guardrail Check", + instructions=( + "You are given a task to determine if the hypothesis is grounded in the provided evidence. " + "Rely solely on the contents of the evidence without using external knowledge." + ), + output_type=FactCheckingOutput, +) + + +@fact_checking_guardrail +async def self_check_facts( + context: RunContextWrapper, agent: Agent, output: MessageOutput, evidence: str +) -> GuardrailFunctionOutput: + """This is a facts checking guardrail function, which happens to call an agent to check if the output + is coherent with the input. + """ + message = f"Input: {evidence}\nAge: {output.age}" + + print(f"message: {message}") + + # Run the fact-checking agent using the constructed message. + result = await Runner.run(guardrail_agent, message, context=context.context) + final_output = result.final_output_as(FactCheckingOutput) + + return GuardrailFunctionOutput( + output_info=final_output, + tripwire_triggered=not final_output.is_age_correct, + ) + + +async def main(): + agent = Agent( + name="Entities Extraction Agent", + instructions=""" + Always respond age = 28. + """, + fact_checking_guardrails=[self_check_facts], + output_type=MessageOutput, + ) + + await Runner.run(agent, "My name is Alex and I'm 28 years old.") + print("First message passed") + + # This should trip the guardrail + try: + result = await Runner.run(agent, "My name is Alex and I'm 38.") + print( + f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}" + ) + + except FactCheckingGuardrailTripwireTriggered as e: + print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 242f5649..5771c7cb 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -10,6 +10,7 @@ from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( AgentsException, + FactCheckingGuardrailTripwireTriggered, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, @@ -17,11 +18,14 @@ UserError, ) from .guardrail import ( + FactCheckingGuardrail, + FactCheckingGuardrailResult, GuardrailFunctionOutput, InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, + fact_checking_guardrail, input_guardrail, output_guardrail, ) @@ -164,6 +168,7 @@ def enable_verbose_stdout_logging(): "AgentsException", "InputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered", + "FactCheckingGuardrailTripwireTriggered", "MaxTurnsExceeded", "ModelBehaviorError", "UserError", @@ -171,9 +176,12 @@ def enable_verbose_stdout_logging(): "InputGuardrailResult", "OutputGuardrail", "OutputGuardrailResult", + "FactCheckingGuardrail", + "FactCheckingGuardrailResult", "GuardrailFunctionOutput", "input_guardrail", "output_guardrail", + "fact_checking_guardrail", "handoff", "Handoff", "HandoffInputData", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 94c181b7..141d21f4 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -32,7 +32,14 @@ from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult +from .guardrail import ( + FactCheckingGuardrail, + FactCheckingGuardrailResult, + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) from .handoffs import Handoff, HandoffInputData from .items import ( HandoffCallItem, @@ -708,6 +715,22 @@ async def run_single_output_guardrail( span_guardrail.span_data.triggered = result.output.tripwire_triggered return result + @classmethod + async def run_single_fact_checking_guardrail( + cls, + guardrail: FactCheckingGuardrail[TContext], + agent: Agent[Any], + agent_output: Any, + context: RunContextWrapper[TContext], + agent_input: Any, + ) -> FactCheckingGuardrailResult: + with guardrail_span(guardrail.get_name()) as span_guardrail: + result = await guardrail.run( + agent=agent, agent_output=agent_output, context=context, agent_input=agent_input + ) + span_guardrail.span_data.triggered = result.output.tripwire_triggered + return result + @classmethod def stream_step_result_to_queue( cls, diff --git a/src/agents/agent.py b/src/agents/agent.py index 13bb464e..824b1275 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -8,7 +8,7 @@ from typing_extensions import TypeAlias, TypedDict -from .guardrail import InputGuardrail, OutputGuardrail +from .guardrail import FactCheckingGuardrail, InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers from .logger import logger @@ -129,6 +129,12 @@ class Agent(Generic[TContext]): Runs only if the agent produces a final output. """ + fact_checking_guardrails: list[FactCheckingGuardrail[TContext]] = field(default_factory=list) + """A list of checks that run on the original input + and the final output of the agent, after generating a response. + Runs only if the agent produces a final output. + """ + output_type: type[Any] | None = None """The type of the output object. If not provided, the output will be `str`.""" diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..e74b02ad 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult class AgentsException(Exception): @@ -61,3 +61,16 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + + +class FactCheckingGuardrailTripwireTriggered(AgentsException): + """Exception raised when a guardrail tripwire is triggered.""" + + guardrail_result: "FactCheckingGuardrailResult" + """The result data of the guardrail that was triggered.""" + + def __init__(self, guardrail_result: "FactCheckingGuardrailResult"): + self.guardrail_result = guardrail_result + super().__init__( + f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" + ) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index a96f0f7d..ea17c3d9 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -68,6 +68,34 @@ class OutputGuardrailResult: """The output of the guardrail function.""" +@dataclass +class FactCheckingGuardrailResult: + """The result of a guardrail run.""" + + guardrail: FactCheckingGuardrail[Any] + """ + The guardrail that was run. + """ + + agent_input: Any + """ + The input of the agent that was checked by the guardrail. + """ + + agent_output: Any + """ + The output of the agent that was checked by the guardrail. + """ + + agent: Agent[Any] + """ + The agent that was checked by the guardrail. + """ + + output: GuardrailFunctionOutput + """The output of the guardrail function.""" + + @dataclass class InputGuardrail(Generic[TContext]): """Input guardrails are checks that run in parallel to the agent's execution. @@ -163,6 +191,7 @@ async def run( raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") output = self.guardrail_function(context, agent, agent_output) + if inspect.isawaitable(output): return OutputGuardrailResult( guardrail=self, @@ -179,6 +208,68 @@ async def run( ) +@dataclass +class FactCheckingGuardrail(Generic[TContext]): + """Fact checking guardrails are checks that run on the final output and the input of an agent. + They can be used to do check if the output passes certain validation criteria + + You can use the `@fact_checking_guardrail()` + decorator to turn a function into an `FactCheckingGuardrail`, + or create an `OutputGuardrail` manually. + + Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a + `FactCheckingGuardrailTripwireTriggered` exception will be raised. + """ + + guardrail_function: Callable[ + [RunContextWrapper[TContext], Agent[Any], Any, Any], + MaybeAwaitable[GuardrailFunctionOutput], + ] + """A function that receives the final agent, its output, and the context, and returns a + `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally + include information about the guardrail's output. + """ + + name: str | None = None + """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail + function's name. + """ + + def get_name(self) -> str: + if self.name: + return self.name + + return self.guardrail_function.__name__ + + async def run( + self, + context: RunContextWrapper[TContext], + agent: Agent[Any], + agent_output: Any, + agent_input: Any, + ) -> FactCheckingGuardrailResult: + if not callable(self.guardrail_function): + raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") + + output = self.guardrail_function(context, agent, agent_output, agent_input) + if inspect.isawaitable(output): + return FactCheckingGuardrailResult( + guardrail=self, + agent=agent, + agent_input=agent_input, + agent_output=agent_output, + output=await output, + ) + + return FactCheckingGuardrailResult( + guardrail=self, + agent=agent, + agent_input=agent_input, + agent_output=agent_output, + output=output, + ) + + TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) # For InputGuardrail @@ -318,3 +409,72 @@ def decorator( # Decorator used with keyword arguments return decorator + + +_FactCheckingGuardrailFuncSync = Callable[ + [RunContextWrapper[TContext_co], "Agent[Any]", Any, Any], + GuardrailFunctionOutput, +] +_FactCheckingGuardrailAsync = Callable[ + [RunContextWrapper[TContext_co], "Agent[Any]", Any, Any], + Awaitable[GuardrailFunctionOutput], +] + + +@overload +def fact_checking_guardrail( + func: _FactCheckingGuardrailFuncSync[TContext_co], +) -> FactCheckingGuardrail[TContext_co]: ... + + +@overload +def fact_checking_guardrail( + func: _FactCheckingGuardrailAsync[TContext_co], +) -> FactCheckingGuardrail[TContext_co]: ... + + +@overload +def fact_checking_guardrail( + *, + name: str | None = None, +) -> Callable[ + [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], + FactCheckingGuardrail[TContext_co], +]: ... + + +def fact_checking_guardrail( + func: _FactCheckingGuardrailFuncSync[TContext_co] + | _FactCheckingGuardrailAsync[TContext_co] + | None = None, + *, + name: str | None = None, +) -> ( + FactCheckingGuardrail[TContext_co] + | Callable[ + [_FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co]], + FactCheckingGuardrail[TContext_co], + ] +): + """ + Decorator that transforms a sync or async function into an `FactCheckingGuardrail`. + It can be used directly (no parentheses) or with keyword args, e.g.: + + @fact_checking_guardrail + def my_sync_guardrail(...): ... + + @fact_checking_guardrail(name="guardrail_name") + async def my_async_guardrail(...): ... + """ + + def decorator( + f: _FactCheckingGuardrailFuncSync[TContext_co] | _FactCheckingGuardrailAsync[TContext_co], + ) -> FactCheckingGuardrail[TContext_co]: + return FactCheckingGuardrail(guardrail_function=f, name=name) + + if func is not None: + # Decorator was used without parentheses + return decorator(func) + + # Decorator used with keyword arguments + return decorator diff --git a/src/agents/result.py b/src/agents/result.py index 40a64806..71f467d4 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -9,10 +9,9 @@ from typing_extensions import TypeVar from ._run_impl import QueueCompleteSentinel -from .agent import Agent from .agent_output import AgentOutputSchema from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded -from .guardrail import InputGuardrailResult, OutputGuardrailResult +from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .stream_events import StreamEvent @@ -20,7 +19,6 @@ from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming if TYPE_CHECKING: - from ._run_impl import QueueCompleteSentinel from .agent import Agent T = TypeVar("T") @@ -50,6 +48,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + fact_checking_guardrail_results: list[FactCheckingGuardrailResult] + """Guardrail results for the original input and the final output of the agent.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -135,6 +136,7 @@ class RunResultStreaming(RunResultBase): _run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) _input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) _output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) + _fact_checking_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) _stored_exception: Exception | None = field(default=None, repr=False) @property @@ -211,6 +213,11 @@ def _check_errors(self): if exc and isinstance(exc, Exception): self._stored_exception = exc + if self._fact_checking_guardrails_task and self._fact_checking_guardrails_task.done(): + exc = self._fact_checking_guardrails_task.exception() + if exc and isinstance(exc, Exception): + self._stored_exception = exc + def _cleanup_tasks(self): if self._run_impl_task and not self._run_impl_task.done(): self._run_impl_task.cancel() @@ -221,5 +228,8 @@ def _cleanup_tasks(self): if self._output_guardrails_task and not self._output_guardrails_task.done(): self._output_guardrails_task.cancel() + if self._fact_checking_guardrails_task and not self._fact_checking_guardrails_task.done(): + self._fact_checking_guardrails_task.cancel() + def __str__(self) -> str: return pretty_print_run_result_streaming(self) diff --git a/src/agents/run.py b/src/agents/run.py index 0159822a..fb91f6a9 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -22,12 +22,20 @@ from .agent_output import AgentOutputSchema from .exceptions import ( AgentsException, + FactCheckingGuardrailTripwireTriggered, InputGuardrailTripwireTriggered, MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, ) -from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult +from .guardrail import ( + FactCheckingGuardrail, + FactCheckingGuardrailResult, + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks @@ -76,6 +84,10 @@ class RunConfig: output_guardrails: list[OutputGuardrail[Any]] | None = None """A list of output guardrails to run on the final output of the run.""" + fact_checking_guardrails: list[FactCheckingGuardrail[Any]] | None = None + """A list of fact checking guardrails to run on the original + input and the final output of the run.""" + tracing_disabled: bool = False """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. """ @@ -257,6 +269,14 @@ async def run( turn_result.next_step.output, context_wrapper, ) + fact_checking_guardrail_results = await cls._run_fact_checking_guardrails( + current_agent.fact_checking_guardrails + + (run_config.fact_checking_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + original_input, + ) return RunResult( input=original_input, new_items=generated_items, @@ -265,6 +285,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + fact_checking_guardrail_results=fact_checking_guardrail_results, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -414,6 +435,7 @@ def run_streamed( max_turns=max_turns, input_guardrail_results=[], output_guardrail_results=[], + fact_checking_guardrail_results=[], _current_agent_output_schema=output_schema, _trace=new_trace, ) @@ -582,13 +604,34 @@ async def _run_streamed_impl( ) ) + streamed_result._fact_checking_guardrails_task = asyncio.create_task( + cls._run_fact_checking_guardrails( + current_agent.fact_checking_guardrails + + (run_config.fact_checking_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + ) + ) + try: output_guardrail_results = await streamed_result._output_guardrails_task except Exception: # Exceptions will be checked in the stream_events loop output_guardrail_results = [] + try: + fact_checking_guardrails_results = ( + await streamed_result._fact_checking_guardrails_task + ) + except Exception: + # Exceptions will be checked in the stream_events loop + fact_checking_guardrails_results = [] + streamed_result.output_guardrail_results = output_guardrail_results + streamed_result.fact_checking_guardrail_results = ( + fact_checking_guardrails_results) streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -876,6 +919,47 @@ async def _run_output_guardrails( return guardrail_results + @classmethod + async def _run_fact_checking_guardrails( + cls, + guardrails: list[FactCheckingGuardrail[TContext]], + agent: Agent[TContext], + agent_output: Any, + context: RunContextWrapper[TContext], + agent_input: Any, + ) -> list[FactCheckingGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_fact_checking_guardrail( + guardrail, agent, agent_output, context, agent_input + ) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise FactCheckingGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + @classmethod async def _get_new_response( cls, diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index c9f318c3..08510279 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -6,6 +6,7 @@ from agents import ( Agent, + FactCheckingGuardrail, GuardrailFunctionOutput, InputGuardrail, OutputGuardrail, @@ -13,7 +14,7 @@ TResponseInputItem, UserError, ) -from agents.guardrail import input_guardrail, output_guardrail +from agents.guardrail import fact_checking_guardrail, input_guardrail, output_guardrail def get_sync_guardrail(triggers: bool, output_info: Any | None = None): @@ -260,3 +261,162 @@ async def test_output_guardrail_decorators(): assert not result.output.tripwire_triggered assert result.output.output_info == "test_4" assert guardrail.get_name() == "Custom name" + + +def get_sync_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): + def sync_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any + ): + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return sync_guardrail + + +@pytest.mark.asyncio +async def test_sync_fact_checking_guardrail(): + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=False) + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=True) + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_sync_fact_checking_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" + + +def get_async_fact_checking_guardrail(triggers: bool, output_info: Any | None = None): + async def async_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any + ): + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return async_guardrail + + +@pytest.mark.asyncio +async def test_async_fact_checking_guardrail(): + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=False) + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=True) + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = FactCheckingGuardrail( + guardrail_function=get_async_fact_checking_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" + + +@pytest.mark.asyncio +async def test_invalid_fact_checking_guardrail_raises_user_error(): + with pytest.raises(UserError): + # Purposely ignoring type error + guardrail = FactCheckingGuardrail(guardrail_function="foo") # type: ignore + await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + + +@fact_checking_guardrail +def decorated_fact_checking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="test_5", + tripwire_triggered=False, + ) + + +@fact_checking_guardrail(name="Custom name") +def decorated_named_fact_checking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any, agent_input: Any +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="test_6", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_fact_checking_guardrail_decorators(): + guardrail = decorated_fact_checking_guardrail + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_5" + + guardrail = decorated_named_fact_checking_guardrail + result = await guardrail.run( + agent=Agent(name="test"), + agent_input="test", + agent_output="test", + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_6" + assert guardrail.get_name() == "Custom name" diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e327..33155fe9 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -14,6 +14,7 @@ def create_run_result(final_output: Any) -> RunResult: final_output=final_output, input_guardrail_results=[], output_guardrail_results=[], + fact_checking_guardrail_results=[], _last_agent=Agent(name="test"), )