Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 139 additions & 17 deletions llms/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,17 @@ func convertSchemaRecursive(schemaMap map[string]any, toolIndex int, propertyPat
// convertTools converts from a list of langchaingo tools to a list of genai
// tools.
func convertTools(tools []llms.Tool) ([]*genai.Tool, error) {
genaiFuncDecls := make([]*genai.FunctionDeclaration, 0, len(tools))
if len(tools) == 0 {
return nil, nil
}

// Initialize a single genaiTool to hold all function declarations.
// This approach is used because the GoogleAI API expects a single tool
// with multiple function declarations, rather than multiple tools each with
// a single function declaration.
genaiTool := genai.Tool{
FunctionDeclarations: make([]*genai.FunctionDeclaration, 0, len(tools)),
}
for i, tool := range tools {
if tool.Type != "function" {
return nil, fmt.Errorf("tool [%d]: unsupported type %q, want 'function'", i, tool.Type)
Expand All @@ -504,33 +514,145 @@ func convertTools(tools []llms.Tool) ([]*genai.Tool, error) {
Description: tool.Function.Description,
}

// Expect the Parameters field to be a map[string]any, from which we will
// extract properties to populate the schema.
params, ok := tool.Function.Parameters.(map[string]any)
schema, err := convertToSchema(tool.Function.Parameters, true)
if err != nil {
return nil, fmt.Errorf("tool [%d]: %w", i, err)
}
genaiFuncDecl.Parameters = schema

genaiTool.FunctionDeclarations = append(genaiTool.FunctionDeclarations, genaiFuncDecl)
}

return []*genai.Tool{&genaiTool}, nil
}

// convert map[any]any to map[string]any if possible
func convertMaps(i any) any {
switch v := i.(type) {
case map[any]any:
m := make(map[string]any)
for key, val := range v {
sKey, ok := key.(string)
if !ok {
return v
}
m[sKey] = convertMaps(val)
}
return m
case []any:
s := make([]any, len(v))
for idx, val := range v {
s[idx] = convertMaps(val)
}
return s
}
return i
}

func convertToSchema(e any, topLevel bool) (*genai.Schema, error) {
e = convertMaps(e)
schema := &genai.Schema{}

eMap, ok := e.(map[string]any)
if !ok {
return nil, fmt.Errorf("tool: unsupported type %T of Parameters", e)
}

if ty, ok := eMap["type"]; ok {
tyString, ok := ty.(string)
if !ok {
return nil, fmt.Errorf("tool [%d]: unsupported type %T of Parameters", i, tool.Function.Parameters)
return nil, fmt.Errorf("tool: expected string for type")
}
schema.Type = convertToolSchemaType(tyString)

if topLevel && schema.Type != genai.TypeObject {
return nil, fmt.Errorf("tool: top-level schema must be an object")
}
}

schema, err := convertSchemaRecursive(params, i, "")
_, ok = eMap["properties"]
if ok {
paramProperties, ok := eMap["properties"].(map[string]any)
if !ok {
return nil, fmt.Errorf("tool: expected map[string]any for properties")
}
schema.Properties = make(map[string]*genai.Schema)
for propName, propValue := range paramProperties {
recSchema, err := convertToSchema(propValue, false)
if err != nil {
return nil, fmt.Errorf("tool, property [%v]: %w", propName, err)
}
schema.Properties[propName] = recSchema
}
} else if schema.Type == genai.TypeObject {
return nil, fmt.Errorf("tool: object schema must have properties")
}

items, ok := eMap["items"]
if ok {
itemsSchema, err := convertToSchema(items, false)
if err != nil {
return nil, err
return nil, fmt.Errorf("tool: %w", err)
}
genaiFuncDecl.Parameters = schema
schema.Items = itemsSchema
} else if schema.Type == genai.TypeArray {
return nil, fmt.Errorf("tool: array schema must have items")
}

// google genai only support one tool, multiple tools must be embedded into function declarations:
// https://github.com/GoogleCloudPlatform/generative-ai/issues/636
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#chat-samples
genaiFuncDecls = append(genaiFuncDecls, genaiFuncDecl)
if description, ok := eMap["description"]; ok {
descString, ok := description.(string)
if !ok {
return nil, fmt.Errorf("tool: expected string for description")
}
schema.Description = descString
}

// Return nil if no tools are provided
if len(genaiFuncDecls) == 0 {
return nil, nil
if nullable, ok := eMap["nullable"]; ok {
nullableBool, ok := nullable.(bool)
if !ok {
return nil, fmt.Errorf("tool: expected bool for nullable")
}
schema.Nullable = nullableBool
}

genaiTools := []*genai.Tool{{FunctionDeclarations: genaiFuncDecls}}
if enum, ok := eMap["enum"]; ok {
enumSlice, err := convertToSliceOfStrings(enum)
if err != nil {
return nil, fmt.Errorf("tool: %w", err)
}
schema.Enum = enumSlice

return genaiTools, nil
}

if required, ok := eMap["required"]; ok {
requiredSlice, err := convertToSliceOfStrings(required)
if err != nil {
return nil, fmt.Errorf("tool: %w", err)
}
schema.Required = requiredSlice
}

return schema, nil
}

func convertToSliceOfStrings(e any) ([]string, error) {
if rs, ok := e.([]string); ok {
return rs, nil
}

ri, ok := e.([]interface{})
if !ok {
return nil, fmt.Errorf("tool: expected []interface{} for required")
}
rs := make([]string, 0, len(ri))
for _, r := range ri {
rString, ok := r.(string)
if !ok {
return nil, fmt.Errorf("tool: expected string for required")
}
rs = append(rs, rString)
}
return rs, nil
}

// convertToolSchemaType converts a tool's schema type from its langchaingo
Expand Down
2 changes: 1 addition & 1 deletion llms/googleai/googleai_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ func TestConvertTools(t *testing.T) { //nolint:funlen // comprehensive test //no
}
result, err := convertTools(tools)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expected to find a map of properties")
assert.Contains(t, err.Error(), "object schema must have properties")
assert.Nil(t, result)
})

Expand Down
Loading