From 31f712f0e4e9fbbbdf292772b68c4efd714c8638 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Tranta?= Date: Mon, 28 Oct 2024 11:56:07 +0100 Subject: [PATCH] add parallelToolCalls option to pass it to the API --- src/LangtailPrompts.spec.ts | 36 ++++++++++++++++++++++++++++++++++++ src/LangtailPrompts.ts | 1 + src/getOpenAIBody.test.ts | 33 +++++++++++++++++++++++++++++++++ src/getOpenAIBody.ts | 3 +++ src/schemas.ts | 2 ++ 5 files changed, 75 insertions(+) diff --git a/src/LangtailPrompts.spec.ts b/src/LangtailPrompts.spec.ts index 001a6b6..07c8ab5 100644 --- a/src/LangtailPrompts.spec.ts +++ b/src/LangtailPrompts.spec.ts @@ -245,6 +245,42 @@ describe.skipIf(!liveTesting)( describe("LangtailPrompts", () => { describe("invoke with optional callbacks", () => { + + it("should pass parallel_tool_calls param to fetch", async () => { + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => ({ + choices: [ + { + message: { + content: "Test response", + }, + }, + ], + }), + headers: new Headers({ + 'Content-Type': 'application/json', + 'X-API-Key': 'test-api-key', + 'x-langtail-thread-id': 'test-thread-id' + }), + }); + + const lt = new LangtailPrompts({ + apiKey: "test-api-key", + fetch: mockFetch, + }); + + await lt.invoke({ + prompt: "test-prompt", + environment: "production", + parallelToolCalls: true, + }); + + expect(mockFetch).toHaveBeenCalled(); + const body = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(body).toHaveProperty('parallelToolCalls', true); + }); it("should trigger onRawResponse callback when response is returned", async () => { const mockFetch = vi.fn().mockResolvedValue({ ok: true, diff --git a/src/LangtailPrompts.ts b/src/LangtailPrompts.ts index 9f1a7d6..b0bc600 100644 --- a/src/LangtailPrompts.ts +++ b/src/LangtailPrompts.ts @@ -45,6 +45,7 @@ export type IRequestParams

= unde export type IRequestParamsStream

= undefined, V extends Version = undefined, S extends boolean | undefined = false> = IRequestParams & { stream?: S + parallelToolCalls?: boolean } export type IInvokeOptionalCallbacks = { diff --git a/src/getOpenAIBody.test.ts b/src/getOpenAIBody.test.ts index 6c7d7d2..a6f73a1 100644 --- a/src/getOpenAIBody.test.ts +++ b/src/getOpenAIBody.test.ts @@ -117,6 +117,39 @@ describe("getOpenAIBody", () => { `) }) + + it("should add parallel_tool_calls param when it is set in parsedBody", () => { + const completionConfig = { + state: { + type: "chat" as const, + args: { + model: "gpt-3.5-turbo", + max_tokens: 100, + temperature: 0.8, + top_p: 1, + presence_penalty: 0, + frequency_penalty: 0, + jsonmode: false, + seed: 123, + stop: [], + }, + template: [ + { + role: "system" as const, + content: "tell me a story", + }, + ], + }, + chatInput: {}, + } + + const openAIbody = getOpenAIBody(completionConfig, { + parallelToolCalls: true, + }) + + expect(openAIbody).toHaveProperty('parallel_tool_calls', true) + }) + it("should override parameters from the playground with the ones in parsedBody", () => { const completionConfig = { state: { diff --git a/src/getOpenAIBody.ts b/src/getOpenAIBody.ts index ffe7010..332e62b 100644 --- a/src/getOpenAIBody.ts +++ b/src/getOpenAIBody.ts @@ -52,6 +52,9 @@ export function getOpenAIBody( temperature: parsedBody.temperature ?? completionArgs.temperature, messages: inputMessages, top_p: parsedBody.top_p ?? completionArgs.top_p, + ...(parsedBody.parallelToolCalls !== undefined + ? { parallel_tool_calls: parsedBody.parallelToolCalls } + : {}), presence_penalty: parsedBody.presence_penalty ?? completionArgs.presence_penalty, frequency_penalty: diff --git a/src/schemas.ts b/src/schemas.ts index 711489e..f9440f5 100644 --- a/src/schemas.ts +++ b/src/schemas.ts @@ -213,6 +213,7 @@ export const bodyMetadataSchema = z export const langtailBodySchema = z.object({ doNotRecord: z.boolean().optional(), + parallelToolCalls: z.boolean().optional(), metadata: bodyMetadataSchema, _langtailTestRunId: z.string().optional(), _langtailTestInputId: z.string().optional(), @@ -226,6 +227,7 @@ export const openAIBodySchemaObjectDefinition = { max_tokens: z.number().optional(), temperature: z.number().optional(), top_p: z.number().optional(), + parallel_tool_calls: z.boolean().optional(), presence_penalty: z.number().optional(), frequency_penalty: z.number().optional(), model: z.string().optional(),