Skip to content

Commit 4645097

Browse files
committed
wip: use-cases, batch inverse
1 parent 39c4325 commit 4645097

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

examples/use-cases/inverse.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import timeit
2+
3+
from jax import jit, vmap
4+
import casadi as cs
5+
import jax.numpy as jnp
6+
import numpy as np
7+
import pinocchio as pin
8+
import pinocchio.casadi as cpin
9+
from robot_descriptions import panda_mj_description
10+
11+
from jaxadi import convert
12+
13+
model = pin.buildModelFromMJCF(panda_mj_description.MJCF_PATH)
14+
cmodel = cpin.Model(model)
15+
cdata = cmodel.createData()
16+
17+
q = cs.SX.sym("q", model.nq)
18+
dq = cs.SX.sym("dq", model.nv)
19+
cpin.computeAllTerms(cmodel, cdata, q, dq)
20+
M = cdata.M
21+
22+
# print(M)
23+
# print(cs.inv(M))
24+
25+
fn = cs.Function("fn", [q, dq], [cs.inv(M)])
26+
jax_fn = convert(fn)
27+
vjax_fn = jit(vmap(jax_fn, in_axes=(0, 0)))
28+
num_batches = 100000
29+
# Warmup
30+
vjax_fn(jnp.ones((model.nq, num_batches)), jnp.ones((model.nv, num_batches)))
31+
num_runs = 100000
32+
33+
# Compare times
34+
print("CasADi:")
35+
print(timeit.timeit(lambda: fn(np.ones(model.nq), np.ones(model.nv)), number=num_runs))
36+
print("JAX-ADi:")
37+
print(
38+
timeit.timeit(
39+
lambda: vjax_fn(jnp.ones((model.nq, num_batches)), jnp.ones((model.nv, num_batches))),
40+
number=num_runs // num_batches,
41+
)
42+
)

0 commit comments

Comments
 (0)