-
Notifications
You must be signed in to change notification settings - Fork 68
Description
Hi,
I'm trying to use orbax to checkpoint arbitrary penzai models, which have parameters of type pz.ParameterValue
, which themselves contain pz.NamedArray
instances. These instances contain named axes. So, I thought I'd try my hand at implementing class derived from type_handlers.TypeHandler
. However, I can't seem to see where in this workflow the axis names would be stored. I saw there is a TypeHandler.metadata
method, but that seems to be called only during restore. And, TypeHandler.serialize
doesn't seem to provide an opportunity to specialize except at a very low level, at the tensorstore level.
Am I missing something else? It would be nice to be able to use orbax with penzai models.
On the penzai side, they claim that orbax can be used, but there is no example of saving / loading an arbitrary model.