diff --git a/packages/ai/README.md b/packages/ai/README.md index a820de9eb..a5c71493e 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -22,11 +22,7 @@ The following example demonstrates how to use the `@gel/ai` package to query an import { createClient } from "gel"; import { createRAGClient } from "@gel/ai"; -const client = createRAGClient({ - instanceName: "_localdev", - database: "main", - tlsSecurity: "insecure", -}); +const client = createClient(); const gpt4Rag = createRAGClient(client, { model: "gpt-4-turbo", @@ -35,7 +31,10 @@ const gpt4Rag = createRAGClient(client, { const astronomyRag = gpt4Rag.withContext({ query: "Astronomy" }); console.time("gpt-4 Time"); -console.log(await astronomyRag.queryRag({ prompt: "What color is the sky on Mars?" })); +console.log( + (await astronomyRag.queryRag({ prompt: "What color is the sky on Mars?" })) + .content, +); console.timeEnd("gpt-4 Time"); const fastAstronomyRag = astronomyRag.withConfig({ @@ -43,25 +42,35 @@ const fastAstronomyRag = astronomyRag.withConfig({ }); console.time("gpt-4o Time"); -console.log(await fastAstronomyRag.queryRag({ prompt: "What color is the sky on Mars?" })); +console.log( + ( + await fastAstronomyRag.queryRag({ + prompt: "What color is the sky on Mars?", + }) + ).content, +); console.timeEnd("gpt-4o Time"); const fastChemistryRag = fastAstronomyRag.withContext({ query: "Chemistry" }); console.log( - await fastChemistryRag.queryRag({ prompt: "What is the atomic number of gold?" }), + ( + await fastChemistryRag.queryRag({ + prompt: "What is the atomic number of gold?", + }) + ).content, ); // handle the Response object -const response = await fastChemistryRag.streamRag( - { prompt: "What is the atomic number of gold?" }, -); +const response = await fastChemistryRag.streamRag({ + prompt: "What is the atomic number of gold?", +}); handleReadableStream(response); // custom function that reads the stream // handle individual chunks as they arrive -for await (const chunk of fastChemistryRag.streamRag( - { prompt: "What is the atomic number of gold?" }, -)) { +for await (const chunk of fastChemistryRag.streamRag({ + prompt: "What is the atomic number of gold?", +})) { console.log("chunk", chunk); } @@ -73,3 +82,181 @@ console.log( }), ); ``` + +## Tool Calls + +The `@gel/ai` package supports tool calls, allowing you to extend the capabilities of the AI model with your own functions. Here's how to use them: + +1. **Define your tools**: Create an array of `ToolDefinition` objects that describe your functions, their parameters, and what they do. +2. **Send the request**: Call `queryRag` or `streamRag` with the user's prompt and the `tools` array. You can also use the `tool_choice` parameter to control how the model uses your tools. +3. **Handle the tool call**: If the model decides to use a tool, it will return an `AssistantMessage` with a `tool_calls` array. Your code needs to: + 1. Parse the `tool_calls` array to identify the tool and its arguments. + 2. Execute the tool and get the result. + 3. Create a `ToolMessage` with the result. + 4. Send the `ToolMessage` back to the model in a new request. +4. **Receive the final response**: The model will use the tool's output to generate a final response. + +### Example + +```typescript +import type { + Message, + ToolDefinition, + UserMessage, + ToolMessage, + AssistantMessage, +} from "@gel/ai"; + +// 1. Define your tools +const tools: ToolDefinition[] = [ + { + type: "function", + name: "get_weather", + description: "Get the current weather for a given city.", + parameters: { + type: "object", + properties: { + city: { + type: "string", + description: "The city to get the weather for.", + }, + }, + required: ["city"], + }, + }, +]; + +// 2. Send the request +const userMessage: UserMessage = { + role: "user", + content: [{ type: "text", text: "What's the weather like in London?" }], +}; + +const messages: Message[] = [userMessage]; + +const response = await ragClient.queryRag({ + messages, + tools, + tool_choice: "auto", +}); + +// 3. Handle the tool call +if (response.tool_calls) { + const toolCall = response.tool_calls[0]; + if (toolCall.function.name === "get_weather") { + const args = JSON.parse(toolCall.function.arguments); + const weather = await getWeather(args.city); // Your function to get the weather + + const toolMessage: ToolMessage = { + role: "tool", + tool_call_id: toolCall.id, + content: JSON.stringify({ weather }), + }; + + // Add the assistant's response and the tool message to the history + messages.push(response); + messages.push(toolMessage); + + // 4. Send the tool result back to the model + const finalResponse = await ragClient.queryRag({ + messages, + tools, + }); + + console.log(finalResponse.content); + } +} else { + console.log(response.content); +} + +// Dummy function for the example +async function getWeather(city: string): Promise { + return `The weather in ${city} is sunny.`; +} +``` + +### Streaming Responses + +When using `streamRag`, you can handle tool calls as they arrive in the stream. The process is similar to the `queryRag` example, but you'll need to handle the streaming chunks to construct the tool call information. + +```typescript +// Function to handle the streaming response +async function handleStreamingResponse(initialMessages: Message[]) { + const stream = ragClient.streamRag({ + messages: initialMessages, + tools, + tool_choice: "auto", + }); + + let toolCalls: { id: string; name: string; arguments: string }[] = []; + let currentToolCall: { id: string; name: string; arguments: string } | null = + null; + + for await (const chunk of stream) { + if ( + chunk.type === "content_block_start" && + chunk.content_block.type === "tool_use" + ) { + currentToolCall = { + id: chunk.content_block.id!, + name: chunk.content_block.name, + arguments: chunk.content_block.args, + }; + } else if ( + chunk.type === "content_block_delta" && + chunk.delta.type === "tool_call_delta" + ) { + if (currentToolCall) { + currentToolCall.arguments += chunk.delta.args; + } + } else if (chunk.type === "content_block_stop") { + if (currentToolCall) { + toolCalls.push(currentToolCall); + currentToolCall = null; + } + } else if (chunk.type === "message_stop") { + // The model has finished its turn + if (toolCalls.length > 0) { + const assistantMessage: AssistantMessage = { + role: "assistant", + content: null, + tool_calls: toolCalls.map((tc) => ({ + id: tc.id, + type: "function", + function: { name: tc.name, arguments: tc.arguments }, + })), + }; + + const toolMessages: ToolMessage[] = await Promise.all( + toolCalls.map(async (tc) => { + const args = JSON.parse(tc.arguments); + const weather = await getWeather(args.city); // Your function to get the weather + return { + role: "tool", + tool_call_id: tc.id, + content: JSON.stringify({ weather }), + }; + }), + ); + + const newMessages: Message[] = [ + ...initialMessages, + assistantMessage, + ...toolMessages, + ]; + + // Call the function again to get the final response + await handleStreamingResponse(newMessages); + } + } else if ( + chunk.type === "content_block_delta" && + chunk.delta.type === "text_delta" + ) { + // Handle text responses from the model + process.stdout.write(chunk.delta.text); + } + } +} + +handleStreamingResponse(messages); +``` diff --git a/packages/ai/jest.config.js b/packages/ai/jest.config.js index d08204e90..9de7fa374 100644 --- a/packages/ai/jest.config.js +++ b/packages/ai/jest.config.js @@ -3,6 +3,9 @@ export const JS_EXT_TO_TREAT_AS_ESM = [".jsx"]; export const ESM_TS_JS_TRANSFORM_PATTERN = "^.+\\.m?[tj]sx?$"; export default { + maxWorkers: 1, + maxConcurrency: 1, + testTimeout: 30_000, testEnvironment: "node", testPathIgnorePatterns: ["./dist"], globalSetup: "./test/globalSetup.ts", diff --git a/packages/ai/package.json b/packages/ai/package.json index 18204b3ef..e2934ead3 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -38,6 +38,7 @@ "@repo/tsconfig": "*", "@types/jest": "^29.5.12", "@types/node": "^20.12.13", + "debug": "^4.4.1", "gel": "^2.0.0", "jest": "29.7.0", "ts-jest": "^29.1.4", diff --git a/packages/ai/src/core.ts b/packages/ai/src/core.ts index 653a26b85..12c103453 100644 --- a/packages/ai/src/core.ts +++ b/packages/ai/src/core.ts @@ -13,6 +13,7 @@ import { type RagRequest, type EmbeddingRequest, isPromptRequest, + type AssistantMessage, } from "./types.js"; import { getHTTPSCRAMAuth } from "gel/dist/httpScram.js"; import { cryptoUtils } from "gel/dist/browserCrypto.js"; @@ -124,7 +125,10 @@ export class RAGClient { return response; } - async queryRag(request: RagRequest, context = this.context): Promise { + async queryRag( + request: RagRequest, + context = this.context, + ): Promise { const res = await this.fetchRag( { ...request, @@ -217,42 +221,48 @@ export class RAGClient { } } -async function parseRagResponse(response: Response): Promise { +async function parseRagResponse(response: Response): Promise { if (!response.headers.get("content-type")?.includes("application/json")) { throw new Error("Expected response to have content-type: application/json"); } const data: unknown = await response.json(); - if (!data) { + if (!data || typeof data !== "object") { throw new Error(`Expected JSON data, but got ${JSON.stringify(data)}`); } - if (typeof data !== "object") { - throw new Error( - `Expected response to be an object, but got ${JSON.stringify(data)}`, - ); + // Handle the new tool call format from the AI extension + if ("tool_calls" in data && Array.isArray(data.tool_calls)) { + return { + role: "assistant", + content: "text" in data ? (data.text as string) : null, + tool_calls: data.tool_calls.map((tc: any) => ({ + id: tc.id, + type: tc.type, + function: { + name: tc.name, + arguments: JSON.stringify(tc.args), + }, + })), + }; } - if ("text" in data) { - if (typeof data.text !== "string") { - throw new Error( - `Expected data.text to be a string, but got ${typeof data.text}: ${JSON.stringify(data.text)}`, - ); - } - return data.text; + if ("role" in data && data.role === "assistant") { + return data as AssistantMessage; } - if ("response" in data) { - if (typeof data.response !== "string") { - throw new Error( - `Expected data.response to be a string, but got ${typeof data.response}: ${JSON.stringify(data.response)}`, - ); - } - return data.response; + if ("text" in data && typeof data.text === "string") { + return { role: "assistant", content: data.text }; + } + + if ("response" in data && typeof data.response === "string") { + return { role: "assistant", content: data.response }; } throw new Error( - `Expected response to include a non-empty string for either the 'text' or 'response' key, but got: ${JSON.stringify(data)}`, + `Expected response to be a message object or include a 'text' or 'response' key, but got: ${JSON.stringify( + data, + )}`, ); } diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 45f0cb1c2..039d20abf 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -12,7 +12,7 @@ export interface UserMessage { export interface AssistantMessage { role: "assistant"; - content: string; + content: string | null; tool_calls?: { id: string; type: "function"; @@ -49,6 +49,22 @@ export interface QueryContext { max_object_count?: number; } +export interface OpenAIToolDefinition { + type: "function"; + name: string; + description: string; + parameters: unknown; + strict?: boolean; +} + +export interface AnthropicToolDefinition { + name: string; + description: string; + input_schema: unknown; +} + +export type ToolDefinition = OpenAIToolDefinition | AnthropicToolDefinition; + export interface RagRequestPrompt { prompt: string; [key: string]: unknown; @@ -56,6 +72,8 @@ export interface RagRequestPrompt { export interface RagRequestMessages { messages: Message[]; + tools?: ToolDefinition[]; + tool_choice?: "auto" | "none" | "required"; [key: string]: unknown; } @@ -80,20 +98,22 @@ export interface MessageStart { }; } +export type ContentBlock = + | { + type: "text"; + text: string; + } + | { + type: "tool_use"; + id?: string | null; + name: string; + args?: Record | null; + }; + export interface ContentBlockStart { type: "content_block_start"; index: number; - content_block: - | { - type: "text"; - text: string; - } - | { - type: "tool_use"; - id?: string | null; - name: string; - args?: string | null; - }; + content_block: ContentBlock; } export interface ContentBlockDelta { @@ -142,6 +162,19 @@ export interface MessageError { }; } +export interface ToolCallChunk { + type: "tool_call_chunk"; + tool_call_chunk: { + index: number; + id?: string; + type?: "function"; + function: { + name?: string; + arguments?: string; + }; + }; +} + export type StreamingMessage = | MessageStart | ContentBlockStart @@ -149,7 +182,8 @@ export type StreamingMessage = | ContentBlockStop | MessageDelta | MessageStop - | MessageError; + | MessageError + | ToolCallChunk; export interface EmbeddingRequest { inputs: string[]; diff --git a/packages/ai/src/utils.ts b/packages/ai/src/utils.ts index 27bbebddd..fd8ba5615 100644 --- a/packages/ai/src/utils.ts +++ b/packages/ai/src/utils.ts @@ -29,5 +29,5 @@ export async function handleResponseError(response: Response): Promise { const bodyText = await response.text(); errorMessage = bodyText || "An unknown error occurred"; } - throw new Error(errorMessage); + throw new Error(`Status: ${response.status}. Message: ${errorMessage}`); } diff --git a/packages/ai/test/core.test.ts b/packages/ai/test/core.test.ts index b0b897b78..a562c3ab8 100644 --- a/packages/ai/test/core.test.ts +++ b/packages/ai/test/core.test.ts @@ -1,84 +1,37 @@ import { type Client } from "gel"; import { createRAGClient } from "../dist/index.js"; -import { getClient, waitFor, getAvailableExtensions } from "@repo/test-utils"; -import { createMockHttpServer, type MockHttpServer } from "./mockHttpServer"; +import { waitFor, getAvailableExtensions } from "@repo/test-utils"; +import { type MockHttpServer } from "./mockHttpServer"; +import { setupTestEnvironment } from "./test-setup"; const availableExtensions = getAvailableExtensions(); if (availableExtensions.has("ai")) { let mockServer: MockHttpServer; + let client: Client; beforeAll(async () => { - // Start the mock server - mockServer = createMockHttpServer(); - - const client = getClient(); - await client.ensureConnected(); - try { - await client.execute(` -create extension pgvector; -create extension ai; - -create type TestEmbeddingModel extending ext::ai::EmbeddingModel { - alter annotation ext::ai::model_name := "text-embedding-test"; - alter annotation ext::ai::model_provider := "custom::test"; - alter annotation ext::ai::embedding_model_max_input_tokens := "8191"; - alter annotation ext::ai::embedding_model_max_batch_tokens := "16384"; - alter annotation ext::ai::embedding_model_max_output_dimensions := "10"; - alter annotation ext::ai::embedding_model_supports_shortening := "true"; -}; - -create type TestTextGenerationModel extending ext::ai::TextGenerationModel { - alter annotation ext::ai::model_name := "text-generation-test"; - alter annotation ext::ai::model_provider := "custom::test"; - alter annotation ext::ai::text_gen_model_context_window := "16385"; -}; - -create type Astronomy { - create required property content: str; - - create deferred index ext::ai::index(embedding_model := "text-embedding-test") on (.content); -}; - -configure current branch insert ext::ai::CustomProviderConfig { - name := "custom::test", - secret := "dummy-key", - api_url := "${mockServer.url}/v1", - api_style := ext::ai::ProviderAPIStyle.OpenAI, -}; - -configure current branch set ext::ai::Config::indexer_naptime := "100ms"; + ({ mockServer, client } = await setupTestEnvironment()); + await client.execute(` +insert Astronomy { content := 'Skies on Mars are red' }; +insert Astronomy { content := 'Skies on Earth are blue' }; `); - } finally { - await client.close(); - } - }, 25_000); + }, 60_000); afterAll(async () => { // Stop the mock server if (mockServer) { await mockServer.close(); } + await client.close(); }); describe("@gel/ai", () => { - let client: Client; beforeEach(() => { mockServer.resetRequests(); }); - afterEach(async () => { - await client?.close(); - }); - test("RAG query", async () => { - client = getClient({ - tlsSecurity: "insecure", - }); - await client.execute(` -insert Astronomy { content := 'Skies on Mars are red' }; -insert Astronomy { content := 'Skies on Earth are blue' }; - `); await waitFor(async () => expect(mockServer.getEmbeddingsRequests().length).toBe(1), ); @@ -93,7 +46,7 @@ insert Astronomy { content := 'Skies on Earth are blue' }; prompt: "What color are the skies on Mars?", }); - expect(result).toEqual("This is a mock response."); + expect(result.content).toEqual("This is a mock response."); const streamedResult = ragClient.streamRag({ prompt: "What color are the skies on Mars?", @@ -115,12 +68,9 @@ insert Astronomy { content := 'Skies on Earth are blue' }; } expect(streamedResultString).toEqual("This is a mock response."); - }, 25_000); + }); test("embedding request", async () => { - client = getClient({ - tlsSecurity: "insecure", - }); const ragClient = createRAGClient(client, { model: "text-generation-test", }); @@ -142,5 +92,109 @@ insert Astronomy { content := 'Skies on Earth are blue' }; [0, 2, 0, 0, 2, 0, 0, 0, 0, 0], ); }); + + test("OpenAI style function calling", async () => { + const ragClient = createRAGClient(client, { + model: "text-generation-test", + }).withContext({ + query: "select Astronomy", + }); + + const result = await ragClient.queryRag({ + messages: [ + { + role: "user", + content: [{ type: "text", text: "What is the diameter of Mars?" }], + }, + ], + tools: [ + { + type: "function", + name: "get_planet_diameter", + description: "Get the diameter of a given planet.", + parameters: { + type: "object", + properties: { + planet_name: { + type: "string", + description: "The name of the planet, e.g. Mars", + }, + }, + required: ["planet_name"], + }, + }, + ], + tool_choice: "auto", + }); + + expect(result.tool_calls).toBeDefined(); + expect(result.tool_calls?.[0].function.name).toEqual( + "get_planet_diameter", + ); + expect(result.tool_calls?.[0].function.arguments).toEqual( + '{"planet_name":"Mars"}', + ); + }); + + test("OpenAI style streaming tool calling", async () => { + const ragClient = createRAGClient(client, { + model: "text-generation-test", + }).withContext({ + query: "select Astronomy", + }); + + const streamedResult = ragClient.streamRag({ + messages: [ + { + role: "user", + content: [{ type: "text", text: "What is the diameter of Mars?" }], + }, + ], + tools: [ + { + type: "function", + name: "get_planet_diameter", + description: "Get the diameter of a given planet.", + parameters: { + type: "object", + properties: { + planet_name: { + type: "string", + description: "The name of the planet, e.g. Mars", + }, + }, + required: ["planet_name"], + }, + }, + ], + tool_choice: "auto", + }); + + let functionName = ""; + let functionArguments = ""; + + for await (const message of streamedResult) { + if ( + message.type === "content_block_start" && + message.content_block.type === "tool_use" + ) { + if (message.content_block.name) { + functionName += message.content_block.name; + } + if (message.content_block.args) { + functionArguments += message.content_block.args; + } + } + if ( + message.type === "content_block_delta" && + message.delta.type === "tool_call_delta" + ) { + functionArguments += message.delta.args; + } + } + + expect(functionName).toEqual("get_planet_diameter"); + expect(functionArguments).toEqual('{"planet_name":"Mars"}'); + }); }); } diff --git a/packages/ai/test/globalSetup.ts b/packages/ai/test/globalSetup.ts index 8f44499a8..8141cd75e 100644 --- a/packages/ai/test/globalSetup.ts +++ b/packages/ai/test/globalSetup.ts @@ -9,17 +9,20 @@ import { } from "@repo/test-utils"; export default async () => { - // tslint:disable-next-line console.log("\nStarting Gel test cluster..."); const statusFile = generateStatusFileName("node"); console.log("Node status file:", statusFile); const { args, availableFeatures } = getServerCommand(getWSLPath(statusFile)); - console.log(`Starting server...`); + console.time("server"); + console.time("server-start"); const { proc, config } = await startServer(args, statusFile); + console.timeEnd("server-start"); + console.time("server-connect"); const { client, version } = await connectToServer(config); + console.timeEnd("server-connect"); const jestConfig: ConnectConfig = { ...config, @@ -36,6 +39,7 @@ export default async () => { global.gelConn = client; process.env._JEST_GEL_VERSION = JSON.stringify(version); + console.time("server-extension-list"); const availableExtensions = ( await client.query<{ name: string; @@ -44,7 +48,8 @@ export default async () => { ).map(({ name, version }) => [name, version]); process.env._JEST_GEL_AVAILABLE_EXTENSIONS = JSON.stringify(availableExtensions); + console.timeEnd("server-extension-list"); - // tslint:disable-next-line + console.timeEnd("server"); console.log(`Gel test cluster is up [port: ${jestConfig.port}]...`); }; diff --git a/packages/ai/test/mockHttpServer.ts b/packages/ai/test/mockHttpServer.ts index 8b1dd4dc8..64315cfef 100644 --- a/packages/ai/test/mockHttpServer.ts +++ b/packages/ai/test/mockHttpServer.ts @@ -1,5 +1,8 @@ import http from "node:http"; import type { AddressInfo } from "node:net"; +import Debug from "debug"; + +const debug = Debug("gel:test:ai:mockHttpServer"); export interface RecordedRequest { url?: string; @@ -43,24 +46,69 @@ const defaultChatCompletionResponse = { system_fingerprint: "fp_test", }; +const openAIFunctionCallingResponse = { + id: "chatcmpl-test-fn-calling", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "gpt-3.5-turbo-0125", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_123", + type: "function", + function: { + name: "get_planet_diameter", + arguments: '{"planet_name":"Mars"}', + }, + }, + ], + }, + logprobs: null, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + system_fingerprint: "fp_test", +}; + export function createMockHttpServer(): MockHttpServer { let chatCompletionsRequests: RecordedRequest[] = []; let embeddingsRequests: RecordedRequest[] = []; let otherRequests: RecordedRequest[] = []; const server = http.createServer((req, res) => { + debug("Request received."); + debug(`Request URL: ${req.url}, Method: ${req.method}`); + debug("Request headers:", req.headers); + let bodyChunks: Buffer[] = []; req.on("data", (chunk) => { + debug("Receiving data chunk."); bodyChunks.push(chunk); }); req.on("end", () => { + debug("Request data fully received."); const bodyString = Buffer.concat(bodyChunks).toString(); + debug("Request body (raw):", bodyString); let parsedBody: any = null; try { parsedBody = bodyString ? JSON.parse(bodyString) : null; + debug("Request body (parsed):", parsedBody); } catch (error) { - console.error("Mock server failed to parse request body:", error); + debug("Failed to parse request body:", error); + res.writeHead(500, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ error: "Failed to parse request body" })); + return; } const recordedRequest: RecordedRequest = { @@ -73,69 +121,165 @@ export function createMockHttpServer(): MockHttpServer { res.setHeader("Content-Type", "application/json"); if (req.method === "POST" && req.url === "/v1/chat/completions") { - console.log( - `Mock server received /v1/chat/completions request: ${bodyString}`, - ); + debug("Handling /v1/chat/completions request."); chatCompletionsRequests = [...chatCompletionsRequests, recordedRequest]; const acceptHeader = req.headers["accept"]; if (acceptHeader && acceptHeader.includes("text/event-stream")) { + debug("Handling streaming chat completion."); res.writeHead(200, { "Content-Type": "text/event-stream" }); const completionId = "chatcmpl-e7f8e220-656c-4455-a132-dacfc1370798"; const model = parsedBody.model; const created = Math.floor(Date.now() / 1000); + + if (parsedBody.tools) { + debug("Handling streaming tool calling."); + const toolCallId = "call_123"; + const functionName = "get_planet_diameter"; + const functionArgs = '{"planet_name":"Mars"}'; + + // First chunk: role and tool call metadata + res.write( + `data: ${JSON.stringify({ + id: completionId, + object: "chat.completion.chunk", + created, + model, + choices: [ + { + index: 0, + delta: { + role: "assistant", + content: null, + tool_calls: [ + { + index: 0, + id: toolCallId, + type: "function", + function: { name: functionName, arguments: "" }, + }, + ], + }, + logprobs: null, + }, + ], + })} + +`, + ); + + // Argument chunks + const argChunks = functionArgs.match(/.{1,10}/g) || []; + argChunks.forEach((argChunk) => { + res.write( + `data: ${JSON.stringify({ + id: completionId, + object: "chat.completion.chunk", + created, + model, + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "tool_call_delta", + function: { arguments: argChunk }, + }, + ], + }, + }, + ], + })} + +`, + ); + }); + + // Final chunk with finish reason + res.write( + `data: ${JSON.stringify({ + id: completionId, + object: "chat.completion.chunk", + created, + model, + choices: [{ index: 0, delta: {}, finish_reason: "tool_calls" }], + })} + +`, + ); + + res.write("data: [DONE]\n\n"); + res.end(); + return; + } + const finishReason = defaultChatCompletionResponse.choices[0].finish_reason; const content = defaultChatCompletionResponse.choices[0].message.content; const contentChunks = content.match(/.{1,50}/g) || []; // Split content into chunks of 50 characters - res.write( - `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}",` + - `"system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}\n\n`, - ); + const firstChunk = `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"finish_reason":null}]} + +`; + debug("Writing stream chunk:", firstChunk); + res.write(firstChunk); contentChunks.forEach((text, index) => { - res.write( - `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}",` + - `"system_fingerprint":null,"choices":[{"index":${index + 1},"delta":{"content":"${text}"},"finish_reason":null}]}\n\n`, - ); + const chunk = `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}","system_fingerprint":null,"choices":[{"index":${index + 1},"delta":{"content":"${text}"},"finish_reason":null}]} + +`; + debug("Writing stream chunk:", chunk); + res.write(chunk); }); - res.write( - `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}",` + - `"system_fingerprint":null,"choices":[{"index":0,"delta":{},"finish_reason":"${finishReason}"}]}\n\n`, - ); + const penultimateChunk = `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}","system_fingerprint":null,"choices":[{"index":0,"delta":{},"finish_reason":"${finishReason}"}]} - res.write( - `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}",` + - `"system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"${finishReason}"}],` + - `"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,` + - `"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}}\n\n`, - ); +`; + debug("Writing stream chunk:", penultimateChunk); + res.write(penultimateChunk); + const finalChunkBeforeDone = `data: {"id":"${completionId}","object":"chat.completion.chunk","created":${created},"model":"${model}","system_fingerprint":"fp_10c08bf97d","choices":[{"index":0,"delta":{},"finish_reason":"${finishReason}"}],"usage":{"queue_time":0.061348671,"prompt_tokens":18,"prompt_time":0.000211569,"completion_tokens":439,"completion_time":0.798181818,"total_tokens":457,"total_time":0.798393387}} + +`; + debug("Writing stream chunk:", finalChunkBeforeDone); + res.write(finalChunkBeforeDone); + + debug("Writing [DONE] chunk."); res.write("data: [DONE]\n\n"); res.end(); + debug("Stream ended."); } else { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify(defaultChatCompletionResponse)); + debug("Handling non-streaming chat completion."); + if (parsedBody.tools) { + debug("'tools' detected, sending function calling response."); + const responseBody = JSON.stringify(openAIFunctionCallingResponse); + debug("Response body:", responseBody); + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(responseBody); + } else { + debug("No 'tools' detected, sending default chat response."); + const responseBody = JSON.stringify(defaultChatCompletionResponse); + debug("Response body:", responseBody); + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(responseBody); + } } } else if (req.method === "POST" && req.url === "/v1/embeddings") { - console.log( - `Mock server received /v1/embeddings request: ${bodyString}`, - ); + debug("Handling /v1/embeddings request."); embeddingsRequests = [...embeddingsRequests, recordedRequest]; if ( parsedBody && "input" in parsedBody && Array.isArray(parsedBody.input) ) { + debug("Valid embeddings request body."); const inputs: string[] = parsedBody.input; const responseData = inputs.map((input, index) => ({ object: "embedding", index: index, - // Produce a dummy embedding as the number of occurences of the first ten - // letters of the alphabet. embedding: Array.from( { length: 10 }, (_, c) => input.split(String.fromCharCode(97 + c)).length - 1, @@ -146,18 +290,25 @@ export function createMockHttpServer(): MockHttpServer { data: responseData, }; res.writeHead(200); - res.end(JSON.stringify(response)); + const responseBody = JSON.stringify(response); + debug("Response body:", responseBody); + res.end(responseBody); } else { + debug("Invalid embeddings request body."); res.writeHead(400); - res.end(JSON.stringify({ error: "Invalid request body" })); + const responseBody = JSON.stringify({ + error: "Invalid request body", + }); + debug("Response body:", responseBody); + res.end(responseBody); } } else { - console.log( - `Mock server received unhandled request: ${req.method} ${req.url}`, - ); + debug(`Handling unhandled request: ${req.method} ${req.url}`); otherRequests = [...otherRequests, recordedRequest]; res.writeHead(404); - res.end(JSON.stringify({ error: "Not Found" })); + const responseBody = JSON.stringify({ error: "Not Found" }); + debug("Response body:", responseBody); + res.end(responseBody); } }); }); @@ -166,7 +317,7 @@ export function createMockHttpServer(): MockHttpServer { const address = server.address() as AddressInfo; const serverUrl = `http://localhost:${address.port}`; - console.log(`Mock HTTP server listening on ${serverUrl}`); + debug(`HTTP server listening on ${serverUrl}`); return { server, @@ -186,7 +337,7 @@ export function createMockHttpServer(): MockHttpServer { if (err) { reject(err); } else { - console.log(`Mock HTTP server on port ${address.port} closed.`); + debug(`HTTP server on port ${address.port} closed.`); resolve(); } }); diff --git a/packages/ai/test/test-setup.ts b/packages/ai/test/test-setup.ts new file mode 100644 index 000000000..b593aa47b --- /dev/null +++ b/packages/ai/test/test-setup.ts @@ -0,0 +1,56 @@ +import type { Client } from "gel"; +import { getClient } from "@repo/test-utils"; +import { createMockHttpServer, type MockHttpServer } from "./mockHttpServer"; + +export async function setupTestEnvironment(): Promise<{ + mockServer: MockHttpServer; + client: Client; +}> { + const mockServer = createMockHttpServer(); + + const client = getClient({ + tlsSecurity: "insecure", + }); + + await client.ensureConnected(); + await client.execute(` +reset schema to initial; +create extension pgvector; +create extension ai; + +create type TestEmbeddingModel extending ext::ai::EmbeddingModel { + alter annotation ext::ai::model_name := "text-embedding-test"; + alter annotation ext::ai::model_provider := "custom::test"; + alter annotation ext::ai::embedding_model_max_input_tokens := "8191"; + alter annotation ext::ai::embedding_model_max_batch_tokens := "16384"; + alter annotation ext::ai::embedding_model_max_output_dimensions := "10"; + alter annotation ext::ai::embedding_model_supports_shortening := "true"; +}; + +create type TestTextGenerationModel extending ext::ai::TextGenerationModel { + alter annotation ext::ai::model_name := "text-generation-test"; + alter annotation ext::ai::model_provider := "custom::test"; + alter annotation ext::ai::text_gen_model_context_window := "16385"; +}; + +create type Astronomy { + create required property content: str; + + create deferred index ext::ai::index(embedding_model := "text-embedding-test") on (.content); +}; + +configure current branch insert ext::ai::CustomProviderConfig { + name := "custom::test", + secret := "dummy-key", + api_url := "${mockServer.url}/v1", + api_style := ext::ai::ProviderAPIStyle.OpenAI, +}; + +configure current branch set ext::ai::Config::indexer_naptime := "100ms"; + `); + + return { + mockServer, + client, + }; +} diff --git a/yarn.lock b/yarn.lock index 1d4f34b54..5665e3636 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1212,6 +1212,18 @@ resolved "https://registry.npmjs.org/@eslint/js/-/js-9.3.0.tgz" integrity sha512-niBqk8iwv96+yuTwjM6bWg8ovzAPF9qkICsGtcoa5/dmqcEMfdwNAX7+/OHcJHc7wj7XqPxH98oAHytFYlw6Sw== +"@gel/create@^0.3.0-rc": + version "0.3.2" + resolved "https://registry.yarnpkg.com/@gel/create/-/create-0.3.2.tgz#72b8f60b33ae5d9568e88fc3f36b6027988fe7f8" + integrity sha512-FNBhGMlK+hdvS0mIUt8GUD9nGJyg5zFe4/sEcGEdWW58ldAU6FBjo7AFxGoQiBB9wQUQwXq13YdPyY99sA7trw== + dependencies: + "@clack/prompts" "^0.7.0" + debug "^4.3.4" + picocolors "^1.0.1" + read-pkg "^9.0.1" + shell-quote "^1.8.2" + write-package "^7.0.1" + "@humanwhocodes/config-array@^0.13.0": version "0.13.0" resolved "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz" @@ -3034,6 +3046,13 @@ debug@^4.3.7: dependencies: ms "^2.1.3" +debug@^4.4.1: + version "4.4.1" + resolved "https://registry.yarnpkg.com/debug/-/debug-4.4.1.tgz#e5a8bc6cbc4c6cd3e64308b0693a3d4fa550189b" + integrity sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ== + dependencies: + ms "^2.1.3" + decimal.js@^10.4.2: version "10.4.3" resolved "https://registry.npmjs.org/decimal.js/-/decimal.js-10.4.3.tgz" @@ -5899,7 +5918,16 @@ string-length@^4.0.1: char-regex "^1.0.2" strip-ansi "^6.0.0" -"string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: +"string-width-cjs@npm:string-width@^4.2.0": + version "4.2.3" + resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: version "4.2.3" resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== @@ -5924,7 +5952,14 @@ string_decoder@^1.1.1: dependencies: safe-buffer "~5.2.0" -"strip-ansi-cjs@npm:strip-ansi@^6.0.1", strip-ansi@^6.0.0, strip-ansi@^6.0.1: +"strip-ansi-cjs@npm:strip-ansi@^6.0.1": + version "6.0.1" + resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: version "6.0.1" resolved "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz" integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== @@ -6528,8 +6563,16 @@ word-wrap@^1.2.5: resolved "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz" integrity sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA== -"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0: - name wrap-ansi-cjs +"wrap-ansi-cjs@npm:wrap-ansi@^7.0.0": + version "7.0.0" + resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +wrap-ansi@^7.0.0: version "7.0.0" resolved "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz" integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==