diff --git a/megatron/arguments.py b/megatron/arguments.py index 9228da6ee9..6580dba80a 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1570,5 +1570,11 @@ def _add_profiler_args(parser): type=str, default='2,3', help="Which steps to profile. Format: ,") + + group.add_argument("--profile-ranks", + type=int, + nargs='+', + default=None, + help="Which ranks to profile. Format: 0 1 2 3") return parser diff --git a/megatron/profiler.py b/megatron/profiler.py index c98096482a..aeab144846 100644 --- a/megatron/profiler.py +++ b/megatron/profiler.py @@ -36,7 +36,9 @@ def is_end_step(): def is_capture_step(): return cur_step >= start_step and cur_step <= end_step - if args.profile.startswith('pt'): + if args.profile.startswith('pt') and ( + args.profile_ranks is None or torch.distributed.get_rank() in args.profile_ranks + ): schedule = torch.profiler.schedule(wait=0, warmup=0, active=active_steps, repeat=1) activities = [torch.profiler.ProfilerActivity.CPU] activities.extend([torch.profiler.ProfilerActivity.HPU] if device.startswith("hpu") else [])