Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func main() {
Content: "Lorem ipsum",
},
},
Stream: true,
Stream: openai.TruePtr(),
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
Expand Down Expand Up @@ -177,7 +177,7 @@ func main() {
Model: openai.GPT3Babbage002,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
Stream: openai.TruePtr(),
}
stream, err := c.CreateCompletionStream(ctx, req)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestCompletionStream(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: openai.GPT3Babbage002,
MaxTokens: 5,
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down
18 changes: 16 additions & 2 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

// TruePtr returns a pointer to a true boolean value.
// This is useful for setting the Stream field in requests.
func TruePtr() *bool {
t := true
return &t
}

// FalsePtr returns a pointer to a false boolean value.
// This is useful for explicitly setting Stream to false in requests.
func FalsePtr() *bool {
f := false
return &f
}

// Chat message role defined by the OpenAI API.
const (
ChatMessageRoleSystem = "system"
Expand Down Expand Up @@ -274,7 +288,7 @@ type ChatCompletionRequest struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stream *bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Expand Down Expand Up @@ -467,7 +481,7 @@ func (c *Client) CreateChatCompletion(
ctx context.Context,
request ChatCompletionRequest,
) (response ChatCompletionResponse, err error) {
if request.Stream {
if request.Stream != nil && *request.Stream {
err = ErrChatCompletionStreamNotSupported
return
}
Expand Down
2 changes: 1 addition & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (c *Client) CreateChatCompletionStream(
return
}

request.Stream = true
request.Stream = TruePtr()
reasoningValidator := NewReasoningValidator()
if err = reasoningValidator.Validate(request); err != nil {
return
Expand Down
28 changes: 14 additions & 14 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -250,7 +250,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -328,7 +328,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
var apiErr *openai.APIError
if !errors.As(err, &apiErr) {
Expand Down Expand Up @@ -376,7 +376,7 @@ func TestCreateChatCompletionStreamWithRefusal(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -497,7 +497,7 @@ func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -632,7 +632,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
if !errors.As(err, &apiErr) {
t.Errorf("Did not return APIError: %+v\n", apiErr)
Expand Down Expand Up @@ -689,7 +689,7 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
StreamOptions: &openai.StreamOptions{
IncludeUsage: true,
},
Expand Down Expand Up @@ -835,7 +835,7 @@ func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -946,7 +946,7 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})

if stream != nil {
Expand All @@ -971,7 +971,7 @@ func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})

if stream != nil {
Expand All @@ -996,7 +996,7 @@ func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
Content: "Hello!",
},
},
Stream: true,
Stream: openai.TruePtr(),
})

if stream != nil {
Expand Down
41 changes: 40 additions & 1 deletion chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,53 @@ func TestChatRequestOmitEmpty(t *testing.T) {
}
}

func TestChatRequestStreamFalseExplicit(t *testing.T) {
// Test that stream=false is explicitly included when set using FalsePtr()
data, err := json.Marshal(openai.ChatCompletionRequest{
Model: "gpt-4",
Stream: openai.FalsePtr(),
})
checks.NoError(t, err)

if !strings.Contains(string(data), `"stream":false`) {
t.Errorf("expected stream:false to be present in JSON, but got: %v", string(data))
}
}

func TestChatRequestStreamTrueExplicit(t *testing.T) {
// Test that stream=true is explicitly included when set using TruePtr()
data, err := json.Marshal(openai.ChatCompletionRequest{
Model: "gpt-4",
Stream: openai.TruePtr(),
})
checks.NoError(t, err)

if !strings.Contains(string(data), `"stream":true`) {
t.Errorf("expected stream:true to be present in JSON, but got: %v", string(data))
}
}

func TestChatRequestStreamNilOmitted(t *testing.T) {
// Test that stream is omitted when nil (not set)
data, err := json.Marshal(openai.ChatCompletionRequest{
Model: "gpt-4",
})
checks.NoError(t, err)

if strings.Contains(string(data), `"stream"`) {
t.Errorf("expected stream to be omitted from JSON when nil, but got: %v", string(data))
}
}

func TestChatCompletionsWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

streamTrue := true
req := openai.ChatCompletionRequest{
Stream: true,
Stream: &streamTrue,
}
_, err := client.CreateChatCompletion(ctx, req)
checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error")
Expand Down
4 changes: 2 additions & 2 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type CompletionRequest struct {
PresencePenalty float32 `json:"presence_penalty,omitempty"`
Seed *int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
Stream *bool `json:"stream,omitempty"`
Suffix string `json:"suffix,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
Expand Down Expand Up @@ -264,7 +264,7 @@ func (c *Client) CreateCompletion(
ctx context.Context,
request CompletionRequest,
) (response CompletionResponse, err error) {
if request.Stream {
if request.Stream != nil && *request.Stream {
err = ErrCompletionStreamNotSupported
return
}
Expand Down
3 changes: 2 additions & 1 deletion completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ func TestCompletionWithStream(t *testing.T) {
client := openai.NewClientWithConfig(config)

ctx := context.Background()
req := openai.CompletionRequest{Stream: true}
streamTrue := true
req := openai.CompletionRequest{Stream: &streamTrue}
_, err := client.CreateCompletion(ctx, req)
if !errors.Is(err, openai.ErrCompletionStreamNotSupported) {
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
Expand Down
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ExampleClient_CreateChatCompletionStream() {
Content: "Lorem ipsum",
},
},
Stream: true,
Stream: openai.TruePtr(),
},
)
if err != nil {
Expand Down Expand Up @@ -102,7 +102,7 @@ func ExampleClient_CreateCompletionStream() {
Model: openai.GPT3Babbage002,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
Stream: openai.TruePtr(),
},
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (c *Client) CreateCompletionStream(
return
}

request.Stream = true
request.Stream = TruePtr()
req, err := c.newRequest(
ctx,
http.MethodPost,
Expand Down
14 changes: 7 additions & 7 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestCreateCompletionStream(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -133,7 +133,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
MaxTokens: 5,
Model: openai.GPT3TextDavinci003,
Prompt: "Hello!",
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -171,7 +171,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
MaxTokens: 5,
Model: openai.GPT3Babbage002,
Prompt: "Hello!",
Stream: true,
Stream: openai.TruePtr(),
})
if !errors.As(err, &apiErr) {
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
Expand Down Expand Up @@ -213,7 +213,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -248,7 +248,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand Down Expand Up @@ -289,7 +289,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
Stream: openai.TruePtr(),
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
Expand All @@ -316,7 +316,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
Stream: openai.TruePtr(),
})
if err == nil {
t.Fatal("Did not return error")
Expand Down
Loading