From 9aa94dd9f362c74c1d56807579d387dfd6ced234 Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Tue, 8 Aug 2023 20:36:09 -0700 Subject: [PATCH] Growing parameter capacity as training progress This is done through Optimizer. Two arguments are added for optimizer: capacity_ratio: scheduler for controlling the number of training elements of a parameter. min_capacity: minimal number elements of each parameter being traing To dynamically change capacity, we assign a random number for each element of the parameter. An element is turned on if its assigned random number is less than capacity_ratio. To save memory, we don't store the random numbers. Instead, we save the random number generator state. --- alf/optimizers/optimizers.py | 55 +++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/alf/optimizers/optimizers.py b/alf/optimizers/optimizers.py index e1974dcf6..8c9075da9 100644 --- a/alf/optimizers/optimizers.py +++ b/alf/optimizers/optimizers.py @@ -15,11 +15,12 @@ import copy import numpy as np import torch -from typing import Callable +from typing import Callable, Union import alf from alf.utils import common from alf.utils import tensor_utils +from alf.utils.schedulers import as_scheduler, ConstantScheduler, Scheduler from . import adam_tf, adamw, nero_plus from .utils import get_opt_arg @@ -103,6 +104,8 @@ def __init__(self, clip_by_global_norm=False, parvi=None, repulsive_weight=1., + capacity_ratio: Union[float, Scheduler] = 1.0, + min_capacity: int = 8192, name=None, **kwargs): """ @@ -142,6 +145,14 @@ def __init__(self, repulsive_weight (float): the weight of the repulsive gradient term for parameters with attribute ``ensemble_group``. + + capacity_ratio: For each parameter, `numel() * capacity_ratio` + elements are turned on for training. The remaining elements + are frozen. ``capacity_ratio`` can be a scheduler to control + the capacity over the training process. + min_capacity (int): For each parameter, at least so many elements + are turned on for training. + name (str): the name displayed when summarizing the gradient norm. If None, then a global name in the format of "class_name_i" will be created, where "i" is the global optimizer id. @@ -165,6 +176,9 @@ def __init__(self, self._gradient_clipping = gradient_clipping self._clip_by_global_norm = clip_by_global_norm self._parvi = parvi + self._capacity_ratio = alf.utils.schedulers.as_scheduler( + capacity_ratio) + self._min_capacity = min_capacity self._norms = {} # norm of each parameter if parvi is not None: assert parvi in ['svgd', 'gfsf' @@ -213,6 +227,32 @@ def step(self, closure=None): if self._parvi is not None: self._parvi_step() + capacity_ratio = self._capacity_ratio() + if capacity_ratio < 1: + # To achieve this, we assign a random number for each element of + # the parameter. An element is turned on if its assigned random number + # is less than capacity_ratio. To save memory, we don't store the + # random numbers. Instead, we save the random number generator state. + rng_state = torch.get_rng_state() + states = {} + for param_group in self.param_groups: + for p in param_group['params']: + state = self.state[p] + s = {} + if 'rng_state' not in state: + rng_state = torch.get_rng_state() + s['rng_state'] = rng_state + else: + rng_state = state['rng_state'] + torch.set_rng_state(rng_state) + n = p.numel() + ratio = max(self._min_capacity / n, capacity_ratio) + mask = torch.rand_like(p) >= ratio + s['mask'] = mask + s['old_param'] = p.data.clone() + states[p] = s + torch.set_rng_state(rng_state) + super(NewCls, self).step(closure=closure) if not isinstance(self, NeroPlus): @@ -222,6 +262,19 @@ def step(self, closure=None): param.data.mul_( self._norms[param] / (param.norm() + 1e-30)) + if capacity_ratio < 1: + for param_group in self.param_groups: + for p in param_group['params']: + state = self.state[p] + s = states[p] + # The following is faster than p.data[mask] = old_param[mask] + p.data.copy_( + torch.where(s['mask'], s['old_param'], p.data)) + del s['mask'] + del s['old_param'] + if 'rng_state' in s: + state['rng_state'] = s['rng_state'] + @common.add_method(NewCls) def _parvi_step(self): for param_group in self.param_groups: