Skip to content

Commit

Permalink
Fix connect/pool create cancellation and gc after (#16)
Browse files Browse the repository at this point in the history
* Fix connect/pool create cancellation and gc after
  • Loading branch information
DriverX authored Jul 18, 2022
1 parent 15495f4 commit dee63f1
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ jobs:
src:
- '.github/workflows/**'
- 'src/aioredis_cluster/_aioredis/**'
- 'src/aioredis_cluster/aioredis/**'
- 'tests/aioredis_tests/**'
- setup.cfg
- id: force_run
Expand Down
11 changes: 7 additions & 4 deletions src/aioredis_cluster/aioredis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
)

try:
from aioredis.connection import RedisConnection, create_connection
from aioredis.connection import RedisConnection
except ImportError:
from .._aioredis.connection import RedisConnection, create_connection
from .._aioredis.connection import RedisConnection

try:
from aioredis.errors import (
Expand Down Expand Up @@ -61,9 +61,9 @@
WatchVariableError,
)
try:
from aioredis.pool import ConnectionsPool, create_pool
from aioredis.pool import ConnectionsPool
except ImportError:
from .._aioredis.pool import ConnectionsPool, create_pool
from .._aioredis.pool import ConnectionsPool
try:
from aioredis.pubsub import Channel
except ImportError:
Expand All @@ -73,6 +73,9 @@
except ImportError:
from .._aioredis.sentinel import RedisSentinel, create_sentinel

from .connection import create_connection
from .pool import create_pool


__version__ = "1.3.1"

Expand Down
120 changes: 120 additions & 0 deletions src/aioredis_cluster/aioredis/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import asyncio
import logging
import socket
import ssl
from typing import List, Tuple, Type, Union

from .abc import AbcConnection
from .util import parse_url


try:
from aioredis.stream import open_connection, open_unix_connection
except ImportError:
from .._aioredis.stream import open_connection, open_unix_connection
try:
from aioredis.connection import MAX_CHUNK_SIZE, RedisConnection
except ImportError:
from .._aioredis.connection import MAX_CHUNK_SIZE, RedisConnection


logger = logging.getLogger(__name__)


async def create_connection(
address: Union[str, Tuple[str, int], List],
*,
db: int = None,
password: str = None,
ssl: Union[bool, ssl.SSLContext] = None,
encoding: str = None,
parser=None,
timeout: float = None,
connection_cls: Type[AbcConnection] = None,
loop=None,
) -> AbcConnection:
"""Creates redis connection.
Opens connection to Redis server specified by address argument.
Address argument can be one of the following:
* A tuple representing (host, port) pair for TCP connections;
* A string representing either Redis URI or unix domain socket path.
SSL argument is passed through to asyncio.create_connection.
By default SSL/TLS is not used.
By default any timeout is applied at the connection stage, however
you can set a limitted time used trying to open a connection via
the `timeout` Kw.
Encoding argument can be used to decode byte-replies to strings.
By default no decoding is done.
Parser parameter can be used to pass custom Redis protocol parser class.
By default hiredis.Reader is used (unless it is missing or platform
is not CPython).
Return value is RedisConnection instance or a connection_cls if it is
given.
This function is a coroutine.
"""
assert isinstance(address, (tuple, list, str)), "tuple or str expected"
if isinstance(address, str):
address, options = parse_url(address)
logger.debug("Parsed Redis URI %r", address)
db = options.setdefault("db", db)
password = options.setdefault("password", password)
encoding = options.setdefault("encoding", encoding)
timeout = options.setdefault("timeout", timeout)
if "ssl" in options:
assert options["ssl"] or (not options["ssl"] and not ssl), (
"Conflicting ssl options are set",
options["ssl"],
ssl,
)
ssl = ssl or options["ssl"]

if timeout is not None and timeout <= 0:
raise ValueError("Timeout has to be None or a number greater than 0")

if connection_cls:
assert issubclass(
connection_cls, AbcConnection
), "connection_class does not meet the AbcConnection contract"
cls = connection_cls
else:
cls = RedisConnection

if isinstance(address, (list, tuple)):
host, port = address
logger.debug("Creating tcp connection to %r", address)
reader, writer = await asyncio.wait_for(
open_connection(host, port, limit=MAX_CHUNK_SIZE, ssl=ssl), timeout
)
sock = writer.transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
address = sock.getpeername()
address = tuple(address[:2]) # type: ignore
else:
logger.debug("Creating unix connection to %r", address)
reader, writer = await asyncio.wait_for(
open_unix_connection(address, ssl=ssl, limit=MAX_CHUNK_SIZE), timeout
)
sock = writer.transport.get_extra_info("socket")
if sock is not None:
address = sock.getpeername()

