Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 12, 2026

Description

Another day, another PR with a huge pile of matrix decompositions nobody asked for :)

solve_discrete_lyapunov is quite important for statespace, because it's used to initialize a (big) covariance matrix. There are two modes for solve_discrete_lyapunov: direct and bilinear. We have bilinear support in the C backend, but Numba and JAX only support direct. Direct involves making a kronecker product of a (big) covariance matrix, so it sucks.

Several linear control Ops: solve_sylvester, solve_discrete_lyapunov, and solve_continuous_lyapunov, are all actually the same Op. They involve a Schur decomposition of the inputs, followed by a call to the TRSYL routine. Aside from that, it's just some dots and solves.

This PR adds the Schur and TRSYL Ops, then uses these two components to implement all of the linear control ops (except for solve_discrete_are, that's another kettle of fish). These two new Ops don't have gradients, but the solvers that use them do, and lop_override is used to handle this. TRSYL is also not added to __all__, as it is not intended to be user-facing. Schur is a well-known matrix decomposition, so I exposed that one.

This has several advantages:

  • By exposing the inner graphs (that were previously in scipy.linalg) we can do linag rewrites
  • Less code to maintain -- all of the linear control Ops now flow through solve_sylvester, so I removed all of the specialized Ops/gradients for the others
  • Free backend dispatching -- as long as we can implement Schur and TRSYL in a backend, we get all the rest of the stuff "for free".

As it turns out, we can implement Schur in Numba and JAX, and TRSYL in Numba. This PR also does that.

I think @Dekermanjian added some code to JAX that would let us dispatch these ops to jax as well, but I don't know if it's merged/released. I would be willing to do that in this PR if it's out.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors and improves linear control operations in PyTensor by introducing Schur and TRSYL as fundamental building blocks. The refactoring consolidates solve_sylvester, solve_continuous_lyapunov, and solve_discrete_lyapunov to flow through a common implementation using these new Ops, enabling better graph optimization and native backend support for Numba and JAX.

Changes:

  • Introduces Schur decomposition Op and TRSYL (Sylvester equation solver) as new building blocks
  • Refactors solve_sylvester, solve_continuous_lyapunov, and solve_discrete_lyapunov to use the new Ops
  • Adds Numba and JAX dispatches for the new Ops with comprehensive test coverage

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
pytensor/tensor/slinalg.py Core implementation of TRSYL, Schur, and refactored linear control solvers with gradient support
pytensor/tensor/rewriting/linalg.py Updated rewriting rules to reference renamed SolveBilinearDiscreteLyapunov Op
pytensor/link/numba/dispatch/slinalg.py Numba dispatches for Schur and TRSYL with proper dtype handling
pytensor/link/numba/dispatch/linalg/_LAPACK.py Low-level LAPACK bindings for gees (Schur) and trsyl
pytensor/link/jax/dispatch/slinalg.py JAX dispatch for Schur decomposition
tests/tensor/test_slinalg.py Comprehensive tests for solve_sylvester, refactored solve_continuous_lyapunov tests, and new Schur tests
tests/link/numba/test_slinalg.py Numba-specific tests for Schur with overwrite and sort parameter validation
tests/link/jax/test_slinalg.py JAX-specific tests for Schur

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good, comments are pretty minor, besides what the bot found

VS = np.empty((_LDVS, _N), dtype=dtype)
RWORK = np.empty(_N, dtype=w_type)
BWORK = val_to_int_ptr(1)
INFO = val_to_int_ptr(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Info isn't used?

Comment on lines 193 to 194
if op.sort is not None:
raise NotImplementedError("jax.scipy.linalg.schur only supports sort=None.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of failing, issue a warning that it is ignored?

INFO,
)

if int_ptr_to_val(INFO) < 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

> 0 never happens or it's fine?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be 1, which means there were multiple eigenvalues of the same value, so a solution was found by jittering the matrix. A solution is returned in that case, so we can pass it back. For reference, scipy spams warnings in this case but also gives you back a matrix.

return np.full_like(C_copy, np.nan)

# CC now contains the solution, scale it
X = SCALE * C_copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reuse the array and then return it?

Suggested change
X = SCALE * C_copy
C_copy *= SCALE

sort = op.sort

if sort is not None:
raise NotImplementedError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, warn, or fallback to obj mode, don't fail

AxA = vec_kron(A, A.conj())
eye = pt.eye(AxA.shape[-1])

vec_Q = Q.reshape((*Q.shape[:-2], -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

join_dims cof cof

sort: Literal["lhp", "rhp", "iuc", "ouc"] | None = None,
):
self.output = output
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if gufunc_signature is constant, define it as an Op attribute, instead of in __init__

ids=["float", "complex", "batch_float"],
)
def test_solve_continuous_sylvester(shape: tuple[int], use_complex: bool):
# batch-complex case got an error from BatchedDot not implemented for complex numbers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the rewrite that introduces BatchedDot shouldn't try to introduce it?

Comment on lines 931 to 932
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
dtype = "complex128" if dtype == "float64" else "complex64"

@Dekermanjian
Copy link

Dekermanjian commented Jan 12, 2026

I think @Dekermanjian added some code to JAX that would let us dispatch these ops to jax as well, but I don't know if it's merged/released. I would be willing to do that in this PR if it's out.

Hey @jessegrabowski, yes I did. I added the Sylvester solver. I believe you can implement the Lyapunov like so:

import jax
import jax.numpy as jnp
import jax.scipy as jsp

def solve_discrete_lyapunov(A, C, method="bilinear"):
    if method == "bilinear":
        eye = jnp.eye(A.shape[0])
        aH = A.conj().transpose()
        aHI_inv = jnp.linalg.inv(aH + eye)
        b = jnp.dot(aH - eye, aHI_inv)
        c = 2 * jnp.dot(jnp.dot(jnp.linalg.inv(A + eye), C), aHI_inv)
        solution = jsp.linalg.solve_sylvester(b.conj().transpose(), b.conj(), -c)
        return solution
    if method == "direct":
        lhs = jnp.kron(A, A.conj())
        lhs = jnp.eye(lhs.shape[0]) - lhs
        x = jsp.linalg.solve(lhs, C.flatten())
        return jnp.reshape(x, C.shape)

I would also be happy to open a PR and add this method directly to Jax.

@jessegrabowski
Copy link
Member Author

I would also be happy to open a PR and add this method directly to Jax.

If we have jsp.linalg.solve_sylvester I will just dispatch that, then we don't need a lyapunov dispatch at all. This PR removes the lyapunov Op, and essentially just does that code you posted as a pytensor graph (except using solve instead of inv), so we would be able to dispatch that whole graph, potentially with optimizations once we can infer matrix structures.

@Dekermanjian
Copy link

I would also be happy to open a PR and add this method directly to Jax.

If we have jsp.linalg.solve_sylvester I will just dispatch that, then we don't need a lyapunov dispatch at all. This PR removes the lyapunov Op, and essentially just does that code you posted as a pytensor graph (except using solve instead of inv), so we would be able to dispatch that whole graph, potentially with optimizations once we can infer matrix structures.

This is really super cool! Please let me know what I can do to help out!!

@jessegrabowski
Copy link
Member Author

Drag that monster statespace PR over the finish line :P



class SolveContinuousLyapunov(Op):
class TRSYL(Op):
Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a tensor._linalg module, don't be shy to start populating it instead of these nlinalg / slinalg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants