Skip to content

Commit

Permalink
Added Pseudo-inverse preconditioner for EqQP.
Browse files Browse the repository at this point in the history
This allows to precompute a preconditioner, and share it across multiple
outer loops, where the inner loop is solving an Equality Constrained QP.
This should provide speedups when the parameters of the inner loop QP
don't change too much.

TODO: modify the implicit diff decorator so that the jvp also uses the preconditioner.
  • Loading branch information
GeoffNN committed Dec 17, 2021
1 parent ced83e9 commit ffefbf1
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 9 deletions.
7 changes: 5 additions & 2 deletions jaxopt/_src/eq_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class EqualityConstrainedQP(base.Solver):
implicit_diff_solve: Optional[Callable] = None
jit: bool = True

def _refined_solve(self, matvec, b, init, maxiter, tol):
def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs):
# Instead of solving S x = b
# We solve \bar{S} x = b
#
Expand Down Expand Up @@ -152,13 +152,14 @@ def matvec_regularized_qp(_, x):
maxiter=self.refine_maxiter,
tol=tol,
)
return solver.run(init_params=init, A=None, b=b)[0]
return solver.run(init_params=init, A=None, b=b, **kwargs)[0]

def run(
self,
init_params: Optional[base.KKTSolution] = None,
params_obj: Optional[Any] = None,
params_eq: Optional[Any] = None,
**kwargs,
) -> base.OptStep:
"""Solves 0.5 * x^T Q x + c^T x subject to Ax = b.
Expand All @@ -168,6 +169,7 @@ def run(
init_params: ignored.
params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
params_eq: (A, b) or (params_A, b) if matvec_A is provided.
**kwargs: Keyword args provided to the solver.
Returns:
(params, state), where params = (primal_var, dual_var_eq, None)
"""
Expand Down Expand Up @@ -200,6 +202,7 @@ def matvec(u):
init=init_params,
tol=self.tol,
maxiter=self.maxiter,
**kwargs,
)
else:
primal, dual_eq = self._refined_solve(
Expand Down
58 changes: 58 additions & 0 deletions jaxopt/_src/eq_qp_preconditioned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Preconditioned solvers for equality constrained quadratic programming."""

from typing import Optional, Any
from dataclasses import dataclass
import jax.numpy as jnp
import jaxopt
from jaxopt._src import base
from jaxopt._src import linear_operator


@dataclass
class PseudoInversePreconditionedEqQP(base.Solver):
qp_solver: jaxopt.EqualityConstrainedQP

def init_params(self, params_obj, params_eq):
"""Computes the matvec associated to the pseudo inverse of the KKT matrix."""
Q, p = params_obj
A, b = params_eq
del p, b

kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]])
kkt_mat_pinv = jnp.linalg.pinv(kkt_mat)

d = Q.shape[0]

pinv_blocks = (
(kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]),
(kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]),
)
return linear_operator.BlockLinearOperator(pinv_blocks)

def run(
self,
init_params: Optional[base.KKTSolution] = None,
params_obj: Optional[Any] = None,
params_eq: Optional[Any] = None,
params_precond=None,
**kwargs
):
# TODO(gnegiar): the M parameter should be passed to both
# the QP solve and the implicit_diff_solve
return self.qp_solver.run(
init_params, params_obj, params_eq, M=params_precond, **kwargs
)
71 changes: 65 additions & 6 deletions jaxopt/_src/linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
"""Interface for linear operators."""

import functools
from dataclasses import dataclass
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as onp

from jaxopt.tree_util import tree_map, tree_sum, tree_mul
from jaxopt.tree_util import tree_map


class DenseLinearOperator:

def __init__(self, pytree):
self.pytree = pytree

Expand All @@ -33,7 +34,7 @@ def matvec(self, x):
return tree_map(jnp.dot, self.pytree, x)

def rmatvec(self, _, y):
return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y)
return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y)

def matvec_and_rmatvec(self, x, y):
return self.matvec(x), self.rmatvec(x, y)
Expand All @@ -52,11 +53,11 @@ def col_norm(w):
if not squared:
col_norms = jnp.sqrt(col_norms)
return col_norms

return tree_map(col_norm, self.pytree)


class FunctionalLinearOperator:

def __init__(self, fun, params):
self.fun = functools.partial(fun, params)

Expand All @@ -71,7 +72,7 @@ def rmatvec(self, x, y):

def matvec_and_rmatvec(self, x, y):
matvec_x, vjp = jax.vjp(self.matvec, x)
rmatvec_y, = vjp(y)
(rmatvec_y,) = vjp(y)
return matvec_x, rmatvec_y

def normal_matvec(self, x):
Expand All @@ -85,3 +86,61 @@ def _make_linear_operator(matvec):
return DenseLinearOperator
else:
return functools.partial(FunctionalLinearOperator, matvec)


def row_matvec(block, u):
return sum(jax.tree_util.tree_map(jnp.dot, block, u))


# TODO(gnegiar): Extend to arbitrary block shapes.
@jax.tree_util.register_pytree_node_class
@dataclass
class BlockLinearOperator:
"""Represents a linear operator defined by blocks over a block pytree.
Attributes:
blocks: a 2x2 block matrix of the form
[[A, B]
[C, D]]
"""

blocks: Tuple[Tuple[jnp.array]]

def __call__(self, x):
return self.matvec(x)

def matvec(self, x):
"""Performs the block matvec with u defined by blocks.
The matvec is of form:
[u1, u2]
[[A, B] *
[C, D]]
"""
return jax.tree_util.tree_map(
lambda block: row_matvec(block, u),
self.blocks,
is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1],
)

def rmatvec(self, x, y):
return self.matvec_and_rmatvec(x, y)[1]

def matvec_and_rmatvec(self, x, y):
matvec_x, vjp = jax.vjp(self.matvec, x)
(rmatvec_y,) = vjp(y)
return matvec_x, rmatvec_y

def normal_matvec(self, x):
"""Computes A^T A x from matvec(x) = A x."""
matvec_x, vjp = jax.vjp(self.matvec, x)
return vjp(matvec_x)[0]

def tree_flatten(self):
return self.blocks, None

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(children)
82 changes: 82 additions & 0 deletions tests/eq_qp_preconditioned_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP
from jaxopt import EqualityConstrainedQP
import numpy as onp


class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase):
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
def fun(Q, c, A, b):
Q = 0.5 * (Q + Q.T)

hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
# reduce the primal variables to a scalar value for test purpose.
return jnp.sum(solver.run(**hyperparams).params[0])

# Derivative w.r.t. A.
rng = onp.random.RandomState(0)
V = rng.rand(*A.shape)
V /= onp.sqrt(onp.sum(V ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b))
deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. b.
v = rng.rand(*b.shape)
v /= onp.sqrt(onp.sum(v ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b))
deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. Q
W = rng.rand(*Q.shape)
W /= onp.sqrt(onp.sum(W ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b))
deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. c
w = rng.rand(*c.shape)
w /= onp.sqrt(onp.sum(w ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b))
deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

def test_pseudoinverse_preconditioner(self):
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
qp = EqualityConstrainedQP(tol=1e-7)
preconditioned_qp = PseudoInversePreconditionedEqQP(qp)
params_obj = (Q, c)
params_eq = (A, b)
params_precond = preconditioned_qp.init_params(params_obj, params_eq)
hyperparams = dict(
params_obj=params_obj,
params_eq=params_eq,
)
sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
2 changes: 1 addition & 1 deletion tests/eq_qp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_qp_eq_only(self):
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
sol = qp.run(**hyperparams).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)

def test_qp_eq_with_init(self):
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
Expand Down

0 comments on commit ffefbf1

Please sign in to comment.