Skip to content

Commit

Permalink
Merge branch 'main' into jenn/api-key
Browse files Browse the repository at this point in the history
  • Loading branch information
jennmueng authored Aug 26, 2024
2 parents b09bb1b + 4e54f8f commit 21aed40
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,4 @@ langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa
watchdog
stumpy==1.12.0
cryptography==43.0.0
pytest_alembic==0.11.1
2 changes: 2 additions & 0 deletions src/migrations/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# the 'revision' command, regardless of autogenerate
# revision_environment = false

script_location = /app/src/migrations


# Logging configuration
[loggers]
Expand Down
4 changes: 3 additions & 1 deletion src/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def process_revision_directives(context, revision, directives):
if conf_args.get("process_revision_directives") is None:
conf_args["process_revision_directives"] = process_revision_directives

connectable = get_engine()
connectable = context.config.attributes.get("connection", None)
if connectable is None:
connectable = get_engine()

with connectable.connect() as connection:
context.configure(
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os

import johen
import pytest
from flask import Flask
from johen.generators import pydantic, sqlalchemy
from pytest_alembic import Config
from sqlalchemy import text

from celery_app.config import CeleryConfig
Expand All @@ -14,6 +16,8 @@
from seer.dependency_injection import Module, resolve
from seer.inference_models import reset_loading_state

logger = logging.getLogger(__name__)


@pytest.fixture
def test_module() -> Module:
Expand All @@ -25,6 +29,29 @@ def configure_environment():
os.environ["LANGFUSE_HOST"] = "" # disable Langfuse logging for tests


@pytest.fixture
def alembic_config():
return Config.from_raw_config({"file": "/app/src/migrations/alembic.ini"})


@pytest.fixture
def alembic_runner(alembic_config: Config, setup_app):
import pytest_alembic

app = resolve(Flask)

with app.app_context():
db.metadata.drop_all(bind=db.engine)
with db.engine.connect() as c:
c.execute(text("""DROP SCHEMA public CASCADE"""))
c.execute(text("""CREATE SCHEMA public;"""))
c.commit()

with pytest_alembic.runner(config=alembic_config, engine=db.engine) as runner:
runner.set_revision("base")
yield runner


@pytest.fixture(autouse=True)
def setup_app(test_module: Module):
with module, configuration_test_module, test_module:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pytest_alembic import MigrationContext
from sqlalchemy import insert

from seer.db import DbRunState, Session


def test_run_state_migration(alembic_runner: MigrationContext):
alembic_runner.migrate_up_before("9b8704bd8c4a")

with Session() as session:
for id, json in (
(1, {"a": "\x00"}), # Valid, but pg sql won't decode it well
(2, {}), # Valid, missing keys
(3, {"updated_at": None}), # Valid, key is null
(4, {"last_triggered_at": None}), # Valid, key is null
(5, {"updated_at": "2020-06-15T13:45:30"}), # Valid, stuff
(6, {"last_triggered_at": "2019-04-12T11:15:31"}), # Valid, stuff
):
session.execute(
insert(alembic_runner.table_at_revision("run_state")).values(id=id, value=json)
)
session.commit()

alembic_runner.migrate_up_to("head")

with Session() as session:
for i in range(1, 7):
state = session.query(DbRunState).get(i)
assert state.last_triggered_at
assert state.updated_at

0 comments on commit 21aed40

Please sign in to comment.