Skip to content

Commit eb55b4e

Browse files
authored
Assistants API (#57)
* add beginning of assistant stuff * remove api_key grab * add message api * add some tests, additional endpoints * complete test suite * uncomment remaining tests * add create_thread_and_run
1 parent 71a69e0 commit eb55b4e

File tree

4 files changed

+1244
-30
lines changed

4 files changed

+1244
-30
lines changed

src/OpenAI.jl

+86-21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using JSON3
44
using HTTP
55
using Dates
66

7+
78
abstract type AbstractOpenAIProvider end
89
Base.@kwdef struct OpenAIProvider <: AbstractOpenAIProvider
910
api_key::String = ""
@@ -17,7 +18,7 @@ Base.@kwdef struct AzureProvider <: AbstractOpenAIProvider
1718
end
1819

1920
"""
20-
DEFAULT_PROVIDER
21+
DEFAULT_PROVIDER
2122
2223
Default provider for OpenAI API requests.
2324
"""
@@ -53,8 +54,8 @@ end
5354

5455
"""
5556
build_url(provider::AbstractOpenAIProvider, api::AbstractString)
56-
57-
Return the URL for the given provider and API.
57+
58+
Return the URL for the given provider and API.
5859
"""
5960
build_url(provider::AbstractOpenAIProvider) = build_url(provider, provider.api)
6061
function build_url(provider::OpenAIProvider, api::String)
@@ -75,9 +76,18 @@ function build_params(kwargs)
7576
return buf
7677
end
7778

78-
function request_body(url, method; input, headers, kwargs...)
79-
input = input === nothing ? [] : input
80-
resp = HTTP.request(method, url; body=input, headers=headers, kwargs...)
79+
function request_body(url, method; input, headers, query, kwargs...)
80+
input = isnothing(input) ? [] : input
81+
query = isnothing(query) ? [] : query
82+
83+
resp = HTTP.request(
84+
method,
85+
url;
86+
body=input,
87+
query=query,
88+
headers=headers,
89+
kwargs...
90+
)
8191
return resp, resp.body
8292
end
8393

@@ -132,7 +142,17 @@ function status_error(resp, log=nothing)
132142
error("request status $(resp.message)$logs")
133143
end
134144

135-
function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key::AbstractString=provider.api_key; method, http_kwargs, streamcallback=nothing, kwargs...)
145+
function _request(
146+
api::AbstractString,
147+
provider::AbstractOpenAIProvider,
148+
api_key::AbstractString=provider.api_key;
149+
method,
150+
query=nothing,
151+
http_kwargs,
152+
streamcallback=nothing,
153+
additional_headers::AbstractVector=Pair{String,String}[],
154+
kwargs...
155+
)
136156
# add stream: True to the API call if a stream callback function is passed
137157
if !isnothing(streamcallback)
138158
kwargs = (kwargs..., stream=true)
@@ -141,10 +161,28 @@ function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key
141161
params = build_params(kwargs)
142162
url = build_url(provider, api)
143163
resp, body = let
164+
# Add whatever other headers we were given
165+
headers = vcat(auth_header(provider, api_key), additional_headers)
166+
144167
if isnothing(streamcallback)
145-
request_body(url, method; input=params, headers=auth_header(provider, api_key), http_kwargs...)
168+
request_body(
169+
url,
170+
method;
171+
input=params,
172+
headers=headers,
173+
query=query,
174+
http_kwargs...
175+
)
146176
else
147-
request_body_live(url; method, input=params, headers=auth_header(provider, api_key), streamcallback=streamcallback, http_kwargs...)
177+
request_body_live(
178+
url;
179+
method,
180+
input=params,
181+
headers=headers,
182+
query=query,
183+
streamcallback=streamcallback,
184+
http_kwargs...
185+
)
148186
end
149187
end
150188
if resp.status >= 400
@@ -278,7 +316,7 @@ message returned by the API.
278316
julia> CC = create_chat(key, "gpt-3.5-turbo",
279317
[Dict("role" => "user", "content"=> "What continent is New York in? Two word answer.")],
280318
streamcallback = x->println(Dates.now()));
281-
2023-03-27T12:34:50.428
319+
2023-03-27T12:34:50.428
282320
2023-03-27T12:34:50.524
283321
2023-03-27T12:34:50.524
284322
2023-03-27T12:34:50.524
@@ -336,14 +374,14 @@ Create embeddings
336374
- `api_key::String`: OpenAI API key
337375
- `input`: The input text to generate the embedding(s) for, as String or array of tokens.
338376
To get embeddings for multiple inputs in a single request, pass an array of strings
339-
or array of token arrays. Each input must not exceed 8192 tokens in length.
340-
- `model_id::String`: Model id. Defaults to $DEFAULT_EMBEDDING_MODEL_ID.
341-
342-
# Keyword Arguments:
343-
- `http_kwargs::NamedTuple`: Optional. Keyword arguments to pass to HTTP.request.
344-
345-
For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/embeddings>
346-
"""
377+
or array of token arrays. Each input must not exceed 8192 tokens in length.
378+
- `model_id::String`: Model id. Defaults to $DEFAULT_EMBEDDING_MODEL_ID.
379+
380+
# Keyword Arguments:
381+
- `http_kwargs::NamedTuple`: Optional. Keyword arguments to pass to HTTP.request.
382+
383+
For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/embeddings>
384+
"""
347385
function create_embeddings(api_key::String, input, model_id::String=DEFAULT_EMBEDDING_MODEL_ID; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
348386
return openai_request("embeddings", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, input, kwargs...)
349387
end
@@ -371,8 +409,6 @@ function create_images(api_key::String, prompt, n::Integer=1, size::String="256x
371409
return openai_request("images/generations", api_key; method="POST", http_kwargs=http_kwargs, prompt, kwargs...)
372410
end
373411

374-
# api usage status
375-
376412
"""
377413
get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
378414
@@ -383,7 +419,7 @@ end
383419
# Arguments:
384420
- `provider::OpenAIProvider`: OpenAI provider object.
385421
- `numofdays::Int`: Optional. Defaults to 99. The number of days to get usage status for.
386-
Note that the maximum `numofdays` is 99.
422+
Note that the maximum `numofdays` is 99.
387423
388424
# Returns:
389425
- `quota`: The total quota for the subscription.(unit: USD)
@@ -450,6 +486,8 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
450486
return (; quota, usage, daily_costs)
451487
end
452488

489+
include("assistants.jl")
490+
453491
export OpenAIResponse
454492
export list_models
455493
export retrieve_model
@@ -460,4 +498,31 @@ export create_embeddings
460498
export create_images
461499
export get_usage_status
462500

501+
# Assistant exports
502+
export list_assistants
503+
export create_assistant
504+
export get_assistant
505+
export delete_assistant
506+
export modify_assistant
507+
508+
# Thread exports
509+
export create_thread
510+
export retrieve_thread
511+
export delete_thread
512+
export modify_thread
513+
514+
# Message exports
515+
export create_message
516+
export list_messages
517+
export retrieve_message
518+
export delete_message
519+
export modify_message
520+
521+
# Run exports
522+
export create_run
523+
export list_runs
524+
export retrieve_run
525+
export delete_run
526+
export modify_run
527+
463528
end # module

0 commit comments

Comments
 (0)