Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pmwd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class Configuration:

chunk_size: int = 2**24

# observables
a_snapshots: Optional[Tuple[float]] = None

def __post_init__(self):
if self._is_transforming():
return
Expand Down
97 changes: 83 additions & 14 deletions pmwd/nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from jax import value_and_grad, jit, vjp, custom_vjp
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.lax import cond

from pmwd.boltzmann import growth
from pmwd.cosmology import E2, H_deriv
from pmwd.gravity import gravity
from pmwd.obs_util import interptcl, itp_prev_adj, itp_next_adj
from pmwd.particles import Particles


def _G_D(a, cosmo, conf):
Expand Down Expand Up @@ -183,11 +186,69 @@ def coevolve_init(a, ptcl, cosmo, conf):


def observe(a_prev, a_next, ptcl, obsvbl, cosmo, conf):
pass
def itp(a, obsvbl):
snap = interptcl(obsvbl['ptcl_prev'], ptcl, a_prev, a_next, a, cosmo)
obsvbl['snapshots'][a] = snap
return obsvbl

if conf.a_snapshots is not None:
for a in conf.a_snapshots:
obsvbl = cond(jnp.logical_and(a_prev < a, a <= a_next),
partial(itp, a), lambda *args: obsvbl, obsvbl)

obsvbl['ptcl_prev'] = ptcl

return obsvbl


def observe_init(a, ptcl, obsvbl, cosmo, conf):
pass
# a dict to carry all observables and related useful information
obsvbl = {}

# to carry the prev ptcl, starting with lpt ptcl
obsvbl['ptcl_prev'] = ptcl

if conf.a_snapshots is not None:
# all output snapshots
obsvbl['snapshots'] = {
a_snap: Particles(ptcl.conf, ptcl.pmid, jnp.zeros_like(ptcl.disp),
vel=jnp.zeros_like(ptcl.vel))
for a_snap in conf.a_snapshots
}
# the nbody a step of output snapshots, (,]
idx = jnp.searchsorted(conf.a_nbody, jnp.asarray(conf.a_snapshots), side='left')
obsvbl['snap_a_step'] = jnp.array((conf.a_nbody[idx-1], conf.a_nbody[idx])).T

return obsvbl


def observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf):

if conf.a_snapshots is not None:
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a_next, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
ptcl_cot, cosmo_cot = cond(a_step[1] == a_prev, itp_prev_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot


def observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf):

if conf.a_snapshots is not None:
# check if the last ptcl is used in interpolation
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot


@jit
Expand Down Expand Up @@ -224,53 +285,61 @@ def nbody(ptcl, obsvbl, cosmo, conf, reverse=False):


@jit
def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf):
#ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo)
def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf):

#ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo)

ptcl, ptcl_cot, cosmo_cot_force = force_adj(a, ptcl, ptcl_cot, cosmo, conf)

cosmo_cot = tree_map(jnp.zeros_like, cosmo)

ptcl_cot, cosmo_cot = observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, conf)

return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force


@jit
def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf):
#ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf)
def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, cosmo_cot_force, conf):

#ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo, conf)

ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = integrate_adj(
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)

ptcl_cot, cosmo_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, conf)

return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force


def nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, reverse=False):
def nbody_adj(ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=False):
"""N-body time integration with adjoint equation."""
a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody

ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_init(
a_nbody[-1], ptcl, ptcl_cot, obsvbl_cot, cosmo, conf)
a_nbody[-1], ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf)

for a_prev, a_next in zip(a_nbody[:0:-1], a_nbody[-2::-1]):
ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_step(
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)
a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, cosmo_cot_force, conf)

return ptcl, ptcl_cot, cosmo_cot


def nbody_fwd(ptcl, obsvbl, cosmo, conf, reverse):
ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf, reverse)
return (ptcl, obsvbl), (ptcl, cosmo, conf)
return (ptcl, obsvbl), (ptcl, obsvbl, cosmo, conf)

