Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition when cleaning up. #143

Merged
merged 21 commits into from
Jul 10, 2023
9 changes: 8 additions & 1 deletion fed/_private/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from fed.cleanup import CleanupManager


class GlobalContext:
def __init__(self) -> None:
self._seq_count = 0
self._cleanup_manager = CleanupManager()

def next_seq_id(self):
def next_seq_id(self) -> int:
self._seq_count += 1
return self._seq_count

def get_cleanup_manager(self) -> CleanupManager:
return self._cleanup_manager


_global_context = None

Expand All @@ -34,4 +40,5 @@ def get_global_context():

def clear_global_context():
global _global_context
_global_context.get_cleanup_manager().graceful_stop()
_global_context = None
6 changes: 3 additions & 3 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
start_recv_proxy,
start_send_proxy,
)
from fed.cleanup import set_exit_on_failure_sending, wait_sending
from fed.fed_object import FedObject
from fed.utils import is_ray_object_refs, setup_logger

Expand Down Expand Up @@ -215,7 +214,9 @@ def init(
)

logger.info(f'Started rayfed with {cluster_config}')
set_exit_on_failure_sending(exit_on_failure_cross_silo_sending)
get_global_context().get_cleanup_manager().start(
exit_when_failure_sending=exit_on_failure_cross_silo_sending)

recv_actor_config = fed_config.ProxyActorConfig(
resource_label=cross_silo_recv_resource_label)
# Start recv proxy
Expand Down Expand Up @@ -249,7 +250,6 @@ def shutdown():
"""
Shutdown a RayFed client.
"""
wait_sending()
compatible_utils._clear_internal_kv()
clear_global_context()
logger.info('Shutdowned rayfed.')
Expand Down
169 changes: 78 additions & 91 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,97 +23,84 @@

logger = logging.getLogger(__name__)

_sending_obj_refs_q = None

_check_send_thread = None

_EXIT_ON_FAILURE_SENDING = False


def set_exit_on_failure_sending(exit_when_failure_sending: bool):
global _EXIT_ON_FAILURE_SENDING
_EXIT_ON_FAILURE_SENDING = exit_when_failure_sending


def get_exit_when_failure_sending():
global _EXIT_ON_FAILURE_SENDING
return _EXIT_ON_FAILURE_SENDING


def _check_sending_objs():
def _signal_exit():
os.kill(os.getpid(), signal.SIGTERM)

global _sending_obj_refs_q
if not _sending_obj_refs_q:
_sending_obj_refs_q = deque()

while True:
try:
obj_ref = _sending_obj_refs_q.popleft()
except IndexError:
time.sleep(0.5)
continue
if isinstance(obj_ref, bool):
break
try:
res = ray.get(obj_ref)
except Exception as e:
logger.warn(f'Failed to send {obj_ref} with error: {e}')
res = False
if not res and get_exit_when_failure_sending():
logger.warn('Signal self to exit.')
_signal_exit()
break

logger.info('Check sending thread was exited.')
global _check_send_thread
_check_send_thread = None
logger.info('Clearing sending queue.')
_sending_obj_refs_q = None


def _main_thread_monitor():
main_thread = threading.main_thread()
main_thread.join()
notify_to_exit()


_monitor_thread = None


def _start_check_sending():
global _sending_obj_refs_q
if not _sending_obj_refs_q:
_sending_obj_refs_q = deque()

global _check_send_thread
if not _check_send_thread:
_check_send_thread = threading.Thread(target=_check_sending_objs)
_check_send_thread.start()
class CleanupManager:
"""
This class is used to manage the related works when the fed driver exiting.
It monitors whether the main thread is broken and it needs wait until all sending
objects get repsonsed.

The main logic path is:
A. If `fed.shutdown()` is invoked in the main thread and every thing works well,
the `graceful_stop()` will be invoked as well and the checking thread will be
notifiled to exit gracefully.

B. If the main thread are broken before sending the notification flag to the
sending thread, the monitor thread will detech that and it joins until the main
jovany-wang marked this conversation as resolved.
Show resolved Hide resolved
thread exited, then notifys the checking thread.
"""

def __init__(self) -> None:
# `deque()` is thread safe on `popleft` and `append` operations.
# See https://docs.python.org/3/library/collections.html#deque-objects
self._sending_obj_refs_q = deque()
self._check_send_thread = None
self._monitor_thread = None

def start(self, exit_when_failure_sending=False):
self._exit_when_failure_sending = exit_when_failure_sending

def __check_func():
self._check_sending_objs()

self._check_send_thread = threading.Thread(target=__check_func)
self._check_send_thread.start()
logger.info('Start check sending thread.')

global _monitor_thread
if not _monitor_thread:
_monitor_thread = threading.Thread(target=_main_thread_monitor)
_monitor_thread.start()
logger.info('Start check sending monitor thread.')


def push_to_sending(obj_ref: ray.ObjectRef):
_start_check_sending()
global _sending_obj_refs_q
_sending_obj_refs_q.append(obj_ref)


def notify_to_exit():
# Sending the termination signal
push_to_sending(True)
logger.info('Notify check sending thread to exit.')


def wait_sending():
global _check_send_thread
if _check_send_thread:
notify_to_exit()
_check_send_thread.join()
def _main_thread_monitor():
main_thread = threading.main_thread()
main_thread.join()
self._notify_to_exit()

self._monitor_thread = threading.Thread(target=_main_thread_monitor)
self._monitor_thread.start()
logger.info('Start check sending monitor thread.')

def graceful_stop(self):
assert self._check_send_thread is not None
self._notify_to_exit()
self._check_send_thread.join()

def _notify_to_exit(self):
# Sending the termination signal
self.push_to_sending(True)
logger.info('Notify check sending thread to exit.')

