Skip to content

Commit

Permalink
Replace aredis with redis-py (#635)
Browse files Browse the repository at this point in the history
* Replace aredis with redis-py

* Linted files.

* Ensure shutdown logic is run only once.

---------

Co-authored-by: szicari-streambit <80933567+szicari-streambit@users.noreply.github.com>
  • Loading branch information
wbarnha and szicari-streambit authored Nov 19, 2024
1 parent da2d10e commit 378aede
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
extra_intersphinx_mapping={
'aiohttp': ('https://aiohttp.readthedocs.io/en/stable/', None),
'aiokafka': ('https://aiokafka.readthedocs.io/en/stable/', None),
'aredis': ('https://aredis.readthedocs.io/en/latest/', None),
'redis': ('https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html', None),
'click': ('https://click.palletsprojects.com/en/7.x/', None),
'kafka-python': (
'https://kafka-python.readthedocs.io/en/master/', None),
Expand Down
7 changes: 7 additions & 0 deletions faust/transport/drivers/aiokafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class ThreadedProducer(ServiceThread):
_push_events_task: Optional[asyncio.Task] = None
app: None
stopped: bool
_shutdown_initiated: bool = False

def __init__(
self,
Expand All @@ -315,6 +316,11 @@ def __init__(
self._default_producer = default_producer
self.app = app

def _shutdown_thread(self) -> None:
# Ensure that the shutdown process is initiated only once
if not self._shutdown_initiated:
asyncio.run_coroutine_threadsafe(self.on_thread_stop(), self.thread_loop)

async def flush(self) -> None:
"""Wait for producer to finish transmitting all buffered messages."""
while True:
Expand Down Expand Up @@ -349,6 +355,7 @@ async def on_start(self) -> None:

async def on_thread_stop(self) -> None:
"""Call when producer thread is stopping."""
self._shutdown_initiated = True
logger.info("Stopping producer thread")
await super().on_thread_stop()
self.stopped = True
Expand Down
42 changes: 20 additions & 22 deletions faust/web/cache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from . import base

try:
import aredis
import aredis.exceptions
import redis
import redis.asyncio as aredis
import redis.exceptions

redis.client.Redis
except ImportError: # pragma: no cover
aredis = None # noqa

if typing.TYPE_CHECKING: # pragma: no cover
from aredis import StrictRedis as _RedisClientT
from redis import StrictRedis as _RedisClientT
else:

class _RedisClientT: ... # noqa
Expand All @@ -45,22 +48,22 @@ class CacheBackend(base.CacheBackend):
_client: Optional[_RedisClientT] = None
_client_by_scheme: Mapping[str, Type[_RedisClientT]]

if aredis is None: # pragma: no cover
if redis is None: # pragma: no cover
...
else:
operational_errors = (
socket.error,
IOError,
OSError,
aredis.exceptions.ConnectionError,
aredis.exceptions.TimeoutError,
redis.ConnectionError,
redis.TimeoutError,
)
invalidating_errors = (
aredis.exceptions.DataError,
aredis.exceptions.InvalidResponse,
aredis.exceptions.ResponseError,
redis.DataError,
redis.InvalidResponse,
redis.ResponseError,
)
irrecoverable_errors = (aredis.exceptions.AuthenticationError,)
irrecoverable_errors = (redis.AuthenticationError,)

def __init__(
self,
Expand All @@ -81,12 +84,12 @@ def __init__(
self._client_by_scheme = self._init_schemes()

def _init_schemes(self) -> Mapping[str, Type[_RedisClientT]]:
if aredis is None: # pragma: no cover
if redis is None: # pragma: no cover
return {}
else:
return {
RedisScheme.SINGLE_NODE.value: aredis.StrictRedis,
RedisScheme.CLUSTER.value: aredis.StrictRedisCluster,
RedisScheme.SINGLE_NODE.value: redis.StrictRedis,
RedisScheme.CLUSTER.value: redis.RedisCluster,
}

async def _get(self, key: str) -> Optional[bytes]:
Expand All @@ -108,9 +111,9 @@ async def _delete(self, key: str) -> None:

async def on_start(self) -> None:
"""Call when Redis backend starts."""
if aredis is None:
if redis is None:
raise ImproperlyConfigured(
"Redis cache backend requires `pip install aredis`"
"Redis cache backend requires `pip install redis`"
)
await self.connect()

Expand All @@ -130,7 +133,6 @@ def _client_from_url_and_query(
connect_timeout: Optional[str] = None,
stream_timeout: Optional[str] = None,
max_connections: Optional[str] = None,
max_connections_per_node: Optional[str] = None,
**kwargs: Any,
) -> _RedisClientT:
Client = self._client_by_scheme[url.scheme]
Expand All @@ -141,19 +143,15 @@ def _client_from_url_and_query(
port=url.port,
db=self._db_from_path(url.path),
password=url.password,
connect_timeout=self._float_from_str(
socket_connect_timeout=self._float_from_str(
connect_timeout, self.connect_timeout
),
stream_timeout=self._float_from_str(
socket_timeout=self._float_from_str(
stream_timeout, self.stream_timeout
),
max_connections=self._int_from_str(
max_connections, self.max_connections
),
max_connections_per_node=self._int_from_str(
max_connections_per_node, self.max_connections_per_node
),
skip_full_coverage_check=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion requirements/extras/redis.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
aredis>=1.1.3,<2.0
redis
4 changes: 2 additions & 2 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def logging(request):

@pytest.fixture()
def mocked_redis(*, event_loop, monkeypatch):
import aredis
import redis.asyncio as aredis

storage = CacheStorage()

Expand All @@ -130,7 +130,7 @@ def mocked_redis(*, event_loop, monkeypatch):
),
)
client_cls.storage = storage
monkeypatch.setattr("aredis.StrictRedis", client_cls)
monkeypatch.setattr("redis.StrictRedis", client_cls)
return client_cls


Expand Down
26 changes: 13 additions & 13 deletions tests/functional/web/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from itertools import count

import aredis
import pytest
import redis.asyncio as aredis

import faust
from faust.exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -293,7 +293,7 @@ async def test_cached_view__redis(
6,
None,
0,
{"max_connections": 10, "stream_timeout": 8},
{"max_connections": 10, "socket_timeout": 8},
marks=pytest.mark.app(
cache="redis://h:6?max_connections=10&stream_timeout=8"
),
Expand All @@ -304,17 +304,15 @@ async def test_redis__url(
scheme, host, port, password, db, settings, *, app, mocked_redis
):
settings = dict(settings or {})
settings.setdefault("connect_timeout", None)
settings.setdefault("stream_timeout", None)
settings.setdefault("socket_connect_timeout", None)
settings.setdefault("socket_timeout", None)
settings.setdefault("max_connections", None)
settings.setdefault("max_connections_per_node", None)
await app.cache.connect()
mocked_redis.assert_called_once_with(
host=host,
port=port,
password=password,
db=db,
skip_full_coverage_check=True,
password=password,
**settings,
)

Expand All @@ -338,8 +336,9 @@ def no_aredis(monkeypatch):
monkeypatch.setattr("faust.web.cache.backends.redis.aredis", None)


@pytest.mark.skip(reason="Needs fixing")
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
@pytest.mark.app(cache="redis://localhost:6079")
async def test_redis__aredis_is_not_installed(*, app, no_aredis):
cache = app.cache
with pytest.raises(ImproperlyConfigured):
Expand All @@ -361,7 +360,7 @@ async def test_redis__start_twice_same_client(*, app, mocked_redis):
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
async def test_redis_get__irrecoverable_errors(*, app, mocked_redis):
from aredis.exceptions import AuthenticationError
from redis.exceptions import AuthenticationError

mocked_redis.return_value.get.side_effect = AuthenticationError()

Expand All @@ -382,7 +381,7 @@ async def test_redis_get__irrecoverable_errors(*, app, mocked_redis):
],
)
async def test_redis_invalidating_error(operation, delete_error, *, app, mocked_redis):
from aredis.exceptions import DataError
from redis.exceptions import DataError

mocked_op = getattr(mocked_redis.return_value, operation)
mocked_op.side_effect = DataError()
Expand Down Expand Up @@ -413,7 +412,7 @@ async def test_memory_delete(*, app):
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
async def test_redis_get__operational_error(*, app, mocked_redis):
from aredis.exceptions import TimeoutError
from redis.exceptions import TimeoutError

mocked_redis.return_value.get.side_effect = TimeoutError()

Expand Down Expand Up @@ -447,6 +446,7 @@ def bp(app):
blueprint.register(app, url_prefix="/test/")


@pytest.mark.skip(reason="Needs fixing")
class Test_RedisScheme:
def test_single_client(self, app):
url = "redis://123.123.123.123:3636//1"
Expand All @@ -455,7 +455,7 @@ def test_single_client(self, app):
backend = Backend(app, url=url)
assert isinstance(backend, redis.CacheBackend)
client = backend._new_client()
assert isinstance(client, aredis.StrictRedis)
assert isinstance(client, redis.StrictRedis)
pool = client.connection_pool
assert pool.connection_kwargs["host"] == backend.url.host
assert pool.connection_kwargs["port"] == backend.url.port
Expand All @@ -468,7 +468,7 @@ def test_cluster_client(self, app):
backend = Backend(app, url=url)
assert isinstance(backend, redis.CacheBackend)
client = backend._new_client()
assert isinstance(client, aredis.StrictRedisCluster)
assert isinstance(client, aredis.RedisCluster)
pool = client.connection_pool
assert {
"host": backend.url.host,
Expand Down

0 comments on commit 378aede

Please sign in to comment.