-
Notifications
You must be signed in to change notification settings - Fork 21
Description
I was experimenting with some time tests and find that odeint to calculate the growth functions is quite slow.
I have tried to hack and replace it with rk4 integration in the growth function itself which seems to be much faster.
ode_jit = jit(ode)
def rk4_ode_jit(carry, t):
y, t_prev = carry
h = t - t_prev
k1 = ode_jit(y, t_prev, cosmo)
k2 = ode_jit(y + h * k1 / 2, t_prev + h / 2, cosmo)
k3 = ode_jit(y + h * k2 / 2, t_prev + h / 2, cosmo)
k4 = ode_jit(y + h * k3, t, cosmo)
y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
return (y, t), y
(yf, _), G = scan(rk4_ode_jit, (G_ic, lna[0]), lna)
Then I do time tests for 64^3 simulation wherein I pass the cosmology parameters, initial modes as input and calculate time for different outputs (just doing boltzmann solve vs boltzmann + LPT).
@jit
def simulate_boltz(modes, omegam, conf):
'''Evaluate growth & tranfer function with odeint
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann(cosmo)
mesh = None
return mesh, cosmo
@jit
def simulate_boltz_rk4(modes, omegam, conf):
'''Evaluate growth & tranfer function with custom rk4
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann_rk4(cosmo)
mesh = None
return mesh, cosmo
@jit
def simulate(modes, omegam, conf):
'''Run LPT simulation with evaluating growth & tranfer function with odeint
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann(cosmo)
ptcl, obsvbl = lpt(modes, cosmo)
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
@jit
def simulate_rk4(modes, omegam, conf):
'''Run LPT simulation with evaluating growth & tranfer function with custom rk4
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann_rk4(cosmo)
ptcl, obsvbl = lpt(modes, cosmo)
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
@jit
def simulate_nbody(modes, cosmo):
'''Run LPT simulation without evaluating growth & tranfer function
'''
ptcl, obsvbl = lpt(modes, cosmo)
conf = cosmo.conf
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
The time taken for each of these is
Time taken for boltzmann: 0.5971660375595093
Time taken for boltzmann rk4: 0.007928729057312012
Time taken for LPT: 0.0041596412658691405
Time taken for simulation (Boltzmann + LPT): 0.463437557220459
Time taken for simulation rk4 (Boltzmann + LPT): 0.04284675121307373
rk4 seems to be much faster than using odeint to generate growth rate.
If what I am doing in running the simulations is sensible and the timing numbers portray an accurate picture,
then we should figure a better way (jaxified) to code this?
I have attached the full script as txt file (copy paste in pmwd/pmwd folder, convert to py and it should run)
test_growth.txt