Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fact checking guardrails #347

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 127 additions & 3 deletions docs/guardrails.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

Guardrails run _in parallel_ to your agents, enabling you to do checks and validations of user input. For example, imagine you have an agent that uses a very smart (and hence slow/expensive) model to help with customer requests. You wouldn't want malicious users to ask the model to help them with their math homework. So, you can run a guardrail with a fast/cheap model. If the guardrail detects malicious usage, it can immediately raise an error, which stops the expensive model from running and saves you time/money.

There are two kinds of guardrails:
There are three kinds of guardrails:

1. Input guardrails run on the initial user input
2. Output guardrails run on the final agent output
3. Fact checking guardrails run on the initial user input and the final user output

## Input guardrails

Expand All @@ -23,17 +24,30 @@ Input guardrails run in 3 steps:

Output guardrails run in 3 steps:

1. First, the guardrail receives the same input passed to the agent.
1. First, the guardrail receives the output of the last agent.
2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`OutputGuardrailResult`][agents.guardrail.OutputGuardrailResult]
3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`OutputGuardrailTripwireTriggered`][agents.exceptions.OutputGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception.

!!! Note

Output guardrails are intended to run on the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the input guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.

## Fact Checking guardrails

Fact Checking guardrails run in 3 steps:

1. First, the guardrail receives the same input passed to the first agent and the output of the last agent.
2. Next, the guardrail function runs to produce a [`GuardrailFunctionOutput`][agents.guardrail.GuardrailFunctionOutput], which is then wrapped in an [`FactCheckingGuardrailResult`][agents.guardrail.FactCheckingGuardrailResult]
3. Finally, we check if [`.tripwire_triggered`][agents.guardrail.GuardrailFunctionOutput.tripwire_triggered] is true. If true, an [`FactCheckingGuardrailTripwireTriggered`][agents.exceptions.FactCheckingGuardrailTripwireTriggered] exception is raised, so you can appropriately respond to the user or handle the exception.

!!! Note

Fact checking guardrails are intended to run on user input and the final agent output, so an agent's guardrails only run if the agent is the *last* agent. Similar to the output guardrails, we do this because guardrails tend to be related to the actual Agent - you'd run different guardrails for different agents, so colocating the code is useful for readability.


## Tripwires

If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output}GuardrailTripwireTriggered` exception and halt the Agent execution.
If the input or output fails the guardrail, the Guardrail can signal this with a tripwire. As soon as we see a guardrail that has triggered the tripwires, we immediately raise a `{Input,Output,FactChecking}GuardrailTripwireTriggered` exception and halt the Agent execution.

## Implementing a guardrail

Expand Down Expand Up @@ -152,3 +166,113 @@ async def main():
2. This is the guardrail's output type.
3. This is the guardrail function that receives the agent's output, and returns the result.
4. This is the actual agent that defines the workflow.

Fact checking guardrails are similar.

```python
import json

from pydantic import BaseModel, Field

from agents import (
Agent,
GuardrailFunctionOutput,
FactCheckingGuardrailTripwireTriggered,
RunContextWrapper,
Runner,
fact_checking_guardrail,
)


"""
This example shows how to use fact checking guardrails.

Fact checking guardrails are checks that run on both the original input and the final output of an agent.
Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that
the output aligns with known facts or the provided input data. They can be used to:
- Validate that the agent's output correctly reflects the information given in the input.
- Ensure that any factual details in the response match expected values.
- Detect discrepancies or potential misinformation.

In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input.
"""


class MessageOutput(BaseModel): # (1)!
reasoning: str = Field(description="Thoughts on how to respond to the user's message")
response: str = Field(description="The response to the user's message")
age: int | None = Field(description="Age of the person")


class FactCheckingOutput(BaseModel): # (2)!
reasoning: str
is_age_correct: bool


guardrail_agent = Agent(
name="Guardrail Check",
instructions=(
"You are given a task to determine if the hypothesis is grounded in the provided evidence. "
"Rely solely on the contents of the evidence without using external knowledge."
),
output_type=FactCheckingOutput,
)


