1
1
from __future__ import annotations
2
2
3
3
from contextlib import contextmanager
4
- from typing import TYPE_CHECKING , Any , AsyncIterator , Callable , ContextManager , Iterator , NamedTuple
4
+ from typing import (
5
+ TYPE_CHECKING ,
6
+ Any ,
7
+ AsyncIterator ,
8
+ Callable ,
9
+ ContextManager ,
10
+ Generic ,
11
+ Iterator ,
12
+ NamedTuple ,
13
+ TypeVar ,
14
+ cast ,
15
+ )
5
16
6
17
import openai
18
+ from openai ._legacy_response import LegacyAPIResponse
19
+ from openai .types .chat .chat_completion import ChatCompletion
20
+ from openai .types .chat .chat_completion_chunk import ChatCompletionChunk
21
+ from openai .types .completion import Completion
22
+ from openai .types .create_embedding_response import CreateEmbeddingResponse
23
+ from openai .types .images_response import ImagesResponse
7
24
from opentelemetry import context
8
25
9
26
if TYPE_CHECKING :
10
27
from openai ._models import FinalRequestOptions
11
28
from openai ._streaming import AsyncStream , Stream
12
- from openai .types .chat .chat_completion import ChatCompletion
13
- from openai .types .completion import Completion
14
- from openai .types .create_embedding_response import CreateEmbeddingResponse
15
- from openai .types .images_response import ImagesResponse
16
- from typing_extensions import LiteralString
29
+ from openai ._types import ResponseT
30
+ from typing_extensions import LiteralString , TypedDict , Unpack
17
31
18
32
from ..main import Logfire , LogfireSpan
19
33
34
+ # The following typevars are used to use a generic type in the `OpenAIRequest` TypedDict for the sync and async flavors
35
+ _AsyncStreamT = TypeVar ('_AsyncStreamT' , bound = AsyncStream [Any ])
36
+ _StreamT = TypeVar ('_StreamT' , bound = Stream [Any ])
37
+
38
+ _ResponseType = TypeVar ('_ResponseType' )
39
+ _StreamType = TypeVar ('_StreamType' )
40
+
41
+ class OpenAIRequest (TypedDict , Generic [_ResponseType , _StreamType ]):
42
+ cast_to : type [_ResponseType ]
43
+ options : FinalRequestOptions
44
+ remaining_retries : int | None
45
+ stream : bool
46
+ stream_cls : type [_StreamType ] | None
47
+
20
48
21
49
__all__ = ('instrument_openai' ,)
22
50
@@ -59,30 +87,30 @@ def instrument_openai_sync(logfire_openai: Logfire, openai_client: openai.OpenAI
59
87
# WARNING: this method is vey similar to `instrument_openai_async` below, any changes here should be reflected there
60
88
openai_client ._original_request_method = original_request_method = openai_client ._request # type: ignore
61
89
62
- def instrumented_openai_request (** kwargs : Any ) -> Any :
90
+ def instrumented_openai_request (** kwargs : Unpack [ OpenAIRequest [ ResponseT , _StreamT ]] ) -> ResponseT | _StreamT :
63
91
if context .get_value ('suppress_instrumentation' ):
64
92
return original_request_method (** kwargs )
65
93
66
- options : FinalRequestOptions | None = kwargs . get ( 'options' )
94
+ options = kwargs [ 'options' ]
67
95
try :
68
- message_template , span_data , on_response , content_from_stream = get_endpoint_config (options )
96
+ message_template , span_data , content_from_stream = get_endpoint_config (options )
69
97
except ValueError as exc :
70
98
logfire_openai .warn ('Unable to instrument OpenAI API call: {error}' , error = str (exc ), kwargs = kwargs )
71
99
return original_request_method (** kwargs )
72
100
73
101
span_data ['async' ] = False
74
- stream = bool ( kwargs . get ( 'stream' ))
102
+ stream = kwargs [ 'stream' ]
75
103
76
104
if stream and content_from_stream :
77
- stream_cls : type [ Stream ] | None = kwargs . get ( 'stream_cls' ) # type: ignore[reportMissingTypeArgument ]
105
+ stream_cls = kwargs [ 'stream_cls' ]
78
106
assert stream_cls is not None , 'Expected `stream_cls` when streaming'
79
107
80
108
class LogfireInstrumentedStream (stream_cls ):
81
109
def __stream__ (self ) -> Iterator [Any ]:
82
110
content : list [str ] = []
83
111
with logfire_openai .span (STEAMING_MSG_TEMPLATE , ** span_data ) as stream_span :
84
112
with maybe_suppress_instrumentation (suppress_otel ):
85
- for chunk in super ().__stream__ (): # type: ignore
113
+ for chunk in super ().__stream__ ():
86
114
chunk_content = content_from_stream (chunk )
87
115
if chunk_content is not None :
88
116
content .append (chunk_content )
@@ -92,15 +120,14 @@ def __stream__(self) -> Iterator[Any]:
92
120
{'combined_chunk_content' : '' .join (content ), 'chunk_count' : len (content )},
93
121
)
94
122
95
- kwargs ['stream_cls' ] = LogfireInstrumentedStream
123
+ kwargs ['stream_cls' ] = LogfireInstrumentedStream # type: ignore
96
124
97
125
with logfire_openai .span (message_template , ** span_data ) as span :
98
126
with maybe_suppress_instrumentation (suppress_otel ):
99
127
if stream :
100
128
return original_request_method (** kwargs )
101
129
else :
102
- response = original_request_method (** kwargs )
103
- on_response (response , span )
130
+ response = on_response (original_request_method (** kwargs ), span )
104
131
return response
105
132
106
133
openai_client ._request = instrumented_openai_request # type: ignore
@@ -110,30 +137,32 @@ def instrument_openai_async(logfire_openai: Logfire, openai_client: openai.Async
110
137
# WARNING: this method is vey similar to `instrument_openai_sync` above, any changes here should be reflected there
111
138
openai_client ._original_request_method = original_request_method = openai_client ._request # type: ignore
112
139
113
- async def instrumented_openai_request (** kwargs : Any ) -> Any :
140
+ async def instrumented_openai_request (
141
+ ** kwargs : Unpack [OpenAIRequest [ResponseT , _AsyncStreamT ]],
142
+ ) -> ResponseT | _AsyncStreamT :
114
143
if context .get_value ('suppress_instrumentation' ):
115
144
return await original_request_method (** kwargs )
116
145
117
- options : FinalRequestOptions | None = kwargs . get ( 'options' )
146
+ options = kwargs [ 'options' ]
118
147
try :
119
- message_template , span_data , on_response , content_from_stream = get_endpoint_config (options )
148
+ message_template , span_data , content_from_stream = get_endpoint_config (options )
120
149
except ValueError as exc :
121
150
logfire_openai .warn ('Unable to instrument OpenAI API call: {error}' , error = str (exc ), kwargs = kwargs )
122
151
return await original_request_method (** kwargs )
123
152
124
153
span_data ['async' ] = True
125
- stream = bool ( kwargs . get ( 'stream' ))
154
+ stream = kwargs [ 'stream' ]
126
155
127
156
if stream and content_from_stream :
128
- stream_cls : type [ AsyncStream ] | None = kwargs . get ( 'stream_cls' ) # type: ignore[reportMissingTypeArgument ]
157
+ stream_cls = kwargs [ 'stream_cls' ]
129
158
assert stream_cls is not None , 'Expected `stream_cls` when streaming'
130
159
131
160
class LogfireInstrumentedStream (stream_cls ):
132
161
async def __stream__ (self ) -> AsyncIterator [Any ]:
133
162
content : list [str ] = []
134
163
with logfire_openai .span (STEAMING_MSG_TEMPLATE , ** span_data ) as stream_span :
135
164
with maybe_suppress_instrumentation (suppress_otel ):
136
- async for chunk in super ().__stream__ (): # type: ignore
165
+ async for chunk in super ().__stream__ ():
137
166
chunk_content = content_from_stream (chunk )
138
167
if chunk_content is not None :
139
168
content .append (chunk_content )
@@ -143,15 +172,14 @@ async def __stream__(self) -> AsyncIterator[Any]:
143
172
{'combined_chunk_content' : '' .join (content ), 'chunk_count' : len (content )},
144
173
)
145
174
146
- kwargs ['stream_cls' ] = LogfireInstrumentedStream
175
+ kwargs ['stream_cls' ] = LogfireInstrumentedStream # type: ignore
147
176
148
177
with logfire_openai .span (message_template , ** span_data ) as span :
149
178
with maybe_suppress_instrumentation (suppress_otel ):
150
179
if stream :
151
180
return await original_request_method (** kwargs )
152
181
else :
153
- response = await original_request_method (** kwargs )
154
- on_response (response , span )
182
+ response = on_response (await original_request_method (** kwargs ), span )
155
183
return response
156
184
157
185
openai_client ._request = instrumented_openai_request # type: ignore
@@ -160,13 +188,10 @@ async def __stream__(self) -> AsyncIterator[Any]:
160
188
class EndpointConfig (NamedTuple ):
161
189
message_template : LiteralString
162
190
span_data : dict [str , Any ]
163
- on_response : Callable [[Any , LogfireSpan ], None ]
164
191
content_from_stream : Callable [[Any ], str | None ] | None
165
192
166
193
167
- def get_endpoint_config (options : FinalRequestOptions | None ) -> EndpointConfig :
168
- if options is None :
169
- raise ValueError ('`options` is required' )
194
+ def get_endpoint_config (options : FinalRequestOptions ) -> EndpointConfig :
170
195
url = options .url
171
196
json_data = options .json_data
172
197
if not isinstance (json_data , dict ):
@@ -179,62 +204,63 @@ def get_endpoint_config(options: FinalRequestOptions | None) -> EndpointConfig:
179
204
return EndpointConfig (
180
205
message_template = 'Chat Completion with {request_data[model]!r}' ,
181
206
span_data = {'request_data' : json_data },
182
- on_response = on_chat_response ,
183
- content_from_stream = lambda chunk : chunk .choices [0 ].delta .content if chunk and chunk .choices else None ,
207
+ content_from_stream = content_from_chat_completions ,
184
208
)
185
209
elif url == '/completions' :
186
210
return EndpointConfig (
187
211
message_template = 'Completion with {request_data[model]!r}' ,
188
212
span_data = {'request_data' : json_data },
189
- on_response = on_completion_response ,
190
- content_from_stream = lambda chunk : chunk .choices [0 ].text if chunk and chunk .choices else None ,
213
+ content_from_stream = content_from_completions ,
191
214
)
192
215
elif url == '/embeddings' :
193
216
return EndpointConfig (
194
217
message_template = 'Embedding Creation with {request_data[model]!r}' ,
195
218
span_data = {'request_data' : json_data },
196
- on_response = on_embedding_response ,
197
219
content_from_stream = None ,
198
220
)
199
221
elif url == '/images/generations' :
200
222
return EndpointConfig (
201
223
message_template = 'Image Generation with {request_data[model]!r}' ,
202
224
span_data = {'request_data' : json_data },
203
- on_response = on_image_response ,
204
225
content_from_stream = None ,
205
226
)
206
227
else :
207
228
raise ValueError (f'Unknown OpenAI API endpoint: `{ url } `' )
208
229
209
230
210
- def on_chat_response (response : ChatCompletion , span : LogfireSpan ) -> None :
211
- span .set_attribute (
212
- 'response_data' ,
213
- {
214
- 'message' : response .choices [0 ].message ,
215
- 'usage' : response .usage ,
216
- },
217
- )
231
+ def content_from_completions (chunk : Completion | None ) -> str | None :
232
+ if chunk and chunk .choices :
233
+ return chunk .choices [0 ].text
234
+ return None # pragma: no cover
218
235
219
236
220
- def on_completion_response (response : Completion , span : LogfireSpan ) -> None :
221
- first_choice = response .choices [0 ]
222
- span .set_attribute (
223
- 'response_data' ,
224
- {
225
- 'finish_reason' : first_choice .finish_reason ,
226
- 'text' : first_choice .text ,
227
- 'usage' : response .usage ,
228
- },
229
- )
237
+ def content_from_chat_completions (chunk : ChatCompletionChunk | None ) -> str | None :
238
+ if chunk and chunk .choices :
239
+ return chunk .choices [0 ].delta .content
240
+ return None
230
241
231
242
232
- def on_embedding_response (response : CreateEmbeddingResponse , span : LogfireSpan ) -> None :
233
- span .set_attribute ('response_data' , {'usage' : response .usage })
243
+ def on_response (response : ResponseT , span : LogfireSpan ) -> ResponseT :
244
+ if isinstance (response , LegacyAPIResponse ): # pragma: no cover
245
+ on_response (response .parse (), span ) # type: ignore
246
+ return cast ('ResponseT' , response )
234
247
235
-
236
- def on_image_response (response : ImagesResponse , span : LogfireSpan ) -> None :
237
- span .set_attribute ('response_data' , {'images' : response .data })
248
+ if isinstance (response , ChatCompletion ):
249
+ span .set_attribute (
250
+ 'response_data' ,
251
+ {'message' : response .choices [0 ].message , 'usage' : response .usage },
252
+ )
253
+ elif isinstance (response , Completion ):
254
+ first_choice = response .choices [0 ]
255
+ span .set_attribute (
256
+ 'response_data' ,
257
+ {'finish_reason' : first_choice .finish_reason , 'text' : first_choice .text , 'usage' : response .usage },
258
+ )
259
+ elif isinstance (response , CreateEmbeddingResponse ):
260
+ span .set_attribute ('response_data' , {'usage' : response .usage })
261
+ elif isinstance (response , ImagesResponse ): # pragma: no branch
262
+ span .set_attribute ('response_data' , {'images' : response .data })
263
+ return response
238
264
239
265
240
266
@contextmanager
0 commit comments