def nbody_bwd(reverse, res, cotangents):
ptcl, cosmo, conf = res
ptcl, obsvbl, cosmo, conf = res
ptcl_cot, obsvbl_cot = cotangents

ptcl, ptcl_cot, cosmo_cot = nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf,
reverse=reverse)
ptcl, ptcl_cot, cosmo_cot = nbody_adj(
ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=reverse)

return ptcl_cot, obsvbl_cot, cosmo_cot, None
return ptcl_cot, None, cosmo_cot, None

nbody.defvjp(nbody_fwd, nbody_bwd)
109 changes: 109 additions & 0 deletions pmwd/obs_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from jax import jit, vjp
import jax.numpy as jnp

from pmwd.particles import Particles
from pmwd.cosmology import E2


def itp_prev(ptcl0, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the first ptcl at a0."""
Da = a1 - a0
t = (a - a0) / Da
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
# displacement
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
disp = h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel
# velocity
# derivatives of the Hermite basis functions
h00 = 6 * t**2 - 6 * t
h10 = 3 * t**2 - 4 * t + 1
vel = h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

dtype = ptcl0.conf.float_dtype
return disp.astype(dtype), vel.astype(dtype)


def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_prev."""
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_prev_vjp = vjp(itp_prev, ptcl0, a0, a1, a, cosmo)
ptcl0_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_prev_vjp(
(iptcl_cot.disp, iptcl_cot.vel))

disp_cot = ptcl_cot.disp + ptcl0_cot.disp
vel_cot = ptcl_cot.vel + ptcl0_cot.vel
ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot)
cosmo_cot += cosmo_cot_itp
return ptcl_cot, cosmo_cot


def itp_next(ptcl1, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the second ptcl at a1."""
Da = a1 - a0
t = (a - a0) / Da
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
# displacement
h01 = - 2 * t**3 + 3 * t**2
h11 = t**3 - t**2
disp = h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel
# velocity
# derivatives of the Hermite basis functions
h01 = - 6 * t**2 + 6 * t
h11 = 3 * t**2 - 2 * t
vel = h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

dtype = ptcl1.conf.float_dtype
return disp.astype(dtype), vel.astype(dtype)


def itp_next_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl1, a0, a1, a, cosmo):
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_next."""
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_next_vjp = vjp(itp_next, ptcl1, a0, a1, a, cosmo)
ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_next_vjp(
(iptcl_cot.disp, iptcl_cot.vel))

disp_cot = ptcl_cot.disp + ptcl1_cot.disp
vel_cot = ptcl_cot.vel + ptcl1_cot.vel
ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot)
cosmo_cot += cosmo_cot_itp
return ptcl_cot, cosmo_cot


def interptcl(ptcl0, ptcl1, a0, a1, a, cosmo):
"""Given two ptcl snapshots, get the interpolated one at a given time using
cubic Hermite interpolation."""
Da = a1 - a0
t = (a - a0) / Da
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
# displacement
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = - 2 * t**3 + 3 * t**2
h11 = t**3 - t**2
disp = (h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel +
h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel)
# velocity
# derivatives of the Hermite basis functions
h00 = 6 * t**2 - 6 * t
h10 = 3 * t**2 - 4 * t + 1
h01 = - 6 * t**2 + 6 * t
h11 = 3 * t**2 - 2 * t
vel = (h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel +
h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel)
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

iptcl = Particles(ptcl0.conf, ptcl0.pmid, disp, vel=vel)
return iptcl


def interptcl_adj(iptcl_cot, ptcl0, ptcl1, a0, a1, a, cosmo):
iptcl, interptcl_vjp = vjp(interptcl, ptcl0, ptcl1, a0, a1, a, cosmo)
ptcl0_cot, ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = interptcl_vjp(iptcl_cot)
return ptcl0_cot, ptcl1_cot, cosmo_cot_itp