Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instability with HyperProp #2

Open
hadaev8 opened this issue Aug 4, 2020 · 39 comments
Open

Instability with HyperProp #2

hadaev8 opened this issue Aug 4, 2020 · 39 comments

Comments

@hadaev8
Copy link
Contributor

hadaev8 commented Aug 4, 2020

I tried it in two tasks, but got nans during training, any suggestions?

@JRC1995
Copy link
Owner

JRC1995 commented Aug 7, 2020

I am not sure. I don't remember if I really tested LaProp based optimizer all that much. Is there any instability with other optimizers that use hypergradient descent? There could still be some bug in the repo too. Most of them are mix mash of multiple techniques, it may be also possible that some of these combinations simply don't work out, or even lead to some mathematical issues that I have missed; so if you aren't using something standard there may be issues with the combination idea itself. As for me I think I mostly did some sanity runs on hypergrad in first few epochs on a more basic optimizer. It didn't seem like there was anything obviously wrong from the changes in the lr, but it's hard to say just from that. My implementation of hypergrad is also a bit different from the original papers. Not sure if that causes any instability.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Aug 7, 2020

I tested HyperProp on 4 tasks, first two are fine (nvidia waveglow and resnet image classification), one crush then I use weight decay param (if I remember case right, maybe it just crush) (a fork of nvidia tacotron2) and the last task is transformer image gan.

I tested also HDQHSGDW on the last task and it seems to be fine.
If you have time for debugging, I can share code and data of the last case for reproducing issue (its colab notebook, so easy to share and run).

What optimizer you would recommend as default nobrain drop-in replacement?
I took HyperProp because it's last in the file, so I assumed its more advanced.

@JRC1995
Copy link
Owner

JRC1995 commented Aug 7, 2020

It depends on what you are trying to do. Research? Research on optimizers, specifically? Just applications? Or simply playing around with optimizers?

I haven't personally tried the optimizers exhaustively, so I can't really tell which synergistic combination is good to go. The whole hyperxxx series optimizers may not be even tried by anyone; especially given my implementation has differences than original hypergrad descent, plus, I think also added some extra gradient computation to handle weight decay and other things.

Generally, it's probably just better to stick with AdamW. This repo is mostly experimental and more for playing around.
With AdamW you can add lookahead, or try RAdam + Lookahead combination (some calls it Ranger). That should be decent. QHRAdam + Lookahead should be decent too. Some reported good results with AdaMod too. So AdaMod + Lookahead may be a good choice too.

This should recover AdaMod + lookahead

optimizer = DemonRanger(params=model.parameters(),
                        lr=config.lr,
                        weight_decay=config.wd,
                        betas=(0.9,0.999,0.999), # restore default AdamW betas
                        nus=(1.0,1.0), # disables QHMomentum
                        IA=False, # disables Iterate Averaging
                        rectify=False, # disables RAdam Recitification
                        AdaMod_bias_correct=False, #disables AdaMod bias corretion (not used originally)
                        use_demon=False #disables Decaying Momentum (DEMON)
                        use_gc=False #disables gradient centralization
                        amsgrad=False # disables amsgrad
                        )

Hypergradient descent is interesting conceptually. You can try and compare it. The original paper says even not-much-tuned hypergradient descent on lr is better than tuning lr in Adam and so on. But I don't think it was extensively tested on multiple tasks. But the gradients for hypergradient descent gets more complicated with AdaMod I think, and there could be some bugs in my gradient computation implementations in that case; and it's probably slow too. If you want to try that you probably have to see: HyperRangerMod

@hadaev8
Copy link
Contributor Author

hadaev8 commented Aug 8, 2020

Production/hobby.
For now I on mine modification of ranger + lookahead + gc + diffgrad. But you did much greater work.
I did not found computation overhead very important (around +10%).
https://i.imgur.com/z9PD2br.png
Very noisy graph, but seems HyperProp (red) coverage much faster.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 4, 2020

I had more tests with all optimizers (and different params) on my TTS task and for some reason, it performed much worse than my current optimizer.
Orange is my version of ranger, dark green is default adamw.
123
Also, I noticed nan loss on validation with HyperProp + demon.

Any suggestions how to track reason of such behavor?

@JRC1995
Copy link
Owner

JRC1995 commented Sep 4, 2020

