diff --git a/alf/algorithms/distributed_off_policy_algorithm.py b/alf/algorithms/distributed_off_policy_algorithm.py index 26f4c070b..d2152d99f 100644 --- a/alf/algorithms/distributed_off_policy_algorithm.py +++ b/alf/algorithms/distributed_off_policy_algorithm.py @@ -152,6 +152,7 @@ def __init__(self, self._core_alg = core_alg self._port = port self._ddp_rank = max(0, PerProcessContext().ddp_rank) + self._num_ranks = PerProcessContext().num_processes def _opt_free_state_dict(self) -> dict: """Return `self._core_alg` state dict without optimizers. @@ -251,19 +252,22 @@ def receive_experience_data(replay_buffer: ReplayBuffer, time.sleep(0.1) -def pull_params_from_trainer(memory_name: str, unroller_id: str): +def pull_params_from_trainer(memory_name: str, unroller_id: str, + params_socket_k: int): """ Once new params arrive, we put it in the shared memory and mark updated. Later after the current unroll finishes, the unroller can load the new params. """ socket, _ = create_zmq_socket( zmq.DEALER, _trainer_addr_config.ip, - _trainer_addr_config.port + _params_port_offset, + _trainer_addr_config.port + _params_port_offset + params_socket_k, unroller_id + "_params") params = SharedMemory(name=memory_name) + # signifies that this unroller is ready to receive params + socket.send_string(UnrollerMessage.OK) while True: data = socket.recv() - params.buf[:1] = b'1' + params.buf[0] = 1 params.buf[1:] = data socket.send_string(UnrollerMessage.OK) @@ -328,9 +332,13 @@ def __init__(self, # by the parent ``RLAlgorithm`` self.observe_for_replay = self._observe_for_replay - if self.is_main_ddp_rank: - self._params_socket, _ = create_zmq_socket( - zmq.ROUTER, '*', self._port + _params_port_offset) + print("Trainer params port: ", + self._port + _params_port_offset + self._ddp_rank) + self._params_socket, _ = create_zmq_socket( + zmq.ROUTER, '*', self._port + _params_port_offset + self._ddp_rank) + # 3 sec timeout for receiving unroller's acknowledgement + # In case some unrollers might die, we don't want to block forever + self._params_socket.setsockopt(zmq.RCVTIMEO, 3000) assert config.unroll_length == -1, ( 'unroll_length must be -1 (no unrolling)') @@ -346,74 +354,93 @@ def _observe_for_replay(self, exp: Experience): def is_main_ddp_rank(self): return self._ddp_rank == 0 - def _send_params_to_unroller(self, unroller_id: str) -> bool: + def _send_params_to_unroller(self, + unroller_id: str, + first_time: bool = False) -> bool: """Send model params to a specified unroller. Args: unroller_id: id (bytes str) of the unroller. + first_time: whether this is the first time this function gets called. + For the first time, we need to wait for the unroller's socket ready + signal. Returns: bool: True if the unroller is still alive. """ + unroller_id1 = unroller_id + b'_params' + if first_time: + # Block until the unroller is ready to receive params + # If we don't do so, the outgoing params might get lost before + # the receiving socket is actually created. + unroller_id_, message = self._params_socket.recv_multipart() + assert unroller_id_ == unroller_id1 + assert message == UnrollerMessage.OK.encode() + # Get all parameters/buffers in a state dict and send them out buffer = io.BytesIO() torch.save(self._opt_free_state_dict(), buffer) - self._params_socket.send_multipart( - [unroller_id + b'_params', - buffer.getvalue()]) - success = False - for _ in range(100): # 1s in total for acknowledgement - try: - # In case some unrollers might die, we don't want to block forever - _, message = self._params_socket.recv_multipart( - flags=zmq.NOBLOCK) - assert message == UnrollerMessage.OK.encode() - logging.debug( - f"[worker-0] Params sent to unroller {unroller_id.decode()}." - ) - success = True - break - except zmq.Again: - time.sleep(0.01) - return success + self._params_socket.send_multipart([unroller_id1, buffer.getvalue()]) + try: + _, message = self._params_socket.recv_multipart() + assert message == UnrollerMessage.OK.encode() + logging.debug(f"[worker-{self._ddp_rank}] Params sent to unroller" + f" {unroller_id.decode()}.") + return True + except zmq.Again: + return False def _create_unroller_registration_thread(self): self._new_unroller_ips_and_ports = mp.Queue() - self._connected_unrollers = set() + self._unrollers_to_update_params = set() + registered_unrollers = set() def _wait_unroller_registration(): """Wait for new registration from a unroller. """ + total_unrollers = 0 # Each rank has its own port number and a registration socket to # handle new unrollers. register_socket, _ = create_zmq_socket(zmq.ROUTER, '*', self._port + self._ddp_rank) while True: unroller_id, message = register_socket.recv_multipart() - if unroller_id not in self._connected_unrollers: + if unroller_id not in registered_unrollers: # A new unroller has connected to the trainer - self._connected_unrollers.add(unroller_id) # The init message should always be: 'init' assert message.decode() == 'init' _, unroller_ip, unroller_port = unroller_id.decode().split( '-') - logging.info( - f"Rank {self._ddp_rank} registered {unroller_ip} {unroller_port}" - ) # Store the new unroller ip and port so that later each rank # can connect to it for experience data. self._new_unroller_ips_and_ports.put((unroller_ip, int(unroller_port))) + registered_unrollers.add(unroller_id) + logging.info( + f"Rank {self._ddp_rank} registered {unroller_ip} {unroller_port}" + ) + if self.is_main_ddp_rank: # Send the number of workers to the new unroller, # so that it is able to know other workers. + # Also send the DDP rank that's responsible for the unroller's + # params syncing. See ``_train_iter_off_policy`` + # where the params sending tasks are distributed. + k = total_unrollers % self._num_ranks register_socket.send_multipart([ unroller_id, - (f'worker-0: {PerProcessContext().num_processes}' - ).encode() + (f'worker-0: {self._num_ranks} {k}').encode() ]) + + # Then we check if its params socket communicates with the + # current rank. + if total_unrollers % self._num_ranks == self._ddp_rank: + self._unrollers_to_update_params.add(unroller_id) # Always first sync the params with a new unroller. - self._send_params_to_unroller(unroller_id) + assert self._send_params_to_unroller( + unroller_id, first_time=True) + + total_unrollers += 1 thread = threading.Thread(target=_wait_unroller_registration) thread.daemon = True @@ -479,16 +506,16 @@ def _train_iter_off_policy(self): steps = super()._train_iter_off_policy() self._total_updates += self._config.num_updates_per_train_iter - if (self.is_main_ddp_rank and alf.summary.get_global_counter() % + if (alf.summary.get_global_counter() % self._push_params_every_n_iters == 0): # Sending params to all the connected unrollers. dead_unrollers = [] - for unroller_id in self._connected_unrollers: + for unroller_id in self._unrollers_to_update_params: if not self._send_params_to_unroller(unroller_id): dead_unrollers.append(unroller_id) # remove dead unrollers for unroller_id in dead_unrollers: - self._connected_unrollers.remove(unroller_id) + self._unrollers_to_update_params.remove(unroller_id) return steps @@ -498,7 +525,6 @@ class DistributedUnroller(DistributedOffPolicyAlgorithm): def __init__(self, core_alg_ctor: Callable, *args, - deploy_mode: bool = False, env: AlfEnvironment = None, config: TrainerConfig = None, checkpoint: str = None, @@ -510,8 +536,6 @@ def __init__(self, core_alg_ctor: creates the algorithm to be wrapped by this class. This algorithm's ``predict_step()`` and ``rollout_step()`` will be used for evaluation and rollout. - deploy_mode: True if this unroller is used for deployment. In this - case, the unroller will not communicate with a trainer. checkpoint: this in-alg ckpt will be ignored if ``deploy_mode==False``. *args: additional args to pass to ``core_alg_ctor``. **kwargs: additional kwargs to pass to ``core_alg_ctor``. @@ -535,17 +559,14 @@ def __init__(self, self._id = f"unroller-{ip}-{self._port}" # For sending experience data - if not deploy_mode: - self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*', - self._port) - self._create_pull_params_subprocess() + self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*', self._port, + self._id) # Record the current worker the data is being sent to # To maintain load balance, we want to cycle through the workers # in a round-robin fashion. self._current_worker = 0 - self._deploy_mode = deploy_mode # Whether this unroller has registered to all trainer workers self._registered = False @@ -561,9 +582,10 @@ def _register_to_trainer(self): register_socket.send_string('init') message = register_socket.recv_string() assert message.startswith('worker-0:') - # message format: "worker-0: N" - num_trainer_workers = message.split(':')[1] + # message format: "worker-0: N k" + num_trainer_workers, params_socket_k = message.split(' ')[1:] self._num_trainer_workers = int(num_trainer_workers) + self._params_socket_k = int(params_socket_k) logging.info( f'Found {self._num_trainer_workers} workers on the trainer. ') # Randomly select a worker as the cycle start so that multiple unrollers @@ -580,6 +602,8 @@ def _register_to_trainer(self): for i in range(self._num_trainer_workers): register_socket.send_string('init') + # Sleep to prevent closing the socket too early to send the messages + time.sleep(1.) register_socket.close() cxt.term() @@ -590,14 +614,16 @@ def _create_pull_params_subprocess(self): size = len(buffer.getvalue()) # Create a shared memory object to store the new params # The first char indicates whether the params have been updated - self._params = SharedMemory(create=True, size=size + 1, name='params') + self._shared_alg_params = SharedMemory( + create=True, size=size + 1, name='params_' + self._id) # Initialize the update char to False (not updated) - self._params.buf[:1] = b'0' + self._shared_alg_params.buf[0] = 0 mp.set_start_method('fork', force=True) process = mp.Process( target=pull_params_from_trainer, - args=(self._params.name, self._id), + args=(self._shared_alg_params.name, self._id, + self._params_socket_k), daemon=True) process.start() @@ -607,8 +633,6 @@ def observe_for_replay(self, exp: Experience): Every time we make sure a full episode is sent to the same DDP rank, if multi-gpu training is enabled on the trainer. """ - if self._deploy_mode: - return # First prune exp's replay state to save communication overhead exp = alf.utils.common.prune_exp_replay_state( exp, self._use_rollout_state, self.rollout_state_spec, @@ -644,33 +668,40 @@ def _check_paramss_update(self) -> bool: """Returns True if params have been updated. """ # Check if the params have been updated - if bytes(self._params.buf[:1]) == b'1': - params = bytes(self._params.buf[1:]) - buffer = io.BytesIO(params) + if self._shared_alg_params.buf[0] == 1: + buffer = io.BytesIO(self._shared_alg_params.buf[1:]) state_dict = torch.load(buffer, map_location='cpu') # We might only update part of the params self._core_alg.load_state_dict(state_dict, strict=False) logging.debug("Params updated from the trainer.") - self._params.buf[:1] = b'0' + self._shared_alg_params.buf[0] = 0 return True return False - def _train_iter_off_policy(self): - if not self._registered and not self._deploy_mode: + def train_iter(self): + """Perform one training iteration of the unroller. + + There is actually no training happening in this function. But the unroller + will check if there are updated params available. + """ + if not self._registered: # We need lazy registration so that trainer's params has a higher # priority than the unroller's loaded params (if enabled). self._register_to_trainer() # Wait until the unroller receives the first params update from trainer # We don't want to do this in ``__init__`` because the params might # get overwritten by a checkpointer. + self._create_pull_params_subprocess() while True: if self._check_paramss_update(): break time.sleep(0.01) self._registered = True + # Copied from super().train_iter() + if self._config.empty_cache: + torch.cuda.empty_cache() # Experience will be sent to the trainer in this function self._unroll_iter_off_policy() - if not self._deploy_mode: - self._check_paramss_update() + self._check_paramss_update() return 0