Skip to content

Commit

Permalink
Tool output response schema (#685)
Browse files Browse the repository at this point in the history
* Include output schema in submit tool outputs

* spotless
  • Loading branch information
raulraja authored Mar 15, 2024
1 parent 41c32d0 commit 7ce94b6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import com.xebia.functional.openai.models.ext.assistant.AssistantToolsCode
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsFunction
import com.xebia.functional.openai.models.ext.assistant.AssistantToolsRetrieval
import com.xebia.functional.xef.llm.fromEnvironment
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import io.ktor.util.logging.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
Expand Down Expand Up @@ -42,7 +44,7 @@ class Assistant(
assistantsApi
)

suspend inline fun getToolRegistered(name: String, args: String): JsonElement =
suspend inline fun getToolRegistered(name: String, args: String): ToolOutput =
try {
val toolConfig = toolsConfig.firstOrNull { it.functionObject.name == name }

Expand All @@ -51,20 +53,26 @@ class Assistant(

val tool: Tool<Any?, Any?> = toolConfig.tool as Tool<Any?, Any?>

val schema = buildJsonSchema(toolSerializer.outputSerializer.descriptor)
val output: Any? = tool(input)
ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
)
val result =
ApiClient.JSON_DEFAULT.encodeToJsonElement(
toolSerializer.outputSerializer as KSerializer<Any?>,
output
)
ToolOutput(schema, result)
} catch (e: Exception) {
val message = "Error calling to tool registered $name: ${e.message}"
val logger = KtorSimpleLogger("Functions")
logger.error(message, e)
JsonObject(mapOf("error" to JsonPrimitive(message)))
val result = JsonObject(mapOf("error" to JsonPrimitive(message)))
ToolOutput(JsonObject(emptyMap()), result)
}

companion object {

@Serializable data class ToolOutput(val schema: JsonObject, val result: JsonElement)

suspend operator fun invoke(
model: String,
name: String? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject

class AssistantThread(
Expand Down Expand Up @@ -200,13 +199,12 @@ class AssistantThread(
run.status == RunObject.Status.requires_action &&
run.requiredAction?.type == RunObjectRequiredAction.Type.submit_tool_outputs
) {
val results: Map<String, JsonElement> =
val results: Map<String, Assistant.Companion.ToolOutput> =
calls
.filter { it.function != null }
.parMap { toolCall ->
val function = toolCall.function!!
val result: JsonElement =
assistant.getToolRegistered(function.name, function.arguments)
val result = assistant.getToolRegistered(function.name, function.arguments)
toolCall.id to result
}
.toMap()
Expand All @@ -222,7 +220,11 @@ class AssistantThread(
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output = ApiClient.JSON_DEFAULT.encodeToString(result)
output =
ApiClient.JSON_DEFAULT.encodeToString(
Assistant.Companion.ToolOutput.serializer(),
result
)
)
}
)
Expand Down

0 comments on commit 7ce94b6

Please sign in to comment.