Hmm...not sure.

@saurabh-kataria
Copy link

I had similar experience in my application: speech enhancement. All optimizers gave much worse performance except DemonRanger, which matched performance with my default optimizers (Adam, RAdam, Ranger, AdamW). Although I did not play around with parameters of optimizers much.
I did not experience NaNs in HyperProp BTW. Also, IWA (Iterative Weight Averaging) did not make any difference. I'm wondering if it is similar to SWA (Stochastic Weight Averaging), which helps me BTW.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 6, 2020

@saurabh-kataria @hadaev8 also when you say all optimizers gave much worse performance, did you mean even the recovered AdamW optimizer (from Readme) from this repo is giving worse performance than, say, your default AdamW?

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 6, 2020

For now, I only tried all optimizers with 2-3 features disabled.
I will try to do it in reverse, start from adamw and add features.
Also, your adamw implementation is not the same as vanilla pytorch, is it intended?

@JRC1995
Copy link
Owner

JRC1995 commented Sep 6, 2020

@hadaev8 depends on what you mean "not same". On the surface, my implementation is a more general optimizer from which AdamW can be recovered with specific hyperparameters. So from that perspective, it should be of course different from vanilla pytorch. However, if you notice some significant difference in the core AdamW-related operations within the source code let me know. There shouldn't be any intended difference in the operations in the semantic level once the same hyperparameters are used. Though, I don't remember if I ever referred to the vanilla pytorch AdamW. I think I looked at some other repo for reference.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 6, 2020

@JRC1995
Copy link
Owner

JRC1995 commented Sep 6, 2020

@hadaev8
It should be the same. We need weight - lr*weight_decay*weight.
Both p.mul_(1 - group['lr'] * group['weight_decay']) (vanilla python i.e weight*(1-lr*weight_decay)) and p.add_(group['lr'] * group['weight_decay'], p ) (what I am doing i.e weight - lr*weight_decay*weight) should result in the same thing.
You can verify it using:

x = T.tensor(10.0)
print(x.add_(-0.1, x))
x = T.tensor(10.0)
print(x.mul_(1-0.1))

@saurabh-kataria
Copy link

saurabh-kataria commented Sep 6, 2020

I meant I compared with pytorch internal implementations of Adam, AdamW, etc. For Ranger, I meant comparison with: https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
I didn't experiment with "recovered versions of Adam etc."

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 7, 2020

@JRC1995
Yes, you are right, still, I prefer how vanilla pytorch looks like.
I'm running tests now.
Adam extracted from demon ranger seems to be on pair with vanilla.
QHMomentum + AdaMod case nan on validation.

Validation loss 52480:       nan  
Traceback (most recent call last):
  File "train.py", line 487, in <module>
    args.n_gpus, args.local_rank, args.group_name, hparams)
  File "train.py", line 443, in train
    hparams.distributed_run, rank, criterion)
  File "train.py", line 164, in validate
    logger.log_model(model, iteration)
  File "/content/tacotron2/utils/logger.py", line 15, in log_model
    self.add_histogram(tag, value.data.cpu().numpy(), iteration)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/writer.py", line 425, in add_histogram
    histogram(tag, values, bins, max_bins=max_bins), global_step, walltime)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/summary.py", line 234, in histogram
    hist = make_histogram(values.astype(float), bins, max_bins)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/summary.py", line 272, in make_histogram
    raise ValueError('The histogram is empty, please file a bug report.')
ValueError: The histogram is empty, please file a bug report.

Also
rectify + AdaMod + weight_decay=1e-2 seems to be fine
Also just QHMomentum is fine too.
So its strange.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 7, 2020

@hadaev8
Yes, I think QHMomentum should be fine if AdamW is fine, because it's mostly a simply extension.
it may be also simply possible that you get "nan" simply because you become unlucky with the random seed, and perhaps the types of models you are running themselves have more tendency towards instability which is for some reason excarbated by certain synergistic combinations and hyperparemeters.

