Skip to content

Commit

Permalink
Use torch F.rms_norm when possible, select fast vs normal paths appro…
Browse files Browse the repository at this point in the history
…priately and test with torchscript
  • Loading branch information
rwightman committed Dec 29, 2024
1 parent e0cacbf commit 5809c2f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 18 deletions.
20 changes: 15 additions & 5 deletions timm/layers/fast_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
has_apex_rmsnorm = False


has_torch_rms_norm = hasattr(F, 'rms_norm')

# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now

Expand Down Expand Up @@ -75,7 +77,6 @@ def fast_group_norm(
if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

Expand All @@ -101,14 +102,12 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)



def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
Expand Down Expand Up @@ -148,8 +147,19 @@ def fast_rms_norm(
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

# fallback
return rms_norm(x, normalized_shape, weight, eps)
if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
if has_torch_rms_norm:
x = F.rms_norm(x, normalized_shape, weight, eps)
else:
x = rms_norm(x, normalized_shape, weight, eps)

return x


def simple_norm(
Expand Down
62 changes: 49 additions & 13 deletions timm/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
import torch.nn as nn
import torch.nn.functional as F

from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm

try:
from torch.nn.functional import rms_norm
except ImportError:
from .fast_norm import rms_norm


class GroupNorm(nn.GroupNorm):
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x):
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Expand All @@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand All @@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
_fast_norm: torch.jit.Final[bool]

def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
Expand Down Expand Up @@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -187,18 +208,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x


class SimpleNorm(nn.Module):
""" SimpleNorm (x / std(x))
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -210,6 +235,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -222,17 +249,21 @@ def reset_parameters(self) -> None:
nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
return x


class SimpleNorm2d(nn.Module):
""" SimpleNorm for NCHW tensors
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool

def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -244,6 +275,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)

if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
Expand All @@ -257,6 +290,9 @@ def reset_parameters(self) -> None:

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x

0 comments on commit 5809c2f

Please sign in to comment.