Skip to content

Commit

Permalink
Merge pull request #1110 from daoducanhc/chris-oct-11th
Browse files Browse the repository at this point in the history
Add the implementation of MSSGD for model selection
  • Loading branch information
lzjpaul authored Oct 11, 2023
2 parents 9d8a2f5 + 063effd commit 830e1f7
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions examples/model_selection_psql/ms_mlp/train_ms_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,104 @@ def call_with_returns(self, loss):
# print ("call_with_returns after apply loss.data: \n", loss.data)
return pn_p_g_list

# MSSGD -- sub class of MSOptimizer
class MSSGD(MSOptimizer):
"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__.
Args:
lr(float): learning rate
momentum(float, optional): momentum factor(default: 0)
weight_decay(float, optional): weight decay(L2 penalty)(default: 0)
dampening(float, optional): dampening for momentum(default: 0)
nesterov(bool, optional): enables Nesterov momentum(default: False)
Typical usage example:
>> > from singa import opt
>> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
>> > optimizer.update()
__ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf
.. note::
The implementation of SGD with Momentum / Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and: math: `\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""

def __init__(self,
lr=0.1,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
dtype=tensor.float32):
super(MSSGD, self).__init__(lr)

# init momentum
if type(momentum) == float or type(momentum) == int:
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
self.momentum = Constant(momentum)
elif isinstance(momentum, DecayScheduler):
self.momentum = momentum
momentum = momentum.init_value
else:
raise TypeError("Wrong momentum type")
# self.dtype = dtype
# self.mom_value = self.momentum(self.step_counter).as_type(self.dtype)
self.mom_value = self.momentum(self.step_counter)

# init dampening
if type(dampening) == float or type(dampening) == int:
self.dampening = Constant(dampening)
elif isinstance(dampening, DecayScheduler):
self.dampening = dampening
dampening = dampening.init_value
else:
raise TypeError("Wrong dampening type")
# self.dam_value = self.dampening(self.step_counter).as_type(self.dtype)
self.dam_value = self.dampening(self.step_counter)

# init weight_decay
if type(weight_decay) == float or type(weight_decay) == int:
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
self.weight_decay = Constant(weight_decay)
elif isinstance(weight_decay, DecayScheduler):
self.weight_decay = weight_decay
else:
raise TypeError("Wrong weight_decay type")
# self.decay_value = self.weight_decay(self.step_counter).as_type(self.dtype)
self.decay_value = self.weight_decay(self.step_counter)

# init other params
self.nesterov = nesterov
self.moments = dict()

# check value
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")

# Data augmentation
def augmentation(x, batch_size):
Expand Down

0 comments on commit 830e1f7

Please sign in to comment.