-
Notifications
You must be signed in to change notification settings - Fork 808
/
Copy pathcustom_example_provider.py
77 lines (57 loc) · 2.21 KB
/
custom_example_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations
import asyncio
import os
from openai import AsyncOpenAI
from agents import (
Agent,
Model,
ModelProvider,
OpenAIChatCompletionsModel,
RunConfig,
Runner,
function_tool,
set_tracing_disabled,
)
BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""
if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)
"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for
others. Steps:
1. Create a custom OpenAI client.
2. Create a ModelProvider that uses the custom client.
3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider.
Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)
class CustomModelProvider(ModelProvider):
def get_model(self, model_name: str | None) -> Model:
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)
CUSTOM_MODEL_PROVIDER = CustomModelProvider()
@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."
async def main():
agent = Agent(name="Assistant", instructions="You only respond in haikus.", tools=[get_weather])
# This will use the custom model provider
result = await Runner.run(
agent,
"What's the weather in Tokyo?",
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
)
print(result.final_output)
# If you uncomment this, it will use OpenAI directly, not the custom provider
# result = await Runner.run(
# agent,
# "What's the weather in Tokyo?",
# )
# print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())