Skip to content

Commit 120785d

Browse files
committed
openai: finish streaming tool calls as tool_calls
1 parent aed1419 commit 120785d

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

openai/openai.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
266266
Index: 0,
267267
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
268268
FinishReason: func(reason string) *string {
269+
if len(toolCalls) > 0 {
270+
reason = "tool_calls"
271+
}
269272
if len(reason) > 0 {
270273
return &reason
271274
}
@@ -570,8 +573,9 @@ type BaseWriter struct {
570573
}
571574

572575
type ChatWriter struct {
573-
stream bool
574-
id string
576+
stream bool
577+
finished bool
578+
id string
575579
BaseWriter
576580
}
577581

@@ -620,15 +624,19 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
620624

621625
// chat chunk
622626
if w.stream {
623-
d, err := json.Marshal(toChunk(w.id, chatResponse))
624-
if err != nil {
625-
return 0, err
626-
}
627+
// If we've already finished, don't send any more chunks with choices.
628+
if !w.finished {
629+
chunk := toChunk(w.id, chatResponse)
630+
d, err := json.Marshal(chunk)
631+
if err != nil {
632+
return 0, err
633+
}
627634

628-
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
629-
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
630-
if err != nil {
631-
return 0, err
635+
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
636+
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
637+
if err != nil {
638+
return 0, err
639+
}
632640
}
633641

634642
if chatResponse.Done {

0 commit comments

Comments
 (0)