Skip to content

Commit 6a9edc3

Browse files
committed
[PyTorch] Fix backward compatibility for checkpoint API (#748)
* Args can be None Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix other arg types Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 35a8754 commit 6a9edc3

File tree

1 file changed

+2
-17
lines changed

1 file changed

+2
-17
lines changed

transformer_engine/pytorch/distributed.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -516,12 +516,6 @@ def checkpoint(
516516
kwargs : dict
517517
dictionary of string keys for keyword arguments to :attr:`function`.
518518
"""
519-
only_tensor_args = True
520-
for arg in args:
521-
if not isinstance(arg, torch.Tensor):
522-
only_tensor_args = False
523-
break
524-
525519
# Pop out te.distributed.checkpoint() arguments
526520
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
527521
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
@@ -530,23 +524,14 @@ def checkpoint(
530524
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
531525

532526
# Ensure backward compatibility.
533-
if not only_tensor_args:
527+
if (len(args) > 3 and isinstance(args[0], bool) and callable(args[1])
528+
and isinstance(args[2], None | dist_group_type)):
534529
warnings.warn(
535530
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
536531
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
537532
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
538533
DeprecationWarning, stacklevel=2,
539534
)
540-
assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API."
541-
assert (
542-
isinstance(args[0], bool) and callable(args[1])
543-
and isinstance(args[2], None | dist_group_type)
544-
), "Incorrect arguments for deprecated `checkpoint` API."
545-
for arg in args[3:]:
546-
assert (
547-
isinstance(arg, None | torch.Tensor)
548-
), f"Expected tensor argument, found {type(arg)}."
549-
550535
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
551536
args = args[3:]
552537

0 commit comments

Comments
 (0)