Skip to content

Commit

Permalink
Growing parameter capacity as training progress (#1539)
Browse files Browse the repository at this point in the history
This is done through Optimizer. Two arguments are added for optimizer:

capacity_ratio: scheduler for controlling the number of training elements of a parameter.
min_capacity: minimal number elements of each parameter being traing

To dynamically change capacity, we assign a random number for each element of
the parameter. An element is turned on if its assigned random number
is less than capacity_ratio. To save memory, we don't store the
random numbers. Instead, we save the random number generator state.
  • Loading branch information
emailweixu authored Sep 22, 2023
1 parent 3940def commit 61f1565
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion alf/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import copy
import numpy as np
import torch
from typing import Callable
from typing import Callable, Union

import alf
from alf.utils import common
from alf.utils import tensor_utils
from alf.utils.schedulers import as_scheduler, ConstantScheduler, Scheduler
from . import adam_tf, adamw, nero_plus
from .utils import get_opt_arg

Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(self,
clip_by_global_norm=False,
parvi=None,
repulsive_weight=1.,
capacity_ratio: Union[float, Scheduler] = 1.0,
min_capacity: int = 8192,
name=None,
**kwargs):
"""
Expand Down Expand Up @@ -142,6 +145,14 @@ def __init__(self,
repulsive_weight (float): the weight of the repulsive gradient term
for parameters with attribute ``ensemble_group``.
capacity_ratio: For each parameter, `numel() * capacity_ratio`
elements are turned on for training. The remaining elements
are frozen. ``capacity_ratio`` can be a scheduler to control
the capacity over the training process.
min_capacity (int): For each parameter, at least so many elements
are turned on for training.
name (str): the name displayed when summarizing the gradient norm. If
None, then a global name in the format of "class_name_i" will be
created, where "i" is the global optimizer id.
Expand All @@ -165,6 +176,9 @@ def __init__(self,
self._gradient_clipping = gradient_clipping
self._clip_by_global_norm = clip_by_global_norm
self._parvi = parvi
self._capacity_ratio = alf.utils.schedulers.as_scheduler(
capacity_ratio)
self._min_capacity = min_capacity
self._norms = {} # norm of each parameter
if parvi is not None:
assert parvi in ['svgd', 'gfsf'
Expand Down Expand Up @@ -213,6 +227,32 @@ def step(self, closure=None):
if self._parvi is not None:
self._parvi_step()

capacity_ratio = self._capacity_ratio()
if capacity_ratio < 1:
# To achieve this, we assign a random number for each element of
# the parameter. An element is turned on if its assigned random number
# is less than capacity_ratio. To save memory, we don't store the
# random numbers. Instead, we save the random number generator state.
rng_state = torch.get_rng_state()
states = {}
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = {}
if 'rng_state' not in state:
rng_state = torch.get_rng_state()
s['rng_state'] = rng_state
else:
rng_state = state['rng_state']
torch.set_rng_state(rng_state)
n = p.numel()
ratio = max(self._min_capacity / n, capacity_ratio)
mask = torch.rand_like(p) >= ratio
s['mask'] = mask
s['old_param'] = p.data.clone()
states[p] = s
torch.set_rng_state(rng_state)

super(NewCls, self).step(closure=closure)

if not isinstance(self, NeroPlus):
Expand All @@ -222,6 +262,19 @@ def step(self, closure=None):
param.data.mul_(
self._norms[param] / (param.norm() + 1e-30))

if capacity_ratio < 1:
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = states[p]
# The following is faster than p.data[mask] = old_param[mask]
p.data.copy_(
torch.where(s['mask'], s['old_param'], p.data))
del s['mask']
del s['old_param']
if 'rng_state' in s:
state['rng_state'] = s['rng_state']

@common.add_method(NewCls)
def _parvi_step(self):
for param_group in self.param_groups:
Expand Down

0 comments on commit 61f1565

Please sign in to comment.