Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -238,12 +236,18 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro
r.Description = raw.Description
r.Strict = raw.Strict
if len(raw.Schema) > 0 && string(raw.Schema) != "null" {
var d jsonschema.Definition
err := json.Unmarshal(raw.Schema, &d)
if err != nil {
return err
// Validate that the schema is a JSON object (must start with '{')
// JSON Schema definitions must be objects
trimmed := raw.Schema
for len(trimmed) > 0 && (trimmed[0] == ' ' || trimmed[0] == '\t' || trimmed[0] == '\n' || trimmed[0] == '\r') {
trimmed = trimmed[1:]
}
if len(trimmed) == 0 || trimmed[0] != '{' {
return errors.New("schema must be a JSON object")
}
r.Schema = &d
// Use json.RawMessage directly to preserve all JSON Schema features
// (anyOf, oneOf, allOf, const, title, etc.) that jsonschema.Definition doesn't support
r.Schema = raw.Schema
}
return nil
}
Expand Down
245 changes: 245 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1205,3 +1205,248 @@
})
}
}

func TestChatCompletionResponseFormatJSONSchema_PreservesAllSchemaFeatures(t *testing.T) {
tests := []struct {
name string
inputSchema string
expectedInJSON string
}{
{
name: "preserves anyOf",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "null"}
]
}
}
}
}`,
expectedInJSON: "anyOf",
},
{
name: "preserves oneOf",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"oneOf": [
{"type": "string"},
{"type": "number"}
]
}
}`,
expectedInJSON: "oneOf",
},
{
name: "preserves allOf",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"allOf": [
{"type": "object"},
{"properties": {"name": {"type": "string"}}}
]
}
}`,
expectedInJSON: "allOf",
},
{
name: "preserves const",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"type": "object",
"properties": {
"status": {"const": "active"}
}
}
}`,
expectedInJSON: "const",
},
{
name: "preserves title",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"type": "object",
"title": "MySchema",
"properties": {
"name": {"type": "string"}
}
}
}`,
expectedInJSON: "title",
},
{
name: "preserves $defs and $ref",
inputSchema: `{
"name": "test",
"strict": true,
"schema": {
"type": "object",
"$defs": {
"Item": {
"type": "object",
"properties": {
"name": {"type": "string"}
}
}
},
"properties": {
"items": {
"type": "array",
"items": {"$ref": "#/$defs/Item"}
}
}
}
}`,
expectedInJSON: "$defs",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var schema openai.ChatCompletionResponseFormatJSONSchema
err := json.Unmarshal([]byte(tt.inputSchema), &schema)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}

marshaled, err := json.Marshal(schema)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}

if !strings.Contains(string(marshaled), tt.expectedInJSON) {
t.Errorf("Expected %q to be preserved in marshaled output, got: %s", tt.expectedInJSON, string(marshaled))
}
})
}
}

func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON_Validation(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
errContains string
}{
{
name: "valid schema with leading whitespace",
input: `{
"name": "test",
"strict": true,
"schema": {
"type": "object"
}
}`,
wantErr: false,
},
{
name: "invalid schema - array instead of object",
input: `{
"name": "test",
"strict": true,
"schema": ["not", "an", "object"]
}`,
wantErr: true,
errContains: "schema must be a JSON object",
},
{
name: "invalid schema - string instead of object",
input: `{
"name": "test",
"strict": true,
"schema": "not an object"
}`,
wantErr: true,
errContains: "schema must be a JSON object",
},
{
name: "invalid schema - number instead of object",
input: `{
"name": "test",
"strict": true,
"schema": 123
}`,
wantErr: true,
errContains: "schema must be a JSON object",
},
{
name: "invalid schema - boolean instead of object",
input: `{
"name": "test",
"strict": true,
"schema": true
}`,
wantErr: true,
errContains: "schema must be a JSON object",
},
{
name: "null schema is allowed",
input: `{
"name": "test",
"strict": true,
"schema": null
}`,
wantErr: false,
},
{
name: "no schema field is allowed",
input: `{
"name": "test",
"strict": true
}`,
wantErr: false,
},
{
name: "schema with newline before object",
input: `{
"name": "test",
"strict": true,
"schema":
{
"type": "object"
}
}`,
wantErr: false,
},
{
name: "schema with tabs before object",

Check failure on line 1426 in chat_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

File is not properly formatted (goimports)
input: "{\"name\": \"test\", \"strict\": true, \"schema\": \t\t{\"type\": \"object\"}}",
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var schema openai.ChatCompletionResponseFormatJSONSchema
err := json.Unmarshal([]byte(tt.input), &schema)

if tt.wantErr {
if err == nil {
t.Errorf("Expected error but got nil")
return
}
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
}
} else {

Check failure on line 1445 in chat_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

elseif: can replace 'else {if cond {}}' with 'else if cond {}' (gocritic)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
Loading