diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 8913ea1ed..9fc4b5844 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -15,7 +15,6 @@ Optimizers, TrainingConfig, ) -from refiners.training_utils.gradient_clipping import GradientClippingConfig from refiners.training_utils.trainer import Trainer, register_callback, register_model from refiners.training_utils.wandb import WandbConfig, WandbMixin @@ -52,7 +51,6 @@ "OptimizerConfig", "TrainingConfig", "ClockConfig", - "GradientClippingConfig", "Optimizers", "LRSchedulerType", ] diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index 46b070c66..4aa0b1a6d 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -7,19 +7,18 @@ import numpy as np import torch from loguru import logger -from torch import Tensor, cuda, nn +from torch import cuda, nn from refiners.fluxion.utils import manual_seed def compute_grad_norm(parameters: Iterable[nn.Parameter]) -> float: """ - Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value. + Computes the gradient norm of the parameters in the given iterable. + + We use the `torch.nn.utils.clip_grad_norm_` function to process the gradients efficiently on the GPU or CPU. """ - gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to compute the norm." - total_norm = torch.stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore - return total_norm # type: ignore + return nn.utils.clip_grad.clip_grad_norm_(parameters, float("inf")).item() def count_learnable_parameters(parameters: Iterable[nn.Parameter]) -> int: diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 254d7e100..0523e9210 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -12,7 +12,6 @@ from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field -from refiners.training_utils.gradient_clipping import GradientClippingConfig # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -28,6 +27,7 @@ class TrainingConfig(BaseModel): batch_size: int = 1 gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP) evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) + gradient_clipping_max_norm: float | None = None evaluation_seed: int = 0 model_config = ConfigDict(extra="forbid") @@ -167,7 +167,6 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() - gradient_clipping: GradientClippingConfig = GradientClippingConfig() model_config = ConfigDict(extra="forbid") diff --git a/src/refiners/training_utils/gradient_clipping.py b/src/refiners/training_utils/gradient_clipping.py deleted file mode 100644 index 28701c870..000000000 --- a/src/refiners/training_utils/gradient_clipping.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import TYPE_CHECKING, Any, Iterable - -import torch -from torch import nn - -from refiners.training_utils.callback import Callback, CallbackConfig - -if TYPE_CHECKING: - from refiners.training_utils.config import BaseConfig - from refiners.training_utils.trainer import Trainer - - -def clip_gradient_norm(parameters: Iterable[nn.Parameter], total_norm: float, clip_norm: float = 1.0) -> None: - """ - Clips the gradient norm of the parameters of a given model similar to `clip_grad_norm_`. - """ - gradients = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to clip." - clip_coefficient = torch.tensor(data=clip_norm / (total_norm + 1e-6)).clamp(max=1) - for gradient in gradients: - gradient.mul_(other=clip_coefficient) # type: ignore - - -def clip_gradient_value(parameters: Iterable[nn.Parameter], clip_value: float) -> None: - """ - Clips the gradients of the parameters of a given model at an individual level similar to `clip_grad_value_`. - """ - gradients = [p.grad.detach() for p in parameters if p.grad is not None] - assert gradients, "The model has no gradients to clip." - for gradient in gradients: - gradient.clamp_(min=-clip_value, max=clip_value) - - -class GradientClippingConfig(CallbackConfig): - clip_grad_norm: float | None = None - clip_grad_value: float | None = None - - -class GradientClipping(Callback["Trainer[BaseConfig, Any]"]): - def __init__(self, config: GradientClippingConfig) -> None: - self.config = config - - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - clip_norm = self.config.clip_grad_norm - if clip_norm is not None: - clip_gradient_norm( - parameters=trainer.learnable_parameters, total_norm=trainer.total_gradient_norm, clip_norm=clip_norm - ) - - clip_value = self.config.clip_grad_value - if clip_value is not None: - clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 6905c3181..50c9065de 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -37,7 +37,6 @@ scoped_seed, ) from refiners.training_utils.config import BaseConfig, LRSchedulerType, ModelConfig -from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig class WarmupScheduler(LRScheduler): @@ -161,10 +160,6 @@ def clock(self, config: ClockConfig) -> TrainingClock: verbose=config.verbose, ) - @register_callback() - def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping: - return GradientClipping(config) - @property def models(self) -> ModelRegistry: return self._models @@ -351,6 +346,8 @@ def backward(self) -> None: self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: self._call_callbacks(event_name="on_optimizer_step_begin") + max_norm = self.config.training.gradient_clipping_max_norm or float("inf") + self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item() self.optimizer.step() self.optimizer.zero_grad() self._call_callbacks(event_name="on_optimizer_step_end") diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index ac167b58c..035dd9362 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -112,6 +112,7 @@ def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None: trainer.wandb_log(data={"step_loss": loss_value}) def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None: + trainer.wandb_log(data={"total_grad_norm": trainer.grad_norm}) avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses) trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss}) self.iteration_losses = [] @@ -124,9 +125,6 @@ def on_epoch_end(self, trainer: "TrainerWithWandb") -> None: def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None: trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]}) - def on_backward_end(self, trainer: "TrainerWithWandb") -> None: - trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm}) - class WandbMixin(ABC): config: Any diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index bfc8b2d2a..eebb211d0 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -5,9 +5,6 @@ use_activation = true [clock] verbose = false -[gradient_clipping] -clip_grad_norm = 1.0 - [training] duration = "100:epoch" seed = 0 @@ -17,6 +14,7 @@ batch_size = 4 gradient_accumulation = "4:step" evaluation_interval = "5:epoch" evaluation_seed = 1 +gradient_clipping_max_norm = 1.0 [optimizer] optimizer = "SGD" diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 9980a6f67..302c70b7a 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -8,9 +8,6 @@ requires_grad = true [clock] verbose = false -[gradient_clipping] -clip_grad_norm = 1.0 - [training] duration = "100:epoch" seed = 0 @@ -18,6 +15,7 @@ batch_size = 4 gradient_accumulation = "4:step" evaluation_interval = "5:epoch" evaluation_seed = 1 +gradient_clipping_max_norm = 1.0 [optimizer] optimizer = "SGD"