-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Description
Hello! When I'm trying to use some interface for my model implementation, I've got a troubles with typing.Protocol and nnx.Module usage, how i can fix it?
System information
Mac OS 15.5, M2
Flax=0.10.6, jax=0.6.1, jaxlib=0.6.1
Python3.12
Problem you have encountered:
Running a code with class TEST(nnx.Module, SomeInterface(Protocol)) i have an error
What you expected to happen:
All would work correctly
Logs, error messages, etc:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
Steps to reproduce:
Guys please just run a code below:
from flax import nnx
import typing
import jax.numpy as jnp
class ModuleInterface(typing.Protocol):
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
pass
class Module(nnx.Module, ModuleInterface):
def __init__(self, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(4, 16, rngs=rngs)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.linear1(x)
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
return self(x)
Metadata
Metadata
Assignees
Labels
No labels