Skip to content

Commit 8cda6c9

Browse files
committed
Simplified code
1 parent ce3a1d7 commit 8cda6c9

File tree

1 file changed

+25
-51
lines changed

1 file changed

+25
-51
lines changed

MaxText/elastic_train.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,34 @@ def train_loop(config, elastic_manager, state=None):
280280
# the step is restored back to the latest snapshot when a slice is lost
281281
while step < config.steps:
282282
try:
283+
elastic_manager.maybe_snapshot(
284+
step=step,
285+
snapshot_jax_arrays={
286+
"params": state.params,
287+
"opt_state": state.opt_state,
288+
},
289+
block=True,
290+
)
291+
283292
if (config.elastic_mode == "fast-resume" and
284293
elastic_manager.good_slice_count < elastic_manager.total_slice_count):
285294
wait_for_all_slices(elastic_manager, config.elastic_wait_period)
286295

296+
ret = elastic_manager.maybe_reshard_up(
297+
step=step,
298+
snapshot_jax_arrays={
299+
"params": state.params,
300+
"opt_state": state.opt_state,
301+
},
302+
elastic_handler=elastic_handler,
303+
handler_kwargs={
304+
"config": config,
305+
"elastic_manager": elastic_manager,
306+
"checkpoint_manager": checkpoint_manager,
307+
},
308+
)
309+
310+
if ret is not None:
287311
(
288312
config,
289313
step,
@@ -297,19 +321,7 @@ def train_loop(config, elastic_manager, state=None):
297321
metric_logger,
298322
writer,
299323
input_data_shardings,
300-
) = elastic_manager.maybe_reshard_up(
301-
step=step,
302-
snapshot_jax_arrays={
303-
"params": state.params,
304-
"opt_state": state.opt_state,
305-
},
306-
elastic_handler=elastic_handler,
307-
handler_kwargs={
308-
"config": config,
309-
"elastic_manager": elastic_manager,
310-
"checkpoint_manager": checkpoint_manager,
311-
},
312-
)
324+
) = ret
313325
step += 1
314326

315327
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
@@ -354,44 +366,6 @@ def train_loop(config, elastic_manager, state=None):
354366
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
355367
prof.deactivate(blocking_object=state)
356368

357-
elastic_manager.maybe_snapshot(
358-
step=step,
359-
snapshot_jax_arrays={
360-
"params": state.params,
361-
"opt_state": state.opt_state,
362-
},
363-
block=True,
364-
)
365-
366-
ret = elastic_manager.maybe_reshard_up(
367-
step=step,
368-
snapshot_jax_arrays={
369-
"params": state.params,
370-
"opt_state": state.opt_state,
371-
},
372-
elastic_handler=elastic_handler,
373-
handler_kwargs={
374-
"config": config,
375-
"elastic_manager": elastic_manager,
376-
"checkpoint_manager": checkpoint_manager,
377-
},
378-
)
379-
380-
if ret is not None:
381-
(
382-
config,
383-
step,
384-
state,
385-
mesh,
386-
checkpoint_manager,
387-
data_iterator,
388-
p_train_step,
389-
example_batch,
390-
learning_rate_schedule,
391-
metric_logger,
392-
writer,
393-
input_data_shardings,
394-
) = ret
395369

396370
if step == start_step:
397371
max_utils.print_mem_stats("After params initialized")

0 commit comments

Comments
 (0)