Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tinygp and emcee #206

Open
pilchat opened this issue Feb 21, 2024 · 2 comments
Open

tinygp and emcee #206

pilchat opened this issue Feb 21, 2024 · 2 comments

Comments

@pilchat
Copy link

pilchat commented Feb 21, 2024

Hi,

I have a code that samples a likelihood function using emcee in a multithreaded framework. the likelihood is computed by tinygp.

Now I have vectorized part of the likelihod function for running on GPUs and wonder about the smartest way to use tinygp in a vectorized flavor. I've been looking for tutorials (not an expert here, unfortunately) but I could only find one of yours using numpyro. Is there any specific reason not to use emcee? More importantly, is there a way to use tinygp in a vectorized way?

Thanks

@dfm
Copy link
Owner

dfm commented Feb 21, 2024

NumPyro will typically get better sampling performance since it used gradient-based sampling methods, and takes better advantage of JAX's JIT compilation. That being said, there shouldn't be any real reason why you can't use emcee.

Can you say a little more about what you mean by "vectorized" here?

@pilchat
Copy link
Author

pilchat commented Feb 21, 2024

My emcee sampler samples a vectorized likelihood function. This likelihood functions thus accepts a 2D array and uses the native vectorization of the numpy (or jax.numpy) library. The only problem is that, as far as I understand, the kernel module in tinygp accepts 1D arrays, so I gotta find a way to handle this to run efficiently on a GPU. This point is actually the bottleneck of the entire code, as the GP computation is the most computationally demanding task and I need to optimize it at best.

I have read something about avoiding for loops and using either numpy.vectorize() or map(), but I'm kinda lost among all the possibilities and short on time for proper testings. This is why I'm asking for advice.

Thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants