Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Configuration and API change #25

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
:alt: logo


##############################
particle mesh with derivatives
==============================
##############################

``pmwd`` is a differentiable cosmological particle-mesh forward model.
The C\ :sub:`2` symmetry of the name symbolizes the reversibility of the
Expand All @@ -29,8 +30,9 @@ interesting projected patterns.
<video src="https://user-images.githubusercontent.com/7311098/212061152-2b1be0ac-bfc4-4b57-87fe-d5b9b5c38e8c.mp4"></video>


************
Installation
------------
************

.. code:: sh

Expand All @@ -41,15 +43,19 @@ Installation
pip install -e .[dev] # to install development dependencies


********
Examples
--------
********

See `docs/examples <docs/examples>`_.


..
*******
Testing
-------
*******
pytest-mypy
pytest-mypy-testing

.. code:: sh

Expand All @@ -70,8 +76,9 @@ See `docs/examples <docs/examples>`_.


..
**********************
References & Citations
----------------------
**********************

We refer the users to the following references for ...
Please cite the following papers:
Expand Down
33 changes: 26 additions & 7 deletions pmwd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
"""pmwd: particle mesh with derivatives"""


from pmwd.configuration import Configuration
from pmwd.cosmology import Cosmology, SimpleLCDM, Planck18, E2, H_deriv, Omega_m_a
from pmwd.boltzmann import (transfer_integ, transfer_fit, transfer, growth_integ,
growth, varlin_integ, varlin, boltzmann, linear_power)
from pmwd.particles import (Particles, ptcl_enmesh,
ptcl_pos, ptcl_rpos, ptcl_rsd, ptcl_los)
import jax
import jax.numpy as jnp



# TODO pmwd.cosmology.{background,perturbation,cosmology} ?
from pmwd.background import E2, H_deriv, Omega_m_a, distance_cache, distance
from pmwd.perturbation import (transfer_cache, transfer_fit, transfer,
growth_cache, growth,
varlin_cache, varlin,
linear_power)
from pmwd.cosmology import Cosmology, SimpleLCDM, Planck18
from pmwd.solver import Solver
from pmwd.modes import white_noise, linear_modes
#from pmwd.particles import (Particles, ptcl_mass, ptcl_enmesh, #FIXME conf problem
# ptcl_rpos, ptcl_rsd, ptcl_los)
from pmwd.particles import Particles, ptcl_mass
#from pmwd.observables import FIXME
from pmwd.scatter import scatter
from pmwd.gather import gather
from pmwd.gravity import laplace, neg_grad, gravity
from pmwd.modes import white_noise, linear_modes
from pmwd.lpt import lpt
from pmwd.nbody import nbody
try:
from pmwd._version import __version__
except ModuleNotFoundError:
pass # not installed


#FIXME bump version of mcfit (after disabled x64) and make that minimum requirement here
#FIXME jax.config.update("jax_enable_x64", False)


#move this to some style script for ipynb's, or the worst to every ipynb
#jnp.set_printoptions(precision=3, edgeitems=2, linewidth=128)
240 changes: 240 additions & 0 deletions pmwd/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
from functools import partial

from jax import value_and_grad
import jax.numpy as jnp
from jax.lax import switch


def E2(a, cosmo):
r"""Squared relative Hubble parameter, :math:`E^2`, normalized at :math:`a=1`.

Parameters
----------
a : ArrayLike
Scale factors.
cosmo : Cosmology

Returns
-------
E2 : jax.Array of cosmo.dtype
Squared relative Hubble parameter.

Notes
-----
The squared Hubble parameter,

.. math::

H^2(a) = H_0^2 E^2(a),

has the time dependence

.. math::

E^2(a) = \Omega_\mathrm{m} a^{-3} + \Omega_\mathrm{k} a^{-2}
+ \Omega_\mathrm{de} a^{-3 (1 + w_0 + w_a)} e^{-3 w_a (1 - a)}.

"""
a = jnp.asarray(a, dtype=cosmo.dtype)

de_a = a**(-3 * (1 + cosmo.w_0 + cosmo.w_a)) * jnp.exp(-3 * cosmo.w_a * (1 - a))
return cosmo.Omega_m * a**-3 + cosmo.Omega_K * a**-2 + cosmo.Omega_de * de_a


@partial(jnp.vectorize, excluded=(1,))
def H_deriv(a, cosmo):
r"""Hubble parameter derivatives, :math:`\mathrm{d}\ln H / \mathrm{d}\ln a`.

Parameters
----------
a : ArrayLike
Scale factors.
cosmo : Cosmology

Returns
-------
dlnH_dlna : jax.Array of cosmo.dtype
Hubble parameter derivatives.

"""
a = jnp.asarray(a, dtype=cosmo.dtype)

E2_value, E2_grad = value_and_grad(E2)(a, cosmo)
return 0.5 * a * E2_grad / E2_value


def Omega_m_a(a, cosmo):
r"""Matter density parameters, :math:`\Omega_\mathrm{m}(a)`.

Parameters
----------
a : ArrayLike
Scale factors.
cosmo : Cosmology

Returns
-------
Omega : jax.Array of cosmo.dtype
Matter density parameters.

Notes
-----

.. math::

\Omega_\mathrm{m}(a) = \frac{\Omega_\mathrm{m} a^{-3}}{E^2(a)}

"""
a = jnp.asarray(a, dtype=cosmo.dtype)

return cosmo.Omega_m / (a**3 * E2(a, cosmo))


