Skip to content

Commit

Permalink
Switch gradient clipping to native torch torch.nn.utils.clip_grad_norm_
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Mar 19, 2024
1 parent 68fe725 commit 38c86f5
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 76 deletions.
2 changes: 0 additions & 2 deletions src/refiners/training_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Optimizers,
TrainingConfig,
)
from refiners.training_utils.gradient_clipping import GradientClippingConfig
from refiners.training_utils.trainer import Trainer, register_callback, register_model
from refiners.training_utils.wandb import WandbConfig, WandbMixin

Expand Down Expand Up @@ -52,7 +51,6 @@
"OptimizerConfig",
"TrainingConfig",
"ClockConfig",
"GradientClippingConfig",
"Optimizers",
"LRSchedulerType",
]
11 changes: 5 additions & 6 deletions src/refiners/training_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@
import numpy as np
import torch
from loguru import logger
from torch import Tensor, cuda, nn
from torch import cuda, nn

from refiners.fluxion.utils import manual_seed


def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float:
"""
Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value.
Computes the gradient norm of the parameters in the given iterable.
We use the `torch.nn.utils.clip_grad_norm_` function to process the gradients efficiently on the GPU or CPU.
"""
gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None]
assert gradients, "The model has no gradients to compute the norm."
total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore
return total_norm # type: ignore
return nn.utils.clip_grad.clip_grad_norm_(parameters, float("inf")).item()


def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int:
Expand Down
3 changes: 1 addition & 2 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
from refiners.training_utils.gradient_clipping import GradientClippingConfig

# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
Expand All @@ -28,6 +27,7 @@ class TrainingConfig(BaseModel):
batch_size: int = 1
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
gradient_clipping_max_norm: float | None = None
evaluation_seed: int = 0

model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -167,7 +167,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig()
gradient_clipping: GradientClippingConfig = GradientClippingConfig()

model_config = ConfigDict(extra="forbid")

Expand Down
52 changes: 0 additions & 52 deletions src/refiners/training_utils/gradient_clipping.py

This file was deleted.

7 changes: 2 additions & 5 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
scoped_seed,
)
from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig
from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig


class WarmupScheduler(LRScheduler):
Expand Down Expand Up @@ -161,10 +160,6 @@ def clock(self, config: ClockConfig) -> TrainingClock:
verbose=config.verbose,
)

@register_callback()
def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping:
return GradientClipping(config)

@property
def models(self) -> ModelRegistry:
return self._models
Expand Down Expand Up @@ -351,6 +346,8 @@ def backward(self) -> None:
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin")
max_norm = self.config.training.gradient_clipping_max_norm or float("inf")
self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item()
self.optimizer.step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
Expand Down
4 changes: 1 addition & 3 deletions src/refiners/training_utils/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"step_loss": loss_value})

def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.grad_norm})
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
Expand All @@ -124,9 +125,6 @@ def on_epoch_end(self, trainer: "TrainerWithWandb") -> None:
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})

def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})


class WandbMixin(ABC):
config: Any
Expand Down
4 changes: 1 addition & 3 deletions tests/training_utils/mock_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ use_activation = true
[clock]
verbose = false

[gradient_clipping]
clip_grad_norm = 1.0

[training]
duration = "100:epoch"
seed = 0
Expand All @@ -17,6 +14,7 @@ batch_size = 4
gradient_accumulation = "4:step"
evaluation_interval = "5:epoch"
evaluation_seed = 1
gradient_clipping_max_norm = 1.0

[optimizer]
optimizer = "SGD"
Expand Down
4 changes: 1 addition & 3 deletions tests/training_utils/mock_config_2_models.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ requires_grad = true
[clock]
verbose = false

[gradient_clipping]
clip_grad_norm = 1.0

[training]
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
evaluation_interval = "5:epoch"
evaluation_seed = 1
gradient_clipping_max_norm = 1.0

[optimizer]
optimizer = "SGD"
Expand Down

0 comments on commit 38c86f5

Please sign in to comment.