Description
Feedback appreciated!
Motivation
A variety of Orbax Checkpoint users have expressed concerns that the API is too complex. As one user stated:
I would say orbax makes the hard things possible, but makes the easy things hard. If we can progress to the point where it makes hard things possible, and easy things easy -- that would be brilliant!
Another user concurred:
There is a steep learning curve using orbax, compared to lets say safetensors.
Orbax core APIs need simplification in order provide a better user experience.
Overview
Orbax Checkpoint (OCP) will introduce a V1 API (the current API denoted as “V0”), located at orbax.checkpoint.v1
. This will serve as the new entry point for all users. However, migration will not be required. The public documentation page will reference only the V1 API.
The V0 API will continue to exist as the underlying implementation of the V1 API in the short term, while being gradually deprecated in the long term.
The new API will be designed to address a number of complaints about the current API, including its complexity, verbosity, and steep learning curve. The current API has failed to incorporate the principle of progressive disclosure of complexity, instead opting for maximum flexibility while failing to simplify common use cases. While maximum flexibility will still be possible, the most common use cases must be easy to understand and expressable in concise code.
API Surface
Basic users:
save_pytree / save_pytree_async
- Allows saving a PyTree synchronously or asynchronously.load_pytree / load_pytree_async
- Allows loading a PyTree synchronously or asynchronously, plus retrieving metadata.Checkpointer
- The primary entry point for managing a sequence of checkpoints. It offers automatic garbage collection, configurable save interval policies, etc. with a similar interface to the free functions.
Intermediate users:
- “checkpointable” - The concept of a logically distinct unit of a checkpoint that has minimal relation to the other units and is often separated for loading (e.g. params/opt_state, dataset, other metadata).
save_checkpointables
/load_checkpointables
- Save and load arbitrary checkpointables (e.g. dataset, embeddings) (PyTree may be one of the checkpointables).- Core customization behaviors like partial load, partial update, and model surgery.
Checkpointable
- Support custom checkpointables by implementing an interface.
Advanced users:
CheckpointableHandler
- Allows fine-grained save/restore behavior customization (particularly formats) for a givenCheckpointable
.LeafHandler
- Allows customizing PyTree save/restore behavior for custom leaf objects.Context
- Allows specific configuration and working with operations (e.g. save/load). Allows above free functions to run in a specific given Context (configuration).configure
- Allows overriding the globalContext
and setting global options.
Use Cases
import orbax.checkpoint as ocp
Single-Checkpoint Use Cases
The following scenarios enumerate functionalities that users need when saving and restoring a single checkpoint, independently of the sequence of checkpoints that is typically required during training. These scenarios can be common when debugging checkpoints locally, or when running evaluations.
Many of the usage patterns listed here also apply when managing a sequence of checkpoints.
Save synchronously and asynchronously
ocp.save_pytree(path, pytree_state)
response = ocp.save_pytree_async(path, pytree_state)
response.result() # Wait for completion
Restore synchronously and asynchronously
restored = ocp.load_pytree(path)
response = ocp.async_load_pytree(path)
restored = response.result() # Wait for completion
Save and restore with optional arguments to customize behavior
ocp.save_pytree(path, pytree, partial_update=True)
restored = ocp.load_pytree(path, abstract_pytree, partial_load=True)
Save multiple logically distinct checkpointables
ocp.save_checkpointables(path, dict(pytree=pytree, dataset=ds_iter, foo=foo))
ocp.load_checkpointables(path, dict(pytree=abstract_pytree)) # Restore only pytree
ocp.load_pytree(path, abstract_pytree) # Same as above
result = ocp.load_checkpointables(
path, dict(pytree=abstract_pytree, dataset=...)
)
pytree, dataset = result['pytree'], result['dataset']
Obtain metadata
ocp.metadata(path) # -> CheckpointMetadata
Support custom checkpointables
class MyCustomClass: # Implements Checkpointable (see below)
… # Some properties.
async def save_async(self, path: Path) -> AsyncResponse[None]:
serialized = self.properties_as_json()
return await ocp.save_async(path, serialized) # Saves as basic JSON file.
@classmethod
async def load_async(
cls, path: Path, abstract_checkpointable: NoneType = None,
) -> AsyncResponse[MyCustomClass]:
# Loading this object does not require any abstract information, so we set
# abstract_checkpointable to None. The user could also define an abstract class
# corresponding to MyCustomClass (the concrete class) and accept that instead.
serialized = await ocp.load_async(path) # Read JSON.
return cls(**serialized)
async def metadata(self, path: Path) -> None:
return None
custom_obj = MyCustomClass(...)
ocp.save_checkpointables(path, dict(pytree=pytree, custom_obj=custom_obj))
Support custom formats
This differs from the above in that instead of some specific user-defined object that is always saved and loaded in the same way, we instead have a core object (like a PyTree) that needs to be handled in an alternative way. For example, this allows us to configure checkpointing behavior in a non-Orbax format (e.g. Roc or PyTorch).
ocp.save_pytree(
path,
pytree,
handler=RocHandler(format=einshape_numpy_proto)
)
Context based customization
See below for more details. The following example allows saving to a process-local filesystem.
# Update global context statically.
ocp.configure(
multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
ocp.save_pytree(local_fs_path, pytree)
# Save pytree in a specific context.
multiprocessing_options=MultiprocessingOptions(primary_host=None)
with ocp.Context(multiprocessing_options=multiprocessing_options):
ocp.save_pytree(local_fs_path, pytree)
Sequence-of-Checkpoints Use Cases
These use cases roughly correspond to those served by the existing CheckpointManager
object. Note however that now the class is using the name Checkpointer
, and that it tries to reuse constructs introduced to serve the single-checkpoint use case.
Saving and restoring
with ocp.Checkpointer(directory) as ckptr:
ckptr.save_pytree(0, train_state) # Save
f = ckptr.save_pytree_async(1, train_state) # Async save
f.result()
ckptr.load_pytree() # Restores the latest step
ckptr.load_pytree(1) # Restore a specific step
ckptr.load_pytree(1, abstract_train_state) # Reshard / cast
f = ckptr.load_pytree_async(1) # Async restore
f.result()
Saving and restoring with multiple checkpointables
with ocp.Checkpointer(directory) as ckptr:
ckptr.save_checkpointables(step, dict(pytree=train_state, dataset=train_iter))
ckptr.load_pytree(step, abstract_train_state)
Obtain metadata
with ocp.Checkpointer(directory) as ckptr:
ckptr.metadata() # -> RootMetadata: Root-directory-level metadata
ckptr.metadata(step) # -> CheckpointMetadata: Checkpoint-level metadata
Determine when to save
# In the future, will default to ContinuousCheckpointingPolicy (at least for internal users)
save_decision_policy=EveryNStepsPolicy(steps=1000)
with ocp.Checkpointer(directory, save_decision_policy) as ckptr:
ckptr.should_save(step) # -> bool
Identify existing checkpoints
with ocp.Checkpointer(directory) as ckptr:
ckptr.latest_step() # -> int
ckptr.steps() # -> set[int]
Handle garbage collection
with ocp.Checkpointer(directory, preservation_policy=LatestN(10)) as ckptr:
…
Rank checkpoints by metrics
# Mimics builtin `sorted` function.
preservation_policy=BestN(10, lambda m: m['accuracy'], reverse=True/False)
with ocp.Checkpointer(directory, preservation_policy) as ckptr:
ckptr.save_pytree(step, …, metrics={'accuracy': 0.9, 'loss': 0.65})
Context based customization
# Update global context statically.
ocp.configure(
multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
with ocp.Checkpointer(directory) as ckptr:
ckptr.save_pytree(0, train_state) # Save
# Checkpointer with a specific context.
# Overrides global context.
with ocp.Checkpointer(directory, context=ocp.Context(...)) as ckptr:
ckptr.save_pytree(0, train_state) # Save
# Checkpointer operation with a specific context.
# Overrides both global and Checkpointer level context.
with ocp.Checkpointer(directory) as ckptr:
with ocp.Context(...):
ckptr.save_pytree(0, train_state) # Save
Training loop example
Note: model surgery complexity omitted.
def init_or_restore(ckptr: Checkpointer) -> PyTree:
if exp_cfg.init_checkpoint_path: # Restore initial checkpoint (e.g. finetuning)
return ckptr.load_pytree(exp_cfg.init_checkpoint_path, transform_fn=surgery_fn)
else: # Init from scratch
return init()
with ocp.Checkpointer(directory, **other_options) as ckptr:
if ckptr.latest_step() is None: # Recovering after restart
train_state = ckptr.load_pytree()
else:
train_state = init_or_restore(ckptr)
ckptr.save_pytree(0, train_state) # Save initial model
for step in range(start_step, end_step):
train_state = train_step(train_state)
if ckptr.should_save(step):
ckptr.save_checkpointables(step, dict(state=train_state, dataset=train_iter))
Tree-Specific Use Cases
Restore with resharding/reshaping/casting
abstract_tree = {
'a': jax.ShapeDtypeStruct(shape, dtype, sharding), # restore as jax.Array
'b': np.empty(shape, dtype), # restore as np.ndarray
'c': '', # restore as string
}
ocp.load_pytree(path, abstract_tree)
Partial restoration
Partial restore is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation.
In contrast, model surgery is the more complete version of this, where the user can manipulate trees/leaves in arbitrary ways, as well as load multiple trees and merge them.
abstract_tree = {
'params': { … }
# Note: omit 'opt_state' to avoid loading it
'step': None # Skip loading 'step'
}
# Unsafe variant, we need to set partial_load True by default
ocp.load_pytree(path, abstract_tree)
# Safe variant, partial_load must be opted-into
ocp.load_pytree(path, abstract_tree, partial_load=True)
Restore with model surgery
def transform_fn(source: PyTree) -> PyTree:
...
ocp.load_and_transform(path, transform_fn, abstract_tree)
Multi-model restore with model surgery
def transform_fn(source_a: PyTree, source_b: PyTree) -> PyTree:
...
ocp.load_and_transform(
abstract_tree, transform_fn, path_a, path_b
)
Partial write (update)
ocp.save_pytree(path, partial_pytree_one, partial_update=True)
ocp.save_pytree(path, partial_pytree_two, partial_update=True)
Support custom tree leaves with Context
# With global context.
ocp.configure(leaf_handlers={MyCustomLeaf: CustomLeafHandler})
…
ocp.save_pytree(path, pytree)
# With local context.
with ocp.Context(leaf_handlers={MyCustomLeaf: CustomLeafHandler}):
ocp.save_pytree(path, pytree_with_custom_leaves)
API Definitions
Overview
In orbax.checkpoint.v1
, free functions) will be the primary entry point for users into the library. These include save_pytree
/ load_pytree
, which deal with single PyTrees, and save_checkpointables
/ load_checkpointables
, which deal with multiple arbitrary checkpointables. (These functions also include async variants, and metadata access.)
While these functions operate at the level of an individual checkpoint path, the other main entry point is Checkpointer
, which operates at the level of a root directory, under which a sequence of checkpoints corresponding to steps in a training loop are stored. This class makes restrictive assumptions about the set of tasks that a user will try to do and the patterns it is used under. In other words, it is obviously less flexible than APIs oriented around singular paths, but provides more features, like automatic garbage collection, metrics management, and save intervals. It will be suitable for many basic training loops, but not for more advanced users with greater customization needs.
The user facing api, especially the free functions, discussed in this doc are based on some global configurations. e.g. multiprocessing options, timeouts, LeafHandlerRegistry etc. These global configurations are called Context
and are implemented as a context manager. The Orbax operations discussed above can be invoked within a context manager to customize their behavior with given configuration.
“Checkpointable” remains a key concept in the V1 API. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
Different checkpointables are handled by CheckpointableHandler
implementations. These provide the logic for saving and loading particular objects, and also identify which objects they are capable of processing. For user convenience, a Checkpointable
interface is also provided, which allows tightly coupling checkpointing logic to a specific object. While less flexible than the CheckpointableHandler
, it is also more intuitive and we intend to expose it as the main interface for checkpointable customization.
The lowest-level user-accessible abstraction is the LeafHandler
, which deals specifically with processing individual leaves of a PyTree. Implementing and registering a subclass of this object allows storing custom leaves in a PyTree.
Free Functions
Free functions serve as the primary entry point for users.
Checkpointable = Any
### SAVING ###
def save_pytree(
directory: PathLike,
pytree: PyTree,
*,
partial_update: bool = False,
force: bool = False,
handler: Type[CheckpointableHandler] = PyTreeHandler,
):
…
def save_pytree_async(...) -> AsyncResponse[None]:
…
def save_checkpointables(
directory: PathLike,
checkpointables: dict[str, Checkpointable],
*,
force: bool = False,
):
…
def save_checkpointables_async(...) -> AsyncResponse[None]:
…
### LOADING ###
def load_pytree(
directory: PathLike,
abstract_pytree: PyTree | None = None,
*,
partial_load: bool = False,
) -> PyTree:
…
def load_pytree_async(...) -> AsyncResponse[PyTree]:
…
def load_checkpointables(
directory: PathLike,
abstract_checkpointables: dict[str, Checkpointable] | None = None,
) -> dict[str, Checkpointable]:
…
def load_checkpointables_async(...) -> AsyncResponse[dict[str, Checkpointable]]:
…
def metadata(directory: PathLike) -> CheckpointMetadata:
…
### MODEL SURGERY ###
def load_and_transform(
directories: Sequence[PathLike],
abstract_pytree: PyTree,
transform_fn: TransformFn,
) -> PyTree:
…
Futures vs. wait_until_finished
Currently async saving in Orbax does not return futures to the user, but instead relies on a wait_until_finished
method for the user to block on the result of the save. However, it makes sense to abandon this model for a few reasons.
First, in the typical training checkpointing use case, users rarely block on the result of save, and only do so before exiting the program, typically. They rely on the library itself to block if a save is already ongoing when they try to save again.
Second, the use of context managers makes both futures and wait_until_finished
unnecessary in many cases, as exiting the context automatically waits. This use case may be more common for small-scale experimentation, or one-off PyTree writes.
Third, and most importantly, the introduction of async_load
does not really have a viable alternative for providing its result to the user other than via a future. The benefits of aligning the APIs of save_async
and async_load
outweigh any other potential arguments, in my view.
Further note that we should aim to move away from the “Future” terminology. Despite its superficial familiarity to many users, this fact can create confusion with the other Future implementations.
Instead we will opt for a construct like AsyncResponse
. This is a simple container class that is returned by asynchronous APIs. It contains a method like result
that allows blocking on the save/load operation and retrieving the operation result. In this respect, it is similar to ocp.Future
in the current codebase.
Alternatives to save_pytree/save_checkpointables, load_pytree/load_checkpointables
We can combine both functionalities into one method, e.g.
def load(
self,
path: PathLike,
abstract_pytree: PyTree | None = None,
*,
extra_checkpointables: dict[str, Any] | None = None,
):
The real difficulty is how to represent the return type for load
. In order to mirror the function inputs, we must return a tuple of (pytree, extra_checkpointables)
. (We could also just return a single dictionary representing all checkpointables, but this is undesirable for users only providing pytree
, since they will not necessarily be aware of how the checkpoint is represented, or that checkpointables
is a concept they need to know. Returning a tuple of (pytree, extra_checkpointables)
also requires a user to know that extra_checkpointables
is an important argument, but is less bad than requiring the user to know the arbitrary name “pytree” in order to access their loaded tree.)
If we accept that a tuple is the ideal return type, in order to mirror the inputs, the return type must be:
tuple[PyTree, dict[str, Any] | None]
This interface is fairly inflexible and not that user friendly. The issues are:
pytree
must be present in every checkpoint.- There is no way to restore
extra_checkpointables
and notpytree
, sinceabstract_pytree=None
is used to indicate “restore the pytree however you can”. - Different treatment of
None
forabstract_pytree
andextra_checkpointables
is confusing. - The return type is complex and not well-suited to users only interested in the pytree (this violates progressive disclosure of complexity).
- Return types depend on both inputs and what is in the checkpoint.
The combined interface is more trouble than it’s worth. Splitting into load_pytree
/ load_checkpointables
meshes well with progressive disclosure of complexity, simplifies input and output signatures, and makes return types more predictable, at the cost of using two functions instead of one.
Context and Configuration
@dataclasses.dataclass(frozen=True)
class Context:
options: Options
def __enter__(self) -> Context:
...
yield self
...
def __exit__(self, ...):
...
Dealing with global configurations (Context)
A “global configuration” is a setting that applies at multiple levels of the Orbax stack. These settings must be applied in the same way to multiple different layers, or the inconsistency can result in unexpected errors.
This includes groups of options like:
@dataclasses.dataclass(frozen=True, kw_only=True)
class Options:
# save_timeout/load_timeout,
async_options: AsyncOptions
# Settings for e.g. primary_host, active_processes
multiprocessing_options: MultiprocessingOptions
# Options controlling path permissions, data governance annotations (internal),
# CNS2 options (internal), etc.
file_options: FileOptions
# Options for enabling hashing/signing behavior in save/load.
signing_options: SigningOptions
# Other options
...
For example, a setting like save_timeout
/load_timeout
applies globally to an entire operation, rather than being set separately for save
/load
, CheckpointableHandler
, and LeafHandler
. Another example is primary_host
, which must have the same setting in every layer, or risk difficult-to-debug breakages.
Practically speaking, global options (corresponding roughly to existing options.py
) are not commonly used. They are typically set once as global configurations. In rare cases, individual operations may need to modify the settings with greater flexibility.
All configurations
Orbax provides a lot of options for configuring specific behaviors at various levels.
For the V1 API, we can subdivide options into a number of categories. Of these, the most interesting are PyTree-related options and Array-related options, which comprise the bulk of all options.
- AsyncOptions
- timeout_secs
- barrier_sync_fn
- post_finalization_callback
- create_directories_asynchronously
- MultiprocessingOptions
- primary_host
- active_processes
- barrier_sync_key_prefix
- FileOptions
- path_permission_mode
- data_governance_annotations
- cns2_storage_options
- temporary_path_class # Atomicity
- SecurityOptions
- tree_verity_options
- PyTrees
- array_storage_options_creator # Creates an ArrayStorageOptions struct on a per-leaf basis that customizes save behavior for individual array leaves. If ArrayStorageOptions are set globally, this option will override them.
- leaf_handler_registry # LeafHandlers used for PyTree leaves
- enable_descriptor
- pytree_metadata_options
- array_metadata_validator # Not user-facing, mostly for internal testing
- partial_update # Enable partial tree update
- partial_load # Enable partial tree loading
- Arrays
- Saving
- concurrent_bytes
- OCDBT options
- use_ocdbt
- ocdbt_target_data_file_size
- enable_post_merge_validation
- Storage options # Can be customized per-array if we have multiple arrays.
- dtype # cast type for storage
- chunk_byte_size # loose cap on the size of Tensorstore chunks
- shard_axes # Chunks subdivided along this axis
- metadata_key # .zarray file name
- use_zarr3
- enable_pinned_host_transfer
- enable_write_sharding_file
- use_replica_parallel
- Loading
- concurrent_bytes
- enable_padding_and_truncation
- Single-replica restore+broadcast
- replica_axis_index
- primary_replica_id
- broadcast_memory_limit_bytes
- broadcast_memory_scaling_factor
- Saving
It is clear that most of these options do not need to be modified often. When they do, a global setting is possible, and only in rare cases do users need per-operation customization. If all such settings are placed within a global Options
, there is considerably less doubt for users about where to find a particular setting. In an inherently complicated landscape with many different settings, it will never be “easy” to find a particular option, but it will be less difficult if all are placed under a common structure.
ocp.configure(
Options(
async_options=ocp.Options.AsyncOptions(timeout=60),
array_options=ocp.Options.ArrayOptions(
concurrent_bytes=1e9,
ocdbt_options=ocp.Options.ArrayOptions.OcdbtOptions(use_ocdbt=False),
# Cast everything to bfloat16
storage_options=ocp.Options.ArrayOptions.StorageOptions(dtype=bfloat16),
)
)
)
# Alternatively:
ocp.configure(
Options({
'async_options': {'timeout': 60},
'array_options': {
'concurrent_bytes': 1e9,
'ocdbt_options': {'use_ocdbt': False},
'storage_options': {'dtype': bfloat16},
}
})
)
It is important to distinguish settings from options that unlock commonly-used operations, like model surgery, partial loading, partial updating, and forced overwrite. These operations are core functionalities that users often wish to enable or disable. As such, they should be located in the signature of the function they are used in (e.g. partial_load: bool
in load
, and force: bool
in save
). These options can still be settings in Options
, to enable global defaults, but can also be exposed directly in save
or load
as local overrides.
Checkpointing in a Training Loop
In the existing library, the division of labor between Checkpointer
and CheckpointManager
has not always been well understood. This is because users often conceptualize these terms interchangeably. Now, however, the user-facing API is oriented around free functions (save
/load
). We can now have a single Checkpointer
class that behaves much as the current CheckpointManager
does.
In the long run, we will aim to achieve a level of composability that allows users to effectively write their own implementation of Checkpointer
with minimal additional code. The Checkpointer
itself should be a Protocol
with a rigid interface - we will be resistant to adding new features without substantial discussion and agreement that the proposed feature is a core element of checkpointing in a training loop.
Checkpointer
will live under orbax.checkpoint.training
to make explicit its intended use for training loops. Users with greater customization requirements will be encouraged to use lower-level APIs.
The save/load interface should mirror the free functions almost identically.
class Checkpointer:
def __init__(
root_directory: epath.PathLike | RootDirectoryFormat,
*,
# Default to continuous checkpointing.
save_decision_policy: SaveDecisionPolicy | None = None
# Default to the latest few. See design.
preservation_policy: PreservationPolicy | None = None
step_name_format: NameFormat[step_lib.Metadata] | None = None
metric_comparator: MetricComparator | None = None
# Default to async deletion.
deletion_options: DeletionOptions | None = None
# Context
context: Context | None = None
…
):
…
@property
def directory(self) -> Path:
…
def latest_step(self) -> StepInfo:
…
def steps(self) -> Sequence[StepInfo]:
…
def save_pytree(...)
def save_checkpointables(...)
def save_pytree_async(...)
def save_checkpointables_async(...)
def load_pytree(...)
def load_checkpointables(...)
def load_pytree_async(...)
def load_checkpointables_async(...)
def metadata(self, step: int | None = None) -> RootMetadata | CheckpointMetadata:
"""Retrieves root-directory-level metadata."""
…
def reload(self):
"""Reload internal properties from the root directory."""
…
Checkpointable
“Checkpointable” is a core concept, at least for any Orbax user beyond a beginner level. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
While the V0 API drew a distinction between “items” and “PyTree leaves”, this distinction was unnecessary. A jax.Array
, str
, or scalar are “checkpointable” objects in the same way that a PyTree
composed of these objects is.
We can introduce a Protocol
to represent this concept called Checkpointable
, which defines methods needed to save and load the object.
T = TypeVar('T')
AbstractT = TypeVar('AbstractT')
class Checkpointable(Protocol[T, AbstractT]):
async def save(self, directory: Path) -> AsyncResponse[None]:
…
@classmethod
async def load(
cls, directory: Path, abstract_checkpointable: AbstractT | None = None
) -> AsyncResponse[T]:
…
async def metadata(self, directory: Path) -> AsyncResponse[AbstractT]:
…
When I have a certain object that requires customized logic for serialization, I can easily define the saving and loading logic associated with that object.
class MyObject:
… # Some properties
async def save(self, directory: Path) -> AsyncResponse[None]:
…
@classmethod
async def load(
cls,
directory: Path,
abstract_checkpointable: AbstractMyObject | None = None
) -> AsyncResponse[MyObject]:
…
Implementing the Checkpointable
interface should not be necessary in most cases, as most users simply want to save an object like an array or a PyTree. Even in the case of custom, user-defined PyTree objects, Checkpointable
should be rarely needed.
Furthermore, it is important to handle the case where a single PyTree may be saved in multiple different ways. This is common when writing format converters (e.g. Roc -> Orbax, PyTorch -> Orbax, etc.). In these cases, the checkpointable object is the same, but the checkpointing logic is different.
For these cases, CheckpointableHandler
makes more sense, as this provides checkpointing logic for a recognizable type and can be swapped in and out as needed.
T = TypeVar('T')
AbstractT = TypeVar('AbstractT')
class CheckpointableHandler(Protocol[T, AbstractT]):
async def save(
…
) -> AsyncResponse[None]:
…
async def load(
…
) -> AsyncResponse[T]:
…
async def metadata(self, directory: epath.Path) -> AsyncResponse[AbstractT]:
…
def is_handleable(self, checkpointable: T | AbstractT) -> bool:
"""Given any object, determine whether it can be stored with this handler."""
…
Here are some concrete examples:
class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]):
async def save(self, path: Path, checkpointable: PyTree) -> AsyncResponse[None]:
leaf_handlers_types = collect_per_leaf_checkpointable_handlers(checkpointable)
save_responses = []
for ht in leaf_handlers_types:
# Construct the per-leaf handler.
save_responses.append(await ht().save(path, leaf))
tree_metadata_response = await save_tree_metadata(path, checkpointable)
# Include finalize behavior in this response.
return UnifiedReponse(
*save_responses,
tree_metadata_response,
)
# Sometimes, checkpointables are always restorable without an
# abstract checkpointable, in which case it may be None.
class DatasetHandler(CheckpointableHandler[tf.data.Iterator, None]):
…
# For a singular np.ndarray handler, we can define the following abstract type:
class AbstractNumpyArray:
@property
def shape(self) -> tuple[int, …]:
…
@property
def dtype(self) -> np.dtype:
…
class NumpyHandler(CheckpointableHandler[np.ndarray, AbstractNumpyArray]):
…
Determining which CheckpointableHandler can save/restore an checkpointable
When the user provides an object to save or restore, how can we determine which handler is appropriate to deal with this object? Ultimately, we do not have that many core CheckpointableHandler
s. These include PyTree
, JSON
, Proto
, and Array
. Except for a JSON object, which is by definition a PyTree, all objects are easily distinguishable. Furthermore, a user generally doesn’t care how their object is stored, as long as it can be stored successfully. Users seeking maximum performance will go hunting for the ideal handler of their own accord. Setting reasonable defaults thus satisfies both beginners and advanced users.
Each handler can define an is_handleable
method that determines whether it is capable of storing the given object. When the user does not explicitly specify a handler, we check all globally-registered CheckpointableHandler
s and select the first one capable of saving or restoring the object. Registration order matters, so we can ensure JsonHandler
is always preferred for sufficiently simple objects (rather than PyTreeHandler
).
PyTree Leaf Handlers
Leaf = TypeVar('Leaf')
AbstractLeaf = TypeVar(AbstractLeaf)
@dataclasses.dataclass
class SerializationParam[Generic[Leaf]]:
name: str
keypath: jax.tree.KeyPath
value: T
@dataclasses.dataclass
class SerializationContext:
path: epath.Path
@dataclasses.dataclass
class DeserializationParam[Generic[AbstractLeaf]]:
name: str
value: AbstractT | None = None
@dataclasses.dataclass
class DeserializationContext:
path: epath.Path
class LeafHandler(Protocol[Leaf, AbstractLeaf]):
async def serialize(
self,
params: list[SerializationParam[Leaf]],
context: SerializationContext,
) -> AsyncResponse[None]:
...
async def deserialize(
self,
params: list[DeserializationParam[AbstractLeaf]],
context: DeserializationContext,
) -> AsyncResponse[Leaf]:
…
async def metadata(
self,
params: list[DeserializationParam[AbstractLeaf]],
context: DeserializationContext,
) -> AsyncResponse[AbstractLeaf]:
…
def finalize(self):
…
class ArrayHandler(LeafHandler[jax.Array, jax.ShapeDtypeStruct]):
…
class AbstractNumpyArray:
@property
def shape(self) -> tuple[int, …]:
…
@property
def dtype(self) -> np.dtype:
…
class NumpyHandler(LeafHandler[np.ndarray, AbstractNumpyArray]):
…
Customizing per-leaf save behavior
An argument to PyTreeHandler
that allows easily setting per-leaf behaviors is the array_storage_options_creator
. This is just a function that can be applied to the input PyTree via jax.tree.map_with_path
and returns a ArrayStorageOptions
struct, which contains a number of per-leaf settings. These options are only relevant to arrays, and are only applied to appropriate leaves.
@dataclasses.dataclass
class ArrayStorageOptions:
# Cast a leaf when saving.
dtype: jnp.dtype | None = None
# Specify a target size for storage chunks
chunk_byte_size: int | None = None
# Specify axes to prioritize for subchunking
shard_axes: tuple[int, …] = tuple()
class ArrayStorageOptionsCreator(Protocol):
"""Creates arguments to customize per-leaf saving behavior.
The function is called by `PyTreeHandler` using::
jax.tree.map_with_path(storage_options_creator, checkpointable)
The user may provide a function that returns `StorageOptions`, which will then be
applied to each leaf while saving.
"""
def __call__(self, key: jax.tree.KeyPath, value: Any) -> ArrayStorageOptions:
…
Eliminating RestoreArgs, ArrayRestoreArgs, etc.
When Orbax was first created, there was no notion that every leaf type had a corresponding abstract type that could be used to restore it. (jax.Array
was not yet solidified as a concept, and jax.ShapeDtypeStruct
existed but sharding
did not really exist yet.) As such, RestoreArgs
was introduced to capture restoration arguments relevant to a particular leaf.
Now, rather than needing to know and understand an entirely new set of classes, the user only needs to understand the rather intuitive idea that every concrete leaf type has a corresponding abstract type, that conveys properties without storing real data.
For standard types, these include:
jax.Array -> jax.ShapeDtypeStruct
np.ndarray -> AbstractNumpyArray # Uses duck typing
int -> int
float -> float
bytes -> bytes
str -> str
The abstract type itself conveys the desired restoration type, if the user wishes to convert from jax.Array
to np.ndarray
, for example. The desired restoration shape, dtype, and sharding, are also conveyed. The only other per-leaf restoration parameter used in Orbax today is strict
, which controls whether padding and truncating is allowed. In practice, however, this is always a setting used at a global level, not a per-leaf level.
What if we need to add additional per-leaf restoration options in the future?
Any such option is likely to be highly specialized, as no such need has been revealed after multiple years of Orbax development. It is always possible in this case to introduce a new LeafHandler with a different abstract type that carries additional properties.