@@ -516,12 +516,6 @@ def checkpoint(
516
516
kwargs : dict
517
517
dictionary of string keys for keyword arguments to :attr:`function`.
518
518
"""
519
- only_tensor_args = True
520
- for arg in args :
521
- if not isinstance (arg , torch .Tensor ):
522
- only_tensor_args = False
523
- break
524
-
525
519
# Pop out te.distributed.checkpoint() arguments
526
520
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
527
521
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs .pop ("use_reentrant" , True )
@@ -530,23 +524,14 @@ def checkpoint(
530
524
get_rng_state_tracker = kwargs .pop ("get_rng_state_tracker" , None )
531
525
532
526
# Ensure backward compatibility.
533
- if not only_tensor_args :
527
+ if (len (args ) > 3 and isinstance (args [0 ], bool ) and callable (args [1 ])
528
+ and isinstance (args [2 ], None | dist_group_type )):
534
529
warnings .warn (
535
530
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
536
531
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
537
532
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`." ,
538
533
DeprecationWarning , stacklevel = 2 ,
539
534
)
540
- assert len (args ) > 3 , "Incorrect number of arguments for deprecated `checkpoint` API."
541
- assert (
542
- isinstance (args [0 ], bool ) and callable (args [1 ])
543
- and isinstance (args [2 ], None | dist_group_type )
544
- ), "Incorrect arguments for deprecated `checkpoint` API."
545
- for arg in args [3 :]:
546
- assert (
547
- isinstance (arg , None | torch .Tensor )
548
- ), f"Expected tensor argument, found { type (arg )} ."
549
-
550
535
distribute_saved_activations , get_rng_state_tracker , tp_group = args [:3 ] # pylint: disable=unbalanced-tuple-unpacking
551
536
args = args [3 :]
552
537
0 commit comments