From c447bca936b962bc437ae2e43ca148025cb3ecad Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Thu, 4 Apr 2024 11:26:01 -0400 Subject: [PATCH] fixing transpose bug --- src/tinygp/solvers/quasisep/core.py | 2 +- tests/test_solvers/test_quasisep/test_core.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 8867761..814cb21 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -295,7 +295,7 @@ def parallel_matmul(self, x: JAXArray) -> JAXArray: def impl(sm, sn): return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) - states = jax.vmap(lambda u, x: (u.a, jnp.outer(u.p, x)))(self, x) + states = jax.vmap(lambda u, x: (u.a.T, jnp.outer(u.p, x)))(self, x) f = jax.lax.associative_scan(impl, states, reverse=True)[1] f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0) return jax.vmap(jnp.dot)(self.q, f) diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index af56569..bd03396 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -6,6 +6,7 @@ import pytest from numpy import random as np_random +from tinygp.kernels.quasisep import Matern52 from tinygp.solvers.quasisep.core import ( DiagQSM, LowerTriQSM, @@ -17,7 +18,7 @@ from tinygp.test_utils import assert_allclose -@pytest.fixture(params=["random", "celerite"]) +@pytest.fixture(params=["random", "celerite", "matern"]) def name(request): return request.param @@ -104,6 +105,17 @@ def get_matrices(name): a = jnp.stack([jnp.diag(v) for v in jnp.exp(-c[None] * dt[:, None])], axis=0) p = jnp.einsum("ni,nij->nj", p, a) + elif name == "matern": + t = jnp.sort(random.uniform(0, 10, N)) + kernel = Matern52(1.5, 1.0) + matrix = kernel.to_symm_qsm(t) + diag = matrix.diag.d + p = matrix.lower.p + q = matrix.lower.q + a = matrix.lower.a + l = matrix.to_dense() + u = l.T + else: raise AssertionError()