Skip to content

Commit 944c590

Browse files
authored
Improve the OpenAI integration (#104)
1 parent dc58761 commit 944c590

File tree

2 files changed

+83
-62
lines changed

2 files changed

+83
-62
lines changed

logfire/_internal/integrations/openai.py

Lines changed: 83 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,50 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, ContextManager, Iterator, NamedTuple
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
AsyncIterator,
8+
Callable,
9+
ContextManager,
10+
Generic,
11+
Iterator,
12+
NamedTuple,
13+
TypeVar,
14+
cast,
15+
)
516

617
import openai
18+
from openai._legacy_response import LegacyAPIResponse
19+
from openai.types.chat.chat_completion import ChatCompletion
20+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
21+
from openai.types.completion import Completion
22+
from openai.types.create_embedding_response import CreateEmbeddingResponse
23+
from openai.types.images_response import ImagesResponse
724
from opentelemetry import context
825

926
if TYPE_CHECKING:
1027
from openai._models import FinalRequestOptions
1128
from openai._streaming import AsyncStream, Stream
12-
from openai.types.chat.chat_completion import ChatCompletion
13-
from openai.types.completion import Completion
14-
from openai.types.create_embedding_response import CreateEmbeddingResponse
15-
from openai.types.images_response import ImagesResponse
16-
from typing_extensions import LiteralString
29+
from openai._types import ResponseT
30+
from typing_extensions import LiteralString, TypedDict, Unpack
1731

1832
from ..main import Logfire, LogfireSpan
1933

34+
# The following typevars are used to use a generic type in the `OpenAIRequest` TypedDict for the sync and async flavors
35+
_AsyncStreamT = TypeVar('_AsyncStreamT', bound=AsyncStream[Any])
36+
_StreamT = TypeVar('_StreamT', bound=Stream[Any])
37+
38+
_ResponseType = TypeVar('_ResponseType')
39+
_StreamType = TypeVar('_StreamType')
40+
41+
class OpenAIRequest(TypedDict, Generic[_ResponseType, _StreamType]):
42+
cast_to: type[_ResponseType]
43+
options: FinalRequestOptions
44+
remaining_retries: int | None
45+
stream: bool
46+
stream_cls: type[_StreamType] | None
47+
2048

2149
__all__ = ('instrument_openai',)
2250

@@ -59,30 +87,30 @@ def instrument_openai_sync(logfire_openai: Logfire, openai_client: openai.OpenAI
5987
# WARNING: this method is vey similar to `instrument_openai_async` below, any changes here should be reflected there
6088
openai_client._original_request_method = original_request_method = openai_client._request # type: ignore
6189

62-
def instrumented_openai_request(**kwargs: Any) -> Any:
90+
def instrumented_openai_request(**kwargs: Unpack[OpenAIRequest[ResponseT, _StreamT]]) -> ResponseT | _StreamT:
6391
if context.get_value('suppress_instrumentation'):
6492
return original_request_method(**kwargs)
6593

66-
options: FinalRequestOptions | None = kwargs.get('options')
94+
options = kwargs['options']
6795
try:
68-
message_template, span_data, on_response, content_from_stream = get_endpoint_config(options)
96+
message_template, span_data, content_from_stream = get_endpoint_config(options)
6997
except ValueError as exc:
7098
logfire_openai.warn('Unable to instrument OpenAI API call: {error}', error=str(exc), kwargs=kwargs)
7199
return original_request_method(**kwargs)
72100

73101
span_data['async'] = False
74-
stream = bool(kwargs.get('stream'))
102+
stream = kwargs['stream']
75103

76104
if stream and content_from_stream:
77-
stream_cls: type[Stream] | None = kwargs.get('stream_cls') # type: ignore[reportMissingTypeArgument]
105+
stream_cls = kwargs['stream_cls']
78106
assert stream_cls is not None, 'Expected `stream_cls` when streaming'
79107

80108
class LogfireInstrumentedStream(stream_cls):
81109
def __stream__(self) -> Iterator[Any]:
82110
content: list[str] = []
83111
with logfire_openai.span(STEAMING_MSG_TEMPLATE, **span_data) as stream_span:
84112
with maybe_suppress_instrumentation(suppress_otel):
85-
for chunk in super().__stream__(): # type: ignore
113+
for chunk in super().__stream__():
86114
chunk_content = content_from_stream(chunk)
87115
if chunk_content is not None:
88116
content.append(chunk_content)
@@ -92,15 +120,14 @@ def __stream__(self) -> Iterator[Any]:
92120
{'combined_chunk_content': ''.join(content), 'chunk_count': len(content)},
93121
)
94122

95-
kwargs['stream_cls'] = LogfireInstrumentedStream
123+
kwargs['stream_cls'] = LogfireInstrumentedStream # type: ignore
96124

97125
with logfire_openai.span(message_template, **span_data) as span:
98126
with maybe_suppress_instrumentation(suppress_otel):
99127
if stream:
100128
return original_request_method(**kwargs)
101129
else:
102-
response = original_request_method(**kwargs)
103-
on_response(response, span)
130+
response = on_response(original_request_method(**kwargs), span)
104131
return response
105132

106133
openai_client._request = instrumented_openai_request # type: ignore
@@ -110,30 +137,32 @@ def instrument_openai_async(logfire_openai: Logfire, openai_client: openai.Async
110137
# WARNING: this method is vey similar to `instrument_openai_sync` above, any changes here should be reflected there
111138
openai_client._original_request_method = original_request_method = openai_client._request # type: ignore
112139

113-
async def instrumented_openai_request(**kwargs: Any) -> Any:
140+
async def instrumented_openai_request(
141+
**kwargs: Unpack[OpenAIRequest[ResponseT, _AsyncStreamT]],
142+
) -> ResponseT | _AsyncStreamT:
114143
if context.get_value('suppress_instrumentation'):
115144
return await original_request_method(**kwargs)
116145

117-
options: FinalRequestOptions | None = kwargs.get('options')
146+
options = kwargs['options']
118147
try:
119-
message_template, span_data, on_response, content_from_stream = get_endpoint_config(options)
148+
message_template, span_data, content_from_stream = get_endpoint_config(options)
120149
except ValueError as exc:
121150
logfire_openai.warn('Unable to instrument OpenAI API call: {error}', error=str(exc), kwargs=kwargs)
122151
return await original_request_method(**kwargs)
123152

124153
span_data['async'] = True
125-
stream = bool(kwargs.get('stream'))
154+
stream = kwargs['stream']
126155

127156
if stream and content_from_stream:
128-
stream_cls: type[AsyncStream] | None = kwargs.get('stream_cls') # type: ignore[reportMissingTypeArgument]
157+
stream_cls = kwargs['stream_cls']
129158
assert stream_cls is not None, 'Expected `stream_cls` when streaming'
130159

131160
class LogfireInstrumentedStream(stream_cls):
132161
async def __stream__(self) -> AsyncIterator[Any]:
133162
content: list[str] = []
134163
with logfire_openai.span(STEAMING_MSG_TEMPLATE, **span_data) as stream_span:
135164
with maybe_suppress_instrumentation(suppress_otel):
136-
async for chunk in super().__stream__(): # type: ignore
165+
async for chunk in super().__stream__():
137166
chunk_content = content_from_stream(chunk)
138167
if chunk_content is not None:
139168
content.append(chunk_content)
@@ -143,15 +172,14 @@ async def __stream__(self) -> AsyncIterator[Any]:
143172
{'combined_chunk_content': ''.join(content), 'chunk_count': len(content)},
144173
)
145174

146-
kwargs['stream_cls'] = LogfireInstrumentedStream
175+
kwargs['stream_cls'] = LogfireInstrumentedStream # type: ignore
147176

148177
with logfire_openai.span(message_template, **span_data) as span:
149178
with maybe_suppress_instrumentation(suppress_otel):
150179
if stream:
151180
return await original_request_method(**kwargs)
152181
else:
153-
response = await original_request_method(**kwargs)
154-
on_response(response, span)
182+
response = on_response(await original_request_method(**kwargs), span)
155183
return response
156184

157185
openai_client._request = instrumented_openai_request # type: ignore
@@ -160,13 +188,10 @@ async def __stream__(self) -> AsyncIterator[Any]:
160188
class EndpointConfig(NamedTuple):
161189
message_template: LiteralString
162190
span_data: dict[str, Any]
163-
on_response: Callable[[Any, LogfireSpan], None]
164191
content_from_stream: Callable[[Any], str | None] | None
165192

166193

