Skip to content

Commit 2124f7f

Browse files
authored
fix(mistral): supports the default llms.WithTools implementation (#970)
1 parent 71160f9 commit 2124f7f

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

llms/mistral/mistralmodel.go

+58-19
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,28 @@ func mistralChatParamsFromCallOptions(callOpts *llms.CallOptions) sdk.ChatReques
117117
chatOpts.Temperature = callOpts.Temperature
118118
chatOpts.RandomSeed = callOpts.Seed
119119
chatOpts.Tools = make([]sdk.Tool, 0)
120-
for _, function := range callOpts.Functions {
121-
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
122-
Type: "function",
123-
Function: sdk.Function{
124-
Name: function.Name,
125-
Description: function.Description,
126-
Parameters: function.Parameters,
127-
},
128-
})
120+
if len(callOpts.Tools) > 0 {
121+
for _, tool := range callOpts.Tools {
122+
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
123+
Type: "function",
124+
Function: sdk.Function{
125+
Name: tool.Function.Name,
126+
Description: tool.Function.Description,
127+
Parameters: tool.Function.Parameters,
128+
},
129+
})
130+
}
131+
} else {
132+
for _, function := range callOpts.Functions {
133+
chatOpts.Tools = append(chatOpts.Tools, sdk.Tool{
134+
Type: "function",
135+
Function: sdk.Function{
136+
Name: function.Name,
137+
Description: function.Description,
138+
Parameters: function.Parameters,
139+
},
140+
})
141+
}
129142
}
130143
return chatOpts
131144
}
@@ -159,6 +172,16 @@ func generateNonStreamingContent(ctx context.Context, m *Model, callOptions *llm
159172
toolCalls := choice.Message.ToolCalls
160173
if len(toolCalls) > 0 {
161174
langchainContentResponse.Choices[idx].FuncCall = (*llms.FunctionCall)(&toolCalls[0].Function)
175+
for _, tool := range toolCalls {
176+
langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{
177+
ID: tool.Id,
178+
Type: string(tool.Type),
179+
FunctionCall: &llms.FunctionCall{
180+
Name: tool.Function.Name,
181+
Arguments: tool.Function.Arguments,
182+
},
183+
})
184+
}
162185
}
163186
}
164187
m.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, langchainContentResponse)
@@ -192,6 +215,16 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C
192215
langchainContentResponse.Choices[0].StopReason = string(choice.FinishReason)
193216
if len(choice.Delta.ToolCalls) > 0 {
194217
langchainContentResponse.Choices[0].FuncCall = (*llms.FunctionCall)(&choice.Delta.ToolCalls[0].Function)
218+
for _, tool := range choice.Delta.ToolCalls {
219+
langchainContentResponse.Choices[0].ToolCalls = append(langchainContentResponse.Choices[0].ToolCalls, llms.ToolCall{
220+
ID: tool.Id,
221+
Type: string(tool.Type),
222+
FunctionCall: &llms.FunctionCall{
223+
Name: tool.Function.Name,
224+
Arguments: tool.Function.Arguments,
225+
},
226+
})
227+
}
195228
}
196229
}
197230
err := callOptions.StreamingFunc(ctx, []byte(chunkStr))
@@ -209,19 +242,25 @@ func generateStreamingContent(ctx context.Context, m *Model, callOptions *llms.C
209242
func convertToMistralChatMessages(langchainMessages []llms.MessageContent) ([]sdk.ChatMessage, error) {
210243
messages := make([]sdk.ChatMessage, 0)
211244
for _, msg := range langchainMessages {
212-
msgText := ""
213245
for _, part := range msg.Parts {
214-
textContent, ok := part.(llms.TextContent)
215-
if !ok {
246+
switch p := part.(type) {
247+
case llms.TextContent:
248+
chatMsg := sdk.ChatMessage{Content: p.Text, Role: string(msg.Role)}
249+
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
250+
if chatMsg.Content != "" && chatMsg.Role != "" {
251+
messages = append(messages, chatMsg)
252+
}
253+
case llms.ToolCallResponse:
254+
chatMsg := sdk.ChatMessage{Role: string(msg.Role), Content: p.Content}
255+
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
256+
messages = append(messages, chatMsg)
257+
case llms.ToolCall:
258+
chatMsg := sdk.ChatMessage{Role: string(msg.Role), ToolCalls: []sdk.ToolCall{{Id: p.ID, Type: sdk.ToolTypeFunction, Function: sdk.FunctionCall{Name: p.FunctionCall.Name, Arguments: p.FunctionCall.Arguments}}}}
259+
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
260+
messages = append(messages, chatMsg)
261+
default:
216262
return nil, errors.New("unsupported content type encountered while preparing chat messages to send to mistral platform")
217263
}
218-
msgText += textContent.Text
219-
}
220-
chatMsg := sdk.ChatMessage{Content: msgText, Role: "user"}
221-
222-
setMistralChatMessageRole(&msg, &chatMsg) // #nosec G601
223-
if chatMsg.Content != "" && chatMsg.Role != "" {
224-
messages = append(messages, chatMsg)
225264
}
226265
}
227266
return messages, nil

0 commit comments

Comments
 (0)