Skip to content

Commit 8657ec3

Browse files
committed
adding an option to enable single replica restore + broadcast option in maxdiffusion
1 parent e4b4205 commit 8657ec3

File tree

8 files changed

+136
-50
lines changed

8 files changed

+136
-50
lines changed

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,12 @@ def config_to_json(model_or_config):
224224
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)),
225225
}
226226

227-
items["unet_state"] = ocp.args.StandardSave(train_states["unet_state"])
228-
items["vae_state"] = ocp.args.StandardSave(train_states["vae_state"])
229-
items["text_encoder_state"] = ocp.args.StandardSave(train_states["text_encoder_state"])
227+
items["unet_state"] = ocp.args.PyTreeSave(train_states["unet_state"])
228+
items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"])
229+
items["text_encoder_state"] = ocp.args.PyTreeSave(train_states["text_encoder_state"])
230230

231231
if hasattr(pipeline, "text_encoder_2"):
232-
items["text_encoder_2_state"] = ocp.args.StandardSave(train_states["text_encoder_2_state"])
232+
items["text_encoder_2_state"] = ocp.args.PyTreeSave(train_states["text_encoder_2_state"])
233233
items["text_encoder_2_config"] = ocp.args.JsonSave(config_to_json(pipeline.text_encoder_2.config))
234234

235235
tokenizer_config = {"path": self.config.tokenizer_model_name_or_path}

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

2020
from typing import Optional, Any
21+
import jax
22+
import numpy as np
2123
import os
24+
25+
import orbax.checkpoint
2226
from maxdiffusion import max_logging
2327
from etils import epath
2428
from flax.training import train_state
@@ -90,7 +94,7 @@ def load_stable_diffusion_configs(
9094
):
9195
f"""
9296
Loads Orbax configurations for different stable diffusion models
93-
97+
9498
Args:
9599
checkpoint_manager (`orbax.checkpoint.checkpoint_manager`)
96100
checkpoint_type (`str`) : use sd or sdxl
@@ -140,8 +144,37 @@ def load_params_from_path(
140144
return restored["params"]
141145

142146

147+
def _find_idx(array: np.ndarray, replica_axis_idx: int):
148+
"""Returns the index along given dimension that the current host belongs to."""
149+
idx = None
150+
for idx, val in np.ndenumerate(array):
151+
if val.process_index == jax.process_index():
152+
break
153+
return idx[replica_axis_idx]
154+
155+
156+
def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
157+
"""Returns the devices from the replica that current host belongs to.
158+
159+
Replicas are assumed to be restricted to the first axis.
160+
161+
Args:
162+
device_array: devices of the mesh that can be obtained by mesh.devices()
163+
replica_axis_idx: axis dimension along which replica is taken
164+
165+
Returns:
166+
devices inside the replica that current host is in
167+
"""
168+
idx = _find_idx(device_array, replica_axis_idx)
169+
replica_result = np.take(device_array, idx, axis=replica_axis_idx)
170+
return np.expand_dims(replica_result, axis=replica_axis_idx)
171+
172+
143173
def load_state_if_possible(
144-
checkpoint_manager: CheckpointManager, abstract_unboxed_pre_state: train_state.TrainState, checkpoint_item: str
174+
checkpoint_manager: CheckpointManager,
175+
abstract_unboxed_pre_state: train_state.TrainState,
176+
checkpoint_item: str,
177+
enable_single_replica_ckpt_restoring: bool,
145178
):
146179
"""Loads TrainState as possible from the inputs.
147180
@@ -151,6 +184,8 @@ def load_state_if_possible(
151184
abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax
152185
matches type against.
153186
checkpoint_item: the name of the checkpoint item that is being loaded. Ex: vae_state
187+
enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitng
188+
with SingleReplicaArrayHandler
154189
155190
Returns:
156191
A tuple of (train_state, train_state_params) where full_train_state captures
@@ -167,9 +202,44 @@ def load_state_if_possible(
167202
return None
168203
else:
169204
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
170-
try:
171-
item = {checkpoint_item: orbax.checkpoint.args.StandardRestore(item=abstract_unboxed_pre_state)}
172-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
173-
except:
174-
max_logging.log(f"could not load {checkpoint_item} from orbax")
175-
return None
205+
# try:
206+
if True:
207+
if not enable_single_replica_ckpt_restoring:
208+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
209+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
210+
211+
def map_to_pspec(data):
212+
pspec = data.sharding.spec
213+
mesh = data.sharding.mesh
214+
if not enable_single_replica_ckpt_restoring:
215+
return ocp.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
216+
replica_axis_index = 0
217+
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
218+
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
219+
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)
220+
221+
return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
222+
sharding=jax.sharding.NamedSharding(mesh, pspec),
223+
single_replica_sharding=single_replica_sharding,
224+
global_shape=data.shape,
225+
dtype=data.dtype,
226+
)
227+
228+
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
229+
replica_axis_index=0,
230+
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
231+
)
232+
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
233+
234+
restore_args = jax.tree_util.tree_map(
235+
map_to_pspec,
236+
abstract_unboxed_pre_state,
237+
)
238+
item = {checkpoint_item: ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)}
239+
return checkpoint_manager.restore(
240+
latest_step,
241+
args=orbax.checkpoint.args.Composite(**item)
242+
)
243+
# except:
244+
# max_logging.log(f"could not load {checkpoint_item} from orbax")
245+
# return None

