Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions .claude/commands/make-migration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Create Migration

Generate a new migration file based on changes to `sql/schema.sql`.

## Arguments

- `$ARGUMENTS` - The migration name (e.g., "add_user_table", "update_claim_indexes")

## Workflow

1. **Read the current schema**: Read `sql/schema.sql` to understand the current desired state.

2. **Read existing migrations**: Read all files in `src/postgres/migrations/` to understand what's already been migrated.

3. **Determine the changes**: Compare the schema.sql against what the migrations would produce. Identify:
- New tables, columns, indexes, or constraints to add
- Modified functions or triggers
- Any DROP statements needed (be careful with these)

4. **Generate the migration SQL**: Create SQL that transforms the database from the current migrated state to the new schema.sql state.
- For new tables/indexes: Use `CREATE TABLE IF NOT EXISTS`, `CREATE INDEX IF NOT EXISTS`
- For function updates: Use `CREATE OR REPLACE FUNCTION`
- For existing queues that need new indexes: Include a `DO $$ ... END $$` block that applies changes to existing queue tables

5. **Create the migration file**: Generate a timestamped migration file:
- Filename format: `YYYYMMDDHHMMSS_<name>.sql`
- Place in: `src/postgres/migrations/`
- Use current UTC time for the timestamp

6. **Run validation**: Execute `./scripts/validate-schema` to verify the migration produces the correct schema.

## Example

If the user has added a new index to `ensure_queue_tables` in schema.sql:

```sql
-- New migration: 20260115143022_add_new_index.sql

-- Update ensure_queue_tables to include the new index for future queues
create or replace function durable.ensure_queue_tables (p_queue_name text)
returns void
language plpgsql
as $$
begin
-- ... (full function with new index)
end;
$$;

-- Apply the new index to existing queues
do $$
declare
v_queue text;
begin
for v_queue in select queue_name from durable.queues loop
execute format(
'create index if not exists %I on durable.%I (...)',
('t_' || v_queue) || '_new_idx',
't_' || v_queue
);
end loop;
end;
$$;
```

## Important Notes

