Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyu committed Dec 13, 2024
1 parent 7866483 commit a437b5b
Showing 1 changed file with 91 additions and 60 deletions.
151 changes: 91 additions & 60 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self,
self._core_alg = core_alg
self._port = port
self._ddp_rank = max(0, PerProcessContext().ddp_rank)
self._num_ranks = PerProcessContext().num_processes

def _opt_free_state_dict(self) -> dict:
"""Return `self._core_alg` state dict without optimizers.
Expand Down Expand Up @@ -251,19 +252,22 @@ def receive_experience_data(replay_buffer: ReplayBuffer,
time.sleep(0.1)


def pull_params_from_trainer(memory_name: str, unroller_id: str):
def pull_params_from_trainer(memory_name: str, unroller_id: str,
params_socket_k: int):
""" Once new params arrive, we put it in the shared memory and mark updated.
Later after the current unroll finishes, the unroller can load the
new params.
"""
socket, _ = create_zmq_socket(
zmq.DEALER, _trainer_addr_config.ip,
_trainer_addr_config.port + _params_port_offset,
_trainer_addr_config.port + _params_port_offset + params_socket_k,
unroller_id + "_params")
params = SharedMemory(name=memory_name)
# signifies that this unroller is ready to receive params
socket.send_string(UnrollerMessage.OK)
while True:
data = socket.recv()
params.buf[:1] = b'1'
params.buf[0] = 1
params.buf[1:] = data
socket.send_string(UnrollerMessage.OK)

Expand Down Expand Up @@ -328,9 +332,13 @@ def __init__(self,
# by the parent ``RLAlgorithm``
self.observe_for_replay = self._observe_for_replay

if self.is_main_ddp_rank:
self._params_socket, _ = create_zmq_socket(
zmq.ROUTER, '*', self._port + _params_port_offset)
print("Trainer params port: ",
self._port + _params_port_offset + self._ddp_rank)
self._params_socket, _ = create_zmq_socket(
zmq.ROUTER, '*', self._port + _params_port_offset + self._ddp_rank)
# 3 sec timeout for receiving unroller's acknowledgement
# In case some unrollers might die, we don't want to block forever
self._params_socket.setsockopt(zmq.RCVTIMEO, 3000)

assert config.unroll_length == -1, (
'unroll_length must be -1 (no unrolling)')
Expand All @@ -346,74 +354,93 @@ def _observe_for_replay(self, exp: Experience):
def is_main_ddp_rank(self):
return self._ddp_rank == 0

def _send_params_to_unroller(self, unroller_id: str) -> bool:
def _send_params_to_unroller(self,
unroller_id: str,
first_time: bool = False) -> bool:
"""Send model params to a specified unroller.
Args:
unroller_id: id (bytes str) of the unroller.
first_time: whether this is the first time this function gets called.
For the first time, we need to wait for the unroller's socket ready
signal.
Returns:
bool: True if the unroller is still alive.
"""
unroller_id1 = unroller_id + b'_params'
if first_time:
# Block until the unroller is ready to receive params
# If we don't do so, the outgoing params might get lost before
# the receiving socket is actually created.
unroller_id_, message = self._params_socket.recv_multipart()
assert unroller_id_ == unroller_id1
assert message == UnrollerMessage.OK.encode()

# Get all parameters/buffers in a state dict and send them out
buffer = io.BytesIO()
torch.save(self._opt_free_state_dict(), buffer)
self._params_socket.send_multipart(
[unroller_id + b'_params',
buffer.getvalue()])
success = False
for _ in range(100): # 1s in total for acknowledgement
try:
# In case some unrollers might die, we don't want to block forever
_, message = self._params_socket.recv_multipart(
flags=zmq.NOBLOCK)
assert message == UnrollerMessage.OK.encode()
logging.debug(
f"[worker-0] Params sent to unroller {unroller_id.decode()}."
)
success = True
break
except zmq.Again:
time.sleep(0.01)
return success
self._params_socket.send_multipart([unroller_id1, buffer.getvalue()])
try:
_, message = self._params_socket.recv_multipart()
assert message == UnrollerMessage.OK.encode()
logging.debug(f"[worker-{self._ddp_rank}] Params sent to unroller"
f" {unroller_id.decode()}.")
return True
except zmq.Again:
return False

def _create_unroller_registration_thread(self):
self._new_unroller_ips_and_ports = mp.Queue()
self._connected_unrollers = set()
self._unrollers_to_update_params = set()
registered_unrollers = set()

def _wait_unroller_registration():
"""Wait for new registration from a unroller.
"""
total_unrollers = 0
# Each rank has its own port number and a registration socket to
# handle new unrollers.
register_socket, _ = create_zmq_socket(zmq.ROUTER, '*',
self._port + self._ddp_rank)
while True:
unroller_id, message = register_socket.recv_multipart()
if unroller_id not in self._connected_unrollers:
if unroller_id not in registered_unrollers:
# A new unroller has connected to the trainer
self._connected_unrollers.add(unroller_id)
# The init message should always be: 'init'
assert message.decode() == 'init'
_, unroller_ip, unroller_port = unroller_id.decode().split(
'-')
logging.info(
f"Rank {self._ddp_rank} registered {unroller_ip} {unroller_port}"
)
# Store the new unroller ip and port so that later each rank
# can connect to it for experience data.
self._new_unroller_ips_and_ports.put((unroller_ip,
int(unroller_port)))
registered_unrollers.add(unroller_id)
logging.info(
f"Rank {self._ddp_rank} registered {unroller_ip} {unroller_port}"
)

if self.is_main_ddp_rank:
# Send the number of workers to the new unroller,
# so that it is able to know other workers.
# Also send the DDP rank that's responsible for the unroller's
# params syncing. See ``_train_iter_off_policy``
# where the params sending tasks are distributed.
k = total_unrollers % self._num_ranks
register_socket.send_multipart([
unroller_id,
(f'worker-0: {PerProcessContext().num_processes}'
).encode()
(f'worker-0: {self._num_ranks} {k}').encode()
])

# Then we check if its params socket communicates with the
# current rank.
if total_unrollers % self._num_ranks == self._ddp_rank:
self._unrollers_to_update_params.add(unroller_id)
# Always first sync the params with a new unroller.
self._send_params_to_unroller(unroller_id)
assert self._send_params_to_unroller(
unroller_id, first_time=True)

total_unrollers += 1

thread = threading.Thread(target=_wait_unroller_registration)
thread.daemon = True
Expand Down Expand Up @@ -479,16 +506,16 @@ def _train_iter_off_policy(self):
steps = super()._train_iter_off_policy()
self._total_updates += self._config.num_updates_per_train_iter

if (self.is_main_ddp_rank and alf.summary.get_global_counter() %
if (alf.summary.get_global_counter() %
self._push_params_every_n_iters == 0):
# Sending params to all the connected unrollers.
dead_unrollers = []
for unroller_id in self._connected_unrollers:
for unroller_id in self._unrollers_to_update_params:
if not self._send_params_to_unroller(unroller_id):
dead_unrollers.append(unroller_id)
# remove dead unrollers
for unroller_id in dead_unrollers:
self._connected_unrollers.remove(unroller_id)
self._unrollers_to_update_params.remove(unroller_id)

return steps

Expand All @@ -498,7 +525,6 @@ class DistributedUnroller(DistributedOffPolicyAlgorithm):
def __init__(self,
core_alg_ctor: Callable,
*args,
deploy_mode: bool = False,
env: AlfEnvironment = None,
config: TrainerConfig = None,
checkpoint: str = None,
Expand All @@ -510,8 +536,6 @@ def __init__(self,
core_alg_ctor: creates the algorithm to be wrapped by this class.
This algorithm's ``predict_step()`` and ``rollout_step()`` will
be used for evaluation and rollout.
deploy_mode: True if this unroller is used for deployment. In this
case, the unroller will not communicate with a trainer.
checkpoint: this in-alg ckpt will be ignored if ``deploy_mode==False``.
*args: additional args to pass to ``core_alg_ctor``.
**kwargs: additional kwargs to pass to ``core_alg_ctor``.
Expand All @@ -535,17 +559,14 @@ def __init__(self,
self._id = f"unroller-{ip}-{self._port}"

# For sending experience data
if not deploy_mode:
self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*',
self._port)
self._create_pull_params_subprocess()
self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*', self._port,
self._id)

# Record the current worker the data is being sent to
# To maintain load balance, we want to cycle through the workers
# in a round-robin fashion.
self._current_worker = 0

self._deploy_mode = deploy_mode
# Whether this unroller has registered to all trainer workers
self._registered = False

Expand All @@ -561,9 +582,10 @@ def _register_to_trainer(self):
register_socket.send_string('init')
message = register_socket.recv_string()
assert message.startswith('worker-0:')
# message format: "worker-0: N"
num_trainer_workers = message.split(':')[1]
# message format: "worker-0: N k"
num_trainer_workers, params_socket_k = message.split(' ')[1:]
self._num_trainer_workers = int(num_trainer_workers)
self._params_socket_k = int(params_socket_k)
logging.info(
f'Found {self._num_trainer_workers} workers on the trainer. ')
# Randomly select a worker as the cycle start so that multiple unrollers
Expand All @@ -580,6 +602,8 @@ def _register_to_trainer(self):
for i in range(self._num_trainer_workers):
register_socket.send_string('init')

# Sleep to prevent closing the socket too early to send the messages
time.sleep(1.)
register_socket.close()
cxt.term()

Expand All @@ -590,14 +614,16 @@ def _create_pull_params_subprocess(self):
size = len(buffer.getvalue())
# Create a shared memory object to store the new params
# The first char indicates whether the params have been updated
self._params = SharedMemory(create=True, size=size + 1, name='params')
self._shared_alg_params = SharedMemory(
create=True, size=size + 1, name='params_' + self._id)
# Initialize the update char to False (not updated)
self._params.buf[:1] = b'0'
self._shared_alg_params.buf[0] = 0

mp.set_start_method('fork', force=True)
process = mp.Process(
target=pull_params_from_trainer,
args=(self._params.name, self._id),
args=(self._shared_alg_params.name, self._id,
self._params_socket_k),
daemon=True)
process.start()

Expand All @@ -607,8 +633,6 @@ def observe_for_replay(self, exp: Experience):
Every time we make sure a full episode is sent to the same DDP rank, if
multi-gpu training is enabled on the trainer.
"""
if self._deploy_mode:
return
# First prune exp's replay state to save communication overhead
exp = alf.utils.common.prune_exp_replay_state(
exp, self._use_rollout_state, self.rollout_state_spec,
Expand Down Expand Up @@ -644,33 +668,40 @@ def _check_paramss_update(self) -> bool:
"""Returns True if params have been updated.
"""
# Check if the params have been updated
if bytes(self._params.buf[:1]) == b'1':
params = bytes(self._params.buf[1:])
buffer = io.BytesIO(params)
if self._shared_alg_params.buf[0] == 1:
buffer = io.BytesIO(self._shared_alg_params.buf[1:])
state_dict = torch.load(buffer, map_location='cpu')
# We might only update part of the params
self._core_alg.load_state_dict(state_dict, strict=False)
logging.debug("Params updated from the trainer.")
self._params.buf[:1] = b'0'
self._shared_alg_params.buf[0] = 0
return True
return False

def _train_iter_off_policy(self):
if not self._registered and not self._deploy_mode:
def train_iter(self):
"""Perform one training iteration of the unroller.
There is actually no training happening in this function. But the unroller
will check if there are updated params available.
"""
if not self._registered:
# We need lazy registration so that trainer's params has a higher
# priority than the unroller's loaded params (if enabled).
self._register_to_trainer()
# Wait until the unroller receives the first params update from trainer
# We don't want to do this in ``__init__`` because the params might
# get overwritten by a checkpointer.
self._create_pull_params_subprocess()
while True:
if self._check_paramss_update():
break
time.sleep(0.01)
self._registered = True

# Copied from super().train_iter()
if self._config.empty_cache:
torch.cuda.empty_cache()
# Experience will be sent to the trainer in this function
self._unroll_iter_off_policy()
if not self._deploy_mode:
self._check_paramss_update()
self._check_paramss_update()
return 0

0 comments on commit a437b5b

Please sign in to comment.