def distance_cache(cosmo):
r"""Cache the comoving and physical distance tables at ``cosmo.distance_a``.

Parameters
----------
cosmo : Cosmology

Returns
-------
cosmo : Cosmology
A new object containing a distance table, in unit :math:`L`, shape ``(2,
cosmo.distance_a_num,)``, and precision `cosmo.dtype`.

Notes
-----
The comoving horizon in the conformal time :math:`\eta`

.. math::

c \eta = \int_0^t \frac{c \mathrm{d} t}{a(t)}
= d_H \int_0^a \frac{\mathrm{d} a'}{a'^2 E(a')}
= d_H \int_z^\infty \frac{\mathrm{d} z'}{E(z')}.

The light-travel distance in the age or physical time :math:`t`

.. math::

ct = \int_0^t c \mathrm{d} t
= d_H \int_0^a \frac{\mathrm{d} a'}{a' E(a')}
= d_H \int_z^\infty \frac{\mathrm{d} z'}{(1+z') E(z')}.

"""
#FIXME in the future use jax.scipy.integrate.cumulative_trapezoid or Cubic Hermite spline antiderivatives
a = cosmo.distance_a[1:]
da = jnp.diff(cosmo.distance_a, prepend=0)

cdetada = cosmo.d_H / (a**2 * jnp.sqrt(E2(a, cosmo)))
cdetada = jnp.concatenate((jnp.array([0, 0]), cdetada))
cdeta = (cdetada[:-1] + cdetada[1:]) / 2 * da
ceta = jnp.cumsum(cdeta)

cdtda = cosmo.d_H / (a * jnp.sqrt(E2(a, cosmo)))
cdtda = jnp.concatenate((jnp.array([0, 0]), cdtda))
cdt = (cdtda[:-1] + cdtda[1:]) / 2 * da
ct = jnp.cumsum(cdt)

distance = jnp.stack((ceta, ct), axis=0)

#return cosmo.replace(distance=distance)
return distance


def _SK_closed(chi, Ksqrt):
return jnp.sin(Ksqrt * chi) / Ksqrt

def _SK_flat(chi, Ksqrt):
return chi

def _SK_open(chi, Ksqrt):
return jnp.sinh(Ksqrt * chi) / Ksqrt


def distance(a_e, cosmo, type='radial', a_o=1):
r"""Interpolate the distances and compute different distance or time measures
between emissions and observations.

Parameters
----------
a_e : ArrayLike
Scale factors at emission.
cosmo : Cosmology
type : str in {'radial', 'transverse', 'angdiam', 'luminosity', 'light', 'conformal', 'lookback'}, optional
Type of distances or times to return, among radial comoving distance, transverse
comoving distance, angular diameter distance, luminosity distance, light-travel
distance, conformal time, and lookback time.
a_o : ArrayLike, optional
Scale factors at observation.

Returns
-------
d : jax.Array
Distances in :math:`L` or times in :math:`T`.

Notes
-----
The line-of-sight or radial comoving distance, related to the conformal time
:math:`\eta`

.. math::

\chi = c \eta
= \int_{t_\mathrm{e}}^{t_\mathrm{o}} \frac{c \mathrm{d} t}{a(t)}
= d_H \int_{a_\mathrm{e}}^{a_\mathrm{o}} \frac{\mathrm{d} a'}{a'^2 E(a'}
= d_H \int_{z_\mathrm{o}}^{z_\mathrm{e}} \frac{\mathrm{d} z'}{E(z')}.

The transverse comoving or comoving angular diameter distance

.. math::

r = \frac{S_K(\sqrt{|K|} \chi)}{\sqrt{|K|}},

where :math:`S_K` is sin, identity, or sinh for positive, zero, or negative
:math:`K`, respectively.

The angular diameter distance and luminosity distance

.. math::

d_\mathrm{A} &= \frac{a_\mathrm{e}}{a_\mathrm{o}} r, \\
d_L &= \frac{a_\mathrm{o}}{a_\mathrm{e}} r.

The light-travel distance in the lookback or physical time :math:`t`

.. math::

ct = \int_{t_\mathrm{e}}^{t_\mathrm{o}} c \mathrm{d} t
= d_H \int_{a_\mathrm{e}}^{a_\mathrm{o}} \frac{\mathrm{d} a'}{a' E(a'}
= d_H \int_{z_\mathrm{o}}^{z_\mathrm{e}} \frac{\mathrm{d} z'}{(1+z') E(z')}.

"""
if cosmo.distance is None:
raise ValueError('distance table is empty: run Cosmology.cache or distance_cache first')

a_e = jnp.asarray(a_e)
a_o = jnp.asarray(a_o)

phys = 1 if type in {'light', 'lookback'} else 0
d_o = jnp.interp(a_o, cosmo.distance_a, cosmo.distance[phys])
d_e = jnp.interp(a_e, cosmo.distance_a, cosmo.distance[phys])
d = d_o - d_e

if type in {'lookback', 'conformal'}:
d /= cosmo.c

if type in {'radial', 'light', 'lookback', 'conformal'}:
return d

branches = _SK_closed, _SK_flat, _SK_open
Ksqrt = jnp.sqrt(jnp.abs(cosmo.K))
d = switch(jnp.int8(jnp.sign(cosmo.Omega_K)) + 1, branches, d, Ksqrt)
if type == 'transverse':
return d
if type == 'angdiam':
return a_e / a_o * d
if type == 'luminosity':
return a_o / a_e * d

raise ValueError(f'{type=} not supported')
Loading