From dee63f16e51e9d412a9aba81e40a663f9ca324a6 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov <274471+DriverX@users.noreply.github.com> Date: Tue, 19 Jul 2022 02:59:47 +0300 Subject: [PATCH] Fix connect/pool create cancellation and gc after (#16) * Fix connect/pool create cancellation and gc after --- .github/workflows/ci.yml | 1 + src/aioredis_cluster/aioredis/__init__.py | 11 +- src/aioredis_cluster/aioredis/connection.py | 120 ++++++++++++++++++++ src/aioredis_cluster/aioredis/pool.py | 81 +++++++++++++ src/aioredis_cluster/aioredis/util.py | 5 +- src/aioredis_cluster/cluster.py | 12 +- src/aioredis_cluster/pool.py | 5 +- 7 files changed, 214 insertions(+), 21 deletions(-) create mode 100644 src/aioredis_cluster/aioredis/connection.py create mode 100644 src/aioredis_cluster/aioredis/pool.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 79292de..c16c34a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -104,6 +104,7 @@ jobs: src: - '.github/workflows/**' - 'src/aioredis_cluster/_aioredis/**' + - 'src/aioredis_cluster/aioredis/**' - 'tests/aioredis_tests/**' - setup.cfg - id: force_run diff --git a/src/aioredis_cluster/aioredis/__init__.py b/src/aioredis_cluster/aioredis/__init__.py index c2f8fad..b6225eb 100644 --- a/src/aioredis_cluster/aioredis/__init__.py +++ b/src/aioredis_cluster/aioredis/__init__.py @@ -16,9 +16,9 @@ ) try: - from aioredis.connection import RedisConnection, create_connection + from aioredis.connection import RedisConnection except ImportError: - from .._aioredis.connection import RedisConnection, create_connection + from .._aioredis.connection import RedisConnection try: from aioredis.errors import ( @@ -61,9 +61,9 @@ WatchVariableError, ) try: - from aioredis.pool import ConnectionsPool, create_pool + from aioredis.pool import ConnectionsPool except ImportError: - from .._aioredis.pool import ConnectionsPool, create_pool + from .._aioredis.pool import ConnectionsPool try: from aioredis.pubsub import Channel except ImportError: @@ -73,6 +73,9 @@ except ImportError: from .._aioredis.sentinel import RedisSentinel, create_sentinel +from .connection import create_connection +from .pool import create_pool + __version__ = "1.3.1" diff --git a/src/aioredis_cluster/aioredis/connection.py b/src/aioredis_cluster/aioredis/connection.py new file mode 100644 index 0000000..e844ba2 --- /dev/null +++ b/src/aioredis_cluster/aioredis/connection.py @@ -0,0 +1,120 @@ +import asyncio +import logging +import socket +import ssl +from typing import List, Tuple, Type, Union + +from .abc import AbcConnection +from .util import parse_url + + +try: + from aioredis.stream import open_connection, open_unix_connection +except ImportError: + from .._aioredis.stream import open_connection, open_unix_connection +try: + from aioredis.connection import MAX_CHUNK_SIZE, RedisConnection +except ImportError: + from .._aioredis.connection import MAX_CHUNK_SIZE, RedisConnection + + +logger = logging.getLogger(__name__) + + +async def create_connection( + address: Union[str, Tuple[str, int], List], + *, + db: int = None, + password: str = None, + ssl: Union[bool, ssl.SSLContext] = None, + encoding: str = None, + parser=None, + timeout: float = None, + connection_cls: Type[AbcConnection] = None, + loop=None, +) -> AbcConnection: + """Creates redis connection. + + Opens connection to Redis server specified by address argument. + Address argument can be one of the following: + * A tuple representing (host, port) pair for TCP connections; + * A string representing either Redis URI or unix domain socket path. + + SSL argument is passed through to asyncio.create_connection. + By default SSL/TLS is not used. + + By default any timeout is applied at the connection stage, however + you can set a limitted time used trying to open a connection via + the `timeout` Kw. + + Encoding argument can be used to decode byte-replies to strings. + By default no decoding is done. + + Parser parameter can be used to pass custom Redis protocol parser class. + By default hiredis.Reader is used (unless it is missing or platform + is not CPython). + + Return value is RedisConnection instance or a connection_cls if it is + given. + + This function is a coroutine. + """ + assert isinstance(address, (tuple, list, str)), "tuple or str expected" + if isinstance(address, str): + address, options = parse_url(address) + logger.debug("Parsed Redis URI %r", address) + db = options.setdefault("db", db) + password = options.setdefault("password", password) + encoding = options.setdefault("encoding", encoding) + timeout = options.setdefault("timeout", timeout) + if "ssl" in options: + assert options["ssl"] or (not options["ssl"] and not ssl), ( + "Conflicting ssl options are set", + options["ssl"], + ssl, + ) + ssl = ssl or options["ssl"] + + if timeout is not None and timeout <= 0: + raise ValueError("Timeout has to be None or a number greater than 0") + + if connection_cls: + assert issubclass( + connection_cls, AbcConnection + ), "connection_class does not meet the AbcConnection contract" + cls = connection_cls + else: + cls = RedisConnection + + if isinstance(address, (list, tuple)): + host, port = address + logger.debug("Creating tcp connection to %r", address) + reader, writer = await asyncio.wait_for( + open_connection(host, port, limit=MAX_CHUNK_SIZE, ssl=ssl), timeout + ) + sock = writer.transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + address = sock.getpeername() + address = tuple(address[:2]) # type: ignore + else: + logger.debug("Creating unix connection to %r", address) + reader, writer = await asyncio.wait_for( + open_unix_connection(address, ssl=ssl, limit=MAX_CHUNK_SIZE), timeout + ) + sock = writer.transport.get_extra_info("socket") + if sock is not None: + address = sock.getpeername() + + conn = cls(reader, writer, encoding=encoding, address=address, parser=parser) + + try: + if password is not None: + await conn.auth(password) + if db is not None: + await conn.select(db) + except (asyncio.CancelledError, Exception): + conn.close() + await conn.wait_closed() + raise + return conn diff --git a/src/aioredis_cluster/aioredis/pool.py b/src/aioredis_cluster/aioredis/pool.py new file mode 100644 index 0000000..9f7395f --- /dev/null +++ b/src/aioredis_cluster/aioredis/pool.py @@ -0,0 +1,81 @@ +import asyncio +import ssl +from typing import List, Tuple, Type, Union + +from .abc import AbcConnection, AbcPool +from .util import parse_url + + +try: + from aioredis.pool import ConnectionsPool +except ImportError: + from .._aioredis.pool import ConnectionsPool + + +async def create_pool( + address: Union[str, Tuple[str, int], List], + *, + db: int = None, + password: str = None, + ssl: Union[bool, ssl.SSLContext] = None, + encoding: str = None, + minsize: int = 1, + maxsize: int = 10, + parser=None, + create_connection_timeout: float = None, + pool_cls: Type[AbcPool] = None, + connection_cls: Type[AbcConnection] = None, + loop=None, +): + # FIXME: rewrite docstring + """Creates Redis Pool. + + By default it creates pool of Redis instances, but it is + also possible to create pool of plain connections by passing + ``lambda conn: conn`` as commands_factory. + + *commands_factory* parameter is deprecated since v0.2.9 + + All arguments are the same as for create_connection. + + Returns RedisPool instance or a pool_cls if it is given. + """ + if pool_cls: + assert issubclass(pool_cls, AbcPool), "pool_class does not meet the AbcPool contract" + cls = pool_cls + else: + cls = ConnectionsPool + if isinstance(address, str): + address, options = parse_url(address) + db = options.setdefault("db", db) + password = options.setdefault("password", password) + encoding = options.setdefault("encoding", encoding) + create_connection_timeout = options.setdefault("timeout", create_connection_timeout) + if "ssl" in options: + assert options["ssl"] or (not options["ssl"] and not ssl), ( + "Conflicting ssl options are set", + options["ssl"], + ssl, + ) + ssl = ssl or options["ssl"] + # TODO: minsize/maxsize + + pool = cls( + address, + db, + password, + encoding, + minsize=minsize, + maxsize=maxsize, + ssl=ssl, + parser=parser, + create_connection_timeout=create_connection_timeout, + connection_cls=connection_cls, + ) + try: + await pool._fill_free(override_min=False) + except (asyncio.CancelledError, Exception): + pool.close() + await pool.wait_closed() + raise + return pool diff --git a/src/aioredis_cluster/aioredis/util.py b/src/aioredis_cluster/aioredis/util.py index a63b177..bd7cf2f 100644 --- a/src/aioredis_cluster/aioredis/util.py +++ b/src/aioredis_cluster/aioredis/util.py @@ -1,7 +1,7 @@ try: - from aioredis.util import _NOTSET, wait_convert, wait_ok + from aioredis.util import _NOTSET, parse_url, wait_convert, wait_ok except ImportError: - from .._aioredis.util import _NOTSET, wait_convert, wait_ok + from .._aioredis.util import _NOTSET, parse_url, wait_convert, wait_ok (_NOTSET,) @@ -10,4 +10,5 @@ __all__ = ( "wait_convert", "wait_ok", + "parse_url", ) diff --git a/src/aioredis_cluster/cluster.py b/src/aioredis_cluster/cluster.py index eddde18..c7dc4fa 100644 --- a/src/aioredis_cluster/cluster.py +++ b/src/aioredis_cluster/cluster.py @@ -636,16 +636,6 @@ async def _try_execute( pool = await self._pooler.ensure_pool(node_addr) - # pool_size = pool.size - # if pool_size >= pool.maxsize and pool.freesize == 0: - # logger.warning( - # "ConnectionPool to %s size limit reached (minsize:%s, maxsize:%s, current:%s])", - # node_addr, - # pool.minsize, - # pool.maxsize, - # pool_size, - # ) - if props.asking: logger.info("Send ASKING to %s for command %r", node_addr, ctx.cmd_name) @@ -760,7 +750,7 @@ async def _create_pool( if opts is None: opts = {} - default_opts = dict( + default_opts: Dict[str, Any] = dict( pool_cls=self._pool_cls, password=self._password, encoding=self._encoding, diff --git a/src/aioredis_cluster/pool.py b/src/aioredis_cluster/pool.py index d8ee564..eb543f6 100644 --- a/src/aioredis_cluster/pool.py +++ b/src/aioredis_cluster/pool.py @@ -266,9 +266,6 @@ def get_connection(self, command: TBytesOrStr, args=()): continue if conn.in_pubsub: continue - if is_pubsub: - self._pubsub_conn = conn - self._pool.remove(conn) return conn, conn.address return None, self._address # figure out @@ -437,7 +434,7 @@ async def _create_new_connection(self, address) -> AbcConnection: try: conn: AbcConnection = await create_connection( address, - db=self._db, + db=None, password=self._password, ssl=self._ssl, encoding=self._encoding,