Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Growing parameter capacity as training progress #1539

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion alf/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import copy
import numpy as np
import torch
from typing import Callable
from typing import Callable, Union

import alf
from alf.utils import common
from alf.utils import tensor_utils
from alf.utils.schedulers import as_scheduler, ConstantScheduler, Scheduler
from . import adam_tf, adamw, nero_plus
from .utils import get_opt_arg

Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(self,
clip_by_global_norm=False,
parvi=None,
repulsive_weight=1.,
capacity_ratio: Union[float, Scheduler] = 1.0,
min_capacity: int = 8192,
name=None,
**kwargs):
"""
Expand Down Expand Up @@ -142,6 +145,14 @@ def __init__(self,

repulsive_weight (float): the weight of the repulsive gradient term
for parameters with attribute ``ensemble_group``.

capacity_ratio: For each parameter, `numel() * capacity_ratio`
elements are turned on for training. The remaining elements
are frozen. ``capacity_ratio`` can be a scheduler to control
the capacity over the training process.
min_capacity (int): For each parameter, at least so many elements
are turned on for training.

name (str): the name displayed when summarizing the gradient norm. If
None, then a global name in the format of "class_name_i" will be
created, where "i" is the global optimizer id.
Expand All @@ -165,6 +176,9 @@ def __init__(self,
self._gradient_clipping = gradient_clipping
self._clip_by_global_norm = clip_by_global_norm
self._parvi = parvi
self._capacity_ratio = alf.utils.schedulers.as_scheduler(
capacity_ratio)
self._min_capacity = min_capacity
self._norms = {} # norm of each parameter
if parvi is not None:
assert parvi in ['svgd', 'gfsf'
Expand Down Expand Up @@ -213,6 +227,32 @@ def step(self, closure=None):
if self._parvi is not None:
self._parvi_step()

capacity_ratio = self._capacity_ratio()
if capacity_ratio < 1:
# To achieve this, we assign a random number for each element of
# the parameter. An element is turned on if its assigned random number
# is less than capacity_ratio. To save memory, we don't store the
# random numbers. Instead, we save the random number generator state.
rng_state = torch.get_rng_state()
states = {}
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = {}
if 'rng_state' not in state:
rng_state = torch.get_rng_state()
s['rng_state'] = rng_state
else:
rng_state = state['rng_state']
torch.set_rng_state(rng_state)
n = p.numel()
ratio = max(self._min_capacity / n, capacity_ratio)
mask = torch.rand_like(p) >= ratio
s['mask'] = mask
s['old_param'] = p.data.clone()
states[p] = s
torch.set_rng_state(rng_state)

super(NewCls, self).step(closure=closure)

if not isinstance(self, NeroPlus):
Expand All @@ -222,6 +262,19 @@ def step(self, closure=None):
param.data.mul_(
self._norms[param] / (param.norm() + 1e-30))

if capacity_ratio < 1:
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = states[p]
# The following is faster than p.data[mask] = old_param[mask]
p.data.copy_(
torch.where(s['mask'], s['old_param'], p.data))
del s['mask']
del s['old_param']
if 'rng_state' in s:
state['rng_state'] = s['rng_state']

@common.add_method(NewCls)
def _parvi_step(self):
for param_group in self.param_groups:
Expand Down
Loading