Skip to content

Commit

Permalink
Merge branch 'main' into jiemingz/ckpt_mem_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JimmyZhang12 authored Mar 26, 2024
2 parents 3d8dd95 + 439c14b commit c4545fe
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,7 +1732,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== Start nsys profiling ======")
torch.cuda.cudart().cudaProfilerStart()
if self._nsys_profile_gen_shape:
Expand Down Expand Up @@ -1769,7 +1769,7 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()
self._profile_complete = True
Expand Down

0 comments on commit c4545fe

Please sign in to comment.