Skip to content

Commit

Permalink
Merge pull request #220 from tobiade/seqchain-memory
Browse files Browse the repository at this point in the history
chains: add memory support for sequential chain
  • Loading branch information
tmc authored Jul 25, 2023
2 parents 3f372d1 + b2aa463 commit 78b1a81
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
51 changes: 40 additions & 11 deletions chains/sequential.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,44 @@ type SequentialChain struct {
memory schema.Memory
}

func NewSequentialChain(chains []Chain, inputKeys []string, outputKeys []string) (*SequentialChain, error) {
if err := validateSeqChain(chains, inputKeys, outputKeys); err != nil {
return nil, err
}

return &SequentialChain{
func NewSequentialChain(chains []Chain, inputKeys []string, outputKeys []string, opts ...SequentialChainOption) (*SequentialChain, error) { //nolint:lll
s := &SequentialChain{
chains: chains,
inputKeys: inputKeys,
outputKeys: outputKeys,
memory: memory.NewSimple(),
}, nil
}

for _, opt := range opts {
opt(s)
}

if err := s.validateSeqChain(); err != nil {
return nil, err
}

return s, nil
}

func validateSeqChain(chain []Chain, inputKeys []string, outputKeys []string) error {
knownKeys := util.ToSet(inputKeys)
func (c *SequentialChain) validateSeqChain() error {
knownKeys := util.ToSet(c.inputKeys)

// Make sure memory keys don't collide with input keys
memoryKeys := c.memory.MemoryVariables()
overlappingKeys := util.Intersection(memoryKeys, knownKeys)
if len(overlappingKeys) > 0 {
return fmt.Errorf(
"%w: input keys [%v] also exist in the memory keys: [%v] - please use input keys and memory keys that don't overlap",
ErrChainInitialization, strings.Join(overlappingKeys, delimiter), strings.Join(memoryKeys, delimiter),
)
}

// Add memory keys to known keys
for _, key := range memoryKeys {
knownKeys[key] = struct{}{}
}

for i, c := range chain {
for i, c := range c.chains {
// Check that chain has input keys that are in knownKeys
missingKeys := util.Difference(c.GetInputKeys(), knownKeys)
if len(missingKeys) > 0 {
Expand All @@ -64,7 +85,7 @@ func validateSeqChain(chain []Chain, inputKeys []string, outputKeys []string) er
}

// Check that outputKeys are in knownKeys
for _, key := range outputKeys {
for _, key := range c.outputKeys {
if _, ok := knownKeys[key]; !ok {
return fmt.Errorf("%w: output key %s is not in the known keys", ErrChainInitialization, key)
}
Expand Down Expand Up @@ -179,3 +200,11 @@ func (c *SimpleSequentialChain) GetInputKeys() []string {
func (c *SimpleSequentialChain) GetOutputKeys() []string {
return []string{output}
}

type SequentialChainOption func(*SequentialChain)

func WithSeqChainMemory(memory schema.Memory) SequentialChainOption {
return func(c *SequentialChain) {
c.memory = memory
}
}
19 changes: 14 additions & 5 deletions chains/sequential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ func TestSequentialChainErrors(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
chains []Chain
initErr error
execErr error
name string
chains []Chain
initErr error
execErr error
seqChainOpts []SequentialChainOption
}{
{
name: "missing input key",
Expand Down Expand Up @@ -165,13 +166,21 @@ func TestSequentialChainErrors(t *testing.T) {
},
execErr: errDummy,
},
{
name: "memory key collides with input key",
chains: []Chain{
&testLLMChain{inputKeys: []string{"input1"}, outputKeys: []string{"output"}},
},
initErr: ErrChainInitialization,
seqChainOpts: []SequentialChainOption{WithSeqChainMemory(memory.NewBuffer(memory.WithMemoryKey("input1")))},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
c, err := NewSequentialChain(tc.chains, []string{"input1", "input2"}, []string{"output"})
c, err := NewSequentialChain(tc.chains, []string{"input1", "input2"}, []string{"output"}, tc.seqChainOpts...)
if tc.initErr != nil {
assert.ErrorIs(t, err, tc.initErr)
} else {
Expand Down
4 changes: 2 additions & 2 deletions docs/parity_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Please note that this page lists the current state of the LangChain Go project,
| PAL Chain ||
| LLM Requests Chain ||
| Moderation Chain ||
| Sequential Chain | |
| Simple Sequential Chain | |
| Sequential Chain | |
| Simple Sequential Chain | |

## Agents

Expand Down

0 comments on commit 78b1a81

Please sign in to comment.