-
Notifications
You must be signed in to change notification settings - Fork 32
Description
Hi @dfm, as we have discussed in #79, I have the following thoughts about presenting the idea of using kernels as Pytrees.
- Creating a model by providing the
kernelas aparamseems a more natural and efficient way of usingtinygp. Considering that, shall we add or update the following code (Colab link) to the Tips section?:
(log_kernel may not be the best name here, so you may suggest something better)
import jax.numpy as jnp
def build_gp(params):
log_kernel, log_noise = params
kernel = jax.tree_map(jnp.exp, log_kernel)
noise = jnp.exp(log_noise)
return GaussianProcess(kernel, X, diag=noise)
@jax.jit
def loss(params):
gp = build_gp(params)
return -gp.log_probability(y)
log_amp = -0.1
log_scale = 0.0
log_noise = -1.0
log_kernel = log_amp * kernels.ExpSquared(scale=log_scale)
params = (log_kernel, log_noise)
loss(params)- I think "
kernelas a Pytree" approach will have the most impact in the Getting Started section due to the use of kernel combinations. The current code can be translated to something like the following (Colab link). I could not makejaxoptwork due to some reasons (maybe it is detecting parameters as floats instead of DeviceArrays) thus, I usedoptaxinstead (error trace is present in the colab).
import jax
import jax.numpy as jnp
from tinygp import kernels, GaussianProcess
jax.config.update("jax_enable_x64", True)
def build_kernel():
k1 = np.log(66.0) * kernels.ExpSquared(np.log(67.0))
k2 = (np.log(2.4)
* kernels.ExpSquared(np.log(90.0))
* kernels.ExpSineSquared(
scale=np.log(1.0),
gamma=np.log(4.3),
)
)
k3 = np.log(0.66) * kernels.RationalQuadratic(
alpha=np.log(1.2), scale=np.log(0.78)
)
k4 = np.log(0.18) * kernels.ExpSquared(np.log(1.6))
kernel = k1 + k2 + k3 + k4
return kernel
def build_gp(params, X):
# We want most of our parameters to be positive so we take the `exp` here
# Note that we're using `jnp` instead of `np`
kernel, noise, mean = params
kernel = jax.tree_map(jnp.exp, kernel)
return GaussianProcess(kernel, X, diag=jnp.exp(noise), mean=mean)
def neg_log_likelihood(params, X, y):
gp = build_gp(params, X)
return -gp.log_probability(y)
kernel = build_kernel()
log_noise = np.log(0.19)
mean = np.float64(340.0)
params_init = (kernel, log_noise, mean)
# `jax` can be used to differentiate functions, and also note that we're calling
# `jax.jit` for the best performance.
obj = jax.jit(jax.value_and_grad(neg_log_likelihood))
print(f"Initial negative log likelihood: {obj(params_init, t, y)[0]}")
print(
f"Gradient of the negative log likelihood, wrt the parameters:\n{obj(params_init, t, y)[1]}"
)- In the Custom Kernels section, you have mentioned the following:
Besides describing this interface, we also show how tinygp can support arbitrary [JAX pytrees] (https://jax.readthedocs.io/en/latest/pytrees.html) as input.
I did not find something related to the above line in the same section. Was this line written to show something like we are discussing now? In that case, I can modify the current code for the spectral mixture kernel to showcase the new approach.
Please let me know your thoughts on these proposals.
P.S.: Feel free to drop your quick suggestions directly on the colab as comments!
Edit:
I think to make a new kernel work in the above approach, it needs to be defined something like this, right?
from tinygp.helpers import dataclass, field, JAXArray
@dataclass
class Linear(kernels.Kernel):
scale: JAXArray = field(default_factory=lambda: jnp.ones(()))
sigma: JAXArray = field(default_factory=lambda: jnp.zeros(()))
def evaluate(self, X1, X2):
return (X1 / self.scale) @ (X2 / self.scale) + jnp.square(self.sigma)