Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assistants API #57

Merged
merged 7 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 86 additions & 21 deletions src/OpenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using JSON3
using HTTP
using Dates


abstract type AbstractOpenAIProvider end
Base.@kwdef struct OpenAIProvider <: AbstractOpenAIProvider
api_key::String = ""
Expand All @@ -17,7 +18,7 @@ Base.@kwdef struct AzureProvider <: AbstractOpenAIProvider
end

"""
DEFAULT_PROVIDER
DEFAULT_PROVIDER

Default provider for OpenAI API requests.
"""
Expand Down Expand Up @@ -53,8 +54,8 @@ end

"""
build_url(provider::AbstractOpenAIProvider, api::AbstractString)

Return the URL for the given provider and API.
Return the URL for the given provider and API.
"""
build_url(provider::AbstractOpenAIProvider) = build_url(provider, provider.api)
function build_url(provider::OpenAIProvider, api::String)
Expand All @@ -75,9 +76,18 @@ function build_params(kwargs)
return buf
end

function request_body(url, method; input, headers, kwargs...)
input = input === nothing ? [] : input
resp = HTTP.request(method, url; body=input, headers=headers, kwargs...)
function request_body(url, method; input, headers, query, kwargs...)
input = isnothing(input) ? [] : input
query = isnothing(query) ? [] : query

resp = HTTP.request(
method,
url;
body=input,
query=query,
headers=headers,
kwargs...
)
return resp, resp.body
end

Expand Down Expand Up @@ -132,7 +142,17 @@ function status_error(resp, log=nothing)
error("request status $(resp.message)$logs")
end

function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key::AbstractString=provider.api_key; method, http_kwargs, streamcallback=nothing, kwargs...)
function _request(
api::AbstractString,
provider::AbstractOpenAIProvider,
api_key::AbstractString=provider.api_key;
method,
query=nothing,
http_kwargs,
streamcallback=nothing,
additional_headers::AbstractVector=Pair{String,String}[],
kwargs...
)
# add stream: True to the API call if a stream callback function is passed
if !isnothing(streamcallback)
kwargs = (kwargs..., stream=true)
Expand All @@ -141,10 +161,28 @@ function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key
params = build_params(kwargs)
url = build_url(provider, api)
resp, body = let
# Add whatever other headers we were given
headers = vcat(auth_header(provider, api_key), additional_headers)

if isnothing(streamcallback)
request_body(url, method; input=params, headers=auth_header(provider, api_key), http_kwargs...)
request_body(
url,
method;
input=params,
headers=headers,
query=query,
http_kwargs...
)
else
request_body_live(url; method, input=params, headers=auth_header(provider, api_key), streamcallback=streamcallback, http_kwargs...)
request_body_live(
url;
method,
input=params,
headers=headers,
query=query,
streamcallback=streamcallback,
http_kwargs...
)
end
end
if resp.status >= 400
Expand Down Expand Up @@ -278,7 +316,7 @@ message returned by the API.
julia> CC = create_chat(key, "gpt-3.5-turbo",
[Dict("role" => "user", "content"=> "What continent is New York in? Two word answer.")],
streamcallback = x->println(Dates.now()));
2023-03-27T12:34:50.428
2023-03-27T12:34:50.428
2023-03-27T12:34:50.524
2023-03-27T12:34:50.524
2023-03-27T12:34:50.524
Expand Down Expand Up @@ -336,14 +374,14 @@ Create embeddings
- `api_key::String`: OpenAI API key
- `input`: The input text to generate the embedding(s) for, as String or array of tokens.
To get embeddings for multiple inputs in a single request, pass an array of strings
or array of token arrays. Each input must not exceed 8192 tokens in length.
- `model_id::String`: Model id. Defaults to $DEFAULT_EMBEDDING_MODEL_ID.

# Keyword Arguments:
- `http_kwargs::NamedTuple`: Optional. Keyword arguments to pass to HTTP.request.

For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/embeddings>
"""
or array of token arrays. Each input must not exceed 8192 tokens in length.
- `model_id::String`: Model id. Defaults to $DEFAULT_EMBEDDING_MODEL_ID.
# Keyword Arguments:
- `http_kwargs::NamedTuple`: Optional. Keyword arguments to pass to HTTP.request.
For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/embeddings>
"""
function create_embeddings(api_key::String, input, model_id::String=DEFAULT_EMBEDDING_MODEL_ID; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
return openai_request("embeddings", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, input, kwargs...)
end
Expand Down Expand Up @@ -371,8 +409,6 @@ function create_images(api_key::String, prompt, n::Integer=1, size::String="256x
return openai_request("images/generations", api_key; method="POST", http_kwargs=http_kwargs, prompt, kwargs...)
end

# api usage status

"""
get_usage_status(provider::OpenAIProvider; numofdays::Int=99)

Expand All @@ -383,7 +419,7 @@ end
# Arguments:
- `provider::OpenAIProvider`: OpenAI provider object.
- `numofdays::Int`: Optional. Defaults to 99. The number of days to get usage status for.
Note that the maximum `numofdays` is 99.
Note that the maximum `numofdays` is 99.

# Returns:
- `quota`: The total quota for the subscription.(unit: USD)
Expand Down Expand Up @@ -450,6 +486,8 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
return (; quota, usage, daily_costs)
end

include("assistants.jl")

export OpenAIResponse
export list_models
export retrieve_model
Expand All @@ -460,4 +498,31 @@ export create_embeddings
export create_images
export get_usage_status

# Assistant exports
export list_assistants
export create_assistant
export get_assistant
export delete_assistant
export modify_assistant

# Thread exports
export create_thread
export retrieve_thread
export delete_thread
export modify_thread

# Message exports
export create_message
export list_messages
export retrieve_message
export delete_message
export modify_message

# Run exports
export create_run
export list_runs
export retrieve_run
export delete_run
export modify_run

end # module
Loading