Skip to content

Commit

Permalink
Parallel tool calls + abstract away serialization framework (#758)
Browse files Browse the repository at this point in the history
* draft for parallel tool calls + abstract away serialization framework so AI features can be implemented outside of Kotlin

* Stream AI Events amd ParallelToolCalls example

* Support for creating tools with anonymous functions and passing descriptions on tool creations. Annotation based functions seem impossible in KMP, would only work on jvm target

* Progress running tools over different type shapes

* Attempt to fix tests

* Config for max rounds and concurrency of tool executions

* Rename Input to Value due to LLM confusion
  • Loading branch information
raulraja authored Jun 18, 2024
1 parent 9e854cd commit 7f52a44
Show file tree
Hide file tree
Showing 35 changed files with 1,028 additions and 519 deletions.
306 changes: 102 additions & 204 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
@@ -1,237 +1,135 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.Images
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.serializer

sealed interface AI {

@Serializable
@Description("The selected items indexes")
data class SelectedItems(
@Description("The selected items indexes") val selectedItems: List<Int>,
)

data class Classification(
val name: String,
val description: String,
)

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

interface PromptMultipleClassifier {
fun getItems(): List<Classification>

fun template(input: String): String {
val items = getItems()

return """
|Based on the <input>, identify whether the user is asking about one or more of the following items
|
|${
items.joinToString("\n") { item -> "<${item.name}>${item.description}</${item.name}>" }
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow

class AI<out A>(private val config: AIConfig, val serializer: Tool<A>) {

private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> =
config.api.promptStreaming(prompt, config.conversation, config.tools)

@PublishedApi
internal suspend operator fun invoke(prompt: Prompt): A =
when (val serializer = serializer) {
is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Enumeration<A> -> runWithEnumSingleTokenSerializer(serializer, prompt)
is Tool.FlowOfStreamedFunctions<*> -> {
config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
}
|
|<items>
|${
items.mapIndexed { index, item -> "\t<item index=\"$index\">${item.name}</item>" }
.joinToString("\n")
}
|</items>
|<input>
|$input
|</input>
"""
is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.FlowOfAIEventsSealed ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.sealedSerializer,
tools = config.tools,
collector = this
)
}
as A
is Tool.FlowOfAIEvents ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.serializer,
tools = config.tools,
collector = this
)
}
as A
}

@OptIn(ExperimentalSerializationApi::class)
fun KType.enumValuesName(
serializer: KSerializer<Any?> = serializer(this)
): List<Classification> {
return if (serializer.descriptor.kind != SerialKind.ENUM) {
emptyList()
} else {
(0 until serializer.descriptor.elementsCount).map { index ->
val name =
serializer.descriptor
.getElementName(index)
.removePrefix(serializer.descriptor.serialName)
val description =
(serializer.descriptor.getElementAnnotations(index).first { it is Description }
as Description)
.value
Classification(name, description)
private suspend fun runWithEnumSingleTokenSerializer(
serializer: Tool.Enumeration<A>,
prompt: Prompt
): A {
val encoding = prompt.model.modelType(forFunctions = false).encoding
val cases = serializer.cases
val logitBias =
cases
.flatMap {
val result = encoding.encode(it.function.name)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
}
.associate { "$it" to 100 }
val result =
config.api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages,
model = prompt.model,
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.choices[0].message.content
val enumSerializer = serializer.enumSerializer
return if (choice != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
}
}

companion object {

fun <A : Any> chat(
target: KType,
model: CreateChatCompletionRequestModel,
api: Chat,
conversation: Conversation,
enumSerializer: ((case: String) -> A)?,
caseSerializers: List<KSerializer<A>>,
serializer: () -> KSerializer<A>,
): DefaultAI<A> =
DefaultAI(
target = target,
model = model,
api = api,
serializer = serializer,
conversation = conversation,
enumSerializer = enumSerializer,
caseSerializers = caseSerializers
)

fun images(
config: Config = Config(),
): Images = OpenAI(config).images

@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A =
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)

/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
target: KType = typeOf<E>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
config: AIConfig = AIConfig(),
): E where E : Enum<E>, E : PromptClassifier {
val value = enumValues<E>().firstOrNull() ?: error("No values to classify")
return AI<E>(
prompt = value.template(input, output, context),
model = model,
target = target,
config = config,
api = api,
conversation = conversation
)
}

@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> multipleClassify(
input: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): List<E> where E : PromptMultipleClassifier, E : Enum<E> {
config: AIConfig = AIConfig(),
): List<E> where E : Enum<E>, E : PromptMultipleClassifier {
val values = enumValues<E>()
val value = values.firstOrNull() ?: error("No enum values found")
val value = values.firstOrNull() ?: error("No values to classify")
val selected: SelectedItems =
invoke(
AI(
prompt = value.template(input),
model = model,
config = config,
api = api,
conversation = conversation
serializer = Tool.fromKotlin<SelectedItems>(),
config = config
)
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(Prompt(model, prompt), target, config, api, conversation)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(prompt, target, config, api, conversation)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline fun <reified A : Any> chat(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A {
val kind =
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
?: error("Cannot find SerialKind for $target")
return when (kind) {
SerialKind.ENUM -> invokeEnum<A>(prompt, target, config, api, conversation)
else -> {
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = null,
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
}
}
}
}
}

@AiDsl
suspend inline fun <reified A> AI(
prompt: String,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig()
): A = AI(Prompt(config.model, prompt), serializer, config)

@AiDsl
suspend inline fun <reified A> AI(
prompt: Prompt,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig(),
): A = AI(config, serializer).invoke(prompt)
15 changes: 15 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.OpenAI
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.conversation.Conversation

data class AIConfig(
val tools: List<Tool<*>> = emptyList(),
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
val config: Config = Config(),
val openAI: OpenAI = OpenAI(config, logRequests = false),
val api: Chat = openAI.chat,
val conversation: Conversation = Conversation()
)
36 changes: 36 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.xebia.functional.xef

sealed class AIEvent<out A> {
data object Start : AIEvent<Nothing>()

data class Result<out A>(val value: A) : AIEvent<A>()

data class ToolExecutionRequest(val tool: Tool<*>, val input: Any?) : AIEvent<Nothing>()

data class ToolExecutionResponse(val tool: Tool<*>, val output: Any?) : AIEvent<Nothing>()

data class Stop(val usage: Usage) : AIEvent<Nothing>() {
data class Usage(
val llmCalls: Int,
val toolCalls: Int,
val inputTokens: Int,
val outputTokens: Int,
val totalTokens: Int,
)
}

fun debugPrint(): Unit =
when (this) {
// emoji for start is: 🚀
Start -> println("🚀 Starting...")
is Result -> println("🎉 $value")
is ToolExecutionRequest ->
println("🔧 Executing tool: ${tool.function.name} with input: $input")
is ToolExecutionResponse ->
println("🔨 Tool response: ${tool.function.name} resulted in: $output")
is Stop -> {
println("🛑 Stopping...")
println("📊 Usage: $usage")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.xef

data class Classification(
val name: String,
val description: String,
)
Loading

0 comments on commit 7f52a44

Please sign in to comment.