@@ -3,6 +3,7 @@ package main
33import (
44 "context"
55 "encoding/json"
6+ "flag"
67 "fmt"
78 "log"
89 "os"
@@ -12,10 +13,13 @@ import (
1213 "github.com/tmc/langchaingo/llms/ollama"
1314)
1415
16+ var flagVerbose = flag .Bool ("v" , false , "verbose mode" )
17+
1518func main () {
19+ flag .Parse ()
1620 // allow specifying your own model via OLLAMA_TEST_MODEL
1721 // (same as the Ollama unit tests).
18- model := "mistral:instruct "
22+ model := "llama3 "
1923 if v := os .Getenv ("OLLAMA_TEST_MODEL" ); v != "" {
2024 model = v
2125 }
@@ -31,14 +35,12 @@ func main() {
3135 var msgs []llms.MessageContent
3236
3337 // system message defines the available tools.
34- msgs = append (msgs , llms .TextParts (llms .ChatMessageTypeSystem ,
35- systemMessage ()))
36- msgs = append (msgs , llms .TextParts (llms .ChatMessageTypeHuman ,
37- "What's the weather like in Beijing?" ))
38+ msgs = append (msgs , llms .TextParts (llms .ChatMessageTypeSystem , systemMessage ()))
39+ msgs = append (msgs , llms .TextParts (llms .ChatMessageTypeHuman , "What's the weather like in Beijing?" ))
3840
3941 ctx := context .Background ()
4042
41- for {
43+ for retries := 3 ; retries > 0 ; retries = retries - 1 {
4244 resp , err := llm .GenerateContent (ctx , msgs )
4345 if err != nil {
4446 log .Fatal (err )
@@ -49,19 +51,23 @@ func main() {
4951
5052 if c := unmarshalCall (choice1 .Content ); c != nil {
5153 log .Printf ("Call: %v" , c .Tool )
52-
54+ if * flagVerbose {
55+ log .Printf ("Call: %v (raw: %v)" , c .Tool , choice1 .Content )
56+ }
5357 msg , cont := dispatchCall (c )
5458 if ! cont {
5559 break
5660 }
57-
5861 msgs = append (msgs , msg )
5962 } else {
6063 // Ollama doesn't always respond with a function call, let it try again.
6164 log .Printf ("Not a call: %v" , choice1 .Content )
62-
6365 msgs = append (msgs , llms .TextParts (llms .ChatMessageTypeHuman , "Sorry, I don't understand. Please try again." ))
6466 }
67+
68+ if retries == 0 {
69+ log .Fatal ("retries exhausted" )
70+ }
6571 }
6672}
6773
@@ -72,20 +78,17 @@ type Call struct {
7278
7379func unmarshalCall (input string ) * Call {
7480 var c Call
75-
7681 if err := json .Unmarshal ([]byte (input ), & c ); err == nil && c .Tool != "" {
7782 return & c
7883 }
79-
8084 return nil
8185}
8286
8387func dispatchCall (c * Call ) (llms.MessageContent , bool ) {
8488 // ollama doesn't always respond with a *valid* function call. As we're using prompt
8589 // engineering to inject the tools, it may hallucinate.
8690 if ! validTool (c .Tool ) {
87- log .Printf ("invalid function call: %#v" , c )
88-
91+ log .Printf ("invalid function call: %#v, prompting model to try again" , c )
8992 return llms .TextParts (llms .ChatMessageTypeHuman ,
9093 "Tool does not exist, please try again." ), true
9194 }
@@ -106,7 +109,7 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {
106109 if err != nil {
107110 log .Fatal (err )
108111 }
109- return llms .TextParts (llms .ChatMessageTypeSystem , weather ), true
112+ return llms .TextParts (llms .ChatMessageTypeHuman , weather ), true
110113 case "finalResponse" :
111114 resp , ok := c .Input ["response" ].(string )
112115 if ! ok {
@@ -124,11 +127,9 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {
124127
125128func validTool (name string ) bool {
126129 var valid []string
127-
128130 for _ , v := range functions {
129131 valid = append (valid , v .Name )
130132 }
131-
132133 return slices .Contains (valid , name )
133134}
134135
0 commit comments