Skip to content

odeint is slow #5

@modichirag

Description

@modichirag

I was experimenting with some time tests and find that odeint to calculate the growth functions is quite slow.
I have tried to hack and replace it with rk4 integration in the growth function itself which seems to be much faster.

    ode_jit = jit(ode)
    def rk4_ode_jit(carry, t):
        y, t_prev = carry
        h = t - t_prev
	k1 = ode_jit(y, t_prev, cosmo)
	k2 = ode_jit(y + h * k1 / 2, t_prev + h / 2, cosmo)
        k3 = ode_jit(y + h * k2 / 2, t_prev + h / 2, cosmo)
	k4 = ode_jit(y + h * k3, t, cosmo)
        y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
        return (y, t), y

    (yf, _), G = scan(rk4_ode_jit, (G_ic, lna[0]), lna)

Then I do time tests for 64^3 simulation wherein I pass the cosmology parameters, initial modes as input and calculate time for different outputs (just doing boltzmann solve vs boltzmann + LPT).

@jit
def simulate_boltz(modes, omegam, conf):
    '''Evaluate growth & tranfer function with odeint                                                                                                                                                                                                                                                                         
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate_boltz_rk4(modes, omegam, conf):
    '''Evaluate growth & tranfer function with custom rk4                                                                                                                                                                                                                                                                     
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with odeint                                                                                                                                                                                                                                               
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

@jit
def simulate_rk4(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with custom rk4                                                                                                                                                                                                                                           
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo


@jit
def simulate_nbody(modes, cosmo):
    '''Run LPT simulation without evaluating growth & tranfer function                                                                                                                                                                                                                                                        
    '''
    ptcl, obsvbl = lpt(modes, cosmo)
    conf = cosmo.conf
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

The time taken for each of these is

Time taken for boltzmann: 0.5971660375595093
Time taken for boltzmann rk4: 0.007928729057312012
Time taken for LPT: 0.0041596412658691405
Time taken for simulation (Boltzmann + LPT): 0.463437557220459
Time taken for simulation rk4 (Boltzmann + LPT): 0.04284675121307373

rk4 seems to be much faster than using odeint to generate growth rate.

If what I am doing in running the simulations is sensible and the timing numbers portray an accurate picture,
then we should figure a better way (jaxified) to code this?

I have attached the full script as txt file (copy paste in pmwd/pmwd folder, convert to py and it should run)
test_growth.txt

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions