diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 7581197bb..e7cbbb949 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -24,6 +24,8 @@ has_apex_rmsnorm = False +has_torch_rms_norm = hasattr(F, 'rms_norm') + # fast (ie lower precision LN) can be disabled with this flag if issues crop up _USE_FAST_NORM = False # defaulting to False for now @@ -75,7 +77,6 @@ def fast_group_norm( if is_autocast_enabled(x.device.type): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype - # FIXME what to do re CPU autocast? dt = get_autocast_dtype(x.device.type) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None @@ -101,14 +102,12 @@ def fast_layer_norm( # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex dt = get_autocast_dtype(x.device.type) - # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) - def rms_norm( x: torch.Tensor, normalized_shape: List[int], @@ -148,8 +147,19 @@ def fast_rms_norm( else: return fused_rms_norm_affine(x, weight, normalized_shape, eps) - # fallback - return rms_norm(x, normalized_shape, weight, eps) + if is_autocast_enabled(x.device.type): + # normally native AMP casts LN inputs to float32 + # apex LN does not, this is behaving like Apex + dt = get_autocast_dtype(x.device.type) + x, weight = x.to(dt), weight.to(dt) + + with torch.amp.autocast(device_type=x.device.type, enabled=False): + if has_torch_rms_norm: + x = F.rms_norm(x, normalized_shape, weight, eps) + else: + x = rms_norm(x, normalized_shape, weight, eps) + + return x def simple_norm( diff --git a/timm/layers/norm.py b/timm/layers/norm.py index dec868e31..f718750dc 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -11,17 +11,24 @@ import torch.nn as nn import torch.nn.functional as F -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm + +try: + from torch.nn.functional import rms_norm +except ImportError: + from .fast_norm import rms_norm class GroupNorm(nn.GroupNorm): + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) - self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x): - if self.fast_norm: + if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, *] """ + _fast_norm: torch.jit.Final[bool] def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) - self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.fast_norm: + if self._fast_norm: return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) @@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm(nn.LayerNorm): """ LayerNorm w/ fast norm option """ + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) @@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm2d(nn.LayerNorm): """ LayerNorm for channels of '2D' spatial NCHW tensors """ + _fast_norm: torch.jit.Final[bool] + def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) @@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor: class RmsNorm(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -150,17 +165,21 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE fast norm fallback needs our rms norm impl, so both paths through here. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. - x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) return x class RmsNorm2d(nn.Module): """ RmsNorm w/ fast (apex) norm if available """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -187,7 +208,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) # NOTE fast norm fallback needs our rms norm impl, so both paths through here. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. - x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) return x @@ -195,10 +219,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SimpleNorm(nn.Module): """ SimpleNorm (x / std(x)) """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -210,6 +235,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -222,17 +249,21 @@ def reset_parameters(self) -> None: nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = simple_norm(x, self.normalized_shape, self.weight, self.eps) return x class SimpleNorm2d(nn.Module): """ SimpleNorm for NCHW tensors """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm'] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool + _fast_norm: bool def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -244,6 +275,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = affine + self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) + if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: @@ -257,6 +290,9 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) - x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + if self._fast_norm: + x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps) + else: + x = simple_norm(x, self.normalized_shape, self.weight, self.eps) x = x.permute(0, 3, 1, 2) return x