Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Feat cailey sgd #1127

Draft
wants to merge 10 commits into
base: dev
Choose a base branch
from
Next Next commit
Cailey SGD
  • Loading branch information
pablomlago committed Nov 28, 2024
commit 21e5e00853e47bb66882b982ba19e0c445b0f684
197 changes: 197 additions & 0 deletions src/brevitas/optim/sgdg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a reference to the code origin

# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py

import random

import torch
from torch.optim.optimizer import Optimizer


def unit(v, dim: int = 1, eps: float = 1e-8):
vnorm = norm(v, dim)
return v / vnorm.add(eps), vnorm


def norm(v, dim: int = 1):
assert len(v.size()) == 2
return v.norm(p=2, dim=dim, keepdim=True)


def matrix_norm_one(W):
out = torch.abs(W)
out = torch.sum(out, dim=0)
out = torch.max(out)
return out


def Cayley_loop(X, W, tan_vec, t): #
[n, p] = X.size()
Y = X + t * tan_vec
for i in range(5):
Y = X + t * torch.matmul(W, 0.5 * (X + Y))

return Y.t()


def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n
[p, n] = tan_vec.size()
tan_vec.t_()
q, r = torch.linalg.qr(tan_vec)
d = torch.diag(r, 0)
ph = d.sign()
q *= ph.expand_as(q)
q.t_()

return q


episilon = 1e-8


class SGDG(Optimizer):
r"""This optimizer updates variables with two different routines
based on the boolean variable 'stiefel'.

If stiefel is True, the variables will be updated by SGD-G proposed
as decorrelated weight matrix.

If stiefel is False, the variables will be updated by SGD.
This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py.

Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups

-- common parameters
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
stiefel (bool, optional): whether to use SGD-G (default: False)

-- parameters in case stiefel is False
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)

-- parameters in case stiefel is True
omega (float, optional): orthogonality regularization factor (default: 0)
grad_clip (float, optional): threshold for gradient norm clipping (default: None)
"""

def __init__(
self,
params,
lr: float = 1e-3,
momentum: int = 0,
dampening: int = 0,
weight_decay: int = 0,
nesterov: bool = False,
stiefel: bool = False,
omega: int = 0,
grad_clip=None,
) -> None:
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
stiefel=stiefel,
omega=0,
grad_clip=grad_clip,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGDG, self).__init__(params, defaults)

def __setstate__(self, state) -> None:
super(SGDG, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)

def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
momentum = group["momentum"]
stiefel = group["stiefel"]

for p in group["params"]:
if p.grad is None:
continue

unity, _ = unit(p.data.view(p.size()[0], -1))
if stiefel and unity.size()[0] <= unity.size()[1]:
weight_decay = group["weight_decay"]
dampening = group["dampening"]
nesterov = group["nesterov"]

rand_num = random.randint(1, 101)
if rand_num == 1:
unity = qr_retraction(unity)

g = p.grad.data.view(p.size()[0], -1)

lr = group["lr"]

param_state = self.state[p]
if "momentum_buffer" not in param_state:
param_state["momentum_buffer"] = torch.zeros(g.t().size())
if p.is_cuda:
param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda()

V = param_state["momentum_buffer"]
V = momentum * V - g.t()
MX = torch.mm(V, unity)
XMX = torch.mm(unity, MX)
XXMX = torch.mm(unity.t(), XMX)
W_hat = MX - 0.5 * XXMX
W = W_hat - W_hat.t()
t = 0.5 * 2 / (matrix_norm_one(W) + episilon)
alpha = min(t, lr)

p_new = Cayley_loop(unity.t(), W, V, alpha)
V_new = torch.mm(W, unity.t()) # n-by-p
# check_identity(p_new.t())
p.data.copy_(p_new.view(p.size()))
V.copy_(V_new)

else:
d_p = p.grad.data
# defined.
try:
if weight_decay != 0:
# defined.
d_p.add_(weight_decay, p.data)
except:
pass
if momentum != 0:
param_state = self.state[p]
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = d_p.clone()
else:
buf = param_state["momentum_buffer"]
# always defined.
buf.mul_(momentum).add_(1 - dampening, d_p)
# defined.
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf

p.data.add_(-group["lr"], d_p)

return loss
128 changes: 128 additions & 0 deletions tests/brevitas/optim/test_cailey_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Copyright (C) 2024, Advanced Micro Devices, Inc.
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)

All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.

3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU,
NEC Laboratories America and IDIAP Research Institute nor the names
of its contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""

from copy import deepcopy
from itertools import product
import math
import sys
from typing import List, Union
import unittest

from hypothesis import given
import numpy as np
import pytest
import pytest_cases
from pytest_cases import fixture
from scipy.stats import ortho_group
import torch
from torch.nn import Parameter
import torch.nn as nn
from torch.optim.lr_scheduler import LinearLR

from brevitas.optim.sgdg import SGDG
from tests.conftest import SEED

torch.manual_seed(SEED)

from torch.testing._internal.common_optimizers import OptimizerInput

OPTIMIZER_KWARGS = [{
"stiefel": True}, {
"stiefel": True, "lr": 1e-2}, {
"stiefel": True, "lr": torch.tensor(0.001)}]
LR_SCHEDULER_ARGS = [
None,
(LinearLR, {
"start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),]
DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
DTYPES = [torch.float32]

device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES)))


class TestCaileySGD:

@device_dtype_parametrize
@pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS)
@pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS)
def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args):
optim_cls = SGDG
# Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N}
N = 5
P = 3
weight_orthogonal = ortho_group(dim=N, seed=SEED).rvs()
weight_orthonormal = weight_orthogonal / np.linalg.norm(weight_orthogonal, ord=2, axis=0)
# Verify that the matrix is orthonormal
assert np.allclose(np.matmul(weight_orthonormal.T, weight_orthonormal), np.eye(N))
# Initialize weights, the Cailey SGD optimizer expects a matrix of size PxN, given the
# condition unity.size()[0] <= unity.size()[1]
weight = Parameter(
torch.from_numpy(weight_orthonormal[:, :P].T).to(device=device, dtype=dtype))

optimizer = optim_cls([weight], **deepcopy(optimizer_kwargs))
scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0](
optimizer, **lr_scheduler_args[1])

def closure():
optimizer.zero_grad()
loss = (weight - torch.eye(N, P, device=device, dtype=dtype).t()).pow(2).sum()
loss.backward()
return loss

initial_value = closure().item()
for _ in range(20):
closure()
optimizer.step()
if scheduler is not None:
scheduler.step()

# Verify that iterates stay within the Stiefel manifold
assert torch.allclose(
weight.detach().cpu() @ weight.detach().cpu().t(),
torch.eye(P, P, device=device, dtype=dtype).detach().cpu(),
atol=1e-5,
rtol=1e-6)

if optimizer_kwargs.get("maximize", False):
assert closure().item() > initial_value
else:
assert closure().item() < initial_value