conn = cls(reader, writer, encoding=encoding, address=address, parser=parser)

try:
if password is not None:
await conn.auth(password)
if db is not None:
await conn.select(db)
except (asyncio.CancelledError, Exception):
conn.close()
await conn.wait_closed()
raise
return conn
81 changes: 81 additions & 0 deletions src/aioredis_cluster/aioredis/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
import ssl
from typing import List, Tuple, Type, Union

from .abc import AbcConnection, AbcPool
from .util import parse_url


try:
from aioredis.pool import ConnectionsPool
except ImportError:
from .._aioredis.pool import ConnectionsPool


async def create_pool(
address: Union[str, Tuple[str, int], List],
*,
db: int = None,
password: str = None,
ssl: Union[bool, ssl.SSLContext] = None,
encoding: str = None,
minsize: int = 1,
maxsize: int = 10,
parser=None,
create_connection_timeout: float = None,
pool_cls: Type[AbcPool] = None,
connection_cls: Type[AbcConnection] = None,
loop=None,
):
# FIXME: rewrite docstring
"""Creates Redis Pool.
By default it creates pool of Redis instances, but it is
also possible to create pool of plain connections by passing
``lambda conn: conn`` as commands_factory.
*commands_factory* parameter is deprecated since v0.2.9
All arguments are the same as for create_connection.
Returns RedisPool instance or a pool_cls if it is given.
"""
if pool_cls:
assert issubclass(pool_cls, AbcPool), "pool_class does not meet the AbcPool contract"
cls = pool_cls
else:
cls = ConnectionsPool
if isinstance(address, str):
address, options = parse_url(address)
db = options.setdefault("db", db)
password = options.setdefault("password", password)
encoding = options.setdefault("encoding", encoding)
create_connection_timeout = options.setdefault("timeout", create_connection_timeout)
if "ssl" in options:
assert options["ssl"] or (not options["ssl"] and not ssl), (
"Conflicting ssl options are set",
options["ssl"],
ssl,
)
ssl = ssl or options["ssl"]
# TODO: minsize/maxsize

pool = cls(
address,
db,
password,
encoding,
minsize=minsize,
maxsize=maxsize,
ssl=ssl,
parser=parser,
create_connection_timeout=create_connection_timeout,
connection_cls=connection_cls,
)
try:
await pool._fill_free(override_min=False)
except (asyncio.CancelledError, Exception):
pool.close()
await pool.wait_closed()
raise
return pool
5 changes: 3 additions & 2 deletions src/aioredis_cluster/aioredis/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
try:
from aioredis.util import _NOTSET, wait_convert, wait_ok
from aioredis.util import _NOTSET, parse_url, wait_convert, wait_ok
except ImportError:
from .._aioredis.util import _NOTSET, wait_convert, wait_ok
from .._aioredis.util import _NOTSET, parse_url, wait_convert, wait_ok


(_NOTSET,)
Expand All @@ -10,4 +10,5 @@
__all__ = (
"wait_convert",
"wait_ok",
"parse_url",
)
12 changes: 1 addition & 11 deletions src/aioredis_cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,16 +636,6 @@ async def _try_execute(

pool = await self._pooler.ensure_pool(node_addr)

# pool_size = pool.size
# if pool_size >= pool.maxsize and pool.freesize == 0:
# logger.warning(
# "ConnectionPool to %s size limit reached (minsize:%s, maxsize:%s, current:%s])",
# node_addr,
# pool.minsize,
# pool.maxsize,
# pool_size,
# )

if props.asking:
logger.info("Send ASKING to %s for command %r", node_addr, ctx.cmd_name)

Expand Down Expand Up @@ -760,7 +750,7 @@ async def _create_pool(
if opts is None:
opts = {}

default_opts = dict(
default_opts: Dict[str, Any] = dict(
pool_cls=self._pool_cls,
password=self._password,
encoding=self._encoding,
Expand Down
5 changes: 1 addition & 4 deletions src/aioredis_cluster/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,6 @@ def get_connection(self, command: TBytesOrStr, args=()):
continue
if conn.in_pubsub:
continue
if is_pubsub:
self._pubsub_conn = conn
self._pool.remove(conn)
return conn, conn.address
return None, self._address # figure out

Expand Down Expand Up @@ -437,7 +434,7 @@ async def _create_new_connection(self, address) -> AbcConnection:
try:
conn: AbcConnection = await create_connection(
address,
db=self._db,
db=None,
password=self._password,
ssl=self._ssl,
encoding=self._encoding,
Expand Down

0 comments on commit dee63f1

Please sign in to comment.