Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 201 additions & 14 deletions packages/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ The following example demonstrates how to use the `@gel/ai` package to query an
import { createClient } from "gel";
import { createRAGClient } from "@gel/ai";

const client = createRAGClient({
instanceName: "_localdev",
database: "main",
tlsSecurity: "insecure",
});
const client = createClient();

const gpt4Rag = createRAGClient(client, {
model: "gpt-4-turbo",
Expand All @@ -35,33 +31,46 @@ const gpt4Rag = createRAGClient(client, {
const astronomyRag = gpt4Rag.withContext({ query: "Astronomy" });

console.time("gpt-4 Time");
console.log(await astronomyRag.queryRag({ prompt: "What color is the sky on Mars?" }));
console.log(
(await astronomyRag.queryRag({ prompt: "What color is the sky on Mars?" }))
.content,
);
console.timeEnd("gpt-4 Time");

const fastAstronomyRag = astronomyRag.withConfig({
model: "gpt-4o",
});

console.time("gpt-4o Time");
console.log(await fastAstronomyRag.queryRag({ prompt: "What color is the sky on Mars?" }));
console.log(
(
await fastAstronomyRag.queryRag({
prompt: "What color is the sky on Mars?",
})
).content,
);
console.timeEnd("gpt-4o Time");

const fastChemistryRag = fastAstronomyRag.withContext({ query: "Chemistry" });

console.log(
await fastChemistryRag.queryRag({ prompt: "What is the atomic number of gold?" }),
(
await fastChemistryRag.queryRag({
prompt: "What is the atomic number of gold?",
})
).content,
);

// handle the Response object
const response = await fastChemistryRag.streamRag(
{ prompt: "What is the atomic number of gold?" },
);
const response = await fastChemistryRag.streamRag({
prompt: "What is the atomic number of gold?",
});
handleReadableStream(response); // custom function that reads the stream

// handle individual chunks as they arrive
for await (const chunk of fastChemistryRag.streamRag(
{ prompt: "What is the atomic number of gold?" },
)) {
for await (const chunk of fastChemistryRag.streamRag({
prompt: "What is the atomic number of gold?",
})) {
console.log("chunk", chunk);
}

Expand All @@ -73,3 +82,181 @@ console.log(
}),
);
```

## Tool Calls

The `@gel/ai` package supports tool calls, allowing you to extend the capabilities of the AI model with your own functions. Here's how to use them:

1. **Define your tools**: Create an array of `ToolDefinition` objects that describe your functions, their parameters, and what they do.
2. **Send the request**: Call `queryRag` or `streamRag` with the user's prompt and the `tools` array. You can also use the `tool_choice` parameter to control how the model uses your tools.
3. **Handle the tool call**: If the model decides to use a tool, it will return an `AssistantMessage` with a `tool_calls` array. Your code needs to:
1. Parse the `tool_calls` array to identify the tool and its arguments.
2. Execute the tool and get the result.
3. Create a `ToolMessage` with the result.
4. Send the `ToolMessage` back to the model in a new request.
4. **Receive the final response**: The model will use the tool's output to generate a final response.

### Example

```typescript
import type {
Message,
ToolDefinition,
UserMessage,
ToolMessage,
AssistantMessage,
} from "@gel/ai";

// 1. Define your tools
const tools: ToolDefinition[] = [
{
type: "function",
name: "get_weather",
description: "Get the current weather for a given city.",
parameters: {
type: "object",
properties: {
city: {
type: "string",
description: "The city to get the weather for.",
},
},
required: ["city"],
},
},
];

// 2. Send the request
const userMessage: UserMessage = {
role: "user",
content: [{ type: "text", text: "What's the weather like in London?" }],
};

const messages: Message[] = [userMessage];

const response = await ragClient.queryRag({
messages,
tools,
tool_choice: "auto",
});

// 3. Handle the tool call
if (response.tool_calls) {
const toolCall = response.tool_calls[0];
if (toolCall.function.name === "get_weather") {
const args = JSON.parse(toolCall.function.arguments);
const weather = await getWeather(args.city); // Your function to get the weather

const toolMessage: ToolMessage = {
role: "tool",
tool_call_id: toolCall.id,
content: JSON.stringify({ weather }),
};

// Add the assistant's response and the tool message to the history
messages.push(response);
messages.push(toolMessage);

// 4. Send the tool result back to the model
const finalResponse = await ragClient.queryRag({
messages,
tools,
});

console.log(finalResponse.content);
}
} else {
console.log(response.content);
}

// Dummy function for the example
async function getWeather(city: string): Promise<string> {
return `The weather in ${city} is sunny.`;
}
```

### Streaming Responses

When using `streamRag`, you can handle tool calls as they arrive in the stream. The process is similar to the `queryRag` example, but you'll need to handle the streaming chunks to construct the tool call information.

```typescript
// Function to handle the streaming response
async function handleStreamingResponse(initialMessages: Message[]) {
const stream = ragClient.streamRag({
messages: initialMessages,
tools,
tool_choice: "auto",
});

let toolCalls: { id: string; name: string; arguments: string }[] = [];
let currentToolCall: { id: string; name: string; arguments: string } | null =
null;

for await (const chunk of stream) {
if (
chunk.type === "content_block_start" &&
chunk.content_block.type === "tool_use"
) {
currentToolCall = {
id: chunk.content_block.id!,
name: chunk.content_block.name,
arguments: chunk.content_block.args,
};
} else if (
chunk.type === "content_block_delta" &&
chunk.delta.type === "tool_call_delta"
) {
if (currentToolCall) {
currentToolCall.arguments += chunk.delta.args;
}
} else if (chunk.type === "content_block_stop") {
if (currentToolCall) {
toolCalls.push(currentToolCall);
currentToolCall = null;
}
} else if (chunk.type === "message_stop") {
// The model has finished its turn
if (toolCalls.length > 0) {
const assistantMessage: AssistantMessage = {
role: "assistant",
content: null,
tool_calls: toolCalls.map((tc) => ({
id: tc.id,
type: "function",
function: { name: tc.name, arguments: tc.arguments },
})),
};

const toolMessages: ToolMessage[] = await Promise.all(
toolCalls.map(async (tc) => {
const args = JSON.parse(tc.arguments);
const weather = await getWeather(args.city); // Your function to get the weather
return {
role: "tool",
tool_call_id: tc.id,
content: JSON.stringify({ weather }),
};
}),
);

const newMessages: Message[] = [
...initialMessages,
assistantMessage,
...toolMessages,
];

// Call the function again to get the final response
await handleStreamingResponse(newMessages);
}
} else if (
chunk.type === "content_block_delta" &&
chunk.delta.type === "text_delta"
) {
// Handle text responses from the model
process.stdout.write(chunk.delta.text);
}
}
}

handleStreamingResponse(messages);
```
3 changes: 3 additions & 0 deletions packages/ai/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ export const JS_EXT_TO_TREAT_AS_ESM = [".jsx"];
export const ESM_TS_JS_TRANSFORM_PATTERN = "^.+\\.m?[tj]sx?$";

export default {
maxWorkers: 1,
maxConcurrency: 1,
testTimeout: 30_000,
testEnvironment: "node",
testPathIgnorePatterns: ["./dist"],
globalSetup: "./test/globalSetup.ts",
Expand Down
1 change: 1 addition & 0 deletions packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"@repo/tsconfig": "*",
"@types/jest": "^29.5.12",
"@types/node": "^20.12.13",
"debug": "^4.4.1",
"gel": "^2.0.0",
"jest": "29.7.0",
"ts-jest": "^29.1.4",
Expand Down
54 changes: 32 additions & 22 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
type RagRequest,
type EmbeddingRequest,
isPromptRequest,
type AssistantMessage,
} from "./types.js";
import { getHTTPSCRAMAuth } from "gel/dist/httpScram.js";
import { cryptoUtils } from "gel/dist/browserCrypto.js";
Expand Down Expand Up @@ -124,7 +125,10 @@ export class RAGClient {
return response;
}

async queryRag(request: RagRequest, context = this.context): Promise<string> {
async queryRag(
request: RagRequest,
context = this.context,
): Promise<AssistantMessage> {
const res = await this.fetchRag(
{
...request,
Expand Down Expand Up @@ -217,42 +221,48 @@ export class RAGClient {
}
}

async function parseRagResponse(response: Response): Promise<string> {
async function parseRagResponse(response: Response): Promise<AssistantMessage> {
if (!response.headers.get("content-type")?.includes("application/json")) {
throw new Error("Expected response to have content-type: application/json");
}

const data: unknown = await response.json();

if (!data) {
if (!data || typeof data !== "object") {
throw new Error(`Expected JSON data, but got ${JSON.stringify(data)}`);
}

if (typeof data !== "object") {
throw new Error(
`Expected response to be an object, but got ${JSON.stringify(data)}`,
);
// Handle the new tool call format from the AI extension
if ("tool_calls" in data && Array.isArray(data.tool_calls)) {
return {
role: "assistant",
content: "text" in data ? (data.text as string) : null,
tool_calls: data.tool_calls.map((tc: any) => ({
id: tc.id,
type: tc.type,
function: {
name: tc.name,
arguments: JSON.stringify(tc.args),
},
})),
};
}

if ("text" in data) {
if (typeof data.text !== "string") {
throw new Error(
`Expected data.text to be a string, but got ${typeof data.text}: ${JSON.stringify(data.text)}`,
);
}
return data.text;
if ("role" in data && data.role === "assistant") {
return data as AssistantMessage;
}

if ("response" in data) {
if (typeof data.response !== "string") {
throw new Error(
`Expected data.response to be a string, but got ${typeof data.response}: ${JSON.stringify(data.response)}`,
);
}
return data.response;
if ("text" in data && typeof data.text === "string") {
return { role: "assistant", content: data.text };
}

if ("response" in data && typeof data.response === "string") {
return { role: "assistant", content: data.response };
}

throw new Error(
`Expected response to include a non-empty string for either the 'text' or 'response' key, but got: ${JSON.stringify(data)}`,
`Expected response to be a message object or include a 'text' or 'response' key, but got: ${JSON.stringify(
data,
)}`,
);
}
Loading