Serializing nested objects #1159
-
|
I noticed that flax's dataclasses have their from/to_state_dict methods automatically defined. That's great! However, I now have a complex object which is not a dataclass for which I would like to define flax from/to state dict. Some of it's fields are dataclasses. I defined the methods as follows (taken from here ): def serialize_classical_variational_state(vstate):
state_dict = {
"variables": vstate.variables,
"sampler_state": vstate.sampler_state,
}
return state_dict
def deserialize_classical_variational_state(vstate, state_dict):
import copy
new_vstate = copy.copy(vstate)
new_vstate.variables = state_dict["variables"]
new_vstate.sampler_state = state_dict["sampler_state"]
return new_vstateNote that variables is a frozen pytree (parameters of a model) and sampler_state is a dataclass. I noticed that when def serialize_classical_variational_state(vstate):
state_dict = {
"variables": flax.serializer.to_state_dict(vstate.variables),
"sampler_state": flax.serializer.to_state_dict(vstate.sampler_state),
}
return state_dictis that right? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Yes, and the same goes for deserialisation where you should use from_state_dict. |
Beta Was this translation helpful? Give feedback.
Yes, and the same goes for deserialisation where you should use from_state_dict.