From c4fe805b9444a118c25dbca8f7dd25e722dab4f0 Mon Sep 17 00:00:00 2001 From: dmitry krokhin Date: Tue, 21 Nov 2023 20:05:45 +0300 Subject: [PATCH] introduce tarantool driver --- Dockerfile | 2 +- docker-compose.yml | 7 + registry/drivers.py | 113 ------------- registry/drivers/__init__.py | 73 +++++++++ registry/drivers/memory.py | 43 +++++ registry/drivers/tarantool.py | 293 ++++++++++++++++++++++++++++++++++ registry/registry.py | 18 ++- registry/repository.py | 22 ++- requirements.txt | 1 + tests/test_repository.py | 34 +++- 10 files changed, 474 insertions(+), 132 deletions(-) delete mode 100644 registry/drivers.py create mode 100644 registry/drivers/__init__.py create mode 100644 registry/drivers/memory.py create mode 100644 registry/drivers/tarantool.py diff --git a/Dockerfile b/Dockerfile index 5cf1c07..e95c5d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,4 +8,4 @@ RUN pip install --upgrade pip --no-cache-dir -r requirements.txt COPY ./registry /app/registry COPY ./tests /app/tests -CMD pytest +CMD pytest -s diff --git a/docker-compose.yml b/docker-compose.yml index 09b3858..684123c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,13 @@ services: tarantool: image: tarantool/tarantool + tarantool-admin: + image: quay.io/basis-company/tarantool-admin + environment: + TARANTOOL_CONNECTIONS: 'tarantool' + ports: + - "80:80" + tests: build: . depends_on: diff --git a/registry/drivers.py b/registry/drivers.py deleted file mode 100644 index f840f3e..0000000 --- a/registry/drivers.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Optional, Protocol - -from registry.entity import Entity -from registry.schema import StorageDriver - - -class Driver(Protocol): - def __init__(self, dsn: str) -> None: - ... - - async def find( - self, - entity: type[Entity], - queries: list[dict], - limit: Optional[int] = None, - ) -> list[dict]: - raise NotImplementedError() - - async def find_one( - self, - entity: type[Entity], - queries: list[dict], - ) -> Optional[dict]: - rows = await self.find(entity, queries, limit=1) - if len(rows): - return rows[0] - - return None - - async def find_or_create( - self, entity: type[Entity], query: dict, data: dict - ) -> dict: - result = await self.find(entity, [query]) - if len(result): - return result[0] - - return await self.insert(entity, data) - - async def find_or_fail( - self, - entity: type[Entity], - queries: list[dict], - ) -> dict: - instance = await self.find_one(entity, queries) - if not instance: - raise LookupError(f'{entity.__name__} not found') - - return instance - - async def init_schema(self, entity: type[Entity]) -> None: - raise NotImplementedError() - - async def insert(self, entity: type[Entity], data: dict) -> dict: - raise NotImplementedError() - - -driver_instances: dict[str, dict[str, Driver]] = {} - - -async def get_driver(driver: StorageDriver, dsn: str) -> Driver: - if driver not in driver_instances: - driver_instances[driver] = {} - if dsn not in driver_instances[driver]: - driver_instances[driver][dsn] = get_implementation(driver)(dsn) - return driver_instances[driver][dsn] - - -def get_implementation(driver: StorageDriver) -> type[Driver]: - implementations: dict[StorageDriver, type[Driver]] = { - StorageDriver.MEMORY: MemoryDriver - } - if driver in implementations: - return implementations[driver] - raise NotImplementedError(f'Driver {driver} not implemented') - - -class MemoryDriver(Driver): - def __init__(self, dsn: str) -> None: - self.data: dict[type[Entity], list[dict]] = {} - - async def find( - self, - entity: type[Entity], - queries: list[dict], - limit: Optional[int] = None, - ) -> list[dict]: - await self.init_schema(entity) - rows = [ - row for row in self.data[entity] - if await self.is_valid(row, queries) - ] - if limit: - rows = rows[0:limit] - return rows - - async def init_schema(self, entity: type[Entity]) -> None: - if entity not in self.data: - self.data[entity] = [] - - async def insert(self, entity: type[Entity], data: dict) -> dict: - await self.init_schema(entity) - data['id'] = len(self.data[entity]) + 1 - self.data[entity].append(data) - return data - - async def is_valid(self, row, queries: list) -> bool: - for query in queries: - if False not in [ - row[key] == value for (key, value) in query.items() - ]: - return True - - return False diff --git a/registry/drivers/__init__.py b/registry/drivers/__init__.py new file mode 100644 index 0000000..12f8e27 --- /dev/null +++ b/registry/drivers/__init__.py @@ -0,0 +1,73 @@ +from functools import cache +from typing import Optional, Protocol + +from registry.entity import Entity +from registry.schema import StorageDriver + + +class Driver(Protocol): + def __init__(self, dsn: str) -> None: + ... + + async def find( + self, + entity: type[Entity], + queries: list[dict], + limit: Optional[int] = None, + ) -> list[dict]: + raise NotImplementedError() + + async def find_one( + self, + entity: type[Entity], + queries: list[dict], + ) -> Optional[dict]: + rows = await self.find(entity, queries, limit=1) + if len(rows): + return rows[0] + + return None + + async def find_or_create( + self, entity: type[Entity], query: dict, data: dict + ) -> dict: + result = await self.find(entity, [query]) + if len(result): + return result[0] + + return await self.insert(entity, data) + + async def find_or_fail( + self, + entity: type[Entity], + queries: list[dict], + ) -> dict: + instance = await self.find_one(entity, queries) + if not instance: + raise LookupError(f'{entity.__name__} not found') + + return instance + + async def init_schema(self, entity: type[Entity]) -> None: + raise NotImplementedError() + + async def insert(self, entity: type[Entity], data: dict) -> dict: + raise NotImplementedError() + + +@cache +def get_driver(driver: StorageDriver, dsn: str) -> Driver: + return get_implementation(driver)(dsn) + + +@cache +def get_implementation(driver: StorageDriver) -> type[Driver]: + if driver is StorageDriver.MEMORY: + from registry.drivers.memory import MemoryDriver + return MemoryDriver + + if driver is StorageDriver.TARANTOOL: + from registry.drivers.tarantool import TarantoolDriver + return TarantoolDriver + + raise NotImplementedError(f'{driver} driver not implemented') diff --git a/registry/drivers/memory.py b/registry/drivers/memory.py new file mode 100644 index 0000000..ebeb276 --- /dev/null +++ b/registry/drivers/memory.py @@ -0,0 +1,43 @@ +from typing import Optional + +from registry.drivers import Driver +from registry.entity import Entity + + +class MemoryDriver(Driver): + def __init__(self, dsn: str) -> None: + self.data: dict[type[Entity], list[dict]] = {} + + async def find( + self, + entity: type[Entity], + queries: list[dict], + limit: Optional[int] = None, + ) -> list[dict]: + await self.init_schema(entity) + rows = [ + row for row in self.data[entity] + if await self.is_valid(row, queries) + ] + if limit: + rows = rows[0:limit] + return rows + + async def init_schema(self, entity: type[Entity]) -> None: + if entity not in self.data: + self.data[entity] = [] + + async def insert(self, entity: type[Entity], data: dict) -> dict: + await self.init_schema(entity) + data['id'] = len(self.data[entity]) + 1 + self.data[entity].append(data) + return data + + async def is_valid(self, row, queries: list) -> bool: + for query in queries: + if False not in [ + row[key] == value for (key, value) in query.items() + ]: + return True + + return False diff --git a/registry/drivers/tarantool.py b/registry/drivers/tarantool.py new file mode 100644 index 0000000..ee155f8 --- /dev/null +++ b/registry/drivers/tarantool.py @@ -0,0 +1,293 @@ +from asyncio import create_task, gather +from decimal import Decimal +from math import ceil +from typing import Any, Callable, Optional, get_type_hints + +from asynctnt import Connection +from asynctnt.iproto.protocol import SchemaIndex +from dsnparse import parse + +from registry.drivers import Driver +from registry.entity import Entity +from registry.repository import Index, get_entity_repository_class + +constructor_cache: dict[str, Callable] = {} + + +def get_constructor( + space: str, + connection: Connection +) -> Callable: + if space not in constructor_cache: + keys = list( + connection.schema.spaces[space].metadata.name_id_map.keys() + ) + source = [ + 'def ' + space + '(source) -> dict:', + ' return {', + ] + [ + ' "' + keys[n] + '": source[' + str(n) + '],' + for n in range(0, len(keys)) + ] + [ + ' }' + ] + exec("\n\r".join(source), constructor_cache) + + return constructor_cache[space] + + +class TarantoolDriver(Driver): + page_size: int = 5000 + max_threads: int = 2 + + def __init__(self, dsn: str) -> None: + info = parse(dsn) + self.connection = Connection( + host=info.host, + port=info.port, + username=info.username or 'guest', + password=info.password or '', + ) + self.format: dict[type[Entity], dict[str, type]] = {} + self.init: dict[type[Entity], bool] = {} + + async def find( + self, + entity: type[Entity], + queries: list[dict], + limit: Optional[int] = None, + ) -> list[dict]: + await self.init_schema(entity) + constructor = get_constructor(entity.__name__, self.connection) + keys = (queries[0] or {}).keys() + index = cast_index(self.connection, entity.__name__, keys) + params = [get_index_tuple(index, query) for query in queries] + result: list[dict] = [] + + async def worker(): + while len(params) > 0: + values = [ + params.pop() + for _ in range(min(self.page_size, len(params))) + ] + + query = """ + local space, index, values = ... + local result = {} + for i, value in pairs(values) do + local rows = {} + box.space[space].index[index]:pairs(value) + :each(function(t) + table.insert(rows, t) + end) + if #rows > 0 then + table.insert(result, rows) + end + end + return result + """ + + query_params = (entity.__name__, index.name, values) + eval = await self.connection.eval(query, query_params) + for col in eval[0]: + result.extend([constructor(t) for t in col]) + + if limit is not None and len(result) > limit: + break + + workers = 1 + if limit != 1: + workers = min(self.max_threads, ceil(len(params) / self.page_size)) + + tasks = [create_task(worker()) for _ in range(workers)] + + await gather(*tasks, return_exceptions=True) + + return result + + async def insert(self, entity: type[Entity], data: dict) -> dict: + await self.init_schema(entity) + + lua = f''' + local tuple = ... + if tuple[2] == 0 then + if box.sequence.sq_{entity.__name__} == nil then + opts = {'{}'} + last = box.space.{entity.__name__}.index.bucket_id_id:max() + if last ~= nil then + opts['start'] = last.id + 1 + end + box.schema.sequence.create('sq_{entity.__name__}', opts) + end + tuple[2] = box.sequence.sq_{entity.__name__}:next() + end + return box.space.{entity.__name__}:insert(tuple) + ''' + + res = await self.connection.eval(lua, [ + self.get_row_tuple(entity.__name__, data) + ]) + + return self.get_tuple_dict(entity.__name__, res[0]) + + async def init_schema(self, entity: type[Entity]) -> None: + if not self.connection.is_connected: + await self.connection.connect() + + if entity not in self.init: + self.init[entity] = True + if entity.__name__ not in self.connection.schema.spaces: + await self.connection.eval('box.schema.create_space(...)', [ + entity.__name__, { + 'engine': 'memtx', + 'if_not_exists': True, + } + ]) + await self.connection.refetch_schema() + await sync_format(self.connection, entity) + await sync_indexes(self.connection, entity) + + def get_tuple_dict(self, space, param): + format = get_current_format(self.connection, space) + return { + format[n]['name']: param[n] for n in range(len(format)) + } + + def get_row_tuple(self, space, param): + return [ + convert_type(field['type'], param[field['name']]) + if field['name'] in param else convert_type(field['type'], None) + for field in get_current_format(self.connection, space) + ] + + +def convert_type(tarantool_type: str, value: Any) -> str | int | float: + if value is None or value == '': + if tarantool_type == 'number' or tarantool_type == 'unsigned': + return 0 + if tarantool_type == 'string' or tarantool_type == 'str': + return value or '' + + if isinstance(value, str): + match(tarantool_type): + case 'unsigned': + return int(value) + + case 'number': + if value[-1] == '-': + return -1 * float(value[0:-1]) + else: + return float(value) + + if isinstance(value, int) and tarantool_type == 'string': + return str(value) + + if isinstance(value, Decimal): + if tarantool_type == 'number': + return float(value) + if tarantool_type == 'unsigned': + return int(value) + + if isinstance(value, type): + return str(value) + + return value + + +def get_index_tuple(index, param): + return [ + convert_type(field.type, param[field.name]) + for field in index.metadata.fields + if field.name in param + ] + + +def cast_index( + connection: Connection, + space: str, + keys: list[str] +) -> SchemaIndex: + if not isinstance(keys, list): + keys = list(keys) + + if not len(keys): + return connection.schema.spaces[space].indexes[0] + + keys.sort() + + for index in connection.schema.spaces[space].indexes.values(): + candidate_keys = [] + for f in index.metadata.fields: + candidate_keys.append(f.name) + if len(candidate_keys) == len(keys): + break + + candidate_keys.sort() + if candidate_keys == keys: + return index + + raise ValueError(f'no index on {space} for [{",".join(keys)}]') + + +async def sync_format(connection, entity) -> None: + format = [{ + 'name': 'bucket_id', + 'type': 'unsigned', + }] + [ + { + 'name': key, + 'type': get_tarantool_type(type), + } + for (key, type) in get_type_hints(entity).items() + ] + + if format != get_current_format(connection, entity.__name__): + box_space = 'box.space.' + entity.__name__ + await connection.eval(box_space + ':format(...)', [format]) + await connection.refetch_schema() + + +def get_current_format(connection, space: str): + if space not in connection.schema.spaces: + raise LookupError(f'Invalid space {space}') + + if not connection.schema.spaces[space].metadata: + return [] + + return [ + {'name': field.name, 'type': field.type} + for field in connection.schema.spaces[space].metadata.fields + ] + + +async def sync_indexes(connection, entity) -> None: + indexes = [Index(entity, ['bucket_id', 'id'], True)] + indexes.extend([ + i for i in get_entity_repository_class(entity).indexes + if i.entity == entity + ]) + indexes.extend([ + Index(i.entity, ['bucket_id'] + i.fields, i.unique) + for i in get_entity_repository_class(entity).indexes + if i.entity == entity + ]) + + box_space = 'box.space.' + entity.__name__ + changes: bool = False + for index in indexes: + if index.name not in connection.schema.spaces[entity.__name__].indexes: + changes = True + await connection.eval(box_space + ':create_index(...)', [ + index.name, {'parts': index.fields, 'unique': index.unique} + ]) + + if changes: + await connection.refetch_schema() + + +def get_tarantool_type(type) -> str: + if type is float: + return 'number' + if type is int: + return 'unsigned' + return 'str' diff --git a/registry/registry.py b/registry/registry.py index e0e39ca..10804fc 100644 --- a/registry/registry.py +++ b/registry/registry.py @@ -54,7 +54,7 @@ async def bootstrap(self) -> None: if not primary: raise LookupError('primary storage not found') - driver = await get_driver(primary.driver, primary.dsn) + driver = get_driver(primary.driver, primary.dsn) for repository in self.repositories.values(): if isinstance(repository, BucketRepository): @@ -72,7 +72,7 @@ async def find_or_create( if data is None: data = {} if query is None: - query = data + query = dict(**data) context = await self.context(entity, key) for key, value in entity.get_default_values().items(): if key not in data: @@ -137,7 +137,7 @@ async def context(self, entity: type[Entity], key: Any) -> QueryContext: await self.bootstrap() repository = self.get_repository(get_entity_repository_class(entity)) - bucket = await self.get_bucket(repository, key) + bucket = await self.get_bucket(repository.__class__, key) if not bucket.storage_id: storage = await repository.cast_storage(self.storages) @@ -145,7 +145,7 @@ async def context(self, entity: type[Entity], key: Any) -> QueryContext: else: storage = self.get_storage(bucket.storage_id) - driver = await get_driver(storage.driver, storage.dsn) + driver = get_driver(storage.driver, storage.dsn) if bucket.status == BucketStatus.NEW: await repository.init_schema(driver) @@ -161,14 +161,18 @@ async def context(self, entity: type[Entity], key: Any) -> QueryContext: return QueryContext(bucket, driver, entity, repository) - async def get_bucket(self, repository: Repository, key: Any): - if isinstance(repository, BucketRepository | StorageRepository): + async def get_bucket( + self, + repository: type[Repository], + key: Optional[Any] = None + ) -> Bucket: + if repository is BucketRepository or repository is StorageRepository: buckets = self.get_repository(BucketRepository) bucket = buckets.map[Bucket][repository.bucket_id] if isinstance(bucket, Bucket): return bucket - key = await repository.transform_key(key) + key = await self.get_repository(repository).transform_key(key) return await self.find_or_create( entity=Bucket, query={ diff --git a/registry/repository.py b/registry/repository.py index 9f0f63e..3f341b4 100644 --- a/registry/repository.py +++ b/registry/repository.py @@ -13,6 +13,10 @@ class Index: fields: list[str] unique: bool = False + @property + def name(self) -> str: + return '_'.join(self.fields) + class UniqueIndex(Index): unique = True @@ -20,21 +24,17 @@ class UniqueIndex(Index): class Repository: entities: list[type[Entity]] - indexes: Optional[list[Index]] = None + indexes: list[Index] = [] def __init__(self) -> None: - if not self.indexes: - self.indexes = [] - self.map: dict[type[Entity], dict[int, Entity]] = {} for entity in self.entities: - self.indexes.insert(0, UniqueIndex(entity, ['id'])) self.map[entity] = {} async def cast_storage(self, storages: list[Storage]) -> Storage: return storages[0] - async def transform_key(self, key: Any) -> str: + async def transform_key(self, key: Optional[Any] = None) -> str: return '' async def init_data(self, bucket: Bucket, driver: Driver) -> None: @@ -80,12 +80,16 @@ def get_entity_repository_map() -> dict[type[Entity], type[Repository]]: class BucketRepository(Repository): bucket_id: int = 1 entities = [Bucket] + indexes: list[Index] = [ + Index(entity=Bucket, fields=['repository', 'key'], unique=True) + ] async def bootstrap(self, driver: Driver) -> None: bucket_row = await driver.find_or_create( entity=Bucket, query={ - 'id': BucketRepository.bucket_id + 'bucket_id': BucketRepository.bucket_id, + 'id': BucketRepository.bucket_id, }, data={ 'bucket_id': BucketRepository.bucket_id, @@ -100,7 +104,8 @@ async def bootstrap(self, driver: Driver) -> None: storage_row = await driver.find_or_create( entity=Bucket, query={ - 'id': StorageRepository.bucket_id + 'bucket_id': BucketRepository.bucket_id, + 'id': StorageRepository.bucket_id, }, data={ 'bucket_id': BucketRepository.bucket_id, @@ -127,6 +132,7 @@ async def bootstrap(self, driver: Driver, storage: Storage) -> None: await driver.find_or_create( entity=Storage, query=dict( + bucket_id=StorageRepository.bucket_id, id=storage.id, ), data=dict( diff --git a/requirements.txt b/requirements.txt index 6bb2058..797bc3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ anyio asyncpg asynctnt +dsnparse pytest pytest-asyncio diff --git a/tests/test_repository.py b/tests/test_repository.py index c4ef7e7..bbc42bc 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -3,6 +3,7 @@ from pytest import mark from registry.drivers import get_driver +from registry.drivers.tarantool import TarantoolDriver from registry.entity import Entity, Storage from registry.registry import Registry from registry.repository import Index, Repository @@ -26,7 +27,7 @@ class ActionRepository(Repository): ActionTrigger ] indexes = [ - Index(ActionTrigger, ['id']) + Index(Action, ['type', 'owner_id']), ] @@ -34,8 +35,31 @@ class ActionRepository(Repository): @mark.parametrize("storage", [ Storage(1, StorageClass.HOT, StorageDriver.MEMORY, '1'), Storage(1, StorageClass.HOT, StorageDriver.MEMORY, '2'), + Storage( + 1, StorageClass.HOT, StorageDriver.TARANTOOL, 'tcp://tarantool:3301' + ), ]) async def test_hello(storage: Storage): + if storage.driver is StorageDriver.TARANTOOL: + driver: TarantoolDriver = get_driver(storage.driver, storage.dsn) + await driver.connection.connect() + await driver.connection.eval(''' + local todo = {} + for i, space in box.space._space:pairs() do + if space[1] >= 512 then + table.insert(todo, space[3]) + end + end + for i, name in pairs(todo) do + box.space[name]:drop() + end + for i, s in box.space._vsequence:pairs() do + box.sequence[s.name]:drop() + end + box.space._schema:update('max_id', {{'=', 'value', 511}}) + ''') + await driver.connection.refetch_schema() + registry = Registry([storage]) assert len(await registry.find(Action)) == 0 @@ -64,9 +88,13 @@ async def test_hello(storage: Storage): # storage level persistence check [storage] = registry.storages - driver = await get_driver(storage.driver, storage.dsn) + driver = get_driver(storage.driver, storage.dsn) assert len(await driver.find(Action, queries=[{}])) == 2 + bucket = await registry.get_bucket(ActionRepository) + # default values peristence - first_action_dict = await driver.find_one(Action, queries=[{'id': 1}]) + first_action_dict = await driver.find_one(Action, queries=[ + {'bucket_id': bucket.id, 'id': 1} + ]) assert first_action_dict['owner_id'] == 0