Skip to content

Commit

Permalink
Add AGC with ignore_agc arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vballoli committed Feb 17, 2021
1 parent 7d55fa6 commit eaa4b87
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
1 change: 1 addition & 0 deletions nfnets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import WSConv2d, WSConvTranspose2d
from .sgd_agc import SGD_AGC
from .agc import AGC
from .utils import replace_conv
22 changes: 20 additions & 2 deletions nfnets/agc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ class AGC(optim.Optimizer):
optim (torch.optim.Optimizer): Optimizer with base class optim.Optimizer
clipping (float, optional): clipping value (default: 1e-3)
eps (float, optional): eps (default: 1e-3)
model (torch.nn.Module, optional): The original model
ignore_agc (str, Iterable, optional): Layers for AGC to ignore
"""
def __init__(self, params, optim: optim.Optimizer, clipping: float=1e-2, eps: float=1e-3):

def __init__(self, params, optim: optim.Optimizer, clipping: float = 1e-2, eps: float = 1e-3, model=None, ignore_agc=["fc"]):
if clipping < 0.0:
raise ValueError("Invalid clipping value: {}".format(clipping))
if eps < 0.0:
Expand All @@ -24,6 +27,22 @@ def __init__(self, params, optim: optim.Optimizer, clipping: float=1e-2, eps: fl

defaults = dict(clipping=clipping, eps=eps)
defaults = {**defaults, **optim.defaults}

if not isinstance(ignore_agc, Iterable):
ignore_agc = [ignore_agc]

if model is not None:
assert ignore_agc not in [
None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
names = [name for name, module in model.named_modules()]

for module_name in ignore_agc:
if module_name not in names:
raise ModuleNotFoundError(
"Module name {} not found in the model".format(module_name))
parameters = [{"params": module.parameters()} for name,
module in model.named_modules() if name not in ignore_agc]

super(AGC, self).__init__(params, defaults)

@torch.no_grad()
Expand Down Expand Up @@ -54,4 +73,3 @@ def step(self, closure=None):
p.grad.data.copy_(torch.where(trigger, clipped_grad, p.grad))

return self.optim.step(closure)

0 comments on commit eaa4b87

Please sign in to comment.