From de0bacd29a69d413044bb102aaa0a3d991950f44 Mon Sep 17 00:00:00 2001 From: Levente Hunyadi Date: Thu, 11 Jul 2024 01:36:14 +0200 Subject: [PATCH] Add more permissive signature to insert/upsert rows --- .vscode/settings.json | 5 +- pysqlsync/base.py | 271 ++++++++++++++++----- pysqlsync/dialect/mssql/connection.py | 41 +++- pysqlsync/dialect/mysql/connection.py | 19 +- pysqlsync/dialect/oracle/connection.py | 45 +++- pysqlsync/dialect/postgresql/connection.py | 18 +- pysqlsync/dialect/redshift/connection.py | 33 ++- pysqlsync/dialect/trino/connection.py | 12 +- pysqlsync/util/unsync.py | 10 + 9 files changed, 351 insertions(+), 103 deletions(-) create mode 100644 pysqlsync/util/unsync.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 237afd2..c6cda2c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,5 +17,8 @@ "python.analysis.typeCheckingMode": "off", "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" - } + }, + "cSpell.words": [ + "unsync" + ] } diff --git a/pysqlsync/base.py b/pysqlsync/base.py index 605616a..e1b2fff 100644 --- a/pysqlsync/base.py +++ b/pysqlsync/base.py @@ -13,8 +13,19 @@ import json import logging import types +import typing from dataclasses import dataclass -from typing import Any, Callable, Iterable, Optional, Sized, TypeVar, Union, overload +from typing import ( + Any, + AsyncIterable, + Callable, + Iterable, + Optional, + Sized, + TypeVar, + Union, + overload, +) from urllib.parse import quote from strong_typing.inspection import DataclassInstance, is_dataclass_type, is_type_enum @@ -39,6 +50,9 @@ E = TypeVar("E", bound=enum.Enum) T = TypeVar("T") +RecordType = tuple[Any, ...] +RecordIterable = Union[Iterable[RecordType], AsyncIterable[RecordType]] + LOGGER = logging.getLogger("pysqlsync") _JSON_ENCODER = json.JSONEncoder( @@ -593,9 +607,7 @@ async def _execute(self, statement: str) -> None: ... - async def execute_all( - self, statement: str, args: Iterable[tuple[Any, ...]] - ) -> None: + async def execute_all(self, statement: str, args: Iterable[RecordType]) -> None: "Executes a SQL statement with several records of data." if not statement: @@ -618,9 +630,7 @@ async def execute_all( raise QueryException(statement) from e @abc.abstractmethod - async def _execute_all( - self, statement: str, records: Iterable[tuple[Any, ...]] - ) -> None: + async def _execute_all(self, statement: str, records: RecordIterable) -> None: "Executes a SQL statement with several records of data." ... @@ -628,7 +638,7 @@ async def _execute_all( async def _execute_typed( self, statement: str, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, table: Table, order: Optional[tuple[str, ...]] = None, ) -> None: @@ -765,7 +775,7 @@ async def delete_data( async def insert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, @@ -797,7 +807,7 @@ async def insert_rows( async def _insert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, @@ -813,12 +823,13 @@ async def _insert_rows( ) order = tuple(name for name in field_names if name) if field_names else None statement = self.connection.generator.get_table_insert_stmt(table, order) + await self._execute_typed(statement, record_generator, table, order) async def upsert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, @@ -854,7 +865,7 @@ async def upsert_rows( async def _upsert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, @@ -870,16 +881,37 @@ async def _upsert_rows( ) order = tuple(name for name in field_names if name) if field_names else None statement = self.connection.generator.get_table_upsert_stmt(table, order) + await self._execute_typed(statement, record_generator, table, order) + @overload + async def _generate_records( + self, + table: Table, + records: Iterable[RecordType], + *, + field_types: tuple[type, ...], + field_names: Optional[tuple[str, ...]] = None, + ) -> Iterable[RecordType]: ... + + @overload async def _generate_records( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: AsyncIterable[RecordType], *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, - ) -> Iterable[tuple[Any, ...]]: + ) -> AsyncIterable[RecordType]: ... + + async def _generate_records( + self, + table: Table, + records: RecordIterable, + *, + field_types: tuple[type, ...], + field_names: Optional[tuple[str, ...]] = None, + ) -> RecordIterable: """ Creates a record generator for a database table. @@ -915,39 +947,136 @@ async def _generate_records( if len(indices) == len(field_types): return records else: - return (tuple(record[i] for i in indices) for record in records) + return self._select_columns(indices, records) else: if len(indices) == len(field_types): - return ( - tuple( - ( - (transformer(field) if field is not None else None) - if transformer is not None - else field - ) - for transformer, field in zip(transformers, record) + return self._transform_columns(transformers, records) + else: + return self._select_transform_columns(indices, transformers, records) + + @overload + @staticmethod + def _select_columns( + indices: list[int], records: Iterable[RecordType] + ) -> Iterable[RecordType]: ... + + @overload + @staticmethod + def _select_columns( + indices: list[int], records: AsyncIterable[RecordType] + ) -> AsyncIterable[RecordType]: ... + + @staticmethod + def _select_columns(indices: list[int], records: RecordIterable) -> RecordIterable: + if isinstance(records, Iterable): + return (tuple(record[i] for i in indices) for record in records) + elif isinstance(records, AsyncIterable): + return (tuple(record[i] for i in indices) async for record in records) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + + @overload + @staticmethod + def _transform_columns( + transformers: list[Optional[Callable[[Any], Any]]], + records: Iterable[RecordType], + ) -> Iterable[RecordType]: ... + + @overload + @staticmethod + def _transform_columns( + transformers: list[Optional[Callable[[Any], Any]]], + records: AsyncIterable[RecordType], + ) -> AsyncIterable[RecordType]: ... + + @staticmethod + def _transform_columns( + transformers: list[Optional[Callable[[Any], Any]]], records: RecordIterable + ) -> RecordIterable: + if isinstance(records, Iterable): + return ( + tuple( + ( + (transformer(field) if field is not None else None) + if transformer is not None + else field ) - for record in records + for transformer, field in zip(transformers, record) ) - else: - return ( - tuple( - ( - (transformer(field) if field is not None else None) - if transformer is not None - else field - ) - for transformer, field in zip( - transformers, (record[i] for i in indices) - ) + for record in records + ) + elif isinstance(records, AsyncIterable): + return ( + tuple( + ( + (transformer(field) if field is not None else None) + if transformer is not None + else field + ) + for transformer, field in zip(transformers, record) + ) + async for record in records + ) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + + @overload + @staticmethod + def _select_transform_columns( + indices: list[int], + transformers: list[Optional[Callable[[Any], Any]]], + records: Iterable[RecordType], + ) -> Iterable[RecordType]: ... + + @overload + @staticmethod + def _select_transform_columns( + indices: list[int], + transformers: list[Optional[Callable[[Any], Any]]], + records: AsyncIterable[RecordType], + ) -> AsyncIterable[RecordType]: ... + + @staticmethod + def _select_transform_columns( + indices: list[int], + transformers: list[Optional[Callable[[Any], Any]]], + records: RecordIterable, + ) -> RecordIterable: + if isinstance(records, Iterable): + return ( + tuple( + ( + (transformer(field) if field is not None else None) + if transformer is not None + else field + ) + for transformer, field in zip( + transformers, (record[i] for i in indices) ) - for record in records ) + for record in records + ) + elif isinstance(records, AsyncIterable): + return ( + tuple( + ( + (transformer(field) if field is not None else None) + if transformer is not None + else field + ) + for transformer, field in zip( + transformers, (record[i] for i in indices) + ) + ) + async for record in records + ) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") async def _get_transformer( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, generator: BaseGenerator, index: int, field_type: type, @@ -972,31 +1101,51 @@ async def _get_transformer( transformer = generator.get_value_transformer(column, field_type) - if table.is_relation(column): - relation = generator.state.get_referenced_table(table.name, column.name) - if relation.is_lookup_table(): - LOGGER.debug( - "found lookup table column %s in table %s", column.name, table.name - ) + if not table.is_relation(column): + return transformer + relation = generator.state.get_referenced_table(table.name, column.name) + if not relation.is_lookup_table(): + return transformer - enum_dict: dict[str, int] - - if field_type is str: - # a single enumeration value represented as a string - values = set(record[index] for record in records) - values.discard(None) # do not insert NULL into referenced table - enum_dict = await self._merge_lookup_table(relation, values) - return generator.get_enum_transformer(enum_dict) - elif field_type is list or field_type is set: - # a list of enumeration values represented as a list of strings - values = set() - for record in records: - values.update(record[index]) - values.discard(None) # do not insert NULL into referenced table - enum_dict = await self._merge_lookup_table(relation, values) - return generator.get_enum_list_transformer(enum_dict) - - return transformer + LOGGER.debug( + "found lookup table column %s in table %s", column.name, table.name + ) + + enum_dict: dict[str, int] + values: set[Optional[str]] = set() + + if field_type is str: + # a single enumeration value represented as a string + if isinstance(records, Iterable): + for record in records: + values.add(record[index]) + elif isinstance(records, AsyncIterable): + async for record in records: + values.add(record[index]) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + values.discard(None) # do not insert NULL into referenced table + enum_dict = await self._merge_lookup_table( + relation, typing.cast(set[str], values) + ) + return generator.get_enum_transformer(enum_dict) + elif field_type is list or field_type is set: + # a list of enumeration values represented as a list of strings + if isinstance(records, Iterable): + for record in records: + values.update(record[index]) + elif isinstance(records, AsyncIterable): + async for record in records: + values.update(record[index]) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + values.discard(None) # do not insert NULL into referenced table + enum_dict = await self._merge_lookup_table( + relation, typing.cast(set[str], values) + ) + return generator.get_enum_list_transformer(enum_dict) + else: + return transformer async def _merge_lookup_table( self, table: Table, values: set[str] diff --git a/pysqlsync/dialect/mssql/connection.py b/pysqlsync/dialect/mssql/connection.py index c574eda..eae9724 100644 --- a/pysqlsync/dialect/mssql/connection.py +++ b/pysqlsync/dialect/mssql/connection.py @@ -1,17 +1,18 @@ import logging import typing -from typing import Any, Iterable, Optional, TypeVar +from typing import AsyncIterable, Iterable, Optional, TypeVar import pyodbc from strong_typing.inspection import is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext +from pysqlsync.base import BaseConnection, BaseContext, RecordIterable, RecordType from pysqlsync.formation.object_types import Table from pysqlsync.model.data_types import quote from pysqlsync.model.id_types import LocalId, QualifiedId from pysqlsync.resultset import resultset_unwrap_object, resultset_unwrap_tuple from pysqlsync.util.dispatch import thread_dispatch from pysqlsync.util.typing import override +from pysqlsync.util.unsync import unsync from .data_types import sql_to_odbc_type @@ -75,18 +76,44 @@ def _execute(self, statement: str) -> None: cur.execute(statement) @override + async def _execute_all(self, statement: str, records: RecordIterable) -> None: + if isinstance(records, Iterable): + await self._internal_execute_all(statement, records) + elif isinstance(records, AsyncIterable): + await self._internal_execute_all(statement, await unsync(records)) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + + @override + async def _execute_typed( + self, + statement: str, + records: RecordIterable, + table: Table, + order: Optional[tuple[str, ...]] = None, + ) -> None: + if isinstance(records, Iterable): + await self._internal_execute_typed(statement, records, table, order) + elif isinstance(records, AsyncIterable): + await self._internal_execute_typed( + statement, await unsync(records), table, order + ) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + @thread_dispatch - def _execute_all(self, statement: str, args: Iterable[tuple[Any, ...]]) -> None: + def _internal_execute_all( + self, statement: str, records: Iterable[RecordType] + ) -> None: with self.native_connection.cursor() as cur: cur.fast_executemany = True - cur.executemany(statement, args) + cur.executemany(statement, records) - @override @thread_dispatch - def _execute_typed( + def _internal_execute_typed( self, statement: str, - records: Iterable[tuple[Any, ...]], + records: Iterable[RecordType], table: Table, order: Optional[tuple[str, ...]] = None, ) -> None: diff --git a/pysqlsync/dialect/mysql/connection.py b/pysqlsync/dialect/mysql/connection.py index 6ae1144..3086aaf 100644 --- a/pysqlsync/dialect/mysql/connection.py +++ b/pysqlsync/dialect/mysql/connection.py @@ -1,15 +1,16 @@ import logging import typing -from typing import Any, Iterable, Optional, TypeVar +from typing import AsyncIterable, Iterable, Optional, TypeVar import aiomysql from strong_typing.inspection import DataclassInstance, is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext +from pysqlsync.base import BaseConnection, BaseContext, RecordIterable from pysqlsync.model.data_types import escape_like from pysqlsync.model.id_types import LocalId from pysqlsync.resultset import resultset_unwrap_dict, resultset_unwrap_tuple from pysqlsync.util.typing import override +from pysqlsync.util.unsync import unsync D = TypeVar("D", bound=DataclassInstance) T = TypeVar("T") @@ -65,11 +66,15 @@ async def _execute(self, statement: str) -> None: await cur.execute(statement) @override - async def _execute_all( - self, statement: str, args: Iterable[tuple[Any, ...]] - ) -> None: - async with self.native_connection.cursor() as cur: - await cur.executemany(statement, args) + async def _execute_all(self, statement: str, records: RecordIterable) -> None: + if isinstance(records, Iterable): + async with self.native_connection.cursor() as cur: + await cur.executemany(statement, records) + elif isinstance(records, AsyncIterable): + async with self.native_connection.cursor() as cur: + await cur.executemany(statement, await unsync(records)) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") @override async def _query_all(self, signature: type[T], statement: str) -> list[T]: diff --git a/pysqlsync/dialect/oracle/connection.py b/pysqlsync/dialect/oracle/connection.py index 96d93e0..1132a63 100644 --- a/pysqlsync/dialect/oracle/connection.py +++ b/pysqlsync/dialect/oracle/connection.py @@ -1,18 +1,25 @@ import logging import re import typing -from typing import Any, Iterable, Optional, TypeVar +from typing import AsyncIterable, Iterable, Optional, TypeVar import oracledb from strong_typing.inspection import DataclassInstance, is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext, QueryException +from pysqlsync.base import ( + BaseConnection, + BaseContext, + QueryException, + RecordIterable, + RecordType, +) from pysqlsync.formation.object_types import Table from pysqlsync.model.data_types import escape_like from pysqlsync.model.id_types import LocalId from pysqlsync.resultset import resultset_unwrap_tuple from pysqlsync.util.dispatch import thread_dispatch from pysqlsync.util.typing import override +from pysqlsync.util.unsync import unsync from .data_types import sql_to_oracle_type @@ -75,18 +82,44 @@ def _execute(self, statement: str) -> None: raise QueryException(s) from e @override + async def _execute_all(self, statement: str, records: RecordIterable) -> None: + if isinstance(records, Iterable): + await self._internal_execute_all(statement, records) + elif isinstance(records, AsyncIterable): + await self._internal_execute_all(statement, await unsync(records)) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + + @override + async def _execute_typed( + self, + statement: str, + records: RecordIterable, + table: Table, + order: Optional[tuple[str, ...]] = None, + ) -> None: + if isinstance(records, Iterable): + await self._internal_execute_typed(statement, records, table, order) + elif isinstance(records, AsyncIterable): + await self._internal_execute_typed( + statement, await unsync(records), table, order + ) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + @thread_dispatch - def _execute_all(self, statement: str, records: Iterable[tuple[Any, ...]]) -> None: + def _internal_execute_all( + self, statement: str, records: Iterable[RecordType] + ) -> None: statement = statement.rstrip("\r\n\t\v ;") with self.native_connection.cursor() as cur: cur.executemany(statement, list(records)) - @override @thread_dispatch - def _execute_typed( + def _internal_execute_typed( self, statement: str, - records: Iterable[tuple[Any, ...]], + records: Iterable[RecordType], table: Table, order: Optional[tuple[str, ...]] = None, ) -> None: diff --git a/pysqlsync/dialect/postgresql/connection.py b/pysqlsync/dialect/postgresql/connection.py index f217e7d..2c894bb 100644 --- a/pysqlsync/dialect/postgresql/connection.py +++ b/pysqlsync/dialect/postgresql/connection.py @@ -1,16 +1,17 @@ import dataclasses import logging import typing -from typing import Any, Iterable, Optional, TypeVar +from typing import AsyncIterable, Iterable, Optional, TypeVar import asyncpg from strong_typing.inspection import DataclassInstance, is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext, ClassRef +from pysqlsync.base import BaseConnection, BaseContext, ClassRef, RecordIterable from pysqlsync.formation.object_types import Table from pysqlsync.model.properties import is_identity_type from pysqlsync.resultset import resultset_unwrap_dict, resultset_unwrap_tuple from pysqlsync.util.typing import override +from pysqlsync.util.unsync import unsync D = TypeVar("D", bound=DataclassInstance) T = TypeVar("T") @@ -62,10 +63,13 @@ async def _execute(self, statement: str) -> None: await self.native_connection.execute(statement) @override - async def _execute_all( - self, statement: str, args: Iterable[tuple[Any, ...]] - ) -> None: - await self.native_connection.executemany(statement, args) + async def _execute_all(self, statement: str, records: RecordIterable) -> None: + if isinstance(records, Iterable): + await self.native_connection.executemany(statement, records) + elif isinstance(records, AsyncIterable): + await self.native_connection.executemany(statement, await unsync(records)) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") @override async def _query_all(self, signature: type[T], statement: str) -> list[T]: @@ -98,7 +102,7 @@ async def insert_data(self, table: type[D], data: Iterable[D]) -> None: async def _insert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, diff --git a/pysqlsync/dialect/redshift/connection.py b/pysqlsync/dialect/redshift/connection.py index 642b170..9355b9f 100644 --- a/pysqlsync/dialect/redshift/connection.py +++ b/pysqlsync/dialect/redshift/connection.py @@ -3,18 +3,26 @@ import re import typing from io import BytesIO -from typing import Any, Iterable, Optional, TypeVar +from typing import AsyncIterable, Iterable, Optional, TypeVar import redshift_connector from strong_typing.inspection import DataclassInstance, is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext, ClassRef, QueryException +from pysqlsync.base import ( + BaseConnection, + BaseContext, + ClassRef, + QueryException, + RecordIterable, + RecordType, +) from pysqlsync.formation.object_types import Table from pysqlsync.model.id_types import LocalId, SupportsQualifiedId from pysqlsync.model.properties import is_identity_type from pysqlsync.resultset import resultset_unwrap_object, resultset_unwrap_tuple from pysqlsync.util.dispatch import thread_dispatch from pysqlsync.util.typing import override +from pysqlsync.util.unsync import unsync D = TypeVar("D", bound=DataclassInstance) T = TypeVar("T") @@ -72,10 +80,18 @@ def _execute(self, statement: str) -> None: raise QueryException(s) from e @override + async def _execute_all(self, statement: str, records: RecordIterable) -> None: + if isinstance(records, Iterable): + await self._internal_execute_all(statement, records) + elif isinstance(records, AsyncIterable): + await self._internal_execute_all(statement, await unsync(records)) + else: + raise TypeError("expected: `Iterable` or `AsyncIterable` of records") + @thread_dispatch - def _execute_all(self, statement: str, args: Iterable[tuple[Any, ...]]) -> None: + def _internal_execute_all(self, statement: str, records: RecordIterable) -> None: with self.native_connection.cursor() as cursor: - cursor.executemany(statement, args) + cursor.executemany(statement, records) @override @thread_dispatch @@ -107,11 +123,14 @@ def insert_data(self, table: type[D], data: Iterable[D]) -> None: async def _insert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, ) -> None: + if not isinstance(records, Iterable): + raise TypeError("expected: `Iterable` of records") + order = tuple(name for name in field_names if name) if field_names else None columns = [col.name for col in table.get_columns(order)] record_generator = await self._generate_records( @@ -124,7 +143,7 @@ def _insert_copy_stream_async( self, table_name: SupportsQualifiedId, columns: list[LocalId], - records: Iterable[tuple[Any, ...]], + records: Iterable[RecordType], ) -> None: self._insert_copy_stream(table_name, columns, records) @@ -132,7 +151,7 @@ def _insert_copy_stream( self, table_name: SupportsQualifiedId, columns: list[LocalId], - records: Iterable[tuple[Any, ...]], + records: Iterable[RecordType], ) -> None: column_list = ", ".join(str(column) for column in columns) copy_query = f"COPY {table_name} ({column_list}) FROM STDIN" diff --git a/pysqlsync/dialect/trino/connection.py b/pysqlsync/dialect/trino/connection.py index 01de9ac..3c9b6b7 100644 --- a/pysqlsync/dialect/trino/connection.py +++ b/pysqlsync/dialect/trino/connection.py @@ -1,10 +1,10 @@ import typing -from typing import Any, Iterable, Optional, TypeVar +from typing import Iterable, Optional, TypeVar import aiotrino from strong_typing.inspection import is_dataclass_type -from pysqlsync.base import BaseConnection, BaseContext +from pysqlsync.base import BaseConnection, BaseContext, RecordIterable from pysqlsync.formation.object_types import Table from pysqlsync.resultset import resultset_unwrap_tuple from pysqlsync.util.typing import override @@ -48,9 +48,7 @@ async def _execute(self, statement: str) -> None: await cur.execute(statement) @override - async def _execute_all( - self, statement: str, args: Iterable[tuple[Any, ...]] - ) -> None: + async def _execute_all(self, statement: str, records: RecordIterable) -> None: raise NotImplementedError("operation not supported for Trino") @override @@ -71,7 +69,7 @@ async def insert_data(self, table: type[T], data: Iterable[T]) -> None: async def _insert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, @@ -82,7 +80,7 @@ async def _insert_rows( async def _upsert_rows( self, table: Table, - records: Iterable[tuple[Any, ...]], + records: RecordIterable, *, field_types: tuple[type, ...], field_names: Optional[tuple[str, ...]] = None, diff --git a/pysqlsync/util/unsync.py b/pysqlsync/util/unsync.py new file mode 100644 index 0000000..262f2d3 --- /dev/null +++ b/pysqlsync/util/unsync.py @@ -0,0 +1,10 @@ +from typing import AsyncIterable, TypeVar + +T = TypeVar("T") + + +async def unsync(items: AsyncIterable[T]) -> list[T]: + result: list[T] = [] + async for item in items: + result.append(item) + return result