Skip to content

Commit 7f52a44

Browse files
authored
Parallel tool calls + abstract away serialization framework (#758)
* draft for parallel tool calls + abstract away serialization framework so AI features can be implemented outside of Kotlin * Stream AI Events amd ParallelToolCalls example * Support for creating tools with anonymous functions and passing descriptions on tool creations. Annotation based functions seem impossible in KMP, would only work on jvm target * Progress running tools over different type shapes * Attempt to fix tests * Config for max rounds and concurrency of tool executions * Rename Input to Value due to LLM confusion
1 parent 9e854cd commit 7f52a44

File tree

35 files changed

+1028
-519
lines changed

35 files changed

+1028
-519
lines changed
Lines changed: 102 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -1,237 +1,135 @@
11
package com.xebia.functional.xef
22

3-
import com.xebia.functional.openai.generated.api.Chat
4-
import com.xebia.functional.openai.generated.api.Images
5-
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
3+
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
64
import com.xebia.functional.xef.conversation.AiDsl
7-
import com.xebia.functional.xef.conversation.Conversation
8-
import com.xebia.functional.xef.conversation.Description
5+
import com.xebia.functional.xef.llm.models.modelType
6+
import com.xebia.functional.xef.llm.prompt
7+
import com.xebia.functional.xef.llm.promptStreaming
98
import com.xebia.functional.xef.prompt.Prompt
10-
import kotlin.coroutines.cancellation.CancellationException
11-
import kotlin.reflect.KClass
12-
import kotlin.reflect.KType
13-
import kotlin.reflect.typeOf
14-
import kotlinx.serialization.ExperimentalSerializationApi
15-
import kotlinx.serialization.InternalSerializationApi
16-
import kotlinx.serialization.KSerializer
17-
import kotlinx.serialization.Serializable
18-
import kotlinx.serialization.descriptors.SerialKind
19-
import kotlinx.serialization.serializer
20-
21-
sealed interface AI {
22-
23-
@Serializable
24-
@Description("The selected items indexes")
25-
data class SelectedItems(
26-
@Description("The selected items indexes") val selectedItems: List<Int>,
27-
)
28-
29-
data class Classification(
30-
val name: String,
31-
val description: String,
32-
)
33-
34-
interface PromptClassifier {
35-
fun template(input: String, output: String, context: String): String
36-
}
37-
38-
interface PromptMultipleClassifier {
39-
fun getItems(): List<Classification>
40-
41-
fun template(input: String): String {
42-
val items = getItems()
43-
44-
return """
45-
|Based on the <input>, identify whether the user is asking about one or more of the following items
46-
|
47-
|${
48-
items.joinToString("\n") { item -> "<${item.name}>${item.description}</${item.name}>" }
9+
import kotlinx.coroutines.flow.Flow
10+
import kotlinx.coroutines.flow.channelFlow
11+
12+
class AI<out A>(private val config: AIConfig, val serializer: Tool<A>) {
13+
14+
private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> =
15+
config.api.promptStreaming(prompt, config.conversation, config.tools)
16+
17+
@PublishedApi
18+
internal suspend operator fun invoke(prompt: Prompt): A =
19+
when (val serializer = serializer) {
20+
is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
21+
is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
22+
is Tool.Enumeration<A> -> runWithEnumSingleTokenSerializer(serializer, prompt)
23+
is Tool.FlowOfStreamedFunctions<*> -> {
24+
config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
4925
}
50-
|
51-
|<items>
52-
|${
53-
items.mapIndexed { index, item -> "\t<item index=\"$index\">${item.name}</item>" }
54-
.joinToString("\n")
55-
}
56-
|</items>
57-
|<input>
58-
|$input
59-
|</input>
60-
"""
26+
is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
27+
is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
28+
is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
29+
is Tool.FlowOfAIEventsSealed ->
30+
channelFlow {
31+
send(AIEvent.Start)
32+
config.api.prompt(
33+
prompt = prompt,
34+
scope = config.conversation,
35+
serializer = serializer.sealedSerializer,
36+
tools = config.tools,
37+
collector = this
38+
)
39+
}
40+
as A
41+
is Tool.FlowOfAIEvents ->
42+
channelFlow {
43+
send(AIEvent.Start)
44+
config.api.prompt(
45+
prompt = prompt,
46+
scope = config.conversation,
47+
serializer = serializer.serializer,
48+
tools = config.tools,
49+
collector = this
50+
)
51+
}
52+
as A
6153
}
6254

63-
@OptIn(ExperimentalSerializationApi::class)
64-
fun KType.enumValuesName(
65-
serializer: KSerializer<Any?> = serializer(this)
66-
): List<Classification> {
67-
return if (serializer.descriptor.kind != SerialKind.ENUM) {
68-
emptyList()
69-
} else {
70-
(0 until serializer.descriptor.elementsCount).map { index ->
71-
val name =
72-
serializer.descriptor
73-
.getElementName(index)
74-
.removePrefix(serializer.descriptor.serialName)
75-
val description =
76-
(serializer.descriptor.getElementAnnotations(index).first { it is Description }
77-
as Description)
78-
.value
79-
Classification(name, description)
55+
private suspend fun runWithEnumSingleTokenSerializer(
56+
serializer: Tool.Enumeration<A>,
57+
prompt: Prompt
58+
): A {
59+
val encoding = prompt.model.modelType(forFunctions = false).encoding
60+
val cases = serializer.cases
61+
val logitBias =
62+
cases
63+
.flatMap {
64+
val result = encoding.encode(it.function.name)
65+
if (result.size > 1) {
66+
error("Cannot encode enum case $it into one token")
67+
}
68+
result
8069
}
81-
}
70+
.associate { "$it" to 100 }
71+
val result =
72+
config.api.createChatCompletion(
73+
CreateChatCompletionRequest(
74+
messages = prompt.messages,
75+
model = prompt.model,
76+
logitBias = logitBias,
77+
maxTokens = 1,
78+
temperature = 0.0
79+
)
80+
)
81+
val choice = result.choices[0].message.content
82+
val enumSerializer = serializer.enumSerializer
83+
return if (choice != null) {
84+
enumSerializer(choice)
85+
} else {
86+
error("Cannot decode enum case from $choice")
8287
}
8388
}
8489

8590
companion object {
86-
87-
fun <A : Any> chat(
88-
target: KType,
89-
model: CreateChatCompletionRequestModel,
90-
api: Chat,
91-
conversation: Conversation,
92-
enumSerializer: ((case: String) -> A)?,
93-
caseSerializers: List<KSerializer<A>>,
94-
serializer: () -> KSerializer<A>,
95-
): DefaultAI<A> =
96-
DefaultAI(
97-
target = target,
98-
model = model,
99-
api = api,
100-
serializer = serializer,
101-
conversation = conversation,
102-
enumSerializer = enumSerializer,
103-
caseSerializers = caseSerializers
104-
)
105-
106-
fun images(
107-
config: Config = Config(),
108-
): Images = OpenAI(config).images
109-
110-
@PublishedApi
111-
internal suspend inline fun <reified A : Any> invokeEnum(
112-
prompt: Prompt,
113-
target: KType = typeOf<A>(),
114-
config: Config = Config(),
115-
api: Chat = OpenAI(config).chat,
116-
conversation: Conversation = Conversation()
117-
): A =
118-
chat(
119-
target = target,
120-
model = prompt.model,
121-
api = api,
122-
conversation = conversation,
123-
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
124-
caseSerializers = emptyList()
125-
) {
126-
serializer<A>()
127-
}
128-
.invoke(prompt)
129-
130-
/**
131-
* Classify a prompt using a given enum.
132-
*
133-
* @param input The input to the model.
134-
* @param output The output to the model.
135-
* @param context The context to the model.
136-
* @param model The model to use.
137-
* @param target The target type to return.
138-
* @param api The chat API to use.
139-
* @param conversation The conversation to use.
140-
* @return The classified enum.
141-
* @throws IllegalArgumentException If no enum values are found.
142-
*/
14391
@AiDsl
144-
@Throws(IllegalArgumentException::class, CancellationException::class)
14592
suspend inline fun <reified E> classify(
14693
input: String,
14794
output: String,
14895
context: String,
149-
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
150-
target: KType = typeOf<E>(),
151-
config: Config = Config(),
152-
api: Chat = OpenAI(config).chat,
153-
conversation: Conversation = Conversation()
154-
): E where E : PromptClassifier, E : Enum<E> {
155-
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
156-
return invoke(
96+
config: AIConfig = AIConfig(),
97+
): E where E : Enum<E>, E : PromptClassifier {
98+
val value = enumValues<E>().firstOrNull() ?: error("No values to classify")
99+
return AI<E>(
157100
prompt = value.template(input, output, context),
158-
model = model,
159-
target = target,
160101
config = config,
161-
api = api,
162-
conversation = conversation
163102
)
164103
}
165104

166105
@AiDsl
167-
@Throws(IllegalArgumentException::class, CancellationException::class)
168106
suspend inline fun <reified E> multipleClassify(
169107
input: String,
170-
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
171-
config: Config = Config(),
172-
api: Chat = OpenAI(config).chat,
173-
conversation: Conversation = Conversation()
174-
): List<E> where E : PromptMultipleClassifier, E : Enum<E> {
108+
config: AIConfig = AIConfig(),
109+
): List<E> where E : Enum<E>, E : PromptMultipleClassifier {
175110
val values = enumValues<E>()
176-
val value = values.firstOrNull() ?: error("No enum values found")
111+
val value = values.firstOrNull() ?: error("No values to classify")
177112
val selected: SelectedItems =
178-
invoke(
113+
AI(
179114
prompt = value.template(input),
180-
model = model,
181-
config = config,
182-
api = api,
183-
conversation = conversation
115+
serializer = Tool.fromKotlin<SelectedItems>(),
116+
config = config
184117
)
185118
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
186119
}
187-
188-
@AiDsl
189-
suspend inline operator fun <reified A : Any> invoke(
190-
prompt: String,
191-
target: KType = typeOf<A>(),
192-
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125,
193-
config: Config = Config(),
194-
api: Chat = OpenAI(config).chat,
195-
conversation: Conversation = Conversation()
196-
): A = chat(Prompt(model, prompt), target, config, api, conversation)
197-
198-
@AiDsl
199-
suspend inline operator fun <reified A : Any> invoke(
200-
prompt: Prompt,
201-
target: KType = typeOf<A>(),
202-
config: Config = Config(),
203-
api: Chat = OpenAI(config).chat,
204-
conversation: Conversation = Conversation()
205-
): A = chat(prompt, target, config, api, conversation)
206-
207-
@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
208-
@AiDsl
209-
suspend inline fun <reified A : Any> chat(
210-
prompt: Prompt,
211-
target: KType = typeOf<A>(),
212-
config: Config = Config(),
213-
api: Chat = OpenAI(config).chat,
214-
conversation: Conversation = Conversation()
215-
): A {
216-
val kind =
217-
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
218-
?: error("Cannot find SerialKind for $target")
219-
return when (kind) {
220-
SerialKind.ENUM -> invokeEnum<A>(prompt, target, config, api, conversation)
221-
else -> {
222-
chat(
223-
target = target,
224-
model = prompt.model,
225-
api = api,
226-
conversation = conversation,
227-
enumSerializer = null,
228-
caseSerializers = emptyList()
229-
) {
230-
serializer<A>()
231-
}
232-
.invoke(prompt)
233-
}
234-
}
235-
}
236120
}
237121
}
122+
123+
@AiDsl
124+
suspend inline fun <reified A> AI(
125+
prompt: String,
126+
serializer: Tool<A> = Tool.fromKotlin<A>(),
127+
config: AIConfig = AIConfig()
128+
): A = AI(Prompt(config.model, prompt), serializer, config)
129+
130+
@AiDsl
131+
suspend inline fun <reified A> AI(
132+
prompt: Prompt,
133+
serializer: Tool<A> = Tool.fromKotlin<A>(),
134+
config: AIConfig = AIConfig(),
135+
): A = AI(config, serializer).invoke(prompt)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.xebia.functional.xef
2+
3+
import com.xebia.functional.openai.generated.api.Chat
4+
import com.xebia.functional.openai.generated.api.OpenAI
5+
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
6+
import com.xebia.functional.xef.conversation.Conversation
7+
8+
data class AIConfig(
9+
val tools: List<Tool<*>> = emptyList(),
10+
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
11+
val config: Config = Config(),
12+
val openAI: OpenAI = OpenAI(config, logRequests = false),
13+
val api: Chat = openAI.chat,
14+
val conversation: Conversation = Conversation()
15+
)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.xebia.functional.xef
2+
3+
sealed class AIEvent<out A> {
4+
data object Start : AIEvent<Nothing>()
5+
6+
data class Result<out A>(val value: A) : AIEvent<A>()
7+
8+
data class ToolExecutionRequest(val tool: Tool<*>, val input: Any?) : AIEvent<Nothing>()
9+
10+
data class ToolExecutionResponse(val tool: Tool<*>, val output: Any?) : AIEvent<Nothing>()
11+
12+
data class Stop(val usage: Usage) : AIEvent<Nothing>() {
13+
data class Usage(
14+
val llmCalls: Int,
15+
val toolCalls: Int,
16+
val inputTokens: Int,
17+
val outputTokens: Int,
18+
val totalTokens: Int,
19+
)
20+
}
21+
22+
fun debugPrint(): Unit =
23+
when (this) {
24+
// emoji for start is: 🚀
25+
Start -> println("🚀 Starting...")
26+
is Result -> println("🎉 $value")
27+
is ToolExecutionRequest ->
28+
println("🔧 Executing tool: ${tool.function.name} with input: $input")
29+
is ToolExecutionResponse ->
30+
println("🔨 Tool response: ${tool.function.name} resulted in: $output")
31+
is Stop -> {
32+
println("🛑 Stopping...")
33+
println("📊 Usage: $usage")
34+
}
35+
}
36+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package com.xebia.functional.xef
2+
3+
data class Classification(
4+
val name: String,
5+
val description: String,
6+
)

0 commit comments

Comments
 (0)