Skip to content

Commit

Permalink
Add new option to auto-assign DEFAULT for NOT NULL columns
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Oct 28, 2024
1 parent 839ee63 commit 48ac6fd
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 18 deletions.
2 changes: 2 additions & 0 deletions pysqlsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class GeneratorOptions:
:param initialize_tables: Whether to populate special tables (e.g. enumerations) with data.
:param synchronization: Synchronization options.
:param skip_annotations: Annotation classes to ignore on table column types.
:param auto_default: Automatically assign a default value to non-nullable types.
"""

enum_mode: Optional[EnumMode] = None
Expand All @@ -152,6 +153,7 @@ class GeneratorOptions:
initialize_tables: bool = False
synchronization: MutatorOptions = dataclasses.field(default_factory=MutatorOptions)
skip_annotations: tuple[type, ...] = ()
auto_default: bool = False


class BaseGenerator(abc.ABC):
Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/dialect/delta/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(self, options: GeneratorOptions) -> None:
ipaddress.IPv4Address: DeltaFixedBinaryType(4),
ipaddress.IPv6Address: DeltaFixedBinaryType(16),
},
skip_annotations=options.skip_annotations,
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
1 change: 1 addition & 0 deletions pysqlsync/dialect/mssql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, options: GeneratorOptions) -> None:
},
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/dialect/oracle/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def __init__(self, options: GeneratorOptions) -> None:
ipaddress.IPv4Address: OracleVariableBinaryType(4),
ipaddress.IPv6Address: OracleVariableBinaryType(16),
},
skip_annotations=options.skip_annotations,
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/dialect/postgresql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def __init__(self, options: GeneratorOptions) -> None:
substitutions={
JsonType: PostgreSQLJsonType(),
},
skip_annotations=options.skip_annotations,
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/dialect/redshift/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __init__(self, options: GeneratorOptions) -> None:
ipaddress.IPv4Address: RedshiftVariableBinaryType(4),
ipaddress.IPv6Address: RedshiftVariableBinaryType(16),
},
skip_annotations=options.skip_annotations,
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
3 changes: 2 additions & 1 deletion pysqlsync/dialect/snowflake/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __init__(self, options: GeneratorOptions) -> None:
ipaddress.IPv4Address: SqlFixedBinaryType(4),
ipaddress.IPv6Address: SqlFixedBinaryType(16),
},
skip_annotations=options.skip_annotations,
factory=self.factory,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
1 change: 1 addition & 0 deletions pysqlsync/dialect/trino/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, options: GeneratorOptions) -> None:
foreign_constraints=options.foreign_constraints,
initialize_tables=options.initialize_tables,
skip_annotations=options.skip_annotations,
auto_default=options.auto_default,
)
)

Expand Down
1 change: 1 addition & 0 deletions pysqlsync/formation/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def is_ip_address_type(field_type: type) -> bool:


def dataclass_has_primary_key(typ: type[DataclassInstance]) -> bool:
typ = unwrap_annotated_type(typ)
for field in dataclass_fields(typ):
if is_primary_key_type(field.type):
return True
Expand Down
41 changes: 32 additions & 9 deletions pysqlsync/formation/py_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import types
import typing
import uuid
from typing import Annotated, Callable, Iterable, Optional, TypeVar
from typing import Annotated, Any, Callable, Iterable, Optional, TypeVar

from strong_typing.auxiliary import (
MaxLength,
Expand Down Expand Up @@ -150,6 +150,7 @@ def is_extensible_enum_type(typ: TypeLike, cls: type) -> bool:

if not is_type_union(typ):
return False
typ = unwrap_annotated_type(typ)
member_types = [evaluate_member_type(t, cls) for t in unwrap_union_types(typ)]
for member_type in member_types:
if member_type is None:
Expand Down Expand Up @@ -292,6 +293,7 @@ class DataclassConverterOptions:
:param substitutions: SQL type to be substituted for a specific Python type.
:param factory: Creates new column, table, struct and namespace instances.
:param skip_annotations: Annotation classes to ignore on table column types.
:param auto_default: Automatically assign a default value to non-nullable types.
"""

enum_mode: EnumMode = EnumMode.TYPE
Expand All @@ -307,6 +309,7 @@ class DataclassConverterOptions:
substitutions: dict[TypeLike, SqlDataType] = dataclasses.field(default_factory=dict)
factory: ObjectFactory = dataclasses.field(default_factory=ObjectFactory)
skip_annotations: tuple[type, ...] = ()
auto_default: bool = False


