Skip to content

Commit daa5571

Browse files
committed
Reduce amount of work in State
1 parent a7af6ed commit daa5571

File tree

16 files changed

+182
-205
lines changed

16 files changed

+182
-205
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ jobs:
5858
- name: Build
5959
run: |
6060
bazel run simple
61+
bazel run constrained_simple

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,21 @@ int main(int argc, char const *argv[]) {
103103
std::cout << "Initial point: " << x.transpose() << std::endl;
104104

105105
// Evaluate
106-
auto state = f.GetState(x);
107-
std::cout << "Function value at initial point: " << f(x) << std::endl;
108-
std::cout << "Gradient at initial point: " << state.gradient << std::endl;
106+
Eigen::VectorXd gradient(2);
107+
double value = f(x, &gradient);
108+
std::cout << "Function value at initial point: " << value << std::endl;
109+
std::cout << "Gradient at initial point: " << gradient << std::endl;
109110

110111
// Minimize the Rosenbrock function using the BFGS solver.
111112
using Solver = cppoptlib::solver::Bfgs<Rosenbrock>;
113+
auto init_state = f.GetState(x);
112114
Solver solver;
113-
auto [solution_state, solver_progress] = solver.Minimize(f, x);
115+
auto [solution_state, solver_progress] = solver.Minimize(f, init_state);
114116

115117
// Display the results of the optimization.
116118
std::cout << "Optimal solution: " << solution_state.x.transpose() << std::endl;
117-
std::cout << "Optimal function value: " << solution_state.value << std::endl;
119+
std::cout << "Optimal function value: " << f(solution_state.x) << std::endl;
120+
118121
std::cout << "Number of iterations: " << solver_progress.num_iterations << std::endl;
119122
std::cout << "Solver status: " << solver_progress.status << std::endl;
120123

include/cppoptlib/constrained_function.h

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ struct ConstrainedState : public State<base_t> {
2323
using unconstrained_base_t = FunctionBase<typename base_t::scalar_t,
2424
base_t::Dim, base_t::DiffLevel, 0>;
2525

26-
typename base_t::scalar_t value = 0;
2726
typename base_t::vector_t x;
28-
typename base_t::vector_t gradient;
29-
3027
std::array<typename base_t::scalar_t, base_t::NumConstraints>
3128
lagrange_multipliers;
3229
std::array<typename base_t::scalar_t, base_t::NumConstraints> violations;
@@ -46,19 +43,15 @@ struct ConstrainedState : public State<base_t> {
4643
}
4744

4845
void CopyState(const self_t &rhs) {
49-
value = rhs.value;
5046
x = rhs.x.eval();
51-
gradient = rhs.gradient.eval();
5247
penalty = rhs.penalty;
5348
lagrange_multipliers = rhs.lagrange_multipliers;
5449
violations = rhs.violations;
5550
}
5651

5752
State<unconstrained_base_t> AsUnconstrained() const {
5853
State<unconstrained_base_t> state;
59-
state.value = value;
6054
state.x = x.eval();
61-
state.gradient = gradient.eval();
6255
return state;
6356
}
6457
};
@@ -90,18 +83,7 @@ class UnconstrainedFunctionAdapter
9083
const typename cfunction_t::state_t inner = constrained_function.GetState(
9184
x, constrained_state.lagrange_multipliers, constrained_state.penalty);
9285
typename cfunction_t::unconstrained_function_t::state_t unconstrained_state;
93-
unconstrained_state.value = inner.value;
9486
unconstrained_state.x = inner.x;
95-
if constexpr ((cfunction_t::unconstrained_function_t::Differentiability ==
96-
Differentiability::First) ||
97-
(cfunction_t::unconstrained_function_t::Differentiability ==
98-
Differentiability::Second)) {
99-
unconstrained_state.gradient = inner.gradient;
100-
}
101-
if constexpr (cfunction_t::unconstrained_function_t::Differentiability ==
102-
Differentiability::Second) {
103-
unconstrained_state.hessian = inner.hessian;
104-
}
10587
return unconstrained_state;
10688
}
10789

