2323
2424namespace nntrainer {
2525
26- AdamW::AdamW () : adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef()) {
26+ AdamW::AdamW () :
27+ adam_props (PropsB1(), PropsB2(), PropsEpsilon(), TorchRef(),
28+ PropsWeightDecayW ()) {
2729 /* * default properties */
28- auto &[b1, b2, eps, torch_ref] = adam_props;
30+ auto &[b1, b2, eps, torch_ref, weight_decay ] = adam_props;
2931 b1.set (0 .9f );
3032 b2.set (0 .999f );
31- eps.set (1 .0e-7f );
33+ eps.set (1 .0e-8f );
3234 torch_ref.set (false );
35+ weight_decay.set (0 .0f );
3336}
3437
3538AdamW::~AdamW () {}
3639
3740enum AdamParams { wm, wv };
3841
3942std::vector<TensorDim> AdamW::getOptimizerVariableDim (const TensorDim &dim) {
40- return {dim, dim};
43+ /* *
44+ * @note We assume the optimizer parameters should be full precision to
45+ * maintain the accuracy even in mixed precision training.
46+ */
47+ TensorDim wm_dim (dim);
48+ TensorDim wv_dim (dim);
49+ wm_dim.setDataType (ml::train::TensorDim::DataType::FP32);
50+ wv_dim.setDataType (ml::train::TensorDim::DataType::FP32);
51+ return {wm_dim, wv_dim};
4152}
4253
4354void AdamW::exportTo (Exporter &exporter,
@@ -51,6 +62,14 @@ void AdamW::setProperty(const std::vector<std::string> &values) {
5162 Optimizer::setProperty (left);
5263}
5364
65+ double AdamW::getUpdatedLearningRate (unsigned int iteration, double lr) const {
66+ auto &beta1 = std::get<PropsB1>(adam_props).get ();
67+ auto &beta2 = std::get<PropsB2>(adam_props).get ();
68+ auto biasCorrection = [&](double f) { return 1.0 - pow (f, iteration + 1 ); };
69+ lr *= sqrt (biasCorrection (beta2)) / biasCorrection (beta1);
70+ return lr;
71+ }
72+
5473void AdamW::applyGradient (RunOptimizerContext &context) {
5574 Tensor empty_tensor;
5675
@@ -68,13 +87,8 @@ void AdamW::applyGradient(RunOptimizerContext &context) {
6887 auto &beta1 = std::get<PropsB1>(adam_props).get ();
6988 auto &beta2 = std::get<PropsB2>(adam_props).get ();
7089 auto &epsilon = std::get<PropsEpsilon>(adam_props).get ();
71- auto &torch_ref = std::get<TorchRef >(adam_props).get ();
90+ auto &weight_decay = std::get<PropsWeightDecayW >(adam_props).get ();
7291
73- // This is implementation of adam from original paper.
74- // This is not deleted intentionally.
75- unsigned int iteration = context.getIteration ();
76- float biasCorrection1 = 1 - pow (beta1, iteration + 1 );
77- float biasCorrection2 = 1 - pow (beta2, iteration + 1 );
7892 Tensor &wm = context.getOptimizerVariable (AdamParams::wm);
7993 Tensor &wv = context.getOptimizerVariable (AdamParams::wv);
8094
@@ -84,16 +98,23 @@ void AdamW::applyGradient(RunOptimizerContext &context) {
8498 wv.multiply_i (beta2);
8599 wv.add_i (x_grad.multiply (x_grad), 1 .0f - beta2);
86100
87- wv.divide_i (biasCorrection2);
101+ // Decoupled weight decay: w = w - lr * wd * w
102+ if (weight_decay > 0.0 ) {
103+ Tensor &w = context.isMixedPrecision () ? context.getWeightFP32 ()
104+ : context.getWeight ();
105+ w.multiply_i (1 .0f - (context.getLearningRate () * weight_decay));
106+ }
107+
108+ // Adam update with bias-corrected lr
109+ double lr_t =
110+ getUpdatedLearningRate (context.getIteration (), context.getLearningRate ());
111+
88112 std::function<double (double )> sqrtEps = [epsilon](double f) {
89113 return 1 / (sqrtDouble (f) + epsilon);
90114 };
91115 x_grad = wv.apply <float >(sqrtEps, x_grad);
92- x_grad.divide_i (biasCorrection1);
93116 x_grad.multiply_i (wm);
94- context.calcWeightDecayGradient ();
95-
96- context.applyGradient (context.getLearningRate (), x_grad);
117+ context.applyGradient (lr_t , x_grad);
97118}
98119
99120} // namespace nntrainer
0 commit comments