When does recreating a model trigger nnx.jit recompilation?
#4474
Unanswered
NiklasKappel
asked this question in
Q&A
Replies: 2 comments 2 replies
-
|
Further testing reveals the culprit is the use of class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
# self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.avg_pool = nnx.avg_pool
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)), window_shape=(2, 2), strides=(2, 2))
x = self.avg_pool(nnx.relu(self.conv2(x)), window_shape=(2, 2), strides=(2, 2))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return xMy guess is that |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Maybe the behavior is different with jax.tree_util.Partial? Have you tried? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Consider this MWE that uses the model from the MNIST tutorial:
Note that, when running
try_CNN, thefoo_stepfunction is JIT compiled twice (i.e. theprintside effect is triggered twice), once with the first CNN instance and once with the second, even though they are essentially the same. When runningtry_Foothough, thefoo_stepfunction is compiled only once. The latter is the behavior I would expect fromjax.jit, considering that recreating a model does not change anything about the shape of parameter arrays etc.What is the reason for the extra compilation happening in
try_CNN?Beta Was this translation helpful? Give feedback.
All reactions