Skip to content

Commit

Permalink
Growing parameter capacity as training progress
Browse files Browse the repository at this point in the history
This is done through Optimizer. Two arguments are added for optimizer:

capacity_ratio: scheduler for controlling the number of training elements of a parameter.
min_capacity: minimal number elements of each parameter being traing

To dynamically change capacity, we assign a random number for each element of
the parameter. An element is turned on if its assigned random number
is less than capacity_ratio. To save memory, we don't store the
random numbers. Instead, we save the random number generator state.
  • Loading branch information
emailweixu committed Sep 21, 2023
1 parent a37049a commit ffc13c6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
1 change: 1 addition & 0 deletions alf/examples/benchmarks/locomotion/locomotion_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import alf
from alf.utils.math_ops import clipped_exp
from alf.optimizers import AdamTF
from alf.utils.schedulers import LinearScheduler

alf.config(
"create_environment", num_parallel_environments=1, env_name="Ant-v3")
Expand Down
55 changes: 54 additions & 1 deletion alf/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import copy
import numpy as np
import torch
from typing import Callable
from typing import Callable, Union

import alf
from alf.utils import common
from alf.utils import tensor_utils
from alf.utils.schedulers import as_scheduler, ConstantScheduler, Scheduler
from . import adam_tf, adamw, nero_plus
from .utils import get_opt_arg

Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(self,
clip_by_global_norm=False,
parvi=None,
repulsive_weight=1.,
capacity_ratio: Union[float, Scheduler] = 1.0,
min_capacity: int = 8192,
name=None,
**kwargs):
"""
Expand Down Expand Up @@ -142,6 +145,14 @@ def __init__(self,
repulsive_weight (float): the weight of the repulsive gradient term
for parameters with attribute ``ensemble_group``.
capacity_ratio: For each parameter, `numel() * capacity_ratio`
elements are turned on for training. The remaining elements
are freezed. ``capacity_ratio`` can be a scheduler to control
the capacity over the training process.
min_capacity (int): For each parameter, at least so many elements
are turned on for training.
name (str): the name displayed when summarizing the gradient norm. If
None, then a global name in the format of "class_name_i" will be
created, where "i" is the global optimizer id.
Expand All @@ -165,6 +176,9 @@ def __init__(self,
self._gradient_clipping = gradient_clipping
self._clip_by_global_norm = clip_by_global_norm
self._parvi = parvi
self._capacity_ratio = alf.utils.schedulers.as_scheduler(
capacity_ratio)
self._min_capacity = min_capacity
self._norms = {} # norm of each parameter
if parvi is not None:
assert parvi in ['svgd', 'gfsf'
Expand Down Expand Up @@ -213,6 +227,32 @@ def step(self, closure=None):
if self._parvi is not None:
self._parvi_step()

capacity_ratio = self._capacity_ratio()
if capacity_ratio < 1:
# To achieve this, we assign a random number for each element of
# the parameter. An element is turned on if its assigned random number
# is less than capacity_ratio. To save memory, we don't store the
# random numbers. Instead, we save the random number generator state.
rng_state = torch.get_rng_state()
states = {}
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = {}
if 'rng_state' not in state:
rng_state = torch.get_rng_state()
s['rng_state'] = rng_state
else:
rng_state = state['rng_state']
torch.set_rng_state(rng_state)
n = p.numel()
ratio = max(self._min_capacity / n, capacity_ratio)
mask = torch.rand_like(p) >= ratio
s['mask'] = mask
s['old_param'] = p.data.clone()
states[p] = s
torch.set_rng_state(rng_state)

super(NewCls, self).step(closure=closure)

if not isinstance(self, NeroPlus):
Expand All @@ -222,6 +262,19 @@ def step(self, closure=None):
param.data.mul_(
self._norms[param] / (param.norm() + 1e-30))

if capacity_ratio < 1:
for param_group in self.param_groups:
for p in param_group['params']:
state = self.state[p]
s = states[p]
# The following is faster than p.data[mask] = old_param[mask]
p.data.copy_(
torch.where(s['mask'], s['old_param'], p.data))
del s['mask']
del s['old_param']
if 'rng_state' in s:
state['rng_state'] = s['rng_state']

@common.add_method(NewCls)
def _parvi_step(self):
for param_group in self.param_groups:
Expand Down

0 comments on commit ffc13c6

Please sign in to comment.