1111#include " solver.h" // NOLINT
1212
1313namespace cppoptlib ::solver {
14- namespace internal {
15-
16- template <int m, class T >
17- void ShiftLeft (T *matrix) {
18- matrix->leftCols (m - 1 ) = matrix->rightCols (m - 1 ).eval ();
19- }
20-
21- } // namespace internal
2214
2315template <typename function_t , int m = 10 >
2416class Lbfgs : public Solver <function_t > {
25- static_assert (function_t ::DiffLevel ==
26- cppoptlib::function::Differentiability::First ||
27- function_t ::DiffLevel ==
28- cppoptlib::function::Differentiability::Second,
29- " L-BFGS only supports first- or second-order "
30- " differentiable functions" );
17+ static_assert (
18+ function_t ::DiffLevel == cppoptlib::function::Differentiability::First ||
19+ function_t ::DiffLevel ==
20+ cppoptlib::function::Differentiability::Second,
21+ " L-BFGS only supports first- or second-order differentiable functions" );
3122
3223 private:
3324 using Superclass = Solver<function_t >;
3425 using progress_t = typename Superclass::progress_t ;
3526 using state_t = typename function_t ::state_t ;
36-
3727 using scalar_t = typename function_t ::scalar_t ;
38- using hessian_t = typename function_t ::matrix_t ;
39- using matrix_t = typename function_t ::matrix_t ;
4028 using vector_t = typename function_t ::vector_t ;
29+ using matrix_t = typename function_t ::matrix_t ;
4130
31+ // Storage for the correction pairs using Eigen matrices.
4232 using memory_matrix_t = Eigen::Matrix<scalar_t , Eigen::Dynamic, m>;
4333 using memory_vector_t = Eigen::Matrix<scalar_t , 1 , m>;
4434
4535 public:
4636 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
47-
4837 using Superclass::Superclass;
4938
5039 void InitializeSolver (const state_t &initial_state) override {
51- const size_t dim_ = initial_state.x .rows ();
52- x_diff_memory_ = memory_matrix_t::Zero (dim_, m);
53- grad_diff_memory_ = memory_matrix_t::Zero (dim_, m);
54- alpha = memory_vector_t::Zero (m);
55- memory_idx_ = 0 ;
40+ const size_t dim = initial_state.x .rows ();
41+ x_diff_memory_ = memory_matrix_t::Zero (dim, m);
42+ grad_diff_memory_ = memory_matrix_t::Zero (dim, m);
43+ alpha.resize (m);
44+ // Reset the circular buffer:
45+ mem_count_ = 0 ;
46+ mem_pos_ = 0 ;
5647 scaling_factor_ = 1 ;
5748 }
5849
5950 state_t OptimizationStep (const function_t &function, const state_t ¤t,
6051 const progress_t &progress) override {
61- vector_t search_direction = current.gradient ;
62-
6352 constexpr scalar_t eps = std::numeric_limits<scalar_t >::epsilon ();
6453 const scalar_t relative_eps =
6554 static_cast <scalar_t >(eps) *
6655 std::max<scalar_t >(scalar_t {1.0 }, current.x .norm ());
6756
68- // Algorithm 7.4 (L-BFGS two-loop recursion)
69- // // Determine how many stored corrections to use (up to m but not more
70- // than available).
71- int k = 0 ;
72- if (progress.num_iterations > 0 ) {
73- k = std::min<int >(m, memory_idx_ - 1 );
57+ // --- Preconditioning ---
58+ // If second-order information is available, use a diagonal preconditioner.
59+ vector_t precond = vector_t::Ones (current.x .size ());
60+ if constexpr (function_t ::DiffLevel ==
61+ cppoptlib::function::Differentiability::Second) {
62+ precond = current.hessian .diagonal ().cwiseAbs ().array () + eps;
63+ precond = precond.cwiseInverse ();
7464 }
65+ // Precondition the gradient.
66+ vector_t grad_precond = precond.asDiagonal () * current.gradient ;
67+
68+ // --- Two-Loop Recursion ---
69+ // Start with the preconditioned gradient as the initial search direction.
70+ vector_t search_direction = grad_precond;
71+
72+ // Determine the number of corrections available for the two-loop recursion.
73+ // We exclude the most recent correction (which was just computed) from use.
74+ int k = (mem_count_ > 0 ? static_cast <int >(mem_count_) - 1 : 0 );
7575
76- // First loop (backward pass) for the L-BFGS two-loop recursion.
76+ // --- First Loop (Backward Pass) ---
77+ // Iterate over stored corrections in reverse chronological order.
7778 for (int i = k - 1 ; i >= 0 ; i--) {
78- // alpha_i <- rho_i*s_i^T*q
79+ // Compute the index in chronological order.
80+ // When mem_count_ < m, corrections are stored in order [0 ...
81+ // mem_count_-1]. When full, they are stored cyclically starting at
82+ // mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m.
83+ int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
7984 const scalar_t denom =
80- x_diff_memory_.col (i ).dot (grad_diff_memory_.col (i ));
85+ x_diff_memory_.col (idx ).dot (grad_diff_memory_.col (idx ));
8186 if (std::abs (denom) < eps) {
8287 continue ;
8388 }
8489 const scalar_t rho = 1.0 / denom;
85- alpha (i) = rho * x_diff_memory_.col (i).dot (search_direction);
86- // q <- q - alpha_i*y_i
87- search_direction -= alpha (i) * grad_diff_memory_.col (i);
90+ alpha (i) = rho * x_diff_memory_.col (idx).dot (search_direction);
91+ search_direction -= alpha (i) * grad_diff_memory_.col (idx);
8892 }
8993
90- // apply initial Hessian approximation: r <- H_k^0*q
94+ // Apply the initial Hessian approximation.
9195 search_direction *= scaling_factor_;
9296
93- // Second loop (forward pass).
97+ // --- Second Loop (Forward Pass) ---
9498 for (int i = 0 ; i < k; i++) {
95- // beta <- rho_i * y_i^T * r
99+ int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
96100 const scalar_t denom =
97- x_diff_memory_.col (i ).dot (grad_diff_memory_.col (i ));
101+ x_diff_memory_.col (idx ).dot (grad_diff_memory_.col (idx ));
98102 if (std::abs (denom) < eps) {
99103 continue ;
100104 }
101105 const scalar_t rho = 1.0 / denom;
102106 const scalar_t beta =
103- rho * grad_diff_memory_.col (i).dot (search_direction);
104- // r <- r + s_i * ( alpha_i - beta)
105- search_direction += x_diff_memory_.col (i) * (alpha (i) - beta);
107+ rho * grad_diff_memory_.col (idx).dot (search_direction);
108+ search_direction += x_diff_memory_.col (idx) * (alpha (i) - beta);
106109 }
107110
108- // stop with result "H_k*f_f'=q"
109-
110- // any issues with the descent direction ?
111- // Check the descent direction for validity.
111+ // Check descent direction validity.
112112 scalar_t descent_direction = -current.gradient .dot (search_direction);
113113 scalar_t alpha_init =
114114 (current.gradient .norm () > eps) ? 1.0 / current.gradient .norm () : 1.0 ;
115115 if (!std::isfinite (descent_direction) ||
116116 descent_direction > -eps * relative_eps) {
117- // If the descent direction is invalid or not a descent, revert to
118- // steepest descent.
119- search_direction = -current.gradient .eval ();
120- memory_idx_ = 0 ;
117+ // Fall back to steepest descent if necessary.
118+ search_direction = -current.gradient ;
119+ // Reset the correction history if the descent is invalid.
120+ mem_count_ = 0 ;
121+ mem_pos_ = 0 ;
121122 alpha_init = 1.0 ;
122123 }
123124
125+ // Perform a line search.
124126 const scalar_t rate = linesearch::MoreThuente<function_t , 1 >::Search (
125127 current.x , -search_direction, function, alpha_init);
126128
127129 const state_t next = function.GetState (current.x - rate * search_direction);
128130
131+ // Compute the differences for the new correction pair.
129132 const vector_t x_diff = next.x - current.x ;
130133 const vector_t grad_diff = next.gradient - current.gradient ;
131134
132- // Update the history
133- if (x_diff.dot (grad_diff) > eps * grad_diff.squaredNorm ()) {
134- if (memory_idx_ < m) {
135- x_diff_memory_.col (memory_idx_) = x_diff.eval ();
136- grad_diff_memory_.col (memory_idx_) = grad_diff.eval ();
135+ // --- Curvature Condition Check with Cautious Update ---
136+ // We require:
137+ // x_diff.dot(grad_diff) > ||x_diff||^2 * ||current.gradient|| *
138+ // cautious_factor_
139+ const scalar_t threshold =
140+ x_diff.squaredNorm () * current.gradient .norm () * cautious_factor_;
141+ if (x_diff.dot (grad_diff) > threshold) {
142+ // Add the new correction pair into the circular buffer.
143+ if (mem_count_ < static_cast <size_t >(m)) {
144+ // Still have free space.
145+ x_diff_memory_.col (mem_count_) = x_diff;
146+ grad_diff_memory_.col (mem_count_) = grad_diff;
147+ mem_count_++;
137148 } else {
138- internal::ShiftLeft<m>(&x_diff_memory_);
139- internal::ShiftLeft<m>(&grad_diff_memory_);
140-
141- x_diff_memory_.rightCols (1 ) = x_diff;
142- grad_diff_memory_.rightCols (1 ) = grad_diff;
149+ // Buffer full; overwrite the oldest correction.
150+ x_diff_memory_.col (mem_pos_) = x_diff;
151+ grad_diff_memory_.col (mem_pos_) = grad_diff;
152+ mem_pos_ = (mem_pos_ + 1 ) % m;
143153 }
144-
145- memory_idx_++;
146154 }
147- // Adaptive damping in Hessian approximation: update the scaling factor.
148- constexpr scalar_t fallback_value =
149- scalar_t (1e7 ); // Fallback value if update is unstable.
155+ // Update the scaling factor (adaptive damping).
156+ constexpr scalar_t fallback_value = scalar_t (1e7 );
150157 const scalar_t grad_diff_norm_sq = grad_diff.dot (grad_diff);
151158 if (std::abs (grad_diff_norm_sq) > eps) {
152159 scalar_t temp_scaling = grad_diff.dot (x_diff) / grad_diff_norm_sq;
153- // If temp_scaling is non-finite or excessively large, use fallback.
154160 if (!std::isfinite (temp_scaling) ||
155161 std::abs (temp_scaling) > fallback_value) {
156162 scaling_factor_ = fallback_value;
@@ -167,10 +173,16 @@ class Lbfgs : public Solver<function_t> {
167173 private:
168174 memory_matrix_t x_diff_memory_;
169175 memory_matrix_t grad_diff_memory_;
170- size_t memory_idx_ = 0 ;
176+ // Circular buffer state:
177+ size_t mem_count_ = 0 ; // Number of corrections stored so far (max m).
178+ size_t mem_pos_ = 0 ; // Index of the oldest correction in the buffer.
171179
172- memory_vector_t alpha;
180+ memory_vector_t
181+ alpha; // Storage for the coefficients in the two-loop recursion.
173182 scalar_t scaling_factor_ = 1 ;
183+ // Cautious factor to determine whether to accept a new correction pair.
184+ // You may want to expose this parameter or adjust its default value.
185+ scalar_t cautious_factor_ = 1e-6 ;
174186};
175187
176188} // namespace cppoptlib::solver
0 commit comments