-
Notifications
You must be signed in to change notification settings - Fork 694
feat: add langgraph 2.0 checkpoints migration and fix lint issues #921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,8 +96,125 @@ async def update_table_wrapper(): | |
| await update_table_schema(conn, dialect, model_cls) | ||
|
|
||
| await update_table_wrapper() | ||
|
|
||
| # Migrate checkpoints tables | ||
| await migrate_checkpoints_table(conn) | ||
| except Exception as e: | ||
| logger.error(f"Error updating database schema: {str(e)}") | ||
| raise | ||
|
|
||
| logger.info("Database schema updated successfully") | ||
|
|
||
|
|
||
| async def migrate_checkpoints_table(conn) -> None: | ||
| """Migrate checkpoints tables to support langgraph 2.0.""" | ||
| tables = ["checkpoints", "checkpoint_blobs", "checkpoint_writes"] | ||
|
|
||
| def _get_tables(connection): | ||
| insp = inspect(connection) | ||
| return insp.get_table_names() | ||
|
|
||
| existing_tables = await conn.run_sync(_get_tables) | ||
|
|
||
| for table in tables: | ||
| if table not in existing_tables: | ||
| continue | ||
|
|
||
| # 1. Add checkpoint_ns column | ||
| await conn.execute( | ||
| text( | ||
| f"ALTER TABLE {table} ADD COLUMN IF NOT EXISTS checkpoint_ns TEXT DEFAULT ''" | ||
| ) | ||
| ) | ||
|
|
||
| # 2. Drop columns that ShallowPostgresSaver doesn't use | ||
| if table == "checkpoints": | ||
| # ShallowPostgresSaver doesn't use checkpoint_id or parent_checkpoint_id | ||
| await conn.execute( | ||
| text("ALTER TABLE checkpoints DROP COLUMN IF EXISTS checkpoint_id") | ||
| ) | ||
| await conn.execute( | ||
| text( | ||
| "ALTER TABLE checkpoints DROP COLUMN IF EXISTS parent_checkpoint_id" | ||
| ) | ||
| ) | ||
| elif table == "checkpoint_blobs": | ||
| # ShallowPostgresSaver doesn't use version column | ||
| await conn.execute( | ||
| text("ALTER TABLE checkpoint_blobs DROP COLUMN IF EXISTS version") | ||
| ) | ||
|
|
||
| # 3. Update Primary Key | ||
| def _check_pk(connection, table_name=table): | ||
| insp = inspect(connection) | ||
| return insp.get_pk_constraint(table_name) | ||
|
|
||
| pk = await conn.run_sync(_check_pk) | ||
| current_cols = set(pk.get("constrained_columns", [])) | ||
|
|
||
| # Expected columns depend on table | ||
| expected_cols = set() | ||
| pk_cols = "" | ||
| if table == "checkpoints": | ||
| expected_cols = {"thread_id", "checkpoint_ns"} | ||
| pk_cols = "thread_id, checkpoint_ns" | ||
| elif table == "checkpoint_blobs": | ||
| expected_cols = {"thread_id", "checkpoint_ns", "channel"} | ||
| pk_cols = "thread_id, checkpoint_ns, channel" | ||
| elif table == "checkpoint_writes": | ||
| expected_cols = { | ||
| "thread_id", | ||
| "checkpoint_ns", | ||
| "checkpoint_id", | ||
| "task_id", | ||
| "idx", | ||
| } | ||
| pk_cols = "thread_id, checkpoint_ns, checkpoint_id, task_id, idx" | ||
|
|
||
| if current_cols != expected_cols: | ||
| logger.info(f"Migrating {table} PK from {current_cols} to {expected_cols}") | ||
|
|
||
| # If migrating checkpoints to (thread_id, checkpoint_ns), we need to handle duplicates | ||
| if table == "checkpoints" and expected_cols == { | ||
| "thread_id", | ||
| "checkpoint_ns", | ||
| }: | ||
| # Keep only the latest checkpoint for each (thread_id, checkpoint_ns) based on checkpoint_id (time-ordered) | ||
| await conn.execute( | ||
| text(""" | ||
| DELETE FROM checkpoints | ||
| WHERE (thread_id, checkpoint_ns, checkpoint_id) NOT IN ( | ||
| SELECT thread_id, checkpoint_ns, MAX(checkpoint_id) | ||
| FROM checkpoints | ||
| GROUP BY thread_id, checkpoint_ns | ||
| ) | ||
| """) | ||
| ) | ||
|
|
||
| # If migrating checkpoint_blobs to (thread_id, checkpoint_ns, channel), we need to handle duplicates | ||
| elif table == "checkpoint_blobs" and expected_cols == { | ||
| "thread_id", | ||
| "checkpoint_ns", | ||
| "channel", | ||
| }: | ||
| # Keep only blobs that are referenced by the remaining checkpoints | ||
| # The relationship is: checkpoints.checkpoint -> 'channel_versions' ->> blob.channel = blob.version | ||
| await conn.execute( | ||
| text(""" | ||
| DELETE FROM checkpoint_blobs cb | ||
| WHERE NOT EXISTS ( | ||
| SELECT 1 | ||
| FROM checkpoints cp | ||
| WHERE cp.thread_id = cb.thread_id | ||
| AND cp.checkpoint_ns = cb.checkpoint_ns | ||
| AND (cp.checkpoint -> 'channel_versions' ->> cb.channel) = cb.version | ||
| ) | ||
| """) | ||
| ) | ||
|
|
||
| if pk.get("name"): | ||
| await conn.execute( | ||
| text(f"ALTER TABLE {table} DROP CONSTRAINT {pk['name']}") | ||
| ) | ||
|
|
||
| await conn.execute(text(f"ALTER TABLE {table} ADD PRIMARY KEY ({pk_cols})")) | ||
|
Comment on lines
+119
to
+220
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new
migrate_checkpoints_tablefunction lacks test coverage. Given that this migration performs destructive operations (dropping columns, deleting data), it's critical to have tests that verify:Consider adding tests similar to the existing model tests in
tests/models/test_skills_schema.py.