diff --git a/deel/torchlip/utils/lconv_norm.py b/deel/torchlip/utils/lconv_norm.py index bb63031..7ba0ef3 100644 --- a/deel/torchlip/utils/lconv_norm.py +++ b/deel/torchlip/utils/lconv_norm.py @@ -24,7 +24,6 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -from typing import Any from typing import Tuple import numpy as np @@ -35,15 +34,15 @@ def compute_lconv_coef( kernel_size: Tuple[int, ...], - input_shape: Tuple[int, ...], + input_shape: Tuple[int, ...] = None, strides: Tuple[int, ...] = (1, 1), ) -> float: # See https://arxiv.org/abs/2006.06520 stride = np.prod(strides) k1, k2 = kernel_size - h, w = input_shape[-2:] - if stride == 1: + if stride == 1 and input_shape is not None: + h, w = input_shape[-2:] k1_div2 = (k1 - 1) / 2 k2_div2 = (k2 - 1) / 2 coefLip = np.sqrt( @@ -59,7 +58,7 @@ def compute_lconv_coef( class _LConvNorm(nn.Module): - """Parametrization module for Lipschitz normalization.""" + """Parametrization module for kernel normalization of lipschitz convolution.""" def __init__(self, lconv_coefficient: float) -> None: super().__init__() @@ -69,56 +68,6 @@ def forward(self, weight: torch.Tensor) -> torch.Tensor: return weight * self.lconv_coefficient -class LConvNormHook: - - """ - Kernel normalization for Lipschitz convolution. Normalize weights - based on input shape and kernel size, see https://arxiv.org/abs/2006.06520 - """ - - def apply(self, module: torch.nn.Module, name: str = "weight") -> None: - self.name = name - self.coefficient = None - - if not isinstance(module, torch.nn.Conv2d): - raise RuntimeError( - "Can only apply lconv_norm hooks on 2D-convolutional layer." - ) - - module.register_forward_pre_hook(self) - - def __call__(self, module: torch.nn.Conv2d, inputs: Any): - coefficient = compute_lconv_coef( - module.kernel_size, inputs[0].shape[-4:], module.stride - ) - # the parametrization is updated only if the coefficient has changed - if coefficient != self.coefficient: - if hasattr(module, "parametrizations"): - self.remove_parametrization(module) - parametrize.register_parametrization( - module, self.name, _LConvNorm(coefficient) - ) - self.coefficient = coefficient - - def remove_parametrization(self, module: nn.Module) -> nn.Module: - r""" - Removes the normalization reparameterization from a module. - - Args: - module: Containing module. - - Example: - >>> m = bjorck_norm(nn.Linear(20, 40)) - >>> remove_bjorck_norm(m) - """ - for key, m in module.parametrizations[self.name]._modules.items(): - if isinstance(m, _LConvNorm): - if len(module.parametrizations[self.name]) == 1: - parametrize.remove_parametrizations(module, self.name) - else: - del module.parametrizations[self.name]._modules[key] - - def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d: r""" Applies Lipschitz normalization to a kernel in the given convolutional. @@ -142,26 +91,27 @@ def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) """ - LConvNormHook().apply(module, name) + coefficient = compute_lconv_coef(module.kernel_size, None, module.stride) + parametrize.register_parametrization(module, name, _LConvNorm(coefficient)) return module -def remove_lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: +def remove_lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d: r""" - Removes the Lipschitz normalization hook from a module. + Removes the normalization parametrization for lipschitz convolution from a module. Args: module: Containing module. + name: Name of weight parameter. Example: >>> m = lconv_norm(nn.Conv2d(16, 16, (3, 3))) >>> remove_lconv_norm(m) """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, LConvNormHook): - hook.remove_parametrization(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("lconv_norm not found in {}".format(module)) + for key, m in module.parametrizations[name]._modules.items(): + if isinstance(m, _LConvNorm): + if len(module.parametrizations[name]) == 1: + parametrize.remove_parametrizations(module, name) + else: + del module.parametrizations[name]._modules[key] diff --git a/tests/test_parametrizations.py b/tests/test_parametrizations.py index 2da0c3f..29b995c 100644 --- a/tests/test_parametrizations.py +++ b/tests/test_parametrizations.py @@ -111,7 +111,7 @@ def test_lconv_norm(): """ m = torch.nn.Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) torch.nn.init.orthogonal_(m.weight) - w1 = m.weight * compute_lconv_coef(m.kernel_size, (1, 1, 5, 5), m.stride) + w1 = m.weight * compute_lconv_coef(m.kernel_size, None, m.stride) # lconv norm parametrization lconv_norm(m)