You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
See jax-ml/jax#21065 from the jax github; someone there recommended we report this upstream.
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
from functools import partial
import jax.numpy as jnp
import jax
from jax import jit, lax
from jax._src.core import jaxpr_as_fun
from jax.scipy.ndimage import map_coordinates
import time
Nx = 2
Ny = 2
x_axis = jnp.linspace(5., 12.75, Nx)
@partial(jit, static_argnums=(1,))
def main(B, use_cond):
y_axis = jnp.linspace(0, 1, Ny)
def loop_in_main(carry, i):
B = carry
y = y_axis[i]
""" Obtain an array A using interp_A_from_B(), picking one of three ways """
if not use_cond:
# Case 1: We simply run interp_A_from_B() every step
A = interp_A_from_B((y, y_axis, B))
# optionally impose optimization barrier
# A = lax.optimization_barrier(A)
else:
# Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
# interp_A_from_B since index i is always greater than -1.
# For some reason we observe a speed up over case 1.
A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))
# Update B array with values of A from this loop.
B = set_B_to_A(i, B, A)
return B, None
# Use lax.scan to run loop and update B Ny times.
# Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))
return B
@jit # cond jits, jits both branches and may provide a speedup
def interp_A_from_B(params):
# B is a (Nx, Ny) array.
# A is a (Nx,) array.
y, y_axis, B = params
# Precise value of y to interpolate at.
y_prime = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
# Convert to index position within y_axis, to use with ndimage.map_coordinates.
y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
# Interpolated version of A from B via 2D map_coordinates.
interp = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
# Here, only use the interpolated result for values of y_prime larger than the smallest y in y_axis.
condition = y_prime < y_axis[0]
# Put A array together, with some fill in values for where we don't want the interpolated value.
A = condition * jnp.exp(-x_axis[:Nx-1]) \
+ (1-condition) * interp
A = jnp.append(A, jnp.exp(-x_axis[-1]))
return A
def set_B_to_A(i, B, A):
# Update a column of B with the current value of A.
B = B.at[:, i].set(A)
return B
def false_func(params):
# Trivial false function, sets all entries of A to some fill values if called.
A = jnp.exp(-x_axis)
return A
""" Running main() a couple times to see the speed """
# Initial value of B is just (Nx, Ny) size arrays of zeros.
B = jnp.zeros((Nx, Ny), dtype="float32")
num_iter = 500
for use_cond in [True, False]:
jax.block_until_ready(main(B, use_cond))
s = time.time()
for i in range(num_iter):
jax.block_until_ready(main(B, use_cond))
print((time.time() - s) / num_iter)
The "True" output, where we're using the conditional, is consistently faster than the "False" output, where there is no lax.cond (5e-6 seconds vs 6e-6 seconds for this example, though large Nx, Ny exacerbate this difference). Adding an optimization barrier does not reduce the runtime of the block where lax.cond is not called.
I tried to peek at the HLO output but unfortunately it's a bit too byzantine for me to glean anything from.
The text was updated successfully, but these errors were encountered:
See jax-ml/jax#21065 from the jax github; someone there recommended we report this upstream.
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
The "True" output, where we're using the conditional, is consistently faster than the "False" output, where there is no lax.cond (5e-6 seconds vs 6e-6 seconds for this example, though large Nx, Ny exacerbate this difference). Adding an optimization barrier does not reduce the runtime of the block where lax.cond is not called.
I tried to peek at the HLO output but unfortunately it's a bit too byzantine for me to glean anything from.
The text was updated successfully, but these errors were encountered: