3333
3434#include " ../linesearch/more_thuente.h"
3535#include " Eigen/Core"
36- #include " solver.h" // NOLINT
36+ #include " solver.h" // NOLINT
3737
3838namespace cppoptlib ::solver {
3939
@@ -49,7 +49,7 @@ class Lbfgs
4949 cppoptlib::function::DifferentiabilityMode::Second,
5050 " L-BFGS only supports first- or second-order differentiable functions" );
5151
52- private:
52+ private:
5353 using StateType = typename cppoptlib::function::FunctionState<
5454 typename FunctionType::ScalarType, FunctionType::Dimension>;
5555 using Superclass = Solver<FunctionType, StateType>;
@@ -63,7 +63,7 @@ class Lbfgs
6363 using memory_MatrixType = Eigen::Matrix<ScalarType, Eigen::Dynamic, m>;
6464 using memory_VectorType = Eigen::Matrix<ScalarType, 1 , m>;
6565
66- public:
66+ public:
6767 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
6868 using Superclass::Superclass;
6969
@@ -107,43 +107,52 @@ class Lbfgs
107107 // Start with the preconditioned gradient as the initial search direction.
108108 VectorType search_direction = grad_precond;
109109
110- // Determine the number of corrections available for the two-loop recursion.
111- // We exclude the most recent correction (which was just computed) from use.
112- int k = (mem_count_ > 0 ? static_cast <int >(mem_count_) - 1 : 0 );
110+ // Determine the actual number of stored corrections to use
111+ const int k = static_cast <int >(mem_count_);
113112
114- // --- First Loop (Backward Pass) ---
115- // Iterate over stored corrections in reverse chronological order.
113+ // First loop: computes q = q - alpha_i * y_i
114+ // Iterates from the newest correction (k-1) to the oldest (k-m_actual)
115+ // conceptual_idx refers to the chronological order: 0=oldest,
116+ // num_valid_corrections-1=newest
116117 for (int i = k - 1 ; i >= 0 ; i--) {
117118 // Compute the index in chronological order.
118119 // When mem_count_ < m, corrections are stored in order [0 ...
119120 // mem_count_-1]. When full, they are stored cyclically starting at
120121 // mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m.
121- int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
122- const ScalarType denom =
123- x_diff_memory_.col (idx).dot (grad_diff_memory_.col (idx));
124- if (std::abs (denom) < eps) {
122+ const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;
123+
124+ const VectorType &s_col = x_diff_memory_.col (idx);
125+ const VectorType &y_col = grad_diff_memory_.col (idx);
126+
127+ const ScalarType s_dot_y = s_col.dot (y_col);
128+ if (std::abs (s_dot_y) < eps) { // Avoid division by zero or near-zero
125129 continue ;
126130 }
127- const ScalarType rho = 1.0 / denom ;
128- alpha (i) = rho * x_diff_memory_. col (idx) .dot (search_direction);
129- search_direction -= alpha (i) * grad_diff_memory_. col (idx) ;
131+ const ScalarType rho_val = static_cast <ScalarType>( 1.0 ) / s_dot_y ;
132+ alpha (i) = rho_val * s_col .dot (search_direction);
133+ search_direction -= alpha (i) * y_col ;
130134 }
131135
132- // Apply the initial Hessian approximation.
136+ // Apply the initial Hessian approximation H_k^0 = gamma_k * I
137+ // gamma_k = s_{k-1}^T y_{k-1} / (y_{k-1}^T y_{k-1})
138+ // Here, scaling_factor_ is this gamma_k from the *previous* iteration.
133139 search_direction *= scaling_factor_;
134140
135- // --- Second Loop (Forward Pass) ---
141+ // Second loop: computes r = r + s_i * (alpha_i - beta_i)
142+ // Iterates from the oldest correction (k-m_actual) to the newest (k-1)
136143 for (int i = 0 ; i < k; i++) {
137- int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
138- const ScalarType denom =
139- x_diff_memory_.col (idx).dot (grad_diff_memory_.col (idx));
140- if (std::abs (denom) < eps) {
144+ const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;
145+
146+ const VectorType &s_col = x_diff_memory_.col (idx);
147+ const VectorType &y_col = grad_diff_memory_.col (idx);
148+
149+ const ScalarType s_dot_y = s_col.dot (y_col);
150+ if (std::abs (s_dot_y) < eps) {
141151 continue ;
142152 }
143- const ScalarType rho = 1.0 / denom;
144- const ScalarType beta =
145- rho * grad_diff_memory_.col (idx).dot (search_direction);
146- search_direction += x_diff_memory_.col (idx) * (alpha (i) - beta);
153+ const ScalarType rho_val = static_cast <ScalarType>(1.0 ) / s_dot_y;
154+ const ScalarType beta = rho_val * y_col.dot (search_direction);
155+ search_direction += s_col * (alpha (i) - beta);
147156 }
148157
149158 // Check descent direction validity.
@@ -210,21 +219,21 @@ class Lbfgs
210219 return next;
211220 }
212221
213- private:
222+ private:
214223 memory_MatrixType x_diff_memory_;
215224 memory_MatrixType grad_diff_memory_;
216225 // Circular buffer state:
217- size_t mem_count_ = 0 ; // Number of corrections stored so far (max m).
218- size_t mem_pos_ = 0 ; // Index of the oldest correction in the buffer.
226+ size_t mem_count_ = 0 ; // Number of corrections stored so far (max m).
227+ size_t mem_pos_ = 0 ; // Index of the oldest correction in the buffer.
219228
220229 memory_VectorType
221- alpha; // Storage for the coefficients in the two-loop recursion.
230+ alpha; // Storage for the coefficients in the two-loop recursion.
222231 ScalarType scaling_factor_ = 1 ;
223232 // Cautious factor to determine whether to accept a new correction pair.
224233 // You may want to expose this parameter or adjust its default value.
225234 ScalarType cautious_factor_ = 1e-6 ;
226235};
227236
228- } // namespace cppoptlib::solver
237+ } // namespace cppoptlib::solver
229238
230- #endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_
239+ #endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_
0 commit comments