Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

jingweiz
Copy link

@jingweiz jingweiz commented Nov 16, 2017

Hey,
We added SGDW and AdamW in optim, accoridng to the new ICLR submission from Loshchilov and Hutter: Decoupled Weight Decay Regularization
.
We also found some inconsistency of the current implementation of Adam with the original Adam paper and we asked about it here: https://discuss.pytorch.org/t/adam-implementation/10031. We tried to make it consistent in the commit: f195087a3666f85a417ee7561bec439bb68f81c3, but then it failed during testing which might due to some inconsistency with the legacy.optim.adam so we changed it back in the last commit. But it would great if you guys can check it to see if this is actually an inconsistency then we can add it back.
Thanks in advance!

@radekosmulski
Copy link

There seems to exist sgdw.py in the root of the directory. Not sure if this was intentional?

The docstring for SGDRCosineLR in the example section seems to be referencing CosineLR and also the parameter names seem off.

In the step method, we call get_lr but don't pass it epoch_idx which seems to be a required param (I tried this and it blows up).

Thank you for working on this 👍 Happy to help in whatever capacity I can - will continue to test this as new versions become available.

@TiRune
Copy link
Contributor

TiRune commented Dec 6, 2017

I tested the current code on a simple SeNET network on CIFAR10. The current ADAMW code I found here works nearly identically to SGD with momentum 0.9, and does not seem to decay the learning rate at all. In contract to that, the original ADAM optimiser decays the learning rate properly and converges very differently from SGD.

I found the reason for this, and will leave this comment here for other people who are looking for a solution. The initial model used learning rate decay in steps after 2 set periods of iterations. In the old ADAM/SGD settings, this means the weight decay also gets decayed by a factor (in my case 0.1) every set of iterations. This means that to get the same results as with the original methods, you have to take care that the weight_decay term now has a different interpretation! E.g. if the final decay was done by multiplication of a factor 0.01, your weight decay should be the original setting *0.01 for the same results.

p.data.addcdiv_(-step_size, exp_avg, denom)

if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)

This comment was marked as off-topic.

@weiyangfb weiyangfb added the awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it label Aug 14, 2018
@geyang
Copy link

geyang commented Nov 20, 2018

Just want to bump this PR up.

@i-zaitsev
Copy link

i-zaitsev commented Nov 29, 2018

When do you think the fix will be merged into master?

@OutSorcerer
Copy link

OutSorcerer commented Jan 27, 2019

The fact that the paper was accepted to ICLR 2019 should be also taken into consideration.

@taohu88
Copy link

taohu88 commented Feb 3, 2019

BTW, to do it in right way, the weight decay should be multiplied by lr. I saw many proposals didn't do that. This will break how people set weight decay in old way. Suddenly, they have to set weight decay much lower.

Copy link

@taohu88 taohu88 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear, we should do this

                # p = p - wd * lr * p
                # according to the paper, weight decay should be done here
                # be consistent to general approach, the weight decay will be always multiplied by lr
                wd = group['weight_decay']
                if wd != 0:
                    p.data.add_(-wd * lr, p.data)
                
                # note (lr * p.grad) = step_size * exp_avg/denom
                # so p = p - lr * p.grad = p - step_size * exp_avg/denom
                p.data.addcdiv_(-step_size, exp_avg, denom)
                # net result p = p_old - wd * lr * p_old - step_size * exp_avg/denom

return ([base_lr * lr_multi for base_lr in self.base_lrs],
[base_weight_decay * lr_multi * weight_decay_norm_multi
for base_weight_decay in self.base_weight_decays])
for base_weight_decay in self.base_weight_decays])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we always requires weight decay to be multiplied by lr in optimizer itself, then there is no need to do lr rate multiple here. It is better not to do here, but in optimizer.

@BramVanroy
Copy link

This is labelled as awaiting response, but I'm not sure from whom. From @jingweiz ? Or from torch developers?

@bhack
Copy link
Contributor

bhack commented Mar 30, 2019

Any news?

@naifrec
Copy link
Contributor

naifrec commented Apr 9, 2019

hello, does anyone know if Fixing Weight Decay Regularization in Adam could be the reason why Keras (tf backend) seem to outperform PyTorch (given exact same training pipeline) when using Adam? If so, can we prioritize the merging of this PR?

@yet-another-account
Copy link

What is this PR waiting on right now?

@soumith
Copy link
Member

soumith commented May 1, 2019

The title might seem like this is fixing a software bug, but it isn't.
The authors propose a change to the Adam algorithm, that performs better under certain conditions.

this clarification is relevant to comments made by @naifrec

About the awaiting-response tag, it was tagged to get a response from the author on @colllin 's comments on the PR.

@soumith
Copy link
Member

soumith commented Jun 20, 2019

cc: @vincentqb can you review this and get it to completion if it makes sense. Use the guidelines that I shared with you separately.

@vincentqb vincentqb changed the title Fixing Weight Decay Regularization in Adam Decoupled Weight Decay Regularization Jun 28, 2019
@vincentqb
Copy link
Contributor

vincentqb commented Jun 28, 2019

The title might seem like this is fixing a software bug, but it isn't.
The authors propose a change to the Adam algorithm, that performs better under certain conditions.

Changed name from "Fixing Weight Decay Regularization in Adam" to "Decoupled Weight Decay Regularization" to reflect the updated name of the paper being quoted.

facebook-github-bot pushed a commit that referenced this pull request Jul 2, 2019
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: #17468, #10866, #3740, #4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: #21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
p.data.addcdiv_(-step_size, exp_avg, denom)

if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)

This comment was marked as resolved.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this because the you are subtracting the derivative of the L2 norm of the weights, not the L2 norm itself?

@vincentqb vincentqb mentioned this pull request Jul 2, 2019
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: pytorch#17468, pytorch#10866, pytorch#3740, pytorch#4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: pytorch#21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
@farhadrgh
Copy link
Contributor

Is there any particular reason you are not adding AdamaxW?

@github-actions
Copy link
Contributor

github-actions bot commented Mar 1, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
Stale pull requests will automatically be closed 30 days after being marked Stale

@github-actions github-actions bot added the Stale label Mar 1, 2022
@facebook-github-bot
Copy link
Contributor

Hi @jingweiz!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@github-actions github-actions bot removed the Stale label Mar 29, 2022
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it open source Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Morty Proxy This is a proxified and sanitized view of the page, visit original site.