Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 6, 2023
1 parent 00f556b commit b878e45
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 37 deletions.
34 changes: 11 additions & 23 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def init(
alpha: JAXArray,
beta: JAXArray,
eta: JAXArray | None = 1e-30,
) -> "CARMA":
) -> CARMA:
r"""Construct a CARMA kernel using the alpha, beta parameters
Args:
Expand Down Expand Up @@ -724,9 +724,7 @@ def init(
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / (d + eta * real_mask)
om_complex = jnp.array([h1, h2])

obsmodel = (om_real * real_mask) + jnp.ravel(om_complex)[
::2
] * complex_mask
obsmodel = (om_real * real_mask) + jnp.ravel(om_complex)[::2] * complex_mask

## return class
return cls(
Expand All @@ -745,7 +743,7 @@ def init(
@classmethod
def from_quads(
cls, alpha_quads: JAXArray, beta_quads: JAXArray, beta_mult: JAXArray
) -> "CARMA":
) -> CARMA:
"""Construct a CARMA kernel using the roots of the characteristic polynomials
The roots can be re-parameterized as the coefficients of a product
Expand All @@ -764,9 +762,7 @@ def from_quads(
beta_quads = jnp.atleast_1d(beta_quads)
beta_mult = jnp.atleast_1d(beta_mult)

alpha = CARMA.quads2poly(jnp.append(alpha_quads, jnp.array([1.0])))[
:-1
]
alpha = CARMA.quads2poly(jnp.append(alpha_quads, jnp.array([1.0])))[:-1]
beta = CARMA.quads2poly(jnp.append(beta_quads, beta_mult))

return CARMA.init(alpha, beta)
Expand Down Expand Up @@ -794,9 +790,7 @@ def quads2poly(quads_coeffs: JAXArray) -> JAXArray:
size = quads_coeffs.shape[0] - 1
remain = size % 2
nPair = size // 2
mult_f = quads_coeffs[
-1:
] # The coeff of highest order term in the output
mult_f = quads_coeffs[-1:] # The coeff of highest order term in the output

poly = jax.lax.cond(
remain == 1,
Expand Down Expand Up @@ -831,7 +825,7 @@ def poly2quads(poly_coeffs: JAXArray) -> tuple[JAXArray, JAXArray]:
equations. The last entry should a scaling factor, which corresponds to the coefficient of the highest order term in the full polynomial.
"""

quads = jnp.empty((0))
quads = jnp.empty(0)
mult_f = poly_coeffs[-1]
roots = CARMA.roots(poly_coeffs / mult_f)
odd = bool(len(roots) & 0x1)
Expand Down Expand Up @@ -859,9 +853,7 @@ def poly2quads(poly_coeffs: JAXArray) -> tuple[JAXArray, JAXArray]:
return jnp.append(quads, jnp.array(mult_f))

@staticmethod
def carma_acvf(
arroots: JAXArray, arparam: JAXArray, maparam: JAXArray
) -> JAXArray:
def carma_acvf(arroots: JAXArray, arparam: JAXArray, maparam: JAXArray) -> JAXArray:
"""Compute the coefficient of each term in the autocovariance function (ACVF) given CARMA parameters
Args:
Expand Down Expand Up @@ -906,30 +898,26 @@ def design_matrix(self) -> JAXArray:
## for complex exponential components
dm_complex_diag = jnp.diag(self.arroots.real * self.complex_mask)
# upper triangle entries
dm_complex_u = jnp.diag(
(self.arroots.imag * self.complex_select)[:-1], k=1
)
dm_complex_u = jnp.diag((self.arroots.imag * self.complex_select)[:-1], k=1)

return dm_real + dm_complex_diag + -dm_complex_u.T + dm_complex_u

def stationary_covariance(self) -> JAXArray:
p = self.acf.shape[0]

## for real exponential components
diag = jnp.diag(
jnp.where(self.acf.real > 0, jnp.ones(p), -jnp.ones(p))
)
diag = jnp.diag(jnp.where(self.acf.real > 0, jnp.ones(p), -jnp.ones(p)))

## for complex exponential components
diag_complex = jnp.diag(
2
* jnp.square(
(

self.arroots.real
/ (self.arroots.imag + self._eta)
* jnp.roll(self.complex_select, 1)
* self.complex_mask
)

)
)
c_over_d = self.arroots.real / (self.arroots.imag + self._eta)
Expand Down
11 changes: 3 additions & 8 deletions tests/test_kernels/test_quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,15 @@ def test_carma(data):
validate_kernels = [
quasisep.Exp(scale=100.0, sigma=np.sqrt(0.5)),
quasisep.Celerite(25.0 / 6, 2.5, 0.6, -0.8),
quasisep.Exp(1.0, np.sqrt(4.04040404))
+ quasisep.Exp(10.0, np.sqrt(4.5959596)),
quasisep.Exp(1.0, np.sqrt(4.04040404)) + quasisep.Exp(10.0, np.sqrt(4.5959596)),
]

# Compare log_probability & normalization
for i in range(len(carma2_kernels)):
gp1 = GaussianProcess(carma2_kernels[i], x, diag=0.1)
gp2 = GaussianProcess(validate_kernels[i], x, diag=0.1)

np.testing.assert_allclose(
gp1.log_probability(y), gp2.log_probability(y)
)
np.testing.assert_allclose(gp1.log_probability(y), gp2.log_probability(y))
np.testing.assert_allclose(
gp1.solver.normalization(), gp2.solver.normalization()
)
Expand All @@ -118,9 +115,7 @@ def test_carma_jit(data):
x, y, t = data

def build_gp(params):
carma_kernel = quasisep.CARMA.init(
alpha=params["alpha"], beta=params["beta"]
)
carma_kernel = quasisep.CARMA.init(alpha=params["alpha"], beta=params["beta"])
return GaussianProcess(carma_kernel, x, diag=0.01, mean=0.0)

@jax.jit
Expand Down
8 changes: 2 additions & 6 deletions tests/test_solvers/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def data(random):
1.5 * quasisep.Matern52(1.5) + 0.3 * quasisep.Exp(1.5),
quasisep.Matern52(1.5) * quasisep.SHO(omega=1.5, quality=0.1),
1.5 * quasisep.Matern52(1.5) * quasisep.Celerite(1.1, 0.8, 0.9, 0.1),
quasisep.CARMA.init(
alpha=np.array([1.4, 2.3, 1.5]), beta=np.array([0.1, 0.5])
),
quasisep.CARMA.init(alpha=np.array([1.4, 2.3, 1.5]), beta=np.array([0.1, 0.5])),
]
)
def kernel(request):
Expand Down Expand Up @@ -67,6 +65,4 @@ def test_consistent_with_direct(kernel, data):
gp2 = GaussianProcess(kernel, x, diag=0.1, solver=QuasisepSolver)

np.testing.assert_allclose(gp1.log_probability(y), gp2.log_probability(y))
np.testing.assert_allclose(
gp1.solver.normalization(), gp2.solver.normalization()
)
np.testing.assert_allclose(gp1.solver.normalization(), gp2.solver.normalization())

0 comments on commit b878e45

Please sign in to comment.