Skip to content

Commit

Permalink
Merge pull request tomerfiliba-org#531 from notEvil/i530
Browse files Browse the repository at this point in the history
Fix race condition in `Connection.serve` and `AsyncResult.wait`
  • Loading branch information
comrumino authored Mar 18, 2023
2 parents ba07bae + 25430e5 commit 99c5abe
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
7 changes: 5 additions & 2 deletions rpyc/core/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,19 @@ def wait(self):
"""Waits for the result to arrive. If the AsyncResult object has an
expiry set, and the result did not arrive within that timeout,
an :class:`AsyncResultTimeout` exception is raised"""
while not (self._is_ready or self.expired):
while self._waiting():
# Serve the connection since we are not ready. Suppose
# the reply for our seq is served. The callback is this class
# so __call__ sets our obj and _is_ready to true.
self._conn.serve(self._ttl)
self._conn.serve(self._ttl, waiting=self._waiting)

# Check if we timed out before result was ready
if not self._is_ready:
raise AsyncResultTimeout("result expired")

def _waiting(self):
return not (self._is_ready or self.expired)

def add_callback(self, func):
"""Adds a callback to be invoked when the result arrives. The callback
function takes a single argument, which is the current AsyncResult
Expand Down
24 changes: 20 additions & 4 deletions rpyc/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _get_seq_id(self): # IO
return next(self._seqcounter)

def _send(self, msg, seq, args): # IO
data = brine.dump((msg, seq, args))
data = brine.I1.pack(msg) + brine.dump((seq, args)) # see _dispatch
if self._bind_threads:
this_thread = self._get_thread()
data = brine.I8I8.pack(this_thread.id, this_thread._remote_thread_id) + data
Expand Down Expand Up @@ -392,10 +392,13 @@ def _seq_request_callback(self, msg, seq, is_exc, obj):
self._config["logger"].debug(debug_msg.format(msg, seq))

def _dispatch(self, data): # serving---dispatch?
msg, seq, args = brine.load(data)
msg, = brine.I1.unpack(data[:1]) # unpack just msg to minimize time to release
if msg == consts.MSG_REQUEST:
if self._bind_threads:
self._get_thread()._occupation_count += 1
else:
self._recvlock.release()
seq, args = brine.load(data[1:])
self._dispatch_request(seq, args)
else:
if self._bind_threads:
Expand All @@ -404,15 +407,21 @@ def _dispatch(self, data): # serving---dispatch?
if this_thread._occupation_count == 0:
this_thread._remote_thread_id = UNBOUND_THREAD_ID
if msg == consts.MSG_REPLY:
seq, args = brine.load(data[1:])
obj = self._unbox(args)
self._seq_request_callback(msg, seq, False, obj)
if not self._bind_threads:
self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait
elif msg == consts.MSG_EXCEPTION:
if not self._bind_threads:
self._recvlock.release()
seq, args = brine.load(data[1:])
obj = self._unbox_exc(args)
self._seq_request_callback(msg, seq, True, obj)
else:
raise ValueError(f"invalid message type: {msg!r}")

def serve(self, timeout=1, wait_for_lock=True): # serving
def serve(self, timeout=1, wait_for_lock=True, waiting=lambda: True): # serving
"""Serves a single request or reply that arrives within the given
time frame (default is 1 sec). Note that the dispatching of a request
might trigger multiple (nested) requests, thus this function may be
Expand All @@ -427,10 +436,17 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
# Exit early if we cannot acquire the recvlock
if not self._recvlock.acquire(False):
if wait_for_lock:
if not waiting(): # unlikely, but the result could've arrived and another thread could've won the race to acquire
return False
# Wait condition for recvlock release; recvlock is not underlying lock for condition
return self._recv_event.wait(timeout.timeleft())
else:
return False
if not waiting(): # the result arrived and we won the race to acquire, unlucky
self._recvlock.release()
with self._recv_event:
self._recv_event.notify_all()
return False
# Assume the receive rlock is acquired and incremented
# We must release once BEFORE dispatch, dispatch any data, and THEN notify all (see issue #527 and #449)
try:
Expand All @@ -442,11 +458,11 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
self.close() # sends close async request
raise
else:
self._recvlock.release()
if data:
self._dispatch(data) # Dispatch will unbox, invoke callbacks, etc.
return True
else:
self._recvlock.release()
return False
finally:
with self._recv_event:
Expand Down
70 changes: 70 additions & 0 deletions tests/test_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import rpyc
import rpyc.core.async_ as rc_async_
import rpyc.core.protocol as rc_protocol
import contextlib
import signal
import threading
import time
import unittest


class TestRace(unittest.TestCase):
def setUp(self):
self.connection = rpyc.classic.connect_thread()

self.a_str = rpyc.async_(self.connection.builtin.str)

def tearDown(self):
self.connection.close()

def test_asyncresult_race(self):
with _patch():
def hook():
time.sleep(0.2) # loose race

_AsyncResult._HOOK = hook

threading.Thread(target=self.connection.serve_all).start()
time.sleep(0.1) # wait for thread to serve

# schedule KeyboardInterrupt
thread_id = threading.get_ident()
_ = lambda: signal.pthread_kill(thread_id, signal.SIGINT)
timer = threading.Timer(1, _)
timer.start()

a_result = self.a_str("") # request
time.sleep(0.1) # wait for race to start
try:
a_result.wait()
except KeyboardInterrupt:
raise Exception("deadlock")

timer.cancel()


class _AsyncResult(rc_async_.AsyncResult):
_HOOK = None

def __call__(self, *args, **kwargs):
hook = type(self)._HOOK
if hook is not None:
hook()
return super().__call__(*args, **kwargs)


@contextlib.contextmanager
def _patch():
AsyncResult = rc_async_.AsyncResult
try:
rc_async_.AsyncResult = _AsyncResult
rc_protocol.AsyncResult = _AsyncResult # from import
yield

finally:
rc_async_.AsyncResult = AsyncResult
rc_protocol.AsyncResult = AsyncResult


if __name__ == "__main__":
unittest.main()

0 comments on commit 99c5abe

Please sign in to comment.