Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyu committed Dec 17, 2024
1 parent d64fd5c commit 19831a5
Showing 1 changed file with 50 additions and 29 deletions.
79 changes: 50 additions & 29 deletions alf/algorithms/distributed_off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def create_zmq_socket(type: int, ip: str, port: int, id: str = None):
"""A helper function for creating a ZMQ socket.
Args:
type: type of the socket.
type: type of the ZMQ socket, e.g., zmq.DEALER, zmq.PUB, etc. See
https://sachabarbs.wordpress.com/2014/08/21/zeromq-2-the-socket-types-2/
ip: ip address. If it's '*', then `socket.bind()` will be used.
port: port number.
id: identity of the socket (optional). Only required for DEALER
Expand Down Expand Up @@ -253,14 +254,20 @@ def receive_experience_data(replay_buffer: ReplayBuffer,


def pull_params_from_trainer(memory_name: str, unroller_id: str,
params_socket_k: int):
params_socket_rank: int):
""" Once new params arrive, we put it in the shared memory and mark updated.
Later after the current unroll finishes, the unroller can load the
new params.
Args:
memory_name: the name of the shared memory which is used to store the
updated params for the main process.
unroller_id: the id of the unroller.
params_socket_rank: which DDP rank will be syncing params with this unroller.
"""
socket, _ = create_zmq_socket(
zmq.DEALER, _trainer_addr_config.ip,
_trainer_addr_config.port + _params_port_offset + params_socket_k,
_trainer_addr_config.port + _params_port_offset + params_socket_rank,
unroller_id + "_params")
params = SharedMemory(name=memory_name)
# signifies that this unroller is ready to receive params
Expand All @@ -273,15 +280,15 @@ def pull_params_from_trainer(memory_name: str, unroller_id: str,


@alf.configurable(whitelist=[
'max_utd_ratio', 'push_params_every_n_iters', 'checkpoint', 'name',
'max_utd_ratio', 'push_params_every_n_grad_updates', 'checkpoint', 'name',
'optimizer'
])
class DistributedTrainer(DistributedOffPolicyAlgorithm):
def __init__(self,
core_alg_ctor: Callable,
*args,
max_utd_ratio: float = 10.,
push_params_every_n_iters: int = 1,
push_params_every_n_grad_updates: int = 1,
env: AlfEnvironment = None,
config: TrainerConfig = None,
optimizer: alf.optimizers.Optimizer = None,
Expand All @@ -303,8 +310,8 @@ def __init__(self,
sync gradients among subprocesses after each backward.
A larger value will make the trainer more likely overfit to the
replay buffer data, while a smaller value will lead to data wastage.
push_params_every_n_iters: push model parameters to the unroller
every this number of iterations.
push_params_every_n_grad_updates: push model parameters to the unroller
every this number of gradient updates.
*args: additional args to pass to ``core_alg_ctor``.
**kwargs: additional kwargs to pass to ``core_alg_ctor``.
"""
Expand All @@ -320,11 +327,11 @@ def __init__(self,
name=name,
**kwargs)

self._push_params_every_n_iters = push_params_every_n_iters
self._push_params_every_n_grad_updates = push_params_every_n_grad_updates

# Ports:
# 1. registration port: self._port + self._ddp_rank
# 2. params port: self._port + _params_port_offset
# 2. params port: self._port + _params_port_offset + self._ddp_rank

self._max_utd_ratio = max_utd_ratio