@@ -185,30 +167,12 @@ struct ConstrainedFunction {
185167

186168
state_t constrained_state;
187169
constrained_state.x = objective_state.x;
188-
constrained_state.value = objective_state.value;
189-
constrained_state.gradient = objective_state.gradient;
170+
constrained_state.penalty = penalty;
190171

191-
// Sum augmented penalties for hard constraints.
192172
for (std::size_t i = 0; i < TNumConstraints; ++i) {
193-
const typename function_t::state_t constraint_state =
194-
constraints_[i]->GetState(x);
195-
const scalar_t cost = constraint_state.value;
196-
const scalar_t violation = cost;
197-
198-
const scalar_t lambda = lagrange_multipliers[i];
199-
const scalar_t aug_cost =
200-
violation + lambda * violation +
201-
static_cast<scalar_t>(0.5) * penalty * violation * violation;
202-
constrained_state.value += aug_cost;
203-
// Augmented gradient (only active if the constraint is violated).
204-
const scalar_t a = scalar_t(1) + lambda + penalty * violation;
205-
const typename base_t::vector_t scaled_local_grad =
206-
a * constraint_state.gradient;
207-
typename base_t::vector_t aug_grad =
208-
(cost > scalar_t(0)) ? scaled_local_grad
209-
: base_t::vector_t::Zero(x.size());
210-
constrained_state.gradient += aug_grad;
173+
const scalar_t violation = constraints_[i]->operator()(x);
211174
constrained_state.violations[i] = violation;
175+
constrained_state.lagrange_multipliers[i] = lagrange_multipliers[i];
212176
}
213177

214178
return constrained_state;

include/cppoptlib/function.h

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -45,75 +45,20 @@ class Function : public FunctionBase<TScalar, TDim, TDifferentiability> {
4545
using state_t = State<Function<TScalar, TDim, TDifferentiability>>;
4646
};
4747

48-
template <class TScalar, int TDim>
49-
struct State<FunctionBase<TScalar, TDim, Differentiability::None>> {
50-
using base_t = FunctionBase<TScalar, TDim, Differentiability::None>;
51-
using state_t = State<base_t>;
52-
53-
typename base_t::scalar_t value = 0;
54-
typename base_t::vector_t x;
55-
56-
State() = default;
57-
58-
State(const state_t &rhs) : value(rhs.value), x(rhs.x.eval()) {}
59-
60-
state_t &operator=(const state_t &rhs) {
61-
if (this != &rhs) {
62-
value = rhs.value;
63-
x = rhs.x.eval();
64-
}
65-
return *this;
66-
}
67-
};
68-
69-
template <class TScalar, int TDim>
70-
struct State<FunctionBase<TScalar, TDim, Differentiability::First>> {
71-
using base_t = FunctionBase<TScalar, TDim, Differentiability::First>;
72-
using state_t = State<base_t>;
73-
74-
typename base_t::scalar_t value = 0;
75-
typename base_t::vector_t x;
76-
typename base_t::vector_t gradient;
77-
78-
State() = default;
79-
80-
State(const state_t &rhs)
81-
: value(rhs.value), x(rhs.x.eval()), gradient(rhs.gradient.eval()) {}
82-
83-
state_t &operator=(const state_t &rhs) {
84-
if (this != &rhs) {
85-
value = rhs.value;
86-
x = rhs.x.eval();
87-
gradient = rhs.gradient.eval();
88-
}
89-
return *this;
90-
}
91-
};
92-
93-
template <class TScalar, int TDim>
94-
struct State<FunctionBase<TScalar, TDim, Differentiability::Second>> {
95-
using base_t = FunctionBase<TScalar, TDim, Differentiability::Second>;
48+
template <class TScalar, int TDim, Differentiability TDifferentiability>
49+
struct State<FunctionBase<TScalar, TDim, TDifferentiability>> {
50+
using base_t = FunctionBase<TScalar, TDim, TDifferentiability>;
9651
using state_t = State<base_t>;
9752

98-
typename base_t::scalar_t value = 0;
9953
typename base_t::vector_t x;
100-
typename base_t::vector_t gradient;
101-
typename base_t::matrix_t hessian;
10254

10355
State() = default;
10456

105-
State(const state_t &rhs)
106-
: value(rhs.value),
107-
x(rhs.x.eval()),
108-
gradient(rhs.gradient.eval()),
109-
hessian(rhs.hessian.eval()) {}
57+
State(const state_t &rhs) : x(rhs.x.eval()) {}
11058

11159
state_t &operator=(const state_t &rhs) {
11260
if (this != &rhs) {
113-
value = rhs.value;
11461
x = rhs.x.eval();
115-
gradient = rhs.gradient.eval();
116-
hessian = rhs.hessian.eval();
11762
}
11863
return *this;
11964
}
@@ -141,7 +86,6 @@ class Function<TScalar, TDim, Differentiability::None>
14186
state_t GetState(const typename base_t::vector_t &x) const {
14287
state_t state;
14388
state.x = x;
144-
state.value = this->operator()(x);
14589
return state;
14690
}
14791
};
@@ -168,7 +112,6 @@ class Function<TScalar, TDim, Differentiability::First>
168112
state_t GetState(const typename base_t::vector_t &x) const {
169113
state_t state;
170114
state.x = x;
171-
state.value = this->operator()(x, &state.gradient);
172115
return state;
173116
}
174117
};
@@ -196,7 +139,6 @@ class Function<TScalar, TDim, Differentiability::Second>
196139
state_t GetState(const typename base_t::vector_t &x) const {
197140
state_t state;
198141
state.x = x;
199-
state.value = this->operator()(x, &state.gradient, &state.hessian);
200142
return state;
201143
}
202144
};

