Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Pseudo-inverse preconditioner for EqQP. #133

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

GeoffNN
Copy link
Contributor

@GeoffNN GeoffNN commented Dec 17, 2021

This allows to precompute a preconditioner, and share it across multiple
outer loops, where the inner loop is solving an Equality Constrained QP.
This should provide speedups when the parameters of the inner loop QP
don't change too much.

TODO: modify the implicit diff decorator so that the jvp also uses the same preconditioner, since the backward system shares the same linear operator with the forward.

@GeoffNN
Copy link
Contributor Author

GeoffNN commented Dec 17, 2021

CC @Algue-Rythme, this is pretty similar to your Jacobi preconditioner for OSQP.

@GeoffNN GeoffNN force-pushed the eq_qp_precond branch 2 times, most recently from ffefbf1 to e118328 Compare December 17, 2021 02:39
Copy link
Collaborator

@Algue-Rythme Algue-Rythme left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PseudoInverse preconditioner is a good idea if you plan on solving a sequence of similar QPs.

I left a few comments.

@@ -77,7 +77,7 @@ def test_qp_eq_only(self):
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
sol = qp.run(**hyperparams).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! Maybe you need self._check_derivative_Q_c_A_b(qp, None, Q, c, A, b) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the second argument from the self._check_derivative_Q_c_A_b method since it doesn't use it actually.



def row_matvec(block, x):
return sum(jax.tree_util.tree_map(jnp.dot, block, x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, here block is actually a tuple of blocks, and x a tuple of vectors with the same structure ? So technicaly it is not a row_vector product since the result is not a scalar as one would expect. Maybe add a docstring and consider renaming the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, let me know :)

[C, D]]

"""
return jax.tree_util.tree_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are stopping the recursion at depth 1 (i.e you only retrieve [A,B] and C,D) I believe it would be more clear to hardcode it instead of using tree_map. It is a bit overkill here. For example consider writing upper_block = self.blocks[0].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True ; my hope was to leave a stub which could be extended for any sized block matrix, and not just 2x2. I can go either way on this, let me know what you think.

@jax.tree_util.register_pytree_node_class
@dataclass
class BlockLinearOperator:
"""Represents a linear operator defined by blocks over a block pytree.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the pytree refers to ? The 2x2 tuple Tuple[Tuple[jnp.array]] ? Or is it something more complicated ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is -- without registering it as a pytree node, I get this error: type <class 'jaxopt._src.linear_operator.BlockLinearOperator'> is not a valid JAX type.

@GeoffNN GeoffNN force-pushed the eq_qp_precond branch 3 times, most recently from 64bb1a1 to af04a6d Compare December 17, 2021 20:14
@GeoffNN GeoffNN marked this pull request as draft December 17, 2021 20:36
@GeoffNN GeoffNN force-pushed the eq_qp_precond branch 2 times, most recently from c78943e to 5a3ca42 Compare December 17, 2021 20:48
This allows to precompute a preconditioner, and share it across multiple
outer loops, where the inner loop is solving an Equality Constrained QP.
This should provide speedups when the parameters of the inner loop QP
don't change too much.

TODO: modify the implicit diff decorator so that the jvp also uses the preconditioner.
@@ -57,7 +57,8 @@ def eq_fun(primal_var, params_eq):

# It is required to post_process the output of `idf.make_kkt_optimality_fun`
# to make the signatures of optimality_fun() and run() agree.
def optimality_fun(params, params_obj, params_eq):
# The M argument is needed for using preconditioners.
def optimality_fun(params, params_obj, params_eq, M=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On first sight, I'm rather -1 on introducing a pre-conditioner M in optimality_fun since I don't think it plays any role in the optimality conditions (it would not make sense to differentiate with respect to M). Since run and optimality_fun need to have the same signature, this rules out adding M to run as well.

I think I would go for something like this instead:

preconditioner = PseudoInversePreconditioner(params_obj, params_eq)
qp = EqualityConstrainedQP(preconditioner=preconditioner)
qp.run(params_obj=params_obj, params_eq=params_eq)

Typically, stuff that doesn't need to be differentiated should go to the constructor.

If you want to differentiate wrt params_eq or params_obj, you may need to use

EqualityConstrainedQP(preconditioner=lax.stop_gradient(preconditioner))

instead. Not entirely sure if PseudoInversePreconditioner should live in JAXopt or in user land.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't play a role in the optimality conditions -- it could play a role in the backwards though, if we manage to pass the argument to the backward solver as well. This would hopefully speed things up, since the forward and backward linear systems share the linear operator.

With the current API, if we're solving the QP as the inner problem of a bi-level problem, then we need to build a new QP solver instance at each step of the outer loop, and pass solve=partial(linearsolver, M=preconditioner) to both solve and implicit_diff_solve.

Something which is then unclear to me: does building a new instance of the QP solver necessarily trigger recompilation of the run method at each iteration ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants