Skip to content

Commit 10b4de5

Browse files
committed
WIP minor
1 parent 6c23227 commit 10b4de5

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

pmwd/nbody.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def itp(a, obsvbl):
194194
if conf.a_snapshots is not None:
195195
for a in conf.a_snapshots:
196196
obsvbl = cond(jnp.logical_and(a_prev < a, a <= a_next),
197-
partial(itp, a), lambda *args: obsvbl, obsvbl)
197+
partial(itp, a), lambda *args: obsvbl, obsvbl)
198198

199199
obsvbl['ptcl_prev'] = ptcl
200200

@@ -212,7 +212,7 @@ def observe_init(a, ptcl, obsvbl, cosmo, conf):
212212
# all output snapshots
213213
obsvbl['snapshots'] = {
214214
a_snap: Particles(ptcl.conf, ptcl.pmid, jnp.zeros_like(ptcl.disp),
215-
vel=jnp.zeros_like(ptcl.vel))
215+
vel=jnp.zeros_like(ptcl.vel))
216216
for a_snap in conf.a_snapshots
217217
}
218218
# the nbody a step of output snapshots, (,]
@@ -227,13 +227,13 @@ def observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo
227227
if conf.a_snapshots is not None:
228228
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
229229
ptcl_cot, cosmo_cot = cond(a_step[1] == a_next, itp_next_adj,
230-
lambda *args: (ptcl_cot, cosmo_cot),
231-
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
232-
ptcl, a_step[0], a_step[1], a_snap, cosmo)
230+
lambda *args: (ptcl_cot, cosmo_cot),
231+
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
232+
ptcl, a_step[0], a_step[1], a_snap, cosmo)
233233
ptcl_cot, cosmo_cot = cond(a_step[1] == a_prev, itp_prev_adj,
234-
lambda *args: (ptcl_cot, cosmo_cot),
235-
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
236-
ptcl, a_step[0], a_step[1], a_snap, cosmo)
234+
lambda *args: (ptcl_cot, cosmo_cot),
235+
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
236+
ptcl, a_step[0], a_step[1], a_snap, cosmo)
237237

238238
return ptcl_cot, cosmo_cot
239239

@@ -244,9 +244,9 @@ def observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, co
244244
# check if the last ptcl is used in interpolation
245245
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
246246
ptcl_cot, cosmo_cot = cond(a_step[1] == a, itp_next_adj,
247-
lambda *args: (ptcl_cot, cosmo_cot),
248-
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
249-
ptcl, a_step[0], a_step[1], a_snap, cosmo)
247+
lambda *args: (ptcl_cot, cosmo_cot),
248+
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
249+
ptcl, a_step[0], a_step[1], a_snap, cosmo)
250250

251251
return ptcl_cot, cosmo_cot
252252

pmwd/obs_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def itp_prev(ptcl0, a0, a1, a, cosmo):
99
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
10-
function returns the disp and vel from the first ptcl at a0."""
10+
function returns the disp and vel from the first ptcl at a0."""
1111
Da = a1 - a0
1212
t = (a - a0) / Da
1313
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
@@ -27,6 +27,7 @@ def itp_prev(ptcl0, a0, a1, a, cosmo):
2727

2828

2929
def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):
30+
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_prev."""
3031
# iptcl_cot is the cotangent of the interpolated ptcl
3132
(disp, vel), itp_prev_vjp = vjp(itp_prev, ptcl0, a0, a1, a, cosmo)
3233
ptcl0_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_prev_vjp(
@@ -41,7 +42,7 @@ def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):
4142

4243
def itp_next(ptcl1, a0, a1, a, cosmo):
4344
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
44-
function returns the disp and vel from the second ptcl at a1."""
45+
function returns the disp and vel from the second ptcl at a1."""
4546
Da = a1 - a0
4647
t = (a - a0) / Da
4748
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
@@ -61,6 +62,7 @@ def itp_next(ptcl1, a0, a1, a, cosmo):
6162

6263

6364
def itp_next_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl1, a0, a1, a, cosmo):
65+
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_next."""
6466
# iptcl_cot is the cotangent of the interpolated ptcl
6567
(disp, vel), itp_next_vjp = vjp(itp_next, ptcl1, a0, a1, a, cosmo)
6668
ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_next_vjp(

0 commit comments

Comments
 (0)