Skip to content

Commit 8a13a15

Browse files
authored
Break up chat_request() into smaller generics (#426)
And use that to simplify methods for OpenAI subclasses.
1 parent ad1db5b commit 8a13a15

File tree

7 files changed

+152
-147
lines changed

7 files changed

+152
-147
lines changed

R/provider-azure.R

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -131,24 +131,20 @@ azure_endpoint <- function() {
131131
}
132132

133133
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
134-
method(chat_request, ProviderAzureOpenAI) <- function(
135-
provider,
136-
stream = TRUE,
137-
turns = list(),
138-
tools = list(),
139-
type = NULL
140-
) {
141-
req <- request(provider@base_url)
142-
req <- req_url_path_append(req, "/chat/completions")
134+
method(base_request, ProviderAzureOpenAI) <- function(provider) {
135+
req <- base_request(super(provider, ProviderOpenAI))
136+
req <- req_headers(req, Authorization = NULL)
137+
143138
req <- req_url_query(req, `api-version` = provider@api_version)
144139
if (nchar(provider@api_key)) {
145140
req <- req_headers_redacted(req, `api-key` = provider@api_key)
146141
}
147142
req <- ellmer_req_credentials(req, provider@credentials)
148-
req <- req_retry(req, max_tries = 2)
149-
req <- ellmer_req_timeout(req, stream)
150-
req <- ellmer_req_user_agent(req)
151-
req <- req_error(req, body = function(resp) {
143+
req
144+
}
145+
146+
method(base_request_error, ProviderAzureOpenAI) <- function(provider, req) {
147+
req_error(req, body = function(resp) {
152148
error <- resp_body_json(resp)$error
153149
msg <- paste0(error$code, ": ", error$message)
154150
# Try to be helpful in the (common) case that the user or service
@@ -172,37 +168,6 @@ method(chat_request, ProviderAzureOpenAI) <- function(
172168
}
173169
msg
174170
})
175-
176-
messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
177-
tools <- as_json(provider, unname(tools))
178-
179-
if (!is.null(type)) {
180-
response_format <- list(
181-
type = "json_schema",
182-
json_schema = list(
183-
name = "structured_data",
184-
schema = as_json(provider, type),
185-
strict = TRUE
186-
)
187-
)
188-
} else {
189-
response_format <- NULL
190-
}
191-
192-
params <- chat_params(provider, provider@params)
193-
body <- compact(list2(
194-
messages = messages,
195-
model = provider@model,
196-
stream = stream,
197-
stream_options = if (stream) list(include_usage = TRUE),
198-
tools = tools,
199-
response_format = response_format,
200-
!!!params
201-
))
202-
body <- modify_list(body, provider@extra_args)
203-
req <- req_body_json(req, body)
204-
205-
req
206171
}
207172

208173
default_azure_credentials <- function(api_key = NULL, token = NULL) {

R/provider-databricks.R

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,57 +85,52 @@ ProviderDatabricks <- new_class(
8585
properties = list(credentials = class_function)
8686
)
8787

88-
method(chat_request, ProviderDatabricks) <- function(
88+
method(base_request, ProviderDatabricks) <- function(provider) {
89+
req <- request(provider@base_url)
90+
req <- ellmer_req_credentials(req, provider@credentials)
91+
req <- req_retry(req, max_tries = 2)
92+
req <- ellmer_req_timeout(req, stream)
93+
req <- ellmer_req_user_agent(req, databricks_user_agent())
94+
req <- base_request_error(provider, req)
95+
req
96+
}
97+
98+
method(chat_body, ProviderDatabricks) <- function(
8999
provider,
90100
stream = TRUE,
91101
turns = list(),
92102
tools = list(),
93103
type = NULL
94104
) {
95-
req <- request(provider@base_url)
105+
body <- chat_body(
106+
super(provider, ProviderOpenAI),
107+
stream = stream,
108+
turns = turns,
109+
tools = tools,
110+
type = type
111+
)
112+
113+
# Databricks doensn't support stream options
114+
body$stream_options <- NULL
115+
116+
body
117+
}
118+
119+
method(chat_path, ProviderDatabricks) <- function(provider) {
96120
# Note: this API endpoint is undocumented and seems to exist primarily for
97121
# compatibility with the OpenAI Python SDK. The documented endpoint is
98122
# `/serving-endpoints/<model>/invocations`.
99-
req <- req_url_path_append(req, "/serving-endpoints/chat/completions")
100-
req <- ellmer_req_credentials(req, provider@credentials)
101-
req <- ellmer_req_user_agent(req, databricks_user_agent())
102-
req <- req_retry(req, max_tries = 2)
103-
req <- ellmer_req_timeout(req, stream)
104-
req <- req_error(req, body = function(resp) {
123+
"/serving-endpoints/chat/completions"
124+
}
125+
126+
method(base_request_error, ProviderDatabricks) <- function(provider, req) {
127+
req_error(req, body = function(resp) {
105128
if (resp_content_type(resp) == "application/json") {
106129
# Databrick's "OpenAI-compatible" API has a slightly incompatible error
107130
# response format, which we account for here.
108131
resp_body_json(resp)$message
109132
}
110133
})
111-
112-
messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
113-
tools <- as_json(provider, unname(tools))
114-
115-
if (!is.null(type)) {
116-
response_format <- list(
117-
type = "json_schema",
118-
json_schema = list(
119-
name = "structured_data",
120-
schema = as_json(provider, type),
121-
strict = TRUE
122-
)
123-
)
124-
} else {
125-
response_format <- NULL
126-
}
127-
128-
body <- compact(list(
129-
messages = messages,
130-
model = provider@model,
131-
stream = stream,
132-
tools = tools,
133-
response_format = response_format
134-
))
135-
body <- modify_list(body, provider@extra_args)
136-
req <- req_body_json(req, body)
137-
138-
req
139134
}
140135

141136
method(as_json, list(ProviderDatabricks, Turn)) <- function(provider, x) {

R/provider-openai.R

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,42 @@ openai_key <- function() {
105105
key_get("OPENAI_API_KEY")
106106
}
107107

108-
# https://platform.openai.com/docs/api-reference/chat/create
109-
method(chat_request, ProviderOpenAI) <- function(
110-
provider,
111-
stream = TRUE,
112-
turns = list(),
113-
tools = list(),
114-
type = NULL
115-
) {
108+
# Base request -----------------------------------------------------------------
109+
110+
method(base_request, ProviderOpenAI) <- function(provider) {
116111
req <- request(provider@base_url)
117-
req <- req_url_path_append(req, "/chat/completions")
118112
req <- req_auth_bearer_token(req, provider@api_key)
119113
req <- req_retry(req, max_tries = 2)
120114
req <- ellmer_req_timeout(req, stream)
121115
req <- ellmer_req_user_agent(req)
116+
req <- base_request_error(provider, req)
117+
req
118+
}
122119

123-
req <- req_error(req, body = function(resp) {
120+
method(base_request_error, ProviderOpenAI) <- function(provider, req) {
121+
req_error(req, body = function(resp) {
124122
if (resp_content_type(resp) == "application/json") {
125123
resp_body_json(resp)$error$message
126124
} else if (resp_content_type(resp) == "text/plain") {
127125
resp_body_string(resp)
128126
}
129127
})
128+
}
129+
130+
# Chat endpoint ----------------------------------------------------------------
130131

132+
method(chat_path, ProviderOpenAI) <- function(provider) {
133+
"/chat/completions"
134+
}
135+
136+
# https://platform.openai.com/docs/api-reference/chat/create
137+
method(chat_body, ProviderOpenAI) <- function(
138+
provider,
139+
stream = TRUE,
140+
turns = list(),
141+
tools = list(),
142+
type = NULL
143+
) {
131144
messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
132145
tools <- as_json(provider, unname(tools))
133146

@@ -147,7 +160,7 @@ method(chat_request, ProviderOpenAI) <- function(
147160
params <- chat_params(provider, provider@params)
148161
params$seed <- params$seed %||% provider@seed
149162

150-
body <- compact(list2(
163+
compact(list2(
151164
messages = messages,
152165
model = provider@model,
153166
!!!params,
@@ -156,12 +169,9 @@ method(chat_request, ProviderOpenAI) <- function(
156169
tools = tools,
157170
response_format = response_format
158171
))
159-
body <- utils::modifyList(body, provider@extra_args)
160-
req <- req_body_json(req, body)
161-
162-
req
163172
}
164173

174+
165175
method(chat_params, ProviderOpenAI) <- function(provider, params) {
166176
standardise_params(
167177
params,

R/provider-openrouter.R

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,8 @@ openrouter_key <- function() {
5858
key_get("OPENROUTER_API_KEY")
5959
}
6060

61-
method(chat_request, ProviderOpenRouter) <- function(
62-
provider,
63-
stream = TRUE,
64-
turns = list(),
65-
tools = list(),
66-
type = NULL
67-
) {
68-
req <- chat_request(
69-
super(provider, ProviderOpenAI),
70-
stream = stream,
71-
turns = turns,
72-
tools = tools,
73-
type = type
74-
)
75-
61+
method(base_request, ProviderOpenRouter) <- function(provider) {
62+
req <- base_request(super(provider, ProviderOpenAI))
7663
# https://openrouter.ai/docs/api-keys
7764
req <- req_headers(
7865
req,

R/provider-snowflake.R

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -78,57 +78,51 @@ ProviderSnowflakeCortex <- new_class(
7878
)
7979
)
8080

81+
method(base_request, ProviderSnowflakeCortex) <- function(provider) {
82+
req <- request(provider@base_url)
83+
req <- ellmer_req_credentials(req, provider@credentials)
84+
req <- req_retry(req, max_tries = 2)
85+
req <- ellmer_req_timeout(req, stream)
86+
# Snowflake uses the User Agent header to identify "parter applications", so
87+
# identify requests as coming from "r_ellmer" (unless an explicit partner
88+
# application is set via the ambient SF_PARTNER environment variable).
89+
req <- ellmer_req_user_agent(req, Sys.getenv("SF_PARTNER"))
90+
91+
# Snowflake-specific error response handling:
92+
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)
93+
94+
req
95+
}
96+
97+
method(chat_path, ProviderSnowflakeCortex) <- function(provider) {
98+
"/api/v2/cortex/inference:complete"
99+
}
100+
81101
# See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference
82-
method(chat_request, ProviderSnowflakeCortex) <- function(
102+
method(chat_body, ProviderSnowflakeCortex) <- function(
83103
provider,
84104
stream = TRUE,
85105
turns = list(),
86106
tools = list(),
87107
type = NULL
88108
) {
109+
call <- quote(chat_snowflake())
89110
if (length(tools) != 0) {
90-
cli::cli_abort(
91-
"Tool calling is not supported.",
92-
call = quote(chat_snowflake())
93-
)
111+
cli::cli_abort("Tool calling is not supported.", call = call)
94112
}
95113
if (!is.null(type) != 0) {
96-
cli::cli_abort(
97-
"Structured data extraction is not supported.",
98-
call = quote(chat_snowflake())
99-
)
114+
cli::cli_abort("Structured data extraction is not supported.", call = call)
100115
}
101116
if (!stream) {
102-
cli::cli_abort(
103-
"Non-streaming responses are not supported.",
104-
call = quote(chat_snowflake())
105-
)
117+
cli::cli_abort("Non-streaming responses are not supported.", call = call)
106118
}
107119

108-
req <- request(provider@base_url)
109-
req <- req_url_path_append(req, "/api/v2/cortex/inference:complete")
110-
req <- ellmer_req_credentials(req, provider@credentials)
111-
req <- req_retry(req, max_tries = 2)
112-
req <- ellmer_req_timeout(req, stream)
113-
# Snowflake uses the User Agent header to identify "parter applications", so
114-
# identify requests as coming from "r_ellmer" (unless an explicit partner
115-
# application is set via the ambient SF_PARTNER environment variable).
116-
req <- ellmer_req_user_agent(req, Sys.getenv("SF_PARTNER"))
117-
118-
# Snowflake-specific error response handling:
119-
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)
120-
121120
messages <- as_json(provider, turns)
122-
123-
body <- list(
121+
list(
124122
messages = messages,
125123
model = provider@model,
126124
stream = stream
127125
)
128-
body <- modify_list(body, provider@extra_args)
129-
req <- req_body_json(req, body)
130-
131-
req
132126
}
133127

134128
# Snowflake -> ellmer --------------------------------------------------------

0 commit comments

Comments
 (0)