Skip to content

Commit d02c98b

Browse files
committed
refactor: json check func should not be part of the public api
1 parent 7ae3cd4 commit d02c98b

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

internal/jsonext/json.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package jsonext
2+
3+
import (
4+
"encoding/json"
5+
)
6+
7+
func IsValidJSON[T string | []byte](data T) bool {
8+
if len(data) == 0 { // hot path
9+
return false
10+
}
11+
var m json.RawMessage
12+
err := json.Unmarshal([]byte(data), &m)
13+
return err == nil
14+
}

providers/openai.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"strings"
1212

1313
"github.com/charmbracelet/ai"
14+
"github.com/charmbracelet/ai/internal/jsonext"
1415
"github.com/google/uuid"
1516
"github.com/openai/openai-go/v2"
1617
"github.com/openai/openai-go/v2/option"
@@ -618,7 +619,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
618619
return
619620
}
620621
toolCalls[toolCallDelta.Index] = existingToolCall
621-
if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
622+
if jsonext.IsValidJSON(existingToolCall.arguments) {
622623
if !yield(ai.StreamPart{
623624
Type: ai.StreamPartTypeToolInputEnd,
624625
ID: existingToolCall.id,
@@ -679,7 +680,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
679680
}) {
680681
return
681682
}
682-
if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
683+
if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
683684
if !yield(ai.StreamPart{
684685
Type: ai.StreamPartTypeToolInputEnd,
685686
ID: toolCallDelta.ID,

util.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package ai
22

33
import (
4-
"encoding/json"
5-
64
"github.com/go-viper/mapstructure/v2"
75
)
86

@@ -13,9 +11,3 @@ func ParseOptions[T any](options map[string]any, m *T) error {
1311
func FloatOption(f float64) *float64 {
1412
return &f
1513
}
16-
17-
func IsParsableJSON(data string) bool {
18-
var m map[string]any
19-
err := json.Unmarshal([]byte(data), &m)
20-
return err == nil
21-
}

0 commit comments

Comments
 (0)