Skip to content

Commit 3c5bd18

Browse files
committed
fix: change order of jit vmap
1 parent 7c88219 commit 3c5bd18

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

benchmarks/run_benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def run_jaxadi_benchmark(fn, inputs):
108108
fn_name = fn.name()
109109

110110
# apply jaxadi
111-
jax_fn = convert(fn, compile=True)
112-
vmapped_fn = jax.vmap(jax_fn)
111+
jax_fn = convert(fn, compile=False)
112+
vmapped_fn = jax.jit(jax.vmap(jax_fn))
113113

114114
for i, n_envs in enumerate(N_ENVS_SWEEP):
115115
print(f"Running Jaxadi benchmark for {n_envs} environments with function {fn_name}...")
@@ -124,6 +124,9 @@ def run_jaxadi_benchmark(fn, inputs):
124124
# remove the compiled function from the memory and inputs
125125
del inputs
126126

127+
del jax_fn, vmapped_fn
128+
gc.collect()
129+
127130
return results
128131

129132

0 commit comments

Comments
 (0)