From 19831a535e89a3ceb30dafef02b6d2cdb75bb244 Mon Sep 17 00:00:00 2001 From: "haonan.yu" Date: Tue, 17 Dec 2024 10:25:02 -0800 Subject: [PATCH] address comments --- .../distributed_off_policy_algorithm.py | 79 ++++++++++++------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/alf/algorithms/distributed_off_policy_algorithm.py b/alf/algorithms/distributed_off_policy_algorithm.py index 2d1441ac4..9ea3ed348 100644 --- a/alf/algorithms/distributed_off_policy_algorithm.py +++ b/alf/algorithms/distributed_off_policy_algorithm.py @@ -71,7 +71,8 @@ def create_zmq_socket(type: int, ip: str, port: int, id: str = None): """A helper function for creating a ZMQ socket. Args: - type: type of the socket. + type: type of the ZMQ socket, e.g., zmq.DEALER, zmq.PUB, etc. See + https://sachabarbs.wordpress.com/2014/08/21/zeromq-2-the-socket-types-2/ ip: ip address. If it's '*', then `socket.bind()` will be used. port: port number. id: identity of the socket (optional). Only required for DEALER @@ -253,14 +254,20 @@ def receive_experience_data(replay_buffer: ReplayBuffer, def pull_params_from_trainer(memory_name: str, unroller_id: str, - params_socket_k: int): + params_socket_rank: int): """ Once new params arrive, we put it in the shared memory and mark updated. Later after the current unroll finishes, the unroller can load the new params. + + Args: + memory_name: the name of the shared memory which is used to store the + updated params for the main process. + unroller_id: the id of the unroller. + params_socket_rank: which DDP rank will be syncing params with this unroller. """ socket, _ = create_zmq_socket( zmq.DEALER, _trainer_addr_config.ip, - _trainer_addr_config.port + _params_port_offset + params_socket_k, + _trainer_addr_config.port + _params_port_offset + params_socket_rank, unroller_id + "_params") params = SharedMemory(name=memory_name) # signifies that this unroller is ready to receive params @@ -273,7 +280,7 @@ def pull_params_from_trainer(memory_name: str, unroller_id: str, @alf.configurable(whitelist=[ - 'max_utd_ratio', 'push_params_every_n_iters', 'checkpoint', 'name', + 'max_utd_ratio', 'push_params_every_n_grad_updates', 'checkpoint', 'name', 'optimizer' ]) class DistributedTrainer(DistributedOffPolicyAlgorithm): @@ -281,7 +288,7 @@ def __init__(self, core_alg_ctor: Callable, *args, max_utd_ratio: float = 10., - push_params_every_n_iters: int = 1, + push_params_every_n_grad_updates: int = 1, env: AlfEnvironment = None, config: TrainerConfig = None, optimizer: alf.optimizers.Optimizer = None, @@ -303,8 +310,8 @@ def __init__(self, sync gradients among subprocesses after each backward. A larger value will make the trainer more likely overfit to the replay buffer data, while a smaller value will lead to data wastage. - push_params_every_n_iters: push model parameters to the unroller - every this number of iterations. + push_params_every_n_grad_updates: push model parameters to the unroller + every this number of gradient updates. *args: additional args to pass to ``core_alg_ctor``. **kwargs: additional kwargs to pass to ``core_alg_ctor``. """ @@ -320,11 +327,11 @@ def __init__(self, name=name, **kwargs) - self._push_params_every_n_iters = push_params_every_n_iters + self._push_params_every_n_grad_updates = push_params_every_n_grad_updates # Ports: # 1. registration port: self._port + self._ddp_rank - # 2. params port: self._port + _params_port_offset + # 2. params port: self._port + _params_port_offset + self._ddp_rank self._max_utd_ratio = max_utd_ratio @@ -334,15 +341,15 @@ def __init__(self, self._params_socket, _ = create_zmq_socket( zmq.ROUTER, '*', self._port + _params_port_offset + self._ddp_rank) - # 3 sec timeout for receiving unroller's acknowledgement - # In case some unrollers might die, we don't want to block forever - self._params_socket.setsockopt(zmq.RCVTIMEO, 3000) assert config.unroll_length == -1, ( 'unroll_length must be -1 (no unrolling)') # Total number of gradient updates so far self._total_updates = 0 - self._daemons_started = False + # How many times ``train_iter()`` has been called. + # Cannot directly use ``alf.summary.get_global_counter()`` because it + # may be incremented every mini-batch + self._num_train_iters = 0 def _observe_for_replay(self, exp: Experience): raise RuntimeError( @@ -379,14 +386,20 @@ def _send_params_to_unroller(self, buffer = io.BytesIO() torch.save(self._opt_free_state_dict(), buffer) self._params_socket.send_multipart([unroller_id1, buffer.getvalue()]) - try: - _, message = self._params_socket.recv_multipart() - assert message == UnrollerMessage.OK.encode() - logging.debug(f"[worker-{self._ddp_rank}] Params sent to unroller" - f" {unroller_id.decode()}.") - return True - except zmq.Again: - return False + # 3 sec timeout for receiving unroller's acknowledgement + # In case some unrollers might die, we don't want to block forever + for _ in range(30): + try: + _, message = self._params_socket.recv_multipart( + flags=zmq.NOBLOCK) + assert message == UnrollerMessage.OK.encode() + logging.debug( + f"[worker-{self._ddp_rank}] Params sent to unroller" + f" {unroller_id.decode()}.") + return True + except zmq.Again: + time.sleep(0.1) + return False def _create_unroller_registration_thread(self): self._new_unroller_ips_and_ports = mp.Queue() @@ -445,7 +458,7 @@ def _wait_unroller_registration(): thread.start() def _create_data_receiver_subprocess(self): - """Create a proc to receive experience data from unrollers. + """Create a process to receive experience data from unrollers. """ # First create the replay buffer in the main process. For this, we need # to create a dummy experience to set up the replay buffer. @@ -479,7 +492,13 @@ def utd(self): return self._total_updates / total_exps def _train_iter_off_policy(self): - if not self._daemons_started: + if self._num_train_iters == 0: + # First time will be called by ``Trainer._restore_checkpoint()`` + # where the ckpt (if any) will be loaded after this function. + self._num_train_iters += 1 + return super()._train_iter_off_policy() + + if self._num_train_iters == 1: # Only open the unroller registration after we are sure that # the trainer's ckpt (if any) has been loaded, so that the trainer # will send correct params to any newly added unroller. @@ -488,7 +507,6 @@ def _train_iter_off_policy(self): # Instead, we call a separate data receiver process that consistently # pulls data from unrollers. self._create_data_receiver_subprocess() - self._daemons_started = True # A worker will pause when either happens: # 1. replay buffer is not ready (initial collect steps not reached) @@ -504,10 +522,11 @@ def _train_iter_off_policy(self): steps = super()._train_iter_off_policy() self._total_updates += self._config.num_updates_per_train_iter - if (alf.summary.get_global_counter() % - self._push_params_every_n_iters == 0): + if (self._total_updates % self._push_params_every_n_grad_updates == 0): # Sending params to all the connected unrollers. dead_unrollers = [] + logging.info(f"Rank {self._ddp_rank} sends params to unrollers " + f"{self._unrollers_to_update_params}") for unroller_id in self._unrollers_to_update_params: if not self._send_params_to_unroller(unroller_id): dead_unrollers.append(unroller_id) @@ -515,6 +534,8 @@ def _train_iter_off_policy(self): for unroller_id in dead_unrollers: self._unrollers_to_update_params.remove(unroller_id) + self._num_train_iters += 1 + return steps @@ -581,9 +602,9 @@ def _register_to_trainer(self): message = register_socket.recv_string() assert message.startswith('worker-0:') # message format: "worker-0: N k" - num_trainer_workers, params_socket_k = message.split(' ')[1:] + num_trainer_workers, params_socket_rank = message.split(' ')[1:] self._num_trainer_workers = int(num_trainer_workers) - self._params_socket_k = int(params_socket_k) + self._params_socket_rank = int(params_socket_rank) logging.info( f'Found {self._num_trainer_workers} workers on the trainer. ') # Randomly select a worker as the cycle start so that multiple unrollers @@ -621,7 +642,7 @@ def _create_pull_params_subprocess(self): process = mp.Process( target=pull_params_from_trainer, args=(self._shared_alg_params.name, self._id, - self._params_socket_k), + self._params_socket_rank), daemon=True) process.start()