Skip to content

Commit

Permalink
Refactor to use enhanced postgres checkpointer (#117)
Browse files Browse the repository at this point in the history
* Upgrade langgraph version to v0.2.3. Add langgraph-checkpoint-postgres

* Update models for new checkpointer

* Create migration file for new checkpointer

* Refactor code to use new AsyncPostgresSaver. Delete old PostgresSaver class.
  • Loading branch information
StreetLamb authored Aug 13, 2024
1 parent 84c4033 commit 5920256
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 678 deletions.
105 changes: 105 additions & 0 deletions backend/app/alembic/versions/20f584dc80d2_upgrade_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '20f584dc80d2'
down_revision = '38a9c73bfce2'
branch_labels = None
depends_on = None

def upgrade():
# Create new tables
op.create_table('checkpoint_blobs',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('version', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('blob', sa.LargeBinary(), nullable=True),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'channel', 'version')
)
op.create_table('checkpoint_writes',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('checkpoint_id', sa.Uuid(), nullable=False),
sa.Column('task_id', sa.Uuid(), nullable=False),
sa.Column('idx', sa.Integer(), nullable=False),
sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('blob', sa.LargeBinary(), nullable=False),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx')
)

# Drop the old table
op.drop_table('writes')

# Rename and recreate the checkpoints table
op.rename_table('checkpoints', 'checkpoints_old')

op.create_table(
'checkpoints',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('checkpoint_id', sa.Uuid(), nullable=False),
sa.Column('parent_checkpoint_id', sa.Uuid(), nullable=True),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('checkpoint', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id')
)

# Drop the old checkpoints table
op.drop_table('checkpoints_old')

# Clear the threads table
op.execute('DELETE FROM thread')
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###

# Recreate the old checkpoints table
op.create_table(
'checkpoints_old',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('thread_ts', sa.Uuid(), nullable=False),
sa.Column('parent_ts', sa.Uuid(), nullable=True),
sa.Column('checkpoint', sa.LargeBinary(), nullable=True),
sa.Column('metadata', sa.LargeBinary(), nullable=False, server_default=sa.text("'\\x'::bytea")),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'thread_ts')
)

# Drop the new checkpoints table
op.drop_table('checkpoints')

# Rename the old table back to 'checkpoints'
op.rename_table('checkpoints_old', 'checkpoints')

# Recreate the old 'writes' table
op.create_table('writes',
sa.Column('thread_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('thread_ts', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('task_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('idx', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('channel', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('value', postgresql.BYTEA(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id'], name='writes_thread_id_fkey'),
sa.PrimaryKeyConstraint('thread_id', 'thread_ts', 'task_id', 'idx', name='writes_pkey')
)

# Drop the new tables
op.drop_table('checkpoint_writes')
op.drop_table('checkpoint_blobs')

# Clear the threads table
op.execute('DELETE FROM thread')
# ### end Alembic commands ###
8 changes: 8 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Annotated, Any, Literal

from psycopg.rows import dict_row
from pydantic import (
AnyUrl,
BeforeValidator,
Expand Down Expand Up @@ -66,6 +67,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
path=self.POSTGRES_DB,
)

# For checkpointer
SQLALCHEMY_CONNECTION_KWARGS: dict[str, Any] = {
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
}

@computed_field # type: ignore[misc]
@property
def PG_DATABASE_URI(self) -> str:
Expand Down
11 changes: 7 additions & 4 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import (
Expand All @@ -22,7 +23,6 @@
from psycopg import AsyncConnection

from app.core.config import settings
from app.core.graph.checkpoint.postgres import PostgresSaver
from app.core.graph.members import (
GraphLeader,
GraphMember,
Expand Down Expand Up @@ -471,8 +471,11 @@ async def generator(
]

try:
async with await AsyncConnection.connect(settings.PG_DATABASE_URI) as conn:
checkpointer = PostgresSaver(async_connection=conn)
async with await AsyncConnection.connect(
settings.PG_DATABASE_URI,
**settings.SQLALCHEMY_CONNECTION_KWARGS,
) as conn:
checkpointer = AsyncPostgresSaver(conn=conn)
if team.workflow == "hierarchical":
teams = convert_hierarchical_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
Expand Down
Loading

0 comments on commit 5920256

Please sign in to comment.