From ffccbc7f2db1b472c01c3db293d41598d5b29f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciek=20Bry=C5=84ski?= Date: Thu, 31 Oct 2024 17:29:40 -0400 Subject: [PATCH] Add ability to configure alembic_version table in DialectImpl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added a new hook to the :class:`.DefaultImpl` :meth:`.DefaultImpl.version_table_impl`. This allows third party dialects to define the exact structure of the alembic_version table, to include use cases where the table requires special directives and/or additional columns so that it may function correctly on a particular backend. This is not intended as a user-expansion hook, only a dialect implementation hook to produce a working alembic_version table. Pull request courtesy Maciek Bryński. This will be 1.14 so this also version bumps Fixes: #1560 Closes: #1563 Pull-request: https://github.com/sqlalchemy/alembic/pull/1563 Pull-request-sha: e70fdc8f4e405cabf5099c2100763d7b24da3be8 Change-Id: I5e565dff60a979526608d2a1c0c620fbca269a3f --- alembic/ddl/impl.py | 41 ++++++++++++++++++++++++++++-- alembic/runtime/migration.py | 29 +++++++++------------ docs/build/changelog.rst | 2 +- docs/build/unreleased/1560.rst | 12 +++++++++ tests/test_version_table.py | 46 ++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 20 deletions(-) create mode 100644 docs/build/unreleased/1560.rst diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 25746889..2609a62d 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -21,7 +21,12 @@ from typing import Union from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import text from . import _autogen @@ -43,11 +48,9 @@ from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index - from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import UniqueConstraint from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine @@ -136,6 +139,40 @@ def static_output(self, text: str) -> None: self.output_buffer.write(text + "\n\n") self.output_buffer.flush() + def version_table_impl( + self, + *, + version_table: str, + version_table_schema: Optional[str], + version_table_pk: bool, + **kw: Any, + ) -> Table: + """Generate a :class:`.Table` object which will be used as the + structure for the Alembic version table. + + Third party dialects may override this hook to provide an alternate + structure for this :class:`.Table`; requirements are only that it + be named based on the ``version_table`` parameter and contains + at least a single string-holding column named ``version_num``. + + .. versionadded:: 1.14 + + """ + vt = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint( + "version_num", name=f"{version_table}_pkc" + ) + ) + + return vt + def requires_recreate_in_batch( self, batch_op: BatchOperationsImpl ) -> bool: diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 6cfe5e23..28f01c3b 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -24,10 +24,6 @@ from sqlalchemy import Column from sqlalchemy import literal_column -from sqlalchemy import MetaData -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy import String -from sqlalchemy import Table from sqlalchemy.engine import Engine from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine.strategies import MockEngineStrategy @@ -36,6 +32,7 @@ from .. import util from ..util import sqla_compat from ..util.compat import EncodedIO +from ..util.sqla_compat import _select if TYPE_CHECKING: from sqlalchemy.engine import Dialect @@ -190,18 +187,6 @@ def __init__( self.version_table_schema = version_table_schema = opts.get( "version_table_schema", None ) - self._version = Table( - version_table, - MetaData(), - Column("version_num", String(32), nullable=False), - schema=version_table_schema, - ) - if opts.get("version_table_pk", True): - self._version.append_constraint( - PrimaryKeyConstraint( - "version_num", name="%s_pkc" % version_table - ) - ) self._start_from_rev: Optional[str] = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( @@ -212,6 +197,13 @@ def __init__( self.output_buffer, opts, ) + + self._version = self.impl.version_table_impl( + version_table=version_table, + version_table_schema=version_table_schema, + version_table_pk=opts.get("version_table_pk", True), + ) + log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") @@ -540,7 +532,10 @@ def get_current_heads(self) -> Tuple[str, ...]: return () assert self.connection is not None return tuple( - row[0] for row in self.connection.execute(self._version.select()) + row[0] + for row in self.connection.execute( + _select(self._version.c.version_num) + ) ) def _ensure_version_table(self, purge: bool = False) -> None: diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index 51a8a5e2..2d33a186 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== .. changelog:: - :version: 1.13.4 + :version: 1.14.0 :include_notes_from: unreleased .. changelog:: diff --git a/docs/build/unreleased/1560.rst b/docs/build/unreleased/1560.rst new file mode 100644 index 00000000..e808b307 --- /dev/null +++ b/docs/build/unreleased/1560.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, runtime + :tickets: 1560 + + Added a new hook to the :class:`.DefaultImpl` + :meth:`.DefaultImpl.version_table_impl`. This allows third party dialects + to define the exact structure of the alembic_version table, to include use + cases where the table requires special directives and/or additional columns + so that it may function correctly on a particular backend. This is not + intended as a user-expansion hook, only a dialect implementation hook to + produce a working alembic_version table. Pull request courtesy Maciek + Bryński. diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 5ad3c21d..ca569366 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -1,10 +1,15 @@ from sqlalchemy import Column from sqlalchemy import inspect +from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy.dialects import registry +from sqlalchemy.engine import default from alembic import migration +from alembic.ddl import impl from alembic.testing import assert_raises from alembic.testing import assert_raises_message from alembic.testing import config @@ -373,3 +378,44 @@ def test_delete_multi_match_no_sane_rowcount(self): self.connection.dialect, "supports_sane_rowcount", False ): self.updater.update_to_step(_down("a", None, True)) + + +registry.register("custom_version", __name__, "CustomVersionDialect") + + +class CustomVersionDialect(default.DefaultDialect): + name = "custom_version" + + +class CustomVersionTableImpl(impl.DefaultImpl): + __dialect__ = "custom_version" + + def version_table_impl( + self, + *, + version_table, + version_table_schema, + version_table_pk, + **kw, + ): + vt = Table( + version_table, + MetaData(), + Column("id", Integer, autoincrement=True), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint("id", name=f"{version_table}_pkc") + ) + return vt + + +class CustomVersionTableTest(TestMigrationContext): + + def test_custom_version_table(self): + context = migration.MigrationContext.configure( + dialect_name="custom_version", + ) + eq_(len(context._version.columns), 2)