Skip to content

Using tinygp with "kernel as a Pytree" approach #80

@patel-zeel

Description

@patel-zeel

Hi @dfm, as we have discussed in #79, I have the following thoughts about presenting the idea of using kernels as Pytrees.

  1. Creating a model by providing the kernel as a param seems a more natural and efficient way of using tinygp. Considering that, shall we add or update the following code (Colab link) to the Tips section?:

(log_kernel may not be the best name here, so you may suggest something better)

import jax.numpy as jnp


def build_gp(params):
    log_kernel, log_noise = params
    kernel = jax.tree_map(jnp.exp, log_kernel)
    noise = jnp.exp(log_noise)
    return GaussianProcess(kernel, X, diag=noise)


@jax.jit
def loss(params):
    gp = build_gp(params)
    return -gp.log_probability(y)

log_amp = -0.1
log_scale = 0.0
log_noise = -1.0
log_kernel =  log_amp * kernels.ExpSquared(scale=log_scale)
params = (log_kernel, log_noise)
loss(params)
  1. I think "kernel as a Pytree" approach will have the most impact in the Getting Started section due to the use of kernel combinations. The current code can be translated to something like the following (Colab link). I could not make jaxopt work due to some reasons (maybe it is detecting parameters as floats instead of DeviceArrays) thus, I used optax instead (error trace is present in the colab).
import jax
import jax.numpy as jnp

from tinygp import kernels, GaussianProcess

jax.config.update("jax_enable_x64", True)

def build_kernel():
    k1 = np.log(66.0) * kernels.ExpSquared(np.log(67.0))
    k2 = (np.log(2.4)
          * kernels.ExpSquared(np.log(90.0))
          * kernels.ExpSineSquared(
              scale=np.log(1.0),
              gamma=np.log(4.3),
          )
      )
    k3 = np.log(0.66) * kernels.RationalQuadratic(
        alpha=np.log(1.2), scale=np.log(0.78)
    )
    k4 = np.log(0.18) * kernels.ExpSquared(np.log(1.6))
    kernel = k1 + k2 + k3 + k4

    return kernel

def build_gp(params, X):
    # We want most of our parameters to be positive so we take the `exp` here
    # Note that we're using `jnp` instead of `np`
    kernel, noise, mean = params
    kernel = jax.tree_map(jnp.exp, kernel)
    return GaussianProcess(kernel, X, diag=jnp.exp(noise), mean=mean)


def neg_log_likelihood(params, X, y):
    gp = build_gp(params, X)
    return -gp.log_probability(y)

kernel = build_kernel()
log_noise = np.log(0.19)
mean = np.float64(340.0)

params_init = (kernel, log_noise, mean)

# `jax` can be used to differentiate functions, and also note that we're calling
# `jax.jit` for the best performance.
obj = jax.jit(jax.value_and_grad(neg_log_likelihood))

print(f"Initial negative log likelihood: {obj(params_init, t, y)[0]}")
print(
    f"Gradient of the negative log likelihood, wrt the parameters:\n{obj(params_init, t, y)[1]}"
)
  1. In the Custom Kernels section, you have mentioned the following:

Besides describing this interface, we also show how tinygp can support arbitrary [JAX pytrees] (https://jax.readthedocs.io/en/latest/pytrees.html) as input.

I did not find something related to the above line in the same section. Was this line written to show something like we are discussing now? In that case, I can modify the current code for the spectral mixture kernel to showcase the new approach.

Please let me know your thoughts on these proposals.

P.S.: Feel free to drop your quick suggestions directly on the colab as comments!

Edit:
I think to make a new kernel work in the above approach, it needs to be defined something like this, right?

from tinygp.helpers import dataclass, field, JAXArray
@dataclass
class Linear(kernels.Kernel):
    scale: JAXArray = field(default_factory=lambda: jnp.ones(()))
    sigma: JAXArray = field(default_factory=lambda: jnp.zeros(()))
    
    def evaluate(self, X1, X2):
        return (X1 / self.scale) @ (X2 / self.scale) + jnp.square(self.sigma)

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions