-
Notifications
You must be signed in to change notification settings - Fork 8
/
models.ts
154 lines (142 loc) · 3.76 KB
/
models.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import s from 'dedent'
import { zodResponseFormat } from 'openai/helpers/zod.mjs'
import { ParsedFunctionToolCall } from 'openai/resources/beta/chat/completions'
import { FunctionParameters } from 'openai/resources/shared.js'
import { z, ZodObject } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'
import { Message } from './messages.js'
import { Tool } from './tool.js'
/**
* Resposne format for LLM calls is an object of Zod schemas.
*
* The object is then converted either to discriminated union or tool calls,
* depending on the provider.
*/
type LLMResponseFormat = Record<string, z.ZodObject<any>>
type LLMCall<T extends LLMResponseFormat> = {
messages: Message[]
response_format: T
temperature?: number
}
type LLMCallWithTools<T extends LLMResponseFormat> = LLMCall<T> & {
tools: Record<string, Tool>
}
type LLMResponse<T extends LLMResponseFormat> = {
[K in keyof T]: {
type: K
value: z.infer<T[K]>
}
}[keyof T]
type FunctionToolCall = {
type: 'tool_call'
value: ParsedFunctionToolCall[]
}
/**
* Interface for LLM providers
*/
export interface Provider {
/**
* Method for chat-like interactions with LLM.
*
* For response_format such as:
* ```
* { a: z.object({ b: z.string() }) }
* ```
* the return type is:
* ```
* { type: "a", value: { b: string } }
* ```
*
* If you pass tools, the return type is further extended with:
* ```
* { type: 'tool_call', value: ParsedFunctionToolCall[] }
* ```
*/
chat<T extends LLMResponseFormat>(
args: LLMCallWithTools<T>
): Promise<FunctionToolCall | LLMResponse<T>>
chat<T extends LLMResponseFormat>(args: LLMCall<T>): Promise<LLMResponse<T>>
}
/**
* Converts an object of tools to OpenAI tools format.
*/
export const toLLMTools = (tools: Record<string, Tool>, strict: boolean = true) => {
return Object.entries(tools).map(([name, tool]) => ({
type: 'function' as const,
function: {
name,
parameters: zodToJsonSchema(tool.parameters),
description: tool.description,
strict,
},
}))
}
/**
* Converts an object such as
* ```
* { a: z.object({ b: z.string() }) }
* ```
* to OpenAI structured output.
*
* Each key in the union is converted to a discriminated union, such as:
* ```
* z.discriminatedUnion('type', [
* z.object({ type: z.literal('a'), value: z.object({ b: z.string() }) }),
* ])
* ```
*/
export const responseAsStructuredOutput = (response_format: Record<string, any>) => {
const [first, ...rest] = Object.entries(response_format)
return zodResponseFormat(
z.object({
response: z.discriminatedUnion('type', [entryToObject(first), ...rest.map(entryToObject)]),
}),
'task_result'
)
}
const entryToObject = ([key, value]: [string, ZodObject<any>]) => {
return z.object({ type: z.literal(key), value })
}
/**
* Converts an object such as
* ```
* { a: z.object({ b: z.string() }) }
* ```
* to a list of tool calls such as
* ```
* [
* { type: 'function', function: { name: 'a', parameters: { b: z.string() } } },
* ]
* ```
*/
export const responseAsToolCall = (
response_format: Record<string, any>,
strict: boolean = true
) => {
return Object.entries(response_format).map(([name, response]) => {
const schema: FunctionToolSchema = {
type: 'function' as const,
function: {
name,
parameters: zodToJsonSchema(response),
description: s`
Call this function when you are done processing user request
and want to return "${name}" as the result.
`,
},
}
if (strict) {
schema.function.strict = strict
}
return schema
})
}
type FunctionToolSchema = {
type: 'function'
function: {
name: string
parameters: FunctionParameters
description: string
strict?: boolean
}
}