Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory/postgre #186

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ require (
github.com/imdario/mergo v0.3.11 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mitchellh/copystructure v1.0.0 // indirect
Expand Down Expand Up @@ -82,4 +84,6 @@ require (
google.golang.org/api v0.122.0
google.golang.org/grpc v1.55.0
google.golang.org/protobuf v1.30.0
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have alot of deps already, let's make this optional behind an interface

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look ;)

)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.1 h1:oKfB/FhuVtit1bBM3zNRRsZ925ZkMN3HXL+LgLUM9lE=
github.com/jackc/pgx/v5 v5.4.1/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
Expand Down Expand Up @@ -682,6 +686,10 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
Expand Down
106 changes: 106 additions & 0 deletions memory/postgre/internal/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package internal

import (
"errors"

"github.com/tmc/langchaingo/schema"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

var (
ErrDBConnection = errors.New("can't connect to database")
ErrDBMigration = errors.New("can't migrate database")
ErrMissingSessionID = errors.New("session id can not be empty")
)

type Database struct {
gorm *gorm.DB
history *ChatHistory
sessionID string
}

func NewDatabase(dsn string) (*Database, error) {
database := &Database{
history: &ChatHistory{},
}

gorm, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, ErrDBConnection
}

database.gorm = gorm

err = database.gorm.AutoMigrate(ChatHistory{})
if err != nil {
return nil, ErrDBMigration
}

return database, nil
}

func (db *Database) SetSession(id string) {
db.sessionID = id
}

func (db *Database) SessionID() string {
return db.sessionID
}

func (db *Database) SaveHistory(msgs []schema.ChatMessage, bs string) error {
if db.sessionID == "" {
return ErrMissingSessionID
}

newMsgs := Messages{}
for _, msg := range msgs {
newMsgs = append(newMsgs, Message{
Type: string(msg.GetType()),
Text: msg.GetText(),
})
}

db.history.SessionID = db.sessionID
db.history.ChatHistory = &newMsgs
db.history.BufferString = bs

err := db.gorm.Save(&db.history).Error
if err != nil {
return err
}

return nil
}

func (db *Database) GetHistroy() ([]schema.ChatMessage, error) {
if db.sessionID == "" {
return nil, ErrMissingSessionID
}

err := db.gorm.Where(ChatHistory{SessionID: db.sessionID}).Find(&db.history).Error
if err != nil {
return nil, err
}

msgs := []schema.ChatMessage{}
if db.history.ChatHistory != nil {
for i := range *db.history.ChatHistory {
msg := (*db.history.ChatHistory)[i]

if msg.Type == "human" {
msgs = append(msgs, schema.HumanChatMessage{Text: msg.Text})
}

if msg.Type == "ai" {
msgs = append(msgs, schema.AIChatMessage{Text: msg.Text})
}
}
}

return msgs, nil
}

func (db *Database) ClearHistroy() error {
return nil
}
38 changes: 38 additions & 0 deletions memory/postgre/internal/history.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package internal

import (
"database/sql/driver"
"encoding/json"
"errors"
)

var ErrScanType = errors.New("could not scan type into Message")

type ChatHistory struct {
ID int `gorm:"primary_key"`
SessionID string `gorm:"type:varchar(256)"`
BufferString string `gorm:"type:text"`
ChatHistory *Messages `gorm:"type:jsonb;column:chat_history" json:"chat_history"`
}

type Message struct {
Type string `json:"type"`
Text string `json:"text"`
}

type Messages []Message

// Value implements the driver.Valuer interface, this method allows us to
// customize how we store the Message type in the database.
func (m Messages) Value() (driver.Value, error) {
return json.Marshal(m)
}

// Scan implements the sql.Scanner interface, this method allows us to
// define how we convert the Message data from the database into our Message type.
func (m *Messages) Scan(src interface{}) error {
if bytes, ok := src.([]byte); ok {
return json.Unmarshal(bytes, m)
}
return ErrScanType
}
183 changes: 183 additions & 0 deletions memory/postgre/postgre.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package postgre

import (
"errors"
"fmt"
"log"

"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/memory/postgre/internal"
"github.com/tmc/langchaingo/schema"
)

// ErrInvalidInputValues is returned when input values given to a memory in save context are invalid.
var ErrInvalidInputValues = errors.New("invalid input values")

