Get the PyTree structure of flax model without initializing weights #1421
Unanswered
patil-suraj
asked this question in
Show and tell
Replies: 1 comment
-
|
Thanks for this great tip @patil-suraj |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
the
jax.eval_shapefunction can be used to get the PyTree structure of flax model params and optimizer state (any JAX function for that matter) without having to actually initialize them.this should give
Beta Was this translation helpful? Give feedback.
All reactions