Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { Client } from "gel";
import { EventSourceParserStream } from "eventsource-parser/stream";

import type { ResolvedConnectConfig } from "gel/dist/conUtils.js";
import {
getAuthenticatedFetch,
type AuthenticatedFetch,
Expand Down Expand Up @@ -46,9 +45,7 @@ export class RAGClient {
}

private static async getAuthenticatedFetch(client: Client) {
const connectConfig: ResolvedConnectConfig = (
await (client as any).pool._getNormalizedConnectConfig()
).connectionParams;
const connectConfig = await client.resolveConnectionParams();

return getAuthenticatedFetch(connectConfig, httpSCRAMAuth, "ext/ai/");
}
Expand Down Expand Up @@ -110,7 +107,7 @@ export class RAGClient {
!providedPrompt && {
name: "builtin::rag-default",
}),
custom: [...(this.options.prompt?.custom || []), ...messages],
custom: [...(this.options.prompt?.custom ?? []), ...messages],
},
query: [...messages].reverse().find((msg) => msg.role === "user")!
.content[0].text,
Expand Down Expand Up @@ -139,18 +136,9 @@ export class RAGClient {
);
}

const data = await res.json();
const data: unknown = await res.json();

if (
!data ||
typeof data !== "object" ||
typeof data.response !== "string"
) {
throw new Error(
"Expected response to be an object with response key of type string",
);
}
return data.response;
return parseResponse(data);
}

streamRag(
Expand Down Expand Up @@ -190,6 +178,7 @@ export class RAGClient {
}
},
then<TResult1 = Response, TResult2 = never>(
/* eslint-disable @typescript-eslint/no-duplicate-type-constituents */
onfulfilled?:
| ((value: Response) => TResult1 | PromiseLike<TResult1>)
| undefined
Expand All @@ -198,6 +187,7 @@ export class RAGClient {
| ((reason: any) => TResult2 | PromiseLike<TResult2>)
| undefined
| null,
/* eslint-enable @typescript-eslint/no-duplicate-type-constituents */
): Promise<TResult1 | TResult2> {
return fetchRag(
{
Expand Down Expand Up @@ -229,7 +219,43 @@ export class RAGClient {
throw new Error(bodyText);
}

const data: { data: { embedding: number[] }[] } = await response.json();
return data.data[0].embedding;
const data: unknown = await response.json();
return parseEmbeddingResponse(data);
}
}

function parseResponse(data: unknown): string {
if (
typeof data === "object" &&
data != null &&
"response" in data &&
typeof data.response === "string"
) {
return data.response;
}

throw new Error(
"Expected response to be an object with response key of type string",
);
}

function parseEmbeddingResponse(responseData: unknown): number[] {
if (
typeof responseData === "object" &&
responseData != null &&
"data" in responseData &&
Array.isArray(responseData.data)
) {
const firstItem: unknown = responseData.data[0];
if (
typeof firstItem === "object" &&
firstItem != null &&
"embedding" in firstItem
) {
return firstItem.embedding as number[];
}
}
throw new Error(
"Expected response to be an object with data key of type array of objects with embedding key of type number[]",
);
}