Replies: 1 comment 1 reply
-
|
If your input is a python list, you have to use a for loop. To use jaxy things like linears = [nnx.Linear(8, 8, rngs=nnx.Rngs(i)) for i in range(10)]
inputs = [jax.random.normal(jax.random.key(i), (4, 8)) for i in range(10)]
linears = jax.tree.map(lambda *xs: jnp.stack(xs), *linears)
inputs = jax.tree.map(lambda *xs: jnp.stack(xs), *inputs)
@jax.vmap
def f(m, x):
return m(x)
outs = f(linears, inputs) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
i have two same-size lists
models: List[nnx.Module]andinps: List[jax.Array].modelsare separate instantiations of the samennx.Modulesubclass.i calculate outputs as:
outs = [m(i) for m, i in zip(models, inps)].how can i parallelize this without a for loop?
from what i understand,
jax.vmapandjax.lax.maparen't made to do this cleanlyBeta Was this translation helpful? Give feedback.
All reactions