I have actually used AdaMod and QHMomentum together in a project without such issues. The code was a bit different (it was based on the codes here so shouldn't be really different):

from torch.optim.optimizer import Optimizer, required
import torch
import math
import numpy as np


class generic_optimizer(Optimizer):
    def __init__(self, params, config,
                 lr=1e-3, betas=(0.9, 0.999, 0.999), nus=(1.0, 1.0), eps=1e-8,
                 padam_p=0.5, amsgrad=False, AdaMod=False,
                 warmup=False, warmup_steps=2000, use_demon=False,
                 use_gc=False, weight_decay=0.01,
                 lookahead_k=0, lookahead_alpha=0.8,
                 iterate_average=False, iterate_average_cycle=None,
                 epochs=None, steps_per_epoch=None):

        attributes_str = ["lr", "betas", "nus", "padam_p", "eps", "lookahead_k", "lookahead_alpha", "amsgrad", "AdaMod",
                          "use_demon", "warmup", "warmup_steps", "iterate_average", "use_gc", "iterate_average_cycle",
                          "weight_decay"]

        attributes = {}

        for key in attributes_str:
            try:
                attributes[key] = eval("config."+key)
            except:
                attributes[key] = eval(key)

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(attributes['lr']))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(attributes['eps']))
        if not 0.0 <= padam_p <= 0.5:
            raise ValueError("Invalid padam_p value: {}".format(attributes['padam_p']))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(attributes['betas'][0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(attributes['betas'][1]))
        if not 0.0 <= betas[2] < 1.0:
            raise ValueError("Invalid beta parameter at index 2: {}".format(attributes['betas'][2]))
        if not 0.0 <= nus[0] <= 1.0:
            raise ValueError("Invalid nu parameter at index 0: {}".format(attributes['nus'][0]))
        if not 0.0 <= nus[1] <= 1.0:
            raise ValueError("Invalid nu parameter at index 1: {}".format(attributes['nus'][1]))
        if not 0.0 <= lookahead_alpha <= 1.0:
            raise ValueError("Invalid lookahead alpha parameter: {}".format(
                attributes['lookahead_alpha']))
        if attributes['use_demon'] is True and (epochs is None or steps_per_epoch is None):
            raise ValueError(
                "Missing epochs and steps_per_epoch values (Needed if use_demon is True)")

        self.amsgrad = attributes['amsgrad']
        self.AdaMod = attributes['AdaMod']

        self.warmup = attributes['warmup']
        self.warmup_steps = attributes['warmup_steps']
        self.use_demon = attributes['use_demon']

        self.use_gc = attributes['use_gc']

        self.lookahead_k = attributes['lookahead_k']

        self.iterate_average = attributes['iterate_average']
        self.iterate_average_cycle = attributes['iterate_average_cycle']

        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.iterate_average_cycle = steps_per_epoch if iterate_average_cycle is None else iterate_average_cycle

        self.T = self.epochs*self.steps_per_epoch

        defaults = dict(lr=attributes['lr'],
                        betas=attributes['betas'],
                        nus=attributes['nus'],
                        eps=attributes['eps'],
                        padam_p=attributes['padam_p'],
                        lookahead_alpha=attributes['lookahead_alpha'],
                        weight_decay=attributes['weight_decay'])
        super(generic_optimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(generic_optimizer, self).__setstate__(state)

    def apply_demon(self, beta1_init, state):
        if self.use_demon:
            temp = 1-(state['step']/self.T)
            beta1 = beta1_init * temp / ((1-beta1_init)+beta1_init*temp)
        else:
            beta1 = beta1_init
        return beta1

    def apply_warmup(self, steps, lr):
        if self.warmup:
            w = min([1.0, steps/self.warmup_steps])
            return w*lr
        return lr

    def decide(self, activate_iterate_average, step):
        iterate_average_step = False
        lookahead_step = False

        if self.iterate_average and activate_iterate_average:
            lookahead_step = False
            if step % self.iterate_average_cycle == 0:
                iterate_average_step = True
        elif self.lookahead_k == 0:
            lookahead_step = False
        else:
            if step % self.lookahead_k == 0:
                lookahead_step = True
            else:
                lookahead_step = False

        return lookahead_step, iterate_average_step

    def step(self, activate_iterate_average=False, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('AdvRanger does not support sparse gradients')

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                    if self.lookahead_k > 0:
                        state['cached_params'] = p.data.clone()
                    if self.iterate_average:
                        state['num_models'] = 0

                    if self.amsgrad:
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
                    if self.AdaMod:
                        state['n_avg'] = torch.zeros_like(p.data)

                if 'cached_params' not in state and self.iterate_average and activate_iterate_average:
                    state['cached_params'] = p.data.clone()

                state['step'] += 1

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                beta1_init, beta2, beta3 = group['betas']
                nu1, nu2 = group['nus']

                beta1 = self.apply_demon(beta1_init, state)
                lr = self.apply_warmup(state['step'], group['lr'])

                wd = group['weight_decay']
                alpha = group['lookahead_alpha']

                lookahead_step, iterate_average_step = self.decide(activate_iterate_average,
                                                                   state['step'])

                if self.use_gc and grad.view(-1).size(0) > 1:
                    grad.add_(-grad.mean(dim=tuple(range(1, len(list(grad.size())))), keepdim=True))

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                momentum = exp_avg.clone()
                momentum.div_(1 - (beta1 ** state['step'])).mul_(nu1).add_(1-nu1, grad)

                if wd != 0:
                    p.data.add_(-wd*lr, p.data)

                beta2_t = beta2 ** state['step']

                if self.amsgrad and state['step'] > 1:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    vt = max_exp_avg_sq.clone()
                else:
                    vt = exp_avg_sq.clone()

                bias_correction2 = 1 - beta2_t
                vt.div_(bias_correction2)
                if nu2 != 1.0:
                    vt.mul_(nu2).addcmul_(1-nu2, grad, grad)
                denom = vt.pow_(group['padam_p']).add_(group['eps'])
                n = lr/denom

                if self.AdaMod:
                    n_avg = state['n_avg']
                    n_avg.mul_(beta3).add_(1 - beta3, n)
                    torch.min(n, n_avg, out=n)

                p.data.add_(-n*momentum)

                if lookahead_step:
                    p.data.mul_(alpha).add_(1.0 - alpha, state['cached_params'])
                    state['cached_params'].copy_(p.data)

                if iterate_average_step:
                    p.data.add_(state["num_models"], state['cached_params']
                                ).div_(state["num_models"]+1.0)
                    state['cached_params'].copy_(p.data)
                    state["num_models"] += 1

        return loss

with config:

        self.weight_decay = 1e-4
        self.lr = 1e-3
        self.lookahead_k = 5
        self.use_gc = False  # gradient centralization
        self.AdaMod = True
        self.warmup = True
        self.betas = (0.999, 0.999, 0.999)  # (beta1, beta2, beta3) # beta3 correspond to AdaMod
        self.nus = (0.7, 1.0)  # nu1 and nu2 for Quasi-Hyperbolic momentum
        self.demon = False  # Decaying Momentum

and everything else being the default
and max_grad_norm = 5

grad norm is applied outside the optimizer as:

T.nn.utils.clip_grad_norm_(self.parameters, self.config["max_grad_norm"])

Though in this case grad_norm clipping may be the "magic ingredient" (not sure if I would have or did face any issues without it).

It was a Transformer model with a few differences than the original, the most significant one (in terms of changing the training dynamics) was probably the use of ReZero.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 7, 2020

it may be also simply possible that you get "nan" simply because you become unlucky with the random seed, and perhaps the types of models you are running themselves have more tendency towards instability which is for some reason excarbated by certain synergistic combinations and hyperparemeters.

It would be so hard to debug.

I'm using the same seed for every experiment, still, ofc it may be a bad seed in some cases.
I'm using grad norming with with max 1, but don't think it really different from 5.

Here is another chart
123
Blue is crushed QHMomentum + AdaMod and pink near blue is QHMomentum + gc
Expected gc to be always good addition huh

@JRC1995
Copy link
Owner

JRC1995 commented Sep 7, 2020

Yes, the value of grad norm wouldn't probably break or make things; I was just wondering absence or presence of the grad norm clipping itself may be a factor, but it seems like it isn't since we were both using it. I think in my case I am using gc sort of indiscriminately, but originally it was used only for convnets....not sure...forgot about it. I also had a bad time with gc when using ReZero, because I was using scalar parameter, and when using gc, the gradients of the scalar parameter just becomes zero -- so scalar parameters are never learned. Thus the way you initialize the parameters also play a role (for example, even if I use the parameters same as before for ReZero, but say, initialize the scalar ReZero alpha parameters for multiple layers in a batch (in a vector) ...gc would not turn it into 0). I think at the very least I updated the code after that to disable gc for scalar parameters, but there can still be other things that the implementation is not too careful about.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 7, 2020

Condition grad.dim() > 1 should cover rezero, right?

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 7, 2020

QHMomentum + gc also raised nan
Btw why your beta1 is 0.999?

@JRC1995
Copy link
Owner

JRC1995 commented Sep 7, 2020

Condition grad.dim() > 1 should cover rezero, right?

Yes. I already have a condition, though it's more ugly; didn't know about dim.

Btw why your beta1 is 0.999?

Probably because that was the default for QHAdam. In the readme, you can see that I cange the beta1 when using Adam without QH Momentum.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 8, 2020

weight_decay + rectify + use_gc + amsgrad is ok
weight_decay + rectify + AdaMod + use_gc + amsgrad is ok too
So I think hypergrad case problems.
I will test more hypergrad setups.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 9, 2020

So, every modification in adam seems to be ok except QH.
msgrad seems to add huge loss gap and I guess it turned off by default in vanilla pytorch's adam for a reason.

I tried again HyperProp (all features disabled) and lr for one parameter seems to be too huge
It grows and grows, now it.
lr tensor(132.3673, device='cuda:0')
Maybe its rezero (I'm using it only once in the model), should you advice how to track what param is it?

Also, with HDM=False, hypergrad_lr=1e-8 (should be default setup, right?) leraning rate do not changes.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 10, 2020

Yes you can try to track which parameter is experiencing that. It could be possible that hypergradient lr is not that stable in a more complex setting. IIRC, the original code was in a simpler settings without weight decay (and without considering gradients for that), and without proper per-parameter individual gradient change. Instead of Hyperprop I would actually try out HDQHSGDW or HyperRangerMod to play around with hypergrad lr.

Also....actually if you add diffgrad I am not sure, but it can mean that the hypergradient will require different gradients (so you have to manually change the gradient maths inside the code)...sorry for about that.

I think HyperProp has the highest chance of having a bug since I didn't try that out too much IIRC, and the gradient computation can have some mistake as a higher chance (at least for the other two, I think i could extend more easily from the pre-established hypergradient maths from the paper or repo -- though had to be extended to add AdaMod and weight decay), furthermore, new out-of-nowhere optimizers like LaProp also probably have higher chances of not being "good" --- there's a reason why most people still use adam\adamW -- because it has been consistently decent and sort of stood the test of time, even though there may even have been theoretical mistakes. Hyperprop is a sort of wild combination of a rarely used and tested technique (hypergradient lr) and some new optimizer out of nowhere (LaProp) which can easily lead to unstable or wierd stuff.

To have a a more grounded experient with hypergradients I would that's why recommend HyperRangerMod (extends upon Adam/DemonRanger base but with hypergradient option) or HDQHSGDW (extends upon good ol' SGD-with-momentum --- which is also interesting because even now nothing almost really beats a "well tuned SGD or SGD with momentum". Adams and so on, are better for getting better results quickly, but simpler SGDs usually win out in the long run. But I guess, the problem is "well tuning" it which may take more work (Idk, I am not really an optimizer guy --- just made this repo when I was in the mood). In this case, hypergradient lr may have the potential to remedy it. It also add QHMomentum since it was shown to be a good addition to SGD with momentum too in the paper. It's also more minimanilistic in this sense.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 10, 2020

Don't really know what the "defaults" are...for LaProp there probably isn't any real official default. For Adam, I don't remember if there was a good recommended defualt. Probably it was something like what you posted, but you have to just check the hypergrad repo and/or the paper to better confirm.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 10, 2020

Yes, so far only one experiment beat vanilla adamw. It is radam + adamod + weight decay.
Seems like lookahead, amsgrad and gc lead to worse loss at least in short terms (like 10-20% of train config).
How do you think, is it make sense to use radam and adamod at same time? Seems like these modifications solve the same things.

I'm not using diffgrad for HyperProp, also turned off nostalgia, demon, gc, and nu.
With HDM=True/False, hypergrad_lr=1e-3 seems to be stable but bad.
With HDM=True, hypergrad_lr=1e-2 it failed to nan.
Honestly, I'm trying HyperProp because it has less hyper params.

HDQHSGDW was very slow in the short term, also seems like nobody uses sgd for nlp tasks.
Maybe I will test it in long run later.

Also maybe it's important, I use a dropout scheduler for the first 10% of the train. Maybe it adds instability.

https://i.imgur.com/jrCFSyl.png

@JRC1995
Copy link
Owner

JRC1995 commented Sep 10, 2020

Yes, it seems AdaMod and RAdam would both reduce the initial variance and more. You can still probably try to use it, though there probably would not much of a benefit.

Short-term measurement can be tricky. I think lookahead would tend to regularize more, which can lead to slower fitting --- though I guess if you are plotting validation performance, I can't say that adding regularization would necessarily decrease it in the short term. I have also tried lowering p for PAdam, but it seemed to make training appear worse in short term, but I suspect it makes thing better in the long run. But I didn't have patience with it, and I think I wasn't really finding much out of it even with longer runs.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 12, 2020

Guess with 30% of train config I can consider gradient centralization as a bad modification.
Strange a bit.
Diffgrad gives a bit better loss at least in short term.
Also seems like I have good overfitting lol.
123

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 12, 2020

Btw, where is inconsistency, not all optimizers accept alpha=1.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 12, 2020

If you have alpha=1 you will essentially make lookahead redundant. A more principled way to disable lookahead is setting k=0 IIRC, that will deactivate all the lookahead computation. I think the initial range checking is a bit inconsistent across optimizers (some accepting alpha = 1 some not), but ultimately it doesn't matter a lot -- you can get away with k=0 whenever you want alpha=1.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 12, 2020

I mean i disabled lookahead with k=0 alpha=1 like in readme.
Ofc it doesnt matter what alpha is, but i copy pasted it from demonranger and got error.
I think it should be consisten across all optimizers.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 12, 2020

fixed it.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 14, 2020

Here is another comparison.
123
This time adaptive lr models do not crush.
Maybe because I turned off nostalgia and gc.
Still it worse. hypermod without QH almost on pair, but still a bit worse.

Also, I wonder, as far as I understand, lr change performed by sgd algorithm, how do you think, is it possible to apply adam on this value?

@JRC1995
Copy link
Owner

JRC1995 commented Sep 14, 2020

Also, I wonder, as far as I understand, lr change performed by sgd algorithm, how do you think, is it possible to apply adam on this value?

Didn't get this part. Could you elaborate what you mean by "lr change performed by sgd"?

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 14, 2020

https://openreview.net/forum?id=BkrsAzWAb
It says they apply gradient descent (not stochastic one) algorithm on lr value.
Still, the difference between gd and sgd is what sgd calculated on batch and gd on the whole dataset,
It seems strange because lr changes calculated per batch.
So how do you think, is it possible/make sense to replace it with adam algorithm?
In my imagination it should give better learning rate values.

@JRC1995
Copy link
Owner

JRC1995 commented Sep 14, 2020

I think they just meant it loosely or in a more general sense. It's still based on per-batch statistics. The original paper itself applies hypergradient descent in Adam, and hyperranger also allows application of hypergradient on adam if that's what you were asking for.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 14, 2020

Just checked code and hypergradient for adam (and its variations in repo) uses adaptive moments for lr too. Nevermind then.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 15, 2020

Btw, I wanna experiment with lr dropout.
Implementation seems to be easy.
I just need to make a mask and multiply.
But this line confuses me
https://github.com/HuangxingLin123/Learning-Rate-Dropout/blob/master/cifar10/adam_lrd.py#L79
Then we need clone variable? Should you explain why its cloned here?

@JRC1995
Copy link
Owner

JRC1995 commented Sep 15, 2020

Not sure. Doesn't look like cloning is necessary there. Though shouldn't probably hurt too much if it's there and unnecessary.

@hadaev8
Copy link
Contributor Author

hadaev8 commented Sep 19, 2020

Some observations.
Demon vs no demon, train loss at top show instability too:
123
Not sure why it unstable at the end, maybe because weight decay?
Lr dropout vs no lr dropout:
123
So, dropout on pair with default (guess unlucky seed) or a bit better.
Another qh values:
123
I took values from paper from the translation experiment, which should be similar to tts task.
For now it at least on pair with no qh, lets see how it develops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants