-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Description
Is there a possibility to have a model that contains a list of submodules such that I can decide which module I use when calling the model? Here's a minimal (not working) example:
import flax.nnx as nnx
import jax
class MyModule(nnx.Module):
def __init__(self, indim, outdim, t_max, rngs=None):
super().__init__()
self.indim = indim
self.outdim = outdim
self.t_max = t_max
# Define a simple linear layer as an example
self.models = [nnx.Linear(indim, outdim, rngs=rngs) for _ in range(t_max)]
def __call__(self, x, t):
return nnx.switch(t, self.models, x) # ERROR
rngs = nnx.Rngs(0)
model = MyModule(indim=10, outdim=5, t_max=3, rngs=rngs)
# Example usage
x = jax.random.normal(jax.random.PRNGKey(0), (1, 10))
t = jax.numpy.array(1) # Choose a time step
output = model(x, t)
print(output)
I saw that nnx.Sequential
has something similar using nnx.data
, but for me, nnx has no attribute nnx.data
. Why is that?
System Info
flax 0.10.7
jax 0.6.2
jaxlib 0.6.2
Ubuntu 22.04.4 LTS (WSL)
Metadata
Metadata
Assignees
Labels
No labels