From 4e54f8f685eaee6d4aad3f7470b7066a6514ef99 Mon Sep 17 00:00:00 2001 From: Zach Collins Date: Mon, 26 Aug 2024 16:42:26 -0700 Subject: [PATCH] Migration Tests (#1066) Adds the capability to test migrations directly. Setup was a huge pain in the butt thanks to the fact that pytest alembic doesn't work out of the box and requires banging it a bunch until it works well. --- requirements.txt | 1 + src/migrations/alembic.ini | 2 ++ src/migrations/env.py | 4 +++- tests/conftest.py | 27 +++++++++++++++++++++++++++ tests/test_migrations.py | 30 ++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/test_migrations.py diff --git a/requirements.txt b/requirements.txt index 3ee42185c..f46d0d729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -104,3 +104,4 @@ anthropic[vertex]==0.31.2 langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e watchdog stumpy==1.12.0 +pytest_alembic==0.11.1 diff --git a/src/migrations/alembic.ini b/src/migrations/alembic.ini index 502a22590..51f0e6f6a 100644 --- a/src/migrations/alembic.ini +++ b/src/migrations/alembic.ini @@ -8,6 +8,8 @@ # the 'revision' command, regardless of autogenerate # revision_environment = false +script_location = /app/src/migrations + # Logging configuration [loggers] diff --git a/src/migrations/env.py b/src/migrations/env.py index b4a7a4a36..b94dc32bb 100644 --- a/src/migrations/env.py +++ b/src/migrations/env.py @@ -93,7 +93,9 @@ def process_revision_directives(context, revision, directives): if conf_args.get("process_revision_directives") is None: conf_args["process_revision_directives"] = process_revision_directives - connectable = get_engine() + connectable = context.config.attributes.get("connection", None) + if connectable is None: + connectable = get_engine() with connectable.connect() as connection: context.configure( diff --git a/tests/conftest.py b/tests/conftest.py index 276cfaa1d..2a259648d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ +import logging import os import johen import pytest from flask import Flask from johen.generators import pydantic, sqlalchemy +from pytest_alembic import Config from sqlalchemy import text from celery_app.config import CeleryConfig @@ -14,6 +16,8 @@ from seer.dependency_injection import Module, resolve from seer.inference_models import reset_loading_state +logger = logging.getLogger(__name__) + @pytest.fixture def test_module() -> Module: @@ -25,6 +29,29 @@ def configure_environment(): os.environ["LANGFUSE_HOST"] = "" # disable Langfuse logging for tests +@pytest.fixture +def alembic_config(): + return Config.from_raw_config({"file": "/app/src/migrations/alembic.ini"}) + + +@pytest.fixture +def alembic_runner(alembic_config: Config, setup_app): + import pytest_alembic + + app = resolve(Flask) + + with app.app_context(): + db.metadata.drop_all(bind=db.engine) + with db.engine.connect() as c: + c.execute(text("""DROP SCHEMA public CASCADE""")) + c.execute(text("""CREATE SCHEMA public;""")) + c.commit() + + with pytest_alembic.runner(config=alembic_config, engine=db.engine) as runner: + runner.set_revision("base") + yield runner + + @pytest.fixture(autouse=True) def setup_app(test_module: Module): with module, configuration_test_module, test_module: diff --git a/tests/test_migrations.py b/tests/test_migrations.py new file mode 100644 index 000000000..a8208a37d --- /dev/null +++ b/tests/test_migrations.py @@ -0,0 +1,30 @@ +from pytest_alembic import MigrationContext +from sqlalchemy import insert + +from seer.db import DbRunState, Session + + +def test_run_state_migration(alembic_runner: MigrationContext): + alembic_runner.migrate_up_before("9b8704bd8c4a") + + with Session() as session: + for id, json in ( + (1, {"a": "\x00"}), # Valid, but pg sql won't decode it well + (2, {}), # Valid, missing keys + (3, {"updated_at": None}), # Valid, key is null + (4, {"last_triggered_at": None}), # Valid, key is null + (5, {"updated_at": "2020-06-15T13:45:30"}), # Valid, stuff + (6, {"last_triggered_at": "2019-04-12T11:15:31"}), # Valid, stuff + ): + session.execute( + insert(alembic_runner.table_at_revision("run_state")).values(id=id, value=json) + ) + session.commit() + + alembic_runner.migrate_up_to("head") + + with Session() as session: + for i in range(1, 7): + state = session.query(DbRunState).get(i) + assert state.last_triggered_at + assert state.updated_at