Skip to content

Commit

Permalink
GCP client for Google AI platform (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev authored Jul 26, 2023
1 parent e156190 commit bb07d43
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 31 deletions.
7 changes: 7 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutin
ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" }
ktor-http = { module = "io.ktor:ktor-http", version.ref = "ktor" }
ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-content-negotiation ={ module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" }
ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" }
ktor-client-js = { module = "io.ktor:ktor-client-js", version.ref = "ktor" }
ktor-client-winhttp = { module = "io.ktor:ktor-client-winhttp", version.ref = "ktor" }
Expand Down Expand Up @@ -94,6 +96,11 @@ jackson-schema-jakarta = { module = "com.github.victools:jsonschema-module-jakar
jakarta-validation = { module = "jakarta.validation:jakarta.validation-api", version.ref = "jakarta" }

[bundles]
ktor-client = [
"ktor-client",
"ktor-client-content-negotiation",
"ktor-client-serialization"
]
arrow = [
"arrow-core",
"arrow-fx-coroutines"
Expand Down
100 changes: 69 additions & 31 deletions integrations/gcp/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,51 +1,89 @@
plugins {
id(libs.plugins.kotlin.multiplatform.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
alias(libs.plugins.spotless)
alias(libs.plugins.arrow.gradle.publish)
alias(libs.plugins.semver.gradle)
id(libs.plugins.kotlin.multiplatform.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
alias(libs.plugins.spotless)
alias(libs.plugins.arrow.gradle.publish)
alias(libs.plugins.semver.gradle)
}

repositories {
mavenCentral()
mavenCentral()
}

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
toolchain {
languageVersion = JavaLanguageVersion.of(11)
}
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
toolchain {
languageVersion = JavaLanguageVersion.of(11)
}
}

kotlin {
jvm()
js(IR) {
browser()
nodejs()
jvm()
js(IR) {
browser()
nodejs()
}

linuxX64()
macosX64()
macosArm64()
mingwX64()

sourceSets {
val commonMain by getting {
dependencies {
api(projects.xefCore)
implementation(libs.bundles.ktor.client)
}
}

val jvmMain by getting {
dependencies {
implementation(libs.logback)
api(libs.ktor.client.cio)
}
}

val jsMain by getting {
dependencies {
api(libs.ktor.client.js)
}
}

linuxX64()
macosX64()
macosArm64()
mingwX64()

sourceSets {
val commonMain by getting {
dependencies {
api(projects.xefCore)
}
}
val linuxX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosArm64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val mingwX64Main by getting {
dependencies {
api(libs.ktor.client.winhttp)
}
}
}
}

spotless {
kotlin {
target("**/*.kt")
ktfmt().googleStyle()
}
kotlin {
target("**/*.kt")
ktfmt().googleStyle()
}
}

tasks.withType<AbstractPublishToMaven> {
dependsOn(tasks.withType<Sign>())
dependsOn(tasks.withType<Sign>())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package com.xebia.functional.xef.gcp

import com.xebia.functional.xef.AIError
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.HttpRequestRetry
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json

@OptIn(ExperimentalStdlibApi::class)
class GcpClient(
private val apiEndpoint: String,
private val projectId: String,
private val modelId: String,
private val token: String
) : AutoCloseable {
private val http: HttpClient = HttpClient {
install(HttpTimeout)
install(HttpRequestRetry)
install(ContentNegotiation) {
json(
Json {
encodeDefaults = false
isLenient = true
}
)
}
}

@Serializable
private data class Prompt(val instances: List<Instance>, val parameters: Parameters? = null)

@Serializable
private data class Instance(
val context: String? = null,
val examples: List<Example>? = null,
val messages: List<Message>,
)

@Serializable data class Example(val input: String, val output: String)

@Serializable private data class Message(val author: String, val content: String)

@Serializable
private class Parameters(
val temperature: Double? = null,
val maxOutputTokens: Int? = null,
val topK: Int? = null,
val topP: Double? = null
)

@Serializable data class Response(val predictions: List<Predictions>)

@Serializable
data class SafetyAttributes(
val blocked: Boolean,
val scores: List<String>,
val categories: List<String>
)

@Serializable data class CitationMetadata(val citations: List<String>)

@Serializable data class Candidates(val author: String?, val content: String?)

@Serializable
data class Predictions(
val safetyAttributes: List<SafetyAttributes>,
val citationMetadata: List<CitationMetadata>,
val candidates: List<Candidates>
)

suspend fun promptMessage(
prompt: String,
temperature: Double? = null,
maxOutputTokens: Int? = null,
topK: Int? = null,
topP: Double? = null
): String {
val body =
Prompt(
listOf(Instance(messages = listOf(Message(author = "user", content = prompt)))),
Parameters(temperature, maxOutputTokens, topK, topP)
)
val response =
http.post(
"https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess())
response.body<Response>().predictions.firstOrNull()?.candidates?.firstOrNull()?.content
?: throw AIError.NoResponse()
else throw GcpClientException(response.status, response.bodyAsText())
}

class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) :
IllegalStateException("$httpStatusCode: $error")

override fun close() {
http.close()
}
}

0 comments on commit bb07d43

Please sign in to comment.