Skip to content

Commit ce3a1d7

Browse files
committed
Updates to support fast resume. Working in v5e-32. Code can be simplified later
1 parent 54ab493 commit ce3a1d7

File tree

3 files changed

+115
-31
lines changed

3 files changed

+115
-31
lines changed

MaxText/configs/base.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,4 +736,13 @@ projector_output_dim_for_vit: 4096
736736
rope_theta_for_vit: 10000
737737
vision_output_dim_for_vit: 4096
738738
pixel_shuffle_ratio_for_vit: 0.5
739-
projector_dropout_for_vit: 0.0
739+
projector_dropout_for_vit: 0.0
740+
741+
742+
## Elastic training flags
743+
elastic_mode: "fast-resume"
744+
elastic_reshard_check_period: 1
745+
elastic_snapshot_period: 5
746+
elastic_max_elastic_down_event_count: 100
747+
elastic_max_reshard_retry_count: 3
748+
elastic_wait_period: 30

MaxText/elastic_train.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,38 @@ def train_loop(config, elastic_manager, state=None):
280280
# the step is restored back to the latest snapshot when a slice is lost
281281
while step < config.steps:
282282
try:
283+
if (config.elastic_mode == "fast-resume" and
284+
elastic_manager.good_slice_count < elastic_manager.total_slice_count):
285+
wait_for_all_slices(elastic_manager, config.elastic_wait_period)
286+
287+
(
288+
config,
289+
step,
290+
state,
291+
mesh,
292+
checkpoint_manager,
293+
data_iterator,
294+
p_train_step,
295+
example_batch,
296+
learning_rate_schedule,
297+
metric_logger,
298+
writer,
299+
input_data_shardings,
300+
) = elastic_manager.maybe_reshard_up(
301+
step=step,
302+
snapshot_jax_arrays={
303+
"params": state.params,
304+
"opt_state": state.opt_state,
305+
},
306+
elastic_handler=elastic_handler,
307+
handler_kwargs={
308+
"config": config,
309+
"elastic_manager": elastic_manager,
310+
"checkpoint_manager": checkpoint_manager,
311+
},
312+
)
313+
step += 1
314+
283315
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
284316
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
285317
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
@@ -344,6 +376,7 @@ def train_loop(config, elastic_manager, state=None):
344376
"checkpoint_manager": checkpoint_manager,
345377
},
346378
)
379+
347380
if ret is not None:
348381
(
349382
config,
@@ -366,7 +399,20 @@ def train_loop(config, elastic_manager, state=None):
366399
step += 1
367400

368401
except jax.errors.JaxRuntimeError as error:
369-
ret = elastic_manager.maybe_reshard_down(
402+
(
403+
config,
404+
step,
405+
state,
406+
mesh,
407+
checkpoint_manager,
408+
data_iterator,
409+
p_train_step,
410+
example_batch,
411+
learning_rate_schedule,
412+
metric_logger,
413+
writer,
414+
input_data_shardings,
415+
) = elastic_manager.maybe_reshard_down(
370416
error=error,
371417
elastic_handler=elastic_handler,
372418
handler_kwargs={
@@ -375,21 +421,6 @@ def train_loop(config, elastic_manager, state=None):
375421
"checkpoint_manager": checkpoint_manager,
376422
},
377423
)
378-
if ret is not None:
379-
(
380-
config,
381-
step,
382-
state,
383-
mesh,
384-
checkpoint_manager,
385-
data_iterator,
386-
p_train_step,
387-
example_batch,
388-
learning_rate_schedule,
389-
metric_logger,
390-
writer,
391-
input_data_shardings,
392-
) = ret
393424

394425
if checkpoint_manager is not None:
395426
if (int(state.step) - 1) % config.checkpoint_period != 0:
@@ -426,16 +457,21 @@ def train_loop(config, elastic_manager, state=None):
426457
return state
427458

428459

429-
def wait_for_all_slices(elastic_manager: manager.Manager) -> None:
430-
elastic_manager.good_slice_indices = elastic_manager.get_slice_availability()
431-
while len(elastic_manager.good_slice_indices) < elastic_manager.total_slice_count:
460+
def wait_for_all_slices(
461+
elastic_manager: manager.Manager,
462+
wait_period: int,
463+
) -> set[int]:
464+
good_slice_indices = elastic_manager.get_slice_availability()
465+
while len(good_slice_indices) < elastic_manager.total_slice_count:
432466
max_logging.log(
433467
f"Only {elastic_manager.good_slice_count} slices out of {elastic_manager.total_slice_count} available. "
434-
"Sleeping for 5 seconds."
468+
f"Sleeping for {wait_period} seconds."
435469
)
436-
time.sleep(5)
437-
elastic_manager.good_slice_indices = elastic_manager.get_slice_availability()
470+
time.sleep(wait_period)
471+
good_slice_indices = elastic_manager.get_slice_availability()
472+
438473
max_logging.log("All slices are available")
474+
return good_slice_indices
439475

440476

441477
def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
@@ -447,17 +483,11 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
447483
Returns:
448484
The initialized elastic manager
449485
"""
450-
elastic_manager = manager.Manager(
451-
devices,
452-
reshard_check_period=1,
453-
snapshot_period=5,
454-
max_elastic_down_event_count=100,
455-
max_reshard_retry_count=3,
456-
)
486+
elastic_manager = manager.Manager(devices)
457487

458488
# Do not start training until all slices are available
459489
# TODO: b/408455557 - Migrate to pathwaysutils and make configurable
460-
wait_for_all_slices(elastic_manager)
490+
elastic_manager.good_slice_indices = wait_for_all_slices(elastic_manager, 30)
461491

462492
pyconfig.HyperParameters.global_batch_size_to_train_on = property(
463493
lambda self: elastic_manager.scale_by_good_slices(self.get_keys()["global_batch_size_to_train_on"])
@@ -472,6 +502,14 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
472502

473503
return elastic_manager
474504

505+
def elastic_configure(
506+
config: pyconfig.HyperParameters,
507+
elastic_manager: manager.Manager,
508+
):
509+
elastic_manager.reshard_check_period = config.elastic_reshard_check_period
510+
elastic_manager.snapshot_period = config.elastic_snapshot_period
511+
elastic_manager.max_elastic_down_event_count = config.elastic_max_elastic_down_event_count
512+
elastic_manager.max_reshard_retry_count = config.elastic_max_reshard_retry_count
475513

476514
def main(argv: Sequence[str]) -> None:
477515
pathwaysutils.initialize()
@@ -486,6 +524,9 @@ def main(argv: Sequence[str]) -> None:
486524
elastic_manager = elastic_initialize(jax.devices())
487525

488526
config = pyconfig.initialize(argv)
527+
528+
elastic_configure(config, elastic_manager)
529+
489530
max_utils.print_system_information()
490531
validate_train_config(config)
491532
os.environ["TFDS_DATA_DIR"] = config.dataset_path or ""

MaxText/pyconfig.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,37 @@ def validate_rope_type(rope_type: str) -> None:
147147
raise ValueError(f"Invalid RoPE type was passed. Got: {rope_type}. Valid options: {valid_rope_types}")
148148

149149

150+
def validate_elastic(
151+
elastic_mode: str | None,
152+
elastic_reshard_check_period: int | None,
153+
):
154+
modes = {
155+
"replica-resize",
156+
"fast-resume",
157+
}
158+
159+
if elastic_mode not in modes:
160+
raise ValueError(f"{elastic_mode=} must be in {modes}")
161+
162+
if elastic_mode == "fast-resume" and elastic_reshard_check_period not in {None, 1}:
163+
raise ValueError(
164+
"For {elastic_mode=}, {elastic_reshard_check_period=} must be None or 1"
165+
)
166+
167+
168+
def get_elastic_defaults(keys) -> tuple[str, int | None]:
169+
elastic_defaults = {
170+
"elastic_mode": "replica-resize",
171+
"elastic_reshard_check_period": 1,
172+
"elastic_snapshot_period": 1,
173+
"elastic_max_elastic_down_event_count": None,
174+
"elastic_max_reshard_retry_count": None,
175+
"elastic_wait_period": 30,
176+
}
177+
178+
return {k: v for k, v in elastic_defaults.items() if k not in keys}
179+
180+
150181
def validate_keys(keys):
151182
validate_attention_kernel(keys["attention"])
152183
validate_attention_type(keys["attention_type"])
@@ -160,6 +191,7 @@ def validate_keys(keys):
160191
validate_model_call_mode(keys["model_call_mode"])
161192
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
162193
validate_rope_type(keys["rope_type"])
194+
validate_elastic(keys["elastic_mode"], keys["elastic_reshard_check_period"])
163195

164196
assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[
165197
"enable_checkpointing"
@@ -592,6 +624,8 @@ def user_init(raw_keys):
592624

593625
raw_keys["decoder_block"] = DecoderBlockType(raw_keys["decoder_block"])
594626

627+
raw_keys |= get_elastic_defaults(raw_keys)
628+
595629
@staticmethod
596630
def configure_gpt3_task(raw_keys):
597631
"""dynamically configure gpt3 task based on training rules"""

0 commit comments

Comments
 (0)