@@ -88,13 +88,13 @@ export type StateGraphNodeSpec<RunInput, RunOutput> = NodeSpec<
88
88
cachePolicy ?: CachePolicy ;
89
89
} ;
90
90
91
- export type StateGraphAddNodeOptions = {
91
+ export type StateGraphAddNodeOptions < Nodes extends string = string > = {
92
92
retryPolicy ?: RetryPolicy ;
93
93
cachePolicy ?: CachePolicy | boolean ;
94
94
// TODO: Fix generic typing for annotations
95
95
// eslint-disable-next-line @typescript-eslint/no-explicit-any
96
96
input ?: AnnotationRoot < any > | AnyZodObject ;
97
- } & AddNodeOptions ;
97
+ } & AddNodeOptions < Nodes > ;
98
98
99
99
export type StateGraphArgsWithStateSchema <
100
100
SD extends StateDefinition ,
@@ -446,7 +446,12 @@ export class StateGraph<
446
446
isMultipleNodes ( args ) // eslint-disable-line no-nested-ternary
447
447
? Array . isArray ( args [ 0 ] )
448
448
? args [ 0 ]
449
- : Object . entries ( args [ 0 ] )
449
+ : Object . entries ( args [ 0 ] ) . map ( ( [ key , action ] ) => [
450
+ key ,
451
+ action ,
452
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
453
+ ( action as any ) [ Symbol . for ( "langgraph.state.node" ) ] ?? undefined ,
454
+ ] )
450
455
: [ [ args [ 0 ] , args [ 1 ] , args [ 2 ] ] ]
451
456
) as [
452
457
K ,
@@ -594,7 +599,12 @@ export class StateGraph<
594
599
) : StateGraph < SD , S , U , N | K , I , O , C > {
595
600
const parsedNodes = Array . isArray ( nodes )
596
601
? nodes
597
- : ( Object . entries ( nodes ) as [ K , NodeAction < S , U , C > ] [ ] ) ;
602
+ : ( Object . entries ( nodes ) . map ( ( [ key , action ] ) => [
603
+ key ,
604
+ action ,
605
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
606
+ ( action as any ) [ Symbol . for ( "langgraph.state.node" ) ] ?? undefined ,
607
+ ] ) as [ K , NodeAction < S , U , C > , StateGraphAddNodeOptions | undefined ] [ ] ) ;
598
608
599
609
if ( parsedNodes . length === 0 ) {
600
610
throw new Error ( "Sequence requires at least one node." ) ;
@@ -1125,3 +1135,41 @@ function _getControlBranch() {
1125
1135
path : CONTROL_BRANCH_PATH ,
1126
1136
} ) ;
1127
1137
}
1138
+
1139
+ type TypedNodeAction < SD extends StateDefinition , Nodes extends string > = (
1140
+ state : StateType < SD > ,
1141
+ config : LangGraphRunnableConfig
1142
+ ) => UpdateType < SD > | Command < unknown , UpdateType < SD > , Nodes > ;
1143
+
1144
+ export function typedNode < SD extends SDZod , Nodes extends string > (
1145
+ _state : SD extends StateDefinition ? AnnotationRoot < SD > : never ,
1146
+ _options ?: { nodes ?: Nodes [ ] }
1147
+ ) : (
1148
+ func : TypedNodeAction < ToStateDefinition < SD > , Nodes > ,
1149
+ options ?: StateGraphAddNodeOptions < Nodes >
1150
+ ) => TypedNodeAction < ToStateDefinition < SD > , Nodes > ;
1151
+
1152
+ export function typedNode < SD extends SDZod , Nodes extends string > (
1153
+ _state : SD extends AnyZodObject ? SD : never ,
1154
+ _options ?: { nodes ?: Nodes [ ] }
1155
+ ) : (
1156
+ func : TypedNodeAction < ToStateDefinition < SD > , Nodes > ,
1157
+ options ?: StateGraphAddNodeOptions < Nodes >
1158
+ ) => TypedNodeAction < ToStateDefinition < SD > , Nodes > ;
1159
+
1160
+ export function typedNode < SD extends SDZod , Nodes extends string > (
1161
+ _state : SD extends AnyZodObject
1162
+ ? SD
1163
+ : SD extends StateDefinition
1164
+ ? AnnotationRoot < SD >
1165
+ : never ,
1166
+ _options ?: { nodes ?: Nodes [ ] }
1167
+ ) {
1168
+ return (
1169
+ func : TypedNodeAction < ToStateDefinition < SD > , Nodes > ,
1170
+ options ?: StateGraphAddNodeOptions < Nodes >
1171
+ ) => {
1172
+ Object . assign ( func , { [ Symbol . for ( "langgraph.state.node" ) ] : options } ) ;
1173
+ return func ;
1174
+ } ;
1175
+ }
0 commit comments