Skip to content

Commit a7af6ed

Browse files
committed
Further improve L-BFGS with pre-conditioning and circular buffer
To avoid copies, we instead compute the correct indicies. And we make use of the Hessian if available.
1 parent 4023c49 commit a7af6ed

File tree

2 files changed

+83
-71
lines changed

2 files changed

+83
-71
lines changed

include/cppoptlib/solver/lbfgs.h

Lines changed: 82 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,146 +11,152 @@
1111
#include "solver.h" // NOLINT
1212

1313
namespace cppoptlib::solver {
14-
namespace internal {
15-
16-
template <int m, class T>
17-
void ShiftLeft(T *matrix) {
18-
matrix->leftCols(m - 1) = matrix->rightCols(m - 1).eval();
19-
}
20-
21-
} // namespace internal
2214

2315
template <typename function_t, int m = 10>
2416
class Lbfgs : public Solver<function_t> {
25-
static_assert(function_t::DiffLevel ==
26-
cppoptlib::function::Differentiability::First ||
27-
function_t::DiffLevel ==
28-
cppoptlib::function::Differentiability::Second,
29-
"L-BFGS only supports first- or second-order "
30-
"differentiable functions");
17+
static_assert(
18+
function_t::DiffLevel == cppoptlib::function::Differentiability::First ||
19+
function_t::DiffLevel ==
20+
cppoptlib::function::Differentiability::Second,
21+
"L-BFGS only supports first- or second-order differentiable functions");
3122

3223
private:
3324
using Superclass = Solver<function_t>;
3425
using progress_t = typename Superclass::progress_t;
3526
using state_t = typename function_t::state_t;
36-
3727
using scalar_t = typename function_t::scalar_t;
38-
using hessian_t = typename function_t::matrix_t;
39-
using matrix_t = typename function_t::matrix_t;
4028
using vector_t = typename function_t::vector_t;
29+
using matrix_t = typename function_t::matrix_t;
4130

31+
// Storage for the correction pairs using Eigen matrices.
4232
using memory_matrix_t = Eigen::Matrix<scalar_t, Eigen::Dynamic, m>;
4333
using memory_vector_t = Eigen::Matrix<scalar_t, 1, m>;
4434

4535
public:
4636
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
47-
4837
using Superclass::Superclass;
4938

5039
void InitializeSolver(const state_t &initial_state) override {
51-
const size_t dim_ = initial_state.x.rows();
52-
x_diff_memory_ = memory_matrix_t::Zero(dim_, m);
53-
grad_diff_memory_ = memory_matrix_t::Zero(dim_, m);
54-
alpha = memory_vector_t::Zero(m);
55-
memory_idx_ = 0;
40+
const size_t dim = initial_state.x.rows();
41+
x_diff_memory_ = memory_matrix_t::Zero(dim, m);
42+
grad_diff_memory_ = memory_matrix_t::Zero(dim, m);
43+
alpha.resize(m);
44+
// Reset the circular buffer:
45+
mem_count_ = 0;
46+
mem_pos_ = 0;
5647
scaling_factor_ = 1;
5748
}
5849

5950
state_t OptimizationStep(const function_t &function, const state_t &current,
6051
const progress_t &progress) override {
61-
vector_t search_direction = current.gradient;
62-
6352
constexpr scalar_t eps = std::numeric_limits<scalar_t>::epsilon();
6453
const scalar_t relative_eps =
6554
static_cast<scalar_t>(eps) *
6655
std::max<scalar_t>(scalar_t{1.0}, current.x.norm());
6756

68-
// Algorithm 7.4 (L-BFGS two-loop recursion)
69-
// // Determine how many stored corrections to use (up to m but not more
70-
// than available).
71-
int k = 0;
72-
if (progress.num_iterations > 0) {
73-
k = std::min<int>(m, memory_idx_ - 1);
57+
// --- Preconditioning ---
58+
// If second-order information is available, use a diagonal preconditioner.
59+
vector_t precond = vector_t::Ones(current.x.size());
60+
if constexpr (function_t::DiffLevel ==
61+
cppoptlib::function::Differentiability::Second) {
62+
precond = current.hessian.diagonal().cwiseAbs().array() + eps;
63+
precond = precond.cwiseInverse();
7464
}
65+
// Precondition the gradient.
66+
vector_t grad_precond = precond.asDiagonal() * current.gradient;
67+
68+
// --- Two-Loop Recursion ---
69+
// Start with the preconditioned gradient as the initial search direction.
70+
vector_t search_direction = grad_precond;
71+
72+
// Determine the number of corrections available for the two-loop recursion.
73+
// We exclude the most recent correction (which was just computed) from use.
74+
int k = (mem_count_ > 0 ? static_cast<int>(mem_count_) - 1 : 0);
7575

76-
// First loop (backward pass) for the L-BFGS two-loop recursion.
76+
// --- First Loop (Backward Pass) ---
77+
// Iterate over stored corrections in reverse chronological order.
7778
for (int i = k - 1; i >= 0; i--) {
78-
// alpha_i <- rho_i*s_i^T*q
79+
// Compute the index in chronological order.
80+
// When mem_count_ < m, corrections are stored in order [0 ...
81+
// mem_count_-1]. When full, they are stored cyclically starting at
82+
// mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m.
83+
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
7984
const scalar_t denom =
80-
x_diff_memory_.col(i).dot(grad_diff_memory_.col(i));
85+
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
8186
if (std::abs(denom) < eps) {
8287
continue;
8388
}
8489
const scalar_t rho = 1.0 / denom;
85-
alpha(i) = rho * x_diff_memory_.col(i).dot(search_direction);
86-
// q <- q - alpha_i*y_i
87-
search_direction -= alpha(i) * grad_diff_memory_.col(i);
90+
alpha(i) = rho * x_diff_memory_.col(idx).dot(search_direction);
91+
search_direction -= alpha(i) * grad_diff_memory_.col(idx);
8892
}
8993

90-
// apply initial Hessian approximation: r <- H_k^0*q
94+
// Apply the initial Hessian approximation.
9195
search_direction *= scaling_factor_;
9296

93-
// Second loop (forward pass).
97+
// --- Second Loop (Forward Pass) ---
9498
for (int i = 0; i < k; i++) {
95-
// beta <- rho_i * y_i^T * r
99+
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
96100
const scalar_t denom =
97-
x_diff_memory_.col(i).dot(grad_diff_memory_.col(i));
101+
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
98102
if (std::abs(denom) < eps) {
99103
continue;
100104
}
101105
const scalar_t rho = 1.0 / denom;
102106
const scalar_t beta =
103-
rho * grad_diff_memory_.col(i).dot(search_direction);
104-
// r <- r + s_i * ( alpha_i - beta)
105-
search_direction += x_diff_memory_.col(i) * (alpha(i) - beta);
107+
rho * grad_diff_memory_.col(idx).dot(search_direction);
108+
search_direction += x_diff_memory_.col(idx) * (alpha(i) - beta);
106109
}
107110

108-
// stop with result "H_k*f_f'=q"
109-
110-
// any issues with the descent direction ?
111-
// Check the descent direction for validity.
111+
// Check descent direction validity.
112112
scalar_t descent_direction = -current.gradient.dot(search_direction);
113113
scalar_t alpha_init =
114114
(current.gradient.norm() > eps) ? 1.0 / current.gradient.norm() : 1.0;
115115
if (!std::isfinite(descent_direction) ||
116116
descent_direction > -eps * relative_eps) {
117-
// If the descent direction is invalid or not a descent, revert to
118-
// steepest descent.
119-
search_direction = -current.gradient.eval();
120-
memory_idx_ = 0;
117+
// Fall back to steepest descent if necessary.
118+
search_direction = -current.gradient;
119+
// Reset the correction history if the descent is invalid.
120+
mem_count_ = 0;
121+
mem_pos_ = 0;
121122
alpha_init = 1.0;
122123
}
123124

125+
// Perform a line search.
124126
const scalar_t rate = linesearch::MoreThuente<function_t, 1>::Search(
125127
current.x, -search_direction, function, alpha_init);
126128

127129
const state_t next = function.GetState(current.x - rate * search_direction);
128130

131+
// Compute the differences for the new correction pair.
129132
const vector_t x_diff = next.x - current.x;
130133
const vector_t grad_diff = next.gradient - current.gradient;
131134

132-
// Update the history
133-
if (x_diff.dot(grad_diff) > eps * grad_diff.squaredNorm()) {
134-
if (memory_idx_ < m) {
135-
x_diff_memory_.col(memory_idx_) = x_diff.eval();
136-
grad_diff_memory_.col(memory_idx_) = grad_diff.eval();
135+
// --- Curvature Condition Check with Cautious Update ---
136+
// We require:
137+
// x_diff.dot(grad_diff) > ||x_diff||^2 * ||current.gradient|| *
138+
// cautious_factor_
139+
const scalar_t threshold =
140+
x_diff.squaredNorm() * current.gradient.norm() * cautious_factor_;
141+
if (x_diff.dot(grad_diff) > threshold) {
142+
// Add the new correction pair into the circular buffer.
143+
if (mem_count_ < static_cast<size_t>(m)) {
144+
// Still have free space.
145+
x_diff_memory_.col(mem_count_) = x_diff;
146+
grad_diff_memory_.col(mem_count_) = grad_diff;
147+
mem_count_++;
137148
} else {
138-
internal::ShiftLeft<m>(&x_diff_memory_);
139-
internal::ShiftLeft<m>(&grad_diff_memory_);
140-
141-
x_diff_memory_.rightCols(1) = x_diff;
142-
grad_diff_memory_.rightCols(1) = grad_diff;
149+
// Buffer full; overwrite the oldest correction.
150+
x_diff_memory_.col(mem_pos_) = x_diff;
151+
grad_diff_memory_.col(mem_pos_) = grad_diff;
152+
mem_pos_ = (mem_pos_ + 1) % m;
143153
}
144-
145-
memory_idx_++;
146154
}
147-
// Adaptive damping in Hessian approximation: update the scaling factor.
148-
constexpr scalar_t fallback_value =
149-
scalar_t(1e7); // Fallback value if update is unstable.
155+
// Update the scaling factor (adaptive damping).
156+
constexpr scalar_t fallback_value = scalar_t(1e7);
150157
const scalar_t grad_diff_norm_sq = grad_diff.dot(grad_diff);
151158
if (std::abs(grad_diff_norm_sq) > eps) {
152159
scalar_t temp_scaling = grad_diff.dot(x_diff) / grad_diff_norm_sq;
153-
// If temp_scaling is non-finite or excessively large, use fallback.
154160
if (!std::isfinite(temp_scaling) ||
155161
std::abs(temp_scaling) > fallback_value) {
156162
scaling_factor_ = fallback_value;
@@ -167,10 +173,16 @@ class Lbfgs : public Solver<function_t> {
167173
private:
168174
memory_matrix_t x_diff_memory_;
169175
memory_matrix_t grad_diff_memory_;
170-
size_t memory_idx_ = 0;
176+
// Circular buffer state:
177+
size_t mem_count_ = 0; // Number of corrections stored so far (max m).
178+
size_t mem_pos_ = 0; // Index of the oldest correction in the buffer.
171179

172-
memory_vector_t alpha;
180+
memory_vector_t
181+
alpha; // Storage for the coefficients in the two-loop recursion.
173182
scalar_t scaling_factor_ = 1;
183+
// Cautious factor to determine whether to accept a new correction pair.
184+
// You may want to expose this parameter or adjust its default value.
185+
scalar_t cautious_factor_ = 1e-6;
174186
};
175187

176188
} // namespace cppoptlib::solver

src/examples/simple.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using FunctionXd = cppoptlib::function::Function<
1818
double, Eigen::Dynamic, cppoptlib::function::Differentiability::Second>;
1919

2020
class Function : public FunctionXd {
21-
public:
21+
public:
2222
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
2323

2424
scalar_t operator()(const vector_t &x, vector_t *gradient = nullptr,

0 commit comments

Comments
 (0)