Skip to content

Commit c6b8f4f

Browse files
doslindostmc
andauthored
ollama: Fix JSON format bug issue when not streaming (#892)
* Graceful handling when LLM spits whitespace on json mode with Ollama. * ollama: Simplify stream repr, spruce up fn calling example --------- Co-authored-by: Travis Cline <[email protected]>
1 parent e380482 commit c6b8f4f

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

examples/ollama-functions-example/ollama_functions_example.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"context"
55
"encoding/json"
6+
"flag"
67
"fmt"
78
"log"
89
"os"
@@ -12,10 +13,13 @@ import (
1213
"github.com/tmc/langchaingo/llms/ollama"
1314
)
1415

16+
var flagVerbose = flag.Bool("v", false, "verbose mode")
17+
1518
func main() {
19+
flag.Parse()
1620
// allow specifying your own model via OLLAMA_TEST_MODEL
1721
// (same as the Ollama unit tests).
18-
model := "mistral:instruct"
22+
model := "llama3"
1923
if v := os.Getenv("OLLAMA_TEST_MODEL"); v != "" {
2024
model = v
2125
}
@@ -31,14 +35,12 @@ func main() {
3135
var msgs []llms.MessageContent
3236

3337
// system message defines the available tools.
34-
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem,
35-
systemMessage()))
36-
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman,
37-
"What's the weather like in Beijing?"))
38+
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem, systemMessage()))
39+
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "What's the weather like in Beijing?"))
3840

3941
ctx := context.Background()
4042

41-
for {
43+
for retries := 3; retries > 0; retries = retries - 1 {
4244
resp, err := llm.GenerateContent(ctx, msgs)
4345
if err != nil {
4446
log.Fatal(err)
@@ -49,19 +51,23 @@ func main() {
4951

5052
if c := unmarshalCall(choice1.Content); c != nil {
5153
log.Printf("Call: %v", c.Tool)
52-
54+
if *flagVerbose {
55+
log.Printf("Call: %v (raw: %v)", c.Tool, choice1.Content)
56+
}
5357
msg, cont := dispatchCall(c)
5458
if !cont {
5559
break
5660
}
57-
5861
msgs = append(msgs, msg)
5962
} else {
6063
// Ollama doesn't always respond with a function call, let it try again.
6164
log.Printf("Not a call: %v", choice1.Content)
62-
6365
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "Sorry, I don't understand. Please try again."))
6466
}
67+
68+
if retries == 0 {
69+
log.Fatal("retries exhausted")
70+
}
6571
}
6672
}
6773

@@ -72,20 +78,17 @@ type Call struct {
7278

7379
func unmarshalCall(input string) *Call {
7480
var c Call
75-
7681
if err := json.Unmarshal([]byte(input), &c); err == nil && c.Tool != "" {
7782
return &c
7883
}
79-
8084
return nil
8185
}
8286

8387
func dispatchCall(c *Call) (llms.MessageContent, bool) {
8488
// ollama doesn't always respond with a *valid* function call. As we're using prompt
8589
// engineering to inject the tools, it may hallucinate.
8690
if !validTool(c.Tool) {
87-
log.Printf("invalid function call: %#v", c)
88-
91+
log.Printf("invalid function call: %#v, prompting model to try again", c)
8992
return llms.TextParts(llms.ChatMessageTypeHuman,
9093
"Tool does not exist, please try again."), true
9194
}
@@ -106,7 +109,7 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {
106109
if err != nil {
107110
log.Fatal(err)
108111
}
109-
return llms.TextParts(llms.ChatMessageTypeSystem, weather), true
112+
return llms.TextParts(llms.ChatMessageTypeHuman, weather), true
110113
case "finalResponse":
111114
resp, ok := c.Input["response"].(string)
112115
if !ok {
@@ -124,11 +127,9 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {
124127

125128
func validTool(name string) bool {
126129
var valid []string
127-
128130
for _, v := range functions {
129131
valid = append(valid, v.Name)
130132
}
131-
132133
return slices.Contains(valid, name)
133134
}
134135

llms/ollama/internal/ollamaclient/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type Message struct {
4949
type ChatRequest struct {
5050
Model string `json:"model"`
5151
Messages []*Message `json:"messages"`
52-
Stream *bool `json:"stream,omitempty"`
52+
Stream bool `json:"stream,omitempty"`
5353
Format string `json:"format"`
5454
KeepAlive string `json:"keep_alive,omitempty"`
5555

llms/ollama/ollamallm.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
108108
Format: format,
109109
Messages: chatMsgs,
110110
Options: ollamaOptions,
111-
Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil),
111+
Stream: opts.StreamingFunc != nil,
112112
}
113113

114114
keepAlive := o.options.keepAlive
@@ -129,7 +129,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
129129
if response.Message != nil {
130130
streamedResponse += response.Message.Content
131131
}
132-
if response.Done {
132+
if !req.Stream || response.Done {
133133
resp = response
134134
resp.Message = &ollamaclient.Message{
135135
Role: "assistant",

0 commit comments

Comments
 (0)