Skip to content

Commit

Permalink
Ensure synthesized data-class types are assigned to the right module …
Browse files Browse the repository at this point in the history
…in Python 3.12
  • Loading branch information
hunyadi committed Jan 4, 2024
1 parent 4bcf1f4 commit 031a97e
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 14 deletions.
4 changes: 3 additions & 1 deletion pysqlsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def create(
self.state = target
return statement
else:
raise NotImplementedError()
# should never be triggered; either `tables` or `modules` must be defined at this point
raise NotImplementedError("match condition not exhaustive")

def drop(self) -> Optional[str]:
"""
Expand Down Expand Up @@ -1022,6 +1023,7 @@ def _module_or_list(
elif module is not None:
entity_modules = [module]
else:
# should never be triggered; either `module` or `modules` must be defined at this point
raise NotImplementedError("match condition not exhaustive")

return entity_modules
Expand Down
8 changes: 4 additions & 4 deletions pysqlsync/dialect/trino/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _execute(self, statement: str) -> None:
async def _execute_all(
self, statement: str, args: Iterable[tuple[Any, ...]]
) -> None:
raise NotImplementedError()
raise NotImplementedError("operation not supported for Trino")

@override
async def _query_all(self, signature: type[T], statement: str) -> list[T]:
Expand All @@ -65,7 +65,7 @@ async def _query_all(self, signature: type[T], statement: str) -> list[T]:

@override
async def insert_data(self, table: type[T], data: Iterable[T]) -> None:
raise NotImplementedError()
raise NotImplementedError("operation not supported for Trino")

@override
async def _insert_rows(
Expand All @@ -76,7 +76,7 @@ async def _insert_rows(
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> None:
raise NotImplementedError()
raise NotImplementedError("operation not supported for Trino")

@override
async def _upsert_rows(
Expand All @@ -87,4 +87,4 @@ async def _upsert_rows(
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> None:
raise NotImplementedError()
raise NotImplementedError("operation not supported for Trino")
2 changes: 1 addition & 1 deletion pysqlsync/dialect/trino/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def __init__(self, options: GeneratorOptions) -> None:

@override
def placeholder(self, index: int) -> str:
raise NotImplementedError()
raise NotImplementedError("operation not supported for Trino")
4 changes: 3 additions & 1 deletion pysqlsync/formation/object_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ class DiscriminatedConstraint(ReferenceConstraint):

@property
def spec(self) -> str:
raise NotImplementedError()
raise NotImplementedError(
"cannot represent discriminated constraint in SQL DDL"
)


@dataclass
Expand Down
12 changes: 9 additions & 3 deletions pysqlsync/formation/sql_to_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,16 @@ def table_to_dataclass(self, table: Table) -> type[DataclassInstance]:
# default arguments must follow non-default arguments
fields.sort(key=lambda f: f[2].default is not dataclasses.MISSING)

# look up target module
module = self.options.namespaces[table.name.scope_id]

# produce class definition with docstring
typ = dataclasses.make_dataclass(class_name, fields) # type: ignore
if sys.version_info >= (3, 12):
typ = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
else:
typ = dataclasses.make_dataclass(
class_name, fields, namespace={"__module__": module.__name__} # type: ignore
)
with StringIO() as out:
for field in dataclasses.fields(typ):
description = field.metadata.get("description")
Expand All @@ -194,8 +202,6 @@ def table_to_dataclass(self, table: Table) -> type[DataclassInstance]:
typ.__doc__ = docstring

# assign the newly created type to the target module
module = self.options.namespaces[table.name.scope_id]
typ.__module__ = module.__name__
setattr(sys.modules[module.__name__], class_name, typ)

return typ
Expand Down
12 changes: 9 additions & 3 deletions pysqlsync/model/entity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ def make_entity(cls: type[DataclassInstance], key: str) -> type[DataclassInstanc
class_module = cls.__module__
class_doc = cls.__doc__

class_type = dataclasses.make_dataclass(
class_name, target_fields, namespace={"__module__": class_module}
)
if sys.version_info >= (3, 12):
class_type = dataclasses.make_dataclass(
class_name, target_fields, module=class_module
)
else:
class_type = dataclasses.make_dataclass(
class_name, target_fields, namespace={"__module__": class_module}
)

class_type.__doc__ = class_doc
setattr(sys.modules[class_module], class_name, class_type)
return class_type
4 changes: 3 additions & 1 deletion pysqlsync/python_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def module_to_stream(module: ModuleType, target: TextIO) -> None:
elif is_dataclass_type(cls):
dataclass_to_stream(cls, target)
else:
raise NotImplementedError()
raise NotImplementedError(
"classes in module must be of enumeration or data-class type"
)
print(file=target)


Expand Down

0 comments on commit 031a97e

Please sign in to comment.