You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.
For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix: $\lambda = (\eta_{\text{effective}} / \eta_{\text{max}}) \lambda$ with updates: $w_{t} = (1- \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$.
Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.
For better or for worse, I think "AdamW" now refers to the LR-coupled version. In addition to PyTorch and JAX, I see this formulation in Keras (and therefore TensorFlow), PaddlePaddle, and MXNet. If we implement a LR-decoupled variant, we should give it a new name or make it an opt-in option so we don't confuse users.
There has been a lot of discussion in other frameworks:
Allowing the user to invoke the fully decoupled version via either option (opt-in or another name) would be helpful. Couple more references on the potential utility of independent WD below.
Describe the bug
AdamW implementation (see here) does not truly decouple the weight decay and learning rate parameters in line with the adamw paper. This coupling often complicates HP tuning as tuning the learning rate also changes the effective WD used to train the model.
The implementation computes the updates as
where$\eta_{\text{effective}} = \eta_t \eta_{\text{max}}$ with $\eta_t$ denoting the scheduler and $\eta_{\text{max}}$ the max/base LR.
This clearly couples LR and WD and is not in line with the paper which proposes to compute the updates as
For easier and more intuitive tuning, it would be useful to enable the completely decoupled version of AdamW via the simple fix:$\lambda = (\eta_{\text{effective}} / \eta_{\text{max}}) \lambda$ with updates: $w_{t} = (1- \lambda) w_{t-1} - \eta_{\text{effective}} {\hat{m}_t}/{\sqrt{\hat{v}_t} + \epsilon}$ .
Note: This bug also exists in implementations of AdamW in Pytorch and Optax and has already been highlighted a few times across different papers, libraries, and blogs. More links below for reference.
The text was updated successfully, but these errors were encountered: