Skip to content

Commit

Permalink
Migrate explicit enumerations to table relations in PostgreSQL and MySQL
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Mar 6, 2024
1 parent cb07629 commit 71bcd39
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 20 deletions.
29 changes: 27 additions & 2 deletions pysqlsync/dialect/mysql/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,38 @@
from typing import Optional

from pysqlsync.formation.mutation import Mutator
from pysqlsync.formation.object_types import Column, StatementList, Table, join_or_none
from pysqlsync.model.data_types import quote
from pysqlsync.formation.object_types import (
Column,
StatementList,
Table,
deleted,
join_or_none,
)
from pysqlsync.model.data_types import SqlEnumType, quote
from pysqlsync.model.id_types import LocalId

from .object_types import MySQLColumn, MySQLTable


class MySQLMutator(Mutator):
def migrate_column_stmt(
self, source_table: Table, source: Column, target_table: Table, target: Column
) -> Optional[str]:
statements: list[str] = []
ref = target_table.get_constraint(target.name)
if isinstance(source.data_type, SqlEnumType):
enum_values = ", ".join(f"({quote(v)})" for v in source.data_type.values)
statements.append(
f'INSERT INTO {ref.table} ("value") VALUES {enum_values}\n'
'ON DUPLICATE KEY UPDATE "value" = "value";'
)
statements.append(
f"UPDATE {source_table.name} data_table\n"
f'JOIN {ref.table} enum_table ON data_table.{LocalId(deleted(source.name.id))} = enum_table."value"\n'
f'SET data_table.{target.name} = enum_table."id";'
)
return "\n".join(statements)

def mutate_table_stmt(
self, source_table: Table, target_table: Table
) -> Optional[str]:
Expand Down
10 changes: 5 additions & 5 deletions pysqlsync/formation/mutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from ..model.data_types import SqlUserDefinedType, constant, quote
from ..model.data_types import SqlEnumType, SqlUserDefinedType, constant, quote
from ..model.id_types import SupportsName
from .object_types import (
Catalog,
Expand Down Expand Up @@ -101,19 +101,19 @@ def is_column_migrated(self, source: Column, target: Column) -> bool:
if source == target or source.data_type == target.data_type:
return False

is_user_source = isinstance(source.data_type, SqlUserDefinedType)
is_user_target = isinstance(target.data_type, SqlUserDefinedType)
is_user_source = isinstance(source.data_type, (SqlEnumType, SqlUserDefinedType))
is_user_target = isinstance(target.data_type, (SqlEnumType, SqlUserDefinedType))

if is_user_source and is_user_target:
raise ColumnFormationError(
"operation not permitted; cannot migrate data between two different user-defined types",
"operation not permitted; cannot migrate data between two different user-defined or enumeration types",
source.name,
)
elif is_user_source and not is_user_target:
return True
elif not is_user_source and is_user_target:
raise ColumnFormationError(
"operation not permitted; cannot migrate data to a user-defined type",
"operation not permitted; cannot migrate data to a user-defined or enumeration type",
source.name,
)
return False
Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/formation/py_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def enum_value_type(enum_type: type[enum.Enum]) -> type:
raise TypeError(
f"inconsistent enumeration value types for type {enum_type.__name__}: {value_types}"
)
return value_types.pop()
value_type = value_types.pop()
return value_type if value_type is not str else ENUM_LABEL_TYPE


def is_extensible_enum_type(typ: TypeLike, cls: type) -> bool:
Expand Down
32 changes: 20 additions & 12 deletions tests/test_synchronize.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,19 @@ async def test_identity_dataclass(self) -> None:

await conn.drop_objects()

async def enum_migration(self, options: GeneratorOptions) -> None:
async with self.engine.create_connection(self.parameters, options) as conn:
explorer = self.engine.create_explorer(conn)
await explorer.synchronize(module=tables)

options = GeneratorOptions(
enum_mode=EnumMode.RELATION,
namespaces={tables: "sample"},
)
async with self.engine.create_connection(self.parameters, options) as conn:
explorer = self.engine.create_explorer(conn)
await explorer.synchronize(module=tables)


@unittest.skipUnless(has_env_var("ORACLE"), "Oracle tests are disabled")
class TestOracleSynchronize(OracleBase, TestSynchronize):
Expand All @@ -290,17 +303,7 @@ async def test_enum_migration(self) -> None:
enum_mode=EnumMode.TYPE,
namespaces={tables: "sample"},
)
async with self.engine.create_connection(self.parameters, options) as conn:
explorer = self.engine.create_explorer(conn)
await explorer.synchronize(module=tables)

options = GeneratorOptions(
enum_mode=EnumMode.RELATION,
namespaces={tables: "sample"},
)
async with self.engine.create_connection(self.parameters, options) as conn:
explorer = self.engine.create_explorer(conn)
await explorer.synchronize(module=tables)
await self.enum_migration(options)


@unittest.skipUnless(has_env_var("MSSQL"), "Microsoft SQL tests are disabled")
Expand All @@ -310,7 +313,12 @@ class TestMSSQLSynchronize(MSSQLBase, TestSynchronize):

@unittest.skipUnless(has_env_var("MYSQL"), "MySQL tests are disabled")
class TestMySQLSynchronize(MySQLBase, TestSynchronize):
pass
async def test_enum_migration(self) -> None:
options = GeneratorOptions(
enum_mode=EnumMode.INLINE,
namespaces={tables: "sample"},
)
await self.enum_migration(options)


del TestSynchronize
Expand Down

0 comments on commit 71bcd39

Please sign in to comment.