Why TrainState can be use as arguments of jited function without trigger error? #1858
-
|
Hi,
But TrainState can be use in jit-ed function without trigger any error, for example: We have to mark |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
When you look at the definition of the class |
Beta Was this translation helpful? Give feedback.
When you look at the definition of the class
TrainState, you can see that it extendsstruct.PyTreeNode. This means that it is a dataclasses that acts act like a JAX pytree node, so JAX knows how to flatten/unflatten it for JAX transformations.