Skip to content

Commit

Permalink
fixing transpose bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Apr 4, 2024
1 parent 56503a4 commit c447bca
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def parallel_matmul(self, x: JAXArray) -> JAXArray:
def impl(sm, sn):
return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1])

states = jax.vmap(lambda u, x: (u.a, jnp.outer(u.p, x)))(self, x)
states = jax.vmap(lambda u, x: (u.a.T, jnp.outer(u.p, x)))(self, x)
f = jax.lax.associative_scan(impl, states, reverse=True)[1]
f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0)
return jax.vmap(jnp.dot)(self.q, f)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_solvers/test_quasisep/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from numpy import random as np_random

from tinygp.kernels.quasisep import Matern52
from tinygp.solvers.quasisep.core import (
DiagQSM,
LowerTriQSM,
Expand All @@ -17,7 +18,7 @@
from tinygp.test_utils import assert_allclose


@pytest.fixture(params=["random", "celerite"])
@pytest.fixture(params=["random", "celerite", "matern"])
def name(request):
return request.param

Expand Down Expand Up @@ -104,6 +105,17 @@ def get_matrices(name):
a = jnp.stack([jnp.diag(v) for v in jnp.exp(-c[None] * dt[:, None])], axis=0)
p = jnp.einsum("ni,nij->nj", p, a)

elif name == "matern":
t = jnp.sort(random.uniform(0, 10, N))
kernel = Matern52(1.5, 1.0)
matrix = kernel.to_symm_qsm(t)
diag = matrix.diag.d
p = matrix.lower.p
q = matrix.lower.q
a = matrix.lower.a
l = matrix.to_dense()
u = l.T

else:
raise AssertionError()

Expand Down

0 comments on commit c447bca

Please sign in to comment.