@@ -280,10 +280,34 @@ def train_loop(config, elastic_manager, state=None):
280
280
# the step is restored back to the latest snapshot when a slice is lost
281
281
while step < config .steps :
282
282
try :
283
+ elastic_manager .maybe_snapshot (
284
+ step = step ,
285
+ snapshot_jax_arrays = {
286
+ "params" : state .params ,
287
+ "opt_state" : state .opt_state ,
288
+ },
289
+ block = True ,
290
+ )
291
+
283
292
if (config .elastic_mode == "fast-resume" and
284
293
elastic_manager .good_slice_count < elastic_manager .total_slice_count ):
285
294
wait_for_all_slices (elastic_manager , config .elastic_wait_period )
286
295
296
+ ret = elastic_manager .maybe_reshard_up (
297
+ step = step ,
298
+ snapshot_jax_arrays = {
299
+ "params" : state .params ,
300
+ "opt_state" : state .opt_state ,
301
+ },
302
+ elastic_handler = elastic_handler ,
303
+ handler_kwargs = {
304
+ "config" : config ,
305
+ "elastic_manager" : elastic_manager ,
306
+ "checkpoint_manager" : checkpoint_manager ,
307
+ },
308
+ )
309
+
310
+ if ret is not None :
287
311
(
288
312
config ,
289
313
step ,
@@ -297,19 +321,7 @@ def train_loop(config, elastic_manager, state=None):
297
321
metric_logger ,
298
322
writer ,
299
323
input_data_shardings ,
300
- ) = elastic_manager .maybe_reshard_up (
301
- step = step ,
302
- snapshot_jax_arrays = {
303
- "params" : state .params ,
304
- "opt_state" : state .opt_state ,
305
- },
306
- elastic_handler = elastic_handler ,
307
- handler_kwargs = {
308
- "config" : config ,
309
- "elastic_manager" : elastic_manager ,
310
- "checkpoint_manager" : checkpoint_manager ,
311
- },
312
- )
324
+ ) = ret
313
325
step += 1
314
326
315
327
if step == first_profiling_step or prof .should_activate_periodic_profile (step ):
@@ -354,44 +366,6 @@ def train_loop(config, elastic_manager, state=None):
354
366
if step == last_profiling_step or prof .should_deactivate_periodic_profile (step ):
355
367
prof .deactivate (blocking_object = state )
356
368
357
- elastic_manager .maybe_snapshot (
358
- step = step ,
359
- snapshot_jax_arrays = {
360
- "params" : state .params ,
361
- "opt_state" : state .opt_state ,
362
- },
363
- block = True ,
364
- )
365
-
366
- ret = elastic_manager .maybe_reshard_up (
367
- step = step ,
368
- snapshot_jax_arrays = {
369
- "params" : state .params ,
370
- "opt_state" : state .opt_state ,
371
- },
372
- elastic_handler = elastic_handler ,
373
- handler_kwargs = {
374
- "config" : config ,
375
- "elastic_manager" : elastic_manager ,
376
- "checkpoint_manager" : checkpoint_manager ,
377
- },
378
- )
379
-
380
- if ret is not None :
381
- (
382
- config ,
383
- step ,
384
- state ,
385
- mesh ,
386
- checkpoint_manager ,
387
- data_iterator ,
388
- p_train_step ,
389
- example_batch ,
390
- learning_rate_schedule ,
391
- metric_logger ,
392
- writer ,
393
- input_data_shardings ,
394
- ) = ret
395
369
396
370
if step == start_step :
397
371
max_utils .print_mem_stats ("After params initialized" )
0 commit comments