Description
See jax-ml/jax#21065 from the jax github; someone there recommended we report this upstream.
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
from functools import partial
import jax.numpy as jnp
import jax
from jax import jit, lax
from jax._src.core import jaxpr_as_fun
from jax.scipy.ndimage import map_coordinates
import time
Nx = 2
Ny = 2
x_axis = jnp.linspace(5., 12.75, Nx)
@partial(jit, static_argnums=(1,))
def main(B, use_cond):
y_axis = jnp.linspace(0, 1, Ny)
def loop_in_main(carry, i):
B = carry
y = y_axis[i]
""" Obtain an array A using interp_A_from_B(), picking one of three ways """
if not use_cond:
# Case 1: We simply run interp_A_from_B() every step
A = interp_A_from_B((y, y_axis, B))
# optionally impose optimization barrier
# A = lax.optimization_barrier(A)
else:
# Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
# interp_A_from_B since index i is always greater than -1.
# For some reason we observe a speed up over case 1.
A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))
# Update B array with values of A from this loop.
B = set_B_to_A(i, B, A)
return B, None
# Use lax.scan to run loop and update B Ny times.
# Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))
return B
@jit # cond jits, jits both branches and may provide a speedup
def interp_A_from_B(params):
# B is a (Nx, Ny) array.
# A is a (Nx,) array.
y, y_axis, B = params
# Precise value of y to interpolate at.
y_prime = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
# Convert to index position within y_axis, to use with ndimage.map_coordinates.
y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
# Interpolated version of A from B via 2D map_coordinates.
interp = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
# Here, only use the interpolated result for values of y_prime larger than the smallest y in y_axis.
condition = y_prime < y_axis[0]
# Put A array together, with some fill in values for where we don't want the interpolated value.
A = condition * jnp.exp(-x_axis[:Nx-1]) \
+ (1-condition) * interp
A = jnp.append(A, jnp.exp(-x_axis[-1]))
return A
def set_B_to_A(i, B, A):
# Update a column of B with the current value of A.
B = B.at[:, i].set(A)
return B
def false_func(params):
# Trivial false function, sets all entries of A to some fill values if called.
A = jnp.exp(-x_axis)
return A
""" Running main() a couple times to see the speed """
# Initial value of B is just (Nx, Ny) size arrays of zeros.
B = jnp.zeros((Nx, Ny), dtype="float32")
num_iter = 500
for use_cond in [True, False]:
jax.block_until_ready(main(B, use_cond))
s = time.time()
for i in range(num_iter):
jax.block_until_ready(main(B, use_cond))
print((time.time() - s) / num_iter)
The "True" output, where we're using the conditional, is consistently faster than the "False" output, where there is no lax.cond (5e-6 seconds vs 6e-6 seconds for this example, though large Nx, Ny exacerbate this difference). Adding an optimization barrier does not reduce the runtime of the block where lax.cond is not called.
I tried to peek at the HLO output but unfortunately it's a bit too byzantine for me to glean anything from.