Skip to content

Commit

Permalink
[FEAT][print_num_params]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 17, 2023
1 parent 06f02c6 commit 40f0f00
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions zeta/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
save_memory_snapshot,
)
from zeta.utils.disable_logging import disable_warnings_and_logs
from zeta.utils.params import print_num_params, print_main

__all__ = [
"track_cuda_memory_usage",
"benchmark",
"print_cuda_memory_usage",
"save_memory_snapshot",
"disable_warnings_and_logs",
"print_num_params",
"print_main",
]
29 changes: 29 additions & 0 deletions zeta/utils/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.distributed as dist # Add this line


def print_num_params(model):
"""Print the number of parameters in a model.
Args:
model (_type_): _description_
"""
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

if dist.is_available():
if dist.get_rank() == 0:
print(f"Number of parameters in model: {n_params}")
else:
print(f"Number of parameters in model: {n_params}")


def print_main(msg):
"""Print the message only on the main process.
Args:
msg (_type_): _description_
"""
if dist.is_available():
if dist.get_rank() == 0:
print(msg)
else:
print(msg)

0 comments on commit 40f0f00

Please sign in to comment.