We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7c88219 commit 3c5bd18Copy full SHA for 3c5bd18
benchmarks/run_benchmark.py
@@ -108,8 +108,8 @@ def run_jaxadi_benchmark(fn, inputs):
108
fn_name = fn.name()
109
110
# apply jaxadi
111
- jax_fn = convert(fn, compile=True)
112
- vmapped_fn = jax.vmap(jax_fn)
+ jax_fn = convert(fn, compile=False)
+ vmapped_fn = jax.jit(jax.vmap(jax_fn))
113
114
for i, n_envs in enumerate(N_ENVS_SWEEP):
115
print(f"Running Jaxadi benchmark for {n_envs} environments with function {fn_name}...")
@@ -124,6 +124,9 @@ def run_jaxadi_benchmark(fn, inputs):
124
# remove the compiled function from the memory and inputs
125
del inputs
126
127
+ del jax_fn, vmapped_fn
128
+ gc.collect()
129
+
130
return results
131
132
0 commit comments