Replies: 3 comments 8 replies
-
Thanks for the clear summary @ntessore. For the spherical harmonic transforms, we've just released a new JAX code: s2fft. This is a very early release; we're still polishing the code and writing the related paper. For the fast Legendre transforms, it seems there is already a DCT in JAX. Could the iterative matrix construction and then application be ported to Python? Otherwise, we could perhaps consider alterative approaches to the Legendre transform, perhaps leveraging components from spherical harmonic transforms algorithms. |
Beta Was this translation helpful? Give feedback.
-
Let me reply to this as a new comment:
That's good, so we nominally only have to figure out a big matrix multiplication.
Porting it is not a problem as such, but when doing validation tests, the matrix was routinely 100,000 x 100,000 (and a lot of accuracy tests were 10,000,000 x 10,000,000). However, maybe one can exploit the structure of the matrix better. It is a triangular matrix where every second off-diagonal is zero. Maybe one could look at efficient triangular x vector multiplications in Python, and try to make one compatible with the recurrence? The one we use goes one row down the diagonal and across all columns, but maybe one can go down across two columns and down an entire diagonal as well.
I have no particular attachment to our current method, but we did come up with it because no existing alternative worked well enough. All methods that I am aware of map between
and end up with a result where mapping |
Beta Was this translation helpful? Give feedback.
-
It was helpfully pointed out to me by @nstarman that the So let's try and get something done in that direction. I have opened issue glass-dev/glass#67 for more specific implementation details. Hopefully, we can get a first draft PR soon. @EiffL, let me know if you or others need to be in the loop. |
Beta Was this translation helpful? Give feedback.
-
This discussion is to try and unify the questions of: What about GPU or TPU support / autodiff / jax support / torch support / tensorflow support / etc.?
Short answer
All of this should be possible, but requires an amount of work that depends on where you want to make a cut; from surface-level modifications (Gaussian fields) to deep programming work (lognormal fields) and more difficult conceptual questions (galaxy catalogues).
Longer answer and to do list
GLASS has from the beginning been designed with acceleration and differentiability in mind. As a matter of fact, what was ultimately released is actually the fourth complete rewrite of GLASS, and part of each was that at some point I knew the old code would not be flexible enough to support things like autodiff or GPU computation.
So this has been very much part of the plan. In fact, I can give a reasonable to do list of what is necessary to implement e.g. JAX support in the existing library functions.
Overall, my strategy and hope is that GLASS should eventually be written in a way that is completely agnostic about the type of array it is working on, and thus support all of numpy, jax, dask, etc. natively in one code base. This could involve e.g. use of the Array API.
Gaussian fields
The first step should definitely be making Gaussian matter fields work, because lognormal ones pose a greater difficulty (see below), and this is still a good test case.
In principle, only surface-level changes to the existing code are necessary, because everything that is necessary to sample random fields is GLASS' own code, precisely for the purpose of this discussion, with the exception of healpy's spherical harmonic transforms.
What needs to change is that the functions in
glass.fields
must be made array agnostic, so that they work equally on e.g. numpy arrays and JAX arrays. This should not be hard. The most tricky bit would be theiternorm
generator for the iterative sampling of a multivariate normal, but I actually have half a prototype for that. Everything else appears to only require figuring out how to change from thenumpy
namespace to something more flexible.For sampling Gaussian random fields, that leaves the spherical harmonic transforms, and perhaps others can comment on our options.
Lognormal fields
Lognormal fields (or rather, any transformation of a Gaussian field) are difficult. The actual transformation of the Gaussian field (for lognormal,
Y = exp(X + c) - 1
) is easy; as above, it only requires swapping out thenumpy
namespace for something generic.But lognormal fields require finding the Gaussian angular power spectra (Gls) that result in the correct realised angular power spectra (Cls) after transformation of the field. To get the desired accuracy in the Cls, we use the nonlinear solver we present in arXiv:2302-01942. That solver has two layers: One is the solver itself, which is my
gaussiancl
package. That is Python code, and I think it can easily be made array agnostic, because it performs simple arithmetic. The code can be seen here: https://github.com/cltools/gaussiancl/blob/main/gaussiancl.pyThe problem is the second layer, which is the conversion between angular power spectra and angular correlation functions
C(θ)
. That is nominally done by mytransformcl
package, but the actual numeric heavy lifting is done by a Fast Legendre Transform implemented in theflt
package. That is C code using the method we present in arXiv:2302-01942, which has two steps:C(θ)
to its Fourier coefficients.The implementation of that can be found here: https://github.com/ntessore/flt/blob/main/flt.pyx
Making this same code array agnostic would require (1) either a generic DCT implementation or a DCT implementation for each array type, and (2) coding up the matrix multiplication in an array agnostic manner.
The latter point is not trivial; the reason that it is done in C is that the matrices are too big to construct explicitly, and we instead compute them iteratively at the same time as doing the matrix multiplication itself. This happens here: https://github.com/ntessore/flt/blob/main/dctdlt.c
Weak lensing
This should be a simple case of removing a few
numpy
namespace uses. The actual operation performed is a weighted sum of the matter fields, so it should always work with whatever array type is being used.Galaxies
Instead of going into detail about the individual steps in the galaxies sector, let me ask a question: How can we make
differentiable?
Summary
I think we should be able to get Gaussian random fields and lensing working more or less right away.
If that looks promising, maybe we can come up with a plan for how to make the Gaussian angular power spectra work.
Beta Was this translation helpful? Give feedback.
All reactions