-
|
Hi, I noticed that jit compilation fails when my model is passed as a parameter to a function, but not when it is just used from within the outer scope. Here is a minimal working example: jit compiling works fine for If I remove (Additional context: I normally create a flax |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
We get an error: One solution to this is you can specify that that argument is static (check out the API documentation on But Alternatively, you could pass in a string to the Flax also provides "lifted transformed" versions of JAX transformations, which would allow you to use But I think for your situation, one of the above solutions should suffice. |
Beta Was this translation helpful? Give feedback.
jax.jitwon't work on a function with aModuleargument type. We can demonstrate this by making a dummy function with the same argument signature:We get an error:
One solution to this is you can specify that that argument is static (check out the API documentation on
static_argnumshere) via: