Skip to content

Commit

Permalink
Add more permissive signature to insert/upsert rows
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Jul 11, 2024
1 parent 91728a9 commit de0bacd
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 103 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@
"python.analysis.typeCheckingMode": "off",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
}
},
"cSpell.words": [
"unsync"
]
}
271 changes: 210 additions & 61 deletions pysqlsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@
import json
import logging
import types
import typing
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Sized, TypeVar, Union, overload
from typing import (
Any,
AsyncIterable,
Callable,
Iterable,
Optional,
Sized,
TypeVar,
Union,
overload,
)
from urllib.parse import quote

from strong_typing.inspection import DataclassInstance, is_dataclass_type, is_type_enum
Expand All @@ -39,6 +50,9 @@
E = TypeVar("E", bound=enum.Enum)
T = TypeVar("T")

RecordType = tuple[Any, ...]
RecordIterable = Union[Iterable[RecordType], AsyncIterable[RecordType]]

LOGGER = logging.getLogger("pysqlsync")

_JSON_ENCODER = json.JSONEncoder(
Expand Down Expand Up @@ -593,9 +607,7 @@ async def _execute(self, statement: str) -> None:

...

async def execute_all(
self, statement: str, args: Iterable[tuple[Any, ...]]
) -> None:
async def execute_all(self, statement: str, args: Iterable[RecordType]) -> None:
"Executes a SQL statement with several records of data."

if not statement:
Expand All @@ -618,17 +630,15 @@ async def execute_all(
raise QueryException(statement) from e

@abc.abstractmethod
async def _execute_all(
self, statement: str, records: Iterable[tuple[Any, ...]]
) -> None:
async def _execute_all(self, statement: str, records: RecordIterable) -> None:
"Executes a SQL statement with several records of data."

...

async def _execute_typed(
self,
statement: str,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
table: Table,
order: Optional[tuple[str, ...]] = None,
) -> None:
Expand Down Expand Up @@ -765,7 +775,7 @@ async def delete_data(
async def insert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand Down Expand Up @@ -797,7 +807,7 @@ async def insert_rows(
async def _insert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand All @@ -813,12 +823,13 @@ async def _insert_rows(
)
order = tuple(name for name in field_names if name) if field_names else None
statement = self.connection.generator.get_table_insert_stmt(table, order)

await self._execute_typed(statement, record_generator, table, order)

async def upsert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand Down Expand Up @@ -854,7 +865,7 @@ async def upsert_rows(
async def _upsert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand All @@ -870,16 +881,37 @@ async def _upsert_rows(
)
order = tuple(name for name in field_names if name) if field_names else None
statement = self.connection.generator.get_table_upsert_stmt(table, order)

await self._execute_typed(statement, record_generator, table, order)

@overload
async def _generate_records(
self,
table: Table,
records: Iterable[RecordType],
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> Iterable[RecordType]: ...

@overload
async def _generate_records(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: AsyncIterable[RecordType],
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> Iterable[tuple[Any, ...]]:
) -> AsyncIterable[RecordType]: ...

async def _generate_records(
self,
table: Table,
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> RecordIterable:
"""
Creates a record generator for a database table.
Expand Down Expand Up @@ -915,39 +947,136 @@ async def _generate_records(
if len(indices) == len(field_types):
return records
else:
return (tuple(record[i] for i in indices) for record in records)
return self._select_columns(indices, records)
else:
if len(indices) == len(field_types):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(transformers, record)
return self._transform_columns(transformers, records)
else:
return self._select_transform_columns(indices, transformers, records)

@overload
@staticmethod
def _select_columns(
indices: list[int], records: Iterable[RecordType]
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _select_columns(
indices: list[int], records: AsyncIterable[RecordType]
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _select_columns(indices: list[int], records: RecordIterable) -> RecordIterable:
if isinstance(records, Iterable):
return (tuple(record[i] for i in indices) for record in records)
elif isinstance(records, AsyncIterable):
return (tuple(record[i] for i in indices) async for record in records)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

@overload
@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]],
records: Iterable[RecordType],
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]],
records: AsyncIterable[RecordType],
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]], records: RecordIterable
) -> RecordIterable:
if isinstance(records, Iterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for record in records
for transformer, field in zip(transformers, record)
)
else:
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
for record in records
)
elif isinstance(records, AsyncIterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(transformers, record)
)
async for record in records
)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

@overload
@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: Iterable[RecordType],
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: AsyncIterable[RecordType],
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: RecordIterable,
) -> RecordIterable:
if isinstance(records, Iterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
for record in records
)
for record in records
)
elif isinstance(records, AsyncIterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
)
async for record in records
)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

async def _get_transformer(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
generator: BaseGenerator,
index: int,
field_type: type,
Expand All @@ -972,31 +1101,51 @@ async def _get_transformer(

transformer = generator.get_value_transformer(column, field_type)

if table.is_relation(column):
relation = generator.state.get_referenced_table(table.name, column.name)
if relation.is_lookup_table():
LOGGER.debug(
"found lookup table column %s in table %s", column.name, table.name
)
if not table.is_relation(column):
return transformer
relation = generator.state.get_referenced_table(table.name, column.name)
if not relation.is_lookup_table():
return transformer

enum_dict: dict[str, int]

if field_type is str:
# a single enumeration value represented as a string
values = set(record[index] for record in records)
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(relation, values)
return generator.get_enum_transformer(enum_dict)
elif field_type is list or field_type is set:
# a list of enumeration values represented as a list of strings
values = set()
for record in records:
values.update(record[index])
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(relation, values)
return generator.get_enum_list_transformer(enum_dict)

return transformer
LOGGER.debug(
"found lookup table column %s in table %s", column.name, table.name
)

enum_dict: dict[str, int]
values: set[Optional[str]] = set()

if field_type is str:
# a single enumeration value represented as a string
if isinstance(records, Iterable):
for record in records:
values.add(record[index])
elif isinstance(records, AsyncIterable):
async for record in records:
values.add(record[index])
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(
relation, typing.cast(set[str], values)
)
return generator.get_enum_transformer(enum_dict)
elif field_type is list or field_type is set:
# a list of enumeration values represented as a list of strings
if isinstance(records, Iterable):
for record in records:
values.update(record[index])
elif isinstance(records, AsyncIterable):
async for record in records:
values.update(record[index])
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(
relation, typing.cast(set[str], values)
)
return generator.get_enum_list_transformer(enum_dict)
else:
return transformer

async def _merge_lookup_table(
self, table: Table, values: set[str]
Expand Down
Loading

0 comments on commit de0bacd

Please sign in to comment.