1818"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919
2020from typing import Optional , Any
21+ import jax
22+ import numpy as np
2123import os
24+
25+ import orbax .checkpoint
2226from maxdiffusion import max_logging
2327from etils import epath
2428from flax .training import train_state
@@ -90,7 +94,7 @@ def load_stable_diffusion_configs(
9094):
9195 f"""
9296 Loads Orbax configurations for different stable diffusion models
93-
97+
9498 Args:
9599 checkpoint_manager (`orbax.checkpoint.checkpoint_manager`)
96100 checkpoint_type (`str`) : use sd or sdxl
@@ -140,8 +144,37 @@ def load_params_from_path(
140144 return restored ["params" ]
141145
142146
147+ def _find_idx (array : np .ndarray , replica_axis_idx : int ):
148+ """Returns the index along given dimension that the current host belongs to."""
149+ idx = None
150+ for idx , val in np .ndenumerate (array ):
151+ if val .process_index == jax .process_index ():
152+ break
153+ return idx [replica_axis_idx ]
154+
155+
156+ def _replica_devices (device_array : np .ndarray , replica_axis_idx : int ):
157+ """Returns the devices from the replica that current host belongs to.
158+
159+ Replicas are assumed to be restricted to the first axis.
160+
161+ Args:
162+ device_array: devices of the mesh that can be obtained by mesh.devices()
163+ replica_axis_idx: axis dimension along which replica is taken
164+
165+ Returns:
166+ devices inside the replica that current host is in
167+ """
168+ idx = _find_idx (device_array , replica_axis_idx )
169+ replica_result = np .take (device_array , idx , axis = replica_axis_idx )
170+ return np .expand_dims (replica_result , axis = replica_axis_idx )
171+
172+
143173def load_state_if_possible (
144- checkpoint_manager : CheckpointManager , abstract_unboxed_pre_state : train_state .TrainState , checkpoint_item : str
174+ checkpoint_manager : CheckpointManager ,
175+ abstract_unboxed_pre_state : train_state .TrainState ,
176+ checkpoint_item : str ,
177+ enable_single_replica_ckpt_restoring : bool ,
145178):
146179 """Loads TrainState as possible from the inputs.
147180
@@ -151,6 +184,8 @@ def load_state_if_possible(
151184 abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax
152185 matches type against.
153186 checkpoint_item: the name of the checkpoint item that is being loaded. Ex: vae_state
187+ enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitng
188+ with SingleReplicaArrayHandler
154189
155190 Returns:
156191 A tuple of (train_state, train_state_params) where full_train_state captures
@@ -167,9 +202,44 @@ def load_state_if_possible(
167202 return None
168203 else :
169204 max_logging .log (f"restoring from this run's directory latest step { latest_step } " )
170- try :
171- item = {checkpoint_item : orbax .checkpoint .args .StandardRestore (item = abstract_unboxed_pre_state )}
172- return checkpoint_manager .restore (latest_step , args = orbax .checkpoint .args .Composite (** item ))
173- except :
174- max_logging .log (f"could not load { checkpoint_item } from orbax" )
175- return None
205+ # try:
206+ if True :
207+ if not enable_single_replica_ckpt_restoring :
208+ item = {checkpoint_item : orbax .checkpoint .args .PyTreeRestore (item = abstract_unboxed_pre_state )}
209+ return checkpoint_manager .restore (latest_step , args = orbax .checkpoint .args .Composite (** item ))
210+
211+ def map_to_pspec (data ):
212+ pspec = data .sharding .spec
213+ mesh = data .sharding .mesh
214+ if not enable_single_replica_ckpt_restoring :
215+ return ocp .type_handlers .ArrayRestoreArgs (mesh = mesh , mesh_axes = pspec )
216+ replica_axis_index = 0
217+ replica_devices = _replica_devices (mesh .devices , replica_axis_index )
218+ replica_mesh = jax .sharding .Mesh (replica_devices , mesh .axis_names )
219+ single_replica_sharding = jax .sharding .NamedSharding (replica_mesh , pspec )
220+
221+ return ocp .type_handlers .SingleReplicaArrayRestoreArgs (
222+ sharding = jax .sharding .NamedSharding (mesh , pspec ),
223+ single_replica_sharding = single_replica_sharding ,
224+ global_shape = data .shape ,
225+ dtype = data .dtype ,
226+ )
227+
228+ array_handler = ocp .type_handlers .SingleReplicaArrayHandler (
229+ replica_axis_index = 0 ,
230+ broadcast_memory_limit_bytes = 1024 * 1024 * 1000 , # 1000 MB limit
231+ )
232+ ocp .type_handlers .register_type_handler (jax .Array , array_handler , override = True )
233+
234+ restore_args = jax .tree_util .tree_map (
235+ map_to_pspec ,
236+ abstract_unboxed_pre_state ,
237+ )
238+ item = {checkpoint_item : ocp .args .PyTreeRestore (item = abstract_unboxed_pre_state , restore_args = restore_args )}
239+ return checkpoint_manager .restore (
240+ latest_step ,
241+ args = orbax .checkpoint .args .Composite (** item )
242+ )
243+ # except:
244+ # max_logging.log(f"could not load {checkpoint_item} from orbax")
245+ # return None
0 commit comments