-
Notifications
You must be signed in to change notification settings - Fork 9
/
openai_function_calling.py
122 lines (105 loc) · 3.84 KB
/
openai_function_calling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from pprint import pprint
from litellm.utils import ChatCompletionMessageToolCall
from pydantic import TypeAdapter
from tapeagents.agent import Agent
from tapeagents.core import Prompt
from tapeagents.dialog_tape import (
AssistantStep,
DialogContext,
DialogTape,
ToolCalls,
ToolResult,
ToolSpec,
UserStep,
)
from tapeagents.environment import (
ExternalObservationNeeded,
MockToolEnvironment,
)
from tapeagents.llms import LiteLLM, LLMStream
from tapeagents.orchestrator import main_loop
from tapeagents.prompting import tape_to_messages
class FunctionCallingAgent(Agent[DialogTape]):
def make_prompt(self, tape: DialogTape):
assert tape.context
return Prompt(tools=[t.model_dump() for t in tape.context.tools], messages=tape_to_messages(tape))
def generate_steps(self, _, llm_stream: LLMStream):
o = llm_stream.get_output()
if o.content:
yield AssistantStep(content=o.content)
elif o.tool_calls:
assert all(isinstance(tc, ChatCompletionMessageToolCall) for tc in o.tool_calls)
yield ToolCalls.from_llm_output(o)
else:
raise ValueError(f"don't know what to do with message {o}")
TOOL_SCHEMAS = TypeAdapter(list[ToolSpec]).validate_python(
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
)
def try_openai_function_calling():
llm = LiteLLM(model_name="gpt-3.5-turbo")
agent = FunctionCallingAgent.create(llm)
tape = DialogTape(context=DialogContext(tools=TOOL_SCHEMAS), steps=[])
for event in agent.run(tape.append(UserStep(content="What's the weather like in San Francisco, Tokyo"))):
if event.step:
print(event.step)
assert event.final_tape
tape = event.final_tape
tool_call_step = tape.steps[-1]
assert isinstance(tool_call_step, ToolCalls)
tool_calls = tool_call_step.tool_calls
tape = tape.append(
ToolResult(
tool_call_id=tool_calls[0].id,
content="Cloudy, 13C",
)
).append(
ToolResult(
tool_call_id=tool_calls[1].id,
content="Sunny, 23C",
)
)
for event in agent.run(tape):
if event.step:
print(event.step)
def try_openai_function_callling_with_environment():
llm = LiteLLM(model_name="gpt-3.5-turbo")
agent = FunctionCallingAgent.create(llm)
tape = DialogTape(
context=DialogContext(tools=TOOL_SCHEMAS),
steps=[UserStep(content="What's the weather like in San Francisco, Tokyo")],
)
environment = MockToolEnvironment(llm)
for s in tape.steps:
print("USER STEP")
pprint(s.model_dump(exclude_none=True))
try:
for event in main_loop(agent, tape, environment, max_loops=3):
if ae := event.agent_event:
if ae.step:
print("AGENT STEP")
pprint(ae.step.model_dump(exclude_none=True))
elif event.observation:
print("OBSERVATION")
pprint(event.observation.model_dump(exclude_none=True))
except ExternalObservationNeeded as e:
assert isinstance(e.action, AssistantStep)
print("Stopping, next user message is needed")
if __name__ == "__main__":
try_openai_function_calling()
try_openai_function_callling_with_environment()