From 087d17179d30a8c37959155b9bc4270e7c21bb57 Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Wed, 8 Mar 2023 21:38:43 +0530 Subject: [PATCH 1/5] Fix MNIST example --- tests/mnist.py | 60 +++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/tests/mnist.py b/tests/mnist.py index 8385c24..b665d95 100644 --- a/tests/mnist.py +++ b/tests/mnist.py @@ -23,7 +23,7 @@ import torch.nn as nn import torch.nn.functional as F -from nadir import nadir as optim +import nadir as nd from torch.optim.lr_scheduler import StepLR @@ -45,28 +45,29 @@ args.device : bool = 'cuda' if torch.cuda.is_available() else 'cpu' args.log_interval : int = 10 args.epochs : int = 10 -args.betas : Tuple[float, float] = (0.9, 0.99) -args.eps : float = 1e-16 -args.optimizer : Any = optim.Adam +args.betas : Tuple[float, float] = (0.9, 0.999) +args.eps : float = 1e-8 +args.optimizer : Any = nd.SGD -# with open("random_seeds.txt", 'r') as file: -# file_str = file.read().split('\n') -# seeds = [int(num) for num in file_str] -args.random_seeds : List[int] = [42] +with open("random_seeds.txt", 'r') as file: + file_str = file.read().split('\n') + seeds = [int(num) for num in file_str] +args.random_seeds : List[int] = seeds args.seed : int = args.random_seeds[0] # writing the logging args as a namespace obj largs = argparse.Namespace() -largs.run_name : str = 'DoE-Adam' +largs.run_name : str = 'Nadir-Adadelta 2' largs.run_seed : str = args.seed # Initialising the seeds -torch.manual_seed(args.seed) -torch.cuda.manual_seed(args.seed) -np.random.seed(args.seed) -random.seed(args.seed) +def set_seed(x : int): + torch.manual_seed(x) + torch.cuda.manual_seed(x) + np.random.seed(x) + random.seed(x) class MNISTestNet(nn.Module): def __init__(self): @@ -138,7 +139,7 @@ def seed_worker(): train_loader = torch.utils.data.DataLoader( datasets.MNIST( - '../../data', + './data', train=True, download=True, transform=transforms.Compose( @@ -157,7 +158,7 @@ def seed_worker(): test_loader = torch.utils.data.DataLoader( datasets.MNIST( - '../data', + './data', train=False, transform=transforms.Compose( [ @@ -176,35 +177,23 @@ def seed_worker(): -def mnist_tester(optimizer=None, args = None, largs = None): +def mnist_tester(model, optimizer=None, args = None): train_loss = [] test_loss = [] - torch.manual_seed(args.random_seeds[0]) + set_seed(args.random_seeds[0]) device = args.device use_cuda = True if device == torch.device('cuda') else False train_loader, test_loader = prepare_loaders(args, use_cuda) - model = MNISTestNet().to(device) - # create grid of images and write to wandb - images, labels = next(iter(train_loader)) - img_grid = utils.make_grid(images) - wandb.log({'mnist_images': img_grid}) - - # custom optimizer from torch_optimizer package - if args.optimizer == optim.SGD: - config = optim.SGDConfig(lr=args.learning_rate) - elif args.optimizer == optim.Adam: - config = optim.AdamConfig(lr=args.learning_rate, betas=args.betas, eps=args.eps) - # config = config(lr=args.learning_rate) - optimizer = optimizer(model.parameters(), config) - # optimizer = optim(model.parameters(), lr=args.learning_rate) + # images, labels = next(iter(train_loader)) + # img_grid = utils.make_grid(images) + # wandb.log({'mnist_images': img_grid}) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - for epoch in (pbar := tqdm(range(1, args.epochs + 1))): loss=train(args, model, device, train_loader, optimizer, epoch) tloss=test(model, device, test_loader) @@ -227,12 +216,13 @@ def mnist_tester(optimizer=None, args = None, largs = None): # Initialising the optimiser - optimizer = args.optimizer + model = MNISTestNet().to(args.device) + # config = nd.AdadeltaConfig(lr = args.learning_rate, beta_1=args.betas[0], beta_2=args.betas[1]) + optimizer = nd.Adadelta(model.parameters()) # config = AutoConfig(args.params..) # optimizer = args.optimizer(config) - # Running the mnist_tester - mnist_tester(optimizer, args, largs) + mnist_tester(model, optimizer, args) run.finish() \ No newline at end of file From aa210f608fcc8895679437e00f671810616ce5bd Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Wed, 8 Mar 2023 21:38:59 +0530 Subject: [PATCH 2/5] Change version info in __init__ --- src/nadir/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nadir/__init__.py b/src/nadir/__init__.py index aeb70f2..48c13a0 100644 --- a/src/nadir/__init__.py +++ b/src/nadir/__init__.py @@ -26,7 +26,7 @@ from .sgd import SGD, SGDConfig -__version__ = "0.0.1" +__version__ = "0.0.2" __all__ = ('Adadelta', 'AdadeltaConfig', From 5437967728d76baf91ea7963a9f506fd981db6da Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Wed, 8 Mar 2023 21:39:26 +0530 Subject: [PATCH 3/5] Add AMSGrad and WeightDecay (AdamW style) --- src/nadir/base.py | 48 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/nadir/base.py b/src/nadir/base.py index 19a9cf7..d6dba6c 100644 --- a/src/nadir/base.py +++ b/src/nadir/base.py @@ -28,12 +28,14 @@ class BaseConfig: beta_1 : float = 0.0 beta_2 : float = 0.0 eps : float = 1E-8 + weight_decay : float = 0.0 + amsgrad : bool = False def dict(self): return self.__dict__ -class BaseOptimizer(Optimizer): +class BaseOptimizer (Optimizer): def __init__ (self, params, config: BaseConfig = BaseConfig()): if not config.lr > 0.0: @@ -53,18 +55,21 @@ def init_state(self, state, group, param): - + state['step'] = 0 + if self.config.momentum: - state['momentum_step'] = 0 state['momentum'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if self.config.adaptive: - state['adaptive_step'] = 0 state['adaptivity'] = torch.zeros_like(param, memory_format=torch.preserve_format) - + + if self.config.amsgrad: + state['amsgrad'] = torch.zeros_like(param, memory_format=torch.preserve_format) + def momentum(self, state, grad): - step = state['momentum_step'] + step = state['step'] m = state['momentum'] beta_1 = self.config.beta_1 @@ -72,15 +77,29 @@ def momentum(self, m_hat = m.div(1 - beta_1**(step + 1)) state['momentum'] = m - state['momentum_step'] = step + 1 - return m_hat - + + def amsgrad(adaptivity): + + def __adaptivity__(self, state, grad): + u = adaptivity(self, state, grad) + + if self.config.amsgrad: + v = state['amsgrad'] + v = torch.max(v, u) + state['amsgrad'] = v + return v + + return u + + return __adaptivity__ + + @amsgrad def adaptivity(self, state, grad): - step = state['adaptive_step'] + step = state['step'] v = state['adaptivity'] beta_2 = self.config.beta_2 @@ -88,7 +107,6 @@ def adaptivity(self, v_hat = v.div(1 - beta_2**(step + 1)) state['adaptivity'] = v - state['step'] = step + 1 return torch.sqrt(v_hat + self.config.eps) def update(self, @@ -111,6 +129,12 @@ def update(self, else: param.data.add_(upd, alpha = -1 * lr) + if self.config.weight_decay > 0: + param.data.add_(param.data, + alpha = -1 * lr * self.config.weight_decay) + + state['step'] += 1 + @torch.no_grad() def step(self, closure = None): loss = None @@ -129,5 +153,7 @@ def step(self, closure = None): state = self.state[param] if len(state) == 0: self.init_state(state, group, param) + self.update(state, group, grad, param) + return loss \ No newline at end of file From ba8af2c827d4c968dfff43a764779416a65ef36f Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Wed, 8 Mar 2023 21:39:46 +0530 Subject: [PATCH 4/5] Add Lion Optimizer --- src/nadir/lion.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/nadir/lion.py diff --git a/src/nadir/lion.py b/src/nadir/lion.py new file mode 100644 index 0000000..526f5dd --- /dev/null +++ b/src/nadir/lion.py @@ -0,0 +1,62 @@ +### Copyright 2023 [Dawn Of Eve] + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Any, Optional +from dataclasses import dataclass + +import torch + +from .base import BaseOptimizer +from .base import BaseConfig + + +__all__ = ['LionConfig', 'Lion'] + +@dataclass +class LionConfig(BaseConfig): + lr : float = 1E-4 + momentum : bool = True + beta_1 : float = 0.9 + beta_2 : float = 0.99 + eps : float = 1E-8 + weight_decay : float = 0. + + + +class Lion(BaseOptimizer): + + def __init__(self, params, config: LionConfig = LionConfig()): + if not config.momentum: + raise ValueError(f"Invalid value for momentum in config: {config.momentum} ", + "Value must be True") + if not 1 > config.beta_1 > 0.: + raise ValueError(f"Invalid value for beta_1 in config: {config.beta_1} ", + "Value must be between 1 and 0") + if not 1 > config.beta_2 > 0.: + raise ValueError(f"Invalid value for beta_2 in config: {config.beta_2} ", + "Value must be between 1 and 0") + super().__init__(params, config) + self.config = config + + def momentum(self, state, grad): + m = state['momentum'] + beta_1 = self.config.beta_1 + beta_2 = self.config.beta_2 + + u = m.mul(beta_1).add_(grad, alpha=(1-beta_1)) + + m.mul_(beta_2).add_(grad, alpha=(1-beta_2)) + + state['momentum'] = m + + return torch.sign(u) \ No newline at end of file From ff0b86d5c8258863b260815e7b7d5dd931f54b1c Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Wed, 8 Mar 2023 21:39:55 +0530 Subject: [PATCH 5/5] Remove MNIST data --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5f66781..dc3397e 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ dist/* src/nadir.egg-info/* -nadir.egg-info/* \ No newline at end of file +nadir.egg-info/* + +tests/data/* \ No newline at end of file