class DataclassConverter:
Expand Down Expand Up @@ -540,7 +543,7 @@ def member_to_sql_data_type(self, typ: TypeLike, cls: type) -> SqlDataType:
)
elif self.options.struct_mode is StructMode.INLINE:
inline_types: list[SqlStructMember] = []
for field in dataclass_fields(typ):
for field in dataclass_fields(unwrap_annotated_type(typ)):
inline_props = get_field_properties(field.type)
inline_type = self.member_to_sql_data_type(
inline_props.field_type, typ
Expand Down Expand Up @@ -609,7 +612,8 @@ def member_to_sql_data_type(self, typ: TypeLike, cls: type) -> SqlDataType:
return self._enumeration_key_type()
if is_type_union(typ):
member_types = [
evaluate_member_type(t, cls) for t in unwrap_union_types(typ)
evaluate_member_type(t, cls)
for t in unwrap_union_types(unwrap_annotated_type(typ))
]
if all(is_entity_type(t) for t in member_types):
# discriminated union type
Expand Down Expand Up @@ -644,11 +648,30 @@ def member_to_column(

props = get_field_properties(field.type)
data_type = self.member_to_sql_data_type(props.field_type, cls)
default_value = (
constant(field.default)
if field.default is not dataclasses.MISSING and field.default is not None
else None
)
if field.default is not dataclasses.MISSING and field.default is not None:
default_value = constant(field.default)
elif (
not props.nullable
and self.options.auto_default
and not is_dataclass_type(props.plain_type)
and not is_generic_list(props.plain_type)
and not is_generic_set(props.plain_type)
and not is_type_union(props.plain_type)
):
default_expr: Any
if props.plain_type is datetime.datetime:
default_expr = datetime.datetime.min
elif props.plain_type is datetime.date:
default_expr = datetime.date.min
elif props.plain_type is datetime.time:
default_expr = datetime.time.min
elif is_type_enum(props.plain_type):
default_expr = next(member.value for member in props.plain_type)
else:
default_expr = props.plain_type() # type: ignore
default_value = constant(default_expr)
else:
default_value = None
description = (
doc.params[field.name].description if field.name in doc.params else None
)
Expand Down Expand Up @@ -796,7 +819,7 @@ def dataclass_to_constraints(
if is_type_union(field.type):
member_types = [
evaluate_member_type(t, cls)
for t in unwrap_union_types(field.type)
for t in unwrap_union_types(unwrap_annotated_type(field.type))
]
if all(is_entity_type(t) for t in member_types):
# discriminated keys
Expand Down
14 changes: 11 additions & 3 deletions pysqlsync/model/data_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import decimal
from dataclasses import dataclass
from functools import reduce
from typing import Any, Optional
Expand Down Expand Up @@ -31,9 +32,14 @@ def constant(v: Any) -> str:
return str(v)
elif isinstance(v, bool):
return "TRUE" if v else "FALSE"
elif isinstance(v, decimal.Decimal):
return str(v)
elif isinstance(v, datetime.datetime):
timestamp = v.astimezone(tz=datetime.timezone.utc).replace(tzinfo=None)
return quote(timestamp.isoformat(sep=" "))
if v.tzinfo is not None:
timestamp = v.astimezone(tz=datetime.timezone.utc).replace(tzinfo=None)
else:
timestamp = v
return f"TIMESTAMP {quote(timestamp.isoformat(sep=' '))}"
elif isinstance(v, tuple):
values = ", ".join(constant(value) for value in v)
return f"({values})"
Expand All @@ -43,7 +49,9 @@ def constant(v: Any) -> str:
)
return f"({values})"
else:
raise NotImplementedError(f"unknown constant representation for value: {v}")
raise NotImplementedError(
f"unknown constant representation for value (of type): {v} ({type(v)})"
)


def escape_like(value: str, escape_char: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pysqlsync/model/entity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_entity(cls: type[DataclassInstance], key: str) -> type[DataclassInstanc
Transforms a regular data-class type into an entity type.
:param cls: The data-class type to transform.
:param key: The field to be come the primary key.
:param key: The field to become the primary key.
"""

key_field, value_fields = key_value_fields(cls, key)
Expand Down

0 comments on commit 48ac6fd

Please sign in to comment.