How can I make non-trainable variable using nnx.Module? #4533
Unanswered
SangminLee0828
asked this question in
Q&A
Replies: 2 comments 3 replies
-
|
My quick take is : by assigning those parameter to self, they are found by jax.grad as parameters to differentiate with (altough I am suprised because nnx.param exists for a good reason). however, if you want something purely static, maybe the usage of nnx.Module or a class overall is not needed ? (if you want to make variance and mean on the fly) |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
The solution is to create a filter for the train trainable Variable and pass it to both class Classifier(nnx.Module):
def __init__(self, embed_dim, num_classes, backbone, rngs):
self.backbone = backbone
self.head = nnx.Linear(embed_dim, num_classes, rngs=rngs)
def __call__(self, x):
x = self.backbone(x)
x = self.head(x)
return x
def load_model():
return nnx.Linear(784, 1024, rngs=nnx.Rngs(0))
backbone = load_model()
classifier = Classifier(1024, 10, backbone, rngs=nnx.Rngs(1))
# filter to select only Params on head path
head_params = nnx.All(nnx.Param, nnx.PathContains('head'))
optimizer = nnx.Optimizer(
classifier,
tx=optax.adamw(3e-4),
wrt=head_params, # filter head params
)
# simple train step
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
logits = model(x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
diff_state = nnx.DiffState(0, head_params) # filter head params of the first argument
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(grads)
x = jnp.ones((1, 784))
y = jnp.ones((1,), jnp.int32)
train_step(classifier, optimizer, x, y) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am trying to create a normalization layer. This normalization layer has 'mean' and 'variance' inside, so when the values come in, the output values will be normalized value using the stored mean and variance.
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
How can I make 'self.mean' and 'self.variance' not trainable?
Beta Was this translation helpful? Give feedback.
All reactions