type PostgreBuffer struct {
ChatHistory *memory.ChatMessageHistory
DB *internal.Database

ReturnMessages bool
InputKey string
OutputKey string
HumanPrefix string
AIPrefix string
MemoryKey string
}

var _ schema.Memory = &PostgreBuffer{}

func NewPostgreBuffer(dsn string) *PostgreBuffer {
buffer := PostgreBuffer{
ChatHistory: memory.NewChatMessageHistory(),

ReturnMessages: false,
InputKey: "",
OutputKey: "",
HumanPrefix: "Human",
AIPrefix: "AI",
MemoryKey: "history",
}

db, err := internal.NewDatabase(dsn)
if err != nil {
log.Fatal(err)
}

buffer.DB = db

return &buffer
}

func (buffer *PostgreBuffer) SetSession(id string) {
buffer.DB.SetSession(id)
}

func (buffer *PostgreBuffer) SessionID() string {
return buffer.DB.SessionID()
}

func (buffer *PostgreBuffer) MemoryVariables() []string {
return []string{buffer.MemoryKey}
}

func (buffer *PostgreBuffer) LoadMemoryVariables(inputs map[string]any) (map[string]any, error) {
msgs, err := buffer.DB.GetHistroy()
if err != nil {
return nil, err
}

buffer.ChatHistory = memory.NewChatMessageHistory(
memory.WithPreviousMessages(msgs),
)

if buffer.ReturnMessages {
return map[string]any{
buffer.MemoryKey: buffer.ChatHistory.Messages(),
}, nil
}

bufferString, err := schema.GetBufferString(buffer.ChatHistory.Messages(), buffer.HumanPrefix, buffer.AIPrefix)
if err != nil {
return nil, err
}

return map[string]any{
buffer.MemoryKey: bufferString,
}, nil
}

// SaveContext saves the context of the PostgreBuffer.
//
// It takes in two maps, inputs and outputs, which contain key-value pairs of any type.
// The function retrieves the value associated with buffer.InputKey from the inputs map
// and adds it as a user message to the ChatHistory. Then, it retrieves the value
// associated with buffer.OutputKey from the outputs map and adds it as an AI message
// to the ChatHistory. The function then uses the ChatHistory, HumanPrefix, and AIPrefix
// properties of the buffer to generate a bufferString using the GetBufferString function
// from the schema package. Finally, it saves the ChatHistory messages and bufferString
// to the DB using the SaveHistory function, and returns any error encountered.
//
// Return type: error.
func (buffer *PostgreBuffer) SaveContext(inputs map[string]any, outputs map[string]any) error {
userInputValue, err := getInputValue(inputs, buffer.InputKey)
if err != nil {
return err
}

buffer.ChatHistory.AddUserMessage(userInputValue)

aiOutPutValue, err := getInputValue(outputs, buffer.OutputKey)
if err != nil {
return err
}

buffer.ChatHistory.AddAIMessage(aiOutPutValue)

bufferString, err := schema.GetBufferString(buffer.ChatHistory.Messages(), buffer.HumanPrefix, buffer.AIPrefix)
if err != nil {
return err
}

err = buffer.DB.SaveHistory(buffer.ChatHistory.Messages(), bufferString)
if err != nil {
return err
}

return nil
}

func (buffer *PostgreBuffer) Clear() error {
buffer.ChatHistory.Clear()
err := buffer.DB.ClearHistroy()
if err != nil {
return err
}
return nil
}

func getInputValue(inputValues map[string]any, inputKey string) (string, error) {
// If the input key is set, return the value in the inputValues with the input key.
if inputKey != "" {
inputValue, ok := inputValues[inputKey]
if !ok {
return "", fmt.Errorf(
"%w: %v do not contain inputKey %s",
ErrInvalidInputValues,
inputValues,
inputKey,
)
}

return getInputValueReturnToString(inputValue)
}

// Otherwise error if length of map isn't one, or return the only entry in the map.
if len(inputValues) > 1 {
return "", fmt.Errorf(
"%w: multiple keys and no input key set",
ErrInvalidInputValues,
)
}

for _, inputValue := range inputValues {
return getInputValueReturnToString(inputValue)
}

return "", fmt.Errorf("%w: 0 keys", ErrInvalidInputValues)
}

func getInputValueReturnToString(
inputValue interface{},
) (string, error) {
switch value := inputValue.(type) {
case string:
return value, nil
default:
return "", fmt.Errorf(
"%w: input value %v not string",
ErrInvalidInputValues,
inputValue,
)
}
}