Skip to content

module with list of modules #4845

@RisingPhoelix

Description

@RisingPhoelix

Is there a possibility to have a model that contains a list of submodules such that I can decide which module I use when calling the model? Here's a minimal (not working) example:

import flax.nnx as nnx
import jax


class MyModule(nnx.Module):
    def __init__(self, indim, outdim, t_max, rngs=None):
        super().__init__()
        self.indim = indim
        self.outdim = outdim
        self.t_max = t_max

        # Define a simple linear layer as an example
        self.models = [nnx.Linear(indim, outdim, rngs=rngs) for _ in range(t_max)]

    def __call__(self, x, t):
        return nnx.switch(t, self.models, x) # ERROR

rngs = nnx.Rngs(0)
model = MyModule(indim=10, outdim=5, t_max=3, rngs=rngs)

# Example usage
x = jax.random.normal(jax.random.PRNGKey(0), (1, 10))
t = jax.numpy.array(1)  # Choose a time step
output = model(x, t)
print(output)

I saw that nnx.Sequential has something similar using nnx.data, but for me, nnx has no attribute nnx.data. Why is that?

System Info

flax 0.10.7
jax 0.6.2
jaxlib 0.6.2

Ubuntu 22.04.4 LTS (WSL)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions