diff --git a/nfnets/agc.py b/nfnets/agc.py index ff5c8e6..f947e17 100644 --- a/nfnets/agc.py +++ b/nfnets/agc.py @@ -68,7 +68,7 @@ def step(self, closure=None): grad_norm = unitwise_norm(p.grad.detach()) max_norm = param_norm * group['clipping'] - trigger = grad_norm < max_norm + trigger = grad_norm > max_norm clipped_grad = p.grad * \ (max_norm / torch.max(grad_norm, diff --git a/nfnets/sgd_agc.py b/nfnets/sgd_agc.py index 99646f3..db57b01 100644 --- a/nfnets/sgd_agc.py +++ b/nfnets/sgd_agc.py @@ -107,7 +107,7 @@ def step(self, closure=None): grad_norm = unitwise_norm(p.grad.detach()) max_norm = param_norm * group['clipping'] - trigger = grad_norm < max_norm + trigger = grad_norm > max_norm clipped_grad = p.grad * \ (max_norm / torch.max(grad_norm, diff --git a/setup.py b/setup.py index dd7f9f8..11f19b3 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name = 'nfnets-pytorch', packages = find_packages(), - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'NFNets, PyTorch', long_description=long_description,