include/cppoptlib/solver/augmented_lagrangian.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class AugmentedLagrangian : public Solver<function_t> {
3434
AugmentedLagrangian(const solver_t &inner_solver)
3535
: inner_solver_(inner_solver) {}
3636

37-
void InitializeSolver(const state_t & /*initial_state*/) override {}
37+
void InitializeSolver(const function_t & /*function*/,
38+
const state_t & /*initial_state*/) override {}
3839

3940
state_t OptimizationStep(const function_t &function, const state_t &state,
4041
const progress_t & /*progress*/) override {

include/cppoptlib/solver/bfgs.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,38 @@ class Bfgs : public Solver<function_t> {
3232

3333
using Superclass::Superclass;
3434

35-
void InitializeSolver(const state_t &initial_state) override {
35+
void InitializeSolver(const function_t & /*function*/,
36+
const state_t &initial_state) override {
3637
dim_ = initial_state.x.rows();
3738
inverse_hessian_ =
3839
matrix_t::Identity(initial_state.x.rows(), initial_state.x.rows());
3940
}
4041

4142
state_t OptimizationStep(const function_t &function, const state_t &current,
4243
const progress_t & /*progress*/) override {
43-
vector_t search_direction = -inverse_hessian_ * current.gradient;
44+
vector_t current_gradient;
45+
function(current.x, &current_gradient);
46+
47+
vector_t search_direction = -inverse_hessian_ * current_gradient;
4448

4549
// If not positive definit re-initialize Hessian.
46-
const scalar_t phi = current.gradient.dot(search_direction);
50+
const scalar_t phi = current_gradient.dot(search_direction);
4751
if ((phi > 0) || std::isnan(phi)) {
4852
// no, we reset the hessian approximation
4953
inverse_hessian_ = matrix_t::Identity(dim_, dim_);
50-
search_direction = -current.gradient;
54+
search_direction = -current_gradient;
5155
}
5256

5357
const scalar_t rate = linesearch::MoreThuente<function_t, 1>::Search(
5458
current.x, search_direction, function);
5559

5660
const state_t next = function.GetState(current.x + rate * search_direction);
61+
vector_t next_gradient;
62+
function(next.x, &next_gradient);
5763

5864
// Update inverse Hessian estimate.
5965
const vector_t s = rate * search_direction;
60-
const vector_t y = next.gradient - current.gradient;
66+
const vector_t y = next_gradient - current_gradient;
6167
const scalar_t rho = 1.0 / y.dot(s);
6268

6369
inverse_hessian_ =

include/cppoptlib/solver/conjugated_gradient_descent.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,23 @@ class ConjugatedGradientDescent : public Solver<function_t> {
3232

3333
using Superclass::Superclass;
3434

35-
void InitializeSolver(const state_t &initial_state) override {
36-
previous_gradient_ = initial_state.gradient;
35+
void InitializeSolver(const function_t &function,
36+
const state_t &initial_state) override {
37+
function(initial_state.x, &previous_gradient_);
3738
}
3839

3940
state_t OptimizationStep(const function_t &function, const state_t &current,
4041
const progress_t &progress) override {
42+
vector_t current_gradient;
43+
function(current.x, &current_gradient);
4144
if (progress.num_iterations == 0) {
42-
search_direction_ = -current.gradient;
45+
search_direction_ = -current_gradient;
4346
} else {
44-
const double beta = current.gradient.dot(current.gradient) /
47+
const double beta = current_gradient.dot(current_gradient) /
4548
(previous_gradient_.dot(previous_gradient_));
46-
search_direction_ = -current.gradient + beta * search_direction_;
49+
search_direction_ = -current_gradient + beta * search_direction_;
4750
}
48-
previous_gradient_ = current.gradient;
51+
previous_gradient_ = current_gradient;
4952

5053
const scalar_t rate = linesearch::Armijo<function_t, 1>::Search(
5154
current.x, search_direction_, function);

include/cppoptlib/solver/gradient_descent.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@ class GradientDescent : public Solver<function_t> {
3131

3232
using Superclass::Superclass;
3333

34-
void InitializeSolver(const state_t & /*initial_state*/) override {}
34+
void InitializeSolver(const function_t & /*function*/,
35+
const state_t & /*initial_state*/) override {}
3536

3637
state_t OptimizationStep(const function_t &function, const state_t &current,
3738
const progress_t & /*progress*/) override {
39+
vector_t gradient;
40+
function(current.x, &gradient);
3841
const scalar_t rate = linesearch::MoreThuente<function_t, 1>::Search(
39-
current.x, -current.gradient, function);
42+
current.x, -gradient, function);
4043

41-
return function.GetState(current.x - rate * current.gradient);
44+
return function.GetState(current.x - rate * gradient);
4245
}
4346
};
4447

0 commit comments

Comments
 (0)