Skip to content

Commit 38b3b15

Browse files
committed
openapi: finish streaming tool calls as tool_calls
1 parent aed1419 commit 38b3b15

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

openai/openai.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
266266
Index: 0,
267267
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
268268
FinishReason: func(reason string) *string {
269+
if len(toolCalls) > 0 {
270+
reason = "tool_calls"
271+
}
269272
if len(reason) > 0 {
270273
return &reason
271274
}
@@ -570,8 +573,9 @@ type BaseWriter struct {
570573
}
571574

572575
type ChatWriter struct {
573-
stream bool
574-
id string
576+
stream bool
577+
finished bool
578+
id string
575579
BaseWriter
576580
}
577581

@@ -620,7 +624,15 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
620624

621625
// chat chunk
622626
if w.stream {
623-
d, err := json.Marshal(toChunk(w.id, chatResponse))
627+
chunk := toChunk(w.id, chatResponse)
628+
if w.finished {
629+
// we've already finished the chat, usually for tool calls, so do
630+
// not continue to send empty choices.
631+
chunk.Choices = nil
632+
} else if len(chunk.Choices) > 0 && chunk.Choices[0].FinishReason != nil {
633+
w.finished = true
634+
}
635+
d, err := json.Marshal(chunk)
624636
if err != nil {
625637
return 0, err
626638
}

0 commit comments

Comments
 (0)