Skip to content

Commit 8c284e0

Browse files
committed
add model retry settings with backoff logic
1 parent 064e25b commit 8c284e0

File tree

6 files changed

+414
-18
lines changed

6 files changed

+414
-18
lines changed

Diff for: src/agents/models/__init__.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Model implementations and utilities for working with language models."""
2+
3+
from ._openai_shared import (
4+
TOpenAIClient,
5+
create_client,
6+
get_default_openai_client,
7+
get_default_openai_key,
8+
get_use_responses_by_default,
9+
set_default_openai_client,
10+
set_default_openai_key,
11+
set_use_responses_by_default,
12+
)
13+
from .interface import Model, ModelProvider, ModelRetrySettings, ModelTracing
14+
from .openai_chatcompletions import OpenAIChatCompletionsModel
15+
from .openai_provider import OpenAIProvider
16+
from .openai_responses import OpenAIResponsesModel
17+
from .utils import (
18+
cache_model_response,
19+
clear_cache,
20+
compute_cache_key,
21+
get_token_count_estimate,
22+
set_cache_ttl,
23+
validate_response,
24+
)
25+
26+
__all__ = [
27+
# Interface
28+
"Model",
29+
"ModelProvider",
30+
"ModelRetrySettings",
31+
"ModelTracing",
32+
33+
# OpenAI utilities
34+
"get_default_openai_client",
35+
"get_default_openai_key",
36+
"get_use_responses_by_default",
37+
"set_default_openai_client",
38+
"set_default_openai_key",
39+
"set_use_responses_by_default",
40+
"TOpenAIClient",
41+
"create_client",
42+
43+
# Model implementations
44+
"OpenAIChatCompletionsModel",
45+
"OpenAIProvider",
46+
"OpenAIResponsesModel",
47+
48+
# Caching and utilities
49+
"cache_model_response",
50+
"clear_cache",
51+
"compute_cache_key",
52+
"get_token_count_estimate",
53+
"set_cache_ttl",
54+
"validate_response",
55+
]

Diff for: src/agents/models/_openai_shared.py

+71-3
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,102 @@
11
from __future__ import annotations
22

3+
import logging
4+
from typing import Any, TypeAlias
5+
36
from openai import AsyncOpenAI
47

8+
# Type aliases for common OpenAI types
9+
TOpenAIClient: TypeAlias = AsyncOpenAI
10+
TOpenAIClientOptions: TypeAlias = dict[str, Any]
11+
512
_default_openai_key: str | None = None
6-
_default_openai_client: AsyncOpenAI | None = None
13+
_default_openai_client: TOpenAIClient | None = None
714
_use_responses_by_default: bool = True
15+
_logger = logging.getLogger(__name__)
816

917

1018
def set_default_openai_key(key: str) -> None:
19+
"""Set the default OpenAI API key to use when creating clients.
20+
21+
Args:
22+
key: The OpenAI API key
23+
"""
1124
global _default_openai_key
1225
_default_openai_key = key
1326

1427

1528
def get_default_openai_key() -> str | None:
29+
"""Get the default OpenAI API key.
30+
31+
Returns:
32+
The default OpenAI API key, or None if not set
33+
"""
1634
return _default_openai_key
1735

1836

19-
def set_default_openai_client(client: AsyncOpenAI) -> None:
37+
def set_default_openai_client(client: TOpenAIClient) -> None:
38+
"""Set the default OpenAI client to use.
39+
40+
Args:
41+
client: The OpenAI client instance
42+
"""
2043
global _default_openai_client
2144
_default_openai_client = client
2245

2346

24-
def get_default_openai_client() -> AsyncOpenAI | None:
47+
def get_default_openai_client() -> TOpenAIClient | None:
48+
"""Get the default OpenAI client.
49+
50+
Returns:
51+
The default OpenAI client, or None if not set
52+
"""
2553
return _default_openai_client
2654

2755

2856
def set_use_responses_by_default(use_responses: bool) -> None:
57+
"""Set whether to use the Responses API by default.
58+
59+
Args:
60+
use_responses: Whether to use the Responses API
61+
"""
2962
global _use_responses_by_default
3063
_use_responses_by_default = use_responses
3164

3265

3366
def get_use_responses_by_default() -> bool:
67+
"""Get whether to use the Responses API by default.
68+
69+
Returns:
70+
Whether to use the Responses API by default
71+
"""
3472
return _use_responses_by_default
73+
74+
75+
def create_client(
76+
api_key: str | None = None,
77+
base_url: str | None = None,
78+
organization: str | None = None,
79+
project: str | None = None,
80+
http_client: Any = None,
81+
) -> TOpenAIClient:
82+
"""Create a new OpenAI client with the given parameters.
83+
84+
This is a utility function to standardize client creation across the codebase.
85+
86+
Args:
87+
api_key: The API key to use. If not provided, uses the default.
88+
base_url: The base URL to use. If not provided, uses the default.
89+
organization: The organization to use.
90+
project: The project to use.
91+
http_client: The HTTP client to use.
92+
93+
Returns:
94+
A new OpenAI client
95+
"""
96+
return AsyncOpenAI(
97+
api_key=api_key or get_default_openai_key(),
98+
base_url=base_url,
99+
organization=organization,
100+
project=project,
101+
http_client=http_client,
102+
)

Diff for: src/agents/models/interface.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import abc
4+
import asyncio
45
import enum
56
from collections.abc import AsyncIterator
6-
from typing import TYPE_CHECKING
7+
from dataclasses import dataclass, field
8+
from typing import TYPE_CHECKING, Any, Callable
79

810
from ..agent_output import AgentOutputSchema
911
from ..handoffs import Handoff
@@ -31,6 +33,76 @@ def include_data(self) -> bool:
3133
return self == ModelTracing.ENABLED
3234

3335

36+
@dataclass
37+
class ModelRetrySettings:
38+
"""Settings for retrying model calls on failure.
39+
40+
This class helps manage backoff and retry logic when API calls fail.
41+
"""
42+
43+
max_retries: int = 3
44+
"""Maximum number of retries to attempt."""
45+
46+
initial_backoff_seconds: float = 1.0
47+
"""Initial backoff time in seconds before the first retry."""
48+
49+
max_backoff_seconds: float = 30.0
50+
"""Maximum backoff time in seconds between retries."""
51+
52+
backoff_multiplier: float = 2.0
53+
"""Multiplier for backoff time after each retry."""
54+
55+
retryable_status_codes: list[int] = field(default_factory=lambda: [429, 500, 502, 503, 504])
56+
"""HTTP status codes that should trigger a retry."""
57+
58+
async def execute_with_retry(
59+
self,
60+
operation: Callable[[], Any],
61+
should_retry: Callable[[Exception], bool] | None = None
62+
) -> Any:
63+
"""Execute an operation with retry logic.
64+
65+
Args:
66+
operation: Async function to execute
67+
should_retry: Optional function to determine if an exception should trigger a retry
68+
69+
Returns:
70+
The result of the operation if successful
71+
72+
Raises:
73+
The last exception encountered if all retries fail
74+
"""
75+
last_exception = None
76+
backoff = self.initial_backoff_seconds
77+
78+
for attempt in range(self.max_retries + 1):
79+
try:
80+
return await operation()
81+
except Exception as e:
82+
last_exception = e
83+
84+
# Check if we should retry
85+
if attempt >= self.max_retries:
86+
break
87+
88+
should_retry_exception = True
89+
if should_retry is not None:
90+
should_retry_exception = should_retry(e)
91+
92+
if not should_retry_exception:
93+
break
94+
95+
# Wait before retrying
96+
await asyncio.sleep(backoff)
97+
backoff = min(backoff * self.backoff_multiplier, self.max_backoff_seconds)
98+
99+
if last_exception:
100+
raise last_exception
101+
102+
# This should never happen, but just in case
103+
raise RuntimeError("Retry logic failed in an unexpected way")
104+
105+
34106
class Model(abc.ABC):
35107
"""The base interface for calling an LLM."""
36108

Diff for: src/agents/models/openai_provider.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from __future__ import annotations
22

3+
import logging
4+
35
import httpx
4-
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
6+
from openai import DefaultAsyncHttpxClient, OpenAIError
57

68
from . import _openai_shared
9+
from ._openai_shared import TOpenAIClient, create_client
710
from .interface import Model, ModelProvider
811
from .openai_chatcompletions import OpenAIChatCompletionsModel
912
from .openai_responses import OpenAIResponsesModel
1013

1114
DEFAULT_MODEL: str = "gpt-4o"
12-
15+
_logger = logging.getLogger(__name__)
1316

1417
_http_client: httpx.AsyncClient | None = None
1518

@@ -29,10 +32,11 @@ def __init__(
2932
*,
3033
api_key: str | None = None,
3134
base_url: str | None = None,
32-
openai_client: AsyncOpenAI | None = None,
35+
openai_client: TOpenAIClient | None = None,
3336
organization: str | None = None,
3437
project: str | None = None,
3538
use_responses: bool | None = None,
39+
default_model: str = DEFAULT_MODEL,
3640
) -> None:
3741
"""Create a new OpenAI provider.
3842
@@ -46,12 +50,13 @@ def __init__(
4650
organization: The organization to use for the OpenAI client.
4751
project: The project to use for the OpenAI client.
4852
use_responses: Whether to use the OpenAI responses API.
53+
default_model: The default model to use if none is specified.
4954
"""
5055
if openai_client is not None:
5156
assert api_key is None and base_url is None, (
5257
"Don't provide api_key or base_url if you provide openai_client"
5358
)
54-
self._client: AsyncOpenAI | None = openai_client
59+
self._client: TOpenAIClient | None = openai_client
5560
else:
5661
self._client = None
5762
self._stored_api_key = api_key
@@ -64,23 +69,42 @@ def __init__(
6469
else:
6570
self._use_responses = _openai_shared.get_use_responses_by_default()
6671

72+
self._default_model = default_model
73+
6774
# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
6875
# AsyncOpenAI() raises an error if you don't have an API key set.
69-
def _get_client(self) -> AsyncOpenAI:
76+
def _get_client(self) -> TOpenAIClient:
7077
if self._client is None:
71-
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
72-
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
73-
base_url=self._stored_base_url,
74-
organization=self._stored_organization,
75-
project=self._stored_project,
76-
http_client=shared_http_client(),
77-
)
78+
default_client = _openai_shared.get_default_openai_client()
79+
if default_client:
80+
self._client = default_client
81+
else:
82+
try:
83+
self._client = create_client(
84+
api_key=self._stored_api_key,
85+
base_url=self._stored_base_url,
86+
organization=self._stored_organization,
87+
project=self._stored_project,
88+
http_client=shared_http_client(),
89+
)
90+
except OpenAIError as e:
91+
_logger.error(f"Failed to create OpenAI client: {e}")
92+
raise
7893

7994
return self._client
8095

8196
def get_model(self, model_name: str | None) -> Model:
97+
"""Get a model instance by name.
98+
99+
Args:
100+
model_name: The name of the model to get. If None, uses the default model.
101+
102+
Returns:
103+
An OpenAI model implementation (either Responses or ChatCompletions
104+
based on configuration)
105+
"""
82106
if model_name is None:
83-
model_name = DEFAULT_MODEL
107+
model_name = self._default_model
84108

85109
client = self._get_client()
86110

0 commit comments

Comments
 (0)