Skip to content

Commit 72a06f1

Browse files
fix: support using static values for all input (#218)
1 parent d2834e3 commit 72a06f1

File tree

3 files changed

+71
-17
lines changed

3 files changed

+71
-17
lines changed

compose/field_mapping.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -686,19 +686,14 @@ func validateFieldMapping(predecessorType reflect.Type, successorType reflect.Ty
686686
return nil, fmt.Errorf("static check fail: predecessor output type should be struct or map, actual: %v", predecessorType)
687687
}
688688

689-
var (
690-
predecessorFieldType, successorFieldType reflect.Type
691-
err error
692-
predecessorIntermediateInterface, successorIntermediateInterface bool
693-
)
694-
695-
for _, mapping := range mappings {
696-
predecessorFieldType, predecessorIntermediateInterface, err = checkAndExtractFieldType(splitFieldPath(mapping.from), predecessorType)
689+
for i := range mappings {
690+
mapping := mappings[i]
691+
predecessorFieldType, predecessorIntermediateInterface, err := checkAndExtractFieldType(splitFieldPath(mapping.from), predecessorType)
697692
if err != nil {
698693
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
699694
}
700695

701-
successorFieldType, successorIntermediateInterface, err = checkAndExtractFieldType(splitFieldPath(mapping.to), successorType)
696+
successorFieldType, successorIntermediateInterface, err := checkAndExtractFieldType(splitFieldPath(mapping.to), successorType)
702697
if err != nil {
703698
return nil, fmt.Errorf("static check failed for mapping %s: %w", mapping, err)
704699
}

compose/workflow.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,7 @@ func (wf *Workflow[I, O]) compile(ctx context.Context, options *graphCompileOpti
451451
return mergeValues(values)
452452
},
453453
transform: func(in streamReader) streamReader {
454-
sr, sw := schema.Pipe[map[string]any](1)
455-
sw.Send(value, nil)
456-
sw.Close()
457-
454+
sr := schema.StreamReaderFromArray([]map[string]any{value})
458455
newS, err := mergeValues([]any{in, packStreamReader(sr)})
459456
if err != nil {
460457
errSR, errSW := schema.Pipe[map[string]any](1)
@@ -467,11 +464,11 @@ func (wf *Workflow[I, O]) compile(ctx context.Context, options *graphCompileOpti
467464
},
468465
}
469466

470-
if _, ok := wf.g.handlerPreNode[n.key]; !ok {
471-
wf.g.handlerPreNode[n.key] = []handlerPair{pair}
472-
} else {
473-
wf.g.handlerPreNode[n.key] = append([]handlerPair{pair}, wf.g.handlerPreNode[n.key]...)
467+
for i := range paths {
468+
wf.g.fieldMappingRecords[n.key] = append(wf.g.fieldMappingRecords[n.key], ToFieldPath(paths[i]))
474469
}
470+
471+
wf.g.handlerPreNode[n.key] = []handlerPair{pair}
475472
}
476473
}
477474

compose/workflow_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,68 @@ func TestStaticValue(t *testing.T) {
899899
_, err := wf.Compile(context.Background())
900900
assert.ErrorContains(t, err, "two terminal field paths conflict for node 0: [prefilled], [prefilled]")
901901
})
902+
903+
t.Run("all inputs are static values", func(t *testing.T) {
904+
wf := NewWorkflow[string, map[string]any]()
905+
wf.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) {
906+
return in, nil
907+
})).
908+
AddDependency(START).
909+
SetStaticValue(FieldPath{"a", "b"}, "a_b").
910+
SetStaticValue(FieldPath{"c", "d"}, "c_d").
911+
SetStaticValue(FieldPath{"a", "d"}, "a_d")
912+
wf.End().AddInput("0")
913+
r, err := wf.Compile(context.Background())
914+
assert.NoError(t, err)
915+
out, err := r.Invoke(context.Background(), "hello")
916+
assert.NoError(t, err)
917+
assert.Equal(t, map[string]any{
918+
"a": map[string]any{
919+
"b": "a_b",
920+
"d": "a_d",
921+
},
922+
"c": map[string]any{
923+
"d": "c_d",
924+
},
925+
}, out)
926+
927+
type a struct {
928+
B string
929+
D string
930+
}
931+
932+
type s struct {
933+
A a
934+
C map[string]any
935+
}
936+
937+
wf1 := NewWorkflow[string, *s]()
938+
wf1.AddLambdaNode("0", InvokableLambda(func(ctx context.Context, in map[string]any) (output map[string]any, err error) {
939+
return in, nil
940+
})).
941+
AddDependency(START).
942+
SetStaticValue(FieldPath{"A", "B"}, "a_b").
943+
SetStaticValue(FieldPath{"C", "D"}, "c_d").
944+
SetStaticValue(FieldPath{"A", "D"}, "a_d")
945+
wf1.End().AddInput("0", MapFieldPaths(FieldPath{"A", "B"}, FieldPath{"A", "B"}),
946+
MapFieldPaths(FieldPath{"A", "D"}, FieldPath{"A", "D"}),
947+
MapFields("C", "C"))
948+
r1, err := wf1.Compile(context.Background())
949+
assert.NoError(t, err)
950+
out1, err := r1.Stream(context.Background(), "hello")
951+
assert.NoError(t, err)
952+
outChunk, err := out1.Recv()
953+
out1.Close()
954+
assert.Equal(t, &s{
955+
A: a{
956+
B: "a_b",
957+
D: "a_d",
958+
},
959+
C: map[string]any{
960+
"D": "c_d",
961+
},
962+
}, outChunk)
963+
})
902964
}
903965

904966
func TestBranch(t *testing.T) {

0 commit comments

Comments
 (0)