|
1 | 1 | package com.xebia.functional.xef
|
2 | 2 |
|
3 |
| -import com.xebia.functional.openai.generated.api.Chat |
4 |
| -import com.xebia.functional.openai.generated.api.Images |
5 |
| -import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel |
| 3 | +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest |
6 | 4 | import com.xebia.functional.xef.conversation.AiDsl
|
7 |
| -import com.xebia.functional.xef.conversation.Conversation |
8 |
| -import com.xebia.functional.xef.conversation.Description |
| 5 | +import com.xebia.functional.xef.llm.models.modelType |
| 6 | +import com.xebia.functional.xef.llm.prompt |
| 7 | +import com.xebia.functional.xef.llm.promptStreaming |
9 | 8 | import com.xebia.functional.xef.prompt.Prompt
|
10 |
| -import kotlin.coroutines.cancellation.CancellationException |
11 |
| -import kotlin.reflect.KClass |
12 |
| -import kotlin.reflect.KType |
13 |
| -import kotlin.reflect.typeOf |
14 |
| -import kotlinx.serialization.ExperimentalSerializationApi |
15 |
| -import kotlinx.serialization.InternalSerializationApi |
16 |
| -import kotlinx.serialization.KSerializer |
17 |
| -import kotlinx.serialization.Serializable |
18 |
| -import kotlinx.serialization.descriptors.SerialKind |
19 |
| -import kotlinx.serialization.serializer |
20 |
| - |
21 |
| -sealed interface AI { |
22 |
| - |
23 |
| - @Serializable |
24 |
| - @Description("The selected items indexes") |
25 |
| - data class SelectedItems( |
26 |
| - @Description("The selected items indexes") val selectedItems: List<Int>, |
27 |
| - ) |
28 |
| - |
29 |
| - data class Classification( |
30 |
| - val name: String, |
31 |
| - val description: String, |
32 |
| - ) |
33 |
| - |
34 |
| - interface PromptClassifier { |
35 |
| - fun template(input: String, output: String, context: String): String |
36 |
| - } |
37 |
| - |
38 |
| - interface PromptMultipleClassifier { |
39 |
| - fun getItems(): List<Classification> |
40 |
| - |
41 |
| - fun template(input: String): String { |
42 |
| - val items = getItems() |
43 |
| - |
44 |
| - return """ |
45 |
| - |Based on the <input>, identify whether the user is asking about one or more of the following items |
46 |
| - | |
47 |
| - |${ |
48 |
| - items.joinToString("\n") { item -> "<${item.name}>${item.description}</${item.name}>" } |
| 9 | +import kotlinx.coroutines.flow.Flow |
| 10 | +import kotlinx.coroutines.flow.channelFlow |
| 11 | + |
| 12 | +class AI<out A>(private val config: AIConfig, val serializer: Tool<A>) { |
| 13 | + |
| 14 | + private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> = |
| 15 | + config.api.promptStreaming(prompt, config.conversation, config.tools) |
| 16 | + |
| 17 | + @PublishedApi |
| 18 | + internal suspend operator fun invoke(prompt: Prompt): A = |
| 19 | + when (val serializer = serializer) { |
| 20 | + is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools) |
| 21 | + is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools) |
| 22 | + is Tool.Enumeration<A> -> runWithEnumSingleTokenSerializer(serializer, prompt) |
| 23 | + is Tool.FlowOfStreamedFunctions<*> -> { |
| 24 | + config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A |
49 | 25 | }
|
50 |
| - | |
51 |
| - |<items> |
52 |
| - |${ |
53 |
| - items.mapIndexed { index, item -> "\t<item index=\"$index\">${item.name}</item>" } |
54 |
| - .joinToString("\n") |
55 |
| - } |
56 |
| - |</items> |
57 |
| - |<input> |
58 |
| - |$input |
59 |
| - |</input> |
60 |
| - """ |
| 26 | + is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A |
| 27 | + is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools) |
| 28 | + is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools) |
| 29 | + is Tool.FlowOfAIEventsSealed -> |
| 30 | + channelFlow { |
| 31 | + send(AIEvent.Start) |
| 32 | + config.api.prompt( |
| 33 | + prompt = prompt, |
| 34 | + scope = config.conversation, |
| 35 | + serializer = serializer.sealedSerializer, |
| 36 | + tools = config.tools, |
| 37 | + collector = this |
| 38 | + ) |
| 39 | + } |
| 40 | + as A |
| 41 | + is Tool.FlowOfAIEvents -> |
| 42 | + channelFlow { |
| 43 | + send(AIEvent.Start) |
| 44 | + config.api.prompt( |
| 45 | + prompt = prompt, |
| 46 | + scope = config.conversation, |
| 47 | + serializer = serializer.serializer, |
| 48 | + tools = config.tools, |
| 49 | + collector = this |
| 50 | + ) |
| 51 | + } |
| 52 | + as A |
61 | 53 | }
|
62 | 54 |
|
63 |
| - @OptIn(ExperimentalSerializationApi::class) |
64 |
| - fun KType.enumValuesName( |
65 |
| - serializer: KSerializer<Any?> = serializer(this) |
66 |
| - ): List<Classification> { |
67 |
| - return if (serializer.descriptor.kind != SerialKind.ENUM) { |
68 |
| - emptyList() |
69 |
| - } else { |
70 |
| - (0 until serializer.descriptor.elementsCount).map { index -> |
71 |
| - val name = |
72 |
| - serializer.descriptor |
73 |
| - .getElementName(index) |
74 |
| - .removePrefix(serializer.descriptor.serialName) |
75 |
| - val description = |
76 |
| - (serializer.descriptor.getElementAnnotations(index).first { it is Description } |
77 |
| - as Description) |
78 |
| - .value |
79 |
| - Classification(name, description) |
| 55 | + private suspend fun runWithEnumSingleTokenSerializer( |
| 56 | + serializer: Tool.Enumeration<A>, |
| 57 | + prompt: Prompt |
| 58 | + ): A { |
| 59 | + val encoding = prompt.model.modelType(forFunctions = false).encoding |
| 60 | + val cases = serializer.cases |
| 61 | + val logitBias = |
| 62 | + cases |
| 63 | + .flatMap { |
| 64 | + val result = encoding.encode(it.function.name) |
| 65 | + if (result.size > 1) { |
| 66 | + error("Cannot encode enum case $it into one token") |
| 67 | + } |
| 68 | + result |
80 | 69 | }
|
81 |
| - } |
| 70 | + .associate { "$it" to 100 } |
| 71 | + val result = |
| 72 | + config.api.createChatCompletion( |
| 73 | + CreateChatCompletionRequest( |
| 74 | + messages = prompt.messages, |
| 75 | + model = prompt.model, |
| 76 | + logitBias = logitBias, |
| 77 | + maxTokens = 1, |
| 78 | + temperature = 0.0 |
| 79 | + ) |
| 80 | + ) |
| 81 | + val choice = result.choices[0].message.content |
| 82 | + val enumSerializer = serializer.enumSerializer |
| 83 | + return if (choice != null) { |
| 84 | + enumSerializer(choice) |
| 85 | + } else { |
| 86 | + error("Cannot decode enum case from $choice") |
82 | 87 | }
|
83 | 88 | }
|
84 | 89 |
|
85 | 90 | companion object {
|
86 |
| - |
87 |
| - fun <A : Any> chat( |
88 |
| - target: KType, |
89 |
| - model: CreateChatCompletionRequestModel, |
90 |
| - api: Chat, |
91 |
| - conversation: Conversation, |
92 |
| - enumSerializer: ((case: String) -> A)?, |
93 |
| - caseSerializers: List<KSerializer<A>>, |
94 |
| - serializer: () -> KSerializer<A>, |
95 |
| - ): DefaultAI<A> = |
96 |
| - DefaultAI( |
97 |
| - target = target, |
98 |
| - model = model, |
99 |
| - api = api, |
100 |
| - serializer = serializer, |
101 |
| - conversation = conversation, |
102 |
| - enumSerializer = enumSerializer, |
103 |
| - caseSerializers = caseSerializers |
104 |
| - ) |
105 |
| - |
106 |
| - fun images( |
107 |
| - config: Config = Config(), |
108 |
| - ): Images = OpenAI(config).images |
109 |
| - |
110 |
| - @PublishedApi |
111 |
| - internal suspend inline fun <reified A : Any> invokeEnum( |
112 |
| - prompt: Prompt, |
113 |
| - target: KType = typeOf<A>(), |
114 |
| - config: Config = Config(), |
115 |
| - api: Chat = OpenAI(config).chat, |
116 |
| - conversation: Conversation = Conversation() |
117 |
| - ): A = |
118 |
| - chat( |
119 |
| - target = target, |
120 |
| - model = prompt.model, |
121 |
| - api = api, |
122 |
| - conversation = conversation, |
123 |
| - enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) }, |
124 |
| - caseSerializers = emptyList() |
125 |
| - ) { |
126 |
| - serializer<A>() |
127 |
| - } |
128 |
| - .invoke(prompt) |
129 |
| - |
130 |
| - /** |
131 |
| - * Classify a prompt using a given enum. |
132 |
| - * |
133 |
| - * @param input The input to the model. |
134 |
| - * @param output The output to the model. |
135 |
| - * @param context The context to the model. |
136 |
| - * @param model The model to use. |
137 |
| - * @param target The target type to return. |
138 |
| - * @param api The chat API to use. |
139 |
| - * @param conversation The conversation to use. |
140 |
| - * @return The classified enum. |
141 |
| - * @throws IllegalArgumentException If no enum values are found. |
142 |
| - */ |
143 | 91 | @AiDsl
|
144 |
| - @Throws(IllegalArgumentException::class, CancellationException::class) |
145 | 92 | suspend inline fun <reified E> classify(
|
146 | 93 | input: String,
|
147 | 94 | output: String,
|
148 | 95 | context: String,
|
149 |
| - model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o, |
150 |
| - target: KType = typeOf<E>(), |
151 |
| - config: Config = Config(), |
152 |
| - api: Chat = OpenAI(config).chat, |
153 |
| - conversation: Conversation = Conversation() |
154 |
| - ): E where E : PromptClassifier, E : Enum<E> { |
155 |
| - val value = enumValues<E>().firstOrNull() ?: error("No enum values found") |
156 |
| - return invoke( |
| 96 | + config: AIConfig = AIConfig(), |
| 97 | + ): E where E : Enum<E>, E : PromptClassifier { |
| 98 | + val value = enumValues<E>().firstOrNull() ?: error("No values to classify") |
| 99 | + return AI<E>( |
157 | 100 | prompt = value.template(input, output, context),
|
158 |
| - model = model, |
159 |
| - target = target, |
160 | 101 | config = config,
|
161 |
| - api = api, |
162 |
| - conversation = conversation |
163 | 102 | )
|
164 | 103 | }
|
165 | 104 |
|
166 | 105 | @AiDsl
|
167 |
| - @Throws(IllegalArgumentException::class, CancellationException::class) |
168 | 106 | suspend inline fun <reified E> multipleClassify(
|
169 | 107 | input: String,
|
170 |
| - model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o, |
171 |
| - config: Config = Config(), |
172 |
| - api: Chat = OpenAI(config).chat, |
173 |
| - conversation: Conversation = Conversation() |
174 |
| - ): List<E> where E : PromptMultipleClassifier, E : Enum<E> { |
| 108 | + config: AIConfig = AIConfig(), |
| 109 | + ): List<E> where E : Enum<E>, E : PromptMultipleClassifier { |
175 | 110 | val values = enumValues<E>()
|
176 |
| - val value = values.firstOrNull() ?: error("No enum values found") |
| 111 | + val value = values.firstOrNull() ?: error("No values to classify") |
177 | 112 | val selected: SelectedItems =
|
178 |
| - invoke( |
| 113 | + AI( |
179 | 114 | prompt = value.template(input),
|
180 |
| - model = model, |
181 |
| - config = config, |
182 |
| - api = api, |
183 |
| - conversation = conversation |
| 115 | + serializer = Tool.fromKotlin<SelectedItems>(), |
| 116 | + config = config |
184 | 117 | )
|
185 | 118 | return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
|
186 | 119 | }
|
187 |
| - |
188 |
| - @AiDsl |
189 |
| - suspend inline operator fun <reified A : Any> invoke( |
190 |
| - prompt: String, |
191 |
| - target: KType = typeOf<A>(), |
192 |
| - model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125, |
193 |
| - config: Config = Config(), |
194 |
| - api: Chat = OpenAI(config).chat, |
195 |
| - conversation: Conversation = Conversation() |
196 |
| - ): A = chat(Prompt(model, prompt), target, config, api, conversation) |
197 |
| - |
198 |
| - @AiDsl |
199 |
| - suspend inline operator fun <reified A : Any> invoke( |
200 |
| - prompt: Prompt, |
201 |
| - target: KType = typeOf<A>(), |
202 |
| - config: Config = Config(), |
203 |
| - api: Chat = OpenAI(config).chat, |
204 |
| - conversation: Conversation = Conversation() |
205 |
| - ): A = chat(prompt, target, config, api, conversation) |
206 |
| - |
207 |
| - @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) |
208 |
| - @AiDsl |
209 |
| - suspend inline fun <reified A : Any> chat( |
210 |
| - prompt: Prompt, |
211 |
| - target: KType = typeOf<A>(), |
212 |
| - config: Config = Config(), |
213 |
| - api: Chat = OpenAI(config).chat, |
214 |
| - conversation: Conversation = Conversation() |
215 |
| - ): A { |
216 |
| - val kind = |
217 |
| - (target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind |
218 |
| - ?: error("Cannot find SerialKind for $target") |
219 |
| - return when (kind) { |
220 |
| - SerialKind.ENUM -> invokeEnum<A>(prompt, target, config, api, conversation) |
221 |
| - else -> { |
222 |
| - chat( |
223 |
| - target = target, |
224 |
| - model = prompt.model, |
225 |
| - api = api, |
226 |
| - conversation = conversation, |
227 |
| - enumSerializer = null, |
228 |
| - caseSerializers = emptyList() |
229 |
| - ) { |
230 |
| - serializer<A>() |
231 |
| - } |
232 |
| - .invoke(prompt) |
233 |
| - } |
234 |
| - } |
235 |
| - } |
236 | 120 | }
|
237 | 121 | }
|
| 122 | + |
| 123 | +@AiDsl |
| 124 | +suspend inline fun <reified A> AI( |
| 125 | + prompt: String, |
| 126 | + serializer: Tool<A> = Tool.fromKotlin<A>(), |
| 127 | + config: AIConfig = AIConfig() |
| 128 | +): A = AI(Prompt(config.model, prompt), serializer, config) |
| 129 | + |
| 130 | +@AiDsl |
| 131 | +suspend inline fun <reified A> AI( |
| 132 | + prompt: Prompt, |
| 133 | + serializer: Tool<A> = Tool.fromKotlin<A>(), |
| 134 | + config: AIConfig = AIConfig(), |
| 135 | +): A = AI(config, serializer).invoke(prompt) |
0 commit comments