Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use enhanced postgres checkpointer #117

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions backend/app/alembic/versions/20f584dc80d2_upgrade_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '20f584dc80d2'
down_revision = '38a9c73bfce2'
branch_labels = None
depends_on = None

def upgrade():
# Create new tables
op.create_table('checkpoint_blobs',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('version', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('blob', sa.LargeBinary(), nullable=True),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'channel', 'version')
)
op.create_table('checkpoint_writes',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('checkpoint_id', sa.Uuid(), nullable=False),
sa.Column('task_id', sa.Uuid(), nullable=False),
sa.Column('idx', sa.Integer(), nullable=False),
sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('blob', sa.LargeBinary(), nullable=False),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx')
)

# Drop the old table
op.drop_table('writes')

# Rename and recreate the checkpoints table
op.rename_table('checkpoints', 'checkpoints_old')

op.create_table(
'checkpoints',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('checkpoint_ns', sa.String(), nullable=False, server_default=''),
sa.Column('checkpoint_id', sa.Uuid(), nullable=False),
sa.Column('parent_checkpoint_id', sa.Uuid(), nullable=True),
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('checkpoint', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id')
)

# Drop the old checkpoints table
op.drop_table('checkpoints_old')

# Clear the threads table
op.execute('DELETE FROM thread')
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###

# Recreate the old checkpoints table
op.create_table(
'checkpoints_old',
sa.Column('thread_id', sa.Uuid(), nullable=False),
sa.Column('thread_ts', sa.Uuid(), nullable=False),
sa.Column('parent_ts', sa.Uuid(), nullable=True),
sa.Column('checkpoint', sa.LargeBinary(), nullable=True),
sa.Column('metadata', sa.LargeBinary(), nullable=False, server_default=sa.text("'\\x'::bytea")),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id']),
sa.PrimaryKeyConstraint('thread_id', 'thread_ts')
)

# Drop the new checkpoints table
op.drop_table('checkpoints')

# Rename the old table back to 'checkpoints'
op.rename_table('checkpoints_old', 'checkpoints')

# Recreate the old 'writes' table
op.create_table('writes',
sa.Column('thread_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('thread_ts', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('task_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('idx', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('channel', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('value', postgresql.BYTEA(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id'], name='writes_thread_id_fkey'),
sa.PrimaryKeyConstraint('thread_id', 'thread_ts', 'task_id', 'idx', name='writes_pkey')
)

# Drop the new tables
op.drop_table('checkpoint_writes')
op.drop_table('checkpoint_blobs')

# Clear the threads table
op.execute('DELETE FROM thread')
# ### end Alembic commands ###
8 changes: 8 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Annotated, Any, Literal

from psycopg.rows import dict_row
from pydantic import (
AnyUrl,
BeforeValidator,
Expand Down Expand Up @@ -66,6 +67,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
path=self.POSTGRES_DB,
)

# For checkpointer
SQLALCHEMY_CONNECTION_KWARGS: dict[str, Any] = {
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
}

@computed_field # type: ignore[misc]
@property
def PG_DATABASE_URI(self) -> str:
Expand Down
11 changes: 7 additions & 4 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import (
Expand All @@ -22,7 +23,6 @@
from psycopg import AsyncConnection

from app.core.config import settings
from app.core.graph.checkpoint.postgres import PostgresSaver
from app.core.graph.members import (
GraphLeader,
GraphMember,
Expand Down Expand Up @@ -471,8 +471,11 @@ async def generator(
]

try:
async with await AsyncConnection.connect(settings.PG_DATABASE_URI) as conn:
checkpointer = PostgresSaver(async_connection=conn)
async with await AsyncConnection.connect(
settings.PG_DATABASE_URI,
**settings.SQLALCHEMY_CONNECTION_KWARGS,
) as conn:
checkpointer = AsyncPostgresSaver(conn=conn)
if team.workflow == "hierarchical":
teams = convert_hierarchical_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
Expand Down
Loading
Loading