Skip to content

Commit

Permalink
add timeout for new_group (NVIDIA#8998)
Browse files Browse the repository at this point in the history
* add timeout for new_group

Signed-off-by: acphile <simplephile@outlook.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: acphile <simplephile@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
3 people committed Apr 23, 2024
1 parent 153dae9 commit a3825d5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ WORKDIR /workspace/
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \
git cherry-pick -n e69187bc3679ea5841030a165d587bb48b56ee77 && \
pip install .

# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771
Expand Down
11 changes: 9 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@
NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE"


def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None) -> None:
def init_model_parallel(
sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30
) -> None:
""" Initializes Megatron-LM model parallel if using model parallelism.
Args:
Expand All @@ -139,6 +141,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None)
use_sharp=sharp,
expert_model_parallel_size=app_state.expert_model_parallel_size,
order='tp-pp-dp' if app_state.use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
distributed_timeout_minutes=distributed_timeout_minutes,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down Expand Up @@ -219,7 +222,11 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) ->
app_state = AppState()

if app_state.model_parallel_size is not None:
init_model_parallel(self.sharp, self.nccl_communicator_config_path)
init_model_parallel(
self.sharp,
self.nccl_communicator_config_path,
distributed_timeout_minutes=self._timeout.total_seconds() / 60,
)

def configure_ddp(self):
""" Override LightningModule ddp if using model parallel.
Expand Down

0 comments on commit a3825d5

Please sign in to comment.