Skip to content

Commit ba47d7f

Browse files
minseo25jijoongmoon
authored andcommitted
[optimizer] Fix AdamW logic and apply decoupled weight decay
This patch corrects the AdamW optimizer implementation to match the paper and PyTorch reference. Signed-off-by: Minseo Kim <[email protected]>
1 parent 8f9025d commit ba47d7f

File tree

7 files changed

+118
-28
lines changed

7 files changed

+118
-28
lines changed

nntrainer/optimizers/adam.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ void Adam::setProperty(const std::vector<std::string> &values) {
5858
Optimizer::setProperty(left);
5959
}
6060

61-
double Adam::getUpdatedLearningRate(unsigned int iteration, double ll) const {
61+
double Adam::getUpdatedLearningRate(unsigned int iteration, double lr) const {
6262
auto &beta1 = std::get<PropsB1>(adam_props).get();
6363
auto &beta2 = std::get<PropsB2>(adam_props).get();
6464

6565
std::function<float(double)> biasCorrection = [&](float f) {
6666
return 1.0f - pow(f, iteration + 1);
6767
};
6868

69-
ll *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
69+
lr *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
7070

71-
return ll;
71+
return lr;
7272
}
7373

7474
void Adam::applyGradient(RunOptimizerContext &context) {

nntrainer/optimizers/adam.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ class Adam : public Optimizer {
121121
/**
122122
* @brief Get updated learning rate
123123
*
124-
* @param ll learning rate
124+
* @param lr learning rate
125125
*
126126
* @return updated learning rate
127127
*/
128-
double getUpdatedLearningRate(unsigned int iteration, double ll) const;
128+
double getUpdatedLearningRate(unsigned int iteration, double lr) const;
129129
};
130130
} /* namespace nntrainer */
131131

nntrainer/optimizers/adamw.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,32 @@
2323

2424
namespace nntrainer {
2525

26-
AdamW::AdamW() : adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef()) {
26+
AdamW::AdamW() :
27+
adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef(),
28+
PropsWeightDecayW()) {
2729
/** default properties */
28-
auto &[b1, b2, eps, torch_ref] = adam_props;
30+
auto &[b1, b2, eps, torch_ref, weight_decay] = adam_props;
2931
b1.set(0.9f);
3032
b2.set(0.999f);
31-
eps.set(1.0e-7f);
33+
eps.set(1.0e-8f);
3234
torch_ref.set(false);
35+
weight_decay.set(0.0f);
3336
}
3437

3538
AdamW::~AdamW() {}
3639

3740
enum AdamParams { wm, wv };
3841

3942
std::vector<TensorDim> AdamW::getOptimizerVariableDim(const TensorDim &dim) {
40-
return {dim, dim};
43+
/**
44+
* @note We assume the optimizer parameters should be full precision to
45+
* maintain the accuracy even in mixed precision training.
46+
*/
47+
TensorDim wm_dim(dim);
48+
TensorDim wv_dim(dim);
49+
wm_dim.setDataType(ml::train::TensorDim::DataType::FP32);
50+
wv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
51+
return {wm_dim, wv_dim};
4152
}
4253

4354
void AdamW::exportTo(Exporter &exporter,
@@ -51,6 +62,14 @@ void AdamW::setProperty(const std::vector<std::string> &values) {
5162
Optimizer::setProperty(left);
5263
}
5364

65+
double AdamW::getUpdatedLearningRate(unsigned int iteration, double lr) const {
66+
auto &beta1 = std::get<PropsB1>(adam_props).get();
67+
auto &beta2 = std::get<PropsB2>(adam_props).get();
68+
auto biasCorrection = [&](double f) { return 1.0 - pow(f, iteration + 1); };
69+
lr *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
70+
return lr;
71+
}
72+
5473
void AdamW::applyGradient(RunOptimizerContext &context) {
5574
Tensor empty_tensor;
5675

@@ -68,13 +87,8 @@ void AdamW::applyGradient(RunOptimizerContext &context) {
6887
auto &beta1 = std::get<PropsB1>(adam_props).get();
6988
auto &beta2 = std::get<PropsB2>(adam_props).get();
7089
auto &epsilon = std::get<PropsEpsilon>(adam_props).get();
71-
auto &torch_ref = std::get<TorchRef>(adam_props).get();
90+
auto &weight_decay = std::get<PropsWeightDecayW>(adam_props).get();
7291

73-
// This is implementation of adam from original paper.
74-
// This is not deleted intentionally.
75-
unsigned int iteration = context.getIteration();
76-
float biasCorrection1 = 1 - pow(beta1, iteration + 1);
77-
float biasCorrection2 = 1 - pow(beta2, iteration + 1);
7892
Tensor &wm = context.getOptimizerVariable(AdamParams::wm);
7993
Tensor &wv = context.getOptimizerVariable(AdamParams::wv);
8094

@@ -84,16 +98,23 @@ void AdamW::applyGradient(RunOptimizerContext &context) {
8498
wv.multiply_i(beta2);
8599
wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
86100

87-
wv.divide_i(biasCorrection2);
101+
// Decoupled weight decay: w = w - lr * wd * w
102+
if (weight_decay > 0.0) {
103+
Tensor &w = context.isMixedPrecision() ? context.getWeightFP32()
104+
: context.getWeight();
105+
w.multiply_i(1.0f - (context.getLearningRate() * weight_decay));
106+
}
107+
108+
// Adam update with bias-corrected lr
109+
double lr_t =
110+
getUpdatedLearningRate(context.getIteration(), context.getLearningRate());
111+
88112
std::function<double(double)> sqrtEps = [epsilon](double f) {
89113
return 1 / (sqrtDouble(f) + epsilon);
90114
};
91115
x_grad = wv.apply<float>(sqrtEps, x_grad);
92-
x_grad.divide_i(biasCorrection1);
93116
x_grad.multiply_i(wm);
94-
context.calcWeightDecayGradient();
95-
96-
context.applyGradient(context.getLearningRate(), x_grad);
117+
context.applyGradient(lr_t, x_grad);
97118
}
98119

99120
} // namespace nntrainer

nntrainer/optimizers/adamw.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424

2525
namespace nntrainer {
2626

27+
/**
28+
* @brief weight decay property for AdamW
29+
*/
30+
class PropsWeightDecayW : public Property<double> {
31+
public:
32+
static constexpr const char *key = "weight_decay";
33+
using prop_tag = double_prop_tag;
34+
};
35+
2736
/**
2837
* @class AdamW Optimizer class
2938
* @brief AdamW Optimizer
@@ -78,7 +87,17 @@ class AdamW : public Optimizer {
7887
void setProperty(const std::vector<std::string> &values) override;
7988

8089
private:
81-
std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef> adam_props;
90+
std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef, PropsWeightDecayW>
91+
adam_props;
92+
93+
/**
94+
* @brief Get updated learning rate
95+
*
96+
* @param lr learning rate
97+
*
98+
* @return updated learning rate
99+
*/
100+
double getUpdatedLearningRate(unsigned int iteration, double lr) const;
82101
};
83102
} /* namespace nntrainer */
84103

nntrainer/optimizers/optimizer_context.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ Tensor &RunOptimizerContext::getWeight() const {
2222
return weight->getVariableRef();
2323
}
2424

25+
/**
26+
* @brief Get the Weight FP32 tensor object (master weight for mixed precision)
27+
*/
28+
Tensor &RunOptimizerContext::getWeightFP32() const {
29+
return weight->getVariableFP32Ref();
30+
}
31+
2532
/**
2633
* @brief Get the Weight Gradient tensor object
2734
*/
@@ -63,7 +70,10 @@ void RunOptimizerContext::applyLossScale(Tensor &fp32_grad) {
6370
fp32_grad.divide_i(loss_scale);
6471
}
6572

66-
void RunOptimizerContext::calcWeightDecayGradient() {
67-
weight->calcWeightDecayGradient();
73+
/**
74+
* @brief Return if the underlying weight is mixed precision
75+
*/
76+
bool RunOptimizerContext::isMixedPrecision() const {
77+
return weight->isMixedPrecision();
6878
}
6979
} // namespace nntrainer

nntrainer/optimizers/optimizer_context.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,26 @@ class RunOptimizerContext {
4444
*/
4545
Tensor &getWeight() const;
4646

47+
/**
48+
* @brief Get the Weight FP32 tensor object (master weight for mixed
49+
* precision)
50+
*
51+
* @return Tensor& Reference to the FP32 master weight tensor
52+
*/
53+
Tensor &getWeightFP32() const;
54+
4755
/**
4856
* @brief Get the Weight Gradient tensor object
4957
*
5058
* @return Tensor& Reference to the weight grad tensor
5159
*/
5260
Tensor &getGradient() const;
5361

62+
/**
63+
* @brief Return if the underlying weight is mixed precision
64+
*/
65+
bool isMixedPrecision() const;
66+
5467
/**
5568
* @brief Get the optimizer variable associated to this weight
5669
*
@@ -102,11 +115,6 @@ class RunOptimizerContext {
102115
*/
103116
void applyLossScale(Tensor &fp32_grad);
104117

105-
/**
106-
* @brief Calculate gradient from the decay of the weight
107-
*/
108-
void calcWeightDecayGradient();
109-
110118
private:
111119
Weight *weight; /**< weights for the optimizer */
112120
size_t iteration; /**< iteration number */

test/unittest/unittest_nntrainer_internal.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,38 @@ TEST(nntrainer_Optimizer, create_09_n) {
124124
EXPECT_ANY_THROW(op = ac->createOptimizerObject("non-existing type", {}));
125125
}
126126

127+
/**
128+
* @brief Optimizer create
129+
*/
130+
TEST(nntrainer_Optimizer, create_adamw_01_p) {
131+
std::unique_ptr<nntrainer::Optimizer> op;
132+
auto &eg = nntrainer::Engine::Global();
133+
auto ac = eg.getRegisteredContext("cpu");
134+
EXPECT_NO_THROW(op =
135+
ac->createOptimizerObject("adamw", {"weight_decay=0.01"}));
136+
}
137+
138+
/**
139+
* @brief Optimizer create
140+
*/
141+
TEST(nntrainer_Optimizer, create_adamw_02_n) {
142+
std::unique_ptr<nntrainer::Optimizer> op;
143+
auto &eg = nntrainer::Engine::Global();
144+
auto ac = eg.getRegisteredContext("cpu");
145+
EXPECT_ANY_THROW(op = ac->createOptimizerObject("adamw", {"unknown"}));
146+
}
147+
148+
/**
149+
* @brief Optimizer create
150+
*/
151+
TEST(nntrainer_Optimizer, create_adamw_03_n) {
152+
std::unique_ptr<nntrainer::Optimizer> op;
153+
auto &eg = nntrainer::Engine::Global();
154+
auto ac = eg.getRegisteredContext("cpu");
155+
EXPECT_ANY_THROW(op =
156+
ac->createOptimizerObject("adamw", {"learning_rate:0.1"}));
157+
}
158+
127159
TEST(nntrainer_throw_if, throw_invalid_arg_p) {
128160
try {
129161
NNTR_THROW_IF(1 == 1, std::invalid_argument) << "error msg";

0 commit comments

Comments
 (0)