@@ -280,6 +280,38 @@ def train_loop(config, elastic_manager, state=None):
280
280
# the step is restored back to the latest snapshot when a slice is lost
281
281
while step < config .steps :
282
282
try :
283
+ if (config .elastic_mode == "fast-resume" and
284
+ elastic_manager .good_slice_count < elastic_manager .total_slice_count ):
285
+ wait_for_all_slices (elastic_manager , config .elastic_wait_period )
286
+
287
+ (
288
+ config ,
289
+ step ,
290
+ state ,
291
+ mesh ,
292
+ checkpoint_manager ,
293
+ data_iterator ,
294
+ p_train_step ,
295
+ example_batch ,
296
+ learning_rate_schedule ,
297
+ metric_logger ,
298
+ writer ,
299
+ input_data_shardings ,
300
+ ) = elastic_manager .maybe_reshard_up (
301
+ step = step ,
302
+ snapshot_jax_arrays = {
303
+ "params" : state .params ,
304
+ "opt_state" : state .opt_state ,
305
+ },
306
+ elastic_handler = elastic_handler ,
307
+ handler_kwargs = {
308
+ "config" : config ,
309
+ "elastic_manager" : elastic_manager ,
310
+ "checkpoint_manager" : checkpoint_manager ,
311
+ },
312
+ )
313
+ step += 1
314
+
283
315
if step == first_profiling_step or prof .should_activate_periodic_profile (step ):
284
316
optional_postfix = f"step_{ step } " if config .profile_periodically_period > 0 else ""
285
317
prof .activate (blocking_object = state , optional_postfix = optional_postfix )
@@ -344,6 +376,7 @@ def train_loop(config, elastic_manager, state=None):
344
376
"checkpoint_manager" : checkpoint_manager ,
345
377
},
346
378
)
379
+
347
380
if ret is not None :
348
381
(
349
382
config ,
@@ -366,7 +399,20 @@ def train_loop(config, elastic_manager, state=None):
366
399
step += 1
367
400
368
401
except jax .errors .JaxRuntimeError as error :
369
- ret = elastic_manager .maybe_reshard_down (
402
+ (
403
+ config ,
404
+ step ,
405
+ state ,
406
+ mesh ,
407
+ checkpoint_manager ,
408
+ data_iterator ,
409
+ p_train_step ,
410
+ example_batch ,
411
+ learning_rate_schedule ,
412
+ metric_logger ,
413
+ writer ,
414
+ input_data_shardings ,
415
+ ) = elastic_manager .maybe_reshard_down (
370
416
error = error ,
371
417
elastic_handler = elastic_handler ,
372
418
handler_kwargs = {
@@ -375,21 +421,6 @@ def train_loop(config, elastic_manager, state=None):
375
421
"checkpoint_manager" : checkpoint_manager ,
376
422
},
377
423
)
378
- if ret is not None :
379
- (
380
- config ,
381
- step ,
382
- state ,
383
- mesh ,
384
- checkpoint_manager ,
385
- data_iterator ,
386
- p_train_step ,
387
- example_batch ,
388
- learning_rate_schedule ,
389
- metric_logger ,
390
- writer ,
391
- input_data_shardings ,
392
- ) = ret
393
424
394
425
if checkpoint_manager is not None :
395
426
if (int (state .step ) - 1 ) % config .checkpoint_period != 0 :
@@ -426,16 +457,21 @@ def train_loop(config, elastic_manager, state=None):
426
457
return state
427
458
428
459
429
- def wait_for_all_slices (elastic_manager : manager .Manager ) -> None :
430
- elastic_manager .good_slice_indices = elastic_manager .get_slice_availability ()
431
- while len (elastic_manager .good_slice_indices ) < elastic_manager .total_slice_count :
460
+ def wait_for_all_slices (
461
+ elastic_manager : manager .Manager ,
462
+ wait_period : int ,
463
+ ) -> set [int ]:
464
+ good_slice_indices = elastic_manager .get_slice_availability ()
465
+ while len (good_slice_indices ) < elastic_manager .total_slice_count :
432
466
max_logging .log (
433
467
f"Only { elastic_manager .good_slice_count } slices out of { elastic_manager .total_slice_count } available. "
434
- "Sleeping for 5 seconds."
468
+ f "Sleeping for { wait_period } seconds."
435
469
)
436
- time .sleep (5 )
437
- elastic_manager .good_slice_indices = elastic_manager .get_slice_availability ()
470
+ time .sleep (wait_period )
471
+ good_slice_indices = elastic_manager .get_slice_availability ()
472
+
438
473
max_logging .log ("All slices are available" )
474
+ return good_slice_indices
439
475
440
476
441
477
def elastic_initialize (devices : Sequence [jax .Device ]) -> manager .Manager :
@@ -447,17 +483,11 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
447
483
Returns:
448
484
The initialized elastic manager
449
485
"""
450
- elastic_manager = manager .Manager (
451
- devices ,
452
- reshard_check_period = 1 ,
453
- snapshot_period = 5 ,
454
- max_elastic_down_event_count = 100 ,
455
- max_reshard_retry_count = 3 ,
456
- )
486
+ elastic_manager = manager .Manager (devices )
457
487
458
488
# Do not start training until all slices are available
459
489
# TODO: b/408455557 - Migrate to pathwaysutils and make configurable
460
- wait_for_all_slices (elastic_manager )
490
+ elastic_manager . good_slice_indices = wait_for_all_slices (elastic_manager , 30 )
461
491
462
492
pyconfig .HyperParameters .global_batch_size_to_train_on = property (
463
493
lambda self : elastic_manager .scale_by_good_slices (self .get_keys ()["global_batch_size_to_train_on" ])
@@ -472,6 +502,14 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
472
502
473
503
return elastic_manager
474
504
505
+ def elastic_configure (
506
+ config : pyconfig .HyperParameters ,
507
+ elastic_manager : manager .Manager ,
508
+ ):
509
+ elastic_manager .reshard_check_period = config .elastic_reshard_check_period
510
+ elastic_manager .snapshot_period = config .elastic_snapshot_period
511
+ elastic_manager .max_elastic_down_event_count = config .elastic_max_elastic_down_event_count
512
+ elastic_manager .max_reshard_retry_count = config .elastic_max_reshard_retry_count
475
513
476
514
def main (argv : Sequence [str ]) -> None :
477
515
pathwaysutils .initialize ()
@@ -486,6 +524,9 @@ def main(argv: Sequence[str]) -> None:
486
524
elastic_manager = elastic_initialize (jax .devices ())
487
525
488
526
config = pyconfig .initialize (argv )
527
+
528
+ elastic_configure (config , elastic_manager )
529
+
489
530
max_utils .print_system_information ()
490
531
validate_train_config (config )
491
532
os .environ ["TFDS_DATA_DIR" ] = config .dataset_path or ""
0 commit comments