src/maxdiffusion/configs/base14.yml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ timestep_bias: {
6666
begin: 0,
6767
# when using strategy=range, the final step (inclusive) to bias.
6868
end: 1000,
69-
# portion of timesteps to bias.
69+
# portion of timesteps to bias.
7070
# 0.5 will bias one half of the timesteps. Value of strategy determines
7171
# whether the biased portions are in the earlier or later timesteps.
7272
portion: 0.25
@@ -75,7 +75,7 @@ timestep_bias: {
7575
# Override parameters from checkpoints's scheduler.
7676
diffusion_scheduler_config: {
7777
_class_name: '',
78-
# values are v_prediction or leave empty to use scheduler's default.
78+
# values are v_prediction or leave empty to use scheduler's default.
7979
prediction_type: '',
8080
rescale_zero_terminal_snr: False,
8181
timestep_spacing: ''
@@ -87,12 +87,12 @@ base_output_directory: ""
8787
mesh_axes: ['data', 'fsdp', 'tensor']
8888

8989
# batch : batch dimension of data and activations
90-
# hidden :
90+
# hidden :
9191
# embed : attention qkv dense layer hidden dim named as embed
9292
# heads : attention head dim = num_heads * head_dim
9393
# length : attention sequence length
94-
# temb_in : dense.shape[0] of resnet dense before conv
95-
# out_c : dense.shape[1] of resnet dense before conv
94+
# temb_in : dense.shape[0] of resnet dense before conv
95+
# out_c : dense.shape[1] of resnet dense before conv
9696
# out_channels : conv.shape[-1] activation
9797
# keep_1 : conv.shape[0] weight
9898
# keep_2 : conv.shape[1] weight
@@ -118,7 +118,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
118118
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
119119
dcn_fsdp_parallelism: 1
120120
dcn_tensor_parallelism: 1
121-
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
121+
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
122122
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
123123
ici_tensor_parallelism: 1
124124

@@ -144,6 +144,8 @@ enable_data_shuffling: True
144144

145145
# checkpoint every number of samples, -1 means don't checkpoint.
146146
checkpoint_every: -1
147+
# enables one replica to read the ckpt then broadcast to the rest
148+
enable_single_replica_ckpt_restoring: False
147149

148150
# Prepare image latents and text encoder outputs
149151
# during dataset creation to reduce memory consumption.
@@ -165,7 +167,7 @@ per_device_batch_size: 1
165167
warmup_steps_fraction: 0.0
166168
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
167169

168-
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
170+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
169171
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
170172

171173
# AdamW optimizer parameters
@@ -205,4 +207,4 @@ class_prompt: ''
205207
prior_loss_weight: 1.0
206208
num_class_images: 100
207209
# If true, set dataset_save_location.
208-
cache_dreambooth_dataset: False
210+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base21.yml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ timestep_bias: {
6666
begin: 0,
6767
# when using strategy=range, the final step (inclusive) to bias.
6868
end: 1000,
69-
# portion of timesteps to bias.
69+
# portion of timesteps to bias.
7070
# 0.5 will bias one half of the timesteps. Value of strategy determines
7171
# whether the biased portions are in the earlier or later timesteps.
7272
portion: 0.25
@@ -75,7 +75,7 @@ timestep_bias: {
7575
# Override parameters from checkpoints's scheduler.
7676
diffusion_scheduler_config: {
7777
_class_name: '',
78-
# values are v_prediction or leave empty to use scheduler's default.
78+
# values are v_prediction or leave empty to use scheduler's default.
7979
prediction_type: '',
8080
rescale_zero_terminal_snr: False,
8181
timestep_spacing: ''
@@ -89,12 +89,12 @@ base_output_directory: ""
8989
mesh_axes: ['data', 'fsdp', 'tensor']
9090

9191
# batch : batch dimension of data and activations
92-
# hidden :
92+
# hidden :
9393
# embed : attention qkv dense layer hidden dim named as embed
9494
# heads : attention head dim = num_heads * head_dim
9595
# length : attention sequence length
96-
# temb_in : dense.shape[0] of resnet dense before conv
97-
# out_c : dense.shape[1] of resnet dense before conv
96+
# temb_in : dense.shape[0] of resnet dense before conv
97+
# out_c : dense.shape[1] of resnet dense before conv
9898
# out_channels : conv.shape[-1] activation
9999
# keep_1 : conv.shape[0] weight
100100
# keep_2 : conv.shape[1] weight
@@ -120,7 +120,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
120120
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
121121
dcn_fsdp_parallelism: 1
122122
dcn_tensor_parallelism: 1
123-
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
123+
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
124124
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
125125
ici_tensor_parallelism: 1
126126

@@ -146,6 +146,8 @@ enable_data_shuffling: True
146146

147147
# checkpoint every number of samples, -1 means don't checkpoint.
148148
checkpoint_every: -1
149+
# enables one replica to read the ckpt then broadcast to the rest
150+
enable_single_replica_ckpt_restoring: False
149151

150152
# Prepare image latents and text encoder outputs
151153
# during dataset creation to reduce memory consumption.
@@ -165,7 +167,7 @@ per_device_batch_size: 1
165167
warmup_steps_fraction: 0.0
166168
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
167169

168-
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
170+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
169171
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
170172

171173
# AdamW optimizer parameters
@@ -201,4 +203,4 @@ class_prompt: ''
201203
prior_loss_weight: 1.0
202204
num_class_images: 100
203205
# If true, set dataset_save_location.
204-
cache_dreambooth_dataset: False
206+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ timestep_bias: {
7878
begin: 0,
7979
# when using strategy=range, the final step (inclusive) to bias.
8080
end: 1000,
81-
# portion of timesteps to bias.
81+
# portion of timesteps to bias.
8282
# 0.5 will bias one half of the timesteps. Value of strategy determines
8383
# whether the biased portions are in the earlier or later timesteps.
8484
portion: 0.25
@@ -88,7 +88,7 @@ timestep_bias: {
8888
# Override parameters from checkpoints's scheduler.
8989
diffusion_scheduler_config: {
9090
_class_name: '',
91-
# values are v_prediction or leave empty to use scheduler's default.
91+
# values are v_prediction or leave empty to use scheduler's default.
9292
prediction_type: '',
9393
rescale_zero_terminal_snr: False,
9494
timestep_spacing: ''
@@ -102,12 +102,12 @@ base_output_directory: ""
102102
mesh_axes: ['data', 'fsdp', 'tensor']
103103

104104
# batch : batch dimension of data and activations
105-
# hidden :
105+
# hidden :
106106
# embed : attention qkv dense layer hidden dim named as embed
107107
# heads : attention head dim = num_heads * head_dim
108108
# length : attention sequence length
109-
# temb_in : dense.shape[0] of resnet dense before conv
110-
# out_c : dense.shape[1] of resnet dense before conv
109+
# temb_in : dense.shape[0] of resnet dense before conv
110+
# out_c : dense.shape[1] of resnet dense before conv
111111
# out_channels : conv.shape[-1] activation
112112
# keep_1 : conv.shape[0] weight
113113
# keep_2 : conv.shape[1] weight
@@ -133,7 +133,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
133133
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
134134
dcn_fsdp_parallelism: 1
135135
dcn_tensor_parallelism: 1
136-
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
136+
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
137137
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
138138
ici_tensor_parallelism: 1
139139

@@ -159,6 +159,8 @@ enable_data_shuffling: True
159159

160160
# checkpoint every number of samples, -1 means don't checkpoint.
161161
checkpoint_every: -1
162+
# enables one replica to read the ckpt then broadcast to the rest
163+
enable_single_replica_ckpt_restoring: False
162164

163165
# Prepare image latents and text encoder outputs
164166
# during dataset creation to reduce memory consumption.
@@ -178,7 +180,7 @@ per_device_batch_size: 1
178180
warmup_steps_fraction: 0.0
179181
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
180182

181-
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
183+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
182184
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
183185

184186
# AdamW optimizer parameters
@@ -218,4 +220,4 @@ class_prompt: ''
218220
prior_loss_weight: 1.0
219221
num_class_images: 100
220222
# If true, set dataset_save_location.
221-
cache_dreambooth_dataset: False
223+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_xl.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ timestep_bias: {
6767
begin: 0,
6868
# when using strategy=range, the final step (inclusive) to bias.
6969
end: 1000,
70-
# portion of timesteps to bias.
70+
# portion of timesteps to bias.
7171
# 0.5 will bias one half of the timesteps. Value of strategy determines
7272
# whether the biased portions are in the earlier or later timesteps.
7373
portion: 0.25
@@ -76,7 +76,7 @@ timestep_bias: {
7676
# Override parameters from checkpoints's scheduler.
7777
diffusion_scheduler_config: {
7878
_class_name: '',
79-
# values are v_prediction or leave empty to use scheduler's default.
79+
# values are v_prediction or leave empty to use scheduler's default.
8080
prediction_type: '',
8181
rescale_zero_terminal_snr: False,
8282
timestep_spacing: ''
@@ -90,12 +90,12 @@ base_output_directory: ""
9090
mesh_axes: ['data', 'fsdp', 'tensor']
9191

9292
# batch : batch dimension of data and activations
93-
# hidden :
93+
# hidden :
9494
# embed : attention qkv dense layer hidden dim named as embed
9595
# heads : attention head dim = num_heads * head_dim
9696
# length : attention sequence length
97-
# temb_in : dense.shape[0] of resnet dense before conv
98-
# out_c : dense.shape[1] of resnet dense before conv
97+
# temb_in : dense.shape[0] of resnet dense before conv
98+
# out_c : dense.shape[1] of resnet dense before conv
9999
# out_channels : conv.shape[-1] activation
100100
# keep_1 : conv.shape[0] weight
101101
# keep_2 : conv.shape[1] weight
@@ -147,6 +147,8 @@ enable_data_shuffling: True
147147

148148
# checkpoint every number of samples, -1 means don't checkpoint.
149149
checkpoint_every: -1
150+
# enables one replica to read the ckpt then broadcast to the rest
151+
enable_single_replica_ckpt_restoring: False
150152

151153
# Prepare image latents and text encoder outputs
152154
# during dataset creation to reduce memory consumption.
@@ -166,7 +168,7 @@ per_device_batch_size: 2
166168
warmup_steps_fraction: 0.0
167169
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
168170

169-
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
171+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
170172
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
171173

172174
# AdamW optimizer parameters
@@ -204,4 +206,4 @@ enable_mllog: False
204206
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
205207
controlnet_from_pt: True
206208
controlnet_conditioning_scale: 0.5
207-
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
209+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'

0 commit comments

Comments
 (0)