- Always use `IF NOT EXISTS` for idempotent migrations
- For function changes, the full function must be included (not just the diff)
- The `DO $$ ... END $$` block for existing queues should NOT be in schema.sql (it's migration-only logic)
- Run validation after creating the migration to ensure schema.sql matches
42 changes: 42 additions & 0 deletions .claude/commands/validate-schema.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Validate Schema

Run the schema validation script to verify that `sql/schema.sql` matches the result of applying all migrations.

## How It Works

The validation script (`scripts/validate-schema`) uses testcontainers to:

1. Start two PostgreSQL 16 containers
2. Apply `sql/schema.sql` directly to container A
3. Apply all migrations in `src/postgres/migrations/` to container B
4. Dump both schemas using `pg_dump --schema-only --schema=durable`
5. Compare the normalized dumps
6. Report pass/fail with a diff on failure

## Running Validation

```bash
./scripts/validate-schema
```

## Requirements

- Docker must be running
- `uv` must be installed (the script uses inline dependencies)

## When to Run

- After creating a new migration with `/make-migration`
- Before committing schema changes
- CI runs this automatically on pull requests

## Troubleshooting

If validation fails, the output will show a unified diff between:
- `schema.sql` - What the schema file defines
- `migrations` - What applying all migrations produces

Common causes of failure:
- Forgot to update schema.sql after adding a migration
- Migration has different SQL than what's in schema.sql
- Migration includes logic that shouldn't be in schema.sql (like `DO $$ ... END $$` blocks for existing queues)
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Validate schema matches migrations
run: ./scripts/validate-schema

- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
Expand Down
223 changes: 223 additions & 0 deletions scripts/validate-schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = ["psycopg>=3.2.0", "testcontainers>=4.0.0"]
# ///
"""
Validates that sql/schema.sql matches the result of applying all migrations.

This script:
1. Starts two PostgreSQL 14 containers (to match TensorZero's minimum supported version)
2. Container A: Applies sql/schema.sql directly
3. Container B: Applies all migrations in src/postgres/migrations/ in timestamp order
4. Dumps both schemas using pg_dump --schema-only --schema=durable
5. Compares the dumps (excluding the _sqlx_migrations table)
6. Reports pass/fail with diff on failure
"""

import difflib
import re
import subprocess
import sys
from pathlib import Path

import psycopg
from testcontainers.postgres import PostgresContainer


def get_project_root() -> Path:
"""Find the project root by looking for Cargo.toml."""
current = Path(__file__).resolve().parent
while current != current.parent:
if (current / "Cargo.toml").exists():
return current
current = current.parent
raise RuntimeError("Could not find project root (no Cargo.toml found)")


def get_migrations(project_root: Path) -> list[Path]:
"""Get all migration files sorted by timestamp."""
migrations_dir = project_root / "src" / "postgres" / "migrations"
if not migrations_dir.exists():
raise RuntimeError(f"Migrations directory not found: {migrations_dir}")

migrations = sorted(migrations_dir.glob("*.sql"))
if not migrations:
raise RuntimeError(f"No migration files found in {migrations_dir}")

return migrations


def get_psycopg_url(container: PostgresContainer) -> str:
"""Get a psycopg-compatible connection URL from a testcontainer."""
# testcontainers returns a SQLAlchemy-style URL, we need to convert it
host = container.get_container_host_ip()
port = container.get_exposed_port(5432)
return f"postgresql://{container.username}:{container.password}@{host}:{port}/{container.dbname}"


def apply_schema(conn: psycopg.Connection, schema_path: Path) -> None:
"""Apply the schema.sql file to a database."""
sql = schema_path.read_text()
conn.execute(sql)
conn.commit()


def apply_migrations(conn: psycopg.Connection, migrations: list[Path]) -> None:
"""Apply all migrations to a database in order."""
for migration in migrations:
sql = migration.read_text()
conn.execute(sql)
conn.commit()


def dump_schema(container: PostgresContainer) -> str:
"""Dump the durable schema from a database using pg_dump."""
result = subprocess.run(
[
"docker",
"exec",
container.get_wrapped_container().id,
"pg_dump",
"-U",
container.username,
"-d",
container.dbname,
"--schema-only",
"--schema=durable",
"--no-owner",
"--no-privileges",
"--no-comments",
],
capture_output=True,
text=True,
check=True,
)
return result.stdout


def normalize_dump(dump: str) -> str:
r"""Normalize a pg_dump output for comparison.

Removes:
- SET statements and other session configuration
- Comments
- Empty lines
- The _sqlx_migrations table and related objects
- pg_dump session markers (\\restrict, \\unrestrict)
"""
lines = dump.split("\n")
normalized = []
skip_until_semicolon = False

for line in lines:
# Skip SET statements
if line.startswith("SET "):
continue

# Skip SELECT statements (like pg_catalog.set_config)
if line.startswith("SELECT "):
continue

# Skip comments
if line.startswith("--"):
continue

# Skip empty lines
if not line.strip():
continue

# Skip pg_dump session markers
if line.startswith("\\restrict") or line.startswith("\\unrestrict"):
continue

# Skip _sqlx_migrations table and related objects
if "_sqlx_migrations" in line:
skip_until_semicolon = True
continue

if skip_until_semicolon:
if ";" in line:
skip_until_semicolon = False
continue

normalized.append(line)

return "\n".join(normalized)


def main() -> int:
project_root = get_project_root()
schema_path = project_root / "sql" / "schema.sql"
migrations = get_migrations(project_root)

print(f"Project root: {project_root}")
print(f"Schema file: {schema_path}")
print(f"Found {len(migrations)} migrations:")
for m in migrations:
print(f" - {m.name}")
print()

if not schema_path.exists():
print(f"ERROR: Schema file not found: {schema_path}", file=sys.stderr)
return 1

# Use PostgreSQL 14 to match production
postgres_image = "postgres:14-alpine"

print("Starting PostgreSQL containers...")

with (
PostgresContainer(postgres_image) as schema_container,
PostgresContainer(postgres_image) as migrations_container,
):
print("Containers started.")
print()

# Apply schema.sql to container A
print("Applying schema.sql to container A...")
with psycopg.connect(get_psycopg_url(schema_container)) as conn:
apply_schema(conn, schema_path)
print("Schema applied.")

# Apply migrations to container B
print("Applying migrations to container B...")
with psycopg.connect(get_psycopg_url(migrations_container)) as conn:
apply_migrations(conn, migrations)
print("Migrations applied.")
print()

# Dump both schemas
print("Dumping schemas...")
schema_dump = dump_schema(schema_container)
migrations_dump = dump_schema(migrations_container)

# Normalize for comparison
schema_normalized = normalize_dump(schema_dump)
migrations_normalized = normalize_dump(migrations_dump)

# Compare
if schema_normalized == migrations_normalized:
print("SUCCESS: schema.sql matches migrations")
return 0

print("FAILURE: schema.sql does not match migrations")
print()
print("Diff (schema.sql vs migrations):")
print("-" * 60)

diff = difflib.unified_diff(
schema_normalized.split("\n"),
migrations_normalized.split("\n"),
fromfile="schema.sql",
tofile="migrations",
lineterm="",
)
for line in diff:
print(line)

return 1


if __name__ == "__main__":
sys.exit(main())
Loading