Skip to content

Commit a512380

Browse files
authored
feat(langgraph): add typedNode utility (#1235)
2 parents 07a8594 + d3a3da5 commit a512380

File tree

3 files changed

+143
-6
lines changed

3 files changed

+143
-6
lines changed

libs/langgraph/src/graph/graph.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ export type NodeSpec<RunInput, RunOutput> = {
168168
defer?: boolean;
169169
};
170170

171-
export type AddNodeOptions = {
171+
export type AddNodeOptions<Nodes extends string = string> = {
172172
metadata?: Record<string, unknown>;
173173
// eslint-disable-next-line @typescript-eslint/no-explicit-any
174174
subgraphs?: Pregel<any, any>[];
175-
ends?: string[];
175+
ends?: Nodes[];
176176
defer?: boolean;
177177
};
178178

libs/langgraph/src/graph/state.ts

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ export type StateGraphNodeSpec<RunInput, RunOutput> = NodeSpec<
8888
cachePolicy?: CachePolicy;
8989
};
9090

91-
export type StateGraphAddNodeOptions = {
91+
export type StateGraphAddNodeOptions<Nodes extends string = string> = {
9292
retryPolicy?: RetryPolicy;
9393
cachePolicy?: CachePolicy | boolean;
9494
// TODO: Fix generic typing for annotations
9595
// eslint-disable-next-line @typescript-eslint/no-explicit-any
9696
input?: AnnotationRoot<any> | AnyZodObject;
97-
} & AddNodeOptions;
97+
} & AddNodeOptions<Nodes>;
9898

9999
export type StateGraphArgsWithStateSchema<
100100
SD extends StateDefinition,
@@ -446,7 +446,12 @@ export class StateGraph<
446446
isMultipleNodes(args) // eslint-disable-line no-nested-ternary
447447
? Array.isArray(args[0])
448448
? args[0]
449-
: Object.entries(args[0])
449+
: Object.entries(args[0]).map(([key, action]) => [
450+
key,
451+
action,
452+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
453+
(action as any)[Symbol.for("langgraph.state.node")] ?? undefined,
454+
])
450455
: [[args[0], args[1], args[2]]]
451456
) as [
452457
K,
@@ -594,7 +599,12 @@ export class StateGraph<
594599
): StateGraph<SD, S, U, N | K, I, O, C> {
595600
const parsedNodes = Array.isArray(nodes)
596601
? nodes
597-
: (Object.entries(nodes) as [K, NodeAction<S, U, C>][]);
602+
: (Object.entries(nodes).map(([key, action]) => [
603+
key,
604+
action,
605+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
606+
(action as any)[Symbol.for("langgraph.state.node")] ?? undefined,
607+
]) as [K, NodeAction<S, U, C>, StateGraphAddNodeOptions | undefined][]);
598608

599609
if (parsedNodes.length === 0) {
600610
throw new Error("Sequence requires at least one node.");
@@ -1125,3 +1135,41 @@ function _getControlBranch() {
11251135
path: CONTROL_BRANCH_PATH,
11261136
});
11271137
}
1138+
1139+
type TypedNodeAction<SD extends StateDefinition, Nodes extends string> = (
1140+
state: StateType<SD>,
1141+
config: LangGraphRunnableConfig
1142+
) => UpdateType<SD> | Command<unknown, UpdateType<SD>, Nodes>;
1143+
1144+
export function typedNode<SD extends SDZod, Nodes extends string>(
1145+
_state: SD extends StateDefinition ? AnnotationRoot<SD> : never,
1146+
_options?: { nodes?: Nodes[] }
1147+
): (
1148+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1149+
options?: StateGraphAddNodeOptions<Nodes>
1150+
) => TypedNodeAction<ToStateDefinition<SD>, Nodes>;
1151+
1152+
export function typedNode<SD extends SDZod, Nodes extends string>(
1153+
_state: SD extends AnyZodObject ? SD : never,
1154+
_options?: { nodes?: Nodes[] }
1155+
): (
1156+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1157+
options?: StateGraphAddNodeOptions<Nodes>
1158+
) => TypedNodeAction<ToStateDefinition<SD>, Nodes>;
1159+
1160+
export function typedNode<SD extends SDZod, Nodes extends string>(
1161+
_state: SD extends AnyZodObject
1162+
? SD
1163+
: SD extends StateDefinition
1164+
? AnnotationRoot<SD>
1165+
: never,
1166+
_options?: { nodes?: Nodes[] }
1167+
) {
1168+
return (
1169+
func: TypedNodeAction<ToStateDefinition<SD>, Nodes>,
1170+
options?: StateGraphAddNodeOptions<Nodes>
1171+
) => {
1172+
Object.assign(func, { [Symbol.for("langgraph.state.node")]: options });
1173+
return func;
1174+
};
1175+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { z } from "zod";
2+
import { Command } from "../constants.js";
3+
import { Annotation } from "../graph/annotation.js";
4+
import {
5+
MessagesAnnotation,
6+
MessagesZodState,
7+
} from "../graph/messages_annotation.js";
8+
import { StateGraph, typedNode } from "../graph/state.js";
9+
import { _AnyIdHumanMessage } from "./utils.js";
10+
11+
it("Annotation.Root", async () => {
12+
const StateAnnotation = Annotation.Root({
13+
messages: MessagesAnnotation.spec.messages,
14+
foo: Annotation<string>,
15+
});
16+
17+
const node = typedNode(StateAnnotation, {
18+
nodes: ["nodeA", "nodeB", "nodeC"],
19+
});
20+
21+
const nodeA = node(
22+
(state) => {
23+
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
24+
return new Command({
25+
update: { messages: [{ type: "user", content: "a" }], foo: "a" },
26+
goto,
27+
});
28+
},
29+
{ ends: ["nodeB", "nodeC"] }
30+
);
31+
32+
const nodeB = node(() => {
33+
return new Command({
34+
goto: "nodeC",
35+
update: { foo: "123" },
36+
});
37+
});
38+
const nodeC = node((state) => ({ foo: `${state.foo}|c` }));
39+
40+
const graph = new StateGraph(StateAnnotation)
41+
.addNode({ nodeA, nodeB, nodeC })
42+
.addEdge("__start__", "nodeA")
43+
.compile();
44+
45+
expect(await graph.invoke({ foo: "foo" })).toEqual({
46+
messages: [new _AnyIdHumanMessage("a")],
47+
foo: "123|c",
48+
});
49+
});
50+
51+
it("Zod", async () => {
52+
const StateAnnotation = MessagesZodState.extend({
53+
foo: z.string(),
54+
});
55+
56+
const node = typedNode(StateAnnotation, {
57+
nodes: ["nodeA", "nodeB", "nodeC"],
58+
});
59+
60+
const nodeA = node(
61+
(state) => {
62+
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
63+
return new Command({
64+
update: { messages: [{ type: "user", content: "a" }], foo: "a" },
65+
goto,
66+
});
67+
},
68+
{ ends: ["nodeB", "nodeC"] }
69+
);
70+
71+
const nodeB = node(() => {
72+
return new Command({
73+
goto: "nodeC",
74+
update: { foo: "123" },
75+
});
76+
});
77+
78+
const nodeC = node((state) => ({ foo: `${state.foo}|c` }));
79+
80+
const graph = new StateGraph(StateAnnotation)
81+
.addNode({ nodeA, nodeB, nodeC })
82+
.addEdge("__start__", "nodeA")
83+
.compile();
84+
85+
expect(await graph.invoke({ foo: "foo" })).toEqual({
86+
messages: [new _AnyIdHumanMessage("a")],
87+
foo: "123|c",
88+
});
89+
});

0 commit comments

Comments
 (0)