Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected speedup from wrapping function call in trivial jax.lax.cond statement #18440

Open
cgiovanetti opened this issue Oct 17, 2024 · 1 comment

Comments

@cgiovanetti
Copy link

See jax-ml/jax#21065 from the jax github; someone there recommended we report this upstream.

We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.

In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.

from functools import partial
import jax.numpy as jnp
import jax
from jax import jit, lax
from jax._src.core import jaxpr_as_fun
from jax.scipy.ndimage import map_coordinates
import time

Nx = 2
Ny = 2
x_axis = jnp.linspace(5., 12.75, Nx)

@partial(jit, static_argnums=(1,))
def main(B, use_cond):
    y_axis = jnp.linspace(0, 1, Ny)

    def loop_in_main(carry, i):
        B = carry
        y = y_axis[i]

        """ Obtain an array A using interp_A_from_B(), picking one of three ways """
        if not use_cond:
            # Case 1: We simply run interp_A_from_B() every step
            A = interp_A_from_B((y, y_axis, B))
            # optionally impose optimization barrier
            # A = lax.optimization_barrier(A)

        else:
            # Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
            # interp_A_from_B since index i is always greater than -1.
            # For some reason we observe a speed up over case 1.
            A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))

        # Update B array with values of A from this loop.
        B = set_B_to_A(i, B, A)

        return B, None
    # Use lax.scan to run loop and update B Ny times.
    # Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
    B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))

    return B

@jit # cond jits, jits both branches and may provide a speedup
def interp_A_from_B(params):
    # B is a (Nx, Ny) array.
    # A is a (Nx,) array.

    y, y_axis, B    = params
    # Precise value of y to interpolate at.
    y_prime         = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
    # Convert to index position within y_axis, to use with ndimage.map_coordinates.
    y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
    # Interpolated version of A from B via 2D map_coordinates.
    interp          = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
    # Here, only use the interpolated result for values of y_prime larger than the smallest y in     y_axis.
    condition       = y_prime < y_axis[0]

    # Put A array together, with some fill in values for where we don't want the interpolated value.
    A               = condition * jnp.exp(-x_axis[:Nx-1]) \
                    + (1-condition) * interp
    A               = jnp.append(A, jnp.exp(-x_axis[-1]))

    return A

def set_B_to_A(i, B, A):
    # Update a column of B with the current value of A.
    B = B.at[:, i].set(A)
    return B

def false_func(params):
    # Trivial false function, sets all entries of A to some fill values if called.
    A = jnp.exp(-x_axis)
    return A

""" Running main() a couple times to see the speed """
# Initial value of B is just (Nx, Ny) size arrays of zeros.
B = jnp.zeros((Nx, Ny), dtype="float32")
num_iter = 500

for use_cond in [True, False]:
    jax.block_until_ready(main(B, use_cond))

    s = time.time()
    for i in range(num_iter):
        jax.block_until_ready(main(B, use_cond))
    print((time.time() - s) / num_iter)

The "True" output, where we're using the conditional, is consistently faster than the "False" output, where there is no lax.cond (5e-6 seconds vs 6e-6 seconds for this example, though large Nx, Ny exacerbate this difference). Adding an optimization barrier does not reduce the runtime of the block where lax.cond is not called.

I tried to peek at the HLO output but unfortunately it's a bit too byzantine for me to glean anything from.

@cheshire
Copy link
Contributor

@ezhulenev @penpornk I understand this is on CPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants