From 8634fc65a6e5d20c26cab561ae47409ae1ddcd1a Mon Sep 17 00:00:00 2001 From: Levente Hunyadi Date: Fri, 12 Jan 2024 00:04:59 +0100 Subject: [PATCH] Improve unit tests --- .github/actions/test/action.yml | 2 + pysqlsync/dialect/oracle/generator.py | 7 +- tests/test_generator.py | 344 ++++++++++++++++++-------- 3 files changed, 245 insertions(+), 108 deletions(-) diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index 049619f..8d13e35 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -40,8 +40,10 @@ runs: - name: Fetch unit and/or integration tests uses: actions/checkout@v4 with: + # include `.github` to allow post-run for this composite action sparse-checkout: | . + .github tests - name: Install test requirements diff --git a/pysqlsync/dialect/oracle/generator.py b/pysqlsync/dialect/oracle/generator.py index 96a7e00..d87a174 100644 --- a/pysqlsync/dialect/oracle/generator.py +++ b/pysqlsync/dialect/oracle/generator.py @@ -29,6 +29,9 @@ ) from .object_types import OracleObjectFactory +MIN_DATETIME = datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) +MIN_DATE = datetime.date.min + class OracleGenerator(BaseGenerator): "Generator for Oracle." @@ -100,9 +103,9 @@ def get_field_extractor( elif field_type is datetime.time: return ( lambda obj: datetime.datetime.combine( - datetime.date.min, getattr(obj, field_name) + MIN_DATE, getattr(obj, field_name) ) - - datetime.datetime.min + - MIN_DATETIME ) return super().get_field_extractor(column, field_name, field_type) diff --git a/tests/test_generator.py b/tests/test_generator.py index 2696632..691e191 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,30 +1,58 @@ import ipaddress import unittest -from datetime import date, datetime, time, timezone +from datetime import date, datetime, time, timedelta, timezone from strong_typing.inspection import DataclassInstance -from pysqlsync.base import BaseGenerator, GeneratorOptions -from pysqlsync.factory import get_dialect +from pysqlsync.base import GeneratorOptions from tests import tables +from tests.params import ( + MSSQLBase, + MySQLBase, + OracleBase, + PostgreSQLBase, + TestEngineBase, + configure, + has_env_var, +) + +if __name__ == "__main__": + configure() -def get_generator(dialect: str) -> BaseGenerator: - return get_dialect(dialect).create_generator( - GeneratorOptions(namespaces={tables: None}) - ) +class TestGenerator(TestEngineBase, unittest.TestCase): + @property + def options(self) -> GeneratorOptions: + return GeneratorOptions(namespaces={tables: None}) + def assertMatchSQLCreate( + self, dialect: str, table: type[DataclassInstance], sql: str + ) -> None: + if dialect != self.engine.name: + return + + statement = ( + self.engine.create_generator(self.options).create(tables=[table]) or "" + ) + self.assertMultiLineEqual(statement, sql) -def get_create_stmt(table: type[DataclassInstance], dialect: str) -> str: - statement = get_generator(dialect=dialect).create(tables=[table]) - return statement or "" + def assertMatchSQLUpsert( + self, dialect: str, table: type[DataclassInstance], sql: str + ) -> None: + if dialect != self.engine.name: + return + generator = self.engine.create_generator(self.options) + generator.create(tables=[table]) + statement = generator.get_dataclass_upsert_stmt(table) + self.assertMultiLineEqual(statement, sql) -class TestGenerator(unittest.TestCase): def test_create_boolean_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.BooleanTable, dialect="postgresql"), + + self.assertMatchSQLCreate( + "postgresql", + tables.BooleanTable, 'CREATE TABLE "BooleanTable" (\n' '"id" bigint NOT NULL,\n' '"boolean" boolean NOT NULL,\n' @@ -32,8 +60,9 @@ def test_create_boolean_table(self) -> None: 'CONSTRAINT "pk_BooleanTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.BooleanTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.BooleanTable, 'CREATE TABLE "BooleanTable" (\n' '"id" bigint NOT NULL,\n' '"boolean" bit NOT NULL,\n' @@ -41,8 +70,9 @@ def test_create_boolean_table(self) -> None: 'CONSTRAINT "pk_BooleanTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.BooleanTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.BooleanTable, 'CREATE TABLE "BooleanTable" (\n' '"id" bigint NOT NULL,\n' '"boolean" tinyint NOT NULL,\n' @@ -55,8 +85,9 @@ def test_create_numeric_table(self) -> None: self.maxDiff = None for dialect in ["postgresql", "mssql", "mysql"]: with self.subTest(dialect=dialect): - self.assertMultiLineEqual( - get_create_stmt(tables.NumericTable, dialect=dialect), + self.assertMatchSQLCreate( + dialect, + tables.NumericTable, 'CREATE TABLE "NumericTable" (\n' '"id" bigint NOT NULL,\n' '"integer_8" smallint NOT NULL,\n' @@ -73,8 +104,9 @@ def test_create_default_numeric_table(self) -> None: self.maxDiff = None for dialect in ["postgresql", "mysql"]: with self.subTest(dialect=dialect): - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultNumericTable, dialect=dialect), + self.assertMatchSQLCreate( + dialect, + tables.DefaultNumericTable, 'CREATE TABLE "DefaultNumericTable" (\n' '"id" bigint NOT NULL,\n' '"integer_8" smallint NOT NULL DEFAULT 127,\n' @@ -85,8 +117,9 @@ def test_create_default_numeric_table(self) -> None: 'CONSTRAINT "pk_DefaultNumericTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultNumericTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.DefaultNumericTable, 'CREATE TABLE "DefaultNumericTable" (\n' '"id" bigint NOT NULL,\n' '"integer_8" smallint NOT NULL,\n' @@ -107,8 +140,9 @@ def test_create_fixed_precision_float_table(self) -> None: self.maxDiff = None for dialect in ["postgresql", "mssql", "mysql"]: with self.subTest(dialect=dialect): - self.assertMultiLineEqual( - get_create_stmt(tables.FixedPrecisionFloatTable, dialect=dialect), + self.assertMatchSQLCreate( + dialect, + tables.FixedPrecisionFloatTable, 'CREATE TABLE "FixedPrecisionFloatTable" (\n' '"id" bigint NOT NULL,\n' '"float_32" real NOT NULL,\n' @@ -123,8 +157,9 @@ def test_create_decimal_table(self) -> None: self.maxDiff = None for dialect in ["postgresql", "mssql", "mysql"]: with self.subTest(dialect=dialect): - self.assertMultiLineEqual( - get_create_stmt(tables.DecimalTable, dialect=dialect), + self.assertMatchSQLCreate( + dialect, + tables.DecimalTable, 'CREATE TABLE "DecimalTable" (\n' '"id" bigint NOT NULL,\n' '"decimal_value" decimal NOT NULL,\n' @@ -136,8 +171,9 @@ def test_create_decimal_table(self) -> None: def test_create_string_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.StringTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.StringTable, 'CREATE TABLE "StringTable" (\n' '"id" bigint NOT NULL,\n' '"arbitrary_length_string" text NOT NULL,\n' @@ -147,8 +183,9 @@ def test_create_string_table(self) -> None: 'CONSTRAINT "pk_StringTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.StringTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.StringTable, 'CREATE TABLE "StringTable" (\n' '"id" bigint NOT NULL,\n' '"arbitrary_length_string" varchar(max) NOT NULL,\n' @@ -158,8 +195,9 @@ def test_create_string_table(self) -> None: 'CONSTRAINT "pk_StringTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.StringTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.StringTable, 'CREATE TABLE "StringTable" (\n' '"id" bigint NOT NULL,\n' '"arbitrary_length_string" mediumtext NOT NULL,\n' @@ -172,8 +210,9 @@ def test_create_string_table(self) -> None: def test_create_date_time_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.DateTimeTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.DateTimeTable, 'CREATE TABLE "DateTimeTable" (\n' '"id" bigint NOT NULL,\n' '"iso_date_time" timestamp NOT NULL,\n' @@ -184,8 +223,9 @@ def test_create_date_time_table(self) -> None: 'CONSTRAINT "pk_DateTimeTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DateTimeTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.DateTimeTable, 'CREATE TABLE "DateTimeTable" (\n' '"id" bigint NOT NULL,\n' '"iso_date_time" datetime2 NOT NULL,\n' @@ -196,8 +236,9 @@ def test_create_date_time_table(self) -> None: 'CONSTRAINT "pk_DateTimeTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DateTimeTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.DateTimeTable, 'CREATE TABLE "DateTimeTable" (\n' '"id" bigint NOT NULL,\n' '"iso_date_time" datetime NOT NULL,\n' @@ -210,32 +251,36 @@ def test_create_date_time_table(self) -> None: ) def test_create_default_datetime_table(self) -> None: - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultDateTimeTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" bigint NOT NULL,\n' """"iso_date_time" timestamp NOT NULL DEFAULT '1989-10-24 23:59:59',\n""" 'CONSTRAINT "pk_DefaultDateTimeTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultDateTimeTable, dialect="oracle"), + self.assertMatchSQLCreate( + "oracle", + tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" number NOT NULL,\n' """"iso_date_time" timestamp DEFAULT TIMESTAMP '1989-10-24 23:59:59' NOT NULL,\n""" 'CONSTRAINT "pk_DefaultDateTimeTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultDateTimeTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" bigint NOT NULL,\n' """"iso_date_time" datetime NOT NULL DEFAULT '1989-10-24 23:59:59',\n""" 'CONSTRAINT "pk_DefaultDateTimeTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DefaultDateTimeTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" bigint NOT NULL,\n' """"iso_date_time" datetime2 NOT NULL,\n""" @@ -246,8 +291,9 @@ def test_create_default_datetime_table(self) -> None: def test_create_enum_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.EnumTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.EnumTable, """CREATE TYPE "WorkflowState" AS ENUM ('active', 'inactive', 'deleted');\n""" 'CREATE TABLE "EnumTable" (\n' '"id" bigint NOT NULL,\n' @@ -256,8 +302,9 @@ def test_create_enum_table(self) -> None: 'CONSTRAINT "pk_EnumTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.EnumTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.EnumTable, 'CREATE TABLE "EnumTable" (\n' '"id" bigint NOT NULL,\n' """"state" ENUM ('active', 'inactive', 'deleted') CHARACTER SET ascii COLLATE ascii_bin NOT NULL,\n""" @@ -265,8 +312,9 @@ def test_create_enum_table(self) -> None: 'CONSTRAINT "pk_EnumTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.EnumTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.EnumTable, 'CREATE TABLE "EnumTable" (\n' '"id" bigint NOT NULL,\n' '"state" integer NOT NULL,\n' @@ -287,8 +335,9 @@ def test_create_enum_table(self) -> None: def test_create_ipaddress_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.IPAddressTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.IPAddressTable, 'CREATE TABLE "IPAddressTable" (\n' '"id" bigint NOT NULL,\n' '"ipv4" inet NOT NULL,\n' @@ -301,8 +350,9 @@ def test_create_ipaddress_table(self) -> None: ) for dialect in ["mssql", "mysql"]: with self.subTest(dialect=dialect): - self.assertMultiLineEqual( - get_create_stmt(tables.IPAddressTable, dialect=dialect), + self.assertMatchSQLCreate( + dialect, + tables.IPAddressTable, 'CREATE TABLE "IPAddressTable" (\n' '"id" bigint NOT NULL,\n' '"ipv4" binary(4) NOT NULL,\n' @@ -316,8 +366,9 @@ def test_create_ipaddress_table(self) -> None: def test_create_literal_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.LiteralTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.LiteralTable, 'CREATE TABLE "LiteralTable" (\n' '"id" bigint NOT NULL,\n' '"single" char(5) NOT NULL,\n' @@ -327,8 +378,9 @@ def test_create_literal_table(self) -> None: 'CONSTRAINT "pk_LiteralTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.LiteralTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.LiteralTable, 'CREATE TABLE "LiteralTable" (\n' '"id" bigint NOT NULL,\n' '"single" char(5) NOT NULL,\n' @@ -338,8 +390,9 @@ def test_create_literal_table(self) -> None: 'CONSTRAINT "pk_LiteralTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.LiteralTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.LiteralTable, 'CREATE TABLE "LiteralTable" (\n' '"id" bigint NOT NULL,\n' '"single" char(5) NOT NULL,\n' @@ -352,24 +405,27 @@ def test_create_literal_table(self) -> None: def test_create_primary_key_table(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.DataTable, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.DataTable, 'CREATE TABLE "DataTable" (\n' '"id" bigint NOT NULL,\n' '"data" text NOT NULL,\n' 'CONSTRAINT "pk_DataTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DataTable, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.DataTable, 'CREATE TABLE "DataTable" (\n' '"id" bigint NOT NULL,\n' '"data" varchar(max) NOT NULL,\n' 'CONSTRAINT "pk_DataTable" PRIMARY KEY ("id")\n' ");", ) - self.assertMultiLineEqual( - get_create_stmt(tables.DataTable, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.DataTable, 'CREATE TABLE "DataTable" (\n' '"id" bigint NOT NULL,\n' '"data" mediumtext NOT NULL,\n' @@ -379,8 +435,9 @@ def test_create_primary_key_table(self) -> None: def test_create_table_with_description(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.Person, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.Person, 'CREATE TABLE "Address" (\n' '"id" bigint NOT NULL,\n' '"city" text NOT NULL,\n' @@ -399,8 +456,9 @@ def test_create_table_with_description(self) -> None: 'ALTER TABLE "Person"\n' 'ADD CONSTRAINT "fk_Person_address" FOREIGN KEY ("address") REFERENCES "Address" ("id");', ) - self.assertMultiLineEqual( - get_create_stmt(tables.Person, dialect="mssql"), + self.assertMatchSQLCreate( + "mssql", + tables.Person, 'CREATE TABLE "Address" (\n' '"id" bigint NOT NULL,\n' '"city" varchar(max) NOT NULL,\n' @@ -416,8 +474,9 @@ def test_create_table_with_description(self) -> None: 'ALTER TABLE "Person" ADD\n' 'CONSTRAINT "fk_Person_address" FOREIGN KEY ("address") REFERENCES "Address" ("id");', ) - self.assertMultiLineEqual( - get_create_stmt(tables.Person, dialect="mysql"), + self.assertMatchSQLCreate( + "mysql", + tables.Person, 'CREATE TABLE "Address" (\n' '"id" bigint NOT NULL,\n' '"city" mediumtext NOT NULL,\n' @@ -437,8 +496,9 @@ def test_create_table_with_description(self) -> None: def test_create_type_with_description(self) -> None: self.maxDiff = None - self.assertMultiLineEqual( - get_create_stmt(tables.Location, dialect="postgresql"), + self.assertMatchSQLCreate( + "postgresql", + tables.Location, 'CREATE TYPE "Coordinates" AS (\n' '"lat" double precision,\n' '"long" double precision\n' @@ -455,10 +515,9 @@ def test_create_type_with_description(self) -> None: def test_insert_single(self) -> None: self.maxDiff = None - generator = get_generator(dialect="postgresql") - generator.create(tables=[tables.DataTable]) - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.DataTable), + self.assertMatchSQLUpsert( + "postgresql", + tables.DataTable, 'INSERT INTO "DataTable"\n' '("id", "data") VALUES ($1, $2)\n' 'ON CONFLICT ("id") DO UPDATE SET\n' @@ -468,11 +527,11 @@ def test_insert_single(self) -> None: for dialect in ["mssql", "oracle"]: with self.subTest(dialect=dialect): - generator = get_generator(dialect=dialect) - generator.create(tables=[tables.DataTable]) + generator = self.engine.create_generator(self.options) value_list = f"({generator.placeholder(1)}, {generator.placeholder(2)})" - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.DataTable), + self.assertMatchSQLUpsert( + dialect, + tables.DataTable, 'MERGE INTO "DataTable" target\n' f'USING (VALUES {value_list}) source("id", "data")\n' 'ON (target."id" = source."id")\n' @@ -483,10 +542,9 @@ def test_insert_single(self) -> None: ";", ) - generator = get_generator(dialect="mysql") - generator.create(tables=[tables.DataTable]) - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.DataTable), + self.assertMatchSQLUpsert( + "mysql", + tables.DataTable, 'INSERT INTO "DataTable"\n' '("id", "data") VALUES (%s, %s)\n' "ON DUPLICATE KEY UPDATE\n" @@ -496,10 +554,9 @@ def test_insert_single(self) -> None: def test_insert_multiple(self) -> None: self.maxDiff = None - generator = get_generator(dialect="postgresql") - generator.create(tables=[tables.BooleanTable]) - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.BooleanTable), + self.assertMatchSQLUpsert( + "postgresql", + tables.BooleanTable, 'INSERT INTO "BooleanTable"\n' '("id", "boolean", "nullable_boolean") VALUES ($1, $2, $3)\n' 'ON CONFLICT ("id") DO UPDATE SET\n' @@ -510,11 +567,11 @@ def test_insert_multiple(self) -> None: for dialect in ["mssql", "oracle"]: with self.subTest(dialect=dialect): - generator = get_generator(dialect=dialect) - generator.create(tables=[tables.BooleanTable]) + generator = self.engine.create_generator(self.options) value_list = f"({generator.placeholder(1)}, {generator.placeholder(2)}, {generator.placeholder(3)})" - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.BooleanTable), + self.assertMatchSQLUpsert( + dialect, + tables.BooleanTable, 'MERGE INTO "BooleanTable" target\n' f'USING (VALUES {value_list}) source("id", "boolean", "nullable_boolean")\n' 'ON (target."id" = source."id")\n' @@ -527,10 +584,9 @@ def test_insert_multiple(self) -> None: ";", ) - generator = get_generator(dialect="mysql") - generator.create(tables=[tables.BooleanTable]) - self.assertMultiLineEqual( - generator.get_dataclass_upsert_stmt(tables.BooleanTable), + self.assertMatchSQLUpsert( + "mysql", + tables.BooleanTable, 'INSERT INTO "BooleanTable"\n' '("id", "boolean", "nullable_boolean") VALUES (%s, %s, %s)\n' "ON DUPLICATE KEY UPDATE\n" @@ -540,17 +596,13 @@ def test_insert_multiple(self) -> None: ) def test_table_data(self) -> None: - generator = get_generator(dialect="postgresql") + generator = self.engine.create_generator(self.options) generator.create( tables=[ tables.DataTable, - tables.DateTimeTable, - tables.EnumTable, - tables.IPAddressTable, tables.StringTable, ] ) - self.assertEqual( generator.get_dataclass_as_record( tables.DataTable, tables.DataTable(123, "abc") @@ -569,6 +621,10 @@ def test_table_data(self) -> None: ), (2, "abc", "def", "ghi", "jkl"), ) + + def test_table_data_datetime(self) -> None: + generator = self.engine.create_generator(self.options) + generator.create(tables=[tables.DateTimeTable]) self.assertEqual( generator.get_dataclass_as_record( tables.DateTimeTable, @@ -590,12 +646,76 @@ def test_table_data(self) -> None: datetime(1984, 1, 1, 23, 59, 59, tzinfo=timezone.utc), ), ) + + def test_table_data_ipaddress(self) -> None: + generator = self.engine.create_generator(self.options) + generator.create(tables=[tables.IPAddressTable]) + self.assertEqual( + generator.get_dataclass_as_record( + tables.IPAddressTable, + tables.IPAddressTable( + 1, + ipaddress.IPv4Address("192.168.0.1"), + ipaddress.IPv6Address("2001:db8::"), + ipaddress.IPv6Address("2001:db8::"), + None, + None, + ), + ), + ( + 1, + b"\xc0\xa8\x00\x01", + b" \x01\r\xb8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + b" \x01\r\xb8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + None, + None, + ), + ) + + +@unittest.skipUnless(has_env_var("ORACLE"), "Oracle tests are disabled") +class TestOracleGenerator(OracleBase, TestGenerator): + def test_table_data_datetime(self) -> None: + generator = self.engine.create_generator(self.options) + generator.create(tables=[tables.DateTimeTable]) + self.assertEqual( + generator.get_dataclass_as_record( + tables.DateTimeTable, + tables.DateTimeTable( + 1, + datetime(1982, 10, 23, 23, 59, 59, tzinfo=timezone.utc), + date(2023, 1, 1), + time(23, 59, 59, tzinfo=timezone.utc), + None, + datetime(1984, 1, 1, 23, 59, 59, tzinfo=timezone.utc), + ), + ), + ( + 1, + datetime(1982, 10, 23, 23, 59, 59, tzinfo=timezone.utc), + date(2023, 1, 1), + timedelta(seconds=86399), + None, + datetime(1984, 1, 1, 23, 59, 59, tzinfo=timezone.utc), + ), + ) + + +@unittest.skipUnless(has_env_var("POSTGRESQL"), "PostgreSQL tests are disabled") +class TestPostgreSQLGenerator(PostgreSQLBase, TestGenerator): + def test_table_data_enum(self) -> None: + generator = self.engine.create_generator(self.options) + generator.create(tables=[tables.EnumTable]) self.assertEqual( generator.get_dataclass_as_record( tables.EnumTable, tables.EnumTable(1, tables.WorkflowState.active, None) ), (1, "active", None), ) + + def test_table_data_ipaddress(self) -> None: + generator = self.engine.create_generator(self.options) + generator.create(tables=[tables.IPAddressTable]) self.assertEqual( generator.get_dataclass_as_record( tables.IPAddressTable, @@ -619,5 +739,17 @@ def test_table_data(self) -> None: ) +@unittest.skipUnless(has_env_var("MSSQL"), "Microsoft SQL tests are disabled") +class TestMSSQLGenerator(MSSQLBase, TestGenerator): + pass + + +@unittest.skipUnless(has_env_var("MYSQL"), "MySQL tests are disabled") +class TestMySQLGenerator(MySQLBase, TestGenerator): + pass + + +del TestGenerator + if __name__ == "__main__": unittest.main()