Expand All @@ -334,15 +341,15 @@ def __init__(self,

self._params_socket, _ = create_zmq_socket(
zmq.ROUTER, '*', self._port + _params_port_offset + self._ddp_rank)
# 3 sec timeout for receiving unroller's acknowledgement
# In case some unrollers might die, we don't want to block forever
self._params_socket.setsockopt(zmq.RCVTIMEO, 3000)

assert config.unroll_length == -1, (
'unroll_length must be -1 (no unrolling)')
# Total number of gradient updates so far
self._total_updates = 0
self._daemons_started = False
# How many times ``train_iter()`` has been called.
# Cannot directly use ``alf.summary.get_global_counter()`` because it
# may be incremented every mini-batch
self._num_train_iters = 0

def _observe_for_replay(self, exp: Experience):
raise RuntimeError(
Expand Down Expand Up @@ -379,14 +386,20 @@ def _send_params_to_unroller(self,
buffer = io.BytesIO()
torch.save(self._opt_free_state_dict(), buffer)
self._params_socket.send_multipart([unroller_id1, buffer.getvalue()])
try:
_, message = self._params_socket.recv_multipart()
assert message == UnrollerMessage.OK.encode()
logging.debug(f"[worker-{self._ddp_rank}] Params sent to unroller"
f" {unroller_id.decode()}.")
return True
except zmq.Again:
return False
# 3 sec timeout for receiving unroller's acknowledgement
# In case some unrollers might die, we don't want to block forever
for _ in range(30):
try:
_, message = self._params_socket.recv_multipart(
flags=zmq.NOBLOCK)
assert message == UnrollerMessage.OK.encode()
logging.debug(
f"[worker-{self._ddp_rank}] Params sent to unroller"
f" {unroller_id.decode()}.")
return True
except zmq.Again:
time.sleep(0.1)
return False

def _create_unroller_registration_thread(self):
self._new_unroller_ips_and_ports = mp.Queue()
Expand Down Expand Up @@ -445,7 +458,7 @@ def _wait_unroller_registration():
thread.start()

def _create_data_receiver_subprocess(self):
"""Create a proc to receive experience data from unrollers.
"""Create a process to receive experience data from unrollers.
"""
# First create the replay buffer in the main process. For this, we need
# to create a dummy experience to set up the replay buffer.
Expand Down Expand Up @@ -479,7 +492,13 @@ def utd(self):
return self._total_updates / total_exps

def _train_iter_off_policy(self):
if not self._daemons_started:
if self._num_train_iters == 0:
# First time will be called by ``Trainer._restore_checkpoint()``
# where the ckpt (if any) will be loaded after this function.
self._num_train_iters += 1
return super()._train_iter_off_policy()

if self._num_train_iters == 1:
# Only open the unroller registration after we are sure that
# the trainer's ckpt (if any) has been loaded, so that the trainer
# will send correct params to any newly added unroller.
Expand All @@ -488,7 +507,6 @@ def _train_iter_off_policy(self):
# Instead, we call a separate data receiver process that consistently
# pulls data from unrollers.
self._create_data_receiver_subprocess()
self._daemons_started = True

# A worker will pause when either happens:
# 1. replay buffer is not ready (initial collect steps not reached)
Expand All @@ -504,17 +522,20 @@ def _train_iter_off_policy(self):
steps = super()._train_iter_off_policy()
self._total_updates += self._config.num_updates_per_train_iter

if (alf.summary.get_global_counter() %
self._push_params_every_n_iters == 0):
if (self._total_updates % self._push_params_every_n_grad_updates == 0):
# Sending params to all the connected unrollers.
dead_unrollers = []
logging.info(f"Rank {self._ddp_rank} sends params to unrollers "
f"{self._unrollers_to_update_params}")
for unroller_id in self._unrollers_to_update_params:
if not self._send_params_to_unroller(unroller_id):
dead_unrollers.append(unroller_id)
# remove dead unrollers
for unroller_id in dead_unrollers:
self._unrollers_to_update_params.remove(unroller_id)

self._num_train_iters += 1

return steps


Expand Down Expand Up @@ -581,9 +602,9 @@ def _register_to_trainer(self):
message = register_socket.recv_string()
assert message.startswith('worker-0:')
# message format: "worker-0: N k"
num_trainer_workers, params_socket_k = message.split(' ')[1:]
num_trainer_workers, params_socket_rank = message.split(' ')[1:]
self._num_trainer_workers = int(num_trainer_workers)
self._params_socket_k = int(params_socket_k)
self._params_socket_rank = int(params_socket_rank)
logging.info(
f'Found {self._num_trainer_workers} workers on the trainer. ')
# Randomly select a worker as the cycle start so that multiple unrollers
Expand Down Expand Up @@ -621,7 +642,7 @@ def _create_pull_params_subprocess(self):
process = mp.Process(
target=pull_params_from_trainer,
args=(self._shared_alg_params.name, self._id,
self._params_socket_k),
self._params_socket_rank),
daemon=True)
process.start()

Expand Down

0 comments on commit 19831a5

Please sign in to comment.