Skip to content

Commit

Permalink
Fix Redis connections after reconnect - consumer starts consuming the…
Browse files Browse the repository at this point in the history
… tasks after crash. (#2007)

* Add more logs

* Launch _on_connection_disconnect in Conection only if channel was added properly to the poller

* Prepare test which check the flow of the channel removal from poller

* Change the comment
  • Loading branch information
awmackowiak authored Jun 12, 2024
1 parent 1217865 commit dcb43be
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 2 deletions.
10 changes: 8 additions & 2 deletions kombu/transport/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def __init__(self, *args, **kwargs):

if not self.ack_emulation: # disable visibility timeout
self.QoS = virtual.QoS

self._registered = False
self._queue_cycle = cycle_by_name(self.queue_order_strategy)()
self.Client = self._get_client()
self.ResponseError = self._get_response_error()
Expand All @@ -747,6 +747,9 @@ def __init__(self, *args, **kwargs):
raise

self.connection.cycle.add(self) # add to channel poller.
# and set to true after sucessfuly added channel to the poll.
self._registered = True

# copy errors, in case channel closed but threads still
# are still waiting for data.
self.connection_errors = self.connection.connection_errors
Expand Down Expand Up @@ -1201,7 +1204,10 @@ def _connparams(self, asynchronous=False):
class Connection(connection_cls):
def disconnect(self, *args):
super().disconnect(*args)
channel._on_connection_disconnect(self)
# We remove the connection from the poller
# only if it has been added properly.
if channel._registered:
channel._on_connection_disconnect(self)
connection_cls = Connection

connparams['connection_class'] = connection_cls
Expand Down
156 changes: 156 additions & 0 deletions t/unit/transport/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,173 @@ class XTransport(Transport):
Channel = XChannel

conn = Connection(transport=XTransport)
conn.transport.cycle = Mock(name='cycle')
client.ping.side_effect = RuntimeError()
with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_called_with()
pool.disconnect.reset_mock()
# Ensure that the channel without ensured connection to Redis
# won't be added to the cycle.
conn.transport.cycle.add.assert_not_called()
assert len(conn.transport.channels) == 0

pool_at_init = [None]
with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_not_called()

def test_redis_connection_added_to_cycle_if_ping_succeeds(self):
"""Test should check the connection is added to the cycle only
if the ping to Redis was finished successfully."""
# given: mock pool and client
pool = Mock(name='pool')
client = Mock(name='client')

# override channel class with given mocks
class XChannel(Channel):
def __init__(self, *args, **kwargs):
self._pool = pool
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

# override Channel in Transport with given channel
class XTransport(Transport):
Channel = XChannel

# when: create connection with overridden transport
conn = Connection(transport=XTransport)
conn.transport.cycle = Mock(name='cycle')
# create the channel
chan = conn.channel()
# then: check if ping was called
client.ping.assert_called_once()
# the connection was added to the cycle
conn.transport.cycle.add.assert_called_once()
assert len(conn.transport.channels) == 1
# the channel was flaged as registered into poller
assert chan._registered

def test_redis_on_disconnect_channel_only_if_was_registered(self):
"""Test shoud check if the _on_disconnect method is called only
if the channel was registered into the poller."""
# given: mock pool and client
pool = Mock(name='pool')
client = Mock(
name='client',
ping=Mock(return_value=True)
)

# create RedisConnectionMock class
# for the possibility to run disconnect method
class RedisConnectionMock:
def disconnect(self, *args):
pass

# override Channel method with given mocks
class XChannel(Channel):
connection_class = RedisConnectionMock

def __init__(self, *args, **kwargs):
self._pool = pool
# counter to check if the method was called
self.on_disconect_count = 0
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

def _on_connection_disconnect(self, connection):
# increment the counter when the method is called
self.on_disconect_count += 1

# create the channel
chan = XChannel(Mock(
_used_channel_ids=[],
channel_max=1,
channels=[],
client=Mock(
transport_options={},
hostname="127.0.0.1",
virtual_host=None)))
# create the _connparams with overriden connection_class
connparams = chan._connparams(asynchronous=True)
# create redis.Connection
redis_connection = connparams['connection_class']()
# the connection was added to the cycle
chan.connection.cycle.add.assert_called_once()
# and the ping was called
client.ping.assert_called_once()
# the channel was registered
assert chan._registered
# than disconnect the Redis connection
redis_connection.disconnect()
# the on_disconnect counter should be incremented
assert chan.on_disconect_count == 1

def test_redis__on_disconnect_should_not_be_called_if_not_registered(self):
"""Test should check if the _on_disconnect method is not called because
the connection to Redis isn't established properly."""
# given: mock pool
pool = Mock(name='pool')
# client mock with ping method which return ConnectionError
from redis.exceptions import ConnectionError
client = Mock(
name='client',
ping=Mock(side_effect=ConnectionError())
)

# create RedisConnectionMock
# for the possibility to run disconnect method
class RedisConnectionMock:
def disconnect(self, *args):
pass

# override Channel method with given mocks
class XChannel(Channel):
connection_class = RedisConnectionMock

def __init__(self, *args, **kwargs):
self._pool = pool
# counter to check if the method was called
self.on_disconect_count = 0
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

def _on_connection_disconnect(self, connection):
# increment the counter when the method is called
self.on_disconect_count += 1

# then: exception was risen
with pytest.raises(ConnectionError):
# when: create the channel
chan = XChannel(Mock(
_used_channel_ids=[],
channel_max=1,
channels=[],
client=Mock(
transport_options={},
hostname="127.0.0.1",
virtual_host=None)))
# create the _connparams with overriden connection_class
connparams = chan._connparams(asynchronous=True)
# create redis.Connection
redis_connection = connparams['connection_class']()
# the connection wasn't added to the cycle
chan.connection.cycle.add.assert_not_called()
# the ping was called once with the exception
client.ping.assert_called_once()
# the channel was not registered
assert not chan._registered
# then: disconnect the Redis connection
redis_connection.disconnect()
# the on_disconnect counter shouldn't be incremented
assert chan.on_disconect_count == 0

def test_get_redis_ConnectionError(self):
from redis.exceptions import ConnectionError

Expand Down

0 comments on commit dcb43be

Please sign in to comment.