Skip to content

Commit

Permalink
Expose Assistant Tool input and output json config (#790)
Browse files Browse the repository at this point in the history
* feat: expose tool input and output json config

* refactor: move json config out of companion
  • Loading branch information
realdavidvega authored Sep 27, 2024
1 parent ee1c9e8 commit e5916d6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ class Assistant(
suspend inline fun getToolRegistered(name: String, args: String): ToolOutput =
try {
val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name }
val (inputSerializer, outputSerializer) =
toolConfig?.serialization ?: error("Function $name not registered")

val toolSerializer = toolConfig?.serializers ?: error("Function $name not registered")
val input = toolConfig.json.decodeFromString(toolSerializer.inputSerializer, args)

val input = inputSerializer.json.decodeFromString(inputSerializer.serializer, args)
val tool: Tool<Any?, Any?> = toolConfig.tool as Tool<Any?, Any?>

val schema = buildJsonSchema(toolSerializer.outputSerializer.descriptor)
val schema = buildJsonSchema(outputSerializer.serializer.descriptor)
val output: Any? = tool(input)
val result =
toolConfig.json.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
outputSerializer.json.encodeToJsonElement(
outputSerializer.serializer as KSerializer<Any?>,
output
)
ToolOutput(schema, result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,40 @@ import kotlinx.serialization.serializer
fun interface Tool<Input, out Output> {
suspend operator fun invoke(input: Input): Output

data class JsonConfig(val inputJson: Json, val outputJson: Json) {
companion object {
val Default = JsonConfig(inputJson = Json.Default, outputJson = Json.Default)
}
}

companion object {

data class ToolConfig<Input, out Output>(
val functionObject: FunctionObject,
val serializers: ToolSerializer,
val tool: Tool<Input, Output>,
val json: Json = Json.Default
val serialization: ToolSerialization,
val tool: Tool<Input, Output>
)

data class ToolSerializer(
val inputSerializer: KSerializer<*>,
val outputSerializer: KSerializer<*>
data class ToolSerialization(
val inputSerializer: ToolSerializer,
val outputSerializer: ToolSerializer
)

inline fun <reified I, reified O> toolOf(tool: Tool<I, O>): ToolConfig<I, O> {
val inputSerializer = serializer<I>()
val outputSerializer = serializer<O>()
val toolSerializer = ToolSerializer(inputSerializer, outputSerializer)
val fn = chatFunction(inputSerializer.descriptor)
data class ToolSerializer(val serializer: KSerializer<*>, val json: Json)

inline fun <reified I, reified O> toolOf(
tool: Tool<I, O>,
jsonConfig: JsonConfig = JsonConfig.Default
): ToolConfig<I, O> {
val inputSerializer = ToolSerializer(serializer<I>(), jsonConfig.inputJson)
val outputSerializer = ToolSerializer(serializer<O>(), jsonConfig.outputJson)
val toolSerializer = ToolSerialization(inputSerializer, outputSerializer)
val fn = chatFunction(inputSerializer.serializer.descriptor)
val fnName = tool::class.simpleName ?: error("unnamed class")
val fnDescription = defaultFunctionDescription(fnName)
return ToolConfig(
functionObject = fn.copy(name = fnName, description = fnDescription),
serializers = toolSerializer,
serialization = toolSerializer,
tool = tool
)
}
Expand Down

0 comments on commit e5916d6

Please sign in to comment.