From 107233d58c5b3ceaf8b998c0d3ac95145b8c52a9 Mon Sep 17 00:00:00 2001 From: Ethan Marx <61295922+EthanMarx@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:39:07 -0500 Subject: [PATCH] Add callback for tracking gradient norm (#313) * add gradient tracker callback * add norm type to log message --- projects/train/train/callbacks.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index 6d57292a..b322b188 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -10,6 +10,7 @@ from lightning import pytorch as pl from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.utilities import grad_norm BOTO_RETRY_EXCEPTIONS = (ClientError, ConnectTimeoutError) @@ -125,3 +126,13 @@ def on_train_start(self, trainer, pl_module): os.path.join(save_dir, "wandb_url.txt"), "w" ) as f: f.write(url) + + +class GradientTracker(Callback): + def __init__(self, norm_type: int = 2): + self.norm_type = norm_type + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + norms = grad_norm(pl_module, norm_type=self.norm_type) + total_norm = norms[f"grad_{float(self.norm_type)}_norm_total"] + self.log(f"grad_norm_{self.norm_type}", total_norm)