Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch api example #808

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,49 @@ func main() {
//
// fmt.Println(resp.Choices[0].Text)
}
```
</details>

<details>
<summary>BatchAPI</summary>

```go
package main

import (
"context"
"fmt"
)

func main() {
client := openai.NewClient("your token")
ctx := context.Background()
var chatCompletions = make([]openai.BatchChatCompletion, 5)
for i := 0; i < 5; i++ {
chatCompletions[i] = openai.BatchChatCompletion{
CustomID: fmt.Sprintf("req-%d", i),
ChatCompletion: openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("What is the square of %d?", i+1),
},
},
},
}
}

resp, err := client.CreateBatchWithChatCompletions(ctx, openai.CreateBatchWithChatCompletionsRequest{
ChatCompletions: chatCompletions,
})
if err != nil {
fmt.Printf("CreateBatchWithChatCompletions error: %v\n", err)
return
}
fmt.Println("batchID:", resp.ID)
}

```
</details>
See the `examples/` folder for more.
Expand Down
68 changes: 68 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"testing"
Expand Down Expand Up @@ -145,6 +146,73 @@ func TestCompletionStream(t *testing.T) {
}
}

func TestBatchAPI(t *testing.T) {
ctx := context.Background()
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}
var err error
c := openai.NewClient(apiToken)

req := openai.CreateBatchWithUploadFileRequest{
Endpoint: openai.BatchEndpointChatCompletions,
CompletionWindow: "24h",
}
for i := 0; i < 5; i++ {
req.AddChatCompletion(fmt.Sprintf("req-%d", i), openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("What is the square of %d?", i+1),
},
},
})
}
_, err = c.CreateBatchWithUploadFile(ctx, req)
checks.NoError(t, err, "CreateBatchWithUploadFile error")

var chatCompletions = make([]openai.BatchChatCompletion, 5)
for i := 0; i < 5; i++ {
chatCompletions[i] = openai.BatchChatCompletion{
CustomID: fmt.Sprintf("req-%d", i),
ChatCompletion: openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("What is the square of %d?", i+1),
},
},
},
}
}
_, err = c.CreateBatchWithChatCompletions(ctx, openai.CreateBatchWithChatCompletionsRequest{
ChatCompletions: chatCompletions,
})
checks.NoError(t, err, "CreateBatchWithChatCompletions error")

var embeddings = make([]openai.BatchEmbedding, 3)
for i := 0; i < 3; i++ {
embeddings[i] = openai.BatchEmbedding{
CustomID: fmt.Sprintf("req-%d", i),
Embedding: openai.EmbeddingRequest{
Input: "The food was delicious and the waiter...",
Model: openai.AdaEmbeddingV2,
EncodingFormat: openai.EmbeddingEncodingFormatFloat,
},
}
}
_, err = c.CreateBatchWithEmbeddings(ctx, openai.CreateBatchWithEmbeddingsRequest{
Embeddings: embeddings,
})
checks.NoError(t, err, "CreateBatchWithEmbeddings error")

_, err = c.ListBatch(ctx, nil, nil)
checks.NoError(t, err, "ListBatch error")
}

func TestAPIError(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
Expand Down
84 changes: 84 additions & 0 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,90 @@ func (c *Client) CreateBatchWithUploadFile(
})
}

type CreateBatchWithChatCompletionsRequest struct {
CompletionWindow string
Metadata map[string]any
FileName string
ChatCompletions []BatchChatCompletion
}

type BatchChatCompletion struct {
CustomID string
ChatCompletion ChatCompletionRequest
}

// CreateBatchWithChatCompletions — API call to Create batch with chat completions.
func (c *Client) CreateBatchWithChatCompletions(
ctx context.Context,
request CreateBatchWithChatCompletionsRequest,
) (response BatchResponse, err error) {
var file File
var lines = make([]BatchLineItem, len(request.ChatCompletions))
for i, completion := range request.ChatCompletions {
lines[i] = BatchChatCompletionRequest{
CustomID: completion.CustomID,
Body: completion.ChatCompletion,
Method: "POST",
URL: BatchEndpointChatCompletions,
}
}
file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{
FileName: request.FileName,
Lines: lines,
})
if err != nil {
return
}
return c.CreateBatch(ctx, CreateBatchRequest{
InputFileID: file.ID,
Endpoint: BatchEndpointChatCompletions,
CompletionWindow: request.CompletionWindow,
Metadata: request.Metadata,
})
}

type CreateBatchWithEmbeddingsRequest struct {
CompletionWindow string
Metadata map[string]any
FileName string
Embeddings []BatchEmbedding
}

type BatchEmbedding struct {
CustomID string `json:"custom_id"`
Embedding EmbeddingRequest `json:"body"`
}

// CreateBatchWithEmbeddings — API call to Create batch with embeddings.
func (c *Client) CreateBatchWithEmbeddings(
ctx context.Context,
request CreateBatchWithEmbeddingsRequest,
) (response BatchResponse, err error) {
var file File
var lines = make([]BatchLineItem, len(request.Embeddings))
for i, embedding := range request.Embeddings {
lines[i] = BatchEmbeddingRequest{
CustomID: embedding.CustomID,
Body: embedding.Embedding,
Method: "POST",
URL: BatchEndpointEmbeddings,
}
}
file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{
FileName: request.FileName,
Lines: lines,
})
if err != nil {
return
}
return c.CreateBatch(ctx, CreateBatchRequest{
InputFileID: file.ID,
Endpoint: BatchEndpointEmbeddings,
CompletionWindow: request.CompletionWindow,
Metadata: request.Metadata,
})
}

// RetrieveBatch — API call to Retrieve batch.
func (c *Client) RetrieveBatch(
ctx context.Context,
Expand Down
56 changes: 56 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,62 @@ func TestCreateBatchWithUploadFile(t *testing.T) {
checks.NoError(t, err, "CreateBatchWithUploadFile error")
}

func TestCreateBatchWithChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files", handleCreateFile)
server.RegisterHandler("/v1/batches", handleBatchEndpoint)

var chatCompletions = make([]openai.BatchChatCompletion, 5)
for i := 0; i < 5; i++ {
chatCompletions[i] = openai.BatchChatCompletion{
CustomID: fmt.Sprintf("req-%d", i),
ChatCompletion: openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("What is the square of %d?", i+1),
},
},
},
}
}
_, err := client.CreateBatchWithChatCompletions(
context.Background(),
openai.CreateBatchWithChatCompletionsRequest{
ChatCompletions: chatCompletions,
},
)
checks.NoError(t, err, "CreateBatchWithChatCompletions error")
}

func TestCreateBatchWithEmbeddings(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files", handleCreateFile)
server.RegisterHandler("/v1/batches", handleBatchEndpoint)

var embeddings = make([]openai.BatchEmbedding, 3)
for i := 0; i < 3; i++ {
embeddings[i] = openai.BatchEmbedding{
CustomID: fmt.Sprintf("req-%d", i),
Embedding: openai.EmbeddingRequest{
Input: "The food was delicious and the waiter...",
Model: openai.AdaEmbeddingV2,
EncodingFormat: openai.EmbeddingEncodingFormatFloat,
},
}
}
_, err := client.CreateBatchWithEmbeddings(
context.Background(),
openai.CreateBatchWithEmbeddingsRequest{
Embeddings: embeddings,
},
)
checks.NoError(t, err, "CreateBatchWithEmbeddings error")
}

func TestRetrieveBatch(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
6 changes: 6 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"CreateBatchWithUploadFile", func() (any, error) {
return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{})
}},
{"CreateBatchWithChatCompletions", func() (any, error) {
return client.CreateBatchWithChatCompletions(ctx, CreateBatchWithChatCompletionsRequest{})
}},
{"CreateBatchWithEmbeddings", func() (any, error) {
return client.CreateBatchWithEmbeddings(ctx, CreateBatchWithEmbeddingsRequest{})
}},
{"RetrieveBatch", func() (any, error) {
return client.RetrieveBatch(ctx, "")
}},
Expand Down
72 changes: 72 additions & 0 deletions examples/batch/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"context"
"fmt"
"github.com/sashabaranov/go-openai"
"io"
"log"
"os"
)

func main() {
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()

// create batch
response, err := createBatchChatCompletion(ctx, client)
if err != nil {
log.Fatal(err)
}
fmt.Printf("batchID: %s\n", response.ID)

// retrieve Batch
//batchID := "batch_XXXXXXXXXXXXX"
//retrieveBatch(ctx, client, batchID)
}

func createBatchChatCompletion(ctx context.Context, client *openai.Client) (openai.BatchResponse, error) {
var chatCompletions = make([]openai.BatchChatCompletion, 5)
for i := 0; i < 5; i++ {
chatCompletions[i] = openai.BatchChatCompletion{
CustomID: fmt.Sprintf("req-%d", i),
ChatCompletion: openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf("What is the square of %d?", i+1),
},
},
},
}
}

return client.CreateBatchWithChatCompletions(ctx, openai.CreateBatchWithChatCompletionsRequest{
ChatCompletions: chatCompletions,
})
}

func retrieveBatch(ctx context.Context, client *openai.Client, batchID string) {
batch, err := client.RetrieveBatch(ctx, batchID)
if err != nil {
return
}
fmt.Printf("batchStatus: %s\n", batch.Status)

files := map[string]*string{
"inputFile": &batch.InputFileID,
"outputFile": batch.OutputFileID,
"errorFile": batch.ErrorFileID,
}
for name, fileID := range files {
if fileID != nil {
content, err := client.GetFileContent(ctx, *fileID)
if err != nil {
return
}
all, _ := io.ReadAll(content)
fmt.Printf("%s: %s\n", name, all)
}
}
}
Loading