Skip to content

Commit b2dd416

Browse files
committed
Fix off-by-one error in L-BFGS Solver Implementation
Accessing the memory did skip one entry.
1 parent 1435ef7 commit b2dd416

File tree

1 file changed

+40
-31
lines changed

1 file changed

+40
-31
lines changed

include/cppoptlib/solver/lbfgs.h

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include "../linesearch/more_thuente.h"
3535
#include "Eigen/Core"
36-
#include "solver.h" // NOLINT
36+
#include "solver.h" // NOLINT
3737

3838
namespace cppoptlib::solver {
3939

@@ -49,7 +49,7 @@ class Lbfgs
4949
cppoptlib::function::DifferentiabilityMode::Second,
5050
"L-BFGS only supports first- or second-order differentiable functions");
5151

52-
private:
52+
private:
5353
using StateType = typename cppoptlib::function::FunctionState<
5454
typename FunctionType::ScalarType, FunctionType::Dimension>;
5555
using Superclass = Solver<FunctionType, StateType>;
@@ -63,7 +63,7 @@ class Lbfgs
6363
using memory_MatrixType = Eigen::Matrix<ScalarType, Eigen::Dynamic, m>;
6464
using memory_VectorType = Eigen::Matrix<ScalarType, 1, m>;
6565

66-
public:
66+
public:
6767
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
6868
using Superclass::Superclass;
6969

@@ -107,43 +107,52 @@ class Lbfgs
107107
// Start with the preconditioned gradient as the initial search direction.
108108
VectorType search_direction = grad_precond;
109109

110-
// Determine the number of corrections available for the two-loop recursion.
111-
// We exclude the most recent correction (which was just computed) from use.
112-
int k = (mem_count_ > 0 ? static_cast<int>(mem_count_) - 1 : 0);
110+
// Determine the actual number of stored corrections to use
111+
const int k = static_cast<int>(mem_count_);
113112

114-
// --- First Loop (Backward Pass) ---
115-
// Iterate over stored corrections in reverse chronological order.
113+
// First loop: computes q = q - alpha_i * y_i
114+
// Iterates from the newest correction (k-1) to the oldest (k-m_actual)
115+
// conceptual_idx refers to the chronological order: 0=oldest,
116+
// num_valid_corrections-1=newest
116117
for (int i = k - 1; i >= 0; i--) {
117118
// Compute the index in chronological order.
118119
// When mem_count_ < m, corrections are stored in order [0 ...
119120
// mem_count_-1]. When full, they are stored cyclically starting at
120121
// mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m.
121-
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
122-
const ScalarType denom =
123-
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
124-
if (std::abs(denom) < eps) {
122+
const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;
123+
124+
const VectorType &s_col = x_diff_memory_.col(idx);
125+
const VectorType &y_col = grad_diff_memory_.col(idx);
126+
127+
const ScalarType s_dot_y = s_col.dot(y_col);
128+
if (std::abs(s_dot_y) < eps) { // Avoid division by zero or near-zero
125129
continue;
126130
}
127-
const ScalarType rho = 1.0 / denom;
128-
alpha(i) = rho * x_diff_memory_.col(idx).dot(search_direction);
129-
search_direction -= alpha(i) * grad_diff_memory_.col(idx);
131+
const ScalarType rho_val = static_cast<ScalarType>(1.0) / s_dot_y;
132+
alpha(i) = rho_val * s_col.dot(search_direction);
133+
search_direction -= alpha(i) * y_col;
130134
}
131135

132-
// Apply the initial Hessian approximation.
136+
// Apply the initial Hessian approximation H_k^0 = gamma_k * I
137+
// gamma_k = s_{k-1}^T y_{k-1} / (y_{k-1}^T y_{k-1})
138+
// Here, scaling_factor_ is this gamma_k from the *previous* iteration.
133139
search_direction *= scaling_factor_;
134140

135-
// --- Second Loop (Forward Pass) ---
141+
// Second loop: computes r = r + s_i * (alpha_i - beta_i)
142+
// Iterates from the oldest correction (k-m_actual) to the newest (k-1)
136143
for (int i = 0; i < k; i++) {
137-
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
138-
const ScalarType denom =
139-
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
140-
if (std::abs(denom) < eps) {
144+
const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;
145+
146+
const VectorType &s_col = x_diff_memory_.col(idx);
147+
const VectorType &y_col = grad_diff_memory_.col(idx);
148+
149+
const ScalarType s_dot_y = s_col.dot(y_col);
150+
if (std::abs(s_dot_y) < eps) {
141151
continue;
142152
}
143-
const ScalarType rho = 1.0 / denom;
144-
const ScalarType beta =
145-
rho * grad_diff_memory_.col(idx).dot(search_direction);
146-
search_direction += x_diff_memory_.col(idx) * (alpha(i) - beta);
153+
const ScalarType rho_val = static_cast<ScalarType>(1.0) / s_dot_y;
154+
const ScalarType beta = rho_val * y_col.dot(search_direction);
155+
search_direction += s_col * (alpha(i) - beta);
147156
}
148157

149158
// Check descent direction validity.
@@ -210,21 +219,21 @@ class Lbfgs
210219
return next;
211220
}
212221

213-
private:
222+
private:
214223
memory_MatrixType x_diff_memory_;
215224
memory_MatrixType grad_diff_memory_;
216225
// Circular buffer state:
217-
size_t mem_count_ = 0; // Number of corrections stored so far (max m).
218-
size_t mem_pos_ = 0; // Index of the oldest correction in the buffer.
226+
size_t mem_count_ = 0; // Number of corrections stored so far (max m).
227+
size_t mem_pos_ = 0; // Index of the oldest correction in the buffer.
219228

220229
memory_VectorType
221-
alpha; // Storage for the coefficients in the two-loop recursion.
230+
alpha; // Storage for the coefficients in the two-loop recursion.
222231
ScalarType scaling_factor_ = 1;
223232
// Cautious factor to determine whether to accept a new correction pair.
224233
// You may want to expose this parameter or adjust its default value.
225234
ScalarType cautious_factor_ = 1e-6;
226235
};
227236

228-
} // namespace cppoptlib::solver
237+
} // namespace cppoptlib::solver
229238

230-
#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_
239+
#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_

0 commit comments

Comments
 (0)