-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathfake_model.py
141 lines (125 loc) · 4.49 KB
/
fake_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from openai.types.responses import Response, ResponseCompletedEvent
from agents.agent_output import AgentOutputSchema
from agents.handoffs import Handoff
from agents.items import (
ModelResponse,
TResponseInputItem,
TResponseOutputItem,
TResponseStreamEvent,
)
from agents.model_settings import ModelSettings
from agents.models.interface import Model, ModelTracing
from agents.tool import Tool
from agents.tracing import SpanError, generation_span
from agents.usage import Usage
class FakeModel(Model):
def __init__(
self,
tracing_enabled: bool = False,
initial_output: list[TResponseOutputItem] | Exception | None = None,
):
if initial_output is None:
initial_output = []
self.turn_outputs: list[list[TResponseOutputItem] | Exception] = (
[initial_output] if initial_output else []
)
self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {}
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output)
def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]):
self.turn_outputs.extend(outputs)
def get_next_output(self) -> list[TResponseOutputItem] | Exception:
if not self.turn_outputs:
return []
return self.turn_outputs.pop(0)
async def get_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
) -> ModelResponse:
self.last_turn_args = {
"system_instructions": system_instructions,
"input": input,
"model_settings": model_settings,
"tools": tools,
"output_schema": output_schema,
"previous_response_id": previous_response_id,
}
with generation_span(disabled=not self.tracing_enabled) as span:
output = self.get_next_output()
if isinstance(output, Exception):
span.set_error(
SpanError(
message="Error",
data={
"name": output.__class__.__name__,
"message": str(output),
},
)
)
raise output
return ModelResponse(
output=output,
usage=Usage(),
response_id=None,
)
async def stream_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: str | None,
) -> AsyncIterator[TResponseStreamEvent]:
self.last_turn_args = {
"system_instructions": system_instructions,
"input": input,
"model_settings": model_settings,
"tools": tools,
"output_schema": output_schema,
"previous_response_id": previous_response_id,
}
with generation_span(disabled=not self.tracing_enabled) as span:
output = self.get_next_output()
if isinstance(output, Exception):
span.set_error(
SpanError(
message="Error",
data={
"name": output.__class__.__name__,
"message": str(output),
},
)
)
raise output
yield ResponseCompletedEvent(
type="response.completed",
response=get_response_obj(output),
)
def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response:
return Response(
id=response_id or "123",
created_at=123,
model="test_model",
object="response",
output=output,
tool_choice="none",
tools=[],
top_p=None,
parallel_tool_calls=False,
)