def push_to_sending(self, obj_ref: ray.ObjectRef):
self._sending_obj_refs_q.append(obj_ref)

def _check_sending_objs(self):
def _signal_exit():
os.kill(os.getpid(), signal.SIGTERM)

assert self._sending_obj_refs_q is not None

while True:
try:
obj_ref = self._sending_obj_refs_q.popleft()
except IndexError:
time.sleep(0.1)
continue
if isinstance(obj_ref, bool):
break
try:
res = ray.get(obj_ref)
except Exception as e:
logger.warn(f'Failed to send {obj_ref} with error: {e}')
res = False
if not res and self._exit_when_failure_sending:
logger.warn('Signal self to exit.')
_signal_exit()
break

logger.info('Check sending thread was exited.')
4 changes: 2 additions & 2 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import fed.utils as fed_utils
from fed._private import constants
from fed._private.grpc_options import get_grpc_options, set_max_message_length
from fed.cleanup import push_to_sending
from fed.config import get_cluster_config
from fed.grpc import fed_pb2, fed_pb2_grpc
from fed.utils import setup_logger
from fed._private.global_context import get_global_context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -467,7 +467,7 @@ def send(
upstream_seq_id=upstream_seq_id,
downstream_seq_id=downstream_seq_id,
)
push_to_sending(res)
get_global_context().get_cleanup_manager().push_to_sending(res)
return res


Expand Down
43 changes: 22 additions & 21 deletions tests/test_internal_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@
import fed._private.compatible_utils as compatible_utils


def test_kv_init():
def run(party):
compatible_utils.init_ray("local")
cluster = {
'alice': {'address': '127.0.0.1:11010', 'listen_addr': '0.0.0.0:11010'},
'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'},
}
assert compatible_utils.kv is None
fed.init(cluster=cluster, party=party)
assert compatible_utils.kv
assert not compatible_utils.kv.put(b"test_key", b"test_val")
assert compatible_utils.kv.get(b"test_key") == b"test_val"

time.sleep(5)
fed.shutdown()
ray.shutdown()

assert compatible_utils.kv is None
with pytest.raises(ValueError):
# Make sure the kv actor is non-exist no matter whether it's in client mode
ray.get_actor("_INTERNAL_KV_ACTOR")
def run(party):
compatible_utils.init_ray("local")
cluster = {
'alice': {'address': '127.0.0.1:11010', 'listen_addr': '0.0.0.0:11010'},
'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'},
}
assert compatible_utils.kv is None
fed.init(cluster=cluster, party=party)
assert compatible_utils.kv
assert not compatible_utils.kv.put(b"test_key", b"test_val")
assert compatible_utils.kv.get(b"test_key") == b"test_val"

time.sleep(5)
fed.shutdown()

assert compatible_utils.kv is None
with pytest.raises(ValueError):
# Make sure the kv actor is non-exist no matter whether it's in client mode
ray.get_actor("_INTERNAL_KV_ACTOR")
ray.shutdown()


def test_kv_init():
p_alice = multiprocessing.Process(target=run, args=('alice',))
p_bob = multiprocessing.Process(target=run, args=('bob',))
p_alice.start()
Expand Down
43 changes: 22 additions & 21 deletions tests/test_listen_addr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,29 @@ def get_value(self):
return self._value


def test_listen_addr():
def run(party, is_inner_party):
compatible_utils.init_ray(address='local')
cluster = {
'alice': {'address': '127.0.0.1:11012', 'listen_addr': '0.0.0.0:11012'},
'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'},
}
fed.init(cluster=cluster, party=party)

o = f.party("alice").remote()
actor_location = "alice" if is_inner_party else "bob"
my = My.party(actor_location).remote(o)
val = my.get_value.remote()
result = fed.get(val)
assert result == 100
assert fed.get(o) == 100
import time

time.sleep(5)
fed.shutdown()
ray.shutdown()
def run(party, is_inner_party):
compatible_utils.init_ray(address='local')
cluster = {
'alice': {'address': '127.0.0.1:11012', 'listen_addr': '0.0.0.0:11012'},
'bob': {'address': '127.0.0.1:11011', 'listen_addr': '0.0.0.0:11011'},
}
fed.init(cluster=cluster, party=party)

o = f.party("alice").remote()
actor_location = "alice" if is_inner_party else "bob"
my = My.party(actor_location).remote(o)
val = my.get_value.remote()
result = fed.get(val)
assert result == 100
assert fed.get(o) == 100
import time

time.sleep(5)
fed.shutdown()
ray.shutdown()


def test_listen_addr():
p_alice = multiprocessing.Process(target=run, args=('alice', True))
p_bob = multiprocessing.Process(target=run, args=('bob', True))
p_alice.start()
Expand Down
12 changes: 0 additions & 12 deletions tests/test_repeat_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
import multiprocessing

import pytest
import time
import fed
import fed._private.compatible_utils as compatible_utils
import ray

from fed.cleanup import _start_check_sending, push_to_sending


@fed.remote
class My:
Expand All @@ -48,16 +45,8 @@ def bar(self, li):

def run(party):
def _run():
assert fed.cleanup._sending_obj_refs_q is None
compatible_utils.init_ray(address='local')
fed.init(cluster=cluster, party=party)
_start_check_sending()
time.sleep(0.5)
assert fed.cleanup._sending_obj_refs_q is not None
push_to_sending(True)
# Slightly longer than the queue polling
time.sleep(0.6)
assert fed.cleanup._sending_obj_refs_q is None

my1 = My.party("alice").remote()
my2 = My.party("bob").remote()
Expand All @@ -71,7 +60,6 @@ def _run():

fed.shutdown()
ray.shutdown()
assert fed.cleanup._sending_obj_refs_q is None

_run()
_run()
Expand Down
Loading
Loading