Skip to content

Commit 34c2960

Browse files
authored
fix: marshal/unmarshal ProviderError when Cause is openai.APIError (#78)
1 parent ade493f commit 34c2960

File tree

3 files changed

+151
-66
lines changed

3 files changed

+151
-66
lines changed

content_json.go

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,8 @@ func (t *ToolResultContent) UnmarshalJSON(data []byte) error {
393393

394394
// MarshalJSON implements json.Marshaler for ToolResultOutputContentText.
395395
func (t ToolResultOutputContentText) MarshalJSON() ([]byte, error) {
396-
dataBytes, err := json.Marshal(struct {
397-
Text string `json:"text"`
398-
}{
399-
Text: t.Text,
400-
})
396+
type alias ToolResultOutputContentText
397+
dataBytes, err := json.Marshal(alias(t))
401398
if err != nil {
402399
return nil, err
403400
}
@@ -415,15 +412,14 @@ func (t *ToolResultOutputContentText) UnmarshalJSON(data []byte) error {
415412
return err
416413
}
417414

418-
var temp struct {
419-
Text string `json:"text"`
420-
}
415+
type alias ToolResultOutputContentText
416+
var temp alias
421417

422418
if err := json.Unmarshal(tr.Data, &temp); err != nil {
423419
return err
424420
}
425421

426-
t.Text = temp.Text
422+
*t = ToolResultOutputContentText(temp)
427423
return nil
428424
}
429425

@@ -470,13 +466,8 @@ func (t *ToolResultOutputContentError) UnmarshalJSON(data []byte) error {
470466

471467
// MarshalJSON implements json.Marshaler for ToolResultOutputContentMedia.
472468
func (t ToolResultOutputContentMedia) MarshalJSON() ([]byte, error) {
473-
dataBytes, err := json.Marshal(struct {
474-
Data string `json:"data"`
475-
MediaType string `json:"media_type"`
476-
}{
477-
Data: t.Data,
478-
MediaType: t.MediaType,
479-
})
469+
type alias ToolResultOutputContentMedia
470+
dataBytes, err := json.Marshal(alias(t))
480471
if err != nil {
481472
return nil, err
482473
}
@@ -494,17 +485,14 @@ func (t *ToolResultOutputContentMedia) UnmarshalJSON(data []byte) error {
494485
return err
495486
}
496487

497-
var temp struct {
498-
Data string `json:"data"`
499-
MediaType string `json:"media_type"`
500-
}
488+
type alias ToolResultOutputContentMedia
489+
var temp alias
501490

502491
if err := json.Unmarshal(tr.Data, &temp); err != nil {
503492
return err
504493
}
505494

506-
t.Data = temp.Data
507-
t.MediaType = temp.MediaType
495+
*t = ToolResultOutputContentMedia(temp)
508496
return nil
509497
}
510498

@@ -870,15 +858,8 @@ func (f *FunctionTool) UnmarshalJSON(data []byte) error {
870858

871859
// MarshalJSON implements json.Marshaler for ProviderDefinedTool.
872860
func (p ProviderDefinedTool) MarshalJSON() ([]byte, error) {
873-
dataBytes, err := json.Marshal(struct {
874-
ID string `json:"id"`
875-
Name string `json:"name"`
876-
Args map[string]any `json:"args"`
877-
}{
878-
ID: p.ID,
879-
Name: p.Name,
880-
Args: p.Args,
881-
})
861+
type alias ProviderDefinedTool
862+
dataBytes, err := json.Marshal(alias(p))
882863
if err != nil {
883864
return nil, err
884865
}
@@ -896,19 +877,14 @@ func (p *ProviderDefinedTool) UnmarshalJSON(data []byte) error {
896877
return err
897878
}
898879

899-
var aux struct {
900-
ID string `json:"id"`
901-
Name string `json:"name"`
902-
Args map[string]any `json:"args"`
903-
}
880+
type alias ProviderDefinedTool
881+
var aux alias
904882

905883
if err := json.Unmarshal(tj.Data, &aux); err != nil {
906884
return err
907885
}
908886

909-
p.ID = aux.ID
910-
p.Name = aux.Name
911-
p.Args = aux.Args
887+
*p = ProviderDefinedTool(aux)
912888

913889
return nil
914890
}

json_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,108 @@ func TestPromptSerialization(t *testing.T) {
645645
}
646646
})
647647
}
648+
649+
func TestStreamPartErrorSerialization(t *testing.T) {
650+
t.Run("stream part with ProviderError containing OpenAI API error", func(t *testing.T) {
651+
// Create a mock OpenAI API error
652+
openaiErr := errors.New("invalid_api_key: Incorrect API key provided")
653+
654+
// Wrap in ProviderError
655+
providerErr := &ProviderError{
656+
Title: "unauthorized",
657+
Message: "Incorrect API key provided",
658+
Cause: openaiErr,
659+
URL: "https://api.openai.com/v1/chat/completions",
660+
StatusCode: 401,
661+
RequestBody: []byte(`{"model":"gpt-4","messages":[]}`),
662+
ResponseHeaders: map[string]string{
663+
"content-type": "application/json",
664+
},
665+
ResponseBody: []byte(`{"error":{"message":"Incorrect API key provided","type":"invalid_request_error"}}`),
666+
}
667+
668+
// Create StreamPart with error
669+
streamPart := StreamPart{
670+
Type: StreamPartTypeError,
671+
Error: providerErr,
672+
}
673+
674+
// Marshal the stream part
675+
data, err := json.Marshal(streamPart)
676+
if err != nil {
677+
t.Fatalf("failed to marshal stream part: %v", err)
678+
}
679+
680+
// Unmarshal back
681+
var decoded StreamPart
682+
err = json.Unmarshal(data, &decoded)
683+
if err != nil {
684+
t.Fatalf("failed to unmarshal stream part: %v", err)
685+
}
686+
687+
// Verify the stream part type
688+
if decoded.Type != StreamPartTypeError {
689+
t.Errorf("type mismatch: got %v, want %v", decoded.Type, StreamPartTypeError)
690+
}
691+
692+
// Verify error exists
693+
if decoded.Error == nil {
694+
t.Fatal("expected error to be present, got nil")
695+
}
696+
697+
// Verify error message
698+
expectedMsg := "unauthorized: Incorrect API key provided"
699+
if decoded.Error.Error() != expectedMsg {
700+
t.Errorf("error message mismatch: got %q, want %q", decoded.Error.Error(), expectedMsg)
701+
}
702+
})
703+
704+
t.Run("unmarshal stream part with error from JSON", func(t *testing.T) {
705+
// JSON representing a StreamPart with an error
706+
jsonData := `{
707+
"type": "error",
708+
"error": "unauthorized: Incorrect API key provided",
709+
"id": "",
710+
"tool_call_name": "",
711+
"tool_call_input": "",
712+
"delta": "",
713+
"provider_executed": false,
714+
"usage": {
715+
"input_tokens": 0,
716+
"output_tokens": 0,
717+
"total_tokens": 0,
718+
"reasoning_tokens": 0,
719+
"cache_creation_tokens": 0,
720+
"cache_read_tokens": 0
721+
},
722+
"finish_reason": "",
723+
"warnings": null,
724+
"source_type": "",
725+
"url": "",
726+
"title": "",
727+
"provider_metadata": null
728+
}`
729+
730+
var streamPart StreamPart
731+
err := json.Unmarshal([]byte(jsonData), &streamPart)
732+
if err != nil {
733+
t.Fatalf("failed to unmarshal stream part: %v", err)
734+
}
735+
736+
// Verify the stream part type
737+
if streamPart.Type != StreamPartTypeError {
738+
t.Errorf("type mismatch: got %v, want %v", streamPart.Type, StreamPartTypeError)
739+
}
740+
741+
// Verify error exists
742+
if streamPart.Error == nil {
743+
t.Fatal("expected error to be present, got nil")
744+
}
745+
746+
// Verify error message
747+
expectedMsg := "unauthorized: Incorrect API key provided"
748+
if streamPart.Error.Error() != expectedMsg {
749+
t.Errorf("error message mismatch: got %q, want %q", streamPart.Error.Error(), expectedMsg)
750+
}
751+
})
752+
}

model_json.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,42 +102,46 @@ func (r *Response) UnmarshalJSON(data []byte) error {
102102
return nil
103103
}
104104

105+
// MarshalJSON implements json.Marshaler for StreamPart.
106+
func (s StreamPart) MarshalJSON() ([]byte, error) {
107+
type alias StreamPart
108+
aux := struct {
109+
alias
110+
Error string `json:"error,omitempty"`
111+
}{
112+
alias: (alias)(s),
113+
}
114+
115+
// Marshal error to string
116+
if s.Error != nil {
117+
aux.Error = s.Error.Error()
118+
}
119+
120+
// Clear the original Error field to avoid duplicate marshaling
121+
aux.alias.Error = nil
122+
123+
return json.Marshal(aux)
124+
}
125+
105126
// UnmarshalJSON implements json.Unmarshaler for StreamPart.
106127
func (s *StreamPart) UnmarshalJSON(data []byte) error {
107-
var aux struct {
108-
Type StreamPartType `json:"type"`
109-
ID string `json:"id"`
110-
ToolCallName string `json:"tool_call_name"`
111-
ToolCallInput string `json:"tool_call_input"`
112-
Delta string `json:"delta"`
113-
ProviderExecuted bool `json:"provider_executed"`
114-
Usage Usage `json:"usage"`
115-
FinishReason FinishReason `json:"finish_reason"`
116-
Error error `json:"error"`
117-
Warnings []CallWarning `json:"warnings"`
118-
SourceType SourceType `json:"source_type"`
119-
URL string `json:"url"`
120-
Title string `json:"title"`
128+
type alias StreamPart
129+
aux := struct {
130+
*alias
131+
Error string `json:"error"`
121132
ProviderMetadata map[string]json.RawMessage `json:"provider_metadata"`
133+
}{
134+
alias: (*alias)(s),
122135
}
123136

124137
if err := json.Unmarshal(data, &aux); err != nil {
125138
return err
126139
}
127140

128-
s.Type = aux.Type
129-
s.ID = aux.ID
130-
s.ToolCallName = aux.ToolCallName
131-
s.ToolCallInput = aux.ToolCallInput
132-
s.Delta = aux.Delta
133-
s.ProviderExecuted = aux.ProviderExecuted
134-
s.Usage = aux.Usage
135-
s.FinishReason = aux.FinishReason
136-
s.Error = aux.Error
137-
s.Warnings = aux.Warnings
138-
s.SourceType = aux.SourceType
139-
s.URL = aux.URL
140-
s.Title = aux.Title
141+
// Unmarshal error string back to error type
142+
if aux.Error != "" {
143+
s.Error = fmt.Errorf("%s", aux.Error)
144+
}
141145

142146
// Unmarshal ProviderMetadata
143147
if len(aux.ProviderMetadata) > 0 {

0 commit comments

Comments
 (0)