167-
def get_endpoint_config(options: FinalRequestOptions | None) -> EndpointConfig:
168-
if options is None:
169-
raise ValueError('`options` is required')
194+
def get_endpoint_config(options: FinalRequestOptions) -> EndpointConfig:
170195
url = options.url
171196
json_data = options.json_data
172197
if not isinstance(json_data, dict):
@@ -179,62 +204,63 @@ def get_endpoint_config(options: FinalRequestOptions | None) -> EndpointConfig:
179204
return EndpointConfig(
180205
message_template='Chat Completion with {request_data[model]!r}',
181206
span_data={'request_data': json_data},
182-
on_response=on_chat_response,
183-
content_from_stream=lambda chunk: chunk.choices[0].delta.content if chunk and chunk.choices else None,
207+
content_from_stream=content_from_chat_completions,
184208
)
185209
elif url == '/completions':
186210
return EndpointConfig(
187211
message_template='Completion with {request_data[model]!r}',
188212
span_data={'request_data': json_data},
189-
on_response=on_completion_response,
190-
content_from_stream=lambda chunk: chunk.choices[0].text if chunk and chunk.choices else None,
213+
content_from_stream=content_from_completions,
191214
)
192215
elif url == '/embeddings':
193216
return EndpointConfig(
194217
message_template='Embedding Creation with {request_data[model]!r}',
195218
span_data={'request_data': json_data},
196-
on_response=on_embedding_response,
197219
content_from_stream=None,
198220
)
199221
elif url == '/images/generations':
200222
return EndpointConfig(
201223
message_template='Image Generation with {request_data[model]!r}',
202224
span_data={'request_data': json_data},
203-
on_response=on_image_response,
204225
content_from_stream=None,
205226
)
206227
else:
207228
raise ValueError(f'Unknown OpenAI API endpoint: `{url}`')
208229

209230

210-
def on_chat_response(response: ChatCompletion, span: LogfireSpan) -> None:
211-
span.set_attribute(
212-
'response_data',
213-
{
214-
'message': response.choices[0].message,
215-
'usage': response.usage,
216-
},
217-
)
231+
def content_from_completions(chunk: Completion | None) -> str | None:
232+
if chunk and chunk.choices:
233+
return chunk.choices[0].text
234+
return None # pragma: no cover
218235

219236

220-
def on_completion_response(response: Completion, span: LogfireSpan) -> None:
221-
first_choice = response.choices[0]
222-
span.set_attribute(
223-
'response_data',
224-
{
225-
'finish_reason': first_choice.finish_reason,
226-
'text': first_choice.text,
227-
'usage': response.usage,
228-
},
229-
)
237+
def content_from_chat_completions(chunk: ChatCompletionChunk | None) -> str | None:
238+
if chunk and chunk.choices:
239+
return chunk.choices[0].delta.content
240+
return None
230241

231242

232-
def on_embedding_response(response: CreateEmbeddingResponse, span: LogfireSpan) -> None:
233-
span.set_attribute('response_data', {'usage': response.usage})
243+
def on_response(response: ResponseT, span: LogfireSpan) -> ResponseT:
244+
if isinstance(response, LegacyAPIResponse): # pragma: no cover
245+
on_response(response.parse(), span) # type: ignore
246+
return cast('ResponseT', response)
234247

235-
236-
def on_image_response(response: ImagesResponse, span: LogfireSpan) -> None:
237-
span.set_attribute('response_data', {'images': response.data})
248+
if isinstance(response, ChatCompletion):
249+
span.set_attribute(
250+
'response_data',
251+
{'message': response.choices[0].message, 'usage': response.usage},
252+
)
253+
elif isinstance(response, Completion):
254+
first_choice = response.choices[0]
255+
span.set_attribute(
256+
'response_data',
257+
{'finish_reason': first_choice.finish_reason, 'text': first_choice.text, 'usage': response.usage},
258+
)
259+
elif isinstance(response, CreateEmbeddingResponse):
260+
span.set_attribute('response_data', {'usage': response.usage})
261+
elif isinstance(response, ImagesResponse): # pragma: no branch
262+
span.set_attribute('response_data', {'images': response.data})
263+
return response
238264

239265

240266
@contextmanager

tests/otel_integrations/test_openai.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -890,11 +890,6 @@ async def test_async_unknown_method(instrumented_async_client: openai.AsyncClien
890890
)
891891

892892

893-
def test_get_endpoint_config_none():
894-
with pytest.raises(ValueError, match='`options` is required'):
895-
get_endpoint_config(None)
896-
897-
898893
def test_get_endpoint_config_json_not_dict():
899894
with pytest.raises(ValueError, match='Expected `options.json_data` to be a dictionary'):
900895
get_endpoint_config(FinalRequestOptions(method='POST', url='...'))

0 commit comments

Comments
 (0)