We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3cb18e5 commit 7c88219Copy full SHA for 7c88219
benchmarks/run_benchmark.py
@@ -1,3 +1,4 @@
1
+import gc
2
import os
3
import time
4
@@ -73,6 +74,8 @@ def run_cusadi_benchmark(fn, inputs):
73
74
results["N_EVALS"] = N_EVALS
75
76
for fn in benchmark_fns:
77
+ with torch.no_grad():
78
+ torch.cuda.empty_cache()
79
fn_name = fn.name()
80
for i, n_envs in enumerate(N_ENVS_SWEEP):
81
print(f"Running CUDA benchmark for {n_envs} environments with function {fn_name}...")
@@ -128,6 +131,7 @@ def main():
128
131
if PathsProvider.RUN_CUSADI:
129
132
cuda_results = run_cuda_benchmarks()
130
133
np.savez(f"{cur_dir}/cuda_benchmark_results.npz", **cuda_results)
134
+ gc.collect()
135
136
jaxadi_results = run_jaxadi_benchmarks()
137
np.savez(f"{cur_dir}/jaxadi_benchmark_results.npz", **jaxadi_results)
0 commit comments