@fact_checking_guardrail
async def self_check_facts( # (3)!
context: RunContextWrapper,
agent: Agent,
output: MessageOutput,
evidence: str) \
-> GuardrailFunctionOutput:
"""This is a facts checking guardrail function, which happens to call an agent to check if the output
is coherent with the input.
"""
message = (
f"Input: {evidence}\n"
f"Age: {output.age}"
)

print(f"message: {message}")

# Run the fact-checking agent using the constructed message.
result = await Runner.run(guardrail_agent, message, context=context.context)
final_output = result.final_output_as(FactCheckingOutput)

return GuardrailFunctionOutput(
output_info=final_output,
tripwire_triggered=not final_output.is_age_correct,
)


async def main():
agent = Agent( # (4)!
name="Entities Extraction Agent",
instructions="""
Always respond age = 28.
""",
fact_checking_guardrails=[self_check_facts],
output_type=MessageOutput,
)

await Runner.run(agent, "My name is Alex and I'm 28 years old.")
print("First message passed")

# This should trip the guardrail
try:
result = await Runner.run(
agent, "My name is Alex and I'm 38."
)
print(
f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}"
)

except FactCheckingGuardrailTripwireTriggered as e:
print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}")
```

1. This is the actual agent's output type.
2. This is the guardrail's output type.
3. This is the guardrail function that receives the user input and agent's output, and returns the result.
4. This is the actual agent that defines the workflow.
2 changes: 1 addition & 1 deletion examples/agent_patterns/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ You can definitely do this without any special Agents SDK features by using para

This is really useful for latency: for example, you might have a very fast model that runs the guardrail and a slow model that runs the actual agent. You wouldn't want to wait for the slow model to finish, so guardrails let you quickly reject invalid inputs.

See the [`input_guardrails.py`](./input_guardrails.py) and [`output_guardrails.py`](./output_guardrails.py) files for examples.
See the [`input_guardrails.py`](./input_guardrails.py), [`output_guardrails.py`](./output_guardrails.py) and [`fact_checking_guardrails.py`](./fact_checking_guardrails.py) files for examples.
98 changes: 98 additions & 0 deletions examples/agent_patterns/fact_checking_guardrails.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import asyncio
import json

from pydantic import BaseModel, Field

from agents import (
Agent,
FactCheckingGuardrailTripwireTriggered,
GuardrailFunctionOutput,
RunContextWrapper,
Runner,
fact_checking_guardrail,
)

"""
This example shows how to use fact checking guardrails.

Fact checking guardrails are checks that run on both the original input and the final output of an agent.
Their primary purpose is to ensure the consistency and accuracy of the agent’s response by verifying that
the output aligns with known facts or the provided input data. They can be used to:
- Validate that the agent's output correctly reflects the information given in the input.
- Ensure that any factual details in the response match expected values.
- Detect discrepancies or potential misinformation.

In this example, we'll use a contrived scenario where we verify if the agent's response contains data that matches the input.
"""


class MessageOutput(BaseModel):
reasoning: str = Field(description="Thoughts on how to respond to the user's message")
response: str = Field(description="The response to the user's message")
age: int | None = Field(description="Age of the person")


class FactCheckingOutput(BaseModel):
reasoning: str
is_age_correct: bool


guardrail_agent = Agent(
name="Guardrail Check",
instructions=(
"You are given a task to determine if the hypothesis is grounded in the provided evidence. "
"Rely solely on the contents of the evidence without using external knowledge."
),
output_type=FactCheckingOutput,
)


@fact_checking_guardrail
async def self_check_facts(
context: RunContextWrapper, agent: Agent, output: MessageOutput, evidence: str
) -> GuardrailFunctionOutput:
"""This is a facts checking guardrail function, which happens to call an agent to check if the output
is coherent with the input.
"""
message = f"Input: {evidence}\nAge: {output.age}"

print(f"message: {message}")

# Run the fact-checking agent using the constructed message.
result = await Runner.run(guardrail_agent, message, context=context.context)
final_output = result.final_output_as(FactCheckingOutput)

return GuardrailFunctionOutput(
output_info=final_output,
tripwire_triggered=not final_output.is_age_correct,
)


async def main():
agent = Agent(
name="Entities Extraction Agent",
instructions="""
Always respond age = 28.
""",
fact_checking_guardrails=[self_check_facts],
output_type=MessageOutput,
)

await Runner.run(agent, "My name is Alex and I'm 28 years old.")
print("First message passed")

# This should trip the guardrail
try:
result = await Runner.run(agent, "My name is Alex and I'm 38.")
print(
f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}"
)

except FactCheckingGuardrailTripwireTriggered as e:
print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}")


if __name__ == "__main__":
asyncio.run(main())
8 changes: 8 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
AgentsException,
FactCheckingGuardrailTripwireTriggered,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
UserError,
)
from .guardrail import (
FactCheckingGuardrail,
FactCheckingGuardrailResult,
GuardrailFunctionOutput,
InputGuardrail,
InputGuardrailResult,
OutputGuardrail,
OutputGuardrailResult,
fact_checking_guardrail,
input_guardrail,
output_guardrail,
)
Expand Down Expand Up @@ -164,16 +168,20 @@ def enable_verbose_stdout_logging():
"AgentsException",
"InputGuardrailTripwireTriggered",
"OutputGuardrailTripwireTriggered",
"FactCheckingGuardrailTripwireTriggered",
"MaxTurnsExceeded",
"ModelBehaviorError",
"UserError",
"InputGuardrail",
"InputGuardrailResult",
"OutputGuardrail",
"OutputGuardrailResult",
"FactCheckingGuardrail",
"FactCheckingGuardrailResult",
"GuardrailFunctionOutput",
"input_guardrail",
"output_guardrail",
"fact_checking_guardrail",
"handoff",
"Handoff",
"HandoffInputData",
Expand Down
25 changes: 24 additions & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .guardrail import (
FactCheckingGuardrail,
FactCheckingGuardrailResult,
InputGuardrail,
InputGuardrailResult,
OutputGuardrail,
OutputGuardrailResult,
)
from .handoffs import Handoff, HandoffInputData
from .items import (
HandoffCallItem,
Expand Down Expand Up @@ -708,6 +715,22 @@ async def run_single_output_guardrail(
span_guardrail.span_data.triggered = result.output.tripwire_triggered
return result

@classmethod
async def run_single_fact_checking_guardrail(
cls,
guardrail: FactCheckingGuardrail[TContext],
agent: Agent[Any],
agent_output: Any,
context: RunContextWrapper[TContext],
agent_input: Any,
) -> FactCheckingGuardrailResult:
with guardrail_span(guardrail.get_name()) as span_guardrail:
result = await guardrail.run(
agent=agent, agent_output=agent_output, context=context, agent_input=agent_input
)
span_guardrail.span_data.triggered = result.output.tripwire_triggered
return result

@classmethod
def stream_step_result_to_queue(
cls,
Expand Down
8 changes: 7 additions & 1 deletion src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing_extensions import TypeAlias, TypedDict

from .guardrail import InputGuardrail, OutputGuardrail
from .guardrail import FactCheckingGuardrail, InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
from .logger import logger
Expand Down Expand Up @@ -129,6 +129,12 @@ class Agent(Generic[TContext]):
Runs only if the agent produces a final output.
"""

fact_checking_guardrails: list[FactCheckingGuardrail[TContext]] = field(default_factory=list)
"""A list of checks that run on the original input
and the final output of the agent, after generating a response.
Runs only if the agent produces a final output.
"""

output_type: type[Any] | None = None
"""The type of the output object. If not provided, the output will be `str`."""

Expand Down
15 changes: 14 additions & 1 deletion src/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .guardrail import FactCheckingGuardrailResult, InputGuardrailResult, OutputGuardrailResult


class AgentsException(Exception):
Expand Down Expand Up @@ -61,3 +61,16 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"):
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
)


class FactCheckingGuardrailTripwireTriggered(AgentsException):
"""Exception raised when a guardrail tripwire is triggered."""

guardrail_result: "FactCheckingGuardrailResult"
"""The result data of the guardrail that was triggered."""

def __init__(self, guardrail_result: "FactCheckingGuardrailResult"):
self.guardrail_result = guardrail_result
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
)
Loading