-
Notifications
You must be signed in to change notification settings - Fork 155
Improved Linear Control Ops, with Numba Dispatches #1840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
969a7e0 to
32e7fd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors and improves linear control operations in PyTensor by introducing Schur and TRSYL as fundamental building blocks. The refactoring consolidates solve_sylvester, solve_continuous_lyapunov, and solve_discrete_lyapunov to flow through a common implementation using these new Ops, enabling better graph optimization and native backend support for Numba and JAX.
Changes:
- Introduces
Schurdecomposition Op andTRSYL(Sylvester equation solver) as new building blocks - Refactors
solve_sylvester,solve_continuous_lyapunov, andsolve_discrete_lyapunovto use the new Ops - Adds Numba and JAX dispatches for the new Ops with comprehensive test coverage
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| pytensor/tensor/slinalg.py | Core implementation of TRSYL, Schur, and refactored linear control solvers with gradient support |
| pytensor/tensor/rewriting/linalg.py | Updated rewriting rules to reference renamed SolveBilinearDiscreteLyapunov Op |
| pytensor/link/numba/dispatch/slinalg.py | Numba dispatches for Schur and TRSYL with proper dtype handling |
| pytensor/link/numba/dispatch/linalg/_LAPACK.py | Low-level LAPACK bindings for gees (Schur) and trsyl |
| pytensor/link/jax/dispatch/slinalg.py | JAX dispatch for Schur decomposition |
| tests/tensor/test_slinalg.py | Comprehensive tests for solve_sylvester, refactored solve_continuous_lyapunov tests, and new Schur tests |
| tests/link/numba/test_slinalg.py | Numba-specific tests for Schur with overwrite and sort parameter validation |
| tests/link/jax/test_slinalg.py | JAX-specific tests for Schur |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good, comments are pretty minor, besides what the bot found
| VS = np.empty((_LDVS, _N), dtype=dtype) | ||
| RWORK = np.empty(_N, dtype=w_type) | ||
| BWORK = val_to_int_ptr(1) | ||
| INFO = val_to_int_ptr(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Info isn't used?
| if op.sort is not None: | ||
| raise NotImplementedError("jax.scipy.linalg.schur only supports sort=None.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of failing, issue a warning that it is ignored?
| INFO, | ||
| ) | ||
|
|
||
| if int_ptr_to_val(INFO) < 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
> 0 never happens or it's fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be 1, which means there were multiple eigenvalues of the same value, so a solution was found by jittering the matrix. A solution is returned in that case, so we can pass it back. For reference, scipy spams warnings in this case but also gives you back a matrix.
| return np.full_like(C_copy, np.nan) | ||
|
|
||
| # CC now contains the solution, scale it | ||
| X = SCALE * C_copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reuse the array and then return it?
| X = SCALE * C_copy | |
| C_copy *= SCALE |
| sort = op.sort | ||
|
|
||
| if sort is not None: | ||
| raise NotImplementedError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, warn, or fallback to obj mode, don't fail
pytensor/tensor/slinalg.py
Outdated
| AxA = vec_kron(A, A.conj()) | ||
| eye = pt.eye(AxA.shape[-1]) | ||
|
|
||
| vec_Q = Q.reshape((*Q.shape[:-2], -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
join_dims cof cof
| sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None, | ||
| ): | ||
| self.output = output | ||
| self.gufunc_signature = "(m,m)->(m,m),(m,m)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if gufunc_signature is constant, define it as an Op attribute, instead of in __init__
| ids=["float", "complex", "batch_float"], | ||
| ) | ||
| def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool): | ||
| # batch-complex case got an error from BatchedDot not implemented for complex numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then the rewrite that introduces BatchedDot shouldn't try to introduce it?
tests/tensor/test_slinalg.py
Outdated
| precision = int(dtype[-2:]) # 64 or 32 | ||
| dtype = f"complex{int(2 * precision)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| precision = int(dtype[-2:]) # 64 or 32 | |
| dtype = f"complex{int(2 * precision)}" | |
| dtype = "complex128" if dtype == "float64" else "complex64" |
Hey @jessegrabowski, yes I did. I added the Sylvester solver. I believe you can implement the Lyapunov like so: I would also be happy to open a PR and add this method directly to Jax. |
If we have |
This is really super cool! Please let me know what I can do to help out!! |
|
Drag that monster statespace PR over the finish line :P |
32e7fd3 to
1040325
Compare
pytensor/tensor/slinalg.py
Outdated
|
|
||
|
|
||
| class SolveContinuousLyapunov(Op): | ||
| class TRSYL(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a tensor._linalg module, don't be shy to start populating it instead of these nlinalg / slinalg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea!
Description
Another day, another PR with a huge pile of matrix decompositions nobody asked for :)
solve_discrete_lyapunovis quite important for statespace, because it's used to initialize a (big) covariance matrix. There are two modes forsolve_discrete_lyapunov: direct and bilinear. We have bilinear support in the C backend, but Numba and JAX only support direct. Direct involves making a kronecker product of a (big) covariance matrix, so it sucks.Several linear control Ops:
solve_sylvester,solve_discrete_lyapunov, andsolve_continuous_lyapunov, are all actually the same Op. They involve a Schur decomposition of the inputs, followed by a call to the TRSYL routine. Aside from that, it's just some dots and solves.This PR adds the
SchurandTRSYLOps, then uses these two components to implement all of the linear control ops (except forsolve_discrete_are, that's another kettle of fish). These two new Ops don't have gradients, but the solvers that use them do, andlop_overrideis used to handle this. TRSYL is also not added to__all__, as it is not intended to be user-facing.Schuris a well-known matrix decomposition, so I exposed that one.This has several advantages:
solve_sylvester, so I removed all of the specialized Ops/gradients for the othersAs it turns out, we can implement Schur in Numba and JAX, and TRSYL in Numba. This PR also does that.
I think @Dekermanjian added some code to JAX that would let us dispatch these ops to jax as well, but I don't know if it's merged/released. I would be willing to do that in this PR if it's out.
Related Issue
Checklist
Type of change