Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LevenbergMarquardt and pytrees #587

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

gbruno16
Copy link

@gbruno16 gbruno16 commented Mar 31, 2024

I propose to replace the JAX NumPy operations in LevembergMarquardt with the corresponding ones in tree_utils to address issues #505 and #579. Now, the snippet in issue #505 appears to run correctly, both with and without geodesic acceleration (using the solver solve_cg).
However, QR, LU, and Cholesky still fail since they require the flattened versions of both the Jacobian and parameters.

Regarding the computation of the initial value of the damping_factor, using self.damping_parameter * jnp.max(jtj_diag) requires materializing the full identity matrix. Perhaps, for large problems like the one in Issue #579, it would be useful to include the option for the user to choose an initial damping_factor without calculating jtj_diag? (In the same way of the original paper by Marquardt https://www.jstor.org/stable/2098941, p.438)

@vroulet
Copy link
Collaborator

vroulet commented Mar 31, 2024

Hello @gbruno16,
Good to see you on this repo too!
A few comments:

  • We are currently in the process to migrate jaxopt into optax so it may be worth thinking about creating such an optimizer in optax directly. I'd be happy to help in this process but be aware that it'll take some time. In particular optax works with gradient transformation and does not handle solvers as in jaxopt yet. But again I thought about it so we could do it together (ping me by mail if you are interested!).
  • If you want to stick with jaxopt's api, I think you can revamp largely this function or even start from scratch with a simpler implementation (it would be helpful for the migration anyway). By simpler implementation, I mean having an implementation that
    • never materializes the jacobian, only works with jvps/vjps (note that once you have e.g. the jvp you can use jax.linear_transpose to get the vjp).
    • uses a cg_sovler that can work directly with linear operators (so no headache of doing an lu etc... and no materialization of the jacobian)
    • no geodesic accelerration to start with (so the code is a bit simpler, although that geodesic acceleration can be add-on later)
    • keep the gain_ratio logic.

@gbruno16
Copy link
Author

gbruno16 commented Apr 1, 2024

It seems interesting to me! I will send you an email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants