Skip to content

Commit

Permalink
Merge pull request #249 from edocevol/feature/memory_support_context
Browse files Browse the repository at this point in the history
memory: support call with context
  • Loading branch information
tmc authored Aug 15, 2023
2 parents fd8b7f0 + cb3a6d5 commit eb0cbd3
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 74 deletions.
8 changes: 4 additions & 4 deletions chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ type Chain interface {
}

// Call is the standard function used for executing chains.
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint: lll
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
fullValues := make(map[string]any, 0)
for key, value := range inputValues {
fullValues[key] = value
}

newValues, err := c.GetMemory().LoadMemoryVariables(inputValues)
newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
if err != nil {
return nil, err
}
Expand All @@ -53,7 +53,7 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
return nil, err
}

err = c.GetMemory().SaveContext(inputValues, outputValues)
err = c.GetMemory().SaveContext(ctx, inputValues, outputValues)
if err != nil {
return nil, err
}
Expand All @@ -65,7 +65,7 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
// string output.
func Run(ctx context.Context, c Chain, input any, options ...ChainCallOption) (string, error) {
inputKeys := c.GetInputKeys()
memoryKeys := c.GetMemory().MemoryVariables()
memoryKeys := c.GetMemory().MemoryVariables(ctx)
neededKeys := make([]string, 0, len(inputKeys))

// Remove keys gotten from the memory.
Expand Down
6 changes: 3 additions & 3 deletions chains/conversational_retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ func NewConversationalRetrievalQAFromLLM(

// Call gets question, and relevant documents by question from the retriever and gives them to the combine
// documents chain.
func (c ConversationalRetrievalQA) Call(ctx context.Context, values map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint: lll
func (c ConversationalRetrievalQA) Call(ctx context.Context, values map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
query, ok := values[c.InputKey].(string)
if !ok {
return nil, fmt.Errorf("%w: %w", ErrInvalidInputValues, ErrInputValuesWrongType)
}
chatHistoryStr, ok := values[c.Memory.GetMemoryKey()].(string)
chatHistoryStr, ok := values[c.Memory.GetMemoryKey(ctx)].(string)
if !ok {
chatHistory, ok := values[c.Memory.GetMemoryKey()].([]schema.ChatMessage)
chatHistory, ok := values[c.Memory.GetMemoryKey(ctx)].([]schema.ChatMessage)
if !ok {
return nil, fmt.Errorf("%w: %w", ErrMissingMemoryKeyValues, ErrMemoryValuesWrongType)
}
Expand Down
2 changes: 1 addition & 1 deletion chains/sequential.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (c *SequentialChain) validateSeqChain() error {
knownKeys := util.ToSet(c.inputKeys)

// Make sure memory keys don't collide with input keys
memoryKeys := c.memory.MemoryVariables()
memoryKeys := c.memory.MemoryVariables(context.Background())
overlappingKeys := util.Intersection(memoryKeys, knownKeys)
if len(overlappingKeys) > 0 {
return fmt.Errorf(
Expand Down
21 changes: 14 additions & 7 deletions memory/buffer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memory

import (
"context"
"errors"
"fmt"

Expand Down Expand Up @@ -31,16 +32,18 @@ func NewConversationBuffer(options ...ConversationBufferOption) *ConversationBuf
}

// MemoryVariables gets the input key the buffer memory class will load dynamically.
func (m *ConversationBuffer) MemoryVariables() []string {
func (m *ConversationBuffer) MemoryVariables(context.Context) []string {
return []string{m.MemoryKey}
}

// LoadMemoryVariables returns the previous chat messages stored in memory. Previous chat messages
// are returned in a map with the key specified in the MemoryKey field. This key defaults to
// "history". If ReturnMessages is set to true the output is a slice of schema.ChatMessage. Otherwise
// the output is a buffer string of the chat messages.
func (m *ConversationBuffer) LoadMemoryVariables(map[string]any) (map[string]any, error) {
messages, err := m.ChatHistory.Messages()
func (m *ConversationBuffer) LoadMemoryVariables(
ctx context.Context, _ map[string]any,
) (map[string]any, error) {
messages, err := m.ChatHistory.Messages(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -68,12 +71,16 @@ func (m *ConversationBuffer) LoadMemoryVariables(map[string]any) (map[string]any
// input key must be a key in the input values and the output key must be a key in the output
// values. The values in the input and output values used to save a user and ai message must
// be strings.
func (m *ConversationBuffer) SaveContext(inputValues map[string]any, outputValues map[string]any) error {
func (m *ConversationBuffer) SaveContext(
ctx context.Context,
inputValues map[string]any,
outputValues map[string]any,
) error {
userInputValue, err := getInputValue(inputValues, m.InputKey)
if err != nil {
return err
}
err = m.ChatHistory.AddUserMessage(userInputValue)
err = m.ChatHistory.AddUserMessage(ctx, userInputValue)
if err != nil {
return err
}
Expand All @@ -82,7 +89,7 @@ func (m *ConversationBuffer) SaveContext(inputValues map[string]any, outputValue
if err != nil {
return err
}
err = m.ChatHistory.AddAIMessage(aiOutputValue)
err = m.ChatHistory.AddAIMessage(ctx, aiOutputValue)
if err != nil {
return err
}
Expand All @@ -95,7 +102,7 @@ func (m *ConversationBuffer) Clear() error {
return m.ChatHistory.Clear()
}

func (m *ConversationBuffer) GetMemoryKey() string {
func (m *ConversationBuffer) GetMemoryKey(context.Context) string {
return m.MemoryKey
}

Expand Down
29 changes: 15 additions & 14 deletions memory/buffer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memory

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,15 +13,15 @@ func TestBufferMemory(t *testing.T) {
t.Parallel()

m := NewConversationBuffer()
result1, err := m.LoadMemoryVariables(map[string]any{})
result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)
expected1 := map[string]any{"history": ""}
assert.Equal(t, expected1, result1)

err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"})
err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"})
require.NoError(t, err)

result2, err := m.LoadMemoryVariables(map[string]any{})
result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)

expected2 := map[string]any{"history": "Human: bar\nAI: foo"}
Expand All @@ -33,14 +34,14 @@ func TestBufferMemoryReturnMessage(t *testing.T) {
m := NewConversationBuffer()
m.ReturnMessages = true
expected1 := map[string]any{"history": []schema.ChatMessage{}}
result1, err := m.LoadMemoryVariables(map[string]any{})
result1, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)
assert.Equal(t, expected1, result1)

err = m.SaveContext(map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"})
err = m.SaveContext(context.Background(), map[string]any{"foo": "bar"}, map[string]any{"bar": "foo"})
require.NoError(t, err)

result2, err := m.LoadMemoryVariables(map[string]any{})
result2, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)

expectedChatHistory := NewChatMessageHistory(
Expand All @@ -50,7 +51,7 @@ func TestBufferMemoryReturnMessage(t *testing.T) {
}),
)

messages, err := expectedChatHistory.Messages()
messages, err := expectedChatHistory.Messages(context.Background())
assert.NoError(t, err)
expected2 := map[string]any{"history": messages}
assert.Equal(t, expected2, result2)
Expand All @@ -66,7 +67,7 @@ func TestBufferMemoryWithPreLoadedHistory(t *testing.T) {
}),
)))

result, err := m.LoadMemoryVariables(map[string]any{})
result, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)
expected := map[string]any{"history": "Human: bar\nAI: foo"}
assert.Equal(t, expected, result)
Expand All @@ -76,27 +77,27 @@ type testChatMessageHistory struct{}

var _ schema.ChatMessageHistory = testChatMessageHistory{}

func (t testChatMessageHistory) AddUserMessage(_ string) error {
func (t testChatMessageHistory) AddUserMessage(context.Context, string) error {
return nil
}

func (t testChatMessageHistory) AddAIMessage(_ string) error {
func (t testChatMessageHistory) AddAIMessage(context.Context, string) error {
return nil
}

func (t testChatMessageHistory) AddMessage(_ schema.ChatMessage) error {
func (t testChatMessageHistory) AddMessage(context.Context, schema.ChatMessage) error {
return nil
}

func (t testChatMessageHistory) Clear() error {
return nil
}

func (t testChatMessageHistory) SetMessages(_ []schema.ChatMessage) error {
func (t testChatMessageHistory) SetMessages(context.Context, []schema.ChatMessage) error {
return nil
}

func (t testChatMessageHistory) Messages() ([]schema.ChatMessage, error) {
func (t testChatMessageHistory) Messages(context.Context) ([]schema.ChatMessage, error) {
return []schema.ChatMessage{
schema.HumanChatMessage{Content: "user message test"},
schema.AIChatMessage{Content: "ai message test"},
Expand All @@ -109,7 +110,7 @@ func TestBufferMemoryWithChatHistoryOption(t *testing.T) {
chatMessageHistory := testChatMessageHistory{}
m := NewConversationBuffer(WithChatHistory(chatMessageHistory))

result, err := m.LoadMemoryVariables(map[string]any{})
result, err := m.LoadMemoryVariables(context.Background(), map[string]any{})
require.NoError(t, err)
expected := map[string]any{"history": "Human: user message test\nAI: ai message test"}
assert.Equal(t, expected, result)
Expand Down
16 changes: 10 additions & 6 deletions memory/chat.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package memory

import "github.com/tmc/langchaingo/schema"
import (
"context"

"github.com/tmc/langchaingo/schema"
)

// ChatMessageHistory is a struct that stores chat messages.
type ChatMessageHistory struct {
Expand All @@ -16,18 +20,18 @@ func NewChatMessageHistory(options ...ChatMessageHistoryOption) *ChatMessageHist
}

// Messages returns all messages stored.
func (h *ChatMessageHistory) Messages() ([]schema.ChatMessage, error) {
func (h *ChatMessageHistory) Messages(_ context.Context) ([]schema.ChatMessage, error) {
return h.messages, nil
}

// AddAIMessage adds an AIMessage to the chat message history.
func (h *ChatMessageHistory) AddAIMessage(text string) error {
func (h *ChatMessageHistory) AddAIMessage(_ context.Context, text string) error {
h.messages = append(h.messages, schema.AIChatMessage{Content: text})
return nil
}

// AddUserMessage adds an user to the chat message history.
func (h *ChatMessageHistory) AddUserMessage(text string) error {
func (h *ChatMessageHistory) AddUserMessage(_ context.Context, text string) error {
h.messages = append(h.messages, schema.HumanChatMessage{Content: text})
return nil
}
Expand All @@ -37,12 +41,12 @@ func (h *ChatMessageHistory) Clear() error {
return nil
}

func (h *ChatMessageHistory) AddMessage(message schema.ChatMessage) error {
func (h *ChatMessageHistory) AddMessage(_ context.Context, message schema.ChatMessage) error {
h.messages = append(h.messages, message)
return nil
}

func (h *ChatMessageHistory) SetMessages(messages []schema.ChatMessage) error {
func (h *ChatMessageHistory) SetMessages(_ context.Context, messages []schema.ChatMessage) error {
h.messages = messages
return nil
}
11 changes: 6 additions & 5 deletions memory/chat_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memory

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -11,12 +12,12 @@ func TestChatMessageHistory(t *testing.T) {
t.Parallel()

h := NewChatMessageHistory()
err := h.AddAIMessage("foo")
err := h.AddAIMessage(context.Background(), "foo")
assert.NoError(t, err)
err = h.AddUserMessage("bar")
err = h.AddUserMessage(context.Background(), "bar")
assert.NoError(t, err)

messages, err := h.Messages()
messages, err := h.Messages(context.Background())
assert.NoError(t, err)

assert.Equal(t, []schema.ChatMessage{
Expand All @@ -30,10 +31,10 @@ func TestChatMessageHistory(t *testing.T) {
schema.SystemChatMessage{Content: "bar"},
}),
)
err = h.AddUserMessage("zoo")
err = h.AddUserMessage(context.Background(), "zoo")
assert.NoError(t, err)

messages, err = h.Messages()
messages, err = h.Messages(context.Background())
assert.NoError(t, err)

assert.Equal(t, []schema.ChatMessage{
Expand Down
12 changes: 7 additions & 5 deletions memory/simple.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package memory

import (
"context"

"github.com/tmc/langchaingo/schema"
)

Expand All @@ -15,22 +17,22 @@ func NewSimple() Simple {
// Statically assert that Simple implement the memory interface.
var _ schema.Memory = Simple{}

func (m Simple) MemoryVariables() []string {
func (m Simple) MemoryVariables(context.Context) []string {
return nil
}

func (m Simple) LoadMemoryVariables(map[string]any) (map[string]any, error) {
return make(map[string]any, 0), nil
func (m Simple) LoadMemoryVariables(context.Context, map[string]any) (map[string]any, error) {
return make(map[string]any), nil
}

func (m Simple) SaveContext(map[string]any, map[string]any) error {
func (m Simple) SaveContext(context.Context, map[string]any, map[string]any) error {
return nil
}

func (m Simple) Clear() error {
return nil
}

func (m Simple) GetMemoryKey() string {
func (m Simple) GetMemoryKey(context.Context) string {
return ""
}
Loading

0 comments on commit eb0cbd3

Please sign in to comment.