Skip to content

Commit

Permalink
Add more permissive signature to insert/upsert rows
Browse files Browse the repository at this point in the history
  • Loading branch information
hunyadi committed Jul 11, 2024
1 parent 91728a9 commit 983fe4b
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 95 deletions.
290 changes: 227 additions & 63 deletions pysqlsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@
import json
import logging
import types
import typing
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Sized, TypeVar, Union, overload
from typing import (
Any,
AsyncIterable,
Callable,
Iterable,
Optional,
Sized,
TypeAlias,
TypeVar,
Union,
overload,
)
from urllib.parse import quote

from strong_typing.inspection import DataclassInstance, is_dataclass_type, is_type_enum
Expand All @@ -39,6 +51,9 @@
E = TypeVar("E", bound=enum.Enum)
T = TypeVar("T")

RecordType: TypeAlias = tuple[Any, ...]
RecordIterable: TypeAlias = Union[Iterable[RecordType], AsyncIterable[RecordType]]

LOGGER = logging.getLogger("pysqlsync")

_JSON_ENCODER = json.JSONEncoder(
Expand Down Expand Up @@ -593,9 +608,7 @@ async def _execute(self, statement: str) -> None:

...

async def execute_all(
self, statement: str, args: Iterable[tuple[Any, ...]]
) -> None:
async def execute_all(self, statement: str, args: Iterable[RecordType]) -> None:
"Executes a SQL statement with several records of data."

if not statement:
Expand All @@ -618,17 +631,15 @@ async def execute_all(
raise QueryException(statement) from e

@abc.abstractmethod
async def _execute_all(
self, statement: str, records: Iterable[tuple[Any, ...]]
) -> None:
async def _execute_all(self, statement: str, records: Iterable[RecordType]) -> None:
"Executes a SQL statement with several records of data."

...

async def _execute_typed(
self,
statement: str,
records: Iterable[tuple[Any, ...]],
records: Iterable[RecordType],
table: Table,
order: Optional[tuple[str, ...]] = None,
) -> None:
Expand Down Expand Up @@ -765,7 +776,7 @@ async def delete_data(
async def insert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand Down Expand Up @@ -797,7 +808,7 @@ async def insert_rows(
async def _insert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand All @@ -813,12 +824,18 @@ async def _insert_rows(
)
order = tuple(name for name in field_names if name) if field_names else None
statement = self.connection.generator.get_table_insert_stmt(table, order)
await self._execute_typed(statement, record_generator, table, order)

if isinstance(record_generator, Iterable):
await self._execute_typed(statement, record_generator, table, order)
elif isinstance(record_generator, AsyncIterable):
raise NotImplementedError()
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

async def upsert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand Down Expand Up @@ -854,7 +871,7 @@ async def upsert_rows(
async def _upsert_rows(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
Expand All @@ -870,16 +887,42 @@ async def _upsert_rows(
)
order = tuple(name for name in field_names if name) if field_names else None
statement = self.connection.generator.get_table_upsert_stmt(table, order)
await self._execute_typed(statement, record_generator, table, order)

if isinstance(record_generator, Iterable):
await self._execute_typed(statement, record_generator, table, order)
elif isinstance(record_generator, AsyncIterable):
raise NotImplementedError()
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

@overload
async def _generate_records(
self,
table: Table,
records: Iterable[RecordType],
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> Iterable[RecordType]: ...

@overload
async def _generate_records(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: AsyncIterable[RecordType],
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> Iterable[tuple[Any, ...]]:
) -> AsyncIterable[RecordType]: ...

async def _generate_records(
self,
table: Table,
records: RecordIterable,
*,
field_types: tuple[type, ...],
field_names: Optional[tuple[str, ...]] = None,
) -> RecordIterable:
"""
Creates a record generator for a database table.
Expand Down Expand Up @@ -915,39 +958,136 @@ async def _generate_records(
if len(indices) == len(field_types):
return records
else:
return (tuple(record[i] for i in indices) for record in records)
return self._select_columns(indices, records)
else:
if len(indices) == len(field_types):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(transformers, record)
return self._transform_columns(transformers, records)
else:
return self._select_transform_columns(indices, transformers, records)

@overload
@staticmethod
def _select_columns(
indices: list[int], records: Iterable[RecordType]
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _select_columns(
indices: list[int], records: AsyncIterable[RecordType]
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _select_columns(indices: list[int], records: RecordIterable) -> RecordIterable:
if isinstance(records, Iterable):
return (tuple(record[i] for i in indices) for record in records)
elif isinstance(records, AsyncIterable):
return (tuple(record[i] for i in indices) async for record in records)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

@overload
@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]],
records: Iterable[RecordType],
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]],
records: AsyncIterable[RecordType],
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _transform_columns(
transformers: list[Optional[Callable[[Any], Any]]], records: RecordIterable
) -> RecordIterable:
if isinstance(records, Iterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for record in records
for transformer, field in zip(transformers, record)
)
else:
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
for record in records
)
elif isinstance(records, AsyncIterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for record in records
for transformer, field in zip(transformers, record)
)
async for record in records
)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

@overload
@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: Iterable[RecordType],
) -> Iterable[RecordType]: ...

@overload
@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: AsyncIterable[RecordType],
) -> AsyncIterable[RecordType]: ...

@staticmethod
def _select_transform_columns(
indices: list[int],
transformers: list[Optional[Callable[[Any], Any]]],
records: RecordIterable,
) -> RecordIterable:
if isinstance(records, Iterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
)
for record in records
)
elif isinstance(records, AsyncIterable):
return (
tuple(
(
(transformer(field) if field is not None else None)
if transformer is not None
else field
)
for transformer, field in zip(
transformers, (record[i] for i in indices)
)
)
async for record in records
)
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")

async def _get_transformer(
self,
table: Table,
records: Iterable[tuple[Any, ...]],
records: RecordIterable,
generator: BaseGenerator,
index: int,
field_type: type,
Expand All @@ -972,31 +1112,55 @@ async def _get_transformer(

transformer = generator.get_value_transformer(column, field_type)

if table.is_relation(column):
relation = generator.state.get_referenced_table(table.name, column.name)
if relation.is_lookup_table():
LOGGER.debug(
"found lookup table column %s in table %s", column.name, table.name
)
if not table.is_relation(column):
return transformer
relation = generator.state.get_referenced_table(table.name, column.name)
if not relation.is_lookup_table():
return transformer

enum_dict: dict[str, int]

if field_type is str:
# a single enumeration value represented as a string
values = set(record[index] for record in records)
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(relation, values)
return generator.get_enum_transformer(enum_dict)
elif field_type is list or field_type is set:
# a list of enumeration values represented as a list of strings
values = set()
for record in records:
values.update(record[index])
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(relation, values)
return generator.get_enum_list_transformer(enum_dict)

return transformer
LOGGER.debug(
"found lookup table column %s in table %s", column.name, table.name
)

enum_dict: dict[str, int]
values: set[Optional[str]] = set()

if field_type is str:
# a single enumeration value represented as a string
if isinstance(records, Iterable):
for record in records:
values.add(record[index])
elif isinstance(records, AsyncIterable):
async for record in records:
values.add(record[index])
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")
print("=== single ===")
print(values)
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(
relation, typing.cast(set[str], values)
)
return generator.get_enum_transformer(enum_dict)
elif field_type is list or field_type is set:
# a list of enumeration values represented as a list of strings
if isinstance(records, Iterable):
for record in records:
values.update(record[index])
elif isinstance(records, AsyncIterable):
async for record in records:
values.update(record[index])
else:
raise TypeError("expected: `Iterable` or `AsyncIterable` of records")
print("=== multiple ===")
print(values)
values.discard(None) # do not insert NULL into referenced table
enum_dict = await self._merge_lookup_table(
relation, typing.cast(set[str], values)
)
return generator.get_enum_list_transformer(enum_dict)
else:
return transformer

async def _merge_lookup_table(
self, table: Table, values: set[str]
Expand Down
Loading

0 comments on commit 983